Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: angle type no longer parametric. #577

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions tket2/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,6 @@ impl CustomSignatureFunc for Tk1Signature {
}
}

/// Angle type with given log denominator.
pub fn angle_custom_type(log_denom: u8) -> CustomType {
angle::angle_custom_type(&TKET2_EXTENSION, angle::type_arg(log_denom))
}

/// Name of tket 2 extension.
pub const TKET2_EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2");

Expand Down
242 changes: 81 additions & 161 deletions tket2/src/extension/angle.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
use std::{cmp::max, num::NonZeroU64};

use hugr::extension::prelude::USIZE_T;
use hugr::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use hugr::extension::ExtensionSet;
use hugr::ops::constant::{downcast_equal_consts, CustomConst};
use hugr::types::PolyFuncTypeRV;
use hugr::type_row;
use hugr::{
extension::{prelude::ERROR_TYPE, SignatureError, SignatureFromArgs, TypeDef},
types::{
type_param::{TypeArgError, TypeParam},
ConstTypeError, CustomType, PolyFuncType, Signature, Type, TypeArg, TypeBound,
},
types::{ConstTypeError, CustomType, Signature, Type, TypeBound},
Extension,
};
use itertools::Itertools;
use smol_str::SmolStr;
use std::f64::consts::TAU;
use strum::{EnumIter, EnumString, IntoStaticStr};

use super::TKET2_EXTENSION_ID;

/// Identifier for the angle type.
const ANGLE_TYPE_ID: SmolStr = SmolStr::new_inline("angle");
/// Dyadic rational angle type (as [CustomType])
pub const ANGLE_CUSTOM_TYPE: CustomType =
CustomType::new_simple(ANGLE_TYPE_ID, TKET2_EXTENSION_ID, TypeBound::Copyable);

pub(super) fn angle_custom_type(extension: &Extension, log_denom_arg: TypeArg) -> CustomType {
angle_def(extension).instantiate([log_denom_arg]).unwrap()
}

fn angle_type(log_denom: u8) -> Type {
Type::new_extension(super::angle_custom_type(log_denom))
}
/// Type representing an angle that is a dyadic rational multiple of π (as [Type])
pub const ANGLE_TYPE: Type = Type::new_extension(ANGLE_CUSTOM_TYPE);

/// The largest permitted log-denominator.
pub const LOG_DENOM_MAX: u8 = 53;
Expand All @@ -35,27 +29,6 @@ const fn is_valid_log_denom(n: u8) -> bool {
n <= LOG_DENOM_MAX
}

/// Type parameter for the log-denominator of an angle.
pub const LOG_DENOM_TYPE_PARAM: TypeParam =
TypeParam::bounded_nat(NonZeroU64::MIN.saturating_add(LOG_DENOM_MAX as u64));

/// Get the log-denominator of the specified type argument or error if the argument is invalid.
fn get_log_denom(arg: &TypeArg) -> Result<u8, TypeArgError> {
match arg {
TypeArg::BoundedNat { n } if is_valid_log_denom(*n as u8) => Ok(*n as u8),
_ => Err(TypeArgError::TypeMismatch {
arg: arg.clone(),
param: LOG_DENOM_TYPE_PARAM,
}),
}
}

pub(super) const fn type_arg(log_denom: u8) -> TypeArg {
TypeArg::BoundedNat {
n: log_denom as u64,
}
}

/// An angle
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct ConstAngle {
Expand Down Expand Up @@ -118,7 +91,7 @@ impl CustomConst for ConstAngle {
}

fn get_type(&self) -> Type {
super::angle_custom_type(self.log_denom).into()
ANGLE_TYPE
}

fn equal_consts(&self, other: &dyn CustomConst) -> bool {
Expand All @@ -129,136 +102,90 @@ impl CustomConst for ConstAngle {
}
}

/// Collect a vector into an array.
fn collect_array<const N: usize, T: std::fmt::Debug>(arr: &[T]) -> [&T; N] {
arr.iter().collect_vec().try_into().unwrap()
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
#[allow(missing_docs, non_camel_case_types)]
#[non_exhaustive]
/// Angle operations
pub enum AngleOp {
/// Truncate an angle to one with a lower log-denominator with the nearest value, rounding down in [0, 2π) if necessary
atrunc,
/// Addition of angles
aadd,
/// Subtraction of the second angle from the first
asub,
/// Negation of an angle
aneg,
}

fn abinop_sig() -> impl SignatureFromArgs {
struct BinOp;
const PARAMS: &[TypeParam] = &[LOG_DENOM_TYPE_PARAM];
impl MakeOpDef for AngleOp {
fn from_def(
op_def: &hugr::extension::OpDef,
) -> Result<Self, hugr::extension::simple_op::OpLoadError>
where
Self: Sized,
{
hugr::extension::simple_op::try_from_name(op_def.name(), &TKET2_EXTENSION_ID)
}

impl SignatureFromArgs for BinOp {
fn compute_signature(
&self,
arg_values: &[TypeArg],
) -> Result<PolyFuncTypeRV, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
let m: u8 = get_log_denom(arg0)?;
let n: u8 = get_log_denom(arg1)?;
let l: u8 = max(m, n);
let poly_func: PolyFuncType =
Signature::new(vec![angle_type(m), angle_type(n)], vec![angle_type(l)]).into();
Ok(poly_func.into())
fn signature(&self) -> hugr::extension::SignatureFunc {
match self {
AngleOp::atrunc => {
Signature::new(type_row![ANGLE_TYPE, USIZE_T], type_row![ANGLE_TYPE])
}
AngleOp::aadd | AngleOp::asub => {
Signature::new(type_row![ANGLE_TYPE, ANGLE_TYPE], type_row![ANGLE_TYPE])
}
AngleOp::aneg => Signature::new_endo(type_row![ANGLE_TYPE]),
}
.into()
}

fn static_params(&self) -> &[TypeParam] {
PARAMS
}
fn description(&self) -> String {
match self {
AngleOp::atrunc => "truncate an angle to one with a lower log-denominator with the nearest value, rounding down in [0, 2π) if necessary",
AngleOp::aadd => "addition of angles",
AngleOp::asub => "subtraction of the second angle from the first",
AngleOp::aneg => "negation of an angle",
}.to_owned()
}

BinOp
}
fn extension(&self) -> hugr::extension::ExtensionId {
TKET2_EXTENSION_ID
}

fn angle_def(extension: &Extension) -> &TypeDef {
extension.get_type(&ANGLE_TYPE_ID).unwrap()
// TODO constant folding
// https://github.com/CQCL/tket2/issues/405
}

fn generic_angle_type(var_id: usize, angle_type_def: &TypeDef) -> Type {
Type::new_extension(
angle_type_def
.instantiate(vec![TypeArg::new_var_use(var_id, LOG_DENOM_TYPE_PARAM)])
.unwrap(),
)
impl MakeRegisteredOp for AngleOp {
fn extension_id(&self) -> hugr::extension::ExtensionId {
TKET2_EXTENSION_ID
}

fn registry<'s, 'r: 's>(&'s self) -> &'r hugr::extension::ExtensionRegistry {
&super::REGISTRY
}
}

pub(super) fn add_to_extension(extension: &mut Extension) {
let angle_type_def = extension
extension
.add_type(
ANGLE_TYPE_ID,
vec![LOG_DENOM_TYPE_PARAM],
"angle value with a given log-denominator".to_owned(),
vec![],
"angle type expressed as dyadic rational multiples of 2π".to_owned(),
TypeBound::Copyable.into(),
)
.unwrap()
.clone();

extension
.add_op(
"atrunc".into(),
"truncate an angle to one with a lower log-denominator with the same value, rounding \
down in [0, 2π) if necessary"
.to_owned(),
PolyFuncType::new(
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
// atrunc_sig(extension).unwrap(),
Signature::new(
vec![generic_angle_type(0, &angle_type_def)],
vec![generic_angle_type(1, &angle_type_def)],
),
),
)
.unwrap();

extension
.add_op(
"aconvert".into(),
"convert an angle to one with another log-denominator having the same value, if \
possible, otherwise return an error"
.to_owned(),
PolyFuncType::new(
vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM],
Signature::new(
vec![generic_angle_type(0, &angle_type_def)],
vec![Type::new_sum([
generic_angle_type(1, &angle_type_def),
ERROR_TYPE,
])],
),
),
)
.unwrap();

extension
.add_op("aadd".into(), "addition of angles".to_owned(), abinop_sig())
.unwrap();

extension
.add_op(
"asub".into(),
"subtraction of the second angle from the first".to_owned(),
abinop_sig(),
)
.unwrap();

extension
.add_op(
"aneg".into(),
"negation of an angle".to_owned(),
PolyFuncType::new(
vec![LOG_DENOM_TYPE_PARAM],
Signature::new_endo(vec![generic_angle_type(0, &angle_type_def)]),
),
)
.unwrap();
AngleOp::load_all_ops(extension).expect("add fail");
}

#[cfg(test)]
mod test {
use super::*;
use crate::extension::angle_custom_type;
use hugr::types::TypeArg;

#[test]
fn test_angle_log_denoms() {
let type_arg_53 = TypeArg::BoundedNat { n: 53 };
assert_eq!(get_log_denom(&type_arg_53).unwrap(), 53);
use hugr::ops::OpType;
use strum::IntoEnumIterator;

let type_arg_54 = TypeArg::BoundedNat { n: 54 };
assert!(matches!(
get_log_denom(&type_arg_54),
Err(TypeArgError::TypeMismatch { .. })
));
}
use super::*;

#[test]
fn test_angle_consts() {
Expand All @@ -269,8 +196,7 @@ mod test {
assert_ne!(const_a32_7, const_a32_8);
assert_eq!(const_a32_7, ConstAngle::new(5, 7).unwrap());

assert_eq!(const_a32_7.get_type(), angle_custom_type(5).into());
assert_ne!(const_a32_7.get_type(), angle_custom_type(6).into());
assert_eq!(const_a32_7.get_type(), ANGLE_TYPE);
assert!(matches!(
ConstAngle::new(3, 256),
Err(ConstTypeError::CustomCheckFail(_))
Expand All @@ -290,20 +216,14 @@ mod test {

assert_eq!(const_a32_8.name(), "a(2π*8/2^6)");
}
#[test]
fn test_binop_sig() {
let binop_sig = abinop_sig();

let sig = binop_sig
.compute_signature(&[type_arg(23), type_arg(42)])
.unwrap();

let poly_type: PolyFuncType =
Signature::new(vec![angle_type(23), angle_type(42)], vec![angle_type(42)]).into();
assert_eq!(sig, poly_type.into());

assert!(binop_sig
.compute_signature(&[type_arg(23), type_arg(89)])
.is_err());
#[test]
fn test_ops() {
let ops = AngleOp::iter().collect::<Vec<_>>();
assert_eq!(ops.len(), 4);
for op in ops {
let optype: OpType = op.into();
assert_eq!(optype.cast(), Some(op));
}
}
}
Loading