diff --git a/pkgs/development/python-modules/tinygrad/default.nix b/pkgs/development/python-modules/tinygrad/default.nix index 82a57f7d7f08..760b29c1adfc 100644 --- a/pkgs/development/python-modules/tinygrad/default.nix +++ b/pkgs/development/python-modules/tinygrad/default.nix @@ -1,12 +1,25 @@ { lib, + config, buildPythonPackage, fetchFromGitHub, + substituteAll, + addDriverRunpath, + cudaSupport ? config.cudaSupport, + rocmSupport ? config.rocmSupport, + cudaPackages, + ocl-icd, + stdenv, + rocmPackages, + # build-system setuptools, wheel, - gpuctypes, + # dependencies numpy, tqdm, + # nativeCheckInputs + clang, + hexdump, hypothesis, librosa, onnx, @@ -22,30 +35,67 @@ buildPythonPackage rec { pname = "tinygrad"; - version = "0.8.0"; + version = "0.9.0"; pyproject = true; src = fetchFromGitHub { owner = "tinygrad"; repo = "tinygrad"; rev = "refs/tags/v${version}"; - hash = "sha256-QAccZ79qUbe27yUykIf22WdkxYUlOffnMlShakKfp60="; + hash = "sha256-opBxciETZruZjHqz/3vO7rogzjvVJKItulIiok/Zs2Y="; }; - nativeBuildInputs = [ + patches = [ + (substituteAll { + src = ./fix-dlopen-cuda.patch; + inherit (addDriverRunpath) driverLink; + libnvrtc = + if cudaSupport then + "${lib.getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so" + else + "Please import nixpkgs with `config.cudaSupport = true`"; + }) + ]; + + postPatch = + '' + substituteInPlace tinygrad/runtime/autogen/opencl.py \ + --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'" + '' + # hipGetDevicePropertiesR0600 is a symbol from rocm-6. We are currently at rocm-5. + # We are not sure that this works. Remove when rocm gets updated to version 6. + + lib.optionalString rocmSupport '' + substituteInPlace extra/hip_gpu_driver/hip_ioctl.py \ + --replace-fail "processor = platform.processor()" "processor = ${stdenv.hostPlatform.linuxArch}" + substituteInPlace tinygrad/runtime/autogen/hip.py \ + --replace-fail "/opt/rocm/lib/libamdhip64.so" "${rocmPackages.clr}/lib/libamdhip64.so" \ + --replace-fail "/opt/rocm/lib/libhiprtc.so" "${rocmPackages.clr}/lib/libhiprtc.so" \ + --replace-fail "hipGetDevicePropertiesR0600" "hipGetDeviceProperties" + + substituteInPlace tinygrad/runtime/autogen/comgr.py \ + --replace-fail "/opt/rocm/lib/libamd_comgr.so" "${rocmPackages.rocm-comgr}/lib/libamd_comgr.so" + ''; + + build-system = [ setuptools wheel ]; - propagatedBuildInputs = [ - gpuctypes - numpy - tqdm - ]; + dependencies = + [ + numpy + tqdm + ] + ++ lib.optionals stdenv.isDarwin [ + # pyobjc-framework-libdispatch + # pyobjc-framework-metal + ]; pythonImportsCheck = [ "tinygrad" ]; nativeCheckInputs = [ + clang + hexdump hypothesis librosa onnx @@ -63,44 +113,60 @@ buildPythonPackage rec { export HOME=$(mktemp -d) ''; - disabledTests = [ - # Require internet access - "test_benchmark_openpilot_model" - "test_bn_alone" - "test_bn_linear" - "test_bn_mnist" - "test_car" - "test_chicken" - "test_chicken_bigbatch" - "test_conv_mnist" - "testCopySHMtoDefault" - "test_data_parallel_resnet" - "test_e2e_big" - "test_fetch_small" - "test_huggingface_enet_safetensors" - "test_linear_mnist" - "test_load_convnext" - "test_load_enet" - "test_load_enet_alt" - "test_load_llama2bfloat" - "test_load_resnet" - "test_openpilot_model" - "test_resnet" - "test_shufflenet" - "test_transcribe_batch12" - "test_transcribe_batch21" - "test_transcribe_file1" - "test_transcribe_file2" - "test_transcribe_long" - "test_transcribe_long_no_batch" - "test_vgg7" - ]; + disabledTests = + [ + # Require internet access + "test_benchmark_openpilot_model" + "test_bn_alone" + "test_bn_linear" + "test_bn_mnist" + "test_car" + "test_chicken" + "test_chicken_bigbatch" + "test_conv_mnist" + "testCopySHMtoDefault" + "test_data_parallel_resnet" + "test_e2e_big" + "test_fetch_small" + "test_huggingface_enet_safetensors" + "test_linear_mnist" + "test_load_convnext" + "test_load_enet" + "test_load_enet_alt" + "test_load_llama2bfloat" + "test_load_resnet" + "test_openpilot_model" + "test_resnet" + "test_shufflenet" + "test_transcribe_batch12" + "test_transcribe_batch21" + "test_transcribe_file1" + "test_transcribe_file2" + "test_transcribe_long" + "test_transcribe_long_no_batch" + "test_vgg7" + ] + # Fail on aarch64-linux with AssertionError + ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [ + "test_casts_to" + "test_casts_to" + "test_int8_to_uint16_negative" + "test_casts_to" + "test_casts_to" + "test_casts_from" + "test_casts_to" + "test_int8" + "test_casts_to" + ]; - disabledTestPaths = [ - "test/extra/test_lr_scheduler.py" - "test/models/test_mnist.py" - "test/models/test_real_world.py" - ]; + disabledTestPaths = + [ + # Require internet access + "test/models/test_mnist.py" + "test/models/test_real_world.py" + "test/testextra/test_lr_scheduler.py" + ] + ++ lib.optionals (!rocmSupport) [ "extra/hip_gpu_driver/" ]; meta = with lib; { description = "A simple and powerful neural network framework"; @@ -108,5 +174,7 @@ buildPythonPackage rec { changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}"; license = licenses.mit; maintainers = with maintainers; [ GaetanLepage ]; + # Requires unpackaged pyobjc-framework-libdispatch and pyobjc-framework-metal + broken = stdenv.isDarwin; }; } diff --git a/pkgs/development/python-modules/gpuctypes/fix-dlopen-cuda.patch b/pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch similarity index 67% rename from pkgs/development/python-modules/gpuctypes/fix-dlopen-cuda.patch rename to pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch index 8d3b69e35e11..6b77173b4ecc 100644 --- a/pkgs/development/python-modules/gpuctypes/fix-dlopen-cuda.patch +++ b/pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch @@ -1,9 +1,9 @@ -diff --git a/gpuctypes/cuda.py b/gpuctypes/cuda.py -index acba81c..aac5fc7 100644 ---- a/gpuctypes/cuda.py -+++ b/gpuctypes/cuda.py -@@ -143,9 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'): - +diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py +index 359083a9..3cd5f7be 100644 +--- a/tinygrad/runtime/autogen/cuda.py ++++ b/tinygrad/runtime/autogen/cuda.py +@@ -143,10 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'): + return ctypes.cast(string, ctypes.POINTER(ctypes.c_char)) +NAME_TO_PATHS = { @@ -19,9 +19,9 @@ index acba81c..aac5fc7 100644 + try: + return ctypes.CDLL(candidate) + except OSError: -+ pass ++ pass + raise RuntimeError(f"{name} not found") -+ + _libraries = {} -_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda')) -_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))