Skip to content

Commit

Permalink
Revert end-to-end jitting with default qubit (#6869)
Browse files Browse the repository at this point in the history
**Context:**

In #6788 , we started allowing executions on default qubit to be jitted
from end-to-end. Unfortunately, we found that the compilation overheads
on these executions can get very, very expensive. So until we find a way
to reduce the compilation overheads, we are using pure callbacks and
conversion to numpy.

**Description of the Change:**

Default to `convert_to_numpy=False`, and xfail relevant tests. This
change can be undone once we figure out how to resolve the compilation
issue.

**Benefits:**

Reduced compilation overheads, because the execution itself does not get
compiled.

**Possible Drawbacks:**

Slow down on post-compiled workflows. No way to jit an entire execution
on default qubit.

**Related GitHub Issues:**

---------

Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai>
  • Loading branch information
albi3ro and PietropaoloFrisoni authored Jan 23, 2025
1 parent 63cca88 commit 875ae11
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 17 deletions.
8 changes: 4 additions & 4 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
'parameter-shift'
```

* Finite shot and parameter-shift executions on `default.qubit` can now
be natively jitted end-to-end, leading to performance improvements.
Devices can now configure whether or not ML framework data is sent to them
via an `ExecutionConfig.convert_to_numpy` parameter.
* Devices can now configure whether or not ML framework data is sent to them
via an `ExecutionConfig.convert_to_numpy` parameter. This is not used on
`default.qubit` due to compilation overheads when jitting.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)
[(#6869)](https://github.com/PennyLaneAI/pennylane/pull/6869)

* The coefficients of observables now have improved differentiability.
[(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598)
Expand Down
17 changes: 10 additions & 7 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,13 +591,15 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}

jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}
updated_values["convert_to_numpy"] = (
execution_config.interface not in jax_interfaces
or execution_config.gradient_method == "adjoint"
# need numpy to use caching, and need caching higher order derivatives
or execution_config.derivative_order > 1
)
# uncomment once compilation overhead with jitting improved
# TODO: [sc-82874]
# jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}
# updated_values["convert_to_numpy"] = (
# execution_config.interface not in jax_interfaces
# or execution_config.gradient_method == "adjoint"
# # need numpy to use caching, and need caching higher order derivatives
# or execution_config.derivative_order > 1
# )
for option in execution_config.device_options:
if option not in self._device_options:
raise qml.DeviceError(f"device option {option} not present on {self}")
Expand Down Expand Up @@ -643,6 +645,7 @@ def execute(
prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]

if max_workers is None:

return tuple(
_simulate_wrapper(
c,
Expand Down
2 changes: 2 additions & 0 deletions pennylane/measurements/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def shape(self, shots: Optional[int] = None, num_device_wires: int = 0) -> tuple
)
if self.obs:
num_values_per_shot = 1 # one single eigenvalue
elif self.mv is not None:
num_values_per_shot = 1 if isinstance(self.mv, MeasurementValue) else len(self.mv)
else:
# one value per wire
num_values_per_shot = len(self.wires) if len(self.wires) > 0 else num_device_wires
Expand Down
5 changes: 4 additions & 1 deletion tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ def func(x, y, z):
results1 = func1(*params)

jaxpr = str(jax.make_jaxpr(func)(*params))
assert "pure_callback" not in jaxpr
# will change once we solve the compilation overhead issue
# assert "pure_callback" not in jaxpr
# TODO: [sc-82874]
assert "pure_callback" in jaxpr

func2 = jax.jit(func)
results2 = func2(*params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,17 @@ def circuit(x):
assert dev.tracker.totals["execute_and_derivative_batches"] == 1

@pytest.mark.parametrize("interface", ("jax", "jax-jit"))
def test_not_convert_to_numpy_with_jax(self, interface):
def test_convert_to_numpy_with_jax(self, interface):
"""Test that we will not convert to numpy when working with jax."""

# separate test so we can easily update it once we solve the
# compilation overhead issue
# TODO: [sc-82874]
dev = qml.device("default.qubit")
config = qml.devices.ExecutionConfig(
gradient_method=qml.gradients.param_shift, interface=interface
)
processed = dev.setup_execution_config(config)
assert not processed.convert_to_numpy
assert processed.convert_to_numpy

def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
Expand Down
1 change: 1 addition & 0 deletions tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,7 @@ def circuit(params):
assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0)
jax.clear_caches()

@pytest.mark.xfail # TODO: [sc-82874]
@pytest.mark.parametrize("num_split_times", [1, 2])
@pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"])
def test_simple_qnode_jit(self, num_split_times, time_interface):
Expand Down
6 changes: 4 additions & 2 deletions tests/measurements/test_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def circuit():
@pytest.mark.jax
@pytest.mark.parametrize("shots", (None, 500))
@pytest.mark.parametrize("obs", ([0, 1], qml.PauliZ(0) @ qml.PauliZ(1)))
@pytest.mark.parametrize("params", ([np.pi / 2], [np.pi / 2, np.pi / 2, np.pi / 2]))
@pytest.mark.parametrize("params", (np.pi / 2, [np.pi / 2, np.pi / 2, np.pi / 2]))
def test_integration_jax(self, tol_stochastic, shots, obs, params, seed):
"""Test the probability is correct for a known state preparation when jitted with JAX."""
jax = pytest.importorskip("jax")
Expand All @@ -359,7 +359,9 @@ def circuit(x):
# expected probability, using [00, 01, 10, 11]
# ordering, is [0.5, 0.5, 0, 0]

assert "pure_callback" not in str(jax.make_jaxpr(circuit)(params))
# TODO: [sc-82874]
# revert once we are able to jit end to end without extreme compilation overheads
assert "pure_callback" in str(jax.make_jaxpr(circuit)(params))

res = jax.jit(circuit)(params)
expected = np.array([0.5, 0.5, 0, 0])
Expand Down

0 comments on commit 875ae11

Please sign in to comment.