diff --git a/crates/circuit/src/circuit_data.rs b/crates/circuit/src/circuit_data.rs index dd2618f03f9a..dc83c9dd2487 100644 --- a/crates/circuit/src/circuit_data.rs +++ b/crates/circuit/src/circuit_data.rs @@ -343,7 +343,7 @@ impl CircuitData { /// Get a (cached) sorted list of the Python-space `Parameter` instances tracked by this circuit /// data's parameter table. #[getter] - pub fn get_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> { + pub fn get_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> { self.param_table.py_parameters(py) } diff --git a/crates/circuit/src/parameter_table.rs b/crates/circuit/src/parameter_table.rs index 8825fbd71772..8fae607e0b90 100644 --- a/crates/circuit/src/parameter_table.rs +++ b/crates/circuit/src/parameter_table.rs @@ -10,6 +10,8 @@ // copyright notice, and modified files need to carry a notice indicating // that they have been altered from the originals. +use std::cell::OnceCell; + use hashbrown::hash_map::Entry; use hashbrown::{HashMap, HashSet}; use thiserror::Error; @@ -123,18 +125,17 @@ pub struct ParameterTable { by_name: HashMap, /// Additional information on any `ParameterVector` instances that have elements in the circuit. vectors: HashMap, - /// Sort order of the parameters. This is lexicographical for most parameters, except elements - /// of a `ParameterVector` are sorted within the vector by numerical index. We calculate this - /// on demand and cache it; an empty `order` implies it is not currently calculated. We don't - /// use `Option` so we can re-use the allocation for partial parameter bindings. + /// Cache of the sort order of the parameters. This is lexicographical for most parameters, + /// except elements of a `ParameterVector` are sorted within the vector by numerical index. We + /// calculate this on demand and cache it. /// - /// Any method that adds or a removes a parameter is responsible for invalidating this cache. - order: Vec, + /// Any method that adds or removes a parameter needs to invalidate this. + order_cache: OnceCell>, /// Cache of a Python-space list of the parameter objects, in order. We only generate this /// specifically when asked. /// - /// Any method that adds or a removes a parameter is responsible for invalidating this cache. - py_parameters: Option>, + /// Any method that adds or removes a parameter needs to invalidate this. + py_parameters_cache: OnceCell>, } impl ParameterTable { @@ -194,8 +195,6 @@ impl ParameterTable { None }; self.by_name.insert(name.clone(), uuid); - self.order.clear(); - self.py_parameters = None; let mut uses = HashSet::new(); if let Some(usage) = usage { uses.insert_unique_unchecked(usage); @@ -206,6 +205,7 @@ impl ParameterTable { element, object: param_ob.clone().unbind(), }); + self.invalidate_cache(); } } Ok(uuid) @@ -226,19 +226,20 @@ impl ParameterTable { } /// Get the (maybe cached) Python list of the sorted `Parameter` objects. - pub fn py_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> { - if let Some(py_parameters) = self.py_parameters.as_ref() { - return py_parameters.clone_ref(py).into_bound(py); - } - self.ensure_sorted(); - let out = PyList::new_bound( - py, - self.order - .iter() - .map(|uuid| self.by_uuid[uuid].object.clone_ref(py).into_bound(py)), - ); - self.py_parameters = Some(out.clone().unbind()); - out + pub fn py_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> { + self.py_parameters_cache + .get_or_init(|| { + PyList::new_bound( + py, + self.order_cache + .get_or_init(|| self.sorted_order()) + .iter() + .map(|uuid| self.by_uuid[uuid].object.bind(py).clone()), + ) + .unbind() + }) + .bind(py) + .clone() } /// Get a Python set of all tracked `Parameter` objects. @@ -246,23 +247,18 @@ impl ParameterTable { PySet::new_bound(py, self.by_uuid.values().map(|info| &info.object)) } - /// Ensure that the `order` field is populated and sorted. - fn ensure_sorted(&mut self) { - // If `order` is already populated, it's sorted; it's the responsibility of the methods of - // this struct that mutate it to invalidate the cache. - if !self.order.is_empty() { - return; - } - self.order.reserve(self.by_uuid.len()); - self.order.extend(self.by_uuid.keys()); - self.order.sort_unstable_by_key(|uuid| { + /// Get the sorted order of the `ParameterTable`. This does not access the cache. + fn sorted_order(&self) -> Vec { + let mut out = self.by_uuid.keys().copied().collect::>(); + out.sort_unstable_by_key(|uuid| { let info = &self.by_uuid[uuid]; if let Some(vec) = info.element.as_ref() { (&self.vectors[&vec.vector_uuid].name, vec.index) } else { (&info.name, 0) } - }) + }); + out } /// Add a use of a parameter to the table. @@ -305,9 +301,8 @@ impl ParameterTable { vec_entry.remove_entry(); } } - self.order.clear(); - self.py_parameters = None; entry.remove_entry(); + self.invalidate_cache(); } Ok(()) } @@ -332,26 +327,28 @@ impl ParameterTable { (vector_info.refcount > 0).then_some(vector_info) }); } - self.order.clear(); - self.py_parameters = None; + self.invalidate_cache(); Ok(info.uses) } /// Clear this table, yielding the Python parameter objects and their uses in sorted order. + /// + /// The clearing effect is eager and not dependent on the iteration. pub fn drain_ordered( - &'_ mut self, - ) -> impl Iterator, HashSet)> + '_ { - self.ensure_sorted(); + &mut self, + ) -> impl ExactSizeIterator, HashSet)> { + let order = self + .order_cache + .take() + .unwrap_or_else(|| self.sorted_order()); + let by_uuid = ::std::mem::take(&mut self.by_uuid); self.by_name.clear(); self.vectors.clear(); - self.py_parameters = None; - self.order.drain(..).map(|uuid| { - let info = self - .by_uuid - .remove(&uuid) - .expect("tracked UUIDs should be consistent"); - (info.object, info.uses) - }) + self.py_parameters_cache.take(); + ParameterTableDrain { + order: order.into_iter(), + by_uuid, + } } /// Empty this `ParameterTable` of all its contents. This does not affect the capacities of the @@ -360,8 +357,12 @@ impl ParameterTable { self.by_uuid.clear(); self.by_name.clear(); self.vectors.clear(); - self.order.clear(); - self.py_parameters = None; + self.invalidate_cache(); + } + + fn invalidate_cache(&mut self) { + self.order_cache.take(); + self.py_parameters_cache.take(); } /// Expose the tracked data for a given parameter as directly as possible to Python space. @@ -396,9 +397,33 @@ impl ParameterTable { visit.call(&info.object)? } // We don't need to / can't visit the `PyBackedStr` stores. - if let Some(list) = self.py_parameters.as_ref() { + if let Some(list) = self.py_parameters_cache.get() { visit.call(list)? } Ok(()) } } + +struct ParameterTableDrain { + order: ::std::vec::IntoIter, + by_uuid: HashMap, +} +impl Iterator for ParameterTableDrain { + type Item = (Py, HashSet); + + fn next(&mut self) -> Option { + self.order.next().map(|uuid| { + let info = self + .by_uuid + .remove(&uuid) + .expect("tracked UUIDs should be consistent"); + (info.object, info.uses) + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.order.size_hint() + } +} +impl ExactSizeIterator for ParameterTableDrain {} +impl ::std::iter::FusedIterator for ParameterTableDrain {}