Merge pull request #325222 from SomeoneSerge/fix/gpu-access/torch-bin

python3Packages.torch-bin: gpuChecks -> tests.tester-<name>.gpuCheck
This commit is contained in:
Someone 2024-07-07 19:25:27 +00:00 committed by GitHub
commit 4fbd4333e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 43 deletions

View File

@ -121,7 +121,10 @@ buildPythonPackage {
pythonImportsCheck = [ "torch" ];
passthru.gpuChecks.cudaAvailable = callPackage ./test-cuda.nix { torch = torch-bin; };
passthru.tests = callPackage ./tests.nix {
torchWithCuda = torch-bin;
torchWithRocm = torch-bin;
};
meta = {
description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";

View File

@ -1,40 +0,0 @@
{
lib,
torchWithCuda,
torchWithRocm,
callPackage,
}:
let
accelAvailable =
{
feature,
versionAttr,
torch,
cudaPackages,
}:
cudaPackages.writeGpuTestPython
{
inherit feature;
libraries = [ torch ];
name = "${feature}Available";
}
''
import torch
message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
assert torch.cuda.is_available() and torch.version.${versionAttr}, message
print(message)
'';
in
{
tester-cudaAvailable = callPackage accelAvailable {
feature = "cuda";
versionAttr = "cuda";
torch = torchWithCuda;
};
tester-rocmAvailable = callPackage accelAvailable {
feature = "rocm";
versionAttr = "hip";
torch = torchWithRocm;
};
}

View File

@ -0,0 +1,19 @@
{
cudaPackages,
feature,
torch,
versionAttr,
}:
cudaPackages.writeGpuTestPython
{
inherit feature;
libraries = [ torch ];
name = "${feature}Available";
}
''
import torch
message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
assert torch.cuda.is_available() and torch.version.${versionAttr}, message
print(message)
''

View File

@ -1,3 +1,21 @@
{ callPackage }:
{
callPackage,
torchWithCuda,
torchWithRocm,
}:
callPackage ./gpu-checks.nix { }
{
# To perform the runtime check use either
# `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or
# `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox)
tester-cudaAvailable = callPackage ./mk-runtime-check.nix {
feature = "cuda";
versionAttr = "cuda";
torch = torchWithCuda;
};
tester-rocmAvailable = callPackage ./mk-runtime-check.nix {
feature = "rocm";
versionAttr = "hip";
torch = torchWithRocm;
};
}