Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Oxidize commutative cancellation #13091

Merged
merged 11 commits into from
Sep 10, 2024
Merged
2 changes: 1 addition & 1 deletion crates/accelerate/src/commutation_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ const MAX_NUM_QUBITS: u32 = 3;
/// commutation_set = {0: [[0], [2, 3], [4], [1]]}
/// node_indices = {(0, 0): 0, (1, 0): 3, (2, 0): 1, (3, 0): 1, (4, 0): 2}
///
fn analyze_commutations_inner(
pub(crate) fn analyze_commutations_inner(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
Expand Down
291 changes: 291 additions & 0 deletions crates/accelerate/src/commutation_cancellation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use crate::commutation_analysis::analyze_commutations_inner;
use crate::commutation_checker::CommutationChecker;
use crate::{euler_one_qubit_decomposer, QiskitError};
use hashbrown::{HashMap, HashSet};
use pyo3::prelude::*;
use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python};
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire};
use qiskit_circuit::operations::StandardGate::{
CXGate, CYGate, CZGate, HGate, PhaseGate, RXGate, RZGate, SGate, TGate, U1Gate, XGate, YGate,
ZGate,
};
use qiskit_circuit::operations::{Operation, Param, StandardGate};
use qiskit_circuit::Qubit;
use rustworkx_core::petgraph::stable_graph::NodeIndex;
use smallvec::{smallvec, SmallVec};
use std::f64::consts::PI;

const _CUTOFF_PRECISION: f64 = 1e-5;
static ROTATION_GATES: [&str; 4] = ["p", "u1", "rz", "rx"];
static HALF_TURNS: [&str; 2] = ["z", "x"];
static QUARTER_TURNS: [&str; 1] = ["s"];
static EIGHTH_TURNS: [&str; 1] = ["t"];

static VAR_Z_MAP: [(&str, StandardGate); 3] = [("rz", RZGate), ("p", PhaseGate), ("u1", U1Gate)];
static Z_ROTATIONS: [StandardGate; 6] = [PhaseGate, ZGate, U1Gate, RZGate, TGate, SGate];
static X_ROTATIONS: [StandardGate; 2] = [XGate, RXGate];
static SUPPORTED_GATES: [StandardGate; 5] = [CXGate, CYGate, CZGate, HGate, YGate];

#[derive(Hash, Eq, PartialEq, Debug)]
enum GateOrRotation {
Gate(StandardGate),
ZRotation,
XRotation,
}
#[derive(Hash, Eq, PartialEq, Debug)]
struct CancellationSetKey {
gate: GateOrRotation,
qubits: SmallVec<[Qubit; 2]>,
com_set_index: usize,
second_index: Option<usize>,
}

#[pyfunction]
#[pyo3(signature = (dag, commutation_checker, basis_gates=None))]
pub(crate) fn cancel_commutations(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
basis_gates: Option<HashSet<String>>,
) -> PyResult<()> {
let basis: HashSet<String> = if let Some(basis) = basis_gates {
basis
} else {
HashSet::new()
};
let z_var_gate = dag
.op_names
.keys()
.find_map(|g| {
VAR_Z_MAP
.iter()
.find(|(key, _)| *key == g.as_str())
.map(|(_, gate)| gate)
})
.or_else(|| {
basis.iter().find_map(|g| {
VAR_Z_MAP
.iter()
.find(|(key, _)| *key == g.as_str())
.map(|(_, gate)| gate)
})
});
// Fallback to the first matching key from basis if there is no match in dag.op_names

// Gate sets to be cancelled
/* Traverse each qubit to generate the cancel dictionaries
Cancel dictionaries:
- For 1-qubit gates the key is (gate_type, qubit_id, commutation_set_id),
the value is the list of gates that share the same gate type, qubit, commutation set.
- For 2qbit gates the key: (gate_type, first_qbit, sec_qbit, first commutation_set_id,
sec_commutation_set_id), the value is the list gates that share the same gate type,
qubits and commutation sets.
*/
let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?;
let mut cancellation_sets: HashMap<CancellationSetKey, Vec<NodeIndex>> = HashMap::new();

(0..dag.num_qubits() as u32).for_each(|qubit| {
let wire = Qubit(qubit);
if let Some(wire_commutation_set) = commutation_set.get(&Wire::Qubit(wire)) {
wire_commutation_set
.iter()
.enumerate()
.for_each(|(com_set_idx, com_set)| {
// This ensures that we only have DAGOPNodes in the current com_set, yuck...
if let NodeType::Operation(_node0) = &dag.dag[*com_set.first().unwrap()] {
com_set.iter().for_each(|node| {
let op = match &dag.dag[*node] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set."),
};
let num_qargs = dag.get_qargs(op.qubits).len();
// no support for cancellation of parameterized gates
if op
.params_view()
.iter()
.all(|p| !matches!(p, Param::ParameterExpression(_)))
{
if let Some(op_gate) = op.op.try_standard_gate() {
if num_qargs == 1usize && SUPPORTED_GATES.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::Gate(op_gate),
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}

if num_qargs == 1usize && Z_ROTATIONS.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::ZRotation,
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}
if num_qargs == 1usize && X_ROTATIONS.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::XRotation,
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}
// Don't deal with Y rotation, because Y rotation doesn't commute with
// CNOT, so it should be dealt with by optimized1qgate pass
if num_qargs == 2usize
&& dag.get_qargs(op.qubits).first().unwrap() == &wire
{
let second_qarg = dag.get_qargs(op.qubits)[1];
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::Gate(op_gate),
qubits: smallvec![wire, second_qarg],
com_set_index: com_set_idx,
second_index: node_indices
.get(&(*node, Wire::Qubit(second_qarg)))
.copied(),
})
.or_insert_with(Vec::new)
.push(*node);
}
}
}
})
}
})
}
});

for (cancel_key, cancel_set) in &cancellation_sets {
if cancel_set.len() > 1 {
if let GateOrRotation::Gate(g) = cancel_key.gate {
if SUPPORTED_GATES.contains(&g) {
for &c_node in &cancel_set[0..(cancel_set.len() / 2) * 2] {
dag.remove_op_node(c_node);
}
}
}
if matches!(cancel_key.gate, GateOrRotation::ZRotation) && z_var_gate.is_none() {
continue;
}
if matches!(cancel_key.gate, GateOrRotation::ZRotation)
|| matches!(cancel_key.gate, GateOrRotation::XRotation)
{
let run_op = match &dag.dag[*cancel_set.first().unwrap()] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set run."),
};

let run_qarg = dag.get_qargs(run_op.qubits).first().unwrap();
let mut total_angle: f64 = 0.0;
let mut total_phase: f64 = 0.0;
for current_node in cancel_set {
let node_op = match &dag.dag[*current_node] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set run."),
};
let node_op_name = node_op.op.name();

let node_qargs = dag.get_qargs(node_op.qubits);
if node_op
.extra_attrs
.as_deref()
.is_some_and(|attr| attr.condition.is_some())
|| node_qargs.len() > 1
|| &node_qargs[0] != run_qarg
{
panic!("internal error");
}

let node_angle = if ROTATION_GATES.contains(&node_op_name) {
match node_op.params_view().first() {
Some(Param::Float(f)) => *f,
_ => return Err(QiskitError::new_err(format!(
"Rotational gate with parameter expression encoutned in cancellation {:?}",
node_op.op
)))
}
} else if HALF_TURNS.contains(&node_op_name) {
PI
} else if QUARTER_TURNS.contains(&node_op_name) {
PI / 2.0
} else if EIGHTH_TURNS.contains(&node_op_name) {
PI / 4.0
} else {
panic!("Angle for operation {node_op_name} is not defined")
};
total_angle += node_angle;

if let Some(definition) = node_op.op.definition(node_op.params_view()) {
total_phase += match definition.global_phase() {Param::Float(f) => f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")};
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the logic is right now this should return a python error:

Suggested change
total_phase += match definition.global_phase() {Param::Float(f) => f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")};
}
total_phase += match definition.global_phase() {
Param::Float(f) => f,
_ => return Err(QiskitError::new_err(format!("PackedInstruction with definition has no global phase set as floating point number")))
};
}

Although realistically if you make my above suggestion about how to get the definition directly from StandardGate I think you can change this to just use unreachable!() since the generator code for definitions of standard gates never has a parameterized phase, it's always a float 0.0. Something like:

Suggested change
total_phase += match definition.global_phase() {Param::Float(f) => f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")};
}
let Param::Float(new_phase) = definition.global_phase() else { unreachable!() };
total_phase += new_phase

(although maybe set an error message saying standard gate phase is never parameterized)

would do the trick. But we can't make this change unless we're working with a StandardGate when we call definition().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm missing something: one of Rz or P (whichever isn't exactly equal to u1) can have a parametric phase if the angle is parametric. GlobalPhaseGate definitely can (if the angle is parametric).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point, I thought there was one, but I scanned through the match statement yesterday and missed it on rzgate. But regardless this statement still holds since the angle is always a float at this point

}

let new_op = if matches!(cancel_key.gate, GateOrRotation::ZRotation) {
z_var_gate.unwrap()
} else if matches!(cancel_key.gate, GateOrRotation::XRotation) {
&RXGate
} else {
return Err(QiskitError::new_err("impossible case!"));
};

let gate_angle = euler_one_qubit_decomposer::mod_2pi(total_angle, 0.);

let new_op_phase: f64 = if gate_angle.abs() > _CUTOFF_PRECISION {
let new_index = dag.insert_1q_on_incoming_qubit(
(*new_op, &[total_angle]),
*cancel_set.first().unwrap(),
);
let new_node = match &dag.dag[new_index] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set run."),
};

if let Some(definition) = new_node.op.definition(new_node.params_view()) {
match definition.global_phase() {Param::Float(f) => *f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")}
} else {
0.0
}
} else {
0.0
};

dag.add_global_phase(py, &Param::Float(total_phase - new_op_phase))?;

for node in cancel_set {
dag.remove_op_node(*node);
}
}
}
}

Ok(())
}

#[pymodule]
pub fn commutation_cancellation(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(cancel_commutations))?;
Ok(())
}
2 changes: 1 addition & 1 deletion crates/accelerate/src/euler_one_qubit_decomposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ pub fn det_one_qubit(mat: ArrayView2<Complex64>) -> Complex64 {

/// Wrap angle into interval [-π,π). If within atol of the endpoint, clamp to -π
#[inline]
fn mod_2pi(angle: f64, atol: f64) -> f64 {
pub(crate) fn mod_2pi(angle: f64, atol: f64) -> f64 {
// f64::rem_euclid() isn't exactly the same as Python's % operator, but because
// the RHS here is a constant and positive it is effectively equivalent for
// this case
Expand Down
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use pyo3::import_exception;

pub mod circuit_library;
pub mod commutation_analysis;
pub mod commutation_cancellation;
pub mod commutation_checker;
pub mod convert_2q_block_matrix;
pub mod dense_layout;
Expand Down
5 changes: 3 additions & 2 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ pub struct DAGCircuit {
var_output_map: _VarIndexMap,

/// Operation kind to count
op_names: IndexMap<String, usize, RandomState>,
pub op_names: IndexMap<String, usize, RandomState>,

// Python modules we need to frequently access (for now).
control_flow_module: PyControlFlowModule,
Expand Down Expand Up @@ -6260,7 +6260,7 @@ impl DAGCircuit {
&mut self,
new_gate: (StandardGate, &[f64]),
old_index: NodeIndex,
) {
) -> NodeIndex {
self.increment_op(new_gate.0.name());
let old_node = &self.dag[old_index];
let inst = if let NodeType::Operation(old_node) = old_node {
Expand All @@ -6287,6 +6287,7 @@ impl DAGCircuit {
self.dag.add_edge(parent_index, new_index, weight.clone());
self.dag.add_edge(new_index, old_index, weight);
self.dag.remove_edge(edge_index);
new_index
}

/// Remove a sequence of 1 qubit nodes from the dag
Expand Down
11 changes: 6 additions & 5 deletions crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ use pyo3::prelude::*;

use qiskit_accelerate::{
circuit_library::circuit_library, commutation_analysis::commutation_analysis,
commutation_checker::commutation_checker, convert_2q_block_matrix::convert_2q_block_matrix,
dense_layout::dense_layout, error_map::error_map,
euler_one_qubit_decomposer::euler_one_qubit_decomposer, filter_op_nodes::filter_op_nodes_mod,
isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates,
pauli_exp_val::pauli_expval,
commutation_cancellation::commutation_cancellation, commutation_checker::commutation_checker,
convert_2q_block_matrix::convert_2q_block_matrix, dense_layout::dense_layout,
error_map::error_map, euler_one_qubit_decomposer::euler_one_qubit_decomposer,
filter_op_nodes::filter_op_nodes_mod, isometry::isometry, nlayout::nlayout,
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval,
remove_diagonal_gates_before_measure::remove_diagonal_gates_before_measure, results::results,
sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis,
Expand Down Expand Up @@ -71,5 +71,6 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
add_submodule(m, vf2_layout, "vf2_layout")?;
add_submodule(m, commutation_checker, "commutation_checker")?;
add_submodule(m, commutation_analysis, "commutation_analysis")?;
add_submodule(m, commutation_cancellation, "commutation_cancellation")?;
Ok(())
}
1 change: 1 addition & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
sys.modules["qiskit._accelerate.synthesis.clifford"] = _accelerate.synthesis.clifford
sys.modules["qiskit._accelerate.commutation_checker"] = _accelerate.commutation_checker
sys.modules["qiskit._accelerate.commutation_analysis"] = _accelerate.commutation_analysis
sys.modules["qiskit._accelerate.commutation_cancellation"] = _accelerate.commutation_cancellation
sys.modules["qiskit._accelerate.synthesis.linear_phase"] = _accelerate.synthesis.linear_phase
sys.modules["qiskit._accelerate.filter_op_nodes"] = _accelerate.filter_op_nodes

Expand Down
Loading
Loading