Skip to content

Commit

Permalink
feat: check_lowered function for checking all Tk2ops have been removed
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Sep 5, 2024
1 parent 637450f commit 0a7d537
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tket2-hseries/src/extension/hseries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::extension::futures;
use super::futures::future_type;

mod lower;
pub use lower::lower_tk2_op;
use lower::pi_mul;
pub use lower::{check_lowered, lower_tk2_op};

/// The "tket2.hseries" extension id.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("tket2.hseries");
Expand Down
44 changes: 37 additions & 7 deletions tket2-hseries/src/extension/hseries/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use hugr::{
ops::{self, OpTrait},
std_extensions::arithmetic::float_types::ConstF64,
types::Signature,
Hugr, Node, Wire,
Hugr, HugrView, Node, Wire,
};
use itertools::Either;
use thiserror::Error;
Expand Down Expand Up @@ -113,10 +113,32 @@ fn lower_direct(hugr: &mut impl HugrMut) -> Result<Vec<Node>, HugrError> {
.collect())
}

/// Check there are no "tket2.quantum" ops left in the HUGR.
///
/// # Errors
/// Returns vector of nodes that are not lowered.
pub fn check_lowered(hugr: &impl HugrView) -> Result<(), Vec<Node>> {
let unlowered: Vec<Node> = hugr
.nodes()
.filter_map(|node| {
let optype = hugr.get_optype(node);
optype.as_extension_op().and_then(|ext| {
(ext.def().extension() == &tket2::extension::TKET2_EXTENSION_ID).then_some(node)
})
})
.collect();

if unlowered.is_empty() {
Ok(())
} else {
Err(unlowered)
}
}

#[cfg(test)]
mod test {
use hugr::{builder::FunctionBuilder, type_row, HugrView};
use tket2::Circuit;
use tket2::{extension::angle::ANGLE_TYPE, Circuit};

use super::*;
use rstest::rstest;
Expand Down Expand Up @@ -150,6 +172,7 @@ mod test {
HSeriesOp::QFree
]
);
assert_eq!(check_lowered(&h), Ok(()));
}

#[rstest]
Expand Down Expand Up @@ -182,20 +205,27 @@ mod test {
if let Some(hseries_ops) = hseries_ops {
assert_eq!(ops, hseries_ops);
}

assert_eq!(check_lowered(&h), Ok(()));
}

#[test]
fn test_mixed() {
let mut b = DFGBuilder::new(Signature::new_endo(type_row![])).unwrap();
let mut b = DFGBuilder::new(Signature::new(type_row![ANGLE_TYPE], type_row![])).unwrap();
let [angle] = b.input_wires_arr();
let [q] = b.add_dataflow_op(Tk2Op::QAlloc, []).unwrap().outputs_arr();
let [q] = b.add_dataflow_op(Tk2Op::H, [q]).unwrap().outputs_arr();
let [q] = b
.add_dataflow_op(Tk2Op::Rx, [q, angle])
.unwrap()
.outputs_arr();
b.add_dataflow_op(Tk2Op::QFree, [q]).unwrap();
let mut h = b.finish_hugr_with_outputs([], &REGISTRY).unwrap();

let lowered = lower_tk2_op(&mut h).unwrap();
assert_eq!(lowered.len(), 3);
println!("{}", h.mermaid_string());

assert_eq!(h.node_count(), 13); // dfg, input, output, alloc, phasedx, rz, free + 3x(float + load)
assert_eq!(lowered.len(), 4);
// dfg, input, output, alloc, phasedx, rz, phasedx, free + 4x(float + load)
assert_eq!(h.node_count(), 16);
assert_eq!(check_lowered(&h), Ok(()));
}
}

0 comments on commit 0a7d537

Please sign in to comment.