diff --git a/README.md b/README.md index 7fa7ec36..769c84c0 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,13 @@ Version 2 of the TKET compiler. ## Features - `pyo3` -This optional feature enables some python bindings via pyo3. See the `tket2-py` folder for more. + Enables some python bindings via pyo3. See the `tket2-py` folder for more. - `portmatching` - This enables pattern matching using the `portmatching` crate. + Enables pattern matching using the `portmatching` crate. + +- `rewrite-tracing` + Adds opt-in tracking of the rewrites applied to a circuit. ## Developing TKET2 diff --git a/badger-optimiser/Cargo.toml b/badger-optimiser/Cargo.toml index 45a50b44..b60ff20a 100644 --- a/badger-optimiser/Cargo.toml +++ b/badger-optimiser/Cargo.toml @@ -11,7 +11,7 @@ license-file = { workspace = true } [dependencies] clap = { version = "4.4.2", features = ["derive"] } serde_json = "1.0" -tket2 = { workspace = true, features = ["portmatching"] } +tket2 = { workspace = true, features = ["portmatching", "rewrite-tracing"] } quantinuum-hugr = { workspace = true } itertools = { workspace = true } tket-json-rs = { workspace = true } diff --git a/badger-optimiser/src/main.rs b/badger-optimiser/src/main.rs index 861f94af..b87ecaef 100644 --- a/badger-optimiser/src/main.rs +++ b/badger-optimiser/src/main.rs @@ -15,6 +15,7 @@ use tket2::json::{load_tk1_json_file, save_tk1_json_file}; use tket2::optimiser::badger::log::BadgerLogger; use tket2::optimiser::badger::BadgerOptions; use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser}; +use tket2::rewrite::trace::RewriteTracer; #[cfg(feature = "peak_alloc")] #[global_allocator] @@ -104,6 +105,12 @@ struct CmdLineArgs { help = "The priority queue size. Defaults to 100." )] queue_size: usize, + /// Trace each rewrite applied to the circuit. + #[arg( + long = "rewrite-tracing", + help = "Trace each rewrite applied to the circuit. Prints statistics for the best circuit at the end of the optimisation." + )] + rewrite_tracing: bool, } fn main() -> Result<(), Box<dyn std::error::Error>> { @@ -129,7 +136,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { let badger_logger = BadgerLogger::new(circ_candidates_csv); - let circ = load_tk1_json_file(input_path)?; + let mut circ = load_tk1_json_file(input_path)?; + if opts.rewrite_tracing { + circ.enable_rewrite_tracing(); + } println!("Loading optimiser..."); let Ok(optimiser) = load_optimiser(ecc_path) else { diff --git a/tket2/Cargo.toml b/tket2/Cargo.toml index 31f17864..fdc0189c 100644 --- a/tket2/Cargo.toml +++ b/tket2/Cargo.toml @@ -15,9 +15,17 @@ name = "tket2" path = "src/lib.rs" [features] +# Enables some python bindings pyo3 = ["dep:pyo3"] + +# Enables search and replace optimisation passes using the `portmatching` crate. portmatching = ["dep:portmatching", "dep:rmp-serde"] +# Stores a trace of the applied rewrites +rewrite-tracing = [] + +default = [] + [dependencies] lazy_static = "1.4.0" cgmath = "0.18.0" @@ -44,7 +52,7 @@ strum_macros = "0.25.2" strum = "0.25.0" fxhash = "0.2.1" rmp-serde = { version = "1.1.2", optional = true } -delegate = "0.10.0" +delegate = "0.11.0" csv = { version = "1.2.2" } chrono = { version = "0.4.30" } bytemuck = "1.14.0" diff --git a/tket2/src/optimiser/badger.rs b/tket2/src/optimiser/badger.rs index 687d2813..789237b9 100644 --- a/tket2/src/optimiser/badger.rs +++ b/tket2/src/optimiser/badger.rs @@ -37,6 +37,7 @@ use crate::optimiser::badger::hugr_pqueue::{Entry, HugrPQ}; use crate::optimiser::badger::worker::BadgerWorker; use crate::passes::CircuitChunks; use crate::rewrite::strategy::RewriteStrategy; +use crate::rewrite::trace::RewriteTracer; use crate::rewrite::Rewriter; use crate::Circuit; @@ -158,7 +159,8 @@ where let mut best_circ = circ.clone(); let mut best_circ_cost = self.cost(circ); - logger.log_best(&best_circ_cost); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(&best_circ_cost, num_rewrites); // Hash of seen circuits. Dot not store circuits as this map gets huge let hash = circ.circuit_hash().unwrap(); @@ -181,7 +183,8 @@ where if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost.clone(); - logger.log_best(&best_circ_cost); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(&best_circ_cost, num_rewrites); last_best_time = Instant::now(); } circ_cnt += 1; @@ -297,7 +300,8 @@ where if cost < best_circ_cost { best_circ = circ; best_circ_cost = cost; - logger.log_best(&best_circ_cost); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(&best_circ_cost, num_rewrites); if let Some(t) = opt.progress_timeout { progress_timeout_event = crossbeam_channel::at(Instant::now() + Duration::from_secs(t)); } @@ -337,7 +341,8 @@ where if cost < best_circ_cost { best_circ = circ; best_circ_cost = cost; - logger.log_best(&best_circ_cost); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(&best_circ_cost, num_rewrites); } } PriorityChannelLog::CircuitCount { @@ -381,7 +386,8 @@ where let mut chunks = CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op)); - logger.log_best(circ_cost.clone()); + let num_rewrites = circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(circ_cost.clone(), num_rewrites); let (joins, rx_work): (Vec<_>, Vec<_>) = chunks .iter_mut() @@ -420,7 +426,8 @@ where let best_circ = chunks.reassemble()?; let best_circ_cost = self.cost(&best_circ); if best_circ_cost.clone() < circ_cost { - logger.log_best(best_circ_cost.clone()); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(best_circ_cost.clone(), num_rewrites); } logger.log_processing_end(opt.n_threads.get(), None, best_circ_cost, true, false); diff --git a/tket2/src/optimiser/badger/log.rs b/tket2/src/optimiser/badger/log.rs index 3c01fdf6..e0fb5eea 100644 --- a/tket2/src/optimiser/badger/log.rs +++ b/tket2/src/optimiser/badger/log.rs @@ -49,8 +49,17 @@ impl<'w> BadgerLogger<'w> { /// Log a new best candidate #[inline] - pub fn log_best<C: Debug + serde::Serialize>(&mut self, best_cost: C) { - self.log(format!("new best of size {:?}", best_cost)); + pub fn log_best<C: Debug + serde::Serialize>( + &mut self, + best_cost: C, + num_rewrites: Option<usize>, + ) { + match num_rewrites { + Some(rs) => self.log(format!( + "new best of size {best_cost:?} after {rs} rewrites" + )), + None => self.log(format!("new best of size {:?}", best_cost)), + } if let Some(csv_writer) = self.circ_candidates_csv.as_mut() { csv_writer.serialize(BestCircSer::new(best_cost)).unwrap(); csv_writer.flush().unwrap(); diff --git a/tket2/src/rewrite.rs b/tket2/src/rewrite.rs index c8e44e85..45dfafec 100644 --- a/tket2/src/rewrite.rs +++ b/tket2/src/rewrite.rs @@ -3,12 +3,12 @@ #[cfg(feature = "portmatching")] pub mod ecc_rewriter; pub mod strategy; +pub mod trace; use bytemuck::TransparentWrapper; #[cfg(feature = "portmatching")] pub use ecc_rewriter::ECCRewriter; -use delegate::delegate; use derive_more::{From, Into}; use hugr::hugr::views::sibling_subgraph::{InvalidReplacement, InvalidSubgraph}; use hugr::Node; @@ -19,6 +19,8 @@ use hugr::{ use crate::circuit::Circuit; +use self::trace::RewriteTracer; + /// A subcircuit of a circuit. #[derive(Debug, Clone, From, Into)] #[repr(transparent)] @@ -107,11 +109,17 @@ impl CircuitRewrite { self.0.invalidation_set() } - delegate! { - to self.0 { - /// Apply the rewrite rule to a circuit. - pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError>; - } + /// Apply the rewrite rule to a circuit. + #[inline] + pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { + circ.add_rewrite_trace(&self); + self.0.apply(circ) + } + + /// Apply the rewrite rule to a circuit, without registering it in the rewrite trace. + #[inline] + pub fn apply_notrace(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { + self.0.apply(circ) } } diff --git a/tket2/src/rewrite/strategy.rs b/tket2/src/rewrite/strategy.rs index 456e21bf..899b7909 100644 --- a/tket2/src/rewrite/strategy.rs +++ b/tket2/src/rewrite/strategy.rs @@ -28,6 +28,7 @@ use itertools::Itertools; use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, MajorMinorCost}; use crate::Circuit; +use super::trace::{RewriteTrace, RewriteTracer}; use super::CircuitRewrite; /// Rewriting strategies for circuit optimisation. @@ -219,6 +220,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> { let mut curr_circ = circ.clone(); let mut changed_nodes = HashSet::new(); let mut cost_delta = Default::default(); + let mut composed_rewrite_count = 0; for (rewrite, delta) in &rewrites[i..] { if !changed_nodes.is_empty() && rewrite @@ -230,11 +232,15 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> { changed_nodes.extend(rewrite.invalidation_set()); cost_delta += delta.clone(); + composed_rewrite_count += 1; + rewrite .clone() - .apply(&mut curr_circ) + .apply_notrace(&mut curr_circ) .expect("Could not perform rewrite in exhaustive greedy strategy"); } + + curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count)); rewrite_sets.circs.push(curr_circ); rewrite_sets.cost_deltas.push(cost_delta); } @@ -462,6 +468,7 @@ mod tests { use hugr::{Hugr, Node}; use itertools::Itertools; + use crate::rewrite::trace::REWRITE_TRACING_ENABLED; use crate::{ circuit::Circuit, rewrite::{CircuitRewrite, Subcircuit}, @@ -494,9 +501,16 @@ mod tests { #[test] fn test_greedy_strategy() { - let circ = n_cx(10); + let mut circ = n_cx(10); let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec(); + assert_eq!(circ.rewrite_trace(), None); + circ.enable_rewrite_tracing(); + match REWRITE_TRACING_ENABLED { + true => assert_eq!(circ.rewrite_trace(), Some(vec![])), + false => assert_eq!(circ.rewrite_trace(), None), + } + let rws = [ rw_to_empty(&circ, cx_gates[0..2].to_vec()), rw_to_full(&circ, cx_gates[4..7].to_vec()), @@ -508,12 +522,17 @@ mod tests { let rewritten = strategy.apply_rewrites(rws, &circ); assert_eq!(rewritten.len(), 1); assert_eq!(rewritten.circs[0].num_gates(), 5); + + if REWRITE_TRACING_ENABLED { + assert_eq!(rewritten.circs[0].rewrite_trace().unwrap().len(), 3); + } } #[test] fn test_exhaustive_default_strategy() { - let circ = n_cx(10); + let mut circ = n_cx(10); let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec(); + circ.enable_rewrite_tracing(); let rws = [ rw_to_empty(&circ, cx_gates[0..2].to_vec()), @@ -527,6 +546,23 @@ mod tests { let exp_circ_lens = HashSet::from_iter([3, 7, 9]); let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect(); assert_eq!(circ_lens, exp_circ_lens); + + if REWRITE_TRACING_ENABLED { + // Each strategy branch applies a single rewrite, composed of + // multiple individual elements from `rws`. + assert_eq!( + rewritten.circs[0].rewrite_trace().unwrap(), + vec![RewriteTrace::new(3)] + ); + assert_eq!( + rewritten.circs[1].rewrite_trace().unwrap(), + vec![RewriteTrace::new(2)] + ); + assert_eq!( + rewritten.circs[2].rewrite_trace().unwrap(), + vec![RewriteTrace::new(1)] + ); + } } #[test] diff --git a/tket2/src/rewrite/trace.rs b/tket2/src/rewrite/trace.rs new file mode 100644 index 00000000..4d0a4770 --- /dev/null +++ b/tket2/src/rewrite/trace.rs @@ -0,0 +1,128 @@ +//! Utilities for tracing the rewrites applied to a circuit. +//! +//! This is only tracked if the `rewrite-tracing` feature is enabled. + +use hugr::hugr::hugrmut::HugrMut; +use hugr::hugr::NodeMetadata; +use itertools::Itertools; + +use crate::Circuit; + +use super::CircuitRewrite; + +/// Metadata key for the circuit rewrite trace. +pub const METADATA_REWRITES: &str = "TKET2.rewrites"; + +/// Global read-only flag for enabling rewrite tracing. +/// Enable it by setting the `rewrite-tracing` feature. +/// +/// Note that circuits must be explicitly enabled for rewrite tracing by calling +/// [`RewriteTracer::enable_rewrite_tracing`]. +pub const REWRITE_TRACING_ENABLED: bool = cfg!(feature = "rewrite-tracing"); + +/// The trace of a rewrite applied to a circuit. +/// +/// Traces are only enabled if the `rewrite-tracing` feature is enabled and +/// [`RewriteTracer::enable_rewrite_tracing`] is called on the circuit. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RewriteTrace { + /// A count of the number of individual patterns matched for this rewrite step. + /// + /// This is relevant when using a greedy rewrite strategy. + individual_matches: u16, +} + +impl From<&CircuitRewrite> for RewriteTrace { + #[inline] + fn from(_rewrite: &CircuitRewrite) -> Self { + // NOTE: We don't currently track any actual information about the rewrite. + Self { + individual_matches: 1, + } + } +} + +impl RewriteTrace { + /// Create a new trace. + #[inline] + pub fn new(individual_matches: u16) -> Self { + Self { individual_matches } + } +} + +impl From<&serde_json::Value> for RewriteTrace { + #[inline] + fn from(value: &serde_json::Value) -> Self { + Self { + individual_matches: value.as_u64().unwrap() as u16, + } + } +} + +impl From<RewriteTrace> for serde_json::Value { + #[inline] + fn from(trace: RewriteTrace) -> Self { + serde_json::Value::from(trace.individual_matches) + } +} + +/// Extension trait for circuits that can trace rewrites applied to them. +/// +/// This is only tracked if the `rewrite-tracing` feature is enabled and +/// `enable_rewrite_tracing` is called on the circuit before. +pub trait RewriteTracer: Circuit + HugrMut + Sized { + /// Enable rewrite tracing for the circuit. + #[inline] + fn enable_rewrite_tracing(&mut self) { + if !REWRITE_TRACING_ENABLED { + return; + } + let meta = self + .get_metadata_mut(self.root(), METADATA_REWRITES) + .unwrap(); + if *meta == NodeMetadata::Null { + *meta = NodeMetadata::Array(vec![]); + } + } + + /// Register a rewrite applied to the circuit. + /// + /// Returns `true` if the rewrite was successfully registered, or `false` if it was ignored. + #[inline] + fn add_rewrite_trace(&mut self, rewrite: impl Into<RewriteTrace>) -> bool { + if !REWRITE_TRACING_ENABLED { + return false; + } + match self + .get_metadata_mut(self.root(), METADATA_REWRITES) + .ok() + .and_then(|m| m.as_array_mut()) + { + Some(meta) => { + let rewrite = rewrite.into(); + meta.push(rewrite.into()); + true + } + // Tracing was not enable for this circuit. + None => false, + } + } + + /// Returns the traces of rewrites applied to the circuit. + /// + /// Returns `None` if rewrite tracing is not enabled for this circuit. + // + // TODO return an `impl Iterator` once rust 1.75 lands. + #[inline] + fn rewrite_trace(&self) -> Option<Vec<RewriteTrace>> { + if !REWRITE_TRACING_ENABLED { + return None; + } + let meta = self.get_metadata(self.root(), METADATA_REWRITES)?; + let rewrites = meta.as_array()?; + Some(rewrites.iter().map_into().collect_vec()) + } +} + +impl<T: Circuit + HugrMut> RewriteTracer for T {}