Merge pull request #328661 from SomeoneSerge/feat/torch-compile-test

python312Packages.torch: test torch.compile
This commit is contained in:
Someone 2024-07-20 19:02:47 +03:00 committed by GitHub
commit 10f0788379
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 127 additions and 41 deletions

View File

@ -1,29 +0,0 @@
{
lib,
writers,
runCommand,
}:
{
feature ? "cuda",
name ? feature,
libraries ? [ ],
}:
content:
let
tester = writers.writePython3Bin "tester-${name}" { inherit libraries; } content;
tester' = tester.overrideAttrs (oldAttrs: {
passthru.gpuCheck =
runCommand "test-${name}"
{
nativeBuildInputs = [ tester' ];
requiredSystemFeatures = [ feature ];
}
''
set -e
${tester.meta.mainProgram or (lib.getName tester')}
touch $out
'';
});
in
tester'

View File

@ -0,0 +1,66 @@
{
lib,
runCommand,
python3Packages,
makeWrapper,
}:
{
feature ? "cuda",
name ? if feature == null then "cpu" else feature,
libraries ? [ ], # [PythonPackage] | (PackageSet -> [PythonPackage])
...
}@args:
let
inherit (builtins) isFunction all;
librariesFun = if isFunction libraries then libraries else (_: libraries);
in
assert lib.assertMsg (
isFunction libraries || all (python3Packages.hasPythonModule) libraries
) "writeGpuTestPython was passed `libraries` from the wrong python release";
content:
let
interpreter = python3Packages.python.withPackages librariesFun;
tester =
runCommand "tester-${name}"
(
lib.removeAttrs args [
"libraries"
"name"
]
// {
inherit content;
nativeBuildInputs = args.nativeBuildInputs or [ ] ++ [ makeWrapper ];
passAsFile = args.passAsFile or [ ] ++ [ "content" ];
}
)
''
mkdir -p "$out"/bin
cat << EOF >"$out/bin/$name"
#!${lib.getExe interpreter}
EOF
cat "$contentPath" >>"$out/bin/$name"
chmod +x "$out/bin/$name"
if [[ -n "''${makeWrapperArgs+''${makeWrapperArgs[@]}}" ]] ; then
wrapProgram "$out/bin/$name" ''${makeWrapperArgs[@]}
fi
'';
tester' = tester.overrideAttrs (oldAttrs: {
passthru.gpuCheck =
runCommand "test-${name}"
{
nativeBuildInputs = [ tester' ];
requiredSystemFeatures = lib.optionals (feature != null) [ feature ];
}
''
set -e
${tester.meta.mainProgram or (lib.getName tester')}
touch $out
'';
});
in
tester'

View File

@ -1,14 +1,15 @@
{
cudaPackages,
feature,
torch,
libraries,
versionAttr,
pythonPackages,
}:
cudaPackages.writeGpuTestPython
(cudaPackages.writeGpuTestPython.override { python3Packages = pythonPackages; })
{
inherit feature;
libraries = [ torch ];
inherit libraries;
name = "${feature}Available";
}
''

View File

@ -0,0 +1,38 @@
{
cudaPackages,
feature ? null,
lib,
libraries,
name ? if feature == null then "torch-compile-cpu" else "torch-compile-${feature}",
pythonPackages,
stdenv,
}:
let
deviceStr = if feature == null then "" else '', device="cuda"'';
in
(cudaPackages.writeGpuTestPython.override { python3Packages = pythonPackages; })
{
inherit name feature libraries;
makeWrapperArgs = [
"--suffix"
"PATH"
":"
"${lib.getBin stdenv.cc}/bin"
];
}
''
import torch
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(
opt_foo2(
torch.randn(10, 10${deviceStr}),
torch.randn(10, 10${deviceStr})))
''

View File

@ -1,21 +1,31 @@
{
callPackage,
torchWithCuda,
torchWithRocm,
}:
{ callPackage }:
{
rec {
# To perform the runtime check use either
# `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or
# `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox)
tester-cudaAvailable = callPackage ./mk-runtime-check.nix {
feature = "cuda";
versionAttr = "cuda";
torch = torchWithCuda;
libraries = ps: [ ps.torchWithCuda ];
};
tester-rocmAvailable = callPackage ./mk-runtime-check.nix {
feature = "rocm";
versionAttr = "hip";
torch = torchWithRocm;
libraries = ps: [ ps.torchWithRocm ];
};
compileCpu = tester-compileCpu.gpuCheck;
tester-compileCpu = callPackage ./mk-torch-compile-check.nix {
feature = null;
libraries = ps: [ ps.torch ];
};
tester-compileCuda = callPackage ./mk-torch-compile-check.nix {
feature = "cuda";
libraries = ps: [ ps.torchWithCuda ];
};
tester-compileRocm = callPackage ./mk-torch-compile-check.nix {
feature = "rocm";
libraries = ps: [ ps.torchWithRocm ];
};
}

View File

@ -81,7 +81,7 @@ let
nccl = final.callPackage ../development/cuda-modules/nccl { };
nccl-tests = final.callPackage ../development/cuda-modules/nccl-tests { };
writeGpuTestPython = final.callPackage ../development/cuda-modules/write-gpu-python-test.nix { };
writeGpuTestPython = final.callPackage ../development/cuda-modules/write-gpu-test-python.nix { };
});
mkVersionedPackageName =