Skip to content

Commit

Permalink
Add support for adding vars to the dagcircuit
Browse files Browse the repository at this point in the history
  • Loading branch information
mtreinish committed Jul 17, 2024
1 parent 6be2108 commit 6d5fc14
Showing 1 changed file with 96 additions and 2 deletions.
98 changes: 96 additions & 2 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ pub(crate) enum NodeType {
QubitOut(Qubit),
ClbitIn(Clbit),
ClbitOut(Clbit),
VarIn(PyObject),
VarOut(PyObject),
Operation(PackedInstruction),
}

Expand Down Expand Up @@ -269,7 +271,7 @@ pub struct DAGCircuit {
// Python modules we need to frequently access (for now).
control_flow_module: PyControlFlowModule,
circuit_module: PyCircuitModule,
vars_info: HashMap<String, PyObject>,
vars_info: HashMap<String, DAGVarInfo>,
vars_by_type: [Py<PySet>; 3],
}

Expand Down Expand Up @@ -440,12 +442,21 @@ struct BitLocations {
registers: Py<PyList>,
}

#[derive(Copy, Clone, Debug)]
enum DAGVarType {
INPUT = 0,
CAPTURE = 1,
DECLARE = 2,
}

#[derive(Clone, Debug)]
struct DAGVarInfo {
var: PyObject,
type_: DAGVarType,
in_node: NodeIndex,
out_node: NodeIndex,
}

#[pymethods]
impl DAGCircuit {
#[new]
Expand Down Expand Up @@ -1664,6 +1675,12 @@ def _format(operand):
false,
)?;
}
NodeType::VarIn(var) => {
todo!()
}
NodeType::VarOut(var) => {
todo!()
}
NodeType::QubitOut(_) | NodeType::ClbitOut(_) => (),
}
}
Expand Down Expand Up @@ -3705,6 +3722,34 @@ def _format(operand):
Ok(PyString::new_bound(py, std::str::from_utf8(&buffer)?))
}

fn add_input_var(&mut self, py: Python, var: &Bound<PyAny>) -> PyResult<()> {
if !self.vars_by_type[DAGVarType::CAPTURE as usize]
.bind(py)
.is_empty()
{
return Err(DAGCircuitError::new_err(
"cannot add inputs to a circuit with captures",
));
}
self.add_var(py, var, DAGVarType::INPUT)
}

fn add_captured_var(&mut self, py: Python, var: &Bound<PyAny>) -> PyResult<()> {
if !self.vars_by_type[DAGVarType::INPUT as usize]
.bind(py)
.is_empty()
{
return Err(DAGCircuitError::new_err(
"cannot add captures to a circuit with inputs",
));
}
self.add_var(py, var, DAGVarType::CAPTURE)
}

fn add_declared_var(&mut self, var: &Bound<PyAny>) -> PyResult<()> {
self.add_var(var.py(), var, DAGVarType::DECLARE)
}

#[getter]
fn num_vars(&self) -> usize {
self.vars_info.len()
Expand Down Expand Up @@ -3736,7 +3781,7 @@ def _format(operand):
let raw_name = var.getattr("name")?;
let var_name: String = raw_name.extract()?;
match self.vars_info.get(&var_name) {
Some(var_in_dag) => Ok(var_in_dag.is(var)),
Some(var_in_dag) => Ok(var_in_dag.var.is(var)),
None => Ok(false),
}
}
Expand Down Expand Up @@ -4431,6 +4476,12 @@ impl DAGCircuit {
)?
.into_any()
}
NodeType::VarIn(var) => {
Py::new(py, DAGInNode::new(py, id, var.clone_ref(py)))?.into_any()
}
NodeType::VarOut(var) => {
Py::new(py, DAGOutNode::new(py, id, var.clone_ref(py)))?.into_any()
}
};
Ok(dag_node)
}
Expand Down Expand Up @@ -4555,4 +4606,47 @@ impl DAGCircuit {
self.remove_op_node(node);
Ok(out_map)
}

fn add_var(&mut self, py: Python, var: &Bound<PyAny>, type_: DAGVarType) -> PyResult<()> {
// The setup of the initial graph structure between an "in" and an "out" node is the same as
// the bit-related `_add_wire`, but this logically needs to do different bookkeeping around
// tracking the properties
if !var.getattr("standalone")?.extract::<bool>()? {
return Err(DAGCircuitError::new_err(
"cannot add variables that wrap `Clbit` or `ClassicalRegister` instances",
));
}
let var_name: String = var.getattr("name")?.extract::<String>()?;
if let Some(previous) = self.vars_info.get(&var_name) {
if var.eq(previous.var.clone_ref(py))? {
return Err(DAGCircuitError::new_err(
"var is already present in circuit",
));
}
return Err(DAGCircuitError::new_err(
"Can not add var as its name shadows an existing var",
));
}
let in_node = NodeType::VarIn(var.clone().unbind());
let out_node = NodeType::VarOut(var.clone().unbind());
let in_index = self.dag.add_node(in_node);
let out_index = self.dag.add_node(out_node);
self.dag
.add_edge(in_index, out_index, Wire::Var(var.clone().unbind()));
self.var_input_map.insert(var.clone().unbind(), in_index);
self.var_output_map.insert(var.clone().unbind(), out_index);
self.vars_by_type[type_ as usize]
.bind(py)
.add(var.clone().unbind())?;
self.vars_info.insert(
var_name,
DAGVarInfo {
var: var.clone().unbind(),
type_,
in_node: in_index,
out_node: out_index,
},
);
Ok(())
}
}

0 comments on commit 6d5fc14

Please sign in to comment.