Merge pull request #325222 from SomeoneSerge/fix/gpu-access/torch-bin
python3Packages.torch-bin: gpuChecks -> tests.tester-<name>.gpuCheck
This commit is contained in:
commit
4fbd4333e3
@ -121,7 +121,10 @@ buildPythonPackage {
|
||||
|
||||
pythonImportsCheck = [ "torch" ];
|
||||
|
||||
passthru.gpuChecks.cudaAvailable = callPackage ./test-cuda.nix { torch = torch-bin; };
|
||||
passthru.tests = callPackage ./tests.nix {
|
||||
torchWithCuda = torch-bin;
|
||||
torchWithRocm = torch-bin;
|
||||
};
|
||||
|
||||
meta = {
|
||||
description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
|
||||
|
@ -1,40 +0,0 @@
|
||||
{
|
||||
lib,
|
||||
torchWithCuda,
|
||||
torchWithRocm,
|
||||
callPackage,
|
||||
}:
|
||||
|
||||
let
|
||||
accelAvailable =
|
||||
{
|
||||
feature,
|
||||
versionAttr,
|
||||
torch,
|
||||
cudaPackages,
|
||||
}:
|
||||
cudaPackages.writeGpuTestPython
|
||||
{
|
||||
inherit feature;
|
||||
libraries = [ torch ];
|
||||
name = "${feature}Available";
|
||||
}
|
||||
''
|
||||
import torch
|
||||
message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
|
||||
assert torch.cuda.is_available() and torch.version.${versionAttr}, message
|
||||
print(message)
|
||||
'';
|
||||
in
|
||||
{
|
||||
tester-cudaAvailable = callPackage accelAvailable {
|
||||
feature = "cuda";
|
||||
versionAttr = "cuda";
|
||||
torch = torchWithCuda;
|
||||
};
|
||||
tester-rocmAvailable = callPackage accelAvailable {
|
||||
feature = "rocm";
|
||||
versionAttr = "hip";
|
||||
torch = torchWithRocm;
|
||||
};
|
||||
}
|
19
pkgs/development/python-modules/torch/mk-runtime-check.nix
Normal file
19
pkgs/development/python-modules/torch/mk-runtime-check.nix
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
cudaPackages,
|
||||
feature,
|
||||
torch,
|
||||
versionAttr,
|
||||
}:
|
||||
|
||||
cudaPackages.writeGpuTestPython
|
||||
{
|
||||
inherit feature;
|
||||
libraries = [ torch ];
|
||||
name = "${feature}Available";
|
||||
}
|
||||
''
|
||||
import torch
|
||||
message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
|
||||
assert torch.cuda.is_available() and torch.version.${versionAttr}, message
|
||||
print(message)
|
||||
''
|
@ -1,3 +1,21 @@
|
||||
{ callPackage }:
|
||||
{
|
||||
callPackage,
|
||||
torchWithCuda,
|
||||
torchWithRocm,
|
||||
}:
|
||||
|
||||
callPackage ./gpu-checks.nix { }
|
||||
{
|
||||
# To perform the runtime check use either
|
||||
# `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or
|
||||
# `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox)
|
||||
tester-cudaAvailable = callPackage ./mk-runtime-check.nix {
|
||||
feature = "cuda";
|
||||
versionAttr = "cuda";
|
||||
torch = torchWithCuda;
|
||||
};
|
||||
tester-rocmAvailable = callPackage ./mk-runtime-check.nix {
|
||||
feature = "rocm";
|
||||
versionAttr = "hip";
|
||||
torch = torchWithRocm;
|
||||
};
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user