From 2a42503192f2fcf77009915424a60a016a8364ff Mon Sep 17 00:00:00 2001 From: Connor Baker Date: Fri, 27 Oct 2023 02:51:42 +0000 Subject: [PATCH] python3Packages.torch: patch `cpp_extension.py` for Jetson support --- pkgs/development/python-modules/torch/default.nix | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix index 59396d421ed9..993f49c41f9a 100644 --- a/pkgs/development/python-modules/torch/default.nix +++ b/pkgs/development/python-modules/torch/default.nix @@ -48,7 +48,10 @@ let inherit (lib) attrsets lists strings trivial; - inherit (cudaPackages) cudaFlags cudnn nccl; + inherit (cudaPackages) cudaFlags cudnn; + + # Some packages are not available on all platforms + nccl = cudaPackages.nccl or null; setBool = v: if v then "1" else "0"; @@ -178,6 +181,13 @@ in buildPythonPackage rec { 'message(FATAL_ERROR "Found NCCL header version and library version' \ 'message(WARNING "Found NCCL header version and library version' '' + # TODO(@connorbaker): Remove this patch after 2.1.0 lands. + + lib.optionalString cudaSupport '' + substituteInPlace torch/utils/cpp_extension.py \ + --replace \ + "'8.6', '8.9'" \ + "'8.6', '8.7', '8.9'" + '' # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc' # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header. + lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.targetPlatform.darwinSdkVersion "11.0") '' @@ -253,6 +263,7 @@ in buildPythonPackage rec { PYTORCH_BUILD_VERSION = version; PYTORCH_BUILD_NUMBER = 0; + USE_NCCL = setBool (nccl != null); USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL USE_STATIC_NCCL = setBool useSystemNccl; @@ -316,6 +327,8 @@ in buildPythonPackage rec { libcusolver.lib libcusparse.dev libcusparse.lib + ] ++ lists.optionals (nccl != null) [ + # Some platforms do not support NCCL (i.e., Jetson) nccl.dev # Provides nccl.h AND a static copy of NCCL! ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [ cuda_nvprof.dev #