Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qml.wire.Wires accepts JAX arrays #6312

Merged
merged 26 commits into from
Oct 7, 2024
Merged

Conversation

PietropaoloFrisoni
Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni commented Sep 30, 2024

Context: In JAX JIT/plxpr/qjit workflows, it is quite easy for wires to end up with jax array values. JAX arrays are not hashable and are rejected. Still, we can treat them as iterable and ensure they are accepted as wires.

Furthermore, if we try to use JAX tracers as wires in PennyLane using JAX 0.4.30+, we currently get FutureWarning in some of the tests:

FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.

This future warning was introduced in JAX 0.4.30 and is still not implemented as an error in 0.4.32, but both PL and Catalyst currently use version 0.4.28.

Description of the Change: We use the to_list and item methods of JAX arrays as a trick to make them tuples before hashing them to ensure that they are accepted as iterable by the Wire class. Furthermore, we decided to implement in PL the same patch adopted in Catalyst to avoid the warning reported above as temporary solution.

Benefits: qml.wire.Wires can accept JAX arrays as input.

Possible Drawbacks: The first change should only be relevant when jax.numpy.arrays are explicitly provided. Therefore, we should not face problems with the rest of the code.
As for the second change, it is highly recommended to re-visit the way JAX tracers are handled in the PL pipeline since in a future JAX version we will not be able to hash tracers anymore.

Related GitHub Issues: None.

Related Shortcut Stories [sc-72593] [sc-74904]

@PietropaoloFrisoni PietropaoloFrisoni marked this pull request as ready for review September 30, 2024 22:50
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

Copy link

codecov bot commented Sep 30, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.39%. Comparing base (f1ab2bc) to head (0b6346e).
Report is 336 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6312   +/-   ##
=======================================
  Coverage   99.39%   99.39%           
=======================================
  Files         445      445           
  Lines       42286    42315   +29     
=======================================
+ Hits        42031    42060   +29     
  Misses        255      255           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

pennylane/wires.py Outdated Show resolved Hide resolved
pennylane/wires.py Outdated Show resolved Hide resolved
@albi3ro albi3ro self-requested a review October 2, 2024 13:25
pennylane/wires.py Outdated Show resolved Hide resolved
Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@PietropaoloFrisoni
Copy link
Contributor Author

@albi3ro @dime10 Since in the end we decided to use the same Catalyst trick, I will close this PR so that it is faster to make all changes in one single PR

@PietropaoloFrisoni PietropaoloFrisoni merged commit c87a498 into master Oct 7, 2024
37 checks passed
@PietropaoloFrisoni PietropaoloFrisoni deleted the Wires_accept_JAX_array branch October 7, 2024 18:21
austingmhuang pushed a commit that referenced this pull request Oct 23, 2024
**Context:** In JAX JIT/plxpr/qjit workflows, it is quite easy for wires
to end up with jax array values. JAX arrays are not hashable and are
rejected. Still, we can treat them as iterable and ensure they are
accepted as wires.

**Description of the Change:** We use the `to_list` and `item` methods
of `JAX` arrays as a trick to make them tuples before hashing them.

**Benefits:** `qml.wire.Wires` can accept `JAX` arrays as input.

**Possible Drawbacks:** The change should only be relevant when
`jax.numpy.array`s are explicitly provided. Therefore, in principle, we
should not face problems with the rest of the code.

**Related GitHub Issues:** None.

**Related Shortcut Stories** [sc-72593] [sc-74904]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants