-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Oxidize Commutation Analysis #12995
Oxidize Commutation Analysis #12995
Changes from 114 commits
f421a26
7425d18
a7016b0
861b548
d310746
878e3d6
f01ddfa
4bd7391
c60bb8d
24659c5
1c3521e
bab0ee0
7e52136
665c6b3
ce512b2
cc2ba14
e6c0637
b4c2ee3
b24ae43
e34c3ec
1cbbf84
85041c7
e878608
f7de4e4
2a4c864
b59499a
f5bb696
93b2927
c911205
1e4e6f3
463b86a
01c06e5
f6b27ff
0e62ad0
aaf38b9
28f6de1
11c52ab
d68743f
e07f3d5
30a4a1a
6dec768
1de3290
5c6f006
b90b660
9a6f953
7970e3d
f317077
701c980
464e87b
7f1b451
6aab4a6
e5e57c6
584ee9b
bdeb5f6
1d249ab
40c15ec
39586ba
6c5e1e9
5c9d285
db89660
e9e4a27
6a5f389
de1ee92
eeded61
a0501d3
1225061
21dec35
e37c5b0
c0adbb8
d76fd80
22e0044
ee3f4d3
454208c
847b425
3724873
cd9c119
5266807
53c902c
6e33d38
2b45b04
18c406d
b42f7e4
c564b80
6292e93
8f0f7d2
6636860
2504d1c
52cf8bf
ac75c91
816f5eb
6e8f779
d312791
f4bfa9e
29f1c07
c554666
0041098
a05df8c
f231c83
2bbf36b
f8afd5d
ea09d48
62d2cdb
71ac2fd
12d0a05
9235004
af74cb9
d625539
e19f7c6
52ba2e0
6835c8f
c7dde73
34a8a14
d7cf2cc
8b49c19
7207998
c08cfae
7073570
68f4f4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
// This code is part of Qiskit. | ||
// | ||
// (C) Copyright IBM 2024 | ||
// | ||
// This code is licensed under the Apache License, Version 2.0. You may | ||
// obtain a copy of this license in the LICENSE.txt file in the root directory | ||
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
// | ||
// Any modifications or derivative works of this code must retain this | ||
// copyright notice, and modified files need to carry a notice indicating | ||
// that they have been altered from the originals. | ||
|
||
use pyo3::prelude::PyModule; | ||
use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python}; | ||
use qiskit_circuit::operations::Param; | ||
use qiskit_circuit::Qubit; | ||
use smallvec::{smallvec, SmallVec}; | ||
use std::hash::BuildHasherDefault; | ||
|
||
use crate::commutation_checker::CommutationChecker; | ||
use ahash::AHasher; | ||
use hashbrown::HashMap; | ||
use indexmap::IndexSet; | ||
use pyo3::prelude::*; | ||
|
||
use pyo3::types::{PyDict, PyList}; | ||
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire}; | ||
use rustworkx_core::petgraph::stable_graph::NodeIndex; | ||
|
||
type AIndexSet<T> = IndexSet<T, BuildHasherDefault<AHasher>>; | ||
#[derive(Clone, Debug)] | ||
pub enum CommutationSetEntry { | ||
Index(usize), | ||
SetExists(Vec<AIndexSet<NodeIndex>>), | ||
} | ||
|
||
fn analyze_commutations_inner( | ||
py: Python, | ||
dag: &mut DAGCircuit, | ||
commutation_checker: &mut CommutationChecker, | ||
) -> HashMap<(Option<NodeIndex>, Wire), CommutationSetEntry> { | ||
let mut commutation_set: HashMap<(Option<NodeIndex>, Wire), CommutationSetEntry> = | ||
HashMap::new(); | ||
let max_num_qubits = 3; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not that it matters, but I'd have done this with: const MAX_NUM_QUBITS: u32 = 3; outside the function definition because it is never changed. But realistically the compiler will likely end up with the same generated code in both cases so it doesn't matter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still consistency is nice -- I'll move it outside 👍🏻 |
||
|
||
(0..dag.num_qubits()).for_each(|qubit| { | ||
let wire = Wire::Qubit(Qubit(qubit as u32)); | ||
dag.nodes_on_wire(py, &wire, false) | ||
.iter() | ||
.for_each(|current_gate_idx| { | ||
if let CommutationSetEntry::SetExists(ref mut commutation_entry) = commutation_set | ||
.entry((None, wire.clone())) | ||
.or_insert_with(|| { | ||
CommutationSetEntry::SetExists(vec![AIndexSet::from_iter([ | ||
*current_gate_idx, | ||
])]) | ||
}) | ||
{ | ||
let last = commutation_entry.last_mut().unwrap(); | ||
|
||
if !last.contains(current_gate_idx) { | ||
if last.iter().all(|prev_gate_idx| { | ||
if let ( | ||
NodeType::Operation(packed_inst0), | ||
NodeType::Operation(packed_inst1), | ||
) = (&dag.dag[*current_gate_idx], &dag.dag[*prev_gate_idx]) | ||
{ | ||
let empty_params: Box<SmallVec<[Param; 3]>> = Box::new(smallvec![]); | ||
let op1 = packed_inst0.op.view(); | ||
let op2 = packed_inst1.op.view(); | ||
let params1 = match packed_inst0.params.as_ref() { | ||
Some(params) => params, | ||
None => &empty_params, | ||
}; | ||
let params2 = match packed_inst1.params.as_ref() { | ||
Some(params) => params, | ||
None => &empty_params, | ||
}; | ||
Cryoris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let qargs1 = dag.qargs_interner.get(packed_inst0.qubits); | ||
let qargs2 = dag.qargs_interner.get(packed_inst1.qubits); | ||
let cargs1 = dag.cargs_interner.get(packed_inst0.clbits); | ||
let cargs2 = dag.cargs_interner.get(packed_inst1.clbits); | ||
Cryoris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// TODO preliminary interface, change this when dagcircuit merges | ||
commutation_checker | ||
.commute_inner( | ||
py, | ||
&op1, | ||
params1, | ||
packed_inst0.extra_attrs.as_deref(), | ||
qargs1, | ||
cargs1, | ||
&op2, | ||
params2, | ||
packed_inst1.extra_attrs.as_deref(), | ||
qargs2, | ||
cargs2, | ||
max_num_qubits, | ||
) | ||
.unwrap() | ||
} else { | ||
false | ||
} | ||
}) { | ||
// all commute, add to current list | ||
last.insert(*current_gate_idx); | ||
} else { | ||
// does not commute, create new list | ||
commutation_entry.push(AIndexSet::from_iter([*current_gate_idx])) | ||
} | ||
} | ||
} else { | ||
panic!("Wrong type in dictionary!"); | ||
} | ||
if let CommutationSetEntry::SetExists(last_entry) = | ||
commutation_set.get(&(None, wire.clone())).unwrap() | ||
{ | ||
commutation_set.insert( | ||
(Some(*current_gate_idx), wire.clone()), | ||
CommutationSetEntry::Index(last_entry.len() - 1), | ||
); | ||
} | ||
}) | ||
}); | ||
commutation_set | ||
} | ||
|
||
#[pyfunction] | ||
#[pyo3(signature = (dag, commutation_checker))] | ||
pub(crate) fn analyze_commutations( | ||
py: Python, | ||
dag: &mut DAGCircuit, | ||
commutation_checker: &mut CommutationChecker, | ||
) -> PyResult<Py<PyDict>> { | ||
let commutations = analyze_commutations_inner(py, dag, commutation_checker); | ||
let out_dict = PyDict::new_bound(py); | ||
for (k, comms) in commutations { | ||
let nidx = k.0; | ||
let wire = match k.1 { | ||
Wire::Qubit(q) => dag.qubits.get(q).unwrap().to_object(py), | ||
Wire::Clbit(c) => dag.clbits.get(c).unwrap().to_object(py), | ||
Wire::Var(v) => v, | ||
}; | ||
|
||
if nidx.is_some() { | ||
match comms { | ||
CommutationSetEntry::Index(idx) => { | ||
out_dict.set_item((dag.get_node(py, nidx.unwrap())?, wire), idx)? | ||
} | ||
_ => panic!("Wrong format in commutation analysis"), | ||
}; | ||
} else { | ||
match comms { | ||
CommutationSetEntry::SetExists(comm_set) => out_dict.set_item( | ||
wire, | ||
PyList::new_bound( | ||
py, | ||
comm_set.iter().map(|inner| { | ||
PyList::new_bound( | ||
py, | ||
inner | ||
.into_iter() | ||
.map(|ndidx| dag.get_node(py, *ndidx).unwrap()), | ||
) | ||
}), | ||
), | ||
)?, | ||
_ => panic!("Wrong format in commutation analysis"), | ||
} | ||
} | ||
} | ||
Ok(out_dict.unbind()) | ||
} | ||
|
||
#[pymodule] | ||
pub fn commutation_analysis(m: &Bound<PyModule>) -> PyResult<()> { | ||
m.add_wrapped(wrap_pyfunction!(analyze_commutations))?; | ||
Ok(()) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- | ||
features_transpiler: | ||
- | | ||
Added a Rust implementation of :class:`.CommutationAnalysis` in :func:`.analyze_commutations`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be expressed a bit more succinctly with:
That being said I'm not sure I'm a fan of using a type alias for this instead of just using
IndexSet<NodeIndex, RandomState>
(assuming a useahash::RandomState
is added) where we need to define a new indexset.