Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where unexpected queuing occurs in qml.ctrl among other functions #6284

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/development/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Release notes

This page contains the release notes for PennyLane.

.. mdinclude:: ../releases/changelog-dev.md

.. mdinclude:: ../releases/changelog-0.39.0.md

.. mdinclude:: ../releases/changelog-0.38.0.md
Expand Down
24 changes: 24 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
:orphan:

# Release 0.40.0-dev (development release)

<h3>New features since last release</h3>

<h3>Improvements 🛠</h3>

<h3>Breaking changes 💔</h3>

<h3>Deprecations 👋</h3>

<h3>Documentation 📝</h3>

<h3>Bug fixes 🐛</h3>

* Fixes a bug where `qml.ctrl` and `qml.adjoint` queued extra operators if they were defined as the arguments.
[(#6284)](https://github.com/PennyLaneAI/pennylane/pull/6284)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Guillermo Alonso
1 change: 1 addition & 0 deletions pennylane/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
Version number (major.minor.patch[-label])
"""


__version__ = "0.39.0"
7 changes: 0 additions & 7 deletions pennylane/gradients/pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,6 @@ def _psr_and_contract(res_list, cjacs, int_prefactor):
# Single measurement without shot vector
return _psr_and_contract(results, cjacs, int_prefactor)

# Multiple measurements with shot vector. Not supported with broadcasting yet.
if use_broadcasting:
# TODO: Remove once #2690 is resolved
raise NotImplementedError(
"Broadcasting, multiple measurements and shot vectors are currently not "
"supported all simultaneously by stoch_pulse_grad."
)
return tuple(
tuple(_psr_and_contract(_r, cjacs, int_prefactor) for _r in zip(*r)) for r in zip(*results)
)
Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ def _adjoint_transform(qfunc: Callable, lazy=True) -> Callable:
@wraps(qfunc)
def wrapper(*args, **kwargs):
qscript = make_qscript(qfunc)(*args, **kwargs)

leaves, _ = qml.pytrees.flatten((args, kwargs), lambda obj: isinstance(obj, Operator))
_ = [qml.QueuingManager.remove(l) for l in leaves if isinstance(l, Operator)]

if lazy:
adjoint_ops = [Adjoint(op) for op in reversed(qscript.operations)]
else:
Expand Down
3 changes: 3 additions & 0 deletions pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def _ctrl_transform(op, control, control_values, work_wires):
def wrapper(*args, **kwargs):
qscript = qml.tape.make_qscript(op)(*args, **kwargs)

leaves, _ = qml.pytrees.flatten((args, kwargs), lambda obj: isinstance(obj, Operator))
_ = [qml.QueuingManager.remove(l) for l in leaves if isinstance(l, Operator)]

# flip control_values == 0 wires here, so we don't have to do it for each individual op.
flip_control_on_zero = (len(qscript) > 1) and (control_values is not None)
op_control_values = None if flip_control_on_zero else control_values
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pytest>=7.1.2
pytest-cov>=3.0.0
pytest-mock>=3.7.0
pytest-xdist>=2.5.0
pytest-rng
flaky>=3.7.0
pytest-forked>=1.4.0
pytest-benchmark
Expand Down
11 changes: 6 additions & 5 deletions tests/capture/test_capture_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,13 +653,14 @@ def test_circuit_consts(self, pred, arg, expected):
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.local_salt(1)
@pytest.mark.parametrize("reset", [True, False])
@pytest.mark.parametrize("postselect", [None, 0, 1])
@pytest.mark.parametrize("shots", [None, 20])
def test_mcm_predicate_execution(self, reset, postselect, shots):
def test_mcm_predicate_execution(self, reset, postselect, shots, seed):
"""Test that QNodes executed with mid-circuit measurement predicates for
qml.cond give correct results."""
device = qml.device("default.qubit", wires=3, shots=shots, seed=jax.random.PRNGKey(1234))
device = qml.device("default.qubit", wires=3, shots=shots, seed=jax.random.PRNGKey(seed))

def true_fn(arg):
qml.RX(arg, 0)
Expand All @@ -682,7 +683,7 @@ def f(x, y):

assert np.allclose(res, expected), f"Expected {expected}, but got {res}"

@pytest.mark.parametrize("shots", [None, 100])
@pytest.mark.parametrize("shots", [None, 300])
@pytest.mark.parametrize(
"params, expected",
# The parameters used here will essentially apply a PauliX just before mid-circuit
Expand All @@ -696,11 +697,11 @@ def f(x, y):
([0, 0, 0, 0], (1 / np.sqrt(2), 0, 0, 1)), # false_fn, PauliZ basis
],
)
def test_mcm_predicate_execution_with_elifs(self, params, expected, shots, tol):
def test_mcm_predicate_execution_with_elifs(self, params, expected, shots, tol, seed):
"""Test that QNodes executed with mid-circuit measurement predicates for
qml.cond give correct results when there are also elifs present."""
# pylint: disable=expression-not-assigned
device = qml.device("default.qubit", wires=5, shots=shots, seed=jax.random.PRNGKey(10))
device = qml.device("default.qubit", wires=5, shots=shots, seed=jax.random.PRNGKey(seed))

def true_fn():
# Adjoint Hadamard diagonalizing gates to get Hadamard basis state
Expand Down
26 changes: 14 additions & 12 deletions tests/capture/test_capture_mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,12 @@ class TestMidMeasureExecute:
@pytest.mark.parametrize("reset", [True, False])
@pytest.mark.parametrize("postselect", [None, 0, 1])
@pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5))
def test_simple_circuit_execution(self, phi, reset, postselect, get_device, shots, mp_fn):
def test_simple_circuit_execution(self, phi, reset, postselect, get_device, shots, mp_fn, seed):
"""Test that circuits with mid-circuit measurements can be executed in a QNode."""
if shots is None and mp_fn is qml.sample:
pytest.skip("Cannot measure samples in analytic mode")

dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(12345))
dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(seed))

@qml.qnode(dev)
def f(x):
Expand All @@ -340,7 +340,7 @@ def f(x):
@pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5))
@pytest.mark.parametrize("multi_mcm", [True, False])
def test_circuit_with_terminal_measurement_execution(
self, phi, get_device, shots, mp_fn, multi_mcm
self, phi, get_device, shots, mp_fn, multi_mcm, seed
):
"""Test that circuits with mid-circuit measurements that also collect statistics
on the mid-circuit measurements can be executed in a QNode."""
Expand All @@ -350,7 +350,7 @@ def test_circuit_with_terminal_measurement_execution(
if multi_mcm and mp_fn in (qml.expval, qml.var):
pytest.skip("Cannot measure sequences of MCMs with expval or var")

dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(12345))
dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(seed))

@qml.qnode(dev)
def f(x, y):
Expand All @@ -364,13 +364,13 @@ def f(x, y):

@pytest.mark.xfail
@pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5))
def test_circuit_with_boolean_arithmetic_execution(self, phi, get_device, shots, mp_fn):
def test_circuit_with_boolean_arithmetic_execution(self, phi, get_device, shots, mp_fn, seed):
"""Test that circuits that apply boolean logic to mid-circuit measurement values
can be executed."""
if shots is None and mp_fn is qml.sample:
pytest.skip("Cannot measure samples in analytic mode")

dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(12345))
dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(seed))

@qml.qnode(dev)
def f(x, y):
Expand All @@ -386,13 +386,13 @@ def f(x, y):

@pytest.mark.xfail
@pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5))
def test_circuit_with_classical_processing_execution(self, phi, get_device, shots, mp_fn):
def test_circuit_with_classical_processing_execution(self, phi, get_device, shots, mp_fn, seed):
"""Test that circuits that apply non-boolean operations to mid-circuit measurement
values can be executed."""
if shots is None and mp_fn is qml.sample:
pytest.skip("Cannot measure samples in analytic mode")

dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(12345))
dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(seed))

@qml.qnode(dev)
def f(x, y):
Expand All @@ -409,13 +409,15 @@ def f(x, y):
@pytest.mark.xfail
@pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5))
@pytest.mark.parametrize("fn", [jnp.sin, jnp.sqrt, jnp.log, jnp.exp])
def mid_measure_processed_with_jax_numpy_execution(self, phi, fn, get_device, shots, mp_fn):
def mid_measure_processed_with_jax_numpy_execution(
self, phi, fn, get_device, shots, mp_fn, seed
):
"""Test that a circuit containing mid-circuit measurements processed using jax.numpy
can be executed."""
if shots is None and mp_fn is qml.sample:
pytest.skip("Cannot measure samples in analytic mode")

dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(12345))
dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(seed))

@qml.qnode(dev)
def f(x):
Expand All @@ -428,13 +430,13 @@ def f(x):

@pytest.mark.xfail
@pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5))
def test_mid_measure_as_gate_parameter_execution(self, phi, get_device, shots, mp_fn):
def test_mid_measure_as_gate_parameter_execution(self, phi, get_device, shots, mp_fn, seed):
"""Test that mid-circuit measurements (simple or classical processed) used as gate
parameters can be executed."""
if shots is None and mp_fn is qml.sample:
pytest.skip("Cannot measure samples in analytic mode")

dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(12345))
dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(seed))

@qml.qnode(dev)
def f(x):
Expand Down
12 changes: 6 additions & 6 deletions tests/capture/test_measurements_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,14 +545,14 @@ def f():


@pytest.mark.parametrize("x64_mode", (True, False))
def test_shadow_expval(x64_mode):
def test_shadow_expval(x64_mode, seed):
"""Test that the shadow expval of an observable can be captured."""

initial_mode = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", x64_mode)

def f():
return qml.shadow_expval(qml.X(0), seed=887, k=4)
return qml.shadow_expval(qml.X(0), seed=seed, k=4)

jaxpr = jax.make_jaxpr(f)()

Expand All @@ -561,7 +561,7 @@ def f():

assert jaxpr.eqns[1].primitive == ShadowExpvalMP._obs_primitive
assert jaxpr.eqns[0].outvars == jaxpr.eqns[1].invars
assert jaxpr.eqns[1].params == {"seed": 887, "k": 4}
assert jaxpr.eqns[1].params == {"seed": seed, "k": 4}

am = jaxpr.eqns[1].outvars[0].aval
assert isinstance(am, AbstractMeasurement)
Expand Down Expand Up @@ -638,19 +638,19 @@ def f(w1, w2):
jax.config.update("jax_enable_x64", initial_mode)


def test_ClassicalShadow():
def test_ClassicalShadow(seed):
"""Test that the classical shadow measurement can be captured."""

def f():
return qml.classical_shadow(wires=(0, 1, 2), seed=95)
return qml.classical_shadow(wires=(0, 1, 2), seed=seed)

jaxpr = jax.make_jaxpr(f)()

jaxpr = jax.make_jaxpr(f)()
assert len(jaxpr.eqns) == 1

assert jaxpr.eqns[0].primitive == ClassicalShadowMP._wires_primitive
assert jaxpr.eqns[0].params == {"seed": 95}
assert jaxpr.eqns[0].params == {"seed": seed}
assert len(jaxpr.eqns[0].invars) == 3
mp = jaxpr.eqns[0].outvars[0].aval
assert isinstance(mp, AbstractMeasurement)
Expand Down
50 changes: 38 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ class DummyDevice(DefaultGaussian):
_operation_map["Kerr"] = lambda *x, **y: np.identity(2)


@pytest.fixture(autouse=True)
def set_numpy_seed():
np.random.seed(9872653)
yield


@pytest.fixture(scope="session")
def tol():
"""Numerical tolerance for equality tests."""
Expand Down Expand Up @@ -106,12 +100,6 @@ def qutrit_device_3_wires(request):
#######################################################################


@pytest.fixture(scope="module", params=[1, 2, 3])
def seed(request):
"""Different seeds."""
return request.param


@pytest.fixture(scope="function")
def mock_device(monkeypatch):
"""A mock instance of the abstract Device class"""
Expand Down Expand Up @@ -190,6 +178,44 @@ def legacy_opmath_only():
pytest.skip("This test exclusively tests legacy opmath")


#######################################################################


@pytest.fixture(autouse=True)
def restore_global_seed():
original_state = np.random.get_state()
yield
np.random.set_state(original_state)


@pytest.fixture
def seed(request):
"""An integer random number generator seed

This fixture overrides the ``seed`` fixture provided by pytest-rng, adding the flexibility
of locally getting a new seed for a test case by applying the ``local_salt`` marker. This is
useful when the seed from pytest-rng happens to be a bad seed that causes your test to fail.

.. code_block:: python

@pytest.mark.local_salt(42)
def test_something(seed):
...

The value passed to ``local_salt`` needs to be an integer.

"""

fixture_manager = request._fixturemanager # pylint:disable=protected-access
fixture_defs = fixture_manager.getfixturedefs("seed", request.node)
original_fixture_def = fixture_defs[0] # the original seed fixture provided by pytest-rng
original_seed = original_fixture_def.func(request)
marker = request.node.get_closest_marker("local_salt")
if marker and marker.args:
return original_seed + marker.args[0]
return original_seed


#######################################################################

try:
Expand Down
Loading
Loading