tensorrt: support multiple CUDA versions

Refactor derivation to pick the version that supports the current CUDA version.
Based on the implementation of the same concept in the cudnn derivation.
This commit is contained in:
Aidan Gauland 2022-07-02 09:53:27 +12:00
parent d70b4df686
commit c8fba8254a
No known key found for this signature in database
GPG Key ID: 16E68DD2D0E77C91
3 changed files with 85 additions and 21 deletions

View File

@ -0,0 +1,63 @@
final: prev: let
inherit (final) callPackage;
inherit (prev) cudatoolkit cudaVersion lib pkgs;
### TensorRT
buildTensorRTPackage = args:
callPackage ./generic.nix { } args;
toUnderscore = str: lib.replaceStrings ["."] ["_"] str;
majorMinorPatch = str: lib.concatStringsSep "." (lib.take 3 (lib.splitVersion str));
tensorRTPackages = with lib; let
# Check whether a file is supported for our cuda version
isSupported = fileData: elem cudaVersion fileData.supportedCudaVersions;
# Return the first file that is supported. In practice there should only ever be one anyway.
supportedFile = files: findFirst isSupported null files;
# Supported versions with versions as keys and file as value
supportedVersions = filterAttrs (version: file: file !=null ) (mapAttrs (version: files: supportedFile files) tensorRTVersions);
# Compute versioned attribute name to be used in this package set
computeName = version: "tensorrt_${toUnderscore version}";
# Add all supported builds as attributes
allBuilds = mapAttrs' (version: file: nameValuePair (computeName version) (buildTensorRTPackage (removeAttrs file ["fileVersionCuda"]))) supportedVersions;
# Set the default attributes, e.g. tensorrt = tensorrt_8_4;
defaultBuild = { "tensorrt" = allBuilds.${computeName tensorRTDefaultVersion}; };
in allBuilds // defaultBuild;
tensorRTVersions = {
"8.4.0" = [
rec {
fileVersionCuda = "11.6";
fileVersionCudnn = "8.3";
fullVersion = "8.4.0.6";
sha256 = "sha256-DNgHHXF/G4cK2nnOWImrPXAkOcNW6Wy+8j0LRpAH/LQ=";
tarball = "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}.cudnn${fileVersionCudnn}.tar.gz";
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" ];
}
rec {
fileVersionCuda = "10.2";
fileVersionCudnn = "8.3";
fullVersion = "8.4.0.6";
sha256 = "sha256-aCzH0ZI6BrJ0v+e5Bnm7b8mNltA7NNuIa8qRKzAQv+I=";
tarball = "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}.cudnn${fileVersionCudnn}.tar.gz";
supportedCudaVersions = [ "10.2" ];
}
];
};
# Default attributes
tensorRTDefaultVersion = {
"10.2" = "8.4.0";
"11.0" = "8.4.0";
"11.1" = "8.4.0";
"11.2" = "8.4.0";
"11.3" = "8.4.0";
"11.4" = "8.4.0";
"11.5" = "8.4.0";
"11.6" = "8.4.0";
}.${cudaVersion};
in tensorRTPackages

View File

@ -8,20 +8,25 @@
, cudnn
}:
assert lib.assertMsg (lib.strings.versionAtLeast cudaVersion "11.0")
"This version of TensorRT requires at least CUDA 11.0 (current version is ${cudaVersion})";
assert lib.assertMsg (lib.strings.versionAtLeast cudnn.version "8.3")
"This version of TensorRT requires at least cuDNN 8.3 (current version is ${cudnn.version})";
{ fullVersion
, fileVersionCudnn
, tarball
, sha256
, supportedCudaVersions ? [ ]
}:
assert lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn)
"This version of TensorRT requires at least cuDNN ${fileVersionCudnn} (current version is ${cudnn.version})";
stdenv.mkDerivation rec {
pname = "cudatoolkit-${cudatoolkit.majorVersion}-tensorrt";
version = "8.4.0.6";
version = fullVersion;
src = requireFile rec {
name = "TensorRT-${version}.Linux.x86_64-gnu.cuda-11.6.cudnn8.3.tar.gz";
sha256 = "sha256-DNgHHXF/G4cK2nnOWImrPXAkOcNW6Wy+8j0LRpAH/LQ=";
name = tarball;
inherit sha256;
message = ''
To use the TensorRT derivation, you must join the NVIDIA Developer Program
and download the ${version} Linux x86_64 TAR package from
To use the TensorRT derivation, you must join the NVIDIA Developer Program and
download the ${version} Linux x86_64 TAR package for CUDA ${cudaVersion} from
${meta.homepage}.
Once you have downloaded the file, add it to the store with the following
@ -70,6 +75,12 @@ stdenv.mkDerivation rec {
'';
meta = with lib; {
# Check that the cudatoolkit version satisfies our min/max constraints (both
# inclusive). We mark the package as broken if it fails to satisfies the
# official version constraints (as recorded in default.nix). In some cases
# you _may_ be able to smudge version constraints, just know that you're
# embarking into unknown and unsupported territory when doing so.
broken = !(elem cudaVersion supportedCudaVersions);
description = "TensorRT: a high-performance deep learning interface";
homepage = "https://developer.nvidia.com/tensorrt";
license = licenses.unfree;

View File

@ -43,16 +43,6 @@ let
};
in { inherit cutensor; };
tensorrtExtension = final: prev: let
### Tensorrt
inherit (final) cudaMajorMinorVersion cudaMajorVersion;
# TODO: Add derivations for TensorRT versions that support older CUDA versions.
tensorrt = final.callPackage ../development/libraries/science/math/tensorrt/8.nix { };
in { inherit tensorrt; };
extraPackagesExtension = final: prev: {
nccl = final.callPackage ../development/libraries/science/math/nccl { };
@ -74,10 +64,10 @@ let
(import ../development/compilers/cudatoolkit/redist/extension.nix)
(import ../development/compilers/cudatoolkit/redist/overrides.nix)
(import ../development/libraries/science/math/cudnn/extension.nix)
(import ../development/libraries/science/math/tensorrt/extension.nix)
(import ../test/cuda/cuda-samples/extension.nix)
(import ../test/cuda/cuda-library-samples/extension.nix)
cutensorExtension
] ++ (lib.optional (lib.strings.versionAtLeast cudaVersion "11.0") tensorrtExtension));
# We only package the current version of TensorRT, which requires CUDA 11.
]);
in (scope.overrideScope' composedExtension)