Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Expr support to QuantumCircuit.compose #10375

Merged
merged 1 commit into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 81 additions & 32 deletions qiskit/circuit/quantumcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,43 +949,16 @@ def compose(
)
edge_map.update(zip(other.clbits, dest.cbit_argument_conversion(clbits)))

# Cache for `map_register_to_dest`.
_map_register_cache = {}

def map_register_to_dest(theirs):
"""Map the target's registers to suitable equivalents in the destination, adding an
extra one if there's no exact match."""
if theirs.name in _map_register_cache:
return _map_register_cache[theirs.name]
mapped_bits = [edge_map[bit] for bit in theirs]
for ours in dest.cregs:
if mapped_bits == list(ours):
mapped_theirs = ours
break
else:
mapped_theirs = ClassicalRegister(bits=mapped_bits)
dest.add_register(mapped_theirs)
_map_register_cache[theirs.name] = mapped_theirs
return mapped_theirs

variable_mapper = _ComposeVariableMapper(dest, edge_map)
mapped_instrs: list[CircuitInstruction] = []
for instr in other.data:
n_qargs: list[Qubit] = [edge_map[qarg] for qarg in instr.qubits]
n_cargs: list[Clbit] = [edge_map[carg] for carg in instr.clbits]
n_op = instr.operation.copy()

if getattr(n_op, "condition", None) is not None:
target, value = n_op.condition
if isinstance(target, Clbit):
n_op.condition = (edge_map[target], value)
else:
n_op.condition = (map_register_to_dest(target), value)
elif isinstance(n_op, SwitchCaseOp):
if isinstance(n_op.target, Clbit):
n_op.target = edge_map[n_op.target]
else:
n_op.target = map_register_to_dest(n_op.target)

if (condition := getattr(n_op, "condition", None)) is not None:
n_op.condition = variable_mapper.map_condition(condition)
if isinstance(n_op, SwitchCaseOp):
n_op.target = variable_mapper.map_target(n_op.target)
mapped_instrs.append(CircuitInstruction(n_op, n_qargs, n_cargs))

if front:
Expand Down Expand Up @@ -5252,3 +5225,79 @@ def _bit_argument_conversion_scalar(specifier, bit_sequence, bit_set, type_):
else f"Invalid bit index: '{specifier}' of type '{type(specifier)}'"
)
raise CircuitError(message)


class _ComposeVariableMapper(expr.ExprVisitor[expr.Expr]):
"""Stateful helper class that manages the mapping of variables in conditions and expressions to
items in the destination ``circuit``.

This mutates ``circuit`` by adding registers as required."""

__slots__ = ("circuit", "register_map", "bit_map")

def __init__(self, circuit, bit_map):
self.circuit = circuit
self.register_map = {}
self.bit_map = bit_map

def _map_register(self, theirs):
"""Map the target's registers to suitable equivalents in the destination, adding an
extra one if there's no exact match."""
if (mapped_theirs := self.register_map.get(theirs.name)) is not None:
return mapped_theirs
mapped_bits = [self.bit_map[bit] for bit in theirs]
for ours in self.circuit.cregs:
if mapped_bits == list(ours):
mapped_theirs = ours
break
else:
mapped_theirs = ClassicalRegister(bits=mapped_bits)
self.circuit.add_register(mapped_theirs)
self.register_map[theirs.name] = mapped_theirs
return mapped_theirs

def map_condition(self, condition, /):
"""Map the given ``condition`` so that it only references variables in the destination
circuit (as given to this class on initialisation)."""
if condition is None:
return None
if isinstance(condition, expr.Expr):
return self.map_expr(condition)
target, value = condition
if isinstance(target, Clbit):
return (self.bit_map[target], value)
return (self._map_register(target), value)

def map_target(self, target, /):
"""Map the runtime variables in a ``target`` of a :class:`.SwitchCaseOp` to the new circuit,
as defined in the ``circuit`` argument of the initialiser of this class."""
if isinstance(target, Clbit):
return self.bit_map[target]
if isinstance(target, ClassicalRegister):
return self._map_register(target)
return self.map_expr(target)

def map_expr(self, node: expr.Expr, /) -> expr.Expr:
"""Map the variables in an :class:`~.expr.Expr` node to the new circuit."""
return node.accept(self)

def visit_var(self, node, /):
if isinstance(node.var, Clbit):
return expr.Var(self.bit_map[node.var], node.type)
if isinstance(node.var, ClassicalRegister):
return expr.Var(self._map_register(node.var), node.type)
# Defensive against the expansion of the variable system; we don't want to silently do the
# wrong thing (which would be `return node` without mapping, right now).
raise CircuitError(f"unhandled variable in 'compose': {node}") # pragma: no cover

def visit_value(self, node, /):
return expr.Value(node.value, node.type)

def visit_unary(self, node, /):
return expr.Unary(node.op, node.operand.accept(self), node.type)

def visit_binary(self, node, /):
return expr.Binary(node.op, node.left.accept(self), node.right.accept(self), node.type)

def visit_cast(self, node, /):
return expr.Cast(node.operand.accept(self), node.type, implicit=node.implicit)
Comment on lines +5293 to +5303
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it wouldn't really come up again, but if it does, it might make sense to have a transformer base class specific for expr.Expr that has default implementations like this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean like a subclass of ExprVisitor that (by default) copies the structure by producing the exact same Expr, but visiting each child node, with the intention that a consumer overrides only the behaviour they need?

I could potentially see that being useful, yeah - I'd maybe like to wait to see if we have other use-cases for it first, before adding more API surface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, exactly 😄. Totally agree, makes sense to wait.

91 changes: 91 additions & 0 deletions test/python/circuit/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SwitchCaseOp,
)
from qiskit.circuit.library import HGate, RZGate, CXGate, CCXGate, TwoLocal
from qiskit.circuit.classical import expr
from qiskit.test import QiskitTestCase


Expand Down Expand Up @@ -789,6 +790,96 @@ def test_compose_noclbits_registerless(self):
self.assertEqual(outer.clbits, inner.clbits)
self.assertEqual(outer.cregs, [])

def test_expr_condition_is_mapped(self):
"""Test that an expression in a condition involving several registers is mapped correctly to
the destination circuit."""
inner = QuantumCircuit(1)
inner.x(0)
a_src = ClassicalRegister(2, "a_src")
b_src = ClassicalRegister(2, "b_src")
c_src = ClassicalRegister(name="c_src", bits=list(a_src) + list(b_src))
source = QuantumCircuit(QuantumRegister(1), a_src, b_src, c_src)

test_1 = lambda: expr.lift(a_src[0])
test_2 = lambda: expr.logic_not(b_src[1])
test_3 = lambda: expr.logic_and(expr.bit_and(b_src, 2), expr.less(c_src, 7))
source.if_test(test_1(), inner.copy(), [0], [])
source.if_else(test_2(), inner.copy(), inner.copy(), [0], [])
source.while_loop(test_3(), inner.copy(), [0], [])

a_dest = ClassicalRegister(2, "a_dest")
b_dest = ClassicalRegister(2, "b_dest")
dest = QuantumCircuit(QuantumRegister(1), a_dest, b_dest).compose(source)

# Check that the input conditions weren't mutated.
for in_condition, instruction in zip((test_1, test_2, test_3), source.data):
self.assertEqual(in_condition(), instruction.operation.condition)

# Should be `a_dest`, `b_dest` and an added one to account for `c_src`.
self.assertEqual(len(dest.cregs), 3)
mapped_reg = dest.cregs[-1]

expected = QuantumCircuit(dest.qregs[0], a_dest, b_dest, mapped_reg)
expected.if_test(expr.lift(a_dest[0]), inner.copy(), [0], [])
expected.if_else(expr.logic_not(b_dest[1]), inner.copy(), inner.copy(), [0], [])
expected.while_loop(
expr.logic_and(expr.bit_and(b_dest, 2), expr.less(mapped_reg, 7)), inner.copy(), [0], []
)
self.assertEqual(dest, expected)

def test_expr_target_is_mapped(self):
"""Test that an expression in a switch statement's target is mapping correctly to the
destination circuit."""
inner1 = QuantumCircuit(1)
inner1.x(0)
inner2 = QuantumCircuit(1)
inner2.z(0)

a_src = ClassicalRegister(2, "a_src")
b_src = ClassicalRegister(2, "b_src")
c_src = ClassicalRegister(name="c_src", bits=list(a_src) + list(b_src))
source = QuantumCircuit(QuantumRegister(1), a_src, b_src, c_src)

test_1 = lambda: expr.lift(a_src[0])
test_2 = lambda: expr.logic_not(b_src[1])
test_3 = lambda: expr.lift(b_src)
test_4 = lambda: expr.bit_and(c_src, 7)
source.switch(test_1(), [(False, inner1.copy()), (True, inner2.copy())], [0], [])
source.switch(test_2(), [(False, inner1.copy()), (True, inner2.copy())], [0], [])
source.switch(test_3(), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], [])
source.switch(test_4(), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], [])

a_dest = ClassicalRegister(2, "a_dest")
b_dest = ClassicalRegister(2, "b_dest")
dest = QuantumCircuit(QuantumRegister(1), a_dest, b_dest).compose(source)

# Check that the input expressions weren't mutated.
for in_target, instruction in zip((test_1, test_2, test_3, test_4), source.data):
self.assertEqual(in_target(), instruction.operation.target)

# Should be `a_dest`, `b_dest` and an added one to account for `c_src`.
self.assertEqual(len(dest.cregs), 3)
mapped_reg = dest.cregs[-1]

expected = QuantumCircuit(dest.qregs[0], a_dest, b_dest, mapped_reg)
expected.switch(
expr.lift(a_dest[0]), [(False, inner1.copy()), (True, inner2.copy())], [0], []
)
expected.switch(
expr.logic_not(b_dest[1]), [(False, inner1.copy()), (True, inner2.copy())], [0], []
)
expected.switch(
expr.lift(b_dest), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], []
)
expected.switch(
expr.bit_and(mapped_reg, 7),
[(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())],
[0],
[],
)

self.assertEqual(dest, expected)


if __name__ == "__main__":
unittest.main()