Skip to content

Commit

Permalink
fix!: remove TryFrom for extension ops use cast (#592)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: TryFrom implementations for extension op structs
removed, use `cast`
  • Loading branch information
ss2165 authored Sep 5, 2024
1 parent 7591c08 commit 5ca29af
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 77 deletions.
12 changes: 0 additions & 12 deletions tket2-hseries/src/extension/futures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,18 +202,6 @@ impl TryFrom<&OpType> for FutureOpDef {
}
}

impl TryFrom<&OpType> for FutureOp {
type Error = OpLoadError;

fn try_from(value: &OpType) -> Result<Self, Self::Error> {
Self::from_op(
value
.as_extension_op()
.ok_or(OpLoadError::NotMember(value.name().into()))?,
)
}
}

/// An extension trait for [Dataflow] providing methods to add "tket2.futures"
/// operations.
pub trait FutureOpBuilder: Dataflow {
Expand Down
14 changes: 1 addition & 13 deletions tket2-hseries/src/extension/hseries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use hugr::{
builder::{BuildError, Dataflow},
extension::{
prelude::{BOOL_T, QB_T},
simple_op::{try_from_name, MakeOpDef, MakeRegisteredOp, OpLoadError},
simple_op::{try_from_name, MakeOpDef, MakeRegisteredOp},
ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, Version, PRELUDE,
},
ops::{NamedOp as _, OpType},
std_extensions::arithmetic::float_types::{EXTENSION as FLOAT_TYPES, FLOAT64_TYPE},
type_row,
types::Signature,
Expand Down Expand Up @@ -115,17 +114,6 @@ impl MakeRegisteredOp for HSeriesOp {
}
}

impl TryFrom<&OpType> for HSeriesOp {
type Error = OpLoadError;
fn try_from(value: &OpType) -> Result<Self, Self::Error> {
Self::from_op(
value
.as_extension_op()
.ok_or(OpLoadError::NotMember(value.name().into()))?,
)
}
}

/// An extension trait for [Dataflow] providing methods to add
/// "tket2.hseries" operations.
pub trait HSeriesOpBuilder: Dataflow {
Expand Down
15 changes: 2 additions & 13 deletions tket2-hseries/src/extension/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,17 +380,6 @@ impl TryFrom<&OpType> for ResultOpDef {
}
}

impl TryFrom<&OpType> for ResultOp {
type Error = OpLoadError;

fn try_from(value: &OpType) -> Result<Self, Self::Error> {
let Some(ext) = value.as_extension_op() else {
Err(OpLoadError::NotMember(value.name().into()))?
};
Self::from_extension_op(ext)
}
}

/// An extension trait for [Dataflow] providing methods to add "tket2.result"
/// operations.
pub trait ResultOpBuilder: Dataflow {
Expand Down Expand Up @@ -464,15 +453,15 @@ pub(crate) mod test {
let op_t: OpType = op.clone().to_extension_op().unwrap().into();
let def_op: ResultOpDef = (&op_t).try_into().unwrap();
assert_eq!(op.result_op, def_op);
let new_op: ResultOp = (&op_t).try_into().unwrap();
let new_op: ResultOp = op_t.cast().unwrap();
assert_eq!(&new_op, op);

let op = op.clone().array_op(ARR_SIZE);
let op_t: OpType = op.clone().to_extension_op().unwrap().into();
let def_op: ResultOpDef = (&op_t).try_into().unwrap();

assert_eq!(op.result_op, def_op);
let new_op: ResultOp = (&op_t).try_into().unwrap();
let new_op: ResultOp = op_t.cast().unwrap();
assert_eq!(&new_op, &op);
}
let [b, f, i, u, a_b, a_f, a_i, a_u] = func_builder.input_wires_arr();
Expand Down
13 changes: 6 additions & 7 deletions tket2-hseries/src/lazify_measure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ impl LazifyMeasurePass {
let mut state =
State::new(
hugr.nodes()
.filter_map(|n| match hugr.get_optype(n).try_into() {
Ok(Tk2Op::Measure) => Some(WorkItem::ReplaceMeasure(n)),
.filter_map(|n| match hugr.get_optype(n).cast() {
Some(Tk2Op::Measure) => Some(WorkItem::ReplaceMeasure(n)),
_ => None,
}),
);
Expand Down Expand Up @@ -157,7 +157,7 @@ fn simple_replace_measure(
node: Node,
) -> (HashSet<(Node, IncomingPort)>, SimpleReplacement) {
assert!(
hugr.get_optype(node).try_into() == Ok(Tk2Op::Measure),
hugr.get_optype(node).cast() == Some(Tk2Op::Measure),
"{:?}",
hugr.get_optype(node)
);
Expand Down Expand Up @@ -216,7 +216,6 @@ impl WorkItem {

#[cfg(test)]
mod test {
use cool_asserts::assert_matches;

use hugr::{
extension::{ExtensionRegistry, EMPTY_REG, PRELUDE},
Expand Down Expand Up @@ -261,12 +260,12 @@ mod test {
let mut num_lazy_measure = 0;
for n in hugr.nodes() {
let ot = hugr.get_optype(n);
if let Ok(FutureOpDef::Read) = ot.try_into() {
if let Some(FutureOpDef::Read) = ot.cast() {
num_read += 1;
} else if let Ok(HSeriesOp::LazyMeasure) = ot.try_into() {
} else if let Some(HSeriesOp::LazyMeasure) = ot.cast() {
num_lazy_measure += 1;
} else {
assert_matches!(Tk2Op::try_from(ot), Err(_))
assert_eq!(ot.cast::<Tk2Op>(), None)
}
}

Expand Down
4 changes: 2 additions & 2 deletions tket2-hseries/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ impl HSeriesPass {
.run_validated_pass(hugr, registry, |hugr, _| {
force_order(hugr, hugr.root(), |hugr, node| {
let optype = hugr.get_optype(node);
if Tk2Op::try_from(optype).is_ok() || HSeriesOp::try_from(optype).is_ok() {
if optype.cast::<Tk2Op>().is_some() || optype.cast::<HSeriesOp>().is_some() {
// quantum ops are lifted as early as possible
-1
} else if let Ok(FutureOpDef::Read) = hugr.get_optype(node).try_into() {
} else if let Some(FutureOpDef::Read) = hugr.get_optype(node).cast() {
// read ops are sunk as late as possible
1
} else {
Expand Down
13 changes: 7 additions & 6 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::borrow::{Borrow, Cow};
use hugr::builder::{CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::QB_T;
use hugr::ops::handle::NodeHandle;
use hugr::ops::{ExtensionOp, OpType};
use hugr::ops::{ExtensionOp, NamedOp, OpType};
use hugr::types::Type;
use itertools::Itertools;
use pyo3::exceptions::{PyAttributeError, PyValueError};
Expand Down Expand Up @@ -134,11 +134,12 @@ impl Tk2Circuit {
pub fn circuit_cost<'py>(&self, cost_fn: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyAny>> {
let py = cost_fn.py();
let cost_fn = |op: &OpType| -> PyResult<PyCircuitCost> {
let tk2_op: Tk2Op = op.try_into().map_err(|e| {
PyErr::new::<PyValueError, _>(format!(
"Could not convert circuit operation to a `Tk2Op`: {e}"
))
})?;
let Some(tk2_op) = op.cast::<Tk2Op>() else {
let op_name = op.name();
return Err(PyErr::new::<PyValueError, _>(format!(
"Could not convert circuit operation to a `Tk2Op`: {op_name}"
)));
};
let tk2_py_op = PyTk2Op::from(tk2_op);
let cost = cost_fn.call1((tk2_py_op,))?;
Ok(PyCircuitCost {
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/circuit/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ pub fn is_cx(op: &OpType) -> bool {

/// Returns true if the operation is a quantum operation.
pub fn is_quantum(op: &OpType) -> bool {
let Ok(op): Result<Tk2Op, _> = op.try_into() else {
let Some(op): Option<Tk2Op> = op.cast() else {
return false;
};
op.is_quantum()
Expand Down
17 changes: 1 addition & 16 deletions tket2/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use hugr::ops::NamedOp;
use hugr::{
extension::{
prelude::{BOOL_T, QB_T},
simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp},
simple_op::{try_from_name, MakeOpDef, MakeRegisteredOp},
ExtensionId, OpDef, SignatureFunc,
},
ops::OpType,
Expand Down Expand Up @@ -205,21 +205,6 @@ pub(crate) fn match_symb_const_op(op: &OpType) -> Option<String> {
}
}

impl TryFrom<&OpType> for Tk2Op {
type Error = NotTk2Op;

fn try_from(op: &OpType) -> Result<Self, Self::Error> {
{
match op {
OpType::ExtensionOp(ext) => Tk2Op::from_extension_op(ext).ok(),
OpType::OpaqueOp(opaque) => try_from_name(&opaque.name(), &EXTENSION_ID).ok(),
_ => None,
}
.ok_or_else(|| NotTk2Op { op: op.clone() })
}
}
}

#[cfg(test)]
pub(crate) mod test {

Expand Down
7 changes: 3 additions & 4 deletions tket2/src/passes/commutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ fn load_slices(circ: &Circuit<impl HugrView>) -> SliceVec {

/// check if node is one we want to put in to a slice.
fn is_slice_op(h: &impl HugrView, node: Node) -> bool {
let op: Result<Tk2Op, _> = h.get_optype(node).try_into();
op.is_ok()
h.get_optype(node).cast::<Tk2Op>().is_some()
}

/// Starting from starting_index, work back along slices to check for the
Expand Down Expand Up @@ -156,12 +155,12 @@ fn commutes_at_slice(

let port = command.port_of_qb(q, Direction::Incoming)?;

let op: Tk2Op = circ.hugr().get_optype(command.node()).try_into().ok()?;
let op: Tk2Op = circ.hugr().get_optype(command.node()).cast()?;
// TODO: if not tk2op, might still have serialized commutation data we
// can use.
let pauli = commutation_on_port(&op.qubit_commutation(), port)?;

let other_op: Tk2Op = circ.hugr().get_optype(other_com.node()).try_into().ok()?;
let other_op: Tk2Op = circ.hugr().get_optype(other_com.node()).cast()?;
let other_pauli = commutation_on_port(
&other_op.qubit_commutation(),
other_com.port_of_qb(q, Direction::Outgoing)?,
Expand Down
4 changes: 1 addition & 3 deletions tket2/src/serialize/pytket/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ use hugr::ops::OpType;
use hugr::IncomingPort;
use tket_json_rs::circuit_json;

use crate::Tk2Op;

use self::native::NativeOp;
use self::serialised::OpaqueTk1Op;
use super::OpConvertError;
Expand All @@ -40,7 +38,7 @@ impl Tk1Op {
///
/// Returns an error if the operation is not supported by the TKET1 serialization.
pub fn try_from_optype(op: OpType) -> Result<Option<Self>, OpConvertError> {
if let Ok(tk2op) = Tk2Op::try_from(&op) {
if let Some(tk2op) = op.cast() {
let native = NativeOp::try_from_tk2op(tk2op)
.ok_or_else(|| OpConvertError::UnsupportedOpSerialization(op))?;
// Skip serialisation for some special cases.
Expand Down

0 comments on commit 5ca29af

Please sign in to comment.