diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index c7e84e4c11a2..34b9c429caf6 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -18,11 +18,12 @@ , autoPatchelfHook , buildPythonPackage , config -, cudnn ? cudaPackages.cudnn +, fetchPypi , fetchurl , flatbuffers -, isPy39 +, jaxlib , lib +, ml-dtypes , python , scipy , stdenv @@ -35,46 +36,57 @@ let inherit (cudaPackages) cudatoolkit cudnn; in -assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; -assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; +assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; let - version = "0.4.4"; + version = "0.4.12"; + + inherit (python) pythonVersion; + + # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the + # official instructions recommend installing CPU-only versions via PyPI. + cpuSrcs = + let + getSrcFromPypi = { platform, hash }: fetchPypi { + inherit version platform hash; + pname = "jaxlib"; + format = "wheel"; + # See the `disabled` attr comment below. + dist = "cp310"; + python = "cp310"; + abi = "cp310"; + }; + in + { + "x86_64-linux" = getSrcFromPypi { + platform = "manylinux2014_x86_64"; + hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc="; + }; + "aarch64-darwin" = getSrcFromPypi { + platform = "macosx_11_0_arm64"; + hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo="; + }; + "x86_64-darwin" = getSrcFromPypi { + platform = "macosx_10_14_x86_64"; + hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c="; + }; + }; - pythonVersion = python.pythonVersion; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See - # https://github.com/google/jax/issues/12879 as to why this specific URL is - # the correct index. - cpuSrcs = { - "x86_64-linux" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ="; - }; - "aarch64-darwin" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; - hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U="; - }; - "x86_64-darwin" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl"; - hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok="; - }; + # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. + gpuSrc = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo="; }; - gpuSrc = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk="; - }; in -buildPythonPackage rec { +buildPythonPackage { pname = "jaxlib"; inherit version; format = "wheel"; - # At the time of writing (2022-10-19), there are releases for <=3.10. - # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs - # python version. disabled = !(pythonVersion == "3.10"); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. @@ -87,9 +99,10 @@ buildPythonPackage rec { # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. - nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; + nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] + ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; # Dynamic link dependencies - buildInputs = [ stdenv.cc.cc ]; + buildInputs = [ stdenv.cc.cc.lib ]; # jaxlib contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or @@ -113,7 +126,12 @@ buildPythonPackage rec { done ''; - propagatedBuildInputs = [ absl-py flatbuffers scipy ]; + propagatedBuildInputs = [ + absl-py + flatbuffers + ml-dtypes + scipy + ]; # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for @@ -123,7 +141,7 @@ buildPythonPackage rec { ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas ''; - pythonImportsCheck = [ "jaxlib" ]; + inherit (jaxlib) pythonImportsCheck; meta = with lib; { description = "XLA library for JAX"; diff --git a/pkgs/development/python-modules/jaxlib/prefetch.sh b/pkgs/development/python-modules/jaxlib/prefetch.sh index 31db6530639f..3362e2d0b781 100755 --- a/pkgs/development/python-modules/jaxlib/prefetch.sh +++ b/pkgs/development/python-modules/jaxlib/prefetch.sh @@ -1,7 +1,15 @@ -version="$1" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)" +#!/usr/bin/env bash + +prefetch () { + expr="(import { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url" + url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r) + echo "$url" + sha256=$(nix-prefetch-url "$url") + nix hash to-sri --type sha256 "$sha256" + echo +} + +prefetch "x86_64-linux" "false" +prefetch "aarch64-darwin" "false" +prefetch "x86_64-darwin" "false" +prefetch "x86_64-linux" "true"