Skip to content

Commit

Permalink
Add Expr support in DAGCircuit.substitute_node
Browse files Browse the repository at this point in the history
As part of this minor refactor, this updates the logic to no longer
silently override conditions on the replacement `op`.  The method gains
a `propagate_condition` argument analogous to the same argument in
`substitute_node_with_dag`, which can be set `False` to specify that the
caller is aware that the new operation should implement the same
conditional logic.
  • Loading branch information
jakelishman committed Jul 4, 2023
1 parent ac942ee commit e70058f
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 9 deletions.
49 changes: 40 additions & 9 deletions qiskit/dagcircuit/dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,7 @@ def edge_weight_map(wire):

return {k: self._multi_graph[v] for k, v in node_map.items()}

def substitute_node(self, node, op, inplace=False):
def substitute_node(self, node, op, inplace=False, propagate_condition=True):
"""Replace an DAGOpNode with a single operation. qargs, cargs and
conditions for the new operation will be inferred from the node to be
replaced. The new operation will be checked to match the shape of the
Expand All @@ -1273,6 +1273,10 @@ def substitute_node(self, node, op, inplace=False):
inplace (bool): Optional, default False. If True, existing DAG node
will be modified to include op. Otherwise, a new DAG node will
be used.
propagate_condition (bool): Optional, default True. If True, a condition on the
``node`` to be replaced will be applied to the new ``op``. This is the legacy
behaviour. If either node is a control-flow operation, this will be ignored. If
the ``op`` already has a condition, :exc:`.DAGCircuitError` is raised.
Returns:
DAGOpNode: the new node containing the added operation.
Expand All @@ -1293,23 +1297,50 @@ def substitute_node(self, node, op, inplace=False):
)
)

# This might include wires that are inherent to the node, like in its `condition` or
# `target` fields, so might be wider than `node.op.num_{qu,cl}bits`.
current_wires = {wire for _, _, wire in self.edges(node)}
new_wires = set(node.qargs) | set(node.cargs)
if (new_condition := getattr(op, "condition", None)) is not None:
new_wires.update(condition_resources(new_condition).clbits)
elif isinstance(op, SwitchCaseOp):
if isinstance(op.target, Clbit):
new_wires.add(op.target)
elif isinstance(op.target, ClassicalRegister):
new_wires.update(op.target)
else:
new_wires.update(node_resources(op.target).clbits)

if propagate_condition and not (
isinstance(node.op, ControlFlowOp) or isinstance(op, ControlFlowOp)
):
if new_condition is not None:
raise DAGCircuitError(
"Cannot propagate a condition to an operation that already has one."
)
if (old_condition := getattr(node.op, "condition", None)) is not None:
if not isinstance(op, Instruction):
raise DAGCircuitError("Cannot add a condition on a generic Operation.")
op.condition = old_condition
new_wires.update(condition_resources(old_condition).clbits)

if new_wires != current_wires:
# The new wires must be a non-strict subset of the current wires; if they add new wires,
# we'd not know where to cut the existing wire to insert the new dependency.
raise DAGCircuitError(
f"New operation '{op}' does not span the same wires as the old node '{node}'."
f" New wires: {new_wires}, old wires: {current_wires}."
)

if inplace:
if op.name != node.op.name:
self._increment_op(op)
self._decrement_op(node.op)
save_condition = getattr(node.op, "condition", None)
node.op = op
if save_condition and not isinstance(op, Instruction):
raise DAGCircuitError("Cannot add a condition on a generic Operation.")
node.op.condition = save_condition
return node

new_node = copy.copy(node)
save_condition = getattr(new_node.op, "condition", None)
new_node.op = op
if save_condition and not isinstance(new_node.op, Instruction):
raise DAGCircuitError("Cannot add a condition on a generic Operation.")
new_node.op.condition = save_condition
self._multi_graph[node._node_id] = new_node
if op.name != node.op.name:
self._increment_op(op)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
features:
- |
:meth:`.DAGCircuit.substitute_node` gained a ``propagate_condition`` keyword argument that is
analogous to the same argument in :meth:`~.DAGCircuit.substitute_node_with_dag`. Setting this
to ``False`` opts out of the legacy behaviour of copying a condition on the ``node`` onto the
new ``op`` that is replacing it.
This option is ignored for general control-flow operations, which will never propagate their
condition, nor accept a condition from another node.
fixes:
- |
:meth:`.DAGCircuit.substitute_node` will no longer silently overwrite an existing condition on
the given replacement ``op``. If ``propagate_condition`` is set to ``True`` (the default), a
:exc:`.DAGCircuitError` will be raised instead.
114 changes: 114 additions & 0 deletions test/python/dagcircuit/test_dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,120 @@ def test_substituting_node_preserves_parents_children(self, inplace):
self.assertEqual(dag.descendants(replacement_node), descendants)
self.assertEqual(replacement_node is node_to_be_replaced, inplace)

@data(True, False)
def test_refuses_to_overwrite_condition(self, inplace):
"""Test that the method will not forcibly overwrite a condition."""
qr = QuantumRegister(1)
cr = ClassicalRegister(2)

dag = DAGCircuit()
dag.add_qreg(qr)
dag.add_creg(cr)
node = dag.apply_operation_back(XGate().c_if(cr, 2), qr, [])

with self.assertRaisesRegex(DAGCircuitError, "Cannot propagate a condition"):
dag.substitute_node(
node, XGate().c_if(cr, 1), inplace=inplace, propagate_condition=True
)

@data(True, False)
def test_replace_if_else_op_with_another(self, inplace):
"""Test that one `IfElseOp` can be replaced with another."""
body = QuantumCircuit(1)
body.x(0)

qr = QuantumRegister(1)
cr1 = ClassicalRegister(2)
cr2 = ClassicalRegister(2)
dag = DAGCircuit()
dag.add_qreg(qr)
dag.add_creg(cr1)
dag.add_creg(cr2)
node = dag.apply_operation_back(IfElseOp(expr.logic_not(cr1), body.copy(), None), qr, [])
dag.substitute_node(node, IfElseOp(expr.equal(cr1, 0), body.copy(), None), inplace=inplace)

expected = DAGCircuit()
expected.add_qreg(qr)
expected.add_creg(cr1)
expected.add_creg(cr2)
expected.apply_operation_back(IfElseOp(expr.equal(cr1, 0), body.copy(), None), qr, [])

self.assertEqual(dag, expected)

@data(True, False)
def test_reject_replace_if_else_op_with_other_resources(self, inplace):
"""Test that the resources in the `condition` of a `IfElseOp` are checked against those in
the node to be replaced."""
body = QuantumCircuit(1)
body.x(0)

qr = QuantumRegister(1)
cr1 = ClassicalRegister(2)
cr2 = ClassicalRegister(2)
dag = DAGCircuit()
dag.add_qreg(qr)
dag.add_creg(cr1)
dag.add_creg(cr2)
node = dag.apply_operation_back(IfElseOp(expr.logic_not(cr1), body.copy(), None), qr, [])

with self.assertRaisesRegex(DAGCircuitError, "does not span the same wires"):
dag.substitute_node(
node, IfElseOp(expr.logic_not(cr2), body.copy(), None), inplace=inplace
)

@data(True, False)
def test_replace_switch_with_another(self, inplace):
"""Test that one `SwitchCaseOp` can be replaced with another."""
case = QuantumCircuit(1)
case.x(0)

qr = QuantumRegister(1)
cr1 = ClassicalRegister(2)
cr2 = ClassicalRegister(2)
dag = DAGCircuit()
dag.add_qreg(qr)
dag.add_creg(cr1)
dag.add_creg(cr2)
node = dag.apply_operation_back(
SwitchCaseOp(expr.lift(cr1), [((1, 3), case.copy())]), qr, []
)
dag.substitute_node(
node, SwitchCaseOp(expr.bit_and(cr1, 1), [(1, case.copy())]), inplace=inplace
)

expected = DAGCircuit()
expected.add_qreg(qr)
expected.add_creg(cr1)
expected.add_creg(cr2)
expected.apply_operation_back(
SwitchCaseOp(expr.bit_and(cr1, 1), [(1, case.copy())]), qr, []
)

self.assertEqual(dag, expected)

@data(True, False)
def test_reject_replace_switch_with_other_resources(self, inplace):
"""Test that the resources in the `target` of a `SwitchCaseOp` are checked against those in
the node to be replaced."""
case = QuantumCircuit(1)
case.x(0)

qr = QuantumRegister(1)
cr1 = ClassicalRegister(2)
cr2 = ClassicalRegister(2)
dag = DAGCircuit()
dag.add_qreg(qr)
dag.add_creg(cr1)
dag.add_creg(cr2)
node = dag.apply_operation_back(
SwitchCaseOp(expr.lift(cr1), [((1, 3), case.copy())]), qr, []
)

with self.assertRaisesRegex(DAGCircuitError, "does not span the same wires"):
dag.substitute_node(
node, SwitchCaseOp(expr.lift(cr2), [((1, 3), case.copy())]), inplace=inplace
)


class TestReplaceBlock(QiskitTestCase):
"""Test replacing a block of nodes in a DAG."""
Expand Down

0 comments on commit e70058f

Please sign in to comment.