Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rewrite tracing #267

Merged
merged 5 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ Version 2 of the TKET compiler.
## Features

- `pyo3`
This optional feature enables some python bindings via pyo3. See the `tket2-py` folder for more.
Enables some python bindings via pyo3. See the `tket2-py` folder for more.

- `portmatching`
This enables pattern matching using the `portmatching` crate.
Enables pattern matching using the `portmatching` crate.

- `rewrite-tracing`
Adds opt-in tracking of the rewrites applied to a circuit.

## Developing TKET2

Expand Down
2 changes: 1 addition & 1 deletion badger-optimiser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license-file = { workspace = true }
[dependencies]
clap = { version = "4.4.2", features = ["derive"] }
serde_json = "1.0"
tket2 = { workspace = true, features = ["portmatching"] }
tket2 = { workspace = true, features = ["portmatching", "rewrite-tracing"] }
quantinuum-hugr = { workspace = true }
itertools = { workspace = true }
tket-json-rs = { workspace = true }
Expand Down
12 changes: 11 additions & 1 deletion badger-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tket2::json::{load_tk1_json_file, save_tk1_json_file};
use tket2::optimiser::badger::log::BadgerLogger;
use tket2::optimiser::badger::BadgerOptions;
use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser};
use tket2::rewrite::trace::RewriteTracer;

#[cfg(feature = "peak_alloc")]
#[global_allocator]
Expand Down Expand Up @@ -104,6 +105,12 @@ struct CmdLineArgs {
help = "The priority queue size. Defaults to 100."
)]
queue_size: usize,
/// Trace each rewrite applied to the circuit.
#[arg(
long = "rewrite-tracing",
help = "Trace each rewrite applied to the circuit. Prints statistics for the best circuit at the end of the optimisation."
)]
rewrite_tracing: bool,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -129,7 +136,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

let badger_logger = BadgerLogger::new(circ_candidates_csv);

let circ = load_tk1_json_file(input_path)?;
let mut circ = load_tk1_json_file(input_path)?;
if opts.rewrite_tracing {
circ.enable_rewrite_tracing();
}

println!("Loading optimiser...");
let Ok(optimiser) = load_optimiser(ecc_path) else {
Expand Down
10 changes: 9 additions & 1 deletion tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ name = "tket2"
path = "src/lib.rs"

[features]
# Enables some python bindings
pyo3 = ["dep:pyo3"]

# Enables search and replace optimisation passes using the `portmatching` crate.
portmatching = ["dep:portmatching", "dep:rmp-serde"]

# Stores a trace of the applied rewrites
rewrite-tracing = []

default = []

[dependencies]
lazy_static = "1.4.0"
cgmath = "0.18.0"
Expand All @@ -44,7 +52,7 @@ strum_macros = "0.25.2"
strum = "0.25.0"
fxhash = "0.2.1"
rmp-serde = { version = "1.1.2", optional = true }
delegate = "0.10.0"
delegate = "0.11.0"
csv = { version = "1.2.2" }
chrono = { version = "0.4.30" }
bytemuck = "1.14.0"
Expand Down
19 changes: 13 additions & 6 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::optimiser::badger::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::badger::worker::BadgerWorker;
use crate::passes::CircuitChunks;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::trace::RewriteTracer;
use crate::rewrite::Rewriter;
use crate::Circuit;

Expand Down Expand Up @@ -158,7 +159,8 @@ where

let mut best_circ = circ.clone();
let mut best_circ_cost = self.cost(circ);
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let hash = circ.circuit_hash().unwrap();
Expand All @@ -181,7 +183,8 @@ where
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost.clone();
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);
last_best_time = Instant::now();
}
circ_cnt += 1;
Expand Down Expand Up @@ -297,7 +300,8 @@ where
if cost < best_circ_cost {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);
if let Some(t) = opt.progress_timeout {
progress_timeout_event = crossbeam_channel::at(Instant::now() + Duration::from_secs(t));
}
Expand Down Expand Up @@ -337,7 +341,8 @@ where
if cost < best_circ_cost {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);
}
}
PriorityChannelLog::CircuitCount {
Expand Down Expand Up @@ -381,7 +386,8 @@ where
let mut chunks =
CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op));

logger.log_best(circ_cost.clone());
let num_rewrites = circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(circ_cost.clone(), num_rewrites);

let (joins, rx_work): (Vec<_>, Vec<_>) = chunks
.iter_mut()
Expand Down Expand Up @@ -420,7 +426,8 @@ where
let best_circ = chunks.reassemble()?;
let best_circ_cost = self.cost(&best_circ);
if best_circ_cost.clone() < circ_cost {
logger.log_best(best_circ_cost.clone());
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(best_circ_cost.clone(), num_rewrites);
}

logger.log_processing_end(opt.n_threads.get(), None, best_circ_cost, true, false);
Expand Down
13 changes: 11 additions & 2 deletions tket2/src/optimiser/badger/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,17 @@ impl<'w> BadgerLogger<'w> {

/// Log a new best candidate
#[inline]
pub fn log_best<C: Debug + serde::Serialize>(&mut self, best_cost: C) {
self.log(format!("new best of size {:?}", best_cost));
pub fn log_best<C: Debug + serde::Serialize>(
&mut self,
best_cost: C,
num_rewrites: Option<usize>,
) {
match num_rewrites {
Some(rs) => self.log(format!(
"new best of size {best_cost:?} after {rs} rewrites"
)),
None => self.log(format!("new best of size {:?}", best_cost)),
}
if let Some(csv_writer) = self.circ_candidates_csv.as_mut() {
csv_writer.serialize(BestCircSer::new(best_cost)).unwrap();
csv_writer.flush().unwrap();
Expand Down
20 changes: 14 additions & 6 deletions tket2/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#[cfg(feature = "portmatching")]
pub mod ecc_rewriter;
pub mod strategy;
pub mod trace;

use bytemuck::TransparentWrapper;
#[cfg(feature = "portmatching")]
pub use ecc_rewriter::ECCRewriter;

use delegate::delegate;
use derive_more::{From, Into};
use hugr::hugr::views::sibling_subgraph::{InvalidReplacement, InvalidSubgraph};
use hugr::Node;
Expand All @@ -19,6 +19,8 @@ use hugr::{

use crate::circuit::Circuit;

use self::trace::RewriteTracer;

/// A subcircuit of a circuit.
#[derive(Debug, Clone, From, Into)]
#[repr(transparent)]
Expand Down Expand Up @@ -107,11 +109,17 @@ impl CircuitRewrite {
self.0.invalidation_set()
}

delegate! {
to self.0 {
/// Apply the rewrite rule to a circuit.
pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError>;
}
/// Apply the rewrite rule to a circuit.
#[inline]
pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> {
circ.add_rewrite_trace(&self);
self.0.apply(circ)
}

/// Apply the rewrite rule to a circuit, without registering it in the rewrite trace.
#[inline]
pub fn apply_notrace(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> {
self.0.apply(circ)
}
}

Expand Down
42 changes: 39 additions & 3 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use itertools::Itertools;
use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, MajorMinorCost};
use crate::Circuit;

use super::trace::{RewriteTrace, RewriteTracer};
use super::CircuitRewrite;

/// Rewriting strategies for circuit optimisation.
Expand Down Expand Up @@ -219,6 +220,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
let mut curr_circ = circ.clone();
let mut changed_nodes = HashSet::new();
let mut cost_delta = Default::default();
let mut composed_rewrite_count = 0;
for (rewrite, delta) in &rewrites[i..] {
if !changed_nodes.is_empty()
&& rewrite
Expand All @@ -230,11 +232,15 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
changed_nodes.extend(rewrite.invalidation_set());
cost_delta += delta.clone();

composed_rewrite_count += 1;

rewrite
.clone()
.apply(&mut curr_circ)
.apply_notrace(&mut curr_circ)
.expect("Could not perform rewrite in exhaustive greedy strategy");
}

curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count));
rewrite_sets.circs.push(curr_circ);
rewrite_sets.cost_deltas.push(cost_delta);
}
Expand Down Expand Up @@ -462,6 +468,7 @@ mod tests {
use hugr::{Hugr, Node};
use itertools::Itertools;

use crate::rewrite::trace::REWRITE_TRACING_ENABLED;
use crate::{
circuit::Circuit,
rewrite::{CircuitRewrite, Subcircuit},
Expand Down Expand Up @@ -494,9 +501,16 @@ mod tests {

#[test]
fn test_greedy_strategy() {
let circ = n_cx(10);
let mut circ = n_cx(10);
let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec();

assert_eq!(circ.rewrite_trace(), None);
circ.enable_rewrite_tracing();
match REWRITE_TRACING_ENABLED {
true => assert_eq!(circ.rewrite_trace(), Some(vec![])),
false => assert_eq!(circ.rewrite_trace(), None),
}

let rws = [
rw_to_empty(&circ, cx_gates[0..2].to_vec()),
rw_to_full(&circ, cx_gates[4..7].to_vec()),
Expand All @@ -508,12 +522,17 @@ mod tests {
let rewritten = strategy.apply_rewrites(rws, &circ);
assert_eq!(rewritten.len(), 1);
assert_eq!(rewritten.circs[0].num_gates(), 5);

if REWRITE_TRACING_ENABLED {
assert_eq!(rewritten.circs[0].rewrite_trace().unwrap().len(), 3);
}
}

#[test]
fn test_exhaustive_default_strategy() {
let circ = n_cx(10);
let mut circ = n_cx(10);
let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec();
circ.enable_rewrite_tracing();

let rws = [
rw_to_empty(&circ, cx_gates[0..2].to_vec()),
Expand All @@ -527,6 +546,23 @@ mod tests {
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect();
assert_eq!(circ_lens, exp_circ_lens);

if REWRITE_TRACING_ENABLED {
// Each strategy branch applies a single rewrite, composed of
// multiple individual elements from `rws`.
assert_eq!(
rewritten.circs[0].rewrite_trace().unwrap(),
vec![RewriteTrace::new(3)]
);
assert_eq!(
rewritten.circs[1].rewrite_trace().unwrap(),
vec![RewriteTrace::new(2)]
);
assert_eq!(
rewritten.circs[2].rewrite_trace().unwrap(),
vec![RewriteTrace::new(1)]
);
}
}

#[test]
Expand Down
Loading
Loading