-
Notifications
You must be signed in to change notification settings - Fork 616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
qml.wire.Wires
accepts JAX
arrays
#6312
Conversation
Hello. You may have forgotten to update the changelog!
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6312 +/- ##
=======================================
Coverage 99.39% 99.39%
=======================================
Files 445 445
Lines 42286 42315 +29
=======================================
+ Hits 42031 42060 +29
Misses 255 255 ☔ View full report in Codecov by Sentry. |
…o Wires_accept_JAX_array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
…o Wires_accept_JAX_array
**Context:** In JAX JIT/plxpr/qjit workflows, it is quite easy for wires to end up with jax array values. JAX arrays are not hashable and are rejected. Still, we can treat them as iterable and ensure they are accepted as wires. **Description of the Change:** We use the `to_list` and `item` methods of `JAX` arrays as a trick to make them tuples before hashing them. **Benefits:** `qml.wire.Wires` can accept `JAX` arrays as input. **Possible Drawbacks:** The change should only be relevant when `jax.numpy.array`s are explicitly provided. Therefore, in principle, we should not face problems with the rest of the code. **Related GitHub Issues:** None. **Related Shortcut Stories** [sc-72593] [sc-74904]
Context: In JAX JIT/plxpr/qjit workflows, it is quite easy for wires to end up with jax array values. JAX arrays are not hashable and are rejected. Still, we can treat them as iterable and ensure they are accepted as wires.
Furthermore, if we try to use JAX tracers as wires in PennyLane using JAX 0.4.30+, we currently get FutureWarning in some of the tests:
This future warning was introduced in JAX 0.4.30 and is still not implemented as an error in 0.4.32, but both PL and Catalyst currently use version 0.4.28.
Description of the Change: We use the
to_list
anditem
methods ofJAX
arrays as a trick to make them tuples before hashing them to ensure that they are accepted as iterable by theWire
class. Furthermore, we decided to implement in PL the same patch adopted in Catalyst to avoid the warning reported above as temporary solution.Benefits:
qml.wire.Wires
can acceptJAX
arrays as input.Possible Drawbacks: The first change should only be relevant when
jax.numpy.array
s are explicitly provided. Therefore, we should not face problems with the rest of the code.As for the second change, it is highly recommended to re-visit the way JAX tracers are handled in the PL pipeline since in a future JAX version we will not be able to hash tracers anymore.
Related GitHub Issues: None.
Related Shortcut Stories [sc-72593] [sc-74904]