-
Notifications
You must be signed in to change notification settings - Fork 616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Capture] Add finite differences jvps #6853
base: master
Are you sure you want to change the base?
Conversation
Hello. You may have forgotten to update the changelog!
|
…pennylane into finite-diff-capture
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6853 +/- ##
=======================================
Coverage 99.54% 99.55%
=======================================
Files 477 477
Lines 45246 45273 +27
=======================================
+ Hits 45042 45070 +28
+ Misses 204 203 -1 ☔ View full report in Codecov by Sentry. |
return qml.expval(qml.Z(0)) | ||
|
||
with pytest.warns(UserWarning, match="Detected float32 parameter with finite differences."): | ||
jax.grad(circuit)(jnp.array(0.5, dtype=jnp.float32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be using qml.grad
in the tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential over-testing suggestion: or maybe parametrizing over qml and jax diff functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a jax specific test.
jax.grad(circuit)(jnp.array(0.5, dtype=jnp.float32)) | ||
|
||
|
||
class TestGradients: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non blocking questions:
Do we have tests to validate that the correct diff method is actually being used? Does it matter if we do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a test at the bottom where I double check the captured jaxpr.
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to say in the last review, but could you also add a test for differentiating for complex parameters? Other than that, I'm ready to approve.
Co-authored-by: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com>
Context:
Currently, program capture only supports differentiation with backprop. This is only supported on
default.qubit
Description of the Change:
Adds the capability to take finite difference derivatives.
Benefits:
We can now differentiate anything, and we can differentiate with the lightnings.
Possible Drawbacks:
finite diff tends to be noisy to the point of uselessness.
Related GitHub Issues:
[sc-82167]