Skip to content

Commit

Permalink
Merge pull request #156808 from samuela/samuela/fix-jaxlib
Browse files Browse the repository at this point in the history
JAX ecosystem: `jaxlib` should never be in `propagatedBuildInputs`
  • Loading branch information
samuela authored Jan 28, 2022
2 parents 65a33fe + 93f88be commit d8ff157
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 5 deletions.
2 changes: 2 additions & 0 deletions pkgs/development/python-modules/dm-haiku/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
, dill
, dm-tree
, fetchFromGitHub
, jaxlib
, jmp
, lib
, pytestCheckHook
Expand Down Expand Up @@ -31,6 +32,7 @@ buildPythonPackage rec {
chex
cloudpickle
dm-tree
jaxlib
pytestCheckHook
tensorflow
];
Expand Down
3 changes: 3 additions & 0 deletions pkgs/development/python-modules/elegy/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
, deepmerge
, dm-haiku
, fetchFromGitHub
, jaxlib
, lib
, poetry
, pytestCheckHook
Expand Down Expand Up @@ -35,6 +36,8 @@ buildPythonPackage rec {
poetry
];

buildInputs = [ jaxlib ];

propagatedBuildInputs = [
cloudpickle
deepdish
Expand Down
3 changes: 3 additions & 0 deletions pkgs/development/python-modules/flax/default.nix
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{ buildPythonPackage
, fetchFromGitHub
, jaxlib
, keras
, lib
, matplotlib
Expand All @@ -21,6 +22,8 @@ buildPythonPackage rec {
sha256 = "0zvq0vl88hiwmss49bnm7gdmndr1dfza2bcs1fj88a9r7w9dmlsr";
};

buildInputs = [ jaxlib ];

propagatedBuildInputs = [
matplotlib
msgpack
Expand Down
4 changes: 2 additions & 2 deletions pkgs/development/python-modules/jmp/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ buildPythonPackage rec {
sha256 = "0hh4cmp93wjyidj48gh07vhx2kjvpwd23xvy79bsjn5qaaf6q4cm";
};

# Wheel requires only `numpy`, but the import needs both `jax` and `jaxlib`.
# Wheel requires only `numpy`, but the import needs `jax`.
propagatedBuildInputs = [
jax
jaxlib
];

pythonImportsCheck = [
"jmp"
];

checkInputs = [
jaxlib
pytestCheckHook
];

Expand Down
3 changes: 2 additions & 1 deletion pkgs/development/python-modules/optax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ buildPythonPackage rec {
sha256 = "1q8cxc42a5xais2ll1l238cnn3l7w28savhgiz0lg01ilz2ysbli";
};

buildInputs = [ jaxlib ];

propagatedBuildInputs = [
absl-py
chex
jaxlib
numpy
];

Expand Down
4 changes: 2 additions & 2 deletions pkgs/development/python-modules/treeo/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ buildPythonPackage rec {
poetry-core
];

# These deps are not needed for the wheel, but required during the import.
# jax is not declared in the dependencies, but is necessary.
propagatedBuildInputs = [
jax
jaxlib
];

checkInputs = [ jaxlib ];
pythonImportsCheck = [
"treeo"
];
Expand Down
3 changes: 3 additions & 0 deletions pkgs/development/python-modules/treex/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
, fetchFromGitHub
, flax
, hypothesis
, jaxlib
, keras
, lib
, poetry-core
Expand Down Expand Up @@ -38,6 +39,8 @@ buildPythonPackage rec {
poetry-core
];

buildInputs = [ jaxlib ];

propagatedBuildInputs = [
einops
flax
Expand Down

0 comments on commit d8ff157

Please sign in to comment.