diff --git a/pkgs/development/python-modules/dm-haiku/default.nix b/pkgs/development/python-modules/dm-haiku/default.nix index 6ff6ff412b7e..cb97e2f837af 100644 --- a/pkgs/development/python-modules/dm-haiku/default.nix +++ b/pkgs/development/python-modules/dm-haiku/default.nix @@ -1,17 +1,27 @@ -{ buildPythonPackage +{ lib +, buildPythonPackage , fetchFromGitHub , fetchpatch , absl-py , flax -, numpy -, callPackage -, lib -, jmp -, tabulate , jaxlib +, jmp +, numpy +, tabulate +, pytest-xdist +, pytestCheckHook +, bsuite +, chex +, cloudpickle +, dill +, dm-env +, dm-tree +, optax +, rlax +, tensorflow }: -buildPythonPackage rec { +let dm-haiku = buildPythonPackage rec { pname = "dm-haiku"; version = "0.0.11"; format = "setuptools"; @@ -32,11 +42,6 @@ buildPythonPackage rec { }) ]; - outputs = [ - "out" - "testsout" - ]; - propagatedBuildInputs = [ absl-py flax @@ -50,17 +55,56 @@ buildPythonPackage rec { "haiku" ]; - postInstall = '' - mkdir $testsout - cp -R examples $testsout/examples - ''; + nativeCheckInputs = [ + bsuite + chex + cloudpickle + dill + dm-env + dm-haiku + dm-tree + jaxlib + optax + pytest-xdist + pytestCheckHook + rlax + tensorflow + ]; + + disabledTests = [ + # See https://github.com/deepmind/dm-haiku/issues/366. + "test_jit_Recurrent" + + # Assertion errors + "testShapeChecking0" + "testShapeChecking1" + + # This test requires a more recent version of tensorflow. The current one (2.13) is not enough. + "test_reshape_convert" + + # This test requires JAX support for double precision (64bit), but enabling this causes several + # other tests to fail. + # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + "test_doctest_haiku.experimental" + ]; + + disabledTestPaths = [ + # Those tests requires a more recent version of tensorflow. The current one (2.13) is not enough. + "haiku/_src/integration/jax2tf_test.py" + ]; - # check in passthru.tests.pytest to escape infinite recursion with bsuite doCheck = false; - passthru.tests = { - pytest = callPackage ./tests.nix { }; - }; + # check in passthru.tests.pytest to escape infinite recursion with bsuite + passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: { + pname = "${pname}-tests"; + doCheck = true; + + # We don't have to install because the only purpose + # of this passthru test is to, well, test. + # This fixes having to set `catchConflicts` to false. + dontInstall = true; + }); meta = with lib; { description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet."; @@ -68,4 +112,5 @@ buildPythonPackage rec { license = licenses.asl20; maintainers = with maintainers; [ ndl ]; }; -} +}; +in dm-haiku diff --git a/pkgs/development/python-modules/dm-haiku/tests.nix b/pkgs/development/python-modules/dm-haiku/tests.nix deleted file mode 100644 index 3a99bd3ac85a..000000000000 --- a/pkgs/development/python-modules/dm-haiku/tests.nix +++ /dev/null @@ -1,69 +0,0 @@ -{ buildPythonPackage -, dm-haiku -, chex -, cloudpickle -, dill -, dm-tree -, jaxlib -, pytest-xdist -, pytestCheckHook -, tensorflow -, bsuite -, frozendict -, dm-env -, scikit-image -, rlax -, distrax -, tensorflow-probability -, optax -}: - -buildPythonPackage { - pname = "dm-haiku-tests"; - format = "other"; - inherit (dm-haiku) version; - - src = dm-haiku.testsout; - - dontBuild = true; - dontInstall = true; - - nativeCheckInputs = [ - bsuite - chex - cloudpickle - dill - distrax - dm-env - dm-haiku - dm-tree - frozendict - jaxlib - pytest-xdist - pytestCheckHook - optax - rlax - scikit-image - tensorflow - tensorflow-probability - ]; - - disabledTests = [ - # See https://github.com/deepmind/dm-haiku/issues/366. - "test_jit_Recurrent" - # Assertion errors - "test_connect_conv_padding_function_same0" - "test_connect_conv_padding_function_valid0" - "test_connect_conv_padding_function_same1" - "test_connect_conv_padding_function_same2" - "test_connect_conv_padding_function_valid1" - "test_connect_conv_padding_function_valid2" - "test_invalid_axis_ListString" - "test_invalid_axis_String" - "test_simple_case" - "test_simple_case_with_scale" - "test_slice_axis" - "test_zero_inputs" - ]; - -}