Skip to content

Commit

Permalink
Use retworkx for substitute_node_with_dag
Browse files Browse the repository at this point in the history
This commit leverage the substitute_node_with_subgraph method being
added Qiskit/rustworkx#312 for the dagcircuit method
substitute_node_with_dag.
  • Loading branch information
mtreinish committed Apr 26, 2021
1 parent 71f1c39 commit a32fef8
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 101 deletions.
146 changes: 46 additions & 100 deletions qiskit/dagcircuit/dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,63 +825,6 @@ def _check_wires_list(self, wires, node):
raise DAGCircuitError("expected %d wires, got %d"
% (wire_tot, len(wires)))

def _make_pred_succ_maps(self, node):
"""Return predecessor and successor dictionaries.
Args:
node (DAGNode): reference to multi_graph node
Returns:
tuple(dict): tuple(predecessor_map, successor_map)
These map from wire (Register, int) to the node ids for the
predecessor (successor) nodes of the input node.
"""

pred_map = {e[2]: e[0] for e in
self._multi_graph.in_edges(node._node_id)}
succ_map = {e[2]: e[1] for e in
self._multi_graph.out_edges(node._node_id)}
return pred_map, succ_map

def _full_pred_succ_maps(self, pred_map, succ_map, input_circuit,
wire_map):
"""Map all wires of the input circuit.
Map all wires of the input circuit to predecessor and
successor nodes in self, keyed on wires in self.
Args:
pred_map (dict): comes from _make_pred_succ_maps
succ_map (dict): comes from _make_pred_succ_maps
input_circuit (DAGCircuit): the input circuit
wire_map (dict): the map from wires of input_circuit to wires of self
Returns:
tuple: full_pred_map, full_succ_map (dict, dict)
Raises:
DAGCircuitError: if more than one predecessor for output nodes
"""
full_pred_map = {}
full_succ_map = {}
for w in input_circuit.input_map:
# If w is wire mapped, find the corresponding predecessor
# of the node
if w in wire_map:
full_pred_map[wire_map[w]] = pred_map[wire_map[w]]
full_succ_map[wire_map[w]] = succ_map[wire_map[w]]
else:
# Otherwise, use the corresponding output nodes of self
# and compute the predecessor.
full_succ_map[w] = self.output_map[w]
full_pred_map[w] = self._multi_graph.predecessors(
self.output_map[w])[0]
if len(self._multi_graph.predecessors(self.output_map[w])) != 1:
raise DAGCircuitError("too many predecessors for %s[%d] "
"output node" % (w.register, w.index))

return full_pred_map, full_succ_map

def __eq__(self, other):
# Try to convert to float, but in case of unbound ParameterExpressions
# a TypeError will be raise, fallback to normal equality in those
Expand Down Expand Up @@ -991,7 +934,7 @@ def substitute_node_with_dag(self, node, input_dag, wires=None):

if wires is None:
wires = in_dag.wires

wire_set = set(wires)
self._check_wires_list(wires, node)

# Create a proxy wire_map to identify fragments and duplicates
Expand All @@ -1014,11 +957,11 @@ def substitute_node_with_dag(self, node, input_dag, wires=None):

condition_bit_list = self._bits_in_condition(condition)

wire_map = dict(zip(wires, list(node.qargs) + list(node.cargs) + list(condition_bit_list)))
new_wires = list(node.qargs) + list(node.cargs) + list(condition_bit_list)

wire_map = dict(zip(wires, new_wires))
reverse_wire_map = dict(zip(new_wires, wires))
self._check_wiremap_validity(wire_map, wires, self.input_map)
pred_map, succ_map = self._make_pred_succ_maps(node)
full_pred_map, full_succ_map = self._full_pred_succ_maps(pred_map, succ_map,
in_dag, wire_map)

if condition_bit_list:
# If we are replacing a conditional node, map input dag through
Expand All @@ -1033,48 +976,51 @@ def substitute_node_with_dag(self, node, input_dag, wires=None):
raise DAGCircuitError('Mapped DAG would alter clbits '
'on which it would be conditioned.')

# Now that we know the connections, delete node
self._multi_graph.remove_node(node._node_id)

# Iterate over nodes of input_circuit
for sorted_node in in_dag.topological_op_nodes():
# Insert a new node
def filter_fn(node):
if node.type != 'op':
return False
for qarg in node.qargs:
if qarg not in wire_set:
return False
return True

def edge_map_fn(source, target, self_wire):
wire = reverse_wire_map[self_wire]
# successor edge
if source == node._node_id:
wire_id = in_dag.output_map[wire]._node_id
out_index = in_dag._multi_graph.predecessor_indices(wire_id)[0]
# predecessor edge
else:
wire_id = in_dag.input_map[wire]._node_id
out_index = in_dag._multi_graph.successor_indices(wire_id)[0]
return out_index

def edge_weight_map(wire):
return wire_map[wire]

node_map = self._multi_graph.substitute_node_with_subgraph(
node._node_id, in_dag._multi_graph, edge_map_fn, filter_fn,
edge_weight_map)

# Iterate over nodes of input_circuit and update wires
for old_node_index in node_map:
# update node attributes
new_node_index = node_map[old_node_index]
old_node = in_dag._multi_graph[old_node_index]
new_node = copy.copy(old_node)
condition = self._map_condition(wire_map,
sorted_node.op.condition,
old_node.op.condition,
self.cregs.values())
m_qargs = list(map(lambda x: wire_map.get(x, x),
sorted_node.qargs))
old_node.qargs))
m_cargs = list(map(lambda x: wire_map.get(x, x),
sorted_node.cargs))
node_index = self._add_op_node(sorted_node.op, m_qargs, m_cargs)

# Add edges from predecessor nodes to new node
# and update predecessor nodes that change
all_cbits = self._bits_in_condition(condition)
all_cbits.extend(m_cargs)
al = [m_qargs, all_cbits]
for q in itertools.chain(*al):
self._multi_graph.add_edge(full_pred_map[q],
node_index,
q)
full_pred_map[q] = node_index

# Connect all predecessors and successors, and remove
# residual edges between input and output nodes
for w in full_pred_map:
self._multi_graph.add_edge(full_pred_map[w],
full_succ_map[w],
w)
o_pred = self._multi_graph.predecessors(self.output_map[w]._node_id)
if len(o_pred) > 1:
if len(o_pred) != 2:
raise DAGCircuitError("expected 2 predecessors here")

p = [x for x in o_pred if x != full_pred_map[w]]
if len(p) != 1:
raise DAGCircuitError("expected 1 predecessor to pass filter")

self._multi_graph.remove_edge(p[0], self.output_map[w])
old_node.cargs))
new_node.qargs = m_qargs
new_node.cargs = m_cargs
new_node._node_id = new_node_index
new_node.op.condition = condition
self._multi_graph[new_node_index] = new_node

def substitute_node(self, node, op, inplace=False):
"""Replace a DAGNode with a single instruction. qargs, cargs and
Expand Down
32 changes: 31 additions & 1 deletion test/python/dagcircuit/test_dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,10 +1141,40 @@ def test_substitute_circuit_one_middle(self):
self.dag.substitute_node_with_dag(cx_node, flipped_cx_circuit, wires=[v[0], v[1]])

self.assertEqual(self.dag.count_ops()['h'], 5)
expected = DAGCircuit()
qreg = QuantumRegister(3, 'qr')
creg = ClassicalRegister(2, 'cr')
expected.add_qreg(qreg)
expected.add_creg(creg)
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[1]], [])
expected.apply_operation_back(CXGate(), [qreg[1], qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(HGate(), [qreg[1]], [])
expected.apply_operation_back(XGate(), [qreg[1]], [])
self.assertEqual(self.dag, expected)

def test_substitute_circuit_one_front(self):
"""The method substitute_node_with_dag() replaces a leaf-in-the-front node with a DAG."""
pass
flipped_cx_circuit = DAGCircuit()
v = QuantumRegister(1, "v")
flipped_cx_circuit.add_qreg(v)
flipped_cx_circuit.apply_operation_back(HGate(), [v[0]], [])
flipped_cx_circuit.apply_operation_back(XGate(), [v[0]], [])

self.dag.substitute_node_with_dag(self.dag.op_nodes()[0],
flipped_cx_circuit)
expected = DAGCircuit()
qreg = QuantumRegister(3, 'qr')
creg = ClassicalRegister(2, 'cr')
expected.add_qreg(qreg)
expected.add_creg(creg)
expected.apply_operation_back(HGate(), [qreg[0]], [])
expected.apply_operation_back(XGate(), [qreg[0]], [])
expected.apply_operation_back(CXGate(), [qreg[0], qreg[1]], [])
expected.apply_operation_back(XGate(), [qreg[1]], [])
self.assertEqual(self.dag, expected)

def test_substitute_circuit_one_back(self):
"""The method substitute_node_with_dag() replaces a leaf-in-the-back node with a DAG."""
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ setenv =
QISKIT_TEST_CAPTURE_STREAMS=1
deps = -r{toxinidir}/requirements.txt
-r{toxinidir}/requirements-dev.txt
git+https://github.com/mtreinish/retworkx@test-stuff
commands =
stestr run {posargs}

Expand Down

0 comments on commit a32fef8

Please sign in to comment.