diff --git a/crates/accelerate/src/commutation_checker.rs b/crates/accelerate/src/commutation_checker.rs index 005e37ecc375..16fcc5eca8fb 100644 --- a/crates/accelerate/src/commutation_checker.rs +++ b/crates/accelerate/src/commutation_checker.rs @@ -28,7 +28,7 @@ use qiskit_circuit::circuit_instruction::{ExtraInstructionAttributes, OperationF use qiskit_circuit::dag_node::DAGOpNode; use qiskit_circuit::imports::QI_OPERATOR; use qiskit_circuit::operations::OperationRef::{Gate as PyGateType, Operation as PyOperationType}; -use qiskit_circuit::operations::{Operation, OperationRef, Param}; +use qiskit_circuit::operations::{Operation, OperationRef, Param, StandardGate}; use qiskit_circuit::{BitType, Clbit, Qubit}; use crate::unitary_compose; @@ -38,8 +38,28 @@ static SKIPPED_NAMES: [&str; 4] = ["measure", "reset", "delay", "initialize"]; static NO_CACHE_NAMES: [&str; 2] = ["annotated", "linear_function"]; static SUPPORTED_OP: Lazy> = Lazy::new(|| { HashSet::from([ - "h", "x", "y", "z", "sx", "sxdg", "t", "tdg", "s", "sdg", "cx", "cy", "cz", "swap", - "iswap", "ecr", "ccx", "cswap", + "rxx", "ryy", "rzz", "rzx", "h", "x", "y", "z", "sx", "sxdg", "t", "tdg", "s", "sdg", "cx", + "cy", "cz", "swap", "iswap", "ecr", "ccx", "cswap", + ]) +}); + +// map rotation gates to their generators, or to ``None`` if we cannot currently efficiently +// represent the generator in Rust and store the commutation relation in the commutation dictionary +static SUPPORTED_ROTATIONS: Lazy>> = Lazy::new(|| { + HashMap::from([ + ("rx", Some(OperationRef::Standard(StandardGate::XGate))), + ("ry", Some(OperationRef::Standard(StandardGate::YGate))), + ("rz", Some(OperationRef::Standard(StandardGate::ZGate))), + ("p", Some(OperationRef::Standard(StandardGate::ZGate))), + ("u1", Some(OperationRef::Standard(StandardGate::ZGate))), + ("crx", Some(OperationRef::Standard(StandardGate::CXGate))), + ("cry", Some(OperationRef::Standard(StandardGate::CYGate))), + ("crz", Some(OperationRef::Standard(StandardGate::CZGate))), + ("cp", Some(OperationRef::Standard(StandardGate::CZGate))), + ("rxx", None), // None means the gate is in the commutation dictionary + ("ryy", None), + ("rzx", None), + ("rzz", None), ]) }); @@ -89,6 +109,7 @@ impl CommutationChecker { ) -> Self { // Initialize sets before they are used in the commutation checker Lazy::force(&SUPPORTED_OP); + Lazy::force(&SUPPORTED_ROTATIONS); CommutationChecker { library: CommutationLibrary::new(standard_gate_commutations), cache: HashMap::new(), @@ -242,6 +263,23 @@ impl CommutationChecker { cargs2: &[Clbit], max_num_qubits: u32, ) -> PyResult { + // relative and absolute tolerance used to (1) check whether rotation gates commute + // trivially (i.e. the rotation angle is so small we assume it commutes) and (2) define + // comparison for the matrix-based commutation checks + let rtol = 1e-5; + let atol = 1e-8; + + // if we have rotation gates, we attempt to map them to their generators, for example + // RX -> X or CPhase -> CZ + let (op1, params1, trivial1) = map_rotation(op1, params1, rtol); + if trivial1 { + return Ok(true); + } + let (op2, params2, trivial2) = map_rotation(op2, params2, rtol); + if trivial2 { + return Ok(true); + } + if let Some(gates) = &self.gates { if !gates.is_empty() && (!gates.contains(op1.name()) || !gates.contains(op2.name())) { return Ok(false); @@ -286,7 +324,9 @@ impl CommutationChecker { NO_CACHE_NAMES.contains(&second_op.name()) || // Skip params that do not evaluate to floats for caching and commutation library first_params.iter().any(|p| !matches!(p, Param::Float(_))) || - second_params.iter().any(|p| !matches!(p, Param::Float(_))); + second_params.iter().any(|p| !matches!(p, Param::Float(_))) + && !SUPPORTED_OP.contains(op1.name()) + && !SUPPORTED_OP.contains(op2.name()); if skip_cache { return self.commute_matmul( @@ -297,6 +337,8 @@ impl CommutationChecker { second_op, second_params, second_qargs, + rtol, + atol, ); } @@ -331,6 +373,8 @@ impl CommutationChecker { second_op, second_params, second_qargs, + rtol, + atol, )?; // TODO: implement a LRU cache for this @@ -365,6 +409,8 @@ impl CommutationChecker { second_op: &OperationRef, second_params: &[Param], second_qargs: &[Qubit], + rtol: f64, + atol: f64, ) -> PyResult { // Compute relative positioning of qargs of the second gate to the first gate. // Since the qargs come out the same BitData, we already know there are no accidential @@ -405,8 +451,6 @@ impl CommutationChecker { None => return Ok(false), }; - let rtol = 1e-5; - let atol = 1e-8; if first_qarg == second_qarg { match first_qarg.len() { 1 => Ok(unitary_compose::commute_1q( @@ -568,6 +612,41 @@ where .any(|x| matches!(x, Param::ParameterExpression(_))) } +/// Check if a given operation can be mapped onto a generator. +/// +/// If ``op`` is in the ``SUPPORTED_ROTATIONS`` hashmap, it is a rotation and we +/// (1) check whether the rotation is so small (modulo pi) that we assume it is the +/// identity and it commutes trivially with every other operation +/// (2) otherwise, we check whether a generator of the rotation is given (e.g. X for RX) +/// and we return the generator +/// +/// Returns (operation, parameters, commutes_trivially). +fn map_rotation<'a>( + op: &'a OperationRef<'a>, + params: &'a [Param], + tol: f64, +) -> (&'a OperationRef<'a>, &'a [Param], bool) { + let name = op.name(); + if let Some(generator) = SUPPORTED_ROTATIONS.get(name) { + // if the rotation angle is below the tolerance, the gate is assumed to + // commute with everything, and we simply return the operation with the flag that + // it commutes trivially + if let Param::Float(angle) = params[0] { + if (angle % std::f64::consts::PI).abs() < tol { + return (op, params, true); + }; + }; + + // otherwise, we check if a generator is given -- if not, we'll just return the operation + // itself (e.g. RXX does not have a generator and is just stored in the commutations + // dictionary) + if let Some(gate) = generator { + return (gate, &[], false); + }; + } + (op, params, false) +} + fn get_relative_placement( first_qargs: &[Qubit], second_qargs: &[Qubit], diff --git a/qiskit/circuit/_standard_gates_commutations.py b/qiskit/circuit/_standard_gates_commutations.py index 9dc95b675db7..12899207b7f3 100644 --- a/qiskit/circuit/_standard_gates_commutations.py +++ b/qiskit/circuit/_standard_gates_commutations.py @@ -47,6 +47,7 @@ first cx. """ + standard_gates_commutations = { ("id", "id"): True, ("id", "sx"): True, @@ -70,6 +71,10 @@ ("id", "iswap"): True, ("id", "sxdg"): True, ("id", "tdg"): True, + ("id", "rxx"): True, + ("id", "ryy"): True, + ("id", "rzz"): True, + ("id", "rzx"): True, ("sx", "sx"): True, ("sx", "cx"): { (0,): False, @@ -109,6 +114,13 @@ ("sx", "iswap"): False, ("sx", "sxdg"): True, ("sx", "tdg"): False, + ("sx", "rxx"): True, + ("sx", "ryy"): False, + ("sx", "rzz"): False, + ("sx", "rzx"): { + (0,): False, + (1,): True, + }, ("x", "id"): True, ("x", "sx"): True, ("x", "x"): True, @@ -152,6 +164,13 @@ ("x", "tdg"): False, ("x", "y"): False, ("x", "z"): False, + ("x", "rxx"): True, + ("x", "ryy"): False, + ("x", "rzz"): False, + ("x", "rzx"): { + (0,): False, + (1,): True, + }, ("cx", "cx"): { (0, 1): True, (0, None): True, @@ -303,6 +322,31 @@ }, ("cx", "swap"): False, ("cx", "iswap"): False, + ("cx", "rxx"): { + (0, 1): False, + (0, None): False, + (1, 0): False, + (1, None): False, + (None, 0): True, + (None, 1): True, + }, + ("cx", "ryy"): False, + ("cx", "rzz"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): True, + (None, 0): False, + (None, 1): False, + }, + ("cx", "rzx"): { + (0, 1): True, + (0, None): True, + (1, 0): False, + (1, None): False, + (None, 0): False, + (None, 1): True, + }, ("c3sx", "c3sx"): { (0, 1, 2, 3): True, (0, 1, 2, None): True, @@ -1029,6 +1073,10 @@ ("dcx", "csdg"): False, ("dcx", "swap"): False, ("dcx", "iswap"): False, + ("dcx", "rxx"): False, + ("dcx", "ryy"): False, + ("dcx", "rzz"): False, + ("dcx", "rzx"): False, ("ch", "cx"): { (0, 1): False, (0, None): True, @@ -1189,6 +1237,24 @@ }, ("ch", "swap"): False, ("ch", "iswap"): False, + ("ch", "rxx"): False, + ("ch", "ryy"): False, + ("ch", "rzz"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): True, + (None, 0): False, + (None, 1): False, + }, + ("ch", "rzx"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, ("cswap", "c3sx"): { (0, 1, 2): True, (0, 1, 3): False, @@ -1499,6 +1565,31 @@ }, ("csx", "swap"): False, ("csx", "iswap"): False, + ("csx", "rxx"): { + (0, 1): False, + (0, None): False, + (1, 0): False, + (1, None): False, + (None, 0): True, + (None, 1): True, + }, + ("csx", "ryy"): False, + ("csx", "rzz"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): True, + (None, 0): False, + (None, 1): False, + }, + ("csx", "rzx"): { + (0, 1): True, + (0, None): True, + (1, 0): False, + (1, None): False, + (None, 0): False, + (None, 1): True, + }, ("cy", "c3sx"): { (0, 1): False, (0, 2): False, @@ -1635,6 +1726,31 @@ }, ("cy", "swap"): False, ("cy", "iswap"): False, + ("cy", "rxx"): False, + ("cy", "ryy"): { + (0, 1): False, + (0, None): False, + (1, 0): False, + (1, None): False, + (None, 0): True, + (None, 1): True, + }, + ("cy", "rzz"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): True, + (None, 0): False, + (None, 1): False, + }, + ("cy", "rzx"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, ("cz", "c3sx"): { (0, 1): True, (0, 2): True, @@ -1750,6 +1866,17 @@ (None, 0): False, (None, 1): False, }, + ("cz", "rxx"): False, + ("cz", "ryy"): False, + ("cz", "rzz"): True, + ("cz", "rzx"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): False, + (None, 0): True, + (None, 1): False, + }, ("ccz", "c3sx"): { (0, 1, 2): True, (0, 1, 3): False, @@ -2000,6 +2127,10 @@ ("h", "tdg"): False, ("h", "y"): False, ("h", "z"): False, + ("h", "rxx"): False, + ("h", "ryy"): False, + ("h", "rzz"): False, + ("h", "rzx"): False, ("rccx", "c3sx"): { (0, 1, 2): False, (0, 1, 3): False, @@ -2479,6 +2610,24 @@ ("ecr", "csdg"): False, ("ecr", "swap"): False, ("ecr", "iswap"): False, + ("ecr", "rxx"): { + (0, 1): False, + (0, None): False, + (1, 0): False, + (1, None): False, + (None, 0): True, + (None, 1): True, + }, + ("ecr", "ryy"): False, + ("ecr", "rzz"): False, + ("ecr", "rzx"): { + (0, 1): False, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): True, + }, ("s", "id"): True, ("s", "sx"): False, ("s", "x"): False, @@ -2540,6 +2689,13 @@ ("s", "tdg"): True, ("s", "y"): False, ("s", "z"): True, + ("s", "rxx"): False, + ("s", "ryy"): False, + ("s", "rzz"): True, + ("s", "rzx"): { + (0,): True, + (1,): False, + }, ("sdg", "cx"): { (0,): True, (1,): False, @@ -2594,6 +2750,13 @@ ("sdg", "iswap"): False, ("sdg", "sxdg"): False, ("sdg", "tdg"): True, + ("sdg", "rxx"): False, + ("sdg", "ryy"): False, + ("sdg", "rzz"): True, + ("sdg", "rzx"): { + (0,): True, + (1,): False, + }, ("cs", "cx"): { (0, 1): False, (0, None): True, @@ -2726,6 +2889,17 @@ (None, 0): False, (None, 1): False, }, + ("cs", "rxx"): False, + ("cs", "ryy"): False, + ("cs", "rzz"): True, + ("cs", "rzx"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): False, + (None, 0): True, + (None, 1): False, + }, ("csdg", "c3sx"): { (0, 1): True, (0, 2): True, @@ -3064,6 +3238,13 @@ ("sxdg", "swap"): False, ("sxdg", "iswap"): False, ("sxdg", "sxdg"): True, + ("sxdg", "rxx"): True, + ("sxdg", "ryy"): False, + ("sxdg", "rzz"): False, + ("sxdg", "rzx"): { + (0,): False, + (1,): True, + }, ("t", "id"): True, ("t", "sx"): False, ("t", "x"): False, @@ -3124,6 +3305,13 @@ ("t", "tdg"): True, ("t", "y"): False, ("t", "z"): True, + ("t", "rxx"): False, + ("t", "ryy"): False, + ("t", "rzz"): True, + ("t", "rzx"): { + (0,): True, + (1,): False, + }, ("tdg", "cx"): { (0,): True, (1,): False, @@ -3177,6 +3365,13 @@ ("tdg", "iswap"): False, ("tdg", "sxdg"): False, ("tdg", "tdg"): True, + ("tdg", "rxx"): False, + ("tdg", "ryy"): False, + ("tdg", "rzz"): True, + ("tdg", "rzx"): { + (0,): True, + (1,): False, + }, ("y", "id"): True, ("y", "sx"): False, ("y", "cx"): False, @@ -3204,6 +3399,10 @@ ("y", "tdg"): False, ("y", "y"): True, ("y", "z"): False, + ("y", "rxx"): False, + ("y", "ryy"): True, + ("y", "rzz"): False, + ("y", "rzx"): False, ("z", "id"): True, ("z", "sx"): False, ("z", "cx"): { @@ -3261,4 +3460,390 @@ ("z", "sxdg"): False, ("z", "tdg"): True, ("z", "z"): True, + ("z", "rxx"): False, + ("z", "ryy"): False, + ("z", "rzz"): True, + ("z", "rzx"): { + (0,): True, + (1,): False, + }, + ("rxx", "c3sx"): { + (0, 1): False, + (0, 2): False, + (0, 3): False, + (0, None): False, + (1, 0): False, + (1, 2): False, + (1, 3): False, + (1, None): False, + (2, 0): False, + (2, 1): False, + (2, 3): False, + (2, None): False, + (3, 0): False, + (3, 1): False, + (3, 2): False, + (3, None): True, + (None, 0): False, + (None, 1): False, + (None, 2): False, + (None, 3): True, + }, + ("rxx", "ccx"): { + (0, 1): False, + (0, 2): False, + (0, None): False, + (1, 0): False, + (1, 2): False, + (1, None): False, + (2, 0): False, + (2, 1): False, + (2, None): True, + (None, 0): False, + (None, 1): False, + (None, 2): True, + }, + ("rxx", "cswap"): { + (0, 1): False, + (0, 2): False, + (0, None): False, + (1, 0): False, + (1, 2): True, + (1, None): False, + (2, 0): False, + (2, 1): True, + (2, None): False, + (None, 0): False, + (None, 1): False, + (None, 2): False, + }, + ("rxx", "ccz"): False, + ("rxx", "rccx"): False, + ("rxx", "rcccx"): False, + ("rxx", "csdg"): False, + ("rxx", "swap"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("rxx", "iswap"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("rxx", "rxx"): True, + ("rxx", "ryy"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("rxx", "rzz"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("rxx", "rzx"): { + (0, 1): False, + (0, None): False, + (1, 0): False, + (1, None): True, + (None, 0): False, + (None, 1): True, + }, + ("ryy", "c3sx"): False, + ("ryy", "ccx"): False, + ("ryy", "cswap"): { + (0, 1): False, + (0, 2): False, + (0, None): False, + (1, 0): False, + (1, 2): True, + (1, None): False, + (2, 0): False, + (2, 1): True, + (2, None): False, + (None, 0): False, + (None, 1): False, + (None, 2): False, + }, + ("ryy", "ccz"): False, + ("ryy", "rccx"): False, + ("ryy", "rcccx"): False, + ("ryy", "csdg"): False, + ("ryy", "swap"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("ryy", "iswap"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("ryy", "ryy"): True, + ("ryy", "rzz"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("ryy", "rzx"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("rzz", "c3sx"): { + (0, 1): True, + (0, 2): True, + (0, 3): False, + (0, None): True, + (1, 0): True, + (1, 2): True, + (1, 3): False, + (1, None): True, + (2, 0): True, + (2, 1): True, + (2, 3): False, + (2, None): True, + (3, 0): False, + (3, 1): False, + (3, 2): False, + (3, None): False, + (None, 0): True, + (None, 1): True, + (None, 2): True, + (None, 3): False, + }, + ("rzz", "ccx"): { + (0, 1): True, + (0, 2): False, + (0, None): True, + (1, 0): True, + (1, 2): False, + (1, None): True, + (2, 0): False, + (2, 1): False, + (2, None): False, + (None, 0): True, + (None, 1): True, + (None, 2): False, + }, + ("rzz", "cswap"): { + (0, 1): False, + (0, 2): False, + (0, None): True, + (1, 0): False, + (1, 2): True, + (1, None): False, + (2, 0): False, + (2, 1): True, + (2, None): False, + (None, 0): True, + (None, 1): False, + (None, 2): False, + }, + ("rzz", "ccz"): True, + ("rzz", "rccx"): { + (0, 1): True, + (0, 2): False, + (0, None): True, + (1, 0): True, + (1, 2): False, + (1, None): True, + (2, 0): False, + (2, 1): False, + (2, None): False, + (None, 0): True, + (None, 1): True, + (None, 2): False, + }, + ("rzz", "rcccx"): { + (0, 1): True, + (0, 2): True, + (0, 3): False, + (0, None): True, + (1, 0): True, + (1, 2): True, + (1, 3): False, + (1, None): True, + (2, 0): True, + (2, 1): True, + (2, 3): False, + (2, None): True, + (3, 0): False, + (3, 1): False, + (3, 2): False, + (3, None): False, + (None, 0): True, + (None, 1): True, + (None, 2): True, + (None, 3): False, + }, + ("rzz", "csdg"): True, + ("rzz", "swap"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("rzz", "iswap"): { + (0, 1): True, + (0, None): False, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): False, + }, + ("rzz", "rzz"): True, + ("rzx", "c3sx"): { + (0, 1): False, + (0, 2): False, + (0, 3): True, + (0, None): True, + (1, 0): False, + (1, 2): False, + (1, 3): True, + (1, None): True, + (2, 0): False, + (2, 1): False, + (2, 3): True, + (2, None): True, + (3, 0): False, + (3, 1): False, + (3, 2): False, + (3, None): False, + (None, 0): False, + (None, 1): False, + (None, 2): False, + (None, 3): True, + }, + ("rzx", "ccx"): { + (0, 1): False, + (0, 2): True, + (0, None): True, + (1, 0): False, + (1, 2): True, + (1, None): True, + (2, 0): False, + (2, 1): False, + (2, None): False, + (None, 0): False, + (None, 1): False, + (None, 2): True, + }, + ("rzx", "cswap"): { + (0, 1): False, + (0, 2): False, + (0, None): True, + (1, 0): False, + (1, 2): False, + (1, None): False, + (2, 0): False, + (2, 1): False, + (2, None): False, + (None, 0): False, + (None, 1): False, + (None, 2): False, + }, + ("rzx", "ccz"): { + (0, 1): False, + (0, 2): False, + (0, None): True, + (1, 0): False, + (1, 2): False, + (1, None): True, + (2, 0): False, + (2, 1): False, + (2, None): True, + (None, 0): False, + (None, 1): False, + (None, 2): False, + }, + ("rzx", "rccx"): { + (0, 1): False, + (0, 2): False, + (0, None): True, + (1, 0): False, + (1, 2): False, + (1, None): True, + (2, 0): False, + (2, 1): False, + (2, None): False, + (None, 0): False, + (None, 1): False, + (None, 2): False, + }, + ("rzx", "rcccx"): { + (0, 1): False, + (0, 2): False, + (0, 3): False, + (0, None): True, + (1, 0): False, + (1, 2): False, + (1, 3): False, + (1, None): True, + (2, 0): False, + (2, 1): False, + (2, 3): False, + (2, None): True, + (3, 0): False, + (3, 1): False, + (3, 2): False, + (3, None): False, + (None, 0): False, + (None, 1): False, + (None, 2): False, + (None, 3): False, + }, + ("rzx", "csdg"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): True, + (None, 0): False, + (None, 1): False, + }, + ("rzx", "swap"): False, + ("rzx", "iswap"): False, + ("rzx", "rzz"): { + (0, 1): False, + (0, None): True, + (1, 0): False, + (1, None): True, + (None, 0): False, + (None, 1): False, + }, + ("rzx", "rzx"): { + (0, 1): True, + (0, None): True, + (1, 0): True, + (1, None): False, + (None, 0): False, + (None, 1): True, + }, } diff --git a/qiskit/transpiler/passes/optimization/template_matching/backward_match.py b/qiskit/transpiler/passes/optimization/template_matching/backward_match.py index d194d1cbbddf..3869a078250a 100644 --- a/qiskit/transpiler/passes/optimization/template_matching/backward_match.py +++ b/qiskit/transpiler/passes/optimization/template_matching/backward_match.py @@ -242,7 +242,7 @@ def _is_same_op(self, node_circuit, node_template): Returns: bool: True if the same, False otherwise. """ - return node_circuit.op == node_template.op + return node_circuit.op.soft_compare(node_template.op) def _is_same_q_conf(self, node_circuit, node_template, qarg_circuit): """ diff --git a/releasenotes/notes/Parameterized-commutation-checker-8a78a4715bf78b4e.yaml b/releasenotes/notes/Parameterized-commutation-checker-8a78a4715bf78b4e.yaml new file mode 100644 index 000000000000..72bd9aefe726 --- /dev/null +++ b/releasenotes/notes/Parameterized-commutation-checker-8a78a4715bf78b4e.yaml @@ -0,0 +1,14 @@ +--- +features_circuits: + - | + Improved the functionality of :class:`.CommutationChecker` to include + support for the following parameterized gates with free parameters: + :class:`.RXXGate`,:class:`.RYYGate`,:class:`.RZZGate`,:class:`.RZXGate`, + :class:`.RXGate`,:class:`.RYGate`,:class:`.RZGate`,:class:`.PhaseGate`, + :class:`.U1Gate`,:class:`.CRXGate`,:class:`.CRYGate`,:class:`.CRZGate`, + :class:`.CPhaseGate`. + + Before these were only supported with bound parameters. + + + diff --git a/test/python/circuit/test_commutation_checker.py b/test/python/circuit/test_commutation_checker.py index a0aeae5ca2c3..e261389a3272 100644 --- a/test/python/circuit/test_commutation_checker.py +++ b/test/python/circuit/test_commutation_checker.py @@ -13,36 +13,47 @@ """Test commutation checker class .""" import unittest +from test import QiskitTestCase # pylint: disable=wrong-import-order import numpy as np +from ddt import data, ddt from qiskit import ClassicalRegister from qiskit.circuit import ( - QuantumRegister, - Parameter, - Qubit, AnnotatedOperation, - InverseModifier, ControlModifier, Gate, + InverseModifier, + Parameter, + QuantumRegister, + Qubit, ) from qiskit.circuit.commutation_library import SessionCommutationChecker as scc -from qiskit.dagcircuit import DAGOpNode from qiskit.circuit.library import ( - ZGate, - XGate, - CXGate, + Barrier, CCXGate, + CPhaseGate, + CRXGate, + CRYGate, + CRZGate, + CXGate, + LinearFunction, MCXGate, - RZGate, Measure, - Barrier, + PhaseGate, Reset, - LinearFunction, - SGate, + RXGate, RXXGate, + RYGate, + RYYGate, + RZGate, + RZXGate, + RZZGate, + SGate, + XGate, + ZGate, ) -from test import QiskitTestCase # pylint: disable=wrong-import-order +from qiskit.dagcircuit import DAGOpNode class NewGateCX(Gate): @@ -55,6 +66,7 @@ def to_matrix(self): return np.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=complex) +@ddt class TestCommutationChecker(QiskitTestCase): """Test CommutationChecker class.""" @@ -145,16 +157,10 @@ def test_caching_different_qubit_sets(self): scc.commute(XGate(), [5], [], NewGateCX(), [5, 7], []) self.assertEqual(scc.num_cached_entries(), 1) - def test_cache_with_param_gates(self): + def test_zero_rotations(self): """Check commutativity between (non-parameterized) gates with parameters.""" - scc.clear_cached_commutations() - self.assertTrue(scc.commute(RZGate(0), [0], [], XGate(), [0], [])) - self.assertFalse(scc.commute(RZGate(np.pi / 2), [0], [], XGate(), [0], [])) - self.assertTrue(scc.commute(RZGate(np.pi / 2), [0], [], RZGate(0), [0], [])) - - self.assertFalse(scc.commute(RZGate(np.pi / 2), [1], [], XGate(), [1], [])) - self.assertEqual(scc.num_cached_entries(), 3) + self.assertTrue(scc.commute(XGate(), [0], [], RZGate(0), [0], [])) def test_gates_with_parameters(self): """Check commutativity between (non-parameterized) gates with parameters.""" @@ -172,6 +178,8 @@ def test_parameterized_gates(self): # gate that has parameters and is considered parameterized rz_gate_theta = RZGate(Parameter("Theta")) + rx_gate_theta = RXGate(Parameter("Theta")) + rxx_gate_theta = RXXGate(Parameter("Theta")) rz_gate_phi = RZGate(Parameter("Phi")) self.assertEqual(len(rz_gate_theta.params), 1) self.assertTrue(rz_gate_theta.is_parameterized()) @@ -193,7 +201,6 @@ def test_parameterized_gates(self): # We should detect that parameterized gates over disjoint qubit subsets commute self.assertTrue(scc.commute(rz_gate_theta, [0], [], rz_gate_phi, [1], [])) - # We should detect that parameterized gates over disjoint qubit subsets commute self.assertTrue(scc.commute(rz_gate_theta, [2], [], cx_gate, [1, 3], [])) # However, for now commutativity checker should return False when checking @@ -201,9 +208,14 @@ def test_parameterized_gates(self): # the two gates are over intersecting qubit subsets. # This check should be changed if commutativity checker is extended to # handle parameterized gates better. - self.assertFalse(scc.commute(rz_gate_theta, [0], [], cx_gate, [0, 1], [])) - - self.assertFalse(scc.commute(rz_gate_theta, [0], [], rz_gate, [0], [])) + self.assertFalse(scc.commute(rz_gate_theta, [1], [], cx_gate, [0, 1], [])) + self.assertTrue(scc.commute(rz_gate_theta, [0], [], rz_gate, [0], [])) + self.assertTrue(scc.commute(rz_gate_theta, [0], [], rz_gate_phi, [0], [])) + self.assertTrue(scc.commute(rxx_gate_theta, [0, 1], [], rx_gate_theta, [0], [])) + self.assertTrue(scc.commute(rxx_gate_theta, [0, 1], [], XGate(), [0], [])) + self.assertTrue(scc.commute(XGate(), [0], [], rxx_gate_theta, [0, 1], [])) + self.assertTrue(scc.commute(rx_gate_theta, [0], [], rxx_gate_theta, [0, 1], [])) + self.assertTrue(scc.commute(rz_gate_theta, [0], [], cx_gate, [0, 1], [])) def test_measure(self): """Check commutativity involving measures.""" @@ -354,6 +366,39 @@ def test_serialization(self): cc2.commute_nodes(dop1, dop2) self.assertEqual(cc2.num_cached_entries(), 1) + @data( + RXGate, + RYGate, + RZGate, + PhaseGate, + CRXGate, + CRYGate, + CRZGate, + CPhaseGate, + RXXGate, + RYYGate, + RZZGate, + RZXGate, + ) + def test_cutoff_angles(self, gate_cls): + """Check rotations with a small enough angle are cut off.""" + max_power = 30 + from qiskit.circuit.library import DCXGate + + generic_gate = DCXGate() # gate that does not commute with any rotation gate + + cutoff_angle = 1e-5 # this is the cutoff we use in the CommutationChecker + + for i in range(1, max_power + 1): + angle = 2 ** (-i) + gate = gate_cls(angle) + qargs = list(range(gate.num_qubits)) + + if angle < cutoff_angle: + self.assertTrue(scc.commute(generic_gate, [0, 1], [], gate, qargs, [])) + else: + self.assertFalse(scc.commute(generic_gate, [0, 1], [], gate, qargs, [])) + if __name__ == "__main__": unittest.main() diff --git a/test/python/transpiler/test_commutative_inverse_cancellation.py b/test/python/transpiler/test_commutative_inverse_cancellation.py index cd800a3bb46f..c84dfaea3071 100644 --- a/test/python/transpiler/test_commutative_inverse_cancellation.py +++ b/test/python/transpiler/test_commutative_inverse_cancellation.py @@ -13,15 +13,16 @@ """Test transpiler pass that cancels inverse gates while exploiting the commutation relations.""" import unittest +from test import QiskitTestCase # pylint: disable=wrong-import-order + import numpy as np -from ddt import ddt, data +from ddt import data, ddt from qiskit.circuit import Parameter, QuantumCircuit from qiskit.circuit.library import RZGate, UnitaryGate +from qiskit.quantum_info import Operator from qiskit.transpiler import PassManager from qiskit.transpiler.passes import CommutativeInverseCancellation -from qiskit.quantum_info import Operator -from test import QiskitTestCase # pylint: disable=wrong-import-order @ddt @@ -758,25 +759,26 @@ def test_no_cancellation_across_reset(self, matrix_based): self.assertEqual(circuit, new_circuit) @data(False, True) - def test_no_cancellation_across_parameterized_gates(self, matrix_based): - """Test that parameterized gates prevent cancellation. - This test should be modified when inverse and commutativity checking - get improved to handle parameterized gates. - """ + def test_cancellation_across_parameterized_gates(self, matrix_based): + """Test that parameterized gates do not prevent cancellation.""" + theta = Parameter("Theta") circuit = QuantumCircuit(1) circuit.rz(np.pi / 2, 0) - circuit.rz(Parameter("Theta"), 0) + circuit.rz(theta, 0) circuit.rz(-np.pi / 2, 0) + expected_circuit = QuantumCircuit(1) + expected_circuit.rz(theta, 0) + passmanager = PassManager(CommutativeInverseCancellation(matrix_based=matrix_based)) new_circuit = passmanager.run(circuit) - self.assertEqual(circuit, new_circuit) + self.assertEqual(expected_circuit, new_circuit) @data(False, True) def test_parameterized_gates_do_not_cancel(self, matrix_based): """Test that parameterized gates do not cancel. - This test should be modified when inverse and commutativity checking - get improved to handle parameterized gates. + This test should be modified when inverse checking + gets improved to handle parameterized gates. """ gate = RZGate(Parameter("Theta")) diff --git a/tools/build_standard_commutations.py b/tools/build_standard_commutations.py index 31f1fe03822b..2e13d741c93b 100644 --- a/tools/build_standard_commutations.py +++ b/tools/build_standard_commutations.py @@ -19,10 +19,19 @@ import itertools from functools import lru_cache from typing import List -from qiskit.circuit import Gate, CommutationChecker + import qiskit.circuit.library.standard_gates as stdg +from qiskit.circuit import CommutationChecker, Gate +from qiskit.circuit.library import PauliGate from qiskit.dagcircuit import DAGOpNode +SUPPORTED_ROTATIONS = { + "rxx": PauliGate("XX"), + "ryy": PauliGate("YY"), + "rzz": PauliGate("ZZ"), + "rzx": PauliGate("XZ"), +} + @lru_cache(maxsize=10**3) def _persistent_id(op_name: str) -> int: @@ -83,7 +92,6 @@ def _get_relative_placement(first_qargs, second_qargs) -> tuple: return tuple(qubits_g2.get(q_g0, None) for q_g0 in first_qargs) -@lru_cache(None) def _get_unparameterizable_gates() -> List[Gate]: """Retrieve a list of non-parmaterized gates with up to 3 qubits, using the python inspection module Return: @@ -95,6 +103,17 @@ def _get_unparameterizable_gates() -> List[Gate]: return [g for g in gates if len(g.params) == 0] +def _get_rotation_gates() -> List[Gate]: + """Retrieve a list of parmaterized gates we know the commutation relations of with up + to 3 qubits, using the python inspection module + Return: + A list of parameterized gates(that we know how to commute) to also be considered + in the commutation library + """ + gates = list(stdg.get_standard_gate_name_mapping().values()) + return [g for g in gates if g.name in SUPPORTED_ROTATIONS] + + def _generate_commutation_dict(considered_gates: List[Gate] = None) -> dict: """Compute the commutation relation of considered gates @@ -110,7 +129,11 @@ def _generate_commutation_dict(considered_gates: List[Gate] = None) -> dict: cc = CommutationChecker() for gate0 in considered_gates: - node0 = DAGOpNode(op=gate0, qargs=list(range(gate0.num_qubits)), cargs=[]) + node0 = DAGOpNode( + op=SUPPORTED_ROTATIONS.get(gate0.name, gate0), + qargs=list(range(gate0.num_qubits)), + cargs=[], + ) for gate1 in considered_gates: # only consider canonical entries @@ -143,12 +166,16 @@ def _generate_commutation_dict(considered_gates: List[Gate] = None) -> dict: gate1_qargs.append(next_non_overlapping_qubit_idx) next_non_overlapping_qubit_idx += 1 - node1 = DAGOpNode(op=gate1, qargs=gate1_qargs, cargs=[]) + node1 = DAGOpNode( + op=SUPPORTED_ROTATIONS.get(gate1.name, gate1), + qargs=gate1_qargs, + cargs=[], + ) # replace non-overlapping qubits with None to act as a key in the commutation library relative_placement = _get_relative_placement(node0.qargs, node1.qargs) - if not gate0.is_parameterized() and not gate1.is_parameterized(): + if not node0.op.is_parameterized() and not node1.op.is_parameterized(): # if no gate includes parameters, compute commutation relation using # matrix multiplication op1 = node0.op @@ -219,6 +246,7 @@ def _dump_commuting_dict_as_python( cgates = [ g for g in _get_unparameterizable_gates() if g.name not in ["reset", "measure", "delay"] ] + cgates += _get_rotation_gates() commutation_dict = _generate_commutation_dict(considered_gates=cgates) commutation_dict = _simplify_commuting_dict(commutation_dict) _dump_commuting_dict_as_python(commutation_dict)