Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX ecosystem: jaxlib should never be in propagatedBuildInputs #156808

Merged
merged 7 commits into from
Jan 28, 2022

Conversation

samuela
Copy link
Member

@samuela samuela commented Jan 26, 2022

Motivation for this change

Fix #156767.

There are multiple possible "versions" of the jaxlib package: CPU-only, CUDA-enabled, and TPU-enabled. Currently the jaxlib package defaults to the smallest (and most FOSS) version which is CPU-only. CUDA support can be enabled by end users via the cudaSupport = true option.

This presents an issue when using packages that depend on jax/jaxlib however. If the packages depend directly on jaxlib by adding them to propagatedBuildInputs then the end-user is stuck with the CPU-only version of jaxlib, even though the CUDA version may be desired. Using overlays is one possible workaround but unfortunately that requires every end-user to recompile the entire JAX ecosystem locally whenever they want to use a non-CPU jaxlib.

This PR removes jaxlib from any and all propagatedBuildInputs and patches packages that require jaxlib as a setup.py dependency. AFAIU this is in line with the JAX authors' intention to separate the jax and jaxlib packages. In fact, jax itself does not even have a setup.py dependency on jaxlib.

For future debugging purposes, an easy way to test if you have conflicting jaxlib versions is

echo $PYTHONPATH | tr ':' '\n' | grep jaxlib

This command should return at most one entry at all times. I've verified that the changes in this PR do in fact resolve this issue based on the shell.nix used in #156767.

Assuming we reach consensus that this is the right path forward, I will document this practice on the wiki page.

Things done
  • Built on platform(s)
    • x86_64-linux
    • aarch64-linux
    • x86_64-darwin
    • aarch64-darwin
  • For non-Linux: Is sandbox = true set in nix.conf? (See Nix manual)
  • Tested, as applicable:
  • Tested compilation of all packages that depend on this change using nix-shell -p nixpkgs-review --run "nixpkgs-review rev HEAD". Note: all changes have to be committed, also see nixpkgs-review usage
  • Tested basic functionality of all binary files (usually in ./result/bin/)
  • 22.05 Release Notes (or backporting 21.11 Release notes)
    • (Package updates) Added a release notes entry if the change is major or breaking
    • (Module updates) Added a release notes entry if the change is significant
    • (Module addition) Added a release notes entry if adding a new NixOS module
    • (Release notes changes) Ran nixos/doc/manual/md-to-db.sh to update generated release notes
  • Fits CONTRIBUTING.md.

@samuela
Copy link
Member Author

samuela commented Jan 26, 2022

Result of nixpkgs-review pr 156808 run on x86_64-linux 1

7 packages failed to build:
  • python310Packages.dm-haiku
  • python310Packages.elegy
  • python310Packages.flax
  • python310Packages.objax
  • python310Packages.optax
  • python310Packages.treex
  • python39Packages.objax
11 packages built:
  • python310Packages.chex
  • python310Packages.jmp
  • python310Packages.treeo
  • python39Packages.chex
  • python39Packages.dm-haiku
  • python39Packages.elegy
  • python39Packages.flax
  • python39Packages.jmp
  • python39Packages.optax
  • python39Packages.treeo
  • python39Packages.treex

@samuela
Copy link
Member Author

samuela commented Jan 26, 2022

I'm not worried about python310 failures as that entire package set is mostly broken atm. python39Packages.objax builds successfully outside of nixpkgs-review for me. No idea why nixpkgs-review keeps doing this.

cc @Mic92 for any thoughts?

Copy link
Member

@FRidh FRidh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please report these issues also with the upstream users of jaxlib so they can fix their requirements.

@ndl
Copy link
Contributor

ndl commented Jan 26, 2022

For the packages that explicitly depend on jaxlib in their setup.py: IIUC the alternative to patches is to move jaxlib from propagatedBuildInputs to buildInputs? That seems sufficient to satisfy setup.py requirements, have imports checks and tests pass for the package itself - while still not propagating jaxlib to the dependent packages?

@samuela
Copy link
Member Author

samuela commented Jan 26, 2022

@ndl If that's the case it would certainly make my life easier, and does seem cleaner on the nix side of things. I'll give it a shot and see if I can get it to work!

@samuela samuela force-pushed the samuela/fix-jaxlib branch from 5a85f26 to 93f88be Compare January 27, 2022 20:52
@samuela
Copy link
Member Author

samuela commented Jan 27, 2022

I tested the checkInputs solution as suggested by @ndl and sure enough it works! This PR is ready to review/merge. New commits pushed (93f88be).

objax did not need any changes after #156909.

Copy link
Contributor

@ndl ndl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks for fixing this!

@samuela
Copy link
Member Author

samuela commented Jan 27, 2022

Thanks @ndl! I'll merge tomorrow unless anyone objects

@samuela samuela merged commit d8ff157 into NixOS:master Jan 28, 2022
@samuela samuela deleted the samuela/fix-jaxlib branch January 28, 2022 19:05
@nixos-discourse
Copy link

This pull request has been mentioned on NixOS Discourse. There might be relevant details there:

https://discourse.nixos.org/t/petition-to-build-and-cache-unfree-packages-on-cache-nixos-org/17440/1

@dotlambda
Copy link
Member

This is not the solution to go with! If people want to use a different version of jaxlib they can resort to python's package overrides.

@samuela
Copy link
Member Author

samuela commented Feb 14, 2022

@dotlambda there are two main problems with overrides/overlays:

  • overriden packages do not get built, tested, and cached by hydra or cache.nixos.org
  • the JAX authors intentionally separated jax and jaxlib for just this reason so that users have the freedom to select which jaxlib version suits them best

@dotlambda
Copy link
Member

  • overriden packages do not get built, tested, and cached by hydra or cache.nixos.org

neither does the unfree jaxlibWithCuda

  • the JAX authors intentionally separated jax and jaxlib for just this reason so that users have the freedom to select which jaxlib version suits them best

And people easily get that freedom:

let
  python = python3.override {
    packageOverrides = self: super: {
      jaxlib = self.jaxlibWithCuda;
    };
  };
in python.pkgs.elegy

I suggest that this is put in the wiki instead of the current solution.

@samuela
Copy link
Member Author

samuela commented Feb 15, 2022

neither does the unfree jaxlibWithCuda

True, although this makes things much easier for end users maintaining their own CI infrastructure. For example, nixpkgs-upkeep builds jaxlibWithCuda automatically and pushes it to my public cachix cache.

And people easily get that freedom:

Overlays are a constant source of confusion for users (see eg https://discourse.nixos.org/t/what-are-overlays/14680), and are not nearly as ergonomic as simply jaxlibWithCuda. At minimum, any solution here must work for users writing one-off scripts with nix-shell shebang lines.

By the way, there is extensive precedent for this pattern in nixpkgs already: tensorflowWithCuda, katagoWithCuda, librealsenseWithCuda, TheanoWithCuda, colmapWithCuda, etc.

@dotlambda
Copy link
Member

cc @jonringer @FRidh @mweinelt

@ndl
Copy link
Contributor

ndl commented Feb 15, 2022

FWIW I've actually used the overlay / override solution in the past for a locally-maintained CUDA version of jaxlib but I like @samuela approach better for a number of reasons:

  • No need for (not fully trivial) rebuilds of all JAX-related packages when switching jaxlib variants.
  • Very clear and intuitive composability: just specify which jaxlib version you want in your nix-shell or derivation and you're done (compare that to Python's packageOverrides isn't composable #44426 and some of the solutions proposed).
  • Matches the upstream approach - jax doesn't depend on jaxlib in its core set of dependencies.
  • If someone prefers to have jax bundled with jaxlib - then with the current approach it's pretty easy to produce such a derivation by composing jax and jaxlib of one's choice together, and then for all intents and purposes it will behave as if these were bundled originally, whereas doing the opposite (= unbundling them if they were originally bundled) is more difficult.

@dotlambda are there any particular issues that you see with this approach that outweigh the advantages listed above?

@dotlambda
Copy link
Member

are there any particular issues that you see with this approach that outweigh the advantages listed above?

Any package that depends on a package that depends on a package that ... that uses jax needs to have jaxlib in its buildInputs.

If you want to be able to switch out packages without rebuilding its reverse dependencies (like upstream wants you to be able to) then you should use something other than Nix.

@samuela
Copy link
Member Author

samuela commented Feb 15, 2022

If you want to be able to switch out packages without rebuilding its reverse dependencies (like upstream wants you to be able to) then you should use something other than Nix.

I don't feel like this is a practical solution to any user's problems. I feel as though the message here has devolved a bit into "no, not that way" without offering an alternative solution that actually meets our needs. As maintainers (and users!) of the JAX ecosystem, @ndl and I have found a solution that works very very well for us. I'm open to entertaining alternative designs to meet our specific needs, but we need concrete ideas not just "don't do it".

@ndl
Copy link
Contributor

ndl commented Feb 15, 2022

Any package that depends on a package that depends on a package that ... that uses jax needs to have jaxlib in its buildInputs.

Right now I think the split is actually closer to 50 / 50: 4 packages (chex, dm-haiku, jmp, treeo) don't have jaxlib dependency in their requirements (and, correspondingly, don't need it in buildInputs) while 5 packages (elegy, flax, objax, optax, treex) do. Out of these 5, I believe exactly zero actually need that dependency at build time => the case can be made for raising the corresponding issues with the upstream to remove this dependency.

@dotlambda
Copy link
Member

we need concrete ideas not just "don't do it".

What about using packageOverrides in your shell.nix wouldn't work?
If you want cached packages then you could put all packages built with jaxlibWithCuda on cachix.

I believe exactly zero actually need that dependency at build time => the case can be made for raising the corresponding issues with the upstream to remove this dependency.

propagatedBuildInputs is for run time dependencies, not build time

@dotlambda
Copy link
Member

This is a debate between purity and convenience and I'd go with purity because I don't wanna be hit by bad decisions later on.

@samuela
Copy link
Member Author

samuela commented Feb 15, 2022

What about using packageOverrides in your shell.nix wouldn't work?

packageOverrides is too complex. Esp. for what is effectively the normative use-case of this particular piece of software; the vast majority of people want to use jaxlibWithCuda, not jaxlib. The average ML researcher is not excited about building/packaging/installing software. They want something that "just works". It can't be more complicated than the pip install that they're accustomed to.

packageOverrides also presents an issue when writing scripts with nix-shell shebang lines. Even if it's technically possible, it would certainly be messy.

If you want cached packages then you could put all packages built with jaxlibWithCuda on cachix.

This means every package in the JAX ecosystem ever. It's just too much. Even forgetting about the compute costs, it would make CI much messier. Especially considering that upstream was very intentional in making the design decision to avoid this problem.

Bear in mind that we've only been considering the CPU and GPU versions of jaxlib, but there's a TPU version as well that's not even packaged yet! As new specialized DL hardware is developed other versions will likely be created as well.

This is a debate between purity and convenience and I'd go with purity because I don't wanna be hit by bad decisions later on.

As with any software project, it's always possible that we may run into unforeseen complications in the future. And if/when that happens, I think it would be a more appropriate time to have this conversation. The current design works well for the road bumps that we've hit up to this point. And as @ndl mentioned, the modularity of our current design is strictly more flexible than putting jaxlib in propagatedBuildInputs, meaning that the transition in that direction would be much easier than the other way around.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

python3Packages.jaxlib: NVIDIA GPU not detected correctly
5 participants