From 09358a09240513ad8cdf940dbd68d566b6ad58b2 Mon Sep 17 00:00:00 2001 From: Break Yang Date: Thu, 25 Jan 2024 18:47:24 -0800 Subject: [PATCH 1/2] python3Packages.jaxlib-bin: add cuda 11.8 wheels --- .../development/python-modules/jaxlib/bin.nix | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index f6f8f5e2b1b6..b1a9e8a6dfc5 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -33,7 +33,7 @@ }: let - inherit (cudaPackagesGoogle) cudatoolkit cudnn; + inherit (cudaPackagesGoogle) cudatoolkit cudnn cudaVersion; version = "0.4.23"; @@ -118,25 +118,41 @@ let }; }; - # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. + # Find new releases at https://storage.googleapis.com/jax-releases # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. gpuSrcs = { - "3.9" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; - hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI="; + "cuda12.2" = { + "3.9" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; + hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI="; + }; + "3.10" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg="; + }; + "3.11" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; + hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow="; + }; + "3.12" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; + hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo="; + }; }; - "3.10" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg="; - }; - "3.11" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; - hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow="; - }; - "3.12" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; - hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo="; + "cuda11.8" = { + "3.9" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl"; + hash = "sha256-m2Y5p12gF3OaADu+aGw5RjcKFrj9RB8xzNWnKNpSz60="; + }; + "3.10" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "osha256-aQ7iX3o0kQ4liPexv7dkBVWVTUpaty83L083MybGkf0="; + }; + "3.11" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl"; + hash = "sha256-uIEyjEmv0HBaiYVl5PuICTI9XnH4zAfQ1l9tjALRcP4="; + }; }; }; @@ -146,7 +162,10 @@ buildPythonPackage { inherit version; format = "wheel"; - disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10" || pythonVersion == "3.11" || pythonVersion == "3.12"); + # Note that the prebuilt jaxlib binary requires specific version of CUDA to + # work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11 + # jaxlib binaries only works with CUDA 11.8. + disabled = !((cudaVersion == "11.8" && (pythonVersion == "3.9" || pythonVersion == "3.10" || pythonVersion == "3.11")) || (cudaVersion == "12.2" && (pythonVersion == "3.9" || pythonVersion == "3.10" || pythonVersion == "3.11" || pythonVersion == "3.12"))); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. src = @@ -154,7 +173,7 @@ buildPythonPackage { ( cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}" or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}") - ) else gpuSrcs."${pythonVersion}"; + ) else gpuSrcs."cuda${cudaVersion}"."${pythonVersion}"; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. From 56cb3db6bab7ad66fdb6354c7fdb2ef298e99f8e Mon Sep 17 00:00:00 2001 From: Break Yang Date: Tue, 30 Jan 2024 14:32:41 -0800 Subject: [PATCH 2/2] python3Packages.jaxlib-bin: use correct cuda releases --- .../development/python-modules/jaxlib/bin.nix | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index b1a9e8a6dfc5..b7e20d6f55e7 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -118,41 +118,43 @@ let }; }; + # Note that the prebuilt jaxlib binary requires specific version of CUDA to + # work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11 + # jaxlib binaries only works with CUDA 11.8. This is why we need to find a + # binary that matches the provided cudaVersion. + gpuSrcVersionString = "cuda${cudaVersion}-${pythonVersion}"; + # Find new releases at https://storage.googleapis.com/jax-releases # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. gpuSrcs = { - "cuda12.2" = { - "3.9" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; - hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI="; - }; - "3.10" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg="; - }; - "3.11" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; - hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow="; - }; - "3.12" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; - hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo="; - }; + "cuda12.2-3.9" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; + hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI="; }; - "cuda11.8" = { - "3.9" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl"; - hash = "sha256-m2Y5p12gF3OaADu+aGw5RjcKFrj9RB8xzNWnKNpSz60="; - }; - "3.10" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "osha256-aQ7iX3o0kQ4liPexv7dkBVWVTUpaty83L083MybGkf0="; - }; - "3.11" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl"; - hash = "sha256-uIEyjEmv0HBaiYVl5PuICTI9XnH4zAfQ1l9tjALRcP4="; - }; + "cuda12.2-3.10" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg="; + }; + "cuda12.2-3.11" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; + hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow="; + }; + "cuda12.2-3.12" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; + hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo="; + }; + "cuda11.8-3.9" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl"; + hash = "sha256-m2Y5p12gF3OaADu+aGw5RjcKFrj9RB8xzNWnKNpSz60="; + }; + "cuda11.8-3.10" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "osha256-aQ7iX3o0kQ4liPexv7dkBVWVTUpaty83L083MybGkf0="; + }; + "cuda11.8-3.11" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl"; + hash = "sha256-uIEyjEmv0HBaiYVl5PuICTI9XnH4zAfQ1l9tjALRcP4="; }; }; @@ -162,10 +164,7 @@ buildPythonPackage { inherit version; format = "wheel"; - # Note that the prebuilt jaxlib binary requires specific version of CUDA to - # work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11 - # jaxlib binaries only works with CUDA 11.8. - disabled = !((cudaVersion == "11.8" && (pythonVersion == "3.9" || pythonVersion == "3.10" || pythonVersion == "3.11")) || (cudaVersion == "12.2" && (pythonVersion == "3.9" || pythonVersion == "3.10" || pythonVersion == "3.11" || pythonVersion == "3.12"))); + disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10" || pythonVersion == "3.11" || pythonVersion == "3.12"); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. src = @@ -173,7 +172,7 @@ buildPythonPackage { ( cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}" or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}") - ) else gpuSrcs."cuda${cudaVersion}"."${pythonVersion}"; + ) else gpuSrcs."${gpuSrcVersionString}"; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. @@ -231,6 +230,7 @@ buildPythonPackage { broken = !(cudaSupport -> (cudaPackagesGoogle ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1") || !(cudaSupport -> (cudaPackagesGoogle ? cudnn) && lib.versionAtLeast cudnn.version "8.2") - || !(cudaSupport -> stdenv.isLinux); + || !(cudaSupport -> stdenv.isLinux) + || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}")); }; }