Skip to content

Commit

Permalink
[Capture] Add backprop validation (#6852)
Browse files Browse the repository at this point in the history
**Context:**

We currently use un-validated backprop for differentiation with program
capture. This leads to some unintuitive errors if you try and take a
gradient on lightning with capture enabled.

**Description of the Change:**

Adds some validation to make sure the device supports backprop. Adds the
backprop logic to a `_backprop` jvp function, and dispatches to that
method based on the diff method.

**Benefits:**

Improved error messages when backprop or the requested diff method isn't
supported.

**Possible Drawbacks:**

The code currently is a little clunky, but it is private so we should be
able to move things around once we have more information.

**Related GitHub Issues:**

[sc-82166]
  • Loading branch information
albi3ro authored Jan 21, 2025
1 parent 98bb29b commit 90dc57c
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 12 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
* An informative error is raised when a `QNode` with `diff_method=None` is differentiated.
[(#6770)](https://github.com/PennyLaneAI/pennylane/pull/6770)

* The requested `diff_method` is now validated when program capture is enabled.
[(#6852)](https://github.com/PennyLaneAI/pennylane/pull/6852)

<h3>Breaking changes 💔</h3>

* `MultiControlledX` no longer accepts strings as control values.
Expand Down
31 changes: 29 additions & 2 deletions pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def _qnode_batching_rule(
"using parameter broadcasting to a quantum operation that supports batching.",
UserWarning,
)

# To resolve this ambiguity, we might add more properties to the AbstractOperator
# class to indicate which operators support batching and check them here.
# As above, at this stage we raise a warning and give the user full flexibility.
Expand Down Expand Up @@ -277,15 +276,43 @@ def _qnode_batching_rule(
return result, (0,) * len(result)


### JVP CALCULATION #########################################################
# This structure will change as we add more diff methods


def _make_zero(tan, arg):
return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan


def _qnode_jvp(args, tangents, **impl_kwargs):
def _backprop(args, tangents, **impl_kwargs):
tangents = tuple(map(_make_zero, tangents, args))
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents)


diff_method_map = {"backprop": _backprop}


def _resolve_diff_method(diff_method: str, device) -> str:
# check if best is backprop
if diff_method == "best":
config = qml.devices.ExecutionConfig(gradient_method=diff_method, interface="jax")
diff_method = device.setup_execution_config(config).gradient_method

if diff_method not in diff_method_map:
raise NotImplementedError(f"diff_method {diff_method} not yet implemented.")

return diff_method


def _qnode_jvp(args, tangents, *, qnode_kwargs, device, **impl_kwargs):
diff_method = _resolve_diff_method(qnode_kwargs["diff_method"], device)
return diff_method_map[diff_method](
args, tangents, qnode_kwargs=qnode_kwargs, device=device, **impl_kwargs
)


### END JVP CALCULATION #######################################################

ad.primitive_jvps[qnode_prim] = _qnode_jvp

batching.primitive_batchers[qnode_prim] = _qnode_batching_rule
Expand Down
63 changes: 53 additions & 10 deletions tests/capture/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,18 +339,61 @@ def circuit(x):
assert list(out.keys()) == ["a", "b"]


def test_qnode_jvp():
"""Test that JAX can compute the JVP of the QNode primitive via a registered JVP rule."""
class TestDifferentiation:

@qml.qnode(qml.device("default.qubit", wires=1))
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))
def test_error_backprop_unsupported(self):
"""Test an error is raised with backprop if the device does not support it."""

# pylint: disable=too-few-public-methods
class DummyDev(qml.devices.Device):

def execute(self, *_, **__):
return 0

with pytest.raises(qml.QuantumFunctionError, match="does not support backprop"):

@qml.qnode(DummyDev(wires=2), diff_method="backprop")
def _(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

def test_error_unsupported_diff_method(self):
"""Test an error is raised for a non-backprop diff method."""

@qml.qnode(qml.device("default.qubit", wires=2), diff_method="parameter-shift")
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

with pytest.raises(
NotImplementedError, match="diff_method parameter-shift not yet implemented."
):
jax.grad(circuit)(0.5)

@pytest.mark.parametrize("diff_method", ("best", "backprop"))
def test_default_qubit_backprop(self, diff_method):
"""Test that JAX can compute the JVP of the QNode primitive via a registered JVP rule."""

@qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method)
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

x = 0.9
xt = -0.6
jvp = jax.jvp(circuit, (x,), (xt,))
assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt))

def test_no_gradients_with_lightning(self):
"""Test that we get an error if we try and differentiate a lightning execution."""

@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

x = 0.9
xt = -0.6
jvp = jax.jvp(circuit, (x,), (xt,))
assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt))
with pytest.raises(NotImplementedError, match=r"diff_method adjoint not yet implemented"):
jax.grad(circuit)(0.5)


def test_qnode_jit():
Expand Down

0 comments on commit 90dc57c

Please sign in to comment.