Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove AllWires validation and associated tests #6373

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@

<h3>Breaking changes 💔</h3>

* `AllWires` validation in `QNode.construct` has been removed.
[(#6373)](https://github.com/PennyLaneAI/pennylane/pull/6373)

* The `simplify` argument in `qml.Hamiltonian` and `qml.ops.LinearCombination` has been removed.
Instead, `qml.simplify()` can be called on the constructed operator.
[(#6279)](https://github.com/PennyLaneAI/pennylane/pull/6279)
Expand Down
10 changes: 0 additions & 10 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,16 +861,6 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches
"All measurements must be returned in the order they are measured."
)

num_wires = len(self.tape.wires) if not self.device.wires else len(self.device.wires)
for obj in self.tape.operations + self.tape.observables:
if (
getattr(obj, "num_wires", None) is qml.operation.WiresEnum.AllWires
and obj.wires
and len(obj.wires) != num_wires
):
# check here only if enough wires
raise qml.QuantumFunctionError(f"Operator {obj.name} must act on all wires")

def _execution_component(self, args: tuple, kwargs: dict) -> qml.typing.Result:
"""Construct the transform program and execute the tapes. Helper function for ``__call__``

Expand Down
22 changes: 0 additions & 22 deletions tests/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,28 +1070,6 @@ class DummyObserv(qml.operation.Observable):
class TestOperatorIntegration:
"""Integration tests for the Operator class"""

def test_all_wires_defined_but_init_with_one(self):
"""Test that an exception is raised if the class is defined with ALL wires,
but then instantiated with only one"""

dev1 = qml.device("default.qubit", wires=2)

class DummyOp(qml.operation.Operation):
r"""Dummy custom operator"""

num_wires = qml.operation.WiresEnum.AllWires

@qml.qnode(dev1)
def circuit():
DummyOp(wires=[0])
return qml.expval(qml.PauliZ(0))

with pytest.raises(
qml.QuantumFunctionError,
match=f"Operator {DummyOp.__name__} must act on all wires",
):
circuit()

def test_pow_method_with_non_numeric_power_raises_error(self):
"""Test that when raising an Operator to a power that is not a number raises
a ValueError."""
Expand Down
55 changes: 0 additions & 55 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,61 +645,6 @@ def func(x, y):
assert qn.qtape.operations == contents[0:3]
assert qn.qtape.measurements == contents[3:]

def test_operator_all_device_wires(self, monkeypatch, tol):
"""Test that an operator that must act on all wires raises an error
if the operator wires are not the device wires (when device wires
are defined)."""
monkeypatch.setattr(qml.RX, "num_wires", qml.operation.AllWires)

def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

dev = qml.device("default.qubit", wires=2)
qn = QNode(circuit, dev)

with pytest.raises(qml.QuantumFunctionError, match="Operator RX must act on all wires"):
qn(0.5)

dev = qml.device("default.qubit", wires=1)
qn = QNode(circuit, dev)
assert np.allclose(qn(0.5), np.cos(0.5), atol=tol, rtol=0)

def test_all_wires_new_device(self):
"""Test that an operator on AllWires must act on all device wires if they
are specified, and otherwise all tape wires, with the new device API."""

assert qml.GlobalPhase.num_wires == qml.operation.AllWires

dev = qml.device("default.qubit")
dev_with_wires = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit1(x):
qml.GlobalPhase(x, wires=0)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

# fails when GlobalPhase is a strict subset of all tape wires
with pytest.raises(qml.QuantumFunctionError, match="GlobalPhase must act on all wires"):
circuit1(0.5)

@qml.qnode(dev)
def circuit2(x):
qml.GlobalPhase(x, wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

# passes here, does not care for device.wires because it has none
assert circuit2(0.5) == 1

@qml.qnode(dev_with_wires)
def circuit3(x):
qml.GlobalPhase(x, wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

# fails when GlobalPhase is a subset of device wires, even if it acts on all tape wires
with pytest.raises(qml.QuantumFunctionError, match="GlobalPhase must act on all wires"):
circuit3(0.5)

@pytest.mark.jax
def test_jit_counts_raises_error(self):
"""Test that returning counts in a quantum function with trainable parameters while
Expand Down
41 changes: 0 additions & 41 deletions tests/test_qnode_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,47 +529,6 @@ def func(x, y):
assert qn.qtape.operations == contents[0:3]
assert qn.qtape.measurements == contents[3:]

def test_operator_all_wires(self, monkeypatch, tol):
"""Test that an operator that must act on all wires
does, or raises an error."""
monkeypatch.setattr(qml.RX, "num_wires", qml.operation.AllWires)

def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

dev = qml.device("default.mixed", wires=2)
qn = QNode(circuit, dev)

with pytest.raises(qml.QuantumFunctionError, match="Operator RX must act on all wires"):
qn(0.5)

dev = qml.device("default.mixed", wires=1)
qn = QNode(circuit, dev)
assert np.allclose(qn(0.5), np.cos(0.5), atol=tol, rtol=0)

def test_all_wires_new_device(self):
"""Test that an operator must act on all tape wires with the new device API."""

def circuit1(x):
qml.GlobalPhase(x, wires=0)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

dev = qml.devices.DefaultQubit() # TODO: add wires, change comment below
qn = QNode(circuit1, dev)

# fails when GlobalPhase is a strict subset of all tape wires
with pytest.raises(qml.QuantumFunctionError, match="GlobalPhase must act on all wires"):
qn(0.5)

@qml.qnode(dev)
def circuit2(x):
qml.GlobalPhase(x, wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

# passes here, does not care for device.wires because it has none
assert circuit2(0.5) == 1

@pytest.mark.jax
def test_jit_counts_raises_error(self):
"""Test that returning counts in a quantum function with trainable parameters while
Expand Down
Loading