python3Packages.jaxlib: fix darwin build
This commit is contained in:
parent
83bcf0cef1
commit
633d5cb10f
@ -8,9 +8,11 @@
|
||||
, binutils
|
||||
, buildBazelPackage
|
||||
, buildPythonPackage
|
||||
, cctools
|
||||
, cython
|
||||
, fetchFromGitHub
|
||||
, git
|
||||
, IOKit
|
||||
, jsoncpp
|
||||
, pybind11
|
||||
, setuptools
|
||||
@ -55,8 +57,11 @@ let
|
||||
homepage = "https://github.com/google/jax";
|
||||
license = licenses.asl20;
|
||||
maintainers = with maintainers; [ ndl ];
|
||||
platforms = [ "x86_64-linux" "aarch64-darwin" "x86_64-darwin"];
|
||||
hydraPlatforms = ["x86_64-linux" ]; # Don't think anybody is checking the darwin builds
|
||||
platforms = platforms.unix;
|
||||
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
|
||||
# however even with that fix applied, it doesn't work for everyone:
|
||||
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
|
||||
broken = stdenv.isAarch64;
|
||||
};
|
||||
|
||||
cudatoolkit_joined = symlinkJoin {
|
||||
@ -117,6 +122,8 @@ let
|
||||
] ++ lib.optionals cudaSupport [
|
||||
cudatoolkit
|
||||
cudnn
|
||||
] ++ lib.optionals stdenv.isDarwin [
|
||||
IOKit
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
@ -201,9 +208,7 @@ let
|
||||
|
||||
# Make sure Bazel knows about our configuration flags during fetching so that the
|
||||
# relevant dependencies can be downloaded.
|
||||
bazelFetchFlags = bazel-build.bazelBuildFlags;
|
||||
|
||||
bazelBuildFlags = [
|
||||
bazelFlags = [
|
||||
"-c opt"
|
||||
] ++ lib.optional (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
|
||||
"--config=avx_posix"
|
||||
@ -211,6 +216,11 @@ let
|
||||
"--config=cuda"
|
||||
] ++ lib.optional mklSupport [
|
||||
"--config=mkl_open_source_only"
|
||||
] ++ lib.optionals stdenv.cc.isClang [
|
||||
# bazel depends on the compiler frontend automatically selecting these flags based on file
|
||||
# extension but our clang doesn't.
|
||||
# https://github.com/NixOS/nixpkgs/issues/150655
|
||||
"--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
|
||||
];
|
||||
|
||||
fetchAttrs = {
|
||||
@ -218,7 +228,7 @@ let
|
||||
if cudaSupport then
|
||||
"sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo="
|
||||
else
|
||||
"sha256-6acSbBNcUBw177HMVOmpV7pUfP1aFSe5cP6/zWFdGFo=";
|
||||
"sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo=";
|
||||
};
|
||||
|
||||
buildAttrs = {
|
||||
@ -226,8 +236,8 @@ let
|
||||
|
||||
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
|
||||
# 1) Fix pybind11 include paths.
|
||||
# 2) Force static protobuf linkage to prevent crashes on loading multiple extensions
|
||||
# in the same python program due to duplicate protobuf DBs.
|
||||
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
|
||||
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
|
||||
# 3) Patch python path in the compiler driver.
|
||||
# 4) Patch tensorflow sources to work with later versions of protobuf. See
|
||||
# https://github.com/google/jax/issues/9534. Note that this should be
|
||||
@ -236,13 +246,25 @@ let
|
||||
for src in ./jaxlib/*.{cc,h}; do
|
||||
sed -i 's@include/pybind11@pybind11@g' $src
|
||||
done
|
||||
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
|
||||
--replace "status.message()" "std::string{status.message()}"
|
||||
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
|
||||
--replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
|
||||
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
|
||||
--replace "/usr/bin/libtool" "${cctools}/bin/libtool"
|
||||
'' + lib.optionalString cudaSupport ''
|
||||
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
'';
|
||||
'' + lib.optionalString stdenv.isDarwin ''
|
||||
# Framework search paths aren't added by bintools hook
|
||||
# https://github.com/NixOS/nixpkgs/pull/41914
|
||||
export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
|
||||
'' + (if stdenv.cc.isGNU then ''
|
||||
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
'' else if stdenv.cc.isClang then ''
|
||||
sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
'' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
|
||||
|
||||
installPhase = ''
|
||||
./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch}
|
||||
@ -251,13 +273,21 @@ let
|
||||
|
||||
inherit meta;
|
||||
};
|
||||
platformTag =
|
||||
if stdenv.targetPlatform.isLinux then
|
||||
"manylinux2010_${stdenv.targetPlatform.linuxArch}"
|
||||
else if stdenv.system == "x86_64-darwin" then
|
||||
"macosx_10_9_${stdenv.targetPlatform.linuxArch}"
|
||||
else if stdenv.system == "aarch64-darwin" then
|
||||
"macosx_11_0_${stdenv.targetPlatform.linuxArch}"
|
||||
else throw "Unsupported target platform: ${stdenv.targetPlatform}";
|
||||
|
||||
in
|
||||
buildPythonPackage {
|
||||
inherit meta pname version;
|
||||
format = "wheel";
|
||||
|
||||
src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
|
||||
src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-${platformTag}.whl";
|
||||
|
||||
# Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
|
||||
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
|
||||
|
@ -4558,12 +4558,17 @@ in {
|
||||
cudaPackages = pkgs.cudaPackages_11_6;
|
||||
};
|
||||
|
||||
jaxlib-build = callPackage ../development/python-modules/jaxlib {
|
||||
jaxlib-build = callPackage ../development/python-modules/jaxlib rec {
|
||||
inherit (pkgs.darwin) cctools;
|
||||
buildBazelPackage = pkgs.buildBazelPackage.override {
|
||||
stdenv = if stdenv.isDarwin then pkgs.darwin.apple_sdk_11_0.stdenv else stdenv;
|
||||
};
|
||||
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
|
||||
cudaSupport = pkgs.config.cudaSupport or false;
|
||||
# At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
|
||||
# pin to `cudaPackages_11_6` instead.
|
||||
cudaPackages = pkgs.cudaPackages_11_6;
|
||||
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
|
||||
};
|
||||
|
||||
jaxlib = self.jaxlib-build;
|
||||
|
Loading…
Reference in New Issue
Block a user