From 8219418e5eacc16002cf8582649b893c8247ce7f Mon Sep 17 00:00:00 2001 From: Kevin Hartman Date: Mon, 28 Oct 2024 11:53:29 -0400 Subject: [PATCH] Move Var to DAGCircuit. --- crates/circuit/src/dag_circuit.rs | 63 ++++++++++++++++++++++++++++--- crates/circuit/src/lib.rs | 36 ------------------ 2 files changed, 57 insertions(+), 42 deletions(-) diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index bc2aeaa11b88..1f4bb2d0f02b 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -29,7 +29,7 @@ use crate::interner::{Interned, Interner}; use crate::operations::{Operation, OperationRef, Param, PyInstruction, StandardGate}; use crate::packed_instruction::{PackedInstruction, PackedOperation}; use crate::rustworkx_core_vnext::isomorphism; -use crate::{BitType, Clbit, Qubit, TupleLikeArg, Var}; +use crate::{BitType, Clbit, Qubit, TupleLikeArg}; use hashbrown::{HashMap, HashSet}; use indexmap::IndexMap; @@ -73,6 +73,47 @@ use std::cell::OnceCell; static CONTROL_FLOW_OP_NAMES: [&str; 4] = ["for_loop", "while_loop", "if_else", "switch_case"]; static SEMANTIC_EQ_SYMMETRIC: [&str; 4] = ["barrier", "swap", "break_loop", "continue_loop"]; +/// An opaque key type that identifies a variable within a [DAGCircuit]. +/// +/// When a new variable is added to the DAG, it is associated internally +/// with one of these keys. When enumerating DAG nodes and edges, you can +/// retrieve the associated variable instance via [DAGCircuit::get_var]. +/// +/// These keys are [Eq], but this is semantically valid only for keys +/// from the same [DAGCircuit] instance. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct Var(BitType); + +impl Var { + /// Construct a new [Var] object from a usize. if you have a u32 you can + /// create a [Var] object directly with `Var(0u32)`. This will panic + /// if the `usize` index exceeds `u32::MAX`. + #[inline(always)] + fn new(index: usize) -> Self { + Var(index + .try_into() + .unwrap_or_else(|_| panic!("Index value '{}' exceeds the maximum bit width!", index))) + } + + /// Get the index of the [Var] + #[inline(always)] + fn index(&self) -> usize { + self.0 as usize + } +} + +impl From for Var { + fn from(value: BitType) -> Self { + Var(value) + } +} + +impl From for BitType { + fn from(value: Var) -> Self { + value.0 + } +} + #[derive(Clone, Debug)] pub enum NodeType { QubitIn(Qubit), @@ -4684,7 +4725,8 @@ def _format(operand): "cannot add inputs to a circuit with captures", )); } - self.add_var(py, var, DAGVarType::Input) + self.add_var(py, var, DAGVarType::Input)?; + Ok(()) } /// Add a captured variable to the circuit. @@ -4700,7 +4742,8 @@ def _format(operand): "cannot add captures to a circuit with inputs", )); } - self.add_var(py, var, DAGVarType::Capture) + self.add_var(py, var, DAGVarType::Capture)?; + Ok(()) } /// Add a declared local variable to the circuit. @@ -4708,7 +4751,8 @@ def _format(operand): /// Args: /// var: the variable to add. fn add_declared_var(&mut self, py: Python, var: &Bound) -> PyResult<()> { - self.add_var(py, var, DAGVarType::Declare) + self.add_var(py, var, DAGVarType::Declare)?; + Ok(()) } /// Total number of classical variables tracked by the circuit. @@ -6148,7 +6192,14 @@ impl DAGCircuit { Ok(out_map) } - fn add_var(&mut self, py: Python, var: &Bound, type_: DAGVarType) -> PyResult<()> { + /// Retrieve a variable given its unique [Var] key within the DAG. + /// + /// The provided [Var] must be from this [DAGCircuit]. + pub fn get_var(&self, py: Python, var: Var) -> Option> { + self.vars.get(var).map(|v| v.bind(py).clone()) + } + + fn add_var(&mut self, py: Python, var: &Bound, type_: DAGVarType) -> PyResult { // The setup of the initial graph structure between an "in" and an "out" node is the same as // the bit-related `_add_wire`, but this logically needs to do different bookkeeping around // tracking the properties @@ -6181,7 +6232,7 @@ impl DAGCircuit { out_node: out_index, }, ); - Ok(()) + Ok(var_idx) } fn check_op_addition(&self, py: Python, inst: &PackedInstruction) -> PyResult<()> { diff --git a/crates/circuit/src/lib.rs b/crates/circuit/src/lib.rs index 4955475e3507..8705d52d1aa7 100644 --- a/crates/circuit/src/lib.rs +++ b/crates/circuit/src/lib.rs @@ -79,30 +79,6 @@ impl Clbit { } } -// Note: Var is meant to be opaque outside of this crate, i.e. -// users have no business creating them directly and should instead -// get them from the containing circuit. -#[derive(Copy, Clone, Debug, Hash, Ord, PartialOrd, Eq, PartialEq)] -pub struct Var(pub(crate) BitType); - -impl Var { - /// Construct a new [Var] object from a usize. if you have a u32 you can - /// create a [Var] object directly with `Var(0u32)`. This will panic - /// if the `usize` index exceeds `u32::MAX`. - #[inline(always)] - pub fn new(index: usize) -> Self { - Var(index - .try_into() - .unwrap_or_else(|_| panic!("Index value '{}' exceeds the maximum bit width!", index))) - } - - /// Get the index of the [Var] - #[inline(always)] - pub fn index(&self) -> usize { - self.0 as usize - } -} - pub struct TupleLikeArg<'py> { value: Bound<'py, PyTuple>, } @@ -146,18 +122,6 @@ impl From for BitType { } } -impl From for Var { - fn from(value: BitType) -> Self { - Var(value) - } -} - -impl From for BitType { - fn from(value: Var) -> Self { - value.0 - } -} - pub fn circuit(m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?;