Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow immutable borrow to access QuantumCircuit.parameters #12918

Merged
merged 4 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ impl CircuitData {
/// Get a (cached) sorted list of the Python-space `Parameter` instances tracked by this circuit
/// data's parameter table.
#[getter]
pub fn get_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> {
pub fn get_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> {
self.param_table.py_parameters(py)
}

Expand Down
129 changes: 77 additions & 52 deletions crates/circuit/src/parameter_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use std::cell::OnceCell;

use hashbrown::hash_map::Entry;
use hashbrown::{HashMap, HashSet};
use thiserror::Error;
Expand Down Expand Up @@ -123,18 +125,17 @@ pub struct ParameterTable {
by_name: HashMap<PyBackedStr, ParameterUuid>,
/// Additional information on any `ParameterVector` instances that have elements in the circuit.
vectors: HashMap<VectorUuid, VectorInfo>,
/// Sort order of the parameters. This is lexicographical for most parameters, except elements
/// of a `ParameterVector` are sorted within the vector by numerical index. We calculate this
/// on demand and cache it; an empty `order` implies it is not currently calculated. We don't
/// use `Option<Vec>` so we can re-use the allocation for partial parameter bindings.
/// Cache of the sort order of the parameters. This is lexicographical for most parameters,
/// except elements of a `ParameterVector` are sorted within the vector by numerical index. We
/// calculate this on demand and cache it.
///
/// Any method that adds or a removes a parameter is responsible for invalidating this cache.
order: Vec<ParameterUuid>,
/// Any method that adds or removes a parameter needs to invalidate this.
order_cache: OnceCell<Vec<ParameterUuid>>,
/// Cache of a Python-space list of the parameter objects, in order. We only generate this
/// specifically when asked.
///
/// Any method that adds or a removes a parameter is responsible for invalidating this cache.
py_parameters: Option<Py<PyList>>,
/// Any method that adds or removes a parameter needs to invalidate this.
py_parameters_cache: OnceCell<Py<PyList>>,
}

impl ParameterTable {
Expand Down Expand Up @@ -194,8 +195,6 @@ impl ParameterTable {
None
};
self.by_name.insert(name.clone(), uuid);
self.order.clear();
self.py_parameters = None;
let mut uses = HashSet::new();
if let Some(usage) = usage {
uses.insert_unique_unchecked(usage);
Expand All @@ -206,6 +205,7 @@ impl ParameterTable {
element,
object: param_ob.clone().unbind(),
});
self.invalidate_cache();
}
}
Ok(uuid)
Expand All @@ -226,43 +226,39 @@ impl ParameterTable {
}

/// Get the (maybe cached) Python list of the sorted `Parameter` objects.
pub fn py_parameters<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> {
if let Some(py_parameters) = self.py_parameters.as_ref() {
return py_parameters.clone_ref(py).into_bound(py);
}
self.ensure_sorted();
let out = PyList::new_bound(
py,
self.order
.iter()
.map(|uuid| self.by_uuid[uuid].object.clone_ref(py).into_bound(py)),
);
self.py_parameters = Some(out.clone().unbind());
out
pub fn py_parameters<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> {
self.py_parameters_cache
.get_or_init(|| {
PyList::new_bound(
py,
self.order_cache
.get_or_init(|| self.sorted_order())
.iter()
.map(|uuid| self.by_uuid[uuid].object.bind(py).clone()),
)
.unbind()
})
.bind(py)
.clone()
}

/// Get a Python set of all tracked `Parameter` objects.
pub fn py_parameters_unsorted<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PySet>> {
PySet::new_bound(py, self.by_uuid.values().map(|info| &info.object))
}

/// Ensure that the `order` field is populated and sorted.
fn ensure_sorted(&mut self) {
// If `order` is already populated, it's sorted; it's the responsibility of the methods of
// this struct that mutate it to invalidate the cache.
if !self.order.is_empty() {
return;
}
self.order.reserve(self.by_uuid.len());
self.order.extend(self.by_uuid.keys());
self.order.sort_unstable_by_key(|uuid| {
/// Get the sorted order of the `ParameterTable`. This does not access the cache.
fn sorted_order(&self) -> Vec<ParameterUuid> {
let mut out = self.by_uuid.keys().copied().collect::<Vec<_>>();
out.sort_unstable_by_key(|uuid| {
let info = &self.by_uuid[uuid];
if let Some(vec) = info.element.as_ref() {
(&self.vectors[&vec.vector_uuid].name, vec.index)
} else {
(&info.name, 0)
}
})
});
out
}

/// Add a use of a parameter to the table.
Expand Down Expand Up @@ -305,9 +301,8 @@ impl ParameterTable {
vec_entry.remove_entry();
}
}
self.order.clear();
self.py_parameters = None;
entry.remove_entry();
self.invalidate_cache();
}
Ok(())
}
Expand All @@ -332,26 +327,28 @@ impl ParameterTable {
(vector_info.refcount > 0).then_some(vector_info)
});
}
self.order.clear();
self.py_parameters = None;
self.invalidate_cache();
Ok(info.uses)
}

/// Clear this table, yielding the Python parameter objects and their uses in sorted order.
///
/// The clearing effect is eager and not dependent on the iteration.
pub fn drain_ordered(
&'_ mut self,
) -> impl Iterator<Item = (Py<PyAny>, HashSet<ParameterUse>)> + '_ {
self.ensure_sorted();
&mut self,
) -> impl ExactSizeIterator<Item = (Py<PyAny>, HashSet<ParameterUse>)> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to use an abstract return type here, or can we just return ParameterTableDrain?

In the case of Vec::drain, a public Drain struct gets returned explicitly. That seems like a useful pattern for the general case since we can then implement additional traits for the iterator if needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly I'd done it because I didn't want ParameterTableDrain to exist at all, so I kept the type private. The original version of this function had the iterator type unnameable, because I could borrow everything out of the struct without needing a manual implementation. The Rust stdlib exports all its iterator type objects possibly in part because anonymous impl Trait in return position only turned up in 1.26, so there was no alternative before that.

I can make it a public type if you think it's better that way, but I'm not sure I agree with the justification "can implementation additional traits [...] if needed" in this case - if we need them, then we can name the type and make it public, but at the moment, we've no need.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it's fine as is, given that we can swap in an explicit struct if needed.

(Though I'd be curious if Rust would have used abstract return types for this kind of thing in the standard lib if they'd be around from the beginning.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't know if they would have either. The anonymous types don't make for the prettiest reading of documentation, tbf, though equally, all the structs named Iter just mean you have to go to a mostly empty separate docs page to find which of the Iterator extension traits it implements.

let order = self
.order_cache
.take()
.unwrap_or_else(|| self.sorted_order());
let by_uuid = ::std::mem::take(&mut self.by_uuid);
self.by_name.clear();
self.vectors.clear();
self.py_parameters = None;
self.order.drain(..).map(|uuid| {
let info = self
.by_uuid
.remove(&uuid)
.expect("tracked UUIDs should be consistent");
(info.object, info.uses)
})
self.py_parameters_cache.take();
ParameterTableDrain {
order: order.into_iter(),
by_uuid,
}
}

/// Empty this `ParameterTable` of all its contents. This does not affect the capacities of the
Expand All @@ -360,8 +357,12 @@ impl ParameterTable {
self.by_uuid.clear();
self.by_name.clear();
self.vectors.clear();
self.order.clear();
self.py_parameters = None;
self.invalidate_cache();
}

fn invalidate_cache(&mut self) {
self.order_cache.take();
self.py_parameters_cache.take();
}

/// Expose the tracked data for a given parameter as directly as possible to Python space.
Expand Down Expand Up @@ -396,9 +397,33 @@ impl ParameterTable {
visit.call(&info.object)?
}
// We don't need to / can't visit the `PyBackedStr` stores.
if let Some(list) = self.py_parameters.as_ref() {
if let Some(list) = self.py_parameters_cache.get() {
visit.call(list)?
}
Ok(())
}
}

struct ParameterTableDrain {
order: ::std::vec::IntoIter<ParameterUuid>,
by_uuid: HashMap<ParameterUuid, ParameterInfo>,
}
impl Iterator for ParameterTableDrain {
type Item = (Py<PyAny>, HashSet<ParameterUse>);

fn next(&mut self) -> Option<Self::Item> {
self.order.next().map(|uuid| {
let info = self
.by_uuid
.remove(&uuid)
.expect("tracked UUIDs should be consistent");
(info.object, info.uses)
})
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.order.size_hint()
}
}
impl ExactSizeIterator for ParameterTableDrain {}
impl ::std::iter::FusedIterator for ParameterTableDrain {}
Loading