python3Packages.jaxlib: fix darwin build

This commit is contained in:
Uri Baghin 2022-08-01 11:24:40 +10:00
parent 83bcf0cef1
commit 633d5cb10f
2 changed files with 48 additions and 13 deletions

View File

@ -8,9 +8,11 @@
, binutils
, buildBazelPackage
, buildPythonPackage
, cctools
, cython
, fetchFromGitHub
, git
, IOKit
, jsoncpp
, pybind11
, setuptools
@ -55,8 +57,11 @@ let
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ ndl ];
platforms = [ "x86_64-linux" "aarch64-darwin" "x86_64-darwin"];
hydraPlatforms = ["x86_64-linux" ]; # Don't think anybody is checking the darwin builds
platforms = platforms.unix;
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
# however even with that fix applied, it doesn't work for everyone:
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
broken = stdenv.isAarch64;
};
cudatoolkit_joined = symlinkJoin {
@ -117,6 +122,8 @@ let
] ++ lib.optionals cudaSupport [
cudatoolkit
cudnn
] ++ lib.optionals stdenv.isDarwin [
IOKit
];
postPatch = ''
@ -201,9 +208,7 @@ let
# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
bazelFetchFlags = bazel-build.bazelBuildFlags;
bazelBuildFlags = [
bazelFlags = [
"-c opt"
] ++ lib.optional (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
"--config=avx_posix"
@ -211,6 +216,11 @@ let
"--config=cuda"
] ++ lib.optional mklSupport [
"--config=mkl_open_source_only"
] ++ lib.optionals stdenv.cc.isClang [
# bazel depends on the compiler frontend automatically selecting these flags based on file
# extension but our clang doesn't.
# https://github.com/NixOS/nixpkgs/issues/150655
"--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
];
fetchAttrs = {
@ -218,7 +228,7 @@ let
if cudaSupport then
"sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo="
else
"sha256-6acSbBNcUBw177HMVOmpV7pUfP1aFSe5cP6/zWFdGFo=";
"sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo=";
};
buildAttrs = {
@ -226,8 +236,8 @@ let
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
# 1) Fix pybind11 include paths.
# 2) Force static protobuf linkage to prevent crashes on loading multiple extensions
# in the same python program due to duplicate protobuf DBs.
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
# 3) Patch python path in the compiler driver.
# 4) Patch tensorflow sources to work with later versions of protobuf. See
# https://github.com/google/jax/issues/9534. Note that this should be
@ -236,13 +246,25 @@ let
for src in ./jaxlib/*.{cc,h}; do
sed -i 's@include/pybind11@pybind11@g' $src
done
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
--replace "status.message()" "std::string{status.message()}"
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
--replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
--replace "/usr/bin/libtool" "${cctools}/bin/libtool"
'' + lib.optionalString cudaSupport ''
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'';
'' + lib.optionalString stdenv.isDarwin ''
# Framework search paths aren't added by bintools hook
# https://github.com/NixOS/nixpkgs/pull/41914
export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
'' + (if stdenv.cc.isGNU then ''
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
'' else if stdenv.cc.isClang then ''
sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
'' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
installPhase = ''
./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch}
@ -251,13 +273,21 @@ let
inherit meta;
};
platformTag =
if stdenv.targetPlatform.isLinux then
"manylinux2010_${stdenv.targetPlatform.linuxArch}"
else if stdenv.system == "x86_64-darwin" then
"macosx_10_9_${stdenv.targetPlatform.linuxArch}"
else if stdenv.system == "aarch64-darwin" then
"macosx_11_0_${stdenv.targetPlatform.linuxArch}"
else throw "Unsupported target platform: ${stdenv.targetPlatform}";
in
buildPythonPackage {
inherit meta pname version;
format = "wheel";
src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-${platformTag}.whl";
# Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for

View File

@ -4558,12 +4558,17 @@ in {
cudaPackages = pkgs.cudaPackages_11_6;
};
jaxlib-build = callPackage ../development/python-modules/jaxlib {
jaxlib-build = callPackage ../development/python-modules/jaxlib rec {
inherit (pkgs.darwin) cctools;
buildBazelPackage = pkgs.buildBazelPackage.override {
stdenv = if stdenv.isDarwin then pkgs.darwin.apple_sdk_11_0.stdenv else stdenv;
};
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
cudaSupport = pkgs.config.cudaSupport or false;
# At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
# pin to `cudaPackages_11_6` instead.
cudaPackages = pkgs.cudaPackages_11_6;
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
};
jaxlib = self.jaxlib-build;