tensorrt: support multiple CUDA versions
Refactor derivation to pick the version that supports the current CUDA version. Based on the implementation of the same concept in the cudnn derivation.
This commit is contained in:
parent
d70b4df686
commit
c8fba8254a
@ -0,0 +1,63 @@
|
||||
final: prev: let
|
||||
|
||||
inherit (final) callPackage;
|
||||
inherit (prev) cudatoolkit cudaVersion lib pkgs;
|
||||
|
||||
### TensorRT
|
||||
|
||||
buildTensorRTPackage = args:
|
||||
callPackage ./generic.nix { } args;
|
||||
|
||||
toUnderscore = str: lib.replaceStrings ["."] ["_"] str;
|
||||
|
||||
majorMinorPatch = str: lib.concatStringsSep "." (lib.take 3 (lib.splitVersion str));
|
||||
|
||||
tensorRTPackages = with lib; let
|
||||
# Check whether a file is supported for our cuda version
|
||||
isSupported = fileData: elem cudaVersion fileData.supportedCudaVersions;
|
||||
# Return the first file that is supported. In practice there should only ever be one anyway.
|
||||
supportedFile = files: findFirst isSupported null files;
|
||||
# Supported versions with versions as keys and file as value
|
||||
supportedVersions = filterAttrs (version: file: file !=null ) (mapAttrs (version: files: supportedFile files) tensorRTVersions);
|
||||
# Compute versioned attribute name to be used in this package set
|
||||
computeName = version: "tensorrt_${toUnderscore version}";
|
||||
# Add all supported builds as attributes
|
||||
allBuilds = mapAttrs' (version: file: nameValuePair (computeName version) (buildTensorRTPackage (removeAttrs file ["fileVersionCuda"]))) supportedVersions;
|
||||
# Set the default attributes, e.g. tensorrt = tensorrt_8_4;
|
||||
defaultBuild = { "tensorrt" = allBuilds.${computeName tensorRTDefaultVersion}; };
|
||||
in allBuilds // defaultBuild;
|
||||
|
||||
tensorRTVersions = {
|
||||
"8.4.0" = [
|
||||
rec {
|
||||
fileVersionCuda = "11.6";
|
||||
fileVersionCudnn = "8.3";
|
||||
fullVersion = "8.4.0.6";
|
||||
sha256 = "sha256-DNgHHXF/G4cK2nnOWImrPXAkOcNW6Wy+8j0LRpAH/LQ=";
|
||||
tarball = "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}.cudnn${fileVersionCudnn}.tar.gz";
|
||||
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" ];
|
||||
}
|
||||
rec {
|
||||
fileVersionCuda = "10.2";
|
||||
fileVersionCudnn = "8.3";
|
||||
fullVersion = "8.4.0.6";
|
||||
sha256 = "sha256-aCzH0ZI6BrJ0v+e5Bnm7b8mNltA7NNuIa8qRKzAQv+I=";
|
||||
tarball = "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}.cudnn${fileVersionCudnn}.tar.gz";
|
||||
supportedCudaVersions = [ "10.2" ];
|
||||
}
|
||||
];
|
||||
};
|
||||
|
||||
# Default attributes
|
||||
tensorRTDefaultVersion = {
|
||||
"10.2" = "8.4.0";
|
||||
"11.0" = "8.4.0";
|
||||
"11.1" = "8.4.0";
|
||||
"11.2" = "8.4.0";
|
||||
"11.3" = "8.4.0";
|
||||
"11.4" = "8.4.0";
|
||||
"11.5" = "8.4.0";
|
||||
"11.6" = "8.4.0";
|
||||
}.${cudaVersion};
|
||||
|
||||
in tensorRTPackages
|
@ -8,20 +8,25 @@
|
||||
, cudnn
|
||||
}:
|
||||
|
||||
assert lib.assertMsg (lib.strings.versionAtLeast cudaVersion "11.0")
|
||||
"This version of TensorRT requires at least CUDA 11.0 (current version is ${cudaVersion})";
|
||||
assert lib.assertMsg (lib.strings.versionAtLeast cudnn.version "8.3")
|
||||
"This version of TensorRT requires at least cuDNN 8.3 (current version is ${cudnn.version})";
|
||||
{ fullVersion
|
||||
, fileVersionCudnn
|
||||
, tarball
|
||||
, sha256
|
||||
, supportedCudaVersions ? [ ]
|
||||
}:
|
||||
|
||||
assert lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn)
|
||||
"This version of TensorRT requires at least cuDNN ${fileVersionCudnn} (current version is ${cudnn.version})";
|
||||
|
||||
stdenv.mkDerivation rec {
|
||||
pname = "cudatoolkit-${cudatoolkit.majorVersion}-tensorrt";
|
||||
version = "8.4.0.6";
|
||||
version = fullVersion;
|
||||
src = requireFile rec {
|
||||
name = "TensorRT-${version}.Linux.x86_64-gnu.cuda-11.6.cudnn8.3.tar.gz";
|
||||
sha256 = "sha256-DNgHHXF/G4cK2nnOWImrPXAkOcNW6Wy+8j0LRpAH/LQ=";
|
||||
name = tarball;
|
||||
inherit sha256;
|
||||
message = ''
|
||||
To use the TensorRT derivation, you must join the NVIDIA Developer Program
|
||||
and download the ${version} Linux x86_64 TAR package from
|
||||
To use the TensorRT derivation, you must join the NVIDIA Developer Program and
|
||||
download the ${version} Linux x86_64 TAR package for CUDA ${cudaVersion} from
|
||||
${meta.homepage}.
|
||||
|
||||
Once you have downloaded the file, add it to the store with the following
|
||||
@ -70,6 +75,12 @@ stdenv.mkDerivation rec {
|
||||
'';
|
||||
|
||||
meta = with lib; {
|
||||
# Check that the cudatoolkit version satisfies our min/max constraints (both
|
||||
# inclusive). We mark the package as broken if it fails to satisfies the
|
||||
# official version constraints (as recorded in default.nix). In some cases
|
||||
# you _may_ be able to smudge version constraints, just know that you're
|
||||
# embarking into unknown and unsupported territory when doing so.
|
||||
broken = !(elem cudaVersion supportedCudaVersions);
|
||||
description = "TensorRT: a high-performance deep learning interface";
|
||||
homepage = "https://developer.nvidia.com/tensorrt";
|
||||
license = licenses.unfree;
|
@ -43,16 +43,6 @@ let
|
||||
};
|
||||
in { inherit cutensor; };
|
||||
|
||||
tensorrtExtension = final: prev: let
|
||||
### Tensorrt
|
||||
|
||||
inherit (final) cudaMajorMinorVersion cudaMajorVersion;
|
||||
|
||||
# TODO: Add derivations for TensorRT versions that support older CUDA versions.
|
||||
|
||||
tensorrt = final.callPackage ../development/libraries/science/math/tensorrt/8.nix { };
|
||||
in { inherit tensorrt; };
|
||||
|
||||
extraPackagesExtension = final: prev: {
|
||||
|
||||
nccl = final.callPackage ../development/libraries/science/math/nccl { };
|
||||
@ -74,10 +64,10 @@ let
|
||||
(import ../development/compilers/cudatoolkit/redist/extension.nix)
|
||||
(import ../development/compilers/cudatoolkit/redist/overrides.nix)
|
||||
(import ../development/libraries/science/math/cudnn/extension.nix)
|
||||
(import ../development/libraries/science/math/tensorrt/extension.nix)
|
||||
(import ../test/cuda/cuda-samples/extension.nix)
|
||||
(import ../test/cuda/cuda-library-samples/extension.nix)
|
||||
cutensorExtension
|
||||
] ++ (lib.optional (lib.strings.versionAtLeast cudaVersion "11.0") tensorrtExtension));
|
||||
# We only package the current version of TensorRT, which requires CUDA 11.
|
||||
]);
|
||||
|
||||
in (scope.overrideScope' composedExtension)
|
||||
|
Loading…
Reference in New Issue
Block a user