diff --git a/tket2-py/src/circuit.rs b/tket2-py/src/circuit.rs index ff5c4ad3..126a4486 100644 --- a/tket2-py/src/circuit.rs +++ b/tket2-py/src/circuit.rs @@ -2,6 +2,7 @@ #![allow(unused)] pub mod convert; +pub mod cost; use derive_more::{From, Into}; use pyo3::prelude::*; @@ -14,14 +15,17 @@ use tket2::rewrite::CircuitRewrite; use tket_json_rs::circuit_json::SerialCircuit; pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, Tk2Circuit}; +pub use self::cost::PyCircuitCost; +pub use tket2::{Pauli, Tk2Op}; /// The module definition pub fn module(py: Python) -> PyResult<&PyModule> { let m = PyModule::new(py, "_circuit")?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(validate_hugr, m)?)?; m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?; diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index e63167d7..1980af87 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -1,19 +1,43 @@ //! Utilities for calling Hugr functions on generic python objects. -use pyo3::exceptions::PyAttributeError; +use hugr::ops::OpType; +use pyo3::exceptions::{PyAttributeError, PyValueError}; use pyo3::{prelude::*, PyTypeInfo}; use derive_more::From; use hugr::{Hugr, HugrView}; use serde::Serialize; +use tket2::circuit::CircuitHash; use tket2::extension::REGISTRY; use tket2::json::TKETDecode; use tket2::passes::CircuitChunks; +use tket2::{Circuit, Tk2Op}; use tket_json_rs::circuit_json::SerialCircuit; use crate::rewrite::PyCircuitRewrite; -/// A manager for tket 2 operations on a tket 1 Circuit. +use super::{cost, PyCircuitCost}; + +/// A circuit in tket2 format. +/// +/// This can be freely converted to and from a `pytket.Circuit`. Prefer using +/// this class when applying multiple tket2 operations on a circuit, as it +/// avoids the overhead of converting to and from a `pytket.Circuit` each time. +/// +/// Node indices returned by this class are not stable across conversion to and +/// from a `pytket.Circuit`. +/// +/// # Examples +/// +/// Convert between `pytket.Circuit`s and `Tk2Circuit`s: +/// ```python +/// from pytket import Circuit +/// c = Circuit(2).H(0).CX(0, 1) +/// # Convert to a Tk2Circuit +/// t2c = Tk2Circuit(c) +/// # Convert back to a pytket.Circuit +/// c2 = t2c.to_tket1() +/// ``` #[pyclass] #[derive(Clone, Debug, PartialEq, From)] pub struct Tk2Circuit { @@ -37,7 +61,7 @@ impl Tk2Circuit { } /// Apply a rewrite on the circuit. - pub fn apply_match(&mut self, rw: PyCircuitRewrite) { + pub fn apply_rewrite(&mut self, rw: PyCircuitRewrite) { rw.rewrite.apply(&mut self.hugr).expect("Apply error."); } @@ -73,6 +97,50 @@ impl Tk2Circuit { hugr: tk1.decode()?, }) } + + /// Compute the cost of the circuit based on a per-operation cost function. + /// + /// :param cost_fn: A function that takes a `Tk2Op` and returns an arbitrary cost. + /// The cost must implement `__add__`, `__sub__`, `__lt__`, + /// `__eq__`, `__int__`, and integer `__div__`. + /// + /// :returns: The sum of all operation costs. + pub fn circuit_cost<'py>(&self, cost_fn: &'py PyAny) -> PyResult<&'py PyAny> { + let py = cost_fn.py(); + let cost_fn = |op: &OpType| -> PyResult { + let tk2_op: Tk2Op = op.try_into().map_err(|e| { + PyErr::new::(format!( + "Could not convert circuit operation to a `Tk2Op`: {e}" + )) + })?; + let cost = cost_fn.call1((tk2_op,))?; + Ok(PyCircuitCost { + cost: cost.to_object(py), + }) + }; + let circ_cost = self.hugr.circuit_cost(cost_fn)?; + Ok(circ_cost.cost.into_ref(py)) + } + + /// Returns a hash of the circuit. + pub fn hash(&self) -> u64 { + self.hugr.circuit_hash().unwrap() + } + + /// Hash the circuit + pub fn __hash__(&self) -> isize { + self.hash() as isize + } + + /// Copy the circuit. + pub fn __copy__(&self) -> PyResult { + Ok(self.clone()) + } + + /// Copy the circuit. + pub fn __deepcopy__(&self, _memo: Py) -> PyResult { + Ok(self.clone()) + } } impl Tk2Circuit { /// Tries to extract a Tk2Circuit from a python object. diff --git a/tket2-py/src/circuit/cost.rs b/tket2-py/src/circuit/cost.rs new file mode 100644 index 00000000..cbd734cf --- /dev/null +++ b/tket2-py/src/circuit/cost.rs @@ -0,0 +1,167 @@ +//! + +use std::cmp::Ordering; +use std::iter::Sum; +use std::ops::{Add, AddAssign, Sub}; + +use pyo3::{prelude::*, PyTypeInfo}; +use tket2::circuit::cost::{CircuitCost, CostDelta}; + +/// A generic circuit cost, backed by an arbitrary python object. +#[pyclass] +#[derive(Clone, Debug)] +#[pyo3(name = "CircuitCost")] +pub struct PyCircuitCost { + /// Generic python cost object. + pub cost: PyObject, +} + +#[pymethods] +impl PyCircuitCost { + /// Create a new circuit cost. + #[new] + pub fn new(cost: PyObject) -> Self { + Self { cost } + } +} + +impl Default for PyCircuitCost { + fn default() -> Self { + Python::with_gil(|py| PyCircuitCost { cost: py.None() }) + } +} + +impl Add for PyCircuitCost { + type Output = PyCircuitCost; + + fn add(self, rhs: PyCircuitCost) -> Self::Output { + Python::with_gil(|py| { + let cost = self + .cost + .call_method1(py, "__add__", (rhs.cost,)) + .expect("Could not add circuit cost objects."); + PyCircuitCost { cost } + }) + } +} + +impl AddAssign for PyCircuitCost { + fn add_assign(&mut self, rhs: Self) { + Python::with_gil(|py| { + let cost = self + .cost + .call_method1(py, "__add__", (rhs.cost,)) + .expect("Could not add circuit cost objects."); + self.cost = cost; + }) + } +} + +impl Sub for PyCircuitCost { + type Output = PyCircuitCost; + + fn sub(self, rhs: PyCircuitCost) -> Self::Output { + Python::with_gil(|py| { + let cost = self + .cost + .call_method1(py, "__sub__", (rhs.cost,)) + .expect("Could not subtract circuit cost objects."); + PyCircuitCost { cost } + }) + } +} + +impl Sum for PyCircuitCost { + fn sum>(iter: I) -> Self { + Python::with_gil(|py| { + let cost = iter + .fold(None, |acc: Option, c| { + Some(match acc { + None => c.cost, + Some(cost) => cost + .call_method1(py, "__add__", (c.cost,)) + .expect("Could not add circuit cost objects."), + }) + }) + .unwrap_or_else(|| py.None()); + PyCircuitCost { cost } + }) + } +} + +impl PartialEq for PyCircuitCost { + fn eq(&self, other: &Self) -> bool { + Python::with_gil(|py| { + let res = self + .cost + .call_method1(py, "__eq__", (&other.cost,)) + .expect("Could not compare circuit cost objects."); + res.is_true(py) + .expect("Could not compare circuit cost objects.") + }) + } +} + +impl Eq for PyCircuitCost {} + +impl PartialOrd for PyCircuitCost { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PyCircuitCost { + fn cmp(&self, other: &Self) -> Ordering { + Python::with_gil(|py| -> PyResult { + let res = self.cost.call_method1(py, "__lt__", (&other.cost,))?; + if res.is_true(py)? { + return Ok(Ordering::Less); + } + let res = self.cost.call_method1(py, "__eq__", (&other.cost,))?; + if res.is_true(py)? { + return Ok(Ordering::Equal); + } + Ok(Ordering::Greater) + }) + .expect("Could not compare circuit cost objects.") + } +} + +impl CostDelta for PyCircuitCost { + fn as_isize(&self) -> isize { + Python::with_gil(|py| { + let res = self + .cost + .call_method0(py, "__int__") + .expect("Could not convert the circuit cost object to an integer."); + res.extract(py) + .expect("Could not convert the circuit cost object to an integer.") + }) + } +} + +impl CircuitCost for PyCircuitCost { + type CostDelta = PyCircuitCost; + + fn as_usize(&self) -> usize { + self.as_isize() as usize + } + + fn sub_cost(&self, other: &Self) -> Self::CostDelta { + self.clone() - other.clone() + } + + fn add_delta(&self, delta: &Self::CostDelta) -> Self { + self.clone() + delta.clone() + } + + fn div_cost(&self, n: std::num::NonZeroUsize) -> Self { + Python::with_gil(|py| { + let res = self + .cost + .call_method0(py, "__div__") + .expect("Could not divide the circuit cost object."); + Self { cost: res } + }) + } +} diff --git a/tket2-py/test/test_bindings.py b/tket2-py/test/test_bindings.py deleted file mode 100644 index 28aa8c0c..00000000 --- a/tket2-py/test/test_bindings.py +++ /dev/null @@ -1,93 +0,0 @@ -from dataclasses import dataclass -from pytket.circuit import Circuit - -from tket2 import passes -from tket2.passes import greedy_depth_reduce -from tket2.circuit import Tk2Circuit, to_hugr_dot -from tket2.pattern import Rule, RuleMatcher - - -def test_conversion(): - tk1 = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) - tk1_dot = to_hugr_dot(tk1) - - tk2 = Tk2Circuit(tk1) - tk2_dot = to_hugr_dot(tk2) - - assert type(tk2) == Tk2Circuit - assert tk1_dot == tk2_dot - - tk1_back = tk2.to_tket1() - - assert tk1_back == tk1 - assert type(tk1_back) == Circuit - - -@dataclass -class DepthOptimisePass: - def apply(self, circ: Circuit) -> Circuit: - (circ, n_moves) = greedy_depth_reduce(circ) - return circ - - -def test_depth_optimise(): - c = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) - - assert c.depth() == 3 - - c = DepthOptimisePass().apply(c) - - assert c.depth() == 2 - - -def test_chunks(): - c = Circuit(4).CX(0, 2).CX(1, 3).CX(1, 2).CX(0, 3).CX(1, 3) - - assert c.depth() == 3 - - chunks = passes.chunks(c, 2) - circuits = chunks.circuits() - chunks.update_circuit(0, circuits[0]) - c2 = chunks.reassemble() - - assert c2.depth() == 3 - assert type(c2) == Circuit - - # Split and reassemble, with a tket2 circuit - tk2_chunks = passes.chunks(Tk2Circuit(c2), 2) - tk2 = tk2_chunks.reassemble() - - assert type(tk2) == Tk2Circuit - - -def test_cx_rule(): - c = Tk2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2)) - - rule = Rule(Circuit(2).CX(0, 1).CX(0, 1), Circuit(2)) - matcher = RuleMatcher([rule]) - - mtch = matcher.find_match(c) - - c.apply_match(mtch) - - out = c.to_tket1() - - assert out == Circuit(4).CX(0, 2) - - -def test_multiple_rules(): - circ = Tk2Circuit(Circuit(3).CX(0, 1).H(0).H(1).H(2).Z(0).H(0).H(1).H(2)) - - rule1 = Rule(Circuit(1).H(0).Z(0).H(0), Circuit(1).X(0)) - rule2 = Rule(Circuit(1).H(0).H(0), Circuit(1)) - matcher = RuleMatcher([rule1, rule2]) - - match_count = 0 - while match := matcher.find_match(circ): - match_count += 1 - circ.apply_match(match) - - assert match_count == 3 - - out = circ.to_tket1() - assert out == Circuit(3).CX(0, 1).X(0) diff --git a/tket2-py/test/test_circuit.py b/tket2-py/test/test_circuit.py new file mode 100644 index 00000000..e7bbb306 --- /dev/null +++ b/tket2-py/test/test_circuit.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from pytket.circuit import Circuit + +from tket2.circuit import Tk2Circuit, Tk2Op, to_hugr_dot + + +@dataclass +class CustomCost: + gate_count: int + h_count: int + + def __add__(self, other): + return CustomCost( + self.gate_count + other.gate_count, self.h_count + other.h_count + ) + + +def test_cost(): + circ = Tk2Circuit(Circuit(4).CX(0, 1).H(1).CX(1, 2).CX(0, 3).H(0)) + + print(circ.circuit_cost(lambda op: int(op == Tk2Op.CX))) + + assert circ.circuit_cost(lambda op: int(op == Tk2Op.CX)) == 3 + assert circ.circuit_cost(lambda op: CustomCost(1, op == Tk2Op.H)) == CustomCost( + 5, 2 + ) + + +def test_hash(): + circA = Tk2Circuit(Circuit(4).CX(0, 1).CX(1, 2).CX(0, 3)) + circB = Tk2Circuit(Circuit(4).CX(1, 2).CX(0, 1).CX(0, 3)) + circC = Tk2Circuit(Circuit(4).CX(0, 1).CX(0, 3).CX(1, 2)) + + assert hash(circA) != hash(circB) + assert hash(circA) == hash(circC) + + +def test_conversion(): + tk1 = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) + tk1_dot = to_hugr_dot(tk1) + + tk2 = Tk2Circuit(tk1) + tk2_dot = to_hugr_dot(tk2) + + assert type(tk2) == Tk2Circuit + assert tk1_dot == tk2_dot + + tk1_back = tk2.to_tket1() + + assert tk1_back == tk1 + assert type(tk1_back) == Circuit diff --git a/tket2-py/test/test_pass.py b/tket2-py/test/test_pass.py index b0f1284b..366f34ff 100644 --- a/tket2-py/test/test_pass.py +++ b/tket2-py/test/test_pass.py @@ -1,5 +1,9 @@ from pytket import Circuit, OpType -from tket2.passes import badger_pass +from dataclasses import dataclass + +from tket2.passes import badger_pass, greedy_depth_reduce, chunks +from tket2.circuit import Tk2Circuit +from tket2.pattern import Rule, RuleMatcher def test_simple_badger_pass_no_opt(): @@ -7,3 +11,73 @@ def test_simple_badger_pass_no_opt(): badger = badger_pass(max_threads=1, timeout=0) badger.apply(c) assert c.n_gates_of_type(OpType.CX) == 6 + + +@dataclass +class DepthOptimisePass: + def apply(self, circ: Circuit) -> Circuit: + (circ, n_moves) = greedy_depth_reduce(circ) + return circ + + +def test_depth_optimise(): + c = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) + + assert c.depth() == 3 + + c = DepthOptimisePass().apply(c) + + assert c.depth() == 2 + + +def test_chunks(): + c = Circuit(4).CX(0, 2).CX(1, 3).CX(1, 2).CX(0, 3).CX(1, 3) + + assert c.depth() == 3 + + circ_chunks = chunks(c, 2) + circuits = circ_chunks.circuits() + circ_chunks.update_circuit(0, circuits[0]) + c2 = circ_chunks.reassemble() + + assert c2.depth() == 3 + assert type(c2) == Circuit + + # Split and reassemble, with a tket2 circuit + tk2_chunks = chunks(Tk2Circuit(c2), 2) + tk2 = tk2_chunks.reassemble() + + assert type(tk2) == Tk2Circuit + + +def test_cx_rule(): + c = Tk2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2)) + + rule = Rule(Circuit(2).CX(0, 1).CX(0, 1), Circuit(2)) + matcher = RuleMatcher([rule]) + + mtch = matcher.find_match(c) + + c.apply_rewrite(mtch) + + out = c.to_tket1() + + assert out == Circuit(4).CX(0, 2) + + +def test_multiple_rules(): + circ = Tk2Circuit(Circuit(3).CX(0, 1).H(0).H(1).H(2).Z(0).H(0).H(1).H(2)) + + rule1 = Rule(Circuit(1).H(0).Z(0).H(0), Circuit(1).X(0)) + rule2 = Rule(Circuit(1).H(0).H(0), Circuit(1)) + matcher = RuleMatcher([rule1, rule2]) + + match_count = 0 + while match := matcher.find_match(circ): + match_count += 1 + circ.apply_rewrite(match) + + assert match_count == 3 + + out = circ.to_tket1() + assert out == Circuit(3).CX(0, 1).X(0) diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index 31f95671..6dd3e346 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -5,6 +5,8 @@ pub mod cost; mod hash; pub mod units; +use std::iter::Sum; + pub use command::{Command, CommandIterator}; pub use hash::CircuitHash; use itertools::Either::{Left, Right}; @@ -25,7 +27,6 @@ pub use hugr::ops::OpType; pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; pub use hugr::{Node, Port, Wire}; -use self::cost::CircuitCost; use self::units::{filter, FilteredUnits, Units}; /// An object behaving like a quantum circuit. @@ -135,7 +136,7 @@ pub trait Circuit: HugrView { fn circuit_cost(&self, op_cost: F) -> C where Self: Sized, - C: CircuitCost, + C: Sum, F: Fn(&OpType) -> C, { self.commands().map(|cmd| op_cost(cmd.optype())).sum() @@ -146,7 +147,7 @@ pub trait Circuit: HugrView { #[inline] fn nodes_cost(&self, nodes: impl IntoIterator, op_cost: F) -> C where - C: CircuitCost, + C: Sum, F: Fn(&OpType) -> C, { nodes.into_iter().map(|n| op_cost(self.get_optype(n))).sum()