Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OptimizeSwapBeforeMeasure only when last measurements #5906

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

"""Remove the swaps followed by measurement (and adapt the measurement)."""

from qiskit.circuit import Measure
from qiskit.circuit.library.standard_gates import SwapGate
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler import Layout
from qiskit.dagcircuit import DAGCircuit


Expand All @@ -35,29 +35,33 @@ def run(self, dag):
Returns:
DAGCircuit: the optimized DAG.
"""
swaps = dag.op_nodes(SwapGate)
for swap in swaps[::-1]:
final_successor = []
for successor in dag.successors(swap):
final_successor.append(successor.type == 'out' or (successor.type == 'op' and
successor.op.name == 'measure'))
if all(final_successor):
# the node swap needs to be removed and, if a measure follows, needs to be adapted
swap_qargs = swap.qargs
measure_layer = DAGCircuit()
for qreg in dag.qregs.values():
measure_layer.add_qreg(qreg)
for creg in dag.cregs.values():
measure_layer.add_creg(creg)
for successor in list(dag.successors(swap)):
if successor.type == 'op' and successor.op.name == 'measure':
# replace measure node with a new one, where qargs is set with the "other"
# swap qarg.
dag.remove_op_node(successor)
old_measure_qarg = successor.qargs[0]
new_measure_qarg = swap_qargs[swap_qargs.index(old_measure_qarg) - 1]
measure_layer.apply_operation_back(Measure(), [new_measure_qarg],
[successor.cargs[0]])
dag.compose(measure_layer)
dag.remove_op_node(swap)
return dag

new_dag = DAGCircuit()
new_dag.metadata = dag.metadata
new_dag._global_phase = dag._global_phase
for creg in dag.cregs.values():
new_dag.add_creg(creg)
for qreg in dag.qregs.values():
new_dag.add_qreg(qreg)
Comment on lines +39 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
new_dag = DAGCircuit()
new_dag.metadata = dag.metadata
new_dag._global_phase = dag._global_phase
for creg in dag.cregs.values():
new_dag.add_creg(creg)
for qreg in dag.qregs.values():
new_dag.add_qreg(qreg)
new_dag = dag._copy_circuit_metadata()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice one. Why is _copy_circuit_metadata a private method?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially because it was used internally, but then was useful to passes that needed to rebuild the dag (but preserve metadata) like the layout routing passes. Making it public seems like a good idea.


_layout = Layout.generate_trivial_layout(*dag.qregs.values())
_trivial_layout = Layout.generate_trivial_layout(*dag.qregs.values())

for node in dag.topological_op_nodes():
if node.type == 'op':
qargs = [_trivial_layout[_layout[qarg]] for qarg in node.qargs]
if isinstance(node.op, SwapGate):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this apply for all swaps or only those added by routing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not distinguish from where is the SWAP coming from.

swap = node
final_successor = []
for successor in dag.successors(swap):
if successor.type == 'op' and successor.op.name == 'measure':
is_final_measure = all([s.type == 'out'
for s in dag.successors(successor)])
else:
is_final_measure = False
final_successor.append(successor.type == 'out' or is_final_measure)
if all(final_successor):
_layout.swap(*qargs)
continue
new_dag.apply_operation_back(node.op, qargs, node.cargs)
return new_dag
172 changes: 106 additions & 66 deletions test/python/transpiler/test_optimize_swap_before_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,72 +73,6 @@ def test_optimize_1swap_2measure(self):

self.assertEqual(circuit_to_dag(expected), after)

def test_optimize_nswap_nmeasure(self):
""" Remove severals swap affecting multiple measurements
┌─┐ ┌─┐
q_0: ─X──X─────X────┤M├───────────────────────────────── q_0: ──────┤M├───────────────
│ │ │ └╥┘ ┌─┐ ┌─┐└╥┘
q_1: ─X──X──X──X──X──╫─────X────┤M├───────────────────── q_1: ───┤M├─╫────────────────
│ │ ║ │ └╥┘ ┌─┐ ┌─┐└╥┘ ║
q_2: ───────X──X──X──╫──X──X─────╫──X────┤M├──────────── q_2: ┤M├─╫──╫────────────────
│ ║ │ ║ │ └╥┘┌─┐ └╥┘ ║ ║ ┌─┐
q_3: ─X─────X──X─────╫──X──X──X──╫──X─────╫─┤M├───────── q_3: ─╫──╫──╫────┤M├─────────
│ │ ║ │ │ ║ ║ └╥┘┌─┐ ║ ║ ║ └╥┘ ┌─┐
q_4: ─X──X──X──X─────╫──X──X──X──╫──X─────╫──╫─┤M├────── ==> q_4: ─╫──╫──╫─────╫───────┤M├
│ │ ║ │ ║ │ ║ ║ └╥┘┌─┐ ║ ║ ║ ┌─┐ ║ └╥┘
q_5: ────X──X──X──X──╫──X──X─────╫──X──X──╫──╫──╫─┤M├─── q_5: ─╫──╫──╫─┤M├─╫────────╫─
│ │ ║ │ ║ │ ║ ║ ║ └╥┘┌─┐ ║ ║ ║ └╥┘ ║ ┌─┐ ║
q_6: ─X──X──X──X──X──╫──X──X─────╫─────X──╫──╫──╫──╫─┤M├ q_6: ─╫──╫──╫──╫──╫─┤M├────╫─
│ │ │ ║ │ ┌─┐ ║ ║ ║ ║ ║ └╥┘ ║ ║ ║ ║ ║ └╥┘┌─┐ ║
q_7: ─X──X─────X─────╫──X─┤M├────╫────────╫──╫──╫──╫──╫─ q_7: ─╫──╫──╫──╫──╫──╫─┤M├─╫─
║ └╥┘ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ └╥┘ ║
c: 8/════════════════╩═════╩═════╩════════╩══╩══╩══╩══╩═ c: 8/═╩══╩══╩══╩══╩══╩══╩══╩═
0 7 1 2 3 4 5 6 0 1 2 3 4 5 6 7
"""
circuit = QuantumCircuit(8, 8)
circuit.swap(3, 4)
circuit.swap(6, 7)
circuit.swap(0, 1)
circuit.swap(6, 7)
circuit.swap(4, 5)
circuit.swap(0, 1)
circuit.swap(5, 6)
circuit.swap(3, 4)
circuit.swap(1, 2)
circuit.swap(6, 7)
circuit.swap(4, 5)
circuit.swap(2, 3)
circuit.swap(0, 1)
circuit.swap(5, 6)
circuit.swap(1, 2)
circuit.swap(6, 7)
circuit.swap(4, 5)
circuit.swap(2, 3)
circuit.swap(3, 4)
circuit.swap(3, 4)
circuit.swap(5, 6)
circuit.swap(1, 2)
circuit.swap(4, 5)
circuit.swap(2, 3)
circuit.swap(5, 6)
circuit.measure(range(8), range(8))
dag = circuit_to_dag(circuit)

expected = QuantumCircuit(8, 8)
expected.measure(0, 2)
expected.measure(1, 1)
expected.measure(2, 0)
expected.measure(3, 4)
expected.measure(4, 7)
expected.measure(5, 3)
expected.measure(6, 5)
expected.measure(7, 6)

pass_ = OptimizeSwapBeforeMeasure()
after = pass_.run(dag)

self.assertEqual(circuit_to_dag(expected), after)

def test_cannot_optimize(self):
""" Cannot optimize when swap is not at the end in all of the successors
qr0:--X-----m--
Expand Down Expand Up @@ -218,6 +152,112 @@ def test_optimize_overlap_swap(self):

self.assertEqual(expected, after)

def test_optimize_nswap_nmeasure(self):
""" Remove several swap affecting multiple measurements
┌─┐ ┌─┐
q_0: ─X──X─────X────┤M├───────────────────────────────── q_0: ──────┤M├───────────────
│ │ │ └╥┘ ┌─┐ ┌─┐└╥┘
q_1: ─X──X──X──X──X──╫─────X────┤M├───────────────────── q_1: ───┤M├─╫────────────────
│ │ ║ │ └╥┘ ┌─┐ ┌─┐└╥┘ ║
q_2: ───────X──X──X──╫──X──X─────╫──X────┤M├──────────── q_2: ┤M├─╫──╫────────────────
│ ║ │ ║ │ └╥┘┌─┐ └╥┘ ║ ║ ┌─┐
q_3: ─X─────X──X─────╫──X──X──X──╫──X─────╫─┤M├───────── q_3: ─╫──╫──╫────┤M├─────────
│ │ ║ │ │ ║ ║ └╥┘┌─┐ ║ ║ ║ └╥┘ ┌─┐
q_4: ─X──X──X──X─────╫──X──X──X──╫──X─────╫──╫─┤M├────── ==> q_4: ─╫──╫──╫─────╫───────┤M├
│ │ ║ │ ║ │ ║ ║ └╥┘┌─┐ ║ ║ ║ ┌─┐ ║ └╥┘
q_5: ────X──X──X──X──╫──X──X─────╫──X──X──╫──╫──╫─┤M├─── q_5: ─╫──╫──╫─┤M├─╫────────╫─
│ │ ║ │ ║ │ ║ ║ ║ └╥┘┌─┐ ║ ║ ║ └╥┘ ║ ┌─┐ ║
q_6: ─X──X──X──X──X──╫──X──X─────╫─────X──╫──╫──╫──╫─┤M├ q_6: ─╫──╫──╫──╫──╫─┤M├────╫─
│ │ │ ║ │ ┌─┐ ║ ║ ║ ║ ║ └╥┘ ║ ║ ║ ║ ║ └╥┘┌─┐ ║
q_7: ─X──X─────X─────╫──X─┤M├────╫────────╫──╫──╫──╫──╫─ q_7: ─╫──╫──╫──╫──╫──╫─┤M├─╫─
║ └╥┘ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ └╥┘ ║
c: 8/════════════════╩═════╩═════╩════════╩══╩══╩══╩══╩═ c: 8/═╩══╩══╩══╩══╩══╩══╩══╩═
0 7 1 2 3 4 5 6 0 1 2 3 4 5 6 7
"""
circuit = QuantumCircuit(8, 8)
circuit.swap(3, 4)
circuit.swap(6, 7)
circuit.swap(0, 1)
circuit.swap(6, 7)
circuit.swap(4, 5)
circuit.swap(0, 1)
circuit.swap(5, 6)
circuit.swap(3, 4)
circuit.swap(1, 2)
circuit.swap(6, 7)
circuit.swap(4, 5)
circuit.swap(2, 3)
circuit.swap(0, 1)
circuit.swap(5, 6)
circuit.swap(1, 2)
circuit.swap(6, 7)
circuit.swap(4, 5)
circuit.swap(2, 3)
circuit.swap(3, 4)
circuit.swap(3, 4)
circuit.swap(5, 6)
circuit.swap(1, 2)
circuit.swap(4, 5)
circuit.swap(2, 3)
circuit.swap(5, 6)
circuit.measure(range(8), range(8))

expected = QuantumCircuit(8, 8)
expected.measure(0, 2)
expected.measure(1, 1)
expected.measure(2, 0)
expected.measure(3, 4)
expected.measure(4, 7)
expected.measure(5, 3)
expected.measure(6, 5)
expected.measure(7, 6)

pass_manager = PassManager()
pass_manager.append(
[OptimizeSwapBeforeMeasure(), DAGFixedPoint()],
do_while=lambda property_set: not property_set['dag_fixed_point'])
after = pass_manager.run(circuit)

self.assertEqual(expected, after)


class TestOptimizeSwapBeforeMeasureMidMeasure(QiskitTestCase):
""" Test swap-followed-by-measure optimizations, with mid-circuit measurement."""

def test_mid_circuit(self):
"""Test mid-circuit measurement"""
qr1 = QuantumRegister(1, 'qr1')
qr2 = QuantumRegister(2, 'qr2')
cr = ClassicalRegister(3, 'cr')
circuit = QuantumCircuit(qr1, qr2, cr)
circuit.h(qr1[0])
circuit.h(qr2[1])
circuit.swap(qr1[0], qr2[0])
circuit.measure(qr1[0], cr[0])
circuit.measure(qr2[0], cr[1])
circuit.cx(qr1[0], qr2[1])
circuit.swap(qr1[0], qr2[0])
circuit.measure(qr1[0], cr[0])
circuit.measure(qr2[0], cr[1])

expected = QuantumCircuit(qr1, qr2, cr)
expected.h(qr1[0])
expected.h(qr2[1])
expected.swap(qr1[0], qr2[0])
expected.measure(qr1[0], cr[0])
expected.measure(qr2[0], cr[1])
expected.cx(qr1[0], qr2[1])
expected.measure(qr2[0], cr[0])
expected.measure(qr1[0], cr[1])

pass_manager = PassManager()
pass_manager.append(
[OptimizeSwapBeforeMeasure(), DAGFixedPoint()],
do_while=lambda property_set: not property_set['dag_fixed_point'])
after = pass_manager.run(circuit)

self.assertEqual(expected, after)


if __name__ == '__main__':
unittest.main()