diff --git a/pkgs/development/python-modules/rlax/default.nix b/pkgs/development/python-modules/rlax/default.nix new file mode 100644 index 000000000000..adff2f0ac5d3 --- /dev/null +++ b/pkgs/development/python-modules/rlax/default.nix @@ -0,0 +1,65 @@ +{ lib +, fetchPypi +, buildPythonPackage +, chex +, jaxlib +, tensorflow-probability +, optax +, dm-haiku +, bsuite +, frozendict +, pytestCheckHook +, dm-env +, distrax }: + +buildPythonPackage rec { + pname = "rlax"; + version = "0.1.2"; + + src = fetchPypi { + inherit pname version; + sha256 = "sha256-hAG0idz5VkGVvxaJWoxlVZ8myeHF6ndDxB0SyJm7qV8="; + }; + + buildInputs = [ + chex + jaxlib + distrax + tensorflow-probability + ]; + + checkInputs = [ + bsuite + dm-env + dm-haiku + frozendict + optax + pytestCheckHook + ]; + + pythonImportsCheck = [ + "rlax" + ]; + + disabledTests = [ + # RuntimeErrors + "test_cross_replica_scatter_add0" + "test_cross_replica_scatter_add1" + "test_cross_replica_scatter_add2" + "test_cross_replica_scatter_add3" + "test_cross_replica_scatter_add4" + "test_learn_scale_shift" + "test_normalize_unnormalize_is_identity" + "test_outputs_preserved" + "test_scale_bounded" + "test_slow_update" + "test_unnormalize_linear" + ]; + + meta = with lib; { + description = "Library of reinforcement learning building blocks in JAX"; + homepage = "https://github.com/deepmind/rlax"; + license = licenses.asl20; + maintainers = with maintainers; [ onny ]; + }; +} diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 2231b5cfa889..36ea5f508a17 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -9345,6 +9345,8 @@ in { rki-covid-parser = callPackage ../development/python-modules/rki-covid-parser { }; + rlax = callPackage ../development/python-modules/rlax { }; + rl-coach = callPackage ../development/python-modules/rl-coach { }; rlp = callPackage ../development/python-modules/rlp { };