Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rewrite tracing #267

Merged
merged 5 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: Rewrite tracing
  • Loading branch information
aborgna-q committed Dec 14, 2023
commit a314f3f4969406d468021ebf055d6fc6df165cce
2 changes: 1 addition & 1 deletion badger-optimiser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license-file = { workspace = true }
[dependencies]
clap = { version = "4.4.2", features = ["derive"] }
serde_json = "1.0"
tket2 = { workspace = true, features = ["portmatching"] }
tket2 = { workspace = true, features = ["portmatching", "rewrite-tracing"] }
quantinuum-hugr = { workspace = true }
itertools = { workspace = true }
tket-json-rs = { workspace = true }
Expand Down
12 changes: 11 additions & 1 deletion badger-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tket2::json::{load_tk1_json_file, save_tk1_json_file};
use tket2::optimiser::badger::log::BadgerLogger;
use tket2::optimiser::badger::BadgerOptions;
use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser};
use tket2::rewrite::trace::RewriteTracer;

#[cfg(feature = "peak_alloc")]
#[global_allocator]
Expand Down Expand Up @@ -104,6 +105,12 @@ struct CmdLineArgs {
help = "The priority queue size. Defaults to 100."
)]
queue_size: usize,
/// Trace each rewrite applied to the circuit.
#[arg(
long = "rewrite-tracing",
help = "Trace each rewrite applied to the circuit. Prints statistics for the best circuit at the end of the optimisation."
)]
rewrite_tracing: bool,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -129,7 +136,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

let badger_logger = BadgerLogger::new(circ_candidates_csv);

let circ = load_tk1_json_file(input_path)?;
let mut circ = load_tk1_json_file(input_path)?;
if opts.rewrite_tracing {
circ.enable_rewrite_tracing();
}

println!("Loading optimiser...");
let Ok(optimiser) = load_optimiser(ecc_path) else {
Expand Down
10 changes: 9 additions & 1 deletion tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ name = "tket2"
path = "src/lib.rs"

[features]
# Enables some python bindings
pyo3 = ["dep:pyo3"]

# Enables search and replace optimisation passes using the `portmatching` crate.
portmatching = ["dep:portmatching", "dep:rmp-serde"]

# Stores a trace of the applied rewrites
rewrite-tracing = []

default = []

[dependencies]
lazy_static = "1.4.0"
cgmath = "0.18.0"
Expand All @@ -44,7 +52,7 @@ strum_macros = "0.25.2"
strum = "0.25.0"
fxhash = "0.2.1"
rmp-serde = { version = "1.1.2", optional = true }
delegate = "0.10.0"
delegate = "0.11.0"
csv = { version = "1.2.2" }
chrono = { version = "0.4.30" }
bytemuck = "1.14.0"
Expand Down
19 changes: 13 additions & 6 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::optimiser::badger::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::badger::worker::BadgerWorker;
use crate::passes::CircuitChunks;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::trace::RewriteTracer;
use crate::rewrite::Rewriter;
use crate::Circuit;

Expand Down Expand Up @@ -158,7 +159,8 @@ where

let mut best_circ = circ.clone();
let mut best_circ_cost = self.cost(circ);
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let hash = circ.circuit_hash().unwrap();
Expand All @@ -181,7 +183,8 @@ where
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost.clone();
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);
last_best_time = Instant::now();
}
circ_cnt += 1;
Expand Down Expand Up @@ -297,7 +300,8 @@ where
if cost < best_circ_cost {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);
if let Some(t) = opt.progress_timeout {
progress_timeout_event = crossbeam_channel::at(Instant::now() + Duration::from_secs(t));
}
Expand Down Expand Up @@ -337,7 +341,8 @@ where
if cost < best_circ_cost {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(&best_circ_cost);
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(&best_circ_cost, num_rewrites);
}
}
PriorityChannelLog::CircuitCount {
Expand Down Expand Up @@ -381,7 +386,8 @@ where
let mut chunks =
CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op));

logger.log_best(circ_cost.clone());
let num_rewrites = circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(circ_cost.clone(), num_rewrites);

let (joins, rx_work): (Vec<_>, Vec<_>) = chunks
.iter_mut()
Expand Down Expand Up @@ -420,7 +426,8 @@ where
let best_circ = chunks.reassemble()?;
let best_circ_cost = self.cost(&best_circ);
if best_circ_cost.clone() < circ_cost {
logger.log_best(best_circ_cost.clone());
let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len());
logger.log_best(best_circ_cost.clone(), num_rewrites);
}

logger.log_processing_end(opt.n_threads.get(), None, best_circ_cost, true, false);
Expand Down
13 changes: 11 additions & 2 deletions tket2/src/optimiser/badger/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,17 @@ impl<'w> BadgerLogger<'w> {

/// Log a new best candidate
#[inline]
pub fn log_best<C: Debug + serde::Serialize>(&mut self, best_cost: C) {
self.log(format!("new best of size {:?}", best_cost));
pub fn log_best<C: Debug + serde::Serialize>(
&mut self,
best_cost: C,
num_rewrites: Option<usize>,
) {
match num_rewrites {
Some(rs) => self.log(format!(
"new best of size {best_cost:?} after {rs} rewrites"
)),
None => self.log(format!("new best of size {:?}", best_cost)),
}
if let Some(csv_writer) = self.circ_candidates_csv.as_mut() {
csv_writer.serialize(BestCircSer::new(best_cost)).unwrap();
csv_writer.flush().unwrap();
Expand Down
1 change: 1 addition & 0 deletions tket2/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#[cfg(feature = "portmatching")]
pub mod ecc_rewriter;
pub mod strategy;
pub mod trace;

use bytemuck::TransparentWrapper;
#[cfg(feature = "portmatching")]
Expand Down
42 changes: 40 additions & 2 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use itertools::Itertools;
use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, MajorMinorCost};
use crate::Circuit;

use super::trace::{RewriteTrace, RewriteTracer};
use super::CircuitRewrite;

/// Rewriting strategies for circuit optimisation.
Expand Down Expand Up @@ -144,6 +145,7 @@ impl RewriteStrategy for GreedyRewriteStrategy {
}
changed_nodes.extend(rewrite.subcircuit().nodes().iter().copied());
cost_delta += rewrite.node_count_delta();
circ.add_rewrite_trace(RewriteTrace::new(1));
rewrite
.apply(&mut circ)
.expect("Could not perform rewrite in greedy strategy");
Expand Down Expand Up @@ -219,6 +221,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
let mut curr_circ = circ.clone();
let mut changed_nodes = HashSet::new();
let mut cost_delta = Default::default();
let mut composed_rewrite_count = 0;
for (rewrite, delta) in &rewrites[i..] {
if !changed_nodes.is_empty()
&& rewrite
Expand All @@ -230,11 +233,15 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
changed_nodes.extend(rewrite.invalidation_set());
cost_delta += delta.clone();

composed_rewrite_count += 1;

rewrite
.clone()
.apply(&mut curr_circ)
.expect("Could not perform rewrite in exhaustive greedy strategy");
}

curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count));
rewrite_sets.circs.push(curr_circ);
rewrite_sets.cost_deltas.push(cost_delta);
}
Expand Down Expand Up @@ -285,6 +292,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveThresholdStrategy<T> {
return None;
}
let mut circ = circ.clone();
circ.add_rewrite_trace(RewriteTrace::new(1));
rw.apply(&mut circ).expect("invalid pattern match");
Some((circ, target_cost.sub_cost(&pattern_cost)))
})
Expand Down Expand Up @@ -462,6 +470,7 @@ mod tests {
use hugr::{Hugr, Node};
use itertools::Itertools;

use crate::rewrite::trace::REWRITE_TRACING_ENABLED;
use crate::{
circuit::Circuit,
rewrite::{CircuitRewrite, Subcircuit},
Expand Down Expand Up @@ -494,9 +503,16 @@ mod tests {

#[test]
fn test_greedy_strategy() {
let circ = n_cx(10);
let mut circ = n_cx(10);
let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec();

assert_eq!(circ.rewrite_trace(), None);
circ.enable_rewrite_tracing();
match REWRITE_TRACING_ENABLED {
true => assert_eq!(circ.rewrite_trace(), Some(vec![])),
false => assert_eq!(circ.rewrite_trace(), None),
}

let rws = [
rw_to_empty(&circ, cx_gates[0..2].to_vec()),
rw_to_full(&circ, cx_gates[4..7].to_vec()),
Expand All @@ -508,12 +524,17 @@ mod tests {
let rewritten = strategy.apply_rewrites(rws, &circ);
assert_eq!(rewritten.len(), 1);
assert_eq!(rewritten.circs[0].num_gates(), 5);

if REWRITE_TRACING_ENABLED {
assert_eq!(rewritten.circs[0].rewrite_trace().unwrap().len(), 3);
}
}

#[test]
fn test_exhaustive_default_strategy() {
let circ = n_cx(10);
let mut circ = n_cx(10);
let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec();
circ.enable_rewrite_tracing();

let rws = [
rw_to_empty(&circ, cx_gates[0..2].to_vec()),
Expand All @@ -527,6 +548,23 @@ mod tests {
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect();
assert_eq!(circ_lens, exp_circ_lens);

if REWRITE_TRACING_ENABLED {
// Each strategy branch applies a single rewrite, composed of
// multiple individual elements from `rws`.
assert_eq!(
rewritten.circs[0].rewrite_trace().unwrap(),
vec![RewriteTrace::new(3)]
);
assert_eq!(
rewritten.circs[1].rewrite_trace().unwrap(),
vec![RewriteTrace::new(2)]
);
assert_eq!(
rewritten.circs[2].rewrite_trace().unwrap(),
vec![RewriteTrace::new(1)]
);
}
}

#[test]
Expand Down
Loading