Skip to content

Commit

Permalink
Fix internals usage for dagcircuit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtreinish committed Jul 21, 2024
1 parent d6d3c35 commit 7c373c6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
29 changes: 29 additions & 0 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4634,6 +4634,35 @@ new_condition = (new_target, value)
self.dag
.contains_edge(NodeIndex::new(source), NodeIndex::new(target))
}

fn _is_dag(&self) -> bool {
match rustworkx_core::petgraph::algo::toposort(&self.dag, None) {
Ok(_) => true,
Err(_) => false,
}
}

fn _in_wires(&self, node_index: usize) -> Vec<&PyObject> {
self.dag
.edges_directed(NodeIndex::new(node_index), Incoming)
.map(|wire| match wire.weight() {
Wire::Qubit(qubit) => &self.qubits.bits()[qubit.0 as usize],
Wire::Clbit(clbit) => &self.clbits.bits()[clbit.0 as usize],
Wire::Var(var) => var,
})
.collect()
}

fn _out_wires(&self, node_index: usize) -> Vec<&PyObject> {
self.dag
.edges_directed(NodeIndex::new(node_index), Incoming)
.map(|wire| match wire.weight() {
Wire::Qubit(qubit) => &self.qubits.bits()[qubit.0 as usize],
Wire::Clbit(clbit) => &self.clbits.bits()[clbit.0 as usize],
Wire::Var(var) => var,
})
.collect()
}
}

impl DAGCircuit {
Expand Down
14 changes: 6 additions & 8 deletions test/python/dagcircuit/test_dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,12 @@ def raise_if_dagcircuit_invalid(dag):
DAGCircuitError: if DAGCircuit._multi_graph is inconsistent.
"""

multi_graph = dag._multi_graph

if not rx.is_directed_acyclic_graph(multi_graph):
if not dag._is_dag():
raise DAGCircuitError("multi_graph is not a DAG.")

# Every node should be of type in, out, or op.
# All input/output nodes should be present in input_map/output_map.
for node in dag._multi_graph.nodes():
for node in dag.nodes():
if isinstance(node, DAGInNode):
assert node is dag.input_map[node.wire]
elif isinstance(node, DAGOutNode):
Expand Down Expand Up @@ -112,18 +110,18 @@ def raise_if_dagcircuit_invalid(dag):
# Wires can only terminate at input/output nodes.
op_counts = Counter()
for op_node in dag.op_nodes():
assert multi_graph.in_degree(op_node._node_id) == multi_graph.out_degree(op_node._node_id)
assert sum(1 for _ in dag.predecssors(op_node)) == sum(1 for _ in dag.successors(op_node))
op_counts[op_node.name] += 1
# The _op_names attribute should match the counted op names
assert op_counts == dag._op_names
assert op_counts == dag.count_ops()

# Node input/output edges should match node qarg/carg/condition.
for node in dag.op_nodes():
in_edges = dag._multi_graph.in_edges(node._node_id)
out_edges = dag._multi_graph.out_edges(node._node_id)

in_wires = {data for src, dest, data in in_edges}
out_wires = {data for src, dest, data in out_edges}
in_wires = set(dag.in_wires(node._node_id))
out_wires = set(dag.out_wires(node._node_id))

node_cond_bits = set(
node.op.condition[0][:] if getattr(node.op, "condition", None) is not None else []
Expand Down

0 comments on commit 7c373c6

Please sign in to comment.