diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9ca7f71..673f0d8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,6 +24,9 @@ on: nccl-version: required: true type: string + cudnn-version: + required: false + type: string cuda-samples-version: required: true type: string @@ -75,6 +78,7 @@ jobs: CUDA_VERSION_MINOR=${{ inputs.cuda-version-minor }} CUDA_VERSION_MAJOR=${{ inputs.cuda-version-major }} TARGET_NCCL_VERSION=${{ inputs.nccl-version }} + CUDNN_VERSION=${{ inputs.cudnn-version }} CUDA_SAMPLES_VERSION=${{ inputs.cuda-samples-version }} HPCX_VERSION=${{ inputs.hpcx-version }} HPCX_NCCL_VERSION=${{ inputs.hpcx-nccl-version }} diff --git a/.github/workflows/ubuntu-20.yml b/.github/workflows/ubuntu-20.yml index 847061a..de24521 100644 --- a/.github/workflows/ubuntu-20.yml +++ b/.github/workflows/ubuntu-20.yml @@ -77,6 +77,7 @@ jobs: cuda-version-minor: "12.2.0" cuda-version-major: "12.2" nccl-version: 2.18.3-1 + cudnn-version: "8.9.3.*-1+cuda12.?" cuda-samples-version: "12.2" hpcx-version: "2.15" hpcx-nccl-version: "2.17" diff --git a/Dockerfile.ubuntu20 b/Dockerfile.ubuntu20 index ad46790..e463ab6 100644 --- a/Dockerfile.ubuntu20 +++ b/Dockerfile.ubuntu20 @@ -1,6 +1,26 @@ ARG CUDA_VERSION_MINOR=11.7.1 ARG BASE_IMAGE=nvidia/cuda:${CUDA_VERSION_MINOR}-cudnn8-devel-ubuntu20.04 -FROM ${BASE_IMAGE} + +# Patch in cuDNN manually by specfiying a version here +# only if a base image with cuDNN is not available +ARG CUDNN_VERSION="" + +# Effectively a ternary _CUDNN_IMAGE = CUDNN_VERSION ? "add-cudnn" : BASE_IMAGE +ARG _CUDNN_IMAGE=${CUDNN_VERSION:+add-cudnn} +ARG _CUDNN_IMAGE=${_CUDNN_IMAGE:-$BASE_IMAGE} + +FROM ${BASE_IMAGE} as add-cudnn +ARG CUDNN_VERSION + +RUN apt-get -qq update && \ + apt-get -qq install -y --allow-change-held-packages --no-install-recommends \ + "libcudnn8=${CUDNN_VERSION}" \ + "libcudnn8-dev=${CUDNN_VERSION}" && \ + apt-mark hold libcudnn8 && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +FROM ${_CUDNN_IMAGE} ARG CUDA_VERSION_MAJOR=11.7 ARG TARGET_NCCL_VERSION=2.14.3-1