Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance for default.qubit.compute_vjp #4841

Merged
merged 19 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
`qml.QNode` or `qml.execute`.
[(#4557)](https://github.com/PennyLaneAI/pennylane/pull/4557)
[(#4654)](https://github.com/PennyLaneAI/pennylane/pull/4654)
[(#4841)](https://github.com/PennyLaneAI/pennylane/pull/4841)

```pycon
>>> dev = qml.device('default.qubit')
Expand Down
19 changes: 18 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def name(self):
"""The name of the device."""
return "default.qubit"

_state_cache: Optional[dict] = None
"""
A cache to store the "pre-rotated state" for reuse between the forward pass call to ``execute`` and
subsequent calls to ``compute_vjp``. ``None`` indicates that no caching is required.
"""

# pylint:disable = too-many-arguments
def __init__(
self,
Expand Down Expand Up @@ -469,6 +475,7 @@ def execute(
circuits = [circuits]

max_workers = execution_config.device_options.get("max_workers", self._max_workers)
self._state_cache = {} if execution_config.use_device_jacobian_product else None
interface = (
execution_config.interface
if execution_config.gradient_method in {"backprop", None}
Expand All @@ -482,6 +489,7 @@ def execute(
prng_key=self._prng_key,
debugger=self._debugger,
interface=interface,
state_cache=self._state_cache,
)
for c in circuits
)
Expand Down Expand Up @@ -736,7 +744,16 @@ def compute_vjp(

max_workers = execution_config.device_options.get("max_workers", self._max_workers)
if max_workers is None:
res = tuple(adjoint_vjp(circuit, cots) for circuit, cots in zip(circuits, cotangents))

def _state(circuit):
return (
None if self._state_cache is None else self._state_cache.get(circuit.hash, None)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
)

res = tuple(
adjoint_vjp(circuit, cots, state=_state(circuit))
for circuit, cots in zip(circuits, cotangents)
)
else:
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
Expand Down
18 changes: 14 additions & 4 deletions pennylane/devices/qubit/adjoint_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,20 @@ def adjoint_vjp(tape: QuantumTape, cotangents: Tuple[Number], state=None):

ket = state if state is not None else get_final_state(tape)[0]

if np.shape(cotangents) == tuple():
cotangents = (cotangents,)
obs = qml.dot(cotangents, tape.observables)
bra = apply_operation(obs, ket)
cotangents = (cotangents,) if qml.math.shape(cotangents) == tuple() else cotangents
new_cotangents, new_observables = [], []
for c, o in zip(cotangents, tape.observables):
if not np.allclose(c, 0.0):
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
new_cotangents.append(c)
new_observables.append(o)
if len(new_cotangents) == 0:
return tuple(0.0 for _ in tape.trainable_params)
obs = qml.dot(new_cotangents, new_observables)
if obs._pauli_rep is not None:
flat_bra = obs._pauli_rep.dot(ket.flatten(), wire_order=list(range(tape.num_wires)))
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
bra = flat_bra.reshape(ket.shape)
else:
bra = apply_operation(obs, ket)

param_number = len(tape.get_parameters(trainable_only=False, operations_only=True)) - 1
trainable_param_number = len(tape.trainable_params) - 1
Expand Down
13 changes: 12 additions & 1 deletion pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
"""Simulate a quantum script."""
# pylint: disable=protected-access
from typing import Optional

from numpy.random import default_rng
import numpy as np

Expand Down Expand Up @@ -197,8 +199,14 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non
return results


# pylint: disable=too-many-arguments
def simulate(
circuit: qml.tape.QuantumScript, rng=None, prng_key=None, debugger=None, interface=None
circuit: qml.tape.QuantumScript,
rng=None,
prng_key=None,
debugger=None,
interface=None,
state_cache: Optional[dict] = None,
) -> Result:
"""Simulate a single quantum script.
Expand All @@ -214,6 +222,7 @@ def simulate(
generated. Only for simulation using JAX.
debugger (_Debugger): The debugger to use
interface (str): The machine learning interface to create the initial state with
state_cache=None (Optional[dict]): A dictionary mapping the hash of a circuit to the pre-rotated state. Used to pass the state between forward passes and vjp calculations.
Returns:
tuple(TensorLike): The results of the simulation
Expand All @@ -229,4 +238,6 @@ def simulate(
"""
state, is_state_batched = get_final_state(circuit, debugger=debugger, interface=interface)
if state_cache is not None:
state_cache[circuit.hash] = state
return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=prng_key)
87 changes: 49 additions & 38 deletions tests/devices/qubit/test_adjoint_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class TestAdjointJacobian:
def test_custom_wire_labels(self, tol):
"""Test that adjoint_jacbonian works as expected when custom wire labels are used."""
qs = QuantumScript(
[qml.RX(0.123, wires="a"), qml.RY(0.456, wires="b")], [qml.expval(qml.PauliX("a"))]
[qml.RX(0.123, wires="a"), qml.RY(0.456, wires="b")],
[qml.expval(qml.PauliX("a"))],
trainable_params=[0, 1],
)
qs.trainable_params = {0, 1}

calculated_val = adjoint_jacobian(qs)

Expand All @@ -49,11 +50,11 @@ def test_pauli_rotation_gradient(self, G, theta, tol):

prep_op = qml.StatePrep(np.array([1.0, -1.0], requires_grad=False) / np.sqrt(2), wires=0)
qs = QuantumScript(
ops=[prep_op, G(theta, wires=[0])], measurements=[qml.expval(qml.PauliZ(0))]
ops=[prep_op, G(theta, wires=[0])],
measurements=[qml.expval(qml.PauliZ(0))],
trainable_params=[1],
)

qs.trainable_params = {1}

calculated_val = adjoint_jacobian(qs)
# compare to finite differences
tapes, fn = qml.gradients.finite_diff(qs)
Expand All @@ -72,9 +73,9 @@ def test_Rot_gradient(self, theta, tol):
qs = QuantumScript(
ops=[prep_op, qml.Rot(*params, wires=[0])],
measurements=[qml.expval(qml.PauliZ(0))],
trainable_params=[1, 2, 3],
)

qs.trainable_params = {1, 2, 3}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs = qs_valid[0]

Expand Down Expand Up @@ -110,8 +111,7 @@ def test_gradients(self, op, obs, tol):
]
measurements = [qml.expval(obs(wires=0)), qml.expval(qml.PauliZ(wires=1))]

qs = QuantumScript(ops, measurements)
qs.trainable_params = set(range(1, 1 + op.num_params))
qs = QuantumScript(ops, measurements, trainable_params=list(range(1, 1 + op.num_params)))

qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]
Expand Down Expand Up @@ -188,9 +188,9 @@ def test_gradient_gate_with_multiple_parameters(self, tol):
qs = QuantumScript(
[qml.RX(0.4, wires=[0]), qml.Rot(x, y, z, wires=[0]), qml.RY(-0.2, wires=[0])],
[qml.expval(qml.PauliZ(0))],
trainable_params=[1, 2, 3],
)

qs.trainable_params = {1, 2, 3}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]

Expand Down Expand Up @@ -218,9 +218,9 @@ def test_state_prep(self, prep_op, tol):
qs = QuantumScript(
[prep_op, qml.RX(0.4, wires=[0]), qml.Rot(x, y, z, wires=[0]), qml.RY(-0.2, wires=[0])],
[qml.expval(qml.PauliZ(0))],
trainable_params=[2, 3, 4],
)

qs.trainable_params = {2, 3, 4}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]

Expand Down Expand Up @@ -248,9 +248,9 @@ def test_gradient_of_tape_with_hermitian(self, tol):
qml.CNOT(wires=[1, 2]),
],
[qml.expval(qml.Hermitian(mx, wires=[0, 2]))],
trainable_params=[0, 1, 2],
)

qs.trainable_params = {0, 1, 2}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]

Expand Down Expand Up @@ -279,9 +279,9 @@ def test_gradient_of_tape_with_tensor(self, tol):
qml.CNOT(wires=[1, 2]),
],
[qml.expval(qml.PauliX(0) @ qml.PauliY(2))],
trainable_params=[0, 1, 2],
)

qs.trainable_params = {0, 1, 2}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]

Expand All @@ -304,8 +304,7 @@ def test_with_nontrainable_parametrized(self):
qml.RY(par, wires=0),
qml.QubitUnitary(np.eye(2), wires=0),
]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))])
qs.trainable_params = [0]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))], trainable_params=[0])

grad_adjoint = adjoint_jacobian(qs)
expected = [-np.sin(par)]
Expand All @@ -319,8 +318,7 @@ class TestAdjointJVP:
def test_single_param_single_obs(self, tangents, tol):
"""Test JVP is correct for a single parameter and observable"""
x = np.array(0.654)
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0))])
qs.trainable_params = {0}
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0])

actual = adjoint_jvp(qs, tangents)

Expand All @@ -331,8 +329,11 @@ def test_single_param_single_obs(self, tangents, tol):
def test_single_param_multi_obs(self, tangents, tol):
"""Test JVP is correct for a single parameter and multiple observables"""
x = np.array(0.654)
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0))])
qs.trainable_params = {0}
qs = QuantumScript(
[qml.RY(x, 0)],
[qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0))],
trainable_params=[0],
)

actual = adjoint_jvp(qs, tangents)
assert isinstance(actual, tuple)
Expand All @@ -347,8 +348,9 @@ def test_multi_param_single_obs(self, tangents, tol):
x = np.array(0.654)
y = np.array(1.221)

qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], [qml.expval(qml.PauliY(0))])
qs.trainable_params = {0, 1}
qs = QuantumScript(
[qml.RY(x, 0), qml.RZ(y, 0)], [qml.expval(qml.PauliY(0))], trainable_params=[0, 1]
)

actual = adjoint_jvp(qs, tangents)

Expand All @@ -364,8 +366,7 @@ def test_multi_param_multi_obs(self, tangents, tol):
y = np.array(1.221)

obs = [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0)), qml.expval(qml.PauliY(0))]
qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], obs)
qs.trainable_params = {0, 1}
qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], obs, trainable_params=[0, 1])

actual = adjoint_jvp(qs, tangents)
assert isinstance(actual, tuple)
Expand Down Expand Up @@ -393,8 +394,7 @@ def test_custom_wire_labels(self, tangents, wires, tol):
qml.expval(qml.PauliY(wires[1])),
qml.expval(qml.PauliX(wires[0])),
]
qs = QuantumScript([qml.RY(x, wires[0]), qml.RX(y, wires[1])], obs)
qs.trainable_params = {0, 1}
qs = QuantumScript([qml.RY(x, wires[0]), qml.RX(y, wires[1])], obs, trainable_params=[0, 1])
assert qs.wires.tolist() == wires

actual = adjoint_jvp(qs, tangents)
Expand All @@ -416,8 +416,7 @@ def test_with_nontrainable_parametrized(self):
qml.RY(par, wires=0),
qml.QubitUnitary(np.eye(2), wires=0),
]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))])
qs.trainable_params = [0]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))], trainable_params=[0])

jvp_adjoint = adjoint_jvp(qs, tangents)
expected = [-np.sin(par) * tangents[0]]
Expand All @@ -431,8 +430,7 @@ class TestAdjointVJP:
def test_single_param_single_obs(self, cotangents, tol):
"""Test VJP is correct for a single parameter and observable"""
x = np.array(0.654)
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0))])
qs.trainable_params = {0}
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0])

actual = adjoint_vjp(qs, cotangents)

Expand All @@ -444,8 +442,11 @@ def test_single_param_single_obs(self, cotangents, tol):
def test_single_param_multi_obs(self, cotangents, tol):
"""Test VJP is correct for a single parameter and multiple observables"""
x = np.array(0.654)
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0))])
qs.trainable_params = {0}
qs = QuantumScript(
[qml.RY(x, 0)],
[qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0))],
trainable_params=[0],
)

actual = adjoint_vjp(qs, cotangents)

Expand All @@ -458,8 +459,9 @@ def test_multi_param_single_obs(self, cotangents, tol):
x = np.array(0.654)
y = np.array(1.221)

qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], [qml.expval(qml.PauliY(0))])
qs.trainable_params = {0, 1}
qs = QuantumScript(
[qml.RY(x, 0), qml.RZ(y, 0)], [qml.expval(qml.PauliY(0))], trainable_params=[0, 1]
)

actual = adjoint_vjp(qs, cotangents)
assert isinstance(actual, tuple)
Expand All @@ -477,8 +479,7 @@ def test_multi_param_multi_obs(self, cotangents, tol):
y = np.array(1.221)

obs = [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0)), qml.expval(qml.PauliY(0))]
qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], obs)
qs.trainable_params = {0, 1}
qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], obs, trainable_params=[0, 1])

actual = adjoint_vjp(qs, cotangents)
assert isinstance(actual, tuple)
Expand Down Expand Up @@ -508,8 +509,7 @@ def test_custom_wire_labels(self, cotangents, wires, tol):
qml.expval(qml.PauliY(wires[1])),
qml.expval(qml.PauliX(wires[0])),
]
qs = QuantumScript([qml.RY(x, wires[0]), qml.RX(y, wires[1])], obs)
qs.trainable_params = {0, 1}
qs = QuantumScript([qml.RY(x, wires[0]), qml.RX(y, wires[1])], obs, trainable_params=[0, 1])
assert qs.wires.tolist() == wires

actual = adjoint_vjp(qs, cotangents)
Expand All @@ -531,9 +531,20 @@ def test_with_nontrainable_parametrized(self):
qml.RY(par, wires=0),
qml.QubitUnitary(np.eye(2), wires=0),
]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))])
qs.trainable_params = [0]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))], trainable_params=[0])

vjp_adjoint = adjoint_vjp(qs, cotangents)
expected = [-np.sin(par) * cotangents[0]]
assert np.allclose(vjp_adjoint, expected)

def test_hermitian_expval(self):
"""Test adjoint_vjp works with a hermitian expectation value."""

x = 1.2
H = qml.Hermitian(np.array([[1, 0], [0, -1]]), wires=0)
cotangent = (0.5,)

qs = QuantumScript([qml.RX(x, wires=0)], [qml.expval(H)], trainable_params=[0])

[vjp_adjoint] = adjoint_vjp(qs, cotangent)
assert qml.math.allclose(vjp_adjoint, -0.5 * np.sin(x))