From 6bba713700c096b3b33213c57cceea478bd57999 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Thu, 2 Nov 2023 14:16:45 -0700 Subject: [PATCH 01/11] Remove caching of collection summaries. Pre-fetching of collection summaries was quite expensive and we did not use those summaries very often. Removing the cache completely, now we query summaries each time but only for the collections that are actually used. --- .../datasets/byDimensions/_manager.py | 26 ++- .../datasets/byDimensions/summaries.py | 155 ++++++++---------- .../registry/interfaces/_collections.py | 6 +- .../butler/registry/interfaces/_datasets.py | 25 ++- .../registry/queries/_sql_query_backend.py | 3 +- 5 files changed, 123 insertions(+), 92 deletions(-) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index d52d337a2b..a0d0dfadfd 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -7,6 +7,7 @@ import logging import warnings from collections import defaultdict +from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any import sqlalchemy @@ -207,7 +208,6 @@ def refresh(self) -> None: # Docstring inherited from DatasetRecordStorageManager. byName: dict[str, ByDimensionsDatasetRecordStorage] = {} byId: dict[int, ByDimensionsDatasetRecordStorage] = {} - dataset_types: dict[int, DatasetType] = {} c = self._static.dataset_type.columns with self._db.query(self._static.dataset_type.select()) as sql_result: sql_rows = sql_result.mappings().fetchall() @@ -255,10 +255,8 @@ def refresh(self) -> None: ) byName[datasetType.name] = storage byId[storage._dataset_type_id] = storage - dataset_types[row["id"]] = datasetType self._byName = byName self._byId = byId - self._summaries.refresh(dataset_types) def remove(self, name: str) -> None: # Docstring inherited from DatasetRecordStorageManager. @@ -496,9 +494,29 @@ def getDatasetRef(self, id: DatasetId) -> DatasetRef | None: run=self._collections[row[self._collections.getRunForeignKeyName()]].name, ) + def _dataset_type_factory(self, dataset_type_id: int) -> DatasetType: + """Return dataset type given its ID.""" + return self._byId[dataset_type_id].datasetType + def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummary: # Docstring inherited from DatasetRecordStorageManager. - return self._summaries.get(collection) + summaries = self._summaries.fetch_summaries([collection], None, self._dataset_type_factory) + return summaries[collection.key] + + def fetch_summaries( + self, collections: Iterable[CollectionRecord], dataset_types: Iterable[DatasetType] | None = None + ) -> Mapping[Any, CollectionSummary]: + # Docstring inherited from DatasetRecordStorageManager. + dataset_type_ids: list[int] | None = None + if dataset_types is not None: + dataset_type_ids = [] + for dataset_type in dataset_types: + if dataset_type.isComponent(): + dataset_type = dataset_type.makeCompositeDatasetType() + # Assume we know all possible names. + dataset_type_id = self._byName[dataset_type.name]._dataset_type_id + dataset_type_ids.append(dataset_type_id) + return self._summaries.fetch_summaries(collections, dataset_type_ids, self._dataset_type_factory) _versions: list[VersionTuple] """Schema version for this class.""" diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py index 1356576376..73511ddd0c 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py @@ -31,7 +31,7 @@ __all__ = ("CollectionSummaryManager",) -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from typing import Any, Generic, TypeVar import sqlalchemy @@ -42,13 +42,13 @@ from ..._collection_summary import CollectionSummary from ..._collection_type import CollectionType from ...interfaces import ( - ChainedCollectionRecord, CollectionManager, CollectionRecord, Database, DimensionRecordStorageManager, StaticTablesContext, ) +from ...wildcards import CollectionWildcard _T = TypeVar("_T") @@ -148,7 +148,6 @@ def __init__( self._collectionKeyName = collections.getCollectionForeignKeyName() self._dimensions = dimensions self._tables = tables - self._cache: dict[Any, CollectionSummary] = {} @classmethod def initialize( @@ -237,38 +236,53 @@ def update( self._tables.dimensions[dimension], *[{self._collectionKeyName: collection.key, dimension: v} for v in values], ) - # Update the in-memory cache, too. These changes will remain even if - # the database inserts above are rolled back by some later exception in - # the same transaction, but that's okay: we never promise that a - # CollectionSummary has _just_ the dataset types and governor dimension - # values that are actually present, only that it is guaranteed to - # contain any dataset types or governor dimension values that _may_ be - # present. - # That guarantee (and the possibility of rollbacks) means we can't get - # away with checking the cache before we try the database inserts, - # however; if someone had attempted to insert datasets of some dataset - # type previously, and that rolled back, and we're now trying to insert - # some more datasets of that same type, it would not be okay to skip - # the DB summary table insertions because we found entries in the - # in-memory cache. - self.get(collection).update(summary) - - def refresh(self, dataset_types: Mapping[int, DatasetType]) -> None: - """Load all collection summary information from the database. + + def fetch_summaries( + self, + collections: Iterable[CollectionRecord], + dataset_type_ids: Iterable[int] | None, + dataset_type_factory: Callable[[int], DatasetType], + ) -> Mapping[Any, CollectionSummary]: + """Fetch collection summaries given their names and dataset types. Parameters ---------- - dataset_types : `~collections.abc.Mapping` [`int`, `DatasetType`] - Mapping of an `int` dataset_type_id value to `DatasetType` - instance. Summaries are only loaded for dataset types that appear - in this mapping. + collections : `~collections.abc.Iterable` [`CollectionRecord`] + Collection records to query. + dataset_type_ids : `~collections.abc.Iterable` [`int`] + IDs of dataset types to include into returned summaries. If `None` + then all dataset types will be included. + dataset_type_factory : `Callable` + Method that returns `DatasetType` instance given its dataset type + ID. + + Returns + ------- + summaries : `~collections.abc.Mapping` [`Any`, `CollectionSummary`] + Collection summaries indexed by collection record key. This mapping + will also contain all nested non-chained collections of the chained + collections. """ + # Need to expand all chained collections first. + non_chains: list[CollectionRecord] = [] + chains: dict[CollectionRecord, list[CollectionRecord]] = {} + for collection in collections: + if collection.type is CollectionType.CHAINED: + children = self._collections.resolve_wildcard( + CollectionWildcard.from_names([collection.name]), + flatten_chains=True, + include_chains=False, + ) + non_chains += children + chains[collection] = children + else: + non_chains.append(collection) + # Set up the SQL query we'll use to fetch all of the summary # information at once. - columns = [ - self._tables.datasetType.columns[self._collectionKeyName].label(self._collectionKeyName), - self._tables.datasetType.columns.dataset_type_id.label("dataset_type_id"), - ] + coll_col = self._tables.datasetType.columns[self._collectionKeyName].label(self._collectionKeyName) + dataset_type_id_col = self._tables.datasetType.columns.dataset_type_id.label("dataset_type_id") + columns = [coll_col, dataset_type_id_col] fromClause: sqlalchemy.sql.expression.FromClause = self._tables.datasetType for dimension, table in self._tables.dimensions.items(): columns.append(table.columns[dimension.name].label(dimension.name)) @@ -280,7 +294,12 @@ def refresh(self, dataset_types: Mapping[int, DatasetType]) -> None: ), isouter=True, ) + sql = sqlalchemy.sql.select(*columns).select_from(fromClause) + sql = sql.where(coll_col.in_([coll.key for coll in non_chains])) + if dataset_type_ids is not None: + sql = sql.where(dataset_type_id_col.in_(dataset_type_ids)) + # Run the query and construct CollectionSummary objects from the result # rows. This will never include CHAINED collections or collections # with no datasets. @@ -293,59 +312,29 @@ def refresh(self, dataset_types: Mapping[int, DatasetType]) -> None: collectionKey = row[self._collectionKeyName] # dataset_type_id should also never be None/NULL; it's in the first # table we joined. - if datasetType := dataset_types.get(row["dataset_type_id"]): - # See if we have a summary already for this collection; if not, - # make one. - summary = summaries.get(collectionKey) - if summary is None: - summary = CollectionSummary() - summaries[collectionKey] = summary - # Update the dimensions with the values in this row that - # aren't None/NULL (many will be in general, because these - # enter the query via LEFT OUTER JOIN). - summary.dataset_types.add(datasetType) - for dimension in self._tables.dimensions: - value = row[dimension.name] - if value is not None: - summary.governors.setdefault(dimension.name, set()).add(value) - self._cache = summaries - - def get(self, collection: CollectionRecord) -> CollectionSummary: - """Return a summary for the given collection. + dataset_type = dataset_type_factory(row["dataset_type_id"]) + # See if we have a summary already for this collection; if not, + # make one. + summary = summaries.get(collectionKey) + if summary is None: + summary = CollectionSummary() + summaries[collectionKey] = summary + # Update the dimensions with the values in this row that + # aren't None/NULL (many will be in general, because these + # enter the query via LEFT OUTER JOIN). + summary.dataset_types.add(dataset_type) + for dimension in self._tables.dimensions: + value = row[dimension.name] + if value is not None: + summary.governors.setdefault(dimension.name, set()).add(value) - Parameters - ---------- - collection : `CollectionRecord` - Record describing the collection for which a summary is to be - retrieved. + # Add empty summary for any missing collection. + for collection in non_chains: + if collection.key not in summaries: + summaries[collection.key] = CollectionSummary() - Returns - ------- - summary : `CollectionSummary` - Summary of the dataset types and governor dimension values in - this collection. - """ - summary = self._cache.get(collection.key) - if summary is None: - # When we load the summary information from the database, we don't - # create summaries for CHAINED collections; those are created here - # as needed, and *never* cached - we have no good way to update - # those summaries when some a new dataset is added to a child - # colletion. - if collection.type is CollectionType.CHAINED: - assert isinstance(collection, ChainedCollectionRecord) - child_summaries = [self.get(self._collections.find(child)) for child in collection.children] - if child_summaries: - summary = CollectionSummary.union(*child_summaries) - else: - summary = CollectionSummary() - else: - # Either this collection doesn't have any datasets yet, or the - # only datasets it has were created by some other process since - # the last call to refresh. We assume the former; the user is - # responsible for calling refresh if they want to read - # concurrently-written things. We do remember this in the - # cache. - summary = CollectionSummary() - self._cache[collection.key] = summary - return summary + # Merge children into their chains summaries. + for chain, children in chains.items(): + summaries[chain.key] = CollectionSummary.union(*(summaries[child.key] for child in children)) + + return summaries diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index 2bc5bf30c6..7742faa20c 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -632,10 +632,10 @@ def resolve_wildcard( If `True` (default) recursively yield the child collections of `~CollectionType.CHAINED` collections. include_chains : `bool`, optional - If `False`, return records for `~CollectionType.CHAINED` + If `True`, return records for `~CollectionType.CHAINED` collections themselves. The default is the opposite of - ``flattenChains``: either return records for CHAINED collections or - their children, but not both. + ``flatten_chains``: either return records for CHAINED collections + or their children, but not both. Returns ------- diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py index 3424028804..84a5a735d4 100644 --- a/python/lsst/daf/butler/registry/interfaces/_datasets.py +++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py @@ -32,7 +32,7 @@ __all__ = ("DatasetRecordStorageManager", "DatasetRecordStorage") from abc import ABC, abstractmethod -from collections.abc import Iterable, Iterator, Set +from collections.abc import Iterable, Iterator, Mapping, Set from typing import TYPE_CHECKING, Any from lsst.daf.relation import Relation @@ -603,6 +603,29 @@ def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummar """ raise NotImplementedError() + @abstractmethod + def fetch_summaries( + self, collections: Iterable[CollectionRecord], dataset_types: Iterable[DatasetType] | None = None + ) -> Mapping[Any, CollectionSummary]: + """Fetch collection summaries given their names and dataset types. + + Parameters + ---------- + collections : `~collections.abc.Iterable` [`CollectionRecord`] + Collection records to query. + dataset_types : `~collections.abc.Iterable` [`DatasetType`] or `None` + Dataset types to include into returned summaries. If `None` then + all dataset types will be included. + + Returns + ------- + summaries : `~collections.abc.Mapping` [`Any`, `CollectionSummary`] + Collection summaries indexed by collection record key. This mapping + will also contain all nested non-chained collections of the chained + collections. + """ + raise NotImplementedError() + @abstractmethod def ingest_date_dtype(self) -> type: """Return type of the ``ingest_date`` column.""" diff --git a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py index db574dde44..fc5866e8ba 100644 --- a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py +++ b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py @@ -121,6 +121,7 @@ def filter_dataset_collections( result: dict[DatasetType, list[CollectionRecord]] = { dataset_type: [] for dataset_type in dataset_types } + summaries = self._managers.datasets.fetch_summaries(collections, result.keys()) for dataset_type, filtered_collections in result.items(): for collection_record in collections: if not dataset_type.isCalibration() and collection_record.type is CollectionType.CALIBRATION: @@ -130,7 +131,7 @@ def filter_dataset_collections( f"in CALIBRATION collection {collection_record.name!r}." ) else: - collection_summary = self._managers.datasets.getCollectionSummary(collection_record) + collection_summary = summaries[collection_record.key] if collection_summary.is_compatible_with( dataset_type, governor_constraints, From 780e66ec854f414bdc951f13f97cdc32a419eab5 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Thu, 2 Nov 2023 19:51:06 -0700 Subject: [PATCH 02/11] Get rid of the reverse chain collection cache --- .../daf/butler/registry/collections/_base.py | 17 +++++++++++++++++ .../registry/interfaces/_collections.py | 19 +------------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index f611cd630c..00403d2cd3 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -30,6 +30,7 @@ __all__ = () +import contextlib import itertools from abc import abstractmethod from collections import namedtuple @@ -545,3 +546,19 @@ def _removeCachedRecord(self, record: CollectionRecord) -> None: def _getByName(self, name: str) -> CollectionRecord | None: """Find collection record given collection name.""" raise NotImplementedError() + + def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord]: + # Docstring inherited from CollectionManager. + table = self._tables.collection_chain + sql = ( + sqlalchemy.sql.select(table.columns["parent"]) + .select_from(table) + .where(table.columns["child"] == key) + ) + with self._db.query(sql) as sql_result: + parent_keys = sql_result.scalars().all() + for key in parent_keys: + # TODO: Just in case cached records miss new parent collections. + # This is temporary, will replace with non-cached records soon. + with contextlib.suppress(KeyError): + yield cast(ChainedCollectionRecord, self._records[key]) diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index 7742faa20c..625a02cf7e 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -36,7 +36,6 @@ ] from abc import abstractmethod -from collections import defaultdict from collections.abc import Iterator, Set from typing import TYPE_CHECKING, Any @@ -221,12 +220,6 @@ def update(self, manager: CollectionManager, children: tuple[str, ...], flatten: ) # Delegate to derived classes to do the database updates. self._update(manager, children) - # Update the reverse mapping (from child to parents) in the manager, - # by removing the old relationships and adding back in the new ones. - for old_child in self._children: - manager._parents_by_child[manager.find(old_child).key].discard(self.key) - for new_child in children: - manager._parents_by_child[manager.find(new_child).key].add(self.key) # Actually set this instances sequence of children. self._children = children @@ -246,13 +239,7 @@ def refresh(self, manager: CollectionManager) -> None: The object that manages this records instance and all records instances that may appear as its children. """ - # Clear out the old reverse mapping (from child to parents). - for child in self._children: - manager._parents_by_child[manager.find(child).key].discard(self.key) self._children = self._load(manager) - # Update the reverse mapping (from child to parents) in the manager. - for child in self._children: - manager._parents_by_child[manager.find(child).key].add(self.key) @abstractmethod def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None: @@ -315,7 +302,6 @@ class CollectionManager(VersionedExtension): def __init__(self, *, registry_schema_version: VersionTuple | None = None) -> None: super().__init__(registry_schema_version=registry_schema_version) - self._parents_by_child: defaultdict[Any, set[Any]] = defaultdict(set) @classmethod @abstractmethod @@ -682,7 +668,4 @@ def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord]: key Internal primary key value for the collection. """ - for parent_key in self._parents_by_child[key]: - result = self[parent_key] - assert isinstance(result, ChainedCollectionRecord) - yield result + raise NotImplementedError() From d8c3aed39e5c1327f57fa0b9cb8403e90c0d5e31 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Fri, 3 Nov 2023 11:27:34 -0700 Subject: [PATCH 03/11] Fetch all chained collection definitions in one query. `DefaultCollectionManager.refresh` now runs a single query to fetch full contents of collection_chain table. This removes update logic from collection record classes which became simple data classes now. --- .../daf/butler/registry/collections/_base.py | 294 ++++++------------ .../registry/interfaces/_collections.py | 252 ++++++--------- .../lsst/daf/butler/registry/sql_registry.py | 2 +- 3 files changed, 185 insertions(+), 363 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 00403d2cd3..95a6bfa950 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -33,14 +33,13 @@ import contextlib import itertools from abc import abstractmethod -from collections import namedtuple +from collections import defaultdict, namedtuple from collections.abc import Iterable, Iterator, Set -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast import sqlalchemy -from ..._timespan import Timespan, TimespanDatabaseRepresentation -from ...dimensions import DimensionUniverse +from ..._timespan import TimespanDatabaseRepresentation from .._collection_type import CollectionType from .._exceptions import MissingCollectionError from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple @@ -158,150 +157,10 @@ def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) ) -class DefaultRunRecord(RunRecord): - """Default `RunRecord` implementation. - - This method assumes the same run table definition as produced by - `makeRunTableSpec` method. The only non-fixed name in the schema - is the PK column name, this needs to be passed in a constructor. - - Parameters - ---------- - db : `Database` - Registry database. - key - Unique collection ID, can be the same as ``name`` if ``name`` is used - for identification. Usually this is an integer or string, but can be - other database-specific type. - name : `str` - Run collection name. - table : `sqlalchemy.schema.Table` - Table for run records. - idColumnName : `str` - Name of the identifying column in run table. - host : `str`, optional - Name of the host where run was produced. - timespan : `Timespan`, optional - Timespan for this run. - """ - - def __init__( - self, - db: Database, - key: Any, - name: str, - *, - table: sqlalchemy.schema.Table, - idColumnName: str, - host: str | None = None, - timespan: Timespan | None = None, - ): - super().__init__(key=key, name=name, type=CollectionType.RUN) - self._db = db - self._table = table - self._host = host - if timespan is None: - timespan = Timespan(begin=None, end=None) - self._timespan = timespan - self._idName = idColumnName - - def update(self, host: str | None = None, timespan: Timespan | None = None) -> None: - # Docstring inherited from RunRecord. - if timespan is None: - timespan = Timespan(begin=None, end=None) - row = { - self._idName: self.key, - "host": host, - } - self._db.getTimespanRepresentation().update(timespan, result=row) - count = self._db.update(self._table, {self._idName: self.key}, row) - if count != 1: - raise RuntimeError(f"Run update affected {count} records; expected exactly one.") - self._host = host - self._timespan = timespan - - @property - def host(self) -> str | None: - # Docstring inherited from RunRecord. - return self._host - - @property - def timespan(self) -> Timespan: - # Docstring inherited from RunRecord. - return self._timespan - - -class DefaultChainedCollectionRecord(ChainedCollectionRecord): - """Default `ChainedCollectionRecord` implementation. - - This method assumes the same chain table definition as produced by - `makeCollectionChainTableSpec` method. All column names in the table are - fixed and hard-coded in the methods. - - Parameters - ---------- - db : `Database` - Registry database. - key - Unique collection ID, can be the same as ``name`` if ``name`` is used - for identification. Usually this is an integer or string, but can be - other database-specific type. - name : `str` - Collection name. - table : `sqlalchemy.schema.Table` - Table for chain relationship records. - universe : `DimensionUniverse` - Object managing all known dimensions. - """ - - def __init__( - self, - db: Database, - key: Any, - name: str, - *, - table: sqlalchemy.schema.Table, - universe: DimensionUniverse, - ): - super().__init__(key=key, name=name, universe=universe) - self._db = db - self._table = table - self._universe = universe - - def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None: - # Docstring inherited from ChainedCollectionRecord. - rows = [] - position = itertools.count() - for child in manager.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): - rows.append( - { - "parent": self.key, - "child": child.key, - "position": next(position), - } - ) - with self._db.transaction(): - self._db.delete(self._table, ["parent"], {"parent": self.key}) - self._db.insert(self._table, *rows) - - def _load(self, manager: CollectionManager) -> tuple[str, ...]: - # Docstring inherited from ChainedCollectionRecord. - sql = ( - sqlalchemy.sql.select( - self._table.columns.child, - ) - .select_from(self._table) - .where(self._table.columns.parent == self.key) - .order_by(self._table.columns.position) - ) - with self._db.query(sql) as sql_result: - return tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings()) - - K = TypeVar("K") -class DefaultCollectionManager(Generic[K], CollectionManager): +class DefaultCollectionManager(CollectionManager[K]): """Default `CollectionManager` implementation. This implementation uses record classes defined in this module and is @@ -338,7 +197,7 @@ def __init__( self._db = db self._tables = tables self._collectionIdName = collectionIdName - self._records: dict[K, CollectionRecord] = {} # indexed by record ID + self._records: dict[K, CollectionRecord[K]] = {} # indexed by record ID self._dimensions = dimensions def refresh(self) -> None: @@ -346,54 +205,62 @@ def refresh(self) -> None: sql = sqlalchemy.sql.select( *(list(self._tables.collection.columns) + list(self._tables.run.columns)) ).select_from(self._tables.collection.join(self._tables.run, isouter=True)) + # Extract _all_ chain mappings as well + chain_sql = sqlalchemy.sql.select( + self._tables.collection_chain.columns["parent"], + self._tables.collection_chain.columns["position"], + self._tables.collection_chain.columns["child"], + ).select_from(self._tables.collection_chain) + + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + # Build all chain definitions. + chains_defs: dict[K, list[tuple[int, K]]] = defaultdict(list) + for row in chain_rows: + chains_defs[row["parent"]].append((row["position"], row["child"])) + # Put found records into a temporary instead of updating self._records # in place, for exception safety. - records = [] - chains = [] + records: list[CollectionRecord] = [] TimespanReprClass = self._db.getTimespanRepresentation() - with self._db.query(sql) as sql_result: - sql_rows = sql_result.mappings().fetchall() + id_to_name: dict[K, str] = {} + chained_ids: list[K] = [] for row in sql_rows: collection_id = row[self._tables.collection.columns[self._collectionIdName]] name = row[self._tables.collection.columns.name] + id_to_name[collection_id] = name type = CollectionType(row["type"]) record: CollectionRecord if type is CollectionType.RUN: - record = DefaultRunRecord( + record = RunRecord( key=collection_id, name=name, - db=self._db, - table=self._tables.run, - idColumnName=self._collectionIdName, host=row[self._tables.run.columns.host], timespan=TimespanReprClass.extract(row), ) + records.append(record) elif type is CollectionType.CHAINED: - record = DefaultChainedCollectionRecord( - db=self._db, - key=collection_id, - table=self._tables.collection_chain, - name=name, - universe=self._dimensions.universe, - ) - chains.append(record) + # Need to delay chained collection construction until all names + # are known. + chained_ids.append(collection_id) else: record = CollectionRecord(key=collection_id, name=name, type=type) + records.append(record) + + for chained_id in chained_ids: + children_names = [id_to_name[child_id] for _, child_id in sorted(chains_defs[chained_id])] + record = ChainedCollectionRecord( + key=chained_id, + name=id_to_name[chained_id], + children=children_names, + ) records.append(record) + self._setRecordCache(records) - for chain in chains: - try: - chain.refresh(self) - except MissingCollectionError: - # This indicates a race condition in which some other client - # created a new collection and added it as a child of this - # (pre-existing) chain between the time we fetched all - # collections and the time we queried for parent-child - # relationships. - # Because that's some other unrelated client, we shouldn't care - # about that parent collection anyway, so we just drop it on - # the floor (a manual refresh can be used to get it back). - self._removeCachedRecord(chain) def register( self, name: str, type: CollectionType, doc: str | None = None @@ -412,7 +279,7 @@ def register( assert isinstance(inserted_or_updated, bool) registered = inserted_or_updated assert row is not None - collection_id = row[self._collectionIdName] + collection_id = cast(K, row[self._collectionIdName]) if type is CollectionType.RUN: TimespanReprClass = self._db.getTimespanRepresentation() row, _ = self._db.sync( @@ -421,25 +288,20 @@ def register( returning=("host",) + TimespanReprClass.getFieldNames(), ) assert row is not None - record = DefaultRunRecord( - db=self._db, + record = RunRecord[K]( key=collection_id, name=name, - table=self._tables.run, - idColumnName=self._collectionIdName, host=row["host"], timespan=TimespanReprClass.extract(row), ) elif type is CollectionType.CHAINED: - record = DefaultChainedCollectionRecord( - db=self._db, + record = ChainedCollectionRecord[K]( key=collection_id, name=name, - table=self._tables.collection_chain, - universe=self._dimensions.universe, + children=[], ) else: - record = CollectionRecord(key=collection_id, name=name, type=type) + record = CollectionRecord[K](key=collection_id, name=name, type=type) self._addCachedRecord(record) return record, registered @@ -454,14 +316,14 @@ def remove(self, name: str) -> None: ) self._removeCachedRecord(record) - def find(self, name: str) -> CollectionRecord: + def find(self, name: str) -> CollectionRecord[K]: # Docstring inherited from CollectionManager. result = self._getByName(name) if result is None: raise MissingCollectionError(f"No collection with name '{name}' found.") return result - def __getitem__(self, key: Any) -> CollectionRecord: + def __getitem__(self, key: Any) -> CollectionRecord[K]: # Docstring inherited from CollectionManager. try: return self._records[key] @@ -476,13 +338,13 @@ def resolve_wildcard( done: set[str] | None = None, flatten_chains: bool = True, include_chains: bool | None = None, - ) -> list[CollectionRecord]: + ) -> list[CollectionRecord[K]]: # Docstring inherited if done is None: done = set() include_chains = include_chains if include_chains is not None else not flatten_chains - def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord]: + def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]: if record.name in done: return if record.type in collection_types: @@ -491,12 +353,12 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect yield record if flatten_chains and record.type is CollectionType.CHAINED: done.add(record.name) - for name in cast(ChainedCollectionRecord, record).children: + for name in cast(ChainedCollectionRecord[K], record).children: # flake8 can't tell that we only delete this closure when # we're totally done with it. yield from resolve_nested(self.find(name), done) # noqa: F821 - result: list[CollectionRecord] = [] + result: list[CollectionRecord[K]] = [] if wildcard.patterns is ...: for record in self._records.values(): @@ -526,7 +388,7 @@ def setDocumentation(self, key: Any, doc: str | None) -> None: # Docstring inherited from CollectionManager. self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) - def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: + def _setRecordCache(self, records: Iterable[CollectionRecord[K]]) -> None: """Set internal record cache to contain given records, old cached records will be removed. """ @@ -534,20 +396,20 @@ def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: for record in records: self._records[record.key] = record - def _addCachedRecord(self, record: CollectionRecord) -> None: + def _addCachedRecord(self, record: CollectionRecord[K]) -> None: """Add single record to cache.""" self._records[record.key] = record - def _removeCachedRecord(self, record: CollectionRecord) -> None: + def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: """Remove single record from cache.""" del self._records[record.key] @abstractmethod - def _getByName(self, name: str) -> CollectionRecord | None: + def _getByName(self, name: str) -> CollectionRecord[K] | None: """Find collection record given collection name.""" raise NotImplementedError() - def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord]: + def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord[K]]: # Docstring inherited from CollectionManager. table = self._tables.collection_chain sql = ( @@ -561,4 +423,42 @@ def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord]: # TODO: Just in case cached records miss new parent collections. # This is temporary, will replace with non-cached records soon. with contextlib.suppress(KeyError): - yield cast(ChainedCollectionRecord, self._records[key]) + yield cast(ChainedCollectionRecord[K], self._records[key]) + + def update_chain( + self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False + ) -> ChainedCollectionRecord[K]: + # Docstring inherited from CollectionManager. + children_as_wildcard = CollectionWildcard.from_names(children) + for record in self.resolve_wildcard( + children_as_wildcard, + flatten_chains=True, + include_chains=True, + collection_types={CollectionType.CHAINED}, + ): + if record == chain: + raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.") + if flatten: + children = tuple( + record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True) + ) + + rows = [] + position = itertools.count() + names = [] + for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): + rows.append( + { + "parent": chain.key, + "child": child.key, + "position": next(position), + } + ) + names.append(child.name) + with self._db.transaction(): + self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key}) + self._db.insert(self._tables.collection_chain, *rows) + + record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names)) + self._addCachedRecord(record) + return record diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index 625a02cf7e..3cce276b6b 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -36,11 +36,10 @@ ] from abc import abstractmethod -from collections.abc import Iterator, Set -from typing import TYPE_CHECKING, Any +from collections.abc import Iterable, Iterator, Set +from typing import TYPE_CHECKING, Any, Generic, TypeVar from ..._timespan import Timespan -from ...dimensions import DimensionUniverse from .._collection_type import CollectionType from ..wildcards import CollectionWildcard from ._versioning import VersionedExtension, VersionTuple @@ -50,7 +49,10 @@ from ._dimensions import DimensionRecordStorageManager -class CollectionRecord: +_Key = TypeVar("_Key") + + +class CollectionRecord(Generic[_Key]): """A struct used to represent a collection in internal `Registry` APIs. User-facing code should always just use a `str` to represent collections. @@ -75,7 +77,7 @@ class CollectionRecord: participate in some subclass equality definition. """ - def __init__(self, key: Any, name: str, type: CollectionType): + def __init__(self, key: _Key, name: str, type: CollectionType): self.key = key self.name = name self.type = type @@ -85,7 +87,7 @@ def __init__(self, key: Any, name: str, type: CollectionType): """Name of the collection (`str`). """ - key: Any + key: _Key """The primary/foreign key value for this collection. """ @@ -110,184 +112,85 @@ def __str__(self) -> str: return self.name -class RunRecord(CollectionRecord): +class RunRecord(CollectionRecord[_Key]): """A subclass of `CollectionRecord` that adds execution information and an interface for updating it. - """ - @abstractmethod - def update(self, host: str | None = None, timespan: Timespan | None = None) -> None: - """Update the database record for this run with new execution - information. - - Values not provided will set to ``NULL`` in the database, not ignored. + Parameters + ---------- + key: `object` + Unique collection key. + name : `str` + Name of the collection. + host : `str`, optional + Name of the host or system on which this run was produced. + timespan: `Timespan`, optional + Begin and end timestamps for the period over which the run was + produced. + """ - Parameters - ---------- - host : `str`, optional - Name of the host or system on which this run was produced. - Detailed form to be set by higher-level convention; from the - `Registry` perspective, this is an entirely opaque value. - timespan : `Timespan`, optional - Begin and end timestamps for the period over which the run was - produced. `None`/``NULL`` values are interpreted as infinite - bounds. - """ - raise NotImplementedError() + host: str | None + """Name of the host or system on which this run was produced (`str` or + `None`). + """ - @property - @abstractmethod - def host(self) -> str | None: - """Return the name of the host or system on which this run was - produced (`str` or `None`). - """ - raise NotImplementedError() + timespan: Timespan + """Begin and end timestamps for the period over which the run was produced. + None`/``NULL`` values are interpreted as infinite bounds. + """ - @property - @abstractmethod - def timespan(self) -> Timespan: - """Begin and end timestamps for the period over which the run was - produced. `None`/``NULL`` values are interpreted as infinite - bounds. - """ - raise NotImplementedError() + def __init__( + self, + key: _Key, + name: str, + *, + host: str | None = None, + timespan: Timespan | None = None, + ): + super().__init__(key=key, name=name, type=CollectionType.RUN) + self.host = host + if timespan is None: + timespan = Timespan(begin=None, end=None) + self.timespan = timespan def __repr__(self) -> str: return f"RunRecord(key={self.key!r}, name={self.name!r})" -class ChainedCollectionRecord(CollectionRecord): +class ChainedCollectionRecord(CollectionRecord[_Key]): """A subclass of `CollectionRecord` that adds the list of child collections in a ``CHAINED`` collection. Parameters ---------- - key - Unique collection ID, can be the same as ``name`` if ``name`` is used - for identification. Usually this is an integer or string, but can be - other database-specific type. + key: `object` + Unique collection key. name : `str` Name of the collection. + children: Iterable[str], + Ordered sequence of names of child collections. """ - def __init__(self, key: Any, name: str, universe: DimensionUniverse): - super().__init__(key=key, name=name, type=CollectionType.CHAINED) - self._children: tuple[str, ...] = () - - @property - def children(self) -> tuple[str, ...]: - """The ordered search path of child collections that define this chain - (`tuple` [ `str` ]). - """ - return self._children - - def update(self, manager: CollectionManager, children: tuple[str, ...], flatten: bool) -> None: - """Redefine this chain to search the given child collections. - - This method should be used by all external code to set children. It - delegates to `_update`, which is what should be overridden by - subclasses. - - Parameters - ---------- - manager : `CollectionManager` - The object that manages this records instance and all records - instances that may appear as its children. - children : `tuple` [ `str` ] - A collection search path that should be resolved to set the child - collections of this chain. - flatten : `bool` - If `True`, recursively flatten out any nested - `~CollectionType.CHAINED` collections in ``children`` first. - - Raises - ------ - ValueError - Raised when the child collections contain a cycle. - """ - children_as_wildcard = CollectionWildcard.from_names(children) - for record in manager.resolve_wildcard( - children_as_wildcard, - flatten_chains=True, - include_chains=True, - collection_types={CollectionType.CHAINED}, - ): - if record == self: - raise ValueError(f"Cycle in collection chaining when defining '{self.name}'.") - if flatten: - children = tuple( - record.name for record in manager.resolve_wildcard(children_as_wildcard, flatten_chains=True) - ) - # Delegate to derived classes to do the database updates. - self._update(manager, children) - # Actually set this instances sequence of children. - self._children = children - - def refresh(self, manager: CollectionManager) -> None: - """Load children from the database, using the given manager to resolve - collection primary key values into records. - - This method exists to ensure that all collections that may appear in a - chain are known to the manager before any particular chain tries to - retrieve their records from it. `ChainedCollectionRecord` subclasses - can rely on it being called sometime after their own ``__init__`` to - finish construction. - - Parameters - ---------- - manager : `CollectionManager` - The object that manages this records instance and all records - instances that may appear as its children. - """ - self._children = self._load(manager) - - @abstractmethod - def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None: - """Protected implementation hook for `update`. - - This method should be implemented by subclasses to update the database - to reflect the children given. It should never be called by anything - other than `update`, which should be used by all external code. - - Parameters - ---------- - manager : `CollectionManager` - The object that manages this records instance and all records - instances that may appear as its children. - children : `tuple` [ `str` ] - A collection search path that should be resolved to set the child - collections of this chain. Guaranteed not to contain cycles. - """ - raise NotImplementedError() - - @abstractmethod - def _load(self, manager: CollectionManager) -> tuple[str, ...]: - """Protected implementation hook for `refresh`. - - This method should be implemented by subclasses to retrieve the chain's - child collections from the database and return them. It should never - be called by anything other than `refresh`, which should be used by all - external code. - - Parameters - ---------- - manager : `CollectionManager` - The object that manages this records instance and all records - instances that may appear as its children. + children: tuple[str, ...] + """The ordered search path of child collections that define this chain + (`tuple` [ `str` ]). + """ - Returns - ------- - children : `tuple` [ `str` ] - The ordered sequence of collection names that defines the chained - collection. Guaranteed not to contain cycles. - """ - raise NotImplementedError() + def __init__( + self, + key: Any, + name: str, + *, + children: Iterable[str], + ): + super().__init__(key=key, name=name, type=CollectionType.CHAINED) + self.children = tuple(children) def __repr__(self) -> str: return f"ChainedCollectionRecord(key={self.key!r}, name={self.name!r}, children={self.children!r})" -class CollectionManager(VersionedExtension): +class CollectionManager(Generic[_Key], VersionedExtension): """An interface for managing the collections (including runs) in a `Registry`. @@ -467,7 +370,7 @@ def refresh(self) -> None: @abstractmethod def register( self, name: str, type: CollectionType, doc: str | None = None - ) -> tuple[CollectionRecord, bool]: + ) -> tuple[CollectionRecord[_Key], bool]: """Ensure that a collection of the given name and type are present in the layer this manager is associated with. @@ -533,7 +436,7 @@ def remove(self, name: str) -> None: raise NotImplementedError() @abstractmethod - def find(self, name: str) -> CollectionRecord: + def find(self, name: str) -> CollectionRecord[_Key]: """Return the collection record associated with the given name. Parameters @@ -562,7 +465,7 @@ def find(self, name: str) -> CollectionRecord: raise NotImplementedError() @abstractmethod - def __getitem__(self, key: Any) -> CollectionRecord: + def __getitem__(self, key: Any) -> CollectionRecord[_Key]: """Return the collection record associated with the given primary/foreign key value. @@ -600,7 +503,7 @@ def resolve_wildcard( done: set[str] | None = None, flatten_chains: bool = True, include_chains: bool | None = None, - ) -> list[CollectionRecord]: + ) -> list[CollectionRecord[_Key]]: """Iterate over collection records that match a wildcard. Parameters @@ -631,7 +534,7 @@ def resolve_wildcard( raise NotImplementedError() @abstractmethod - def getDocumentation(self, key: Any) -> str | None: + def getDocumentation(self, key: _Key) -> str | None: """Retrieve the documentation string for a collection. Parameters @@ -647,7 +550,7 @@ def getDocumentation(self, key: Any) -> str | None: raise NotImplementedError() @abstractmethod - def setDocumentation(self, key: Any, doc: str | None) -> None: + def setDocumentation(self, key: _Key, doc: str | None) -> None: """Set the documentation string for a collection. Parameters @@ -659,7 +562,8 @@ def setDocumentation(self, key: Any, doc: str | None) -> None: """ raise NotImplementedError() - def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord]: + @abstractmethod + def getParentChains(self, key: _Key) -> Iterator[ChainedCollectionRecord[_Key]]: """Find all CHAINED collections that directly contain the given collection. @@ -669,3 +573,21 @@ def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord]: Internal primary key value for the collection. """ raise NotImplementedError() + + @abstractmethod + def update_chain( + self, record: ChainedCollectionRecord[_Key], children: Iterable[str], flatten: bool = False + ) -> ChainedCollectionRecord[_Key]: + """Update chained collection composition. + + Parameters + ---------- + record : `ChainedCollectionRecord` + Chained collection record. + children : `~collections.abc.Iterable` [`str`] + Ordered names of children collections. + flatten : `bool`, optional + If `True`, recursively flatten out any nested + `~CollectionType.CHAINED` collections in ``children`` first. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 733f820941..538ffbc265 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -603,7 +603,7 @@ def setCollectionChain(self, parent: str, children: Any, *, flatten: bool = Fals assert isinstance(record, ChainedCollectionRecord) children = CollectionWildcard.from_expression(children).require_ordered() if children != record.children or flatten: - record.update(self._managers.collections, children, flatten=flatten) + self._managers.collections.update_chain(record, children, flatten=flatten) def getCollectionParentChains(self, collection: str) -> set[str]: """Return the CHAINED collections that directly contain the given one. From 0d879297c2637f6ae431593d96ce342042961c5a Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Fri, 3 Nov 2023 15:06:35 -0700 Subject: [PATCH 04/11] Reduce need for schema reflection for registry tables. Static tables do not really need schema verification because we rely on version numbers in butler_attributes. Verification and reflection may still be usefule for dynamic tables (tags/calibs) but we now delay it until the tables are actually used. --- .../datasets/byDimensions/_manager.py | 74 ++++++++++++------- .../datasets/byDimensions/_storage.py | 26 +++++-- .../butler/registry/interfaces/_database.py | 4 - 3 files changed, 67 insertions(+), 37 deletions(-) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index a0d0dfadfd..60e7b0f308 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -55,6 +55,36 @@ class MissingDatabaseTableError(RuntimeError): """Exception raised when a table is not found in a database.""" +class _ExistingTableFactory: + """Factory for `sqlalchemy.schema.Table` instances that returns already + existing table instance. + """ + + def __init__(self, table: sqlalchemy.schema.Table): + self._table = table + + def __call__(self) -> sqlalchemy.schema.Table: + return self._table + + +class _SpecTableFactory: + """Factory for `sqlalchemy.schema.Table` instances that builds table + instances using provided `ddl.TableSpec` definition and verifies that + table exists in the database. + """ + + def __init__(self, db: Database, name: str, spec: ddl.TableSpec): + self._db = db + self._name = name + self._spec = spec + + def __call__(self) -> sqlalchemy.schema.Table: + table = self._db.getExistingTable(self._name, self._spec) + if table is None: + raise MissingDatabaseTableError(f"Table {self._name} is missing from database schema.") + return table + + class ByDimensionsDatasetRecordStorageManagerBase(DatasetRecordStorageManager): """A manager class for datasets that uses one dataset-collection table for each group of dataset types that share the same dimensions. @@ -218,37 +248,24 @@ def refresh(self) -> None: datasetType = DatasetType( name, dimensions, row[c.storage_class], isCalibration=(calibTableName is not None) ) - tags = self._db.getExistingTable( - row[c.tag_association_table], - makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()), - ) - if tags is None: - raise MissingDatabaseTableError( - f"Table {row[c.tag_association_table]} is missing from database schema." - ) + tags_spec = makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()) + tags_table_factory = _SpecTableFactory(self._db, row[c.tag_association_table], tags_spec) + calibs_table_factory = None if calibTableName is not None: - calibs = self._db.getExistingTable( - row[c.calibration_association_table], - makeCalibTableSpec( - datasetType, - type(self._collections), - self._db.getTimespanRepresentation(), - self.getIdColumnType(), - ), + calibs_spec = makeCalibTableSpec( + datasetType, + type(self._collections), + self._db.getTimespanRepresentation(), + self.getIdColumnType(), ) - if calibs is None: - raise MissingDatabaseTableError( - f"Table {row[c.calibration_association_table]} is missing from database schema." - ) - else: - calibs = None + calibs_table_factory = _SpecTableFactory(self._db, calibTableName, calibs_spec) storage = self._recordStorageType( db=self._db, datasetType=datasetType, static=self._static, summaries=self._summaries, - tags=tags, - calibs=calibs, + tags_table_factory=tags_table_factory, + calibs_table_factory=calibs_table_factory, dataset_type_id=row["id"], collections=self._collections, use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, @@ -302,6 +319,8 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool tagTableName, makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()), ) + tags_table_factory = _ExistingTableFactory(tags) + calibs_table_factory = None if calibTableName is not None: calibs = self._db.ensureTableExists( calibTableName, @@ -312,8 +331,7 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool self.getIdColumnType(), ), ) - else: - calibs = None + calibs_table_factory = _ExistingTableFactory(calibs) row, inserted = self._db.sync( self._static.dataset_type, keys={"name": datasetType.name}, @@ -335,8 +353,8 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool datasetType=datasetType, static=self._static, summaries=self._summaries, - tags=tags, - calibs=calibs, + tags_table_factory=tags_table_factory, + calibs_table_factory=calibs_table_factory, dataset_type_id=row["id"], collections=self._collections, use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py index 454702b5b8..b88b80a3c5 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py @@ -32,7 +32,7 @@ __all__ = ("ByDimensionsDatasetRecordStorage",) -from collections.abc import Iterable, Iterator, Sequence, Set +from collections.abc import Callable, Iterable, Iterator, Sequence, Set from datetime import datetime from typing import TYPE_CHECKING @@ -77,9 +77,9 @@ def __init__( collections: CollectionManager, static: StaticDatasetTablesTuple, summaries: CollectionSummaryManager, - tags: sqlalchemy.schema.Table, + tags_table_factory: Callable[[], sqlalchemy.schema.Table], use_astropy_ingest_date: bool, - calibs: sqlalchemy.schema.Table | None, + calibs_table_factory: Callable[[], sqlalchemy.schema.Table] | None, ): super().__init__(datasetType=datasetType) self._dataset_type_id = dataset_type_id @@ -87,10 +87,26 @@ def __init__( self._collections = collections self._static = static self._summaries = summaries - self._tags = tags - self._calibs = calibs + self._tags_table_factory = tags_table_factory + self._calibs_table_factory = calibs_table_factory self._runKeyColumn = collections.getRunForeignKeyName() self._use_astropy = use_astropy_ingest_date + self._tags_table: sqlalchemy.schema.Table | None = None + self._calibs_table: sqlalchemy.schema.Table | None = None + + @property + def _tags(self) -> sqlalchemy.schema.Table: + if self._tags_table is None: + self._tags_table = self._tags_table_factory() + return self._tags_table + + @property + def _calibs(self) -> sqlalchemy.schema.Table | None: + if self._calibs_table is None: + if self._calibs_table_factory is None: + return None + self._calibs_table = self._calibs_table_factory() + return self._calibs_table def delete(self, datasets: Iterable[DatasetRef]) -> None: # Docstring inherited from DatasetRecordStorage. diff --git a/python/lsst/daf/butler/registry/interfaces/_database.py b/python/lsst/daf/butler/registry/interfaces/_database.py index 438bf55613..61bc9e2440 100644 --- a/python/lsst/daf/butler/registry/interfaces/_database.py +++ b/python/lsst/daf/butler/registry/interfaces/_database.py @@ -140,10 +140,6 @@ def addTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table: relationships. """ name = self._db._mangleTableName(name) - if name in self._tableNames: - _checkExistingTableDefinition( - name, spec, self._inspector.get_columns(name, schema=self._db.namespace) - ) metadata = self._db._metadata assert metadata is not None, "Guaranteed by context manager that returns this object." table = self._db._convertTableSpec(name, spec, metadata) From 8481829eb56f1a6d8decf43235bf189bb3234825 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Sun, 5 Nov 2023 21:23:40 -0800 Subject: [PATCH 05/11] Implement delayed population of collection cache. --- .../daf/butler/registry/collections/_base.py | 154 +++++++++--------- .../butler/registry/collections/nameKey.py | 115 ++++++++++++- .../registry/collections/synthIntKey.py | 132 ++++++++++++++- 3 files changed, 308 insertions(+), 93 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 95a6bfa950..0e78a76794 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -30,10 +30,9 @@ __all__ = () -import contextlib import itertools from abc import abstractmethod -from collections import defaultdict, namedtuple +from collections import namedtuple from collections.abc import Iterable, Iterator, Set from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -199,72 +198,24 @@ def __init__( self._collectionIdName = collectionIdName self._records: dict[K, CollectionRecord[K]] = {} # indexed by record ID self._dimensions = dimensions + self._full_fetch = False # True if cache contains everything. def refresh(self) -> None: # Docstring inherited from CollectionManager. - sql = sqlalchemy.sql.select( - *(list(self._tables.collection.columns) + list(self._tables.run.columns)) - ).select_from(self._tables.collection.join(self._tables.run, isouter=True)) - # Extract _all_ chain mappings as well - chain_sql = sqlalchemy.sql.select( - self._tables.collection_chain.columns["parent"], - self._tables.collection_chain.columns["position"], - self._tables.collection_chain.columns["child"], - ).select_from(self._tables.collection_chain) + # We just reset the cache here but do not retrieve any records. + self._full_fetch = False + self._setRecordCache([]) - with self._db.transaction(): - with self._db.query(sql) as sql_result: - sql_rows = sql_result.mappings().fetchall() - with self._db.query(chain_sql) as sql_result: - chain_rows = sql_result.mappings().fetchall() - - # Build all chain definitions. - chains_defs: dict[K, list[tuple[int, K]]] = defaultdict(list) - for row in chain_rows: - chains_defs[row["parent"]].append((row["position"], row["child"])) - - # Put found records into a temporary instead of updating self._records - # in place, for exception safety. - records: list[CollectionRecord] = [] - TimespanReprClass = self._db.getTimespanRepresentation() - id_to_name: dict[K, str] = {} - chained_ids: list[K] = [] - for row in sql_rows: - collection_id = row[self._tables.collection.columns[self._collectionIdName]] - name = row[self._tables.collection.columns.name] - id_to_name[collection_id] = name - type = CollectionType(row["type"]) - record: CollectionRecord - if type is CollectionType.RUN: - record = RunRecord( - key=collection_id, - name=name, - host=row[self._tables.run.columns.host], - timespan=TimespanReprClass.extract(row), - ) - records.append(record) - elif type is CollectionType.CHAINED: - # Need to delay chained collection construction until all names - # are known. - chained_ids.append(collection_id) - else: - record = CollectionRecord(key=collection_id, name=name, type=type) - records.append(record) - - for chained_id in chained_ids: - children_names = [id_to_name[child_id] for _, child_id in sorted(chains_defs[chained_id])] - record = ChainedCollectionRecord( - key=chained_id, - name=id_to_name[chained_id], - children=children_names, - ) - records.append(record) - - self._setRecordCache(records) + def _fetch_all(self) -> None: + """Retrieve all records into cache if not done so yet.""" + if not self._full_fetch: + records = self._fetch_by_key(None) + self._setRecordCache(records) + self._full_fetch = True def register( self, name: str, type: CollectionType, doc: str | None = None - ) -> tuple[CollectionRecord, bool]: + ) -> tuple[CollectionRecord[K], bool]: # Docstring inherited from CollectionManager. registered = False record = self._getByName(name) @@ -323,12 +274,31 @@ def find(self, name: str) -> CollectionRecord[K]: raise MissingCollectionError(f"No collection with name '{name}' found.") return result + def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]: + """Return multiple records given their names.""" + names = list(names) + # To protect against potential races in cache updates. + records = {} + for name in names: + records[name] = self._get_cached_name(name) + fetch_names = [name for name, record in records.items() if record is None] + for record in self._fetch_by_name(fetch_names): + records[record.name] = record + missing_names = [name for name, record in records.items() if record is None] + if len(missing_names) == 1: + raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.") + elif len(missing_names) > 1: + raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.") + return [cast(CollectionRecord[K], records[name]) for name in names] + def __getitem__(self, key: Any) -> CollectionRecord[K]: # Docstring inherited from CollectionManager. - try: - return self._records[key] - except KeyError as err: - raise MissingCollectionError(f"Collection with key '{key}' not found.") from err + if (record := self._records.get(key)) is not None: + return record + if records := self._fetch_by_key([key]): + return records[0] + else: + raise MissingCollectionError(f"Collection with key '{key}' not found.") def resolve_wildcard( self, @@ -353,20 +323,25 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect yield record if flatten_chains and record.type is CollectionType.CHAINED: done.add(record.name) - for name in cast(ChainedCollectionRecord[K], record).children: + for child in self._find_many(cast(ChainedCollectionRecord[K], record).children): # flake8 can't tell that we only delete this closure when # we're totally done with it. - yield from resolve_nested(self.find(name), done) # noqa: F821 + yield from resolve_nested(child, done) # noqa: F821 result: list[CollectionRecord[K]] = [] + # If we have wildcard or ellipsis we need to read everything in memory. + if wildcard.patterns: + self._fetch_all() + if wildcard.patterns is ...: for record in self._records.values(): result.extend(resolve_nested(record, done)) del resolve_nested return result - for name in wildcard.strings: - result.extend(resolve_nested(self.find(name), done)) + if wildcard.strings: + for record in self._find_many(wildcard.strings): + result.extend(resolve_nested(record, done)) if wildcard.patterns: for record in self._records.values(): if any(p.fullmatch(record.name) for p in wildcard.patterns): @@ -374,7 +349,7 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect del resolve_nested return result - def getDocumentation(self, key: Any) -> str | None: + def getDocumentation(self, key: K) -> str | None: # Docstring inherited from CollectionManager. sql = ( sqlalchemy.sql.select(self._tables.collection.columns.doc) @@ -384,7 +359,7 @@ def getDocumentation(self, key: Any) -> str | None: with self._db.query(sql) as sql_result: return sql_result.scalar() - def setDocumentation(self, key: Any, doc: str | None) -> None: + def setDocumentation(self, key: K, doc: str | None) -> None: # Docstring inherited from CollectionManager. self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) @@ -404,12 +379,33 @@ def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: """Remove single record from cache.""" del self._records[record.key] - @abstractmethod def _getByName(self, name: str) -> CollectionRecord[K] | None: """Find collection record given collection name.""" + if (record := self._get_cached_name(name)) is not None: + return record + records = self._fetch_by_name([name]) + for record in records: + self._addCachedRecord(record) + return records[0] if records else None + + @abstractmethod + def _get_cached_name(self, name: str) -> CollectionRecord[K] | None: + """Find cached collection record given its name.""" + raise NotImplementedError() + + @abstractmethod + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]: + """Fetch collection record from database given its name.""" + raise NotImplementedError() + + @abstractmethod + def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]: + """Fetch collection record from database given its key, or fetch all + collctions if argument is None. + """ raise NotImplementedError() - def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord[K]]: + def getParentChains(self, key: K) -> Iterator[ChainedCollectionRecord[K]]: # Docstring inherited from CollectionManager. table = self._tables.collection_chain sql = ( @@ -419,11 +415,13 @@ def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord[K]]: ) with self._db.query(sql) as sql_result: parent_keys = sql_result.scalars().all() - for key in parent_keys: - # TODO: Just in case cached records miss new parent collections. - # This is temporary, will replace with non-cached records soon. - with contextlib.suppress(KeyError): - yield cast(ChainedCollectionRecord[K], self._records[key]) + # TODO: It would be more efficient to write a single query that both + # finds parents and all their children, but for now we do not care + # much about efficiency. Also the only client of this method does not + # need full records, only parent collection names, maybe we should + # change this method to return names instead. + for record in self._fetch_by_key(parent_keys): + yield cast(ChainedCollectionRecord[K], record) def update_chain( self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index e5e635e61c..7336558cc9 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -26,16 +26,17 @@ # along with this program. If not, see . from __future__ import annotations -from ... import ddl - __all__ = ["NameKeyCollectionManager"] +from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any import sqlalchemy +from ... import ddl from ..._timespan import TimespanDatabaseRepresentation -from ..interfaces import VersionTuple +from .._collection_type import CollectionType +from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple from ._base import ( CollectionTablesTuple, DefaultCollectionManager, @@ -44,7 +45,7 @@ ) if TYPE_CHECKING: - from ..interfaces import CollectionRecord, Database, DimensionRecordStorageManager, StaticTablesContext + from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext _KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True) @@ -68,7 +69,7 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> ) -class NameKeyCollectionManager(DefaultCollectionManager): +class NameKeyCollectionManager(DefaultCollectionManager[str]): """A `CollectionManager` implementation that uses collection names for primary/foreign keys and aggressively loads all collection/run records in the database into memory. @@ -152,10 +153,110 @@ def addRunForeignKey( ) return copy - def _getByName(self, name: str) -> CollectionRecord | None: - # Docstring inherited from DefaultCollectionManager. + def _get_cached_name(self, name: str) -> CollectionRecord[str] | None: + # Docstring inherited from base class. return self._records.get(name) + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: + # Docstring inherited from base class. + return self._fetch_by_key(names) + + def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]: + # Docstring inherited from base class. + sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from( + self._tables.collection.join(self._tables.run, isouter=True) + ) + + chain_sql = sqlalchemy.sql.select( + self._tables.collection_chain.columns["parent"], + self._tables.collection_chain.columns["position"], + self._tables.collection_chain.columns["child"], + ) + + records: list[CollectionRecord[str]] = [] + # We want to keep transactions as short as possible. When we fetch + # everything we want to quickly fetch things into memory and finish + # transaction. When we fetch just few records we need to process result + # of the first query before we can run the second one. + if collection_ids is not None: + sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids)) + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + + if chained_ids: + # Retrieve chained collection compositions + chain_sql = chain_sql.where( + self._tables.collection_chain.columns["parent"].in_(chained_ids) + ) + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records += self._rows_to_chains(chain_rows, chained_ids) + + else: + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + records += self._rows_to_chains(chain_rows, chained_ids) + + return records + + def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]: + """Convert rows returned from collection query to a list of records + and a list chained collection names. + """ + records: list[CollectionRecord[str]] = [] + TimespanReprClass = self._db.getTimespanRepresentation() + chained_ids: list[str] = [] + for row in rows: + name = row[self._tables.collection.columns.name] + type = CollectionType(row["type"]) + record: CollectionRecord[str] + if type is CollectionType.RUN: + record = RunRecord[str]( + key=name, + name=name, + host=row[self._tables.run.columns.host], + timespan=TimespanReprClass.extract(row), + ) + records.append(record) + elif type is CollectionType.CHAINED: + # Need to delay chained collection construction until to + # fetch their children names. + chained_ids.append(name) + else: + record = CollectionRecord[str](key=name, name=name, type=type) + records.append(record) + + return records, chained_ids + + def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]: + """Convert rows returned from collection chain query to a list of + records. + """ + chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} + for row in rows: + chains_defs[row["parent"]].append((row["position"], row["child"])) + + records: list[CollectionRecord[str]] = [] + for name, children in chains_defs.items(): + children_names = [child for _, child in sorted(children)] + record = ChainedCollectionRecord[str]( + key=name, + name=name, + children=children_names, + ) + records.append(record) + + return records + @classmethod def currentVersions(cls) -> list[VersionTuple]: # Docstring inherited from VersionedExtension. diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index 8e49140c8d..2e1bf5f758 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -30,13 +30,14 @@ __all__ = ["SynthIntKeyCollectionManager"] -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any import sqlalchemy from ..._timespan import TimespanDatabaseRepresentation -from ..interfaces import CollectionRecord, VersionTuple +from .._collection_type import CollectionType +from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple from ._base import ( CollectionTablesTuple, DefaultCollectionManager, @@ -73,7 +74,7 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> ) -class SynthIntKeyCollectionManager(DefaultCollectionManager): +class SynthIntKeyCollectionManager(DefaultCollectionManager[int]): """A `CollectionManager` implementation that uses synthetic primary key (auto-incremented integer) for collections table. @@ -184,7 +185,7 @@ def addRunForeignKey( ) return copy - def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: + def _setRecordCache(self, records: Iterable[CollectionRecord[int]]) -> None: """Set internal record cache to contain given records, old cached records will be removed. """ @@ -194,20 +195,135 @@ def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: self._records[record.key] = record self._nameCache[record.name] = record - def _addCachedRecord(self, record: CollectionRecord) -> None: + def _addCachedRecord(self, record: CollectionRecord[int]) -> None: """Add single record to cache.""" self._records[record.key] = record self._nameCache[record.name] = record - def _removeCachedRecord(self, record: CollectionRecord) -> None: + def _removeCachedRecord(self, record: CollectionRecord[int]) -> None: """Remove single record from cache.""" del self._records[record.key] del self._nameCache[record.name] - def _getByName(self, name: str) -> CollectionRecord | None: - # Docstring inherited from DefaultCollectionManager. + def _get_cached_name(self, name: str) -> CollectionRecord[int] | None: + # Docstring inherited from base class. return self._nameCache.get(name) + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]: + # Docstring inherited from base class. + return self._fetch("name", names) + + def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]: + # Docstring inherited from base class. + return self._fetch(self._collectionIdName, collection_ids) + + def _fetch( + self, column_name: str, collections: Iterable[int | str] | None + ) -> list[CollectionRecord[int]]: + collection_chain = self._tables.collection_chain + collection = self._tables.collection + sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from( + collection.join(self._tables.run, isouter=True) + ) + + chain_sql = ( + sqlalchemy.sql.select( + collection_chain.columns["parent"], + collection_chain.columns["position"], + collection.columns["name"].label("child_name"), + ) + .select_from(collection_chain) + .join( + collection, + onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName], + ) + ) + + records: list[CollectionRecord[int]] = [] + # We want to keep transactions as short as possible. When we fetch + # everything we want to quickly fetch things into memory and finish + # transaction. When we fetch just few records we need to process first + # query before wi can run second one, + if collections is not None: + sql = sql.where(collection.columns[column_name].in_(collections)) + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + + if chained_ids: + chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids))) + + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records += self._rows_to_chains(chain_rows, chained_ids) + + else: + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + records += self._rows_to_chains(chain_rows, chained_ids) + + return records + + def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]: + """Convert rows returned from collection query to a list of records + and a dict chained collection names. + """ + records: list[CollectionRecord[int]] = [] + chained_ids: dict[int, str] = {} + TimespanReprClass = self._db.getTimespanRepresentation() + for row in rows: + key: int = row[self._collectionIdName] + name: str = row[self._tables.collection.columns.name] + type = CollectionType(row["type"]) + record: CollectionRecord[int] + if type is CollectionType.RUN: + record = RunRecord[int]( + key=key, + name=name, + host=row[self._tables.run.columns.host], + timespan=TimespanReprClass.extract(row), + ) + records.append(record) + elif type is CollectionType.CHAINED: + # Need to delay chained collection construction until to + # fetch their children names. + chained_ids[key] = name + else: + record = CollectionRecord[int](key=key, name=name, type=type) + records.append(record) + return records, chained_ids + + def _rows_to_chains( + self, rows: Iterable[Mapping], chained_ids: dict[int, str] + ) -> list[CollectionRecord[int]]: + """Convert rows returned from collection chain query to a list of + records. + """ + chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} + for row in rows: + chains_defs[row["parent"]].append((row["position"], row["child_name"])) + + records: list[CollectionRecord[int]] = [] + for key, children in chains_defs.items(): + name = chained_ids[key] + children_names = [child for _, child in sorted(children)] + record = ChainedCollectionRecord[int]( + key=key, + name=name, + children=children_names, + ) + records.append(record) + + return records + @classmethod def currentVersions(cls) -> list[VersionTuple]: # Docstring inherited from VersionedExtension. From b1d4cb8efbb88bfe4a8b15e6a87b1bd68135a2d3 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Tue, 7 Nov 2023 15:09:43 -0800 Subject: [PATCH 06/11] Make CollectionManager.getParentChains return names instead of records --- .../daf/butler/registry/collections/_base.py | 18 ------------------ .../daf/butler/registry/collections/nameKey.py | 12 ++++++++++++ .../butler/registry/collections/synthIntKey.py | 14 ++++++++++++++ .../butler/registry/interfaces/_collections.py | 11 ++++++++--- .../lsst/daf/butler/registry/sql_registry.py | 7 +------ 5 files changed, 35 insertions(+), 27 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 0e78a76794..1c3587d1ab 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -405,24 +405,6 @@ def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRe """ raise NotImplementedError() - def getParentChains(self, key: K) -> Iterator[ChainedCollectionRecord[K]]: - # Docstring inherited from CollectionManager. - table = self._tables.collection_chain - sql = ( - sqlalchemy.sql.select(table.columns["parent"]) - .select_from(table) - .where(table.columns["child"] == key) - ) - with self._db.query(sql) as sql_result: - parent_keys = sql_result.scalars().all() - # TODO: It would be more efficient to write a single query that both - # finds parents and all their children, but for now we do not care - # much about efficiency. Also the only client of this method does not - # need full records, only parent collection names, maybe we should - # change this method to return names instead. - for record in self._fetch_by_key(parent_keys): - yield cast(ChainedCollectionRecord[K], record) - def update_chain( self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False ) -> ChainedCollectionRecord[K]: diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index 7336558cc9..d8c50cce2f 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -153,6 +153,18 @@ def addRunForeignKey( ) return copy + def getParentChains(self, key: str) -> set[str]: + # Docstring inherited from CollectionManager. + table = self._tables.collection_chain + sql = ( + sqlalchemy.sql.select(table.columns["parent"]) + .select_from(table) + .where(table.columns["child"] == key) + ) + with self._db.query(sql) as sql_result: + parent_names = set(sql_result.scalars().all()) + return parent_names + def _get_cached_name(self, name: str) -> CollectionRecord[str] | None: # Docstring inherited from base class. return self._records.get(name) diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index 2e1bf5f758..d2edcaae88 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -185,6 +185,20 @@ def addRunForeignKey( ) return copy + def getParentChains(self, key: int) -> set[str]: + # Docstring inherited from CollectionManager. + chain = self._tables.collection_chain + collection = self._tables.collection + sql = ( + sqlalchemy.sql.select(collection.columns["name"]) + .select_from(collection) + .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"]) + .where(chain.columns["child"] == key) + ) + with self._db.query(sql) as sql_result: + parent_names = set(sql_result.scalars().all()) + return parent_names + def _setRecordCache(self, records: Iterable[CollectionRecord[int]]) -> None: """Set internal record cache to contain given records, old cached records will be removed. diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index 3cce276b6b..c07b894adc 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -36,7 +36,7 @@ ] from abc import abstractmethod -from collections.abc import Iterable, Iterator, Set +from collections.abc import Iterable, Set from typing import TYPE_CHECKING, Any, Generic, TypeVar from ..._timespan import Timespan @@ -563,14 +563,19 @@ def setDocumentation(self, key: _Key, doc: str | None) -> None: raise NotImplementedError() @abstractmethod - def getParentChains(self, key: _Key) -> Iterator[ChainedCollectionRecord[_Key]]: - """Find all CHAINED collections that directly contain the given + def getParentChains(self, key: _Key) -> set[str]: + """Find all CHAINED collection names that directly contain the given collection. Parameters ---------- key Internal primary key value for the collection. + + Returns + ------- + names : `set` [`str`] + Parent collection names. """ raise NotImplementedError() diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 538ffbc265..5e03938b78 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -618,12 +618,7 @@ def getCollectionParentChains(self, collection: str) -> set[str]: chains : `set` of `str` Set of `~CollectionType.CHAINED` collection names. """ - return { - record.name - for record in self._managers.collections.getParentChains( - self._managers.collections.find(collection).key - ) - } + return self._managers.collections.getParentChains(self._managers.collections.find(collection).key) def getCollectionDocumentation(self, collection: str) -> str | None: """Retrieve the documentation string for a collection. From 5833c91f55f4be34daf795f115f44fdbc45a7085 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Wed, 8 Nov 2023 16:13:10 -0800 Subject: [PATCH 07/11] Remove dataset type caching from datasets manager. --- .../datasets/byDimensions/_manager.py | 235 +++++++++--------- .../datasets/byDimensions/summaries.py | 37 ++- .../butler/registry/interfaces/_datasets.py | 4 +- .../lsst/daf/butler/registry/sql_registry.py | 3 +- 4 files changed, 142 insertions(+), 137 deletions(-) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index 60e7b0f308..88e20c67f2 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -4,6 +4,7 @@ __all__ = ("ByDimensionsDatasetRecordStorageManagerUUID",) +import dataclasses import logging import warnings from collections import defaultdict @@ -55,16 +56,14 @@ class MissingDatabaseTableError(RuntimeError): """Exception raised when a table is not found in a database.""" -class _ExistingTableFactory: - """Factory for `sqlalchemy.schema.Table` instances that returns already - existing table instance. - """ - - def __init__(self, table: sqlalchemy.schema.Table): - self._table = table +@dataclasses.dataclass +class _DatasetTypeRecord: + """Contents of a single dataset type record.""" - def __call__(self) -> sqlalchemy.schema.Table: - return self._table + dataset_type: DatasetType + dataset_type_id: int + tag_table_name: str + calib_table_name: str | None class _SpecTableFactory: @@ -139,8 +138,6 @@ def __init__( self._dimensions = dimensions self._static = static self._summaries = summaries - self._byName: dict[str, ByDimensionsDatasetRecordStorage] = {} - self._byId: dict[int, ByDimensionsDatasetRecordStorage] = {} @classmethod def initialize( @@ -162,6 +159,7 @@ def initialize( context, collections=collections, dimensions=dimensions, + dataset_type_table=static.dataset_type, ) return cls( db=db, @@ -236,44 +234,33 @@ def addDatasetForeignKey( def refresh(self) -> None: # Docstring inherited from DatasetRecordStorageManager. - byName: dict[str, ByDimensionsDatasetRecordStorage] = {} - byId: dict[int, ByDimensionsDatasetRecordStorage] = {} - c = self._static.dataset_type.columns - with self._db.query(self._static.dataset_type.select()) as sql_result: - sql_rows = sql_result.mappings().fetchall() - for row in sql_rows: - name = row[c.name] - dimensions = self._dimensions.loadDimensionGraph(row[c.dimensions_key]) - calibTableName = row[c.calibration_association_table] - datasetType = DatasetType( - name, dimensions, row[c.storage_class], isCalibration=(calibTableName is not None) - ) - tags_spec = makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()) - tags_table_factory = _SpecTableFactory(self._db, row[c.tag_association_table], tags_spec) - calibs_table_factory = None - if calibTableName is not None: - calibs_spec = makeCalibTableSpec( - datasetType, - type(self._collections), - self._db.getTimespanRepresentation(), - self.getIdColumnType(), - ) - calibs_table_factory = _SpecTableFactory(self._db, calibTableName, calibs_spec) - storage = self._recordStorageType( - db=self._db, - datasetType=datasetType, - static=self._static, - summaries=self._summaries, - tags_table_factory=tags_table_factory, - calibs_table_factory=calibs_table_factory, - dataset_type_id=row["id"], - collections=self._collections, - use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, + pass + + def _make_storage(self, record: _DatasetTypeRecord) -> ByDimensionsDatasetRecordStorage: + """Create storage instance for a dataset type record.""" + tags_spec = makeTagTableSpec(record.dataset_type, type(self._collections), self.getIdColumnType()) + tags_table_factory = _SpecTableFactory(self._db, record.tag_table_name, tags_spec) + calibs_table_factory = None + if record.calib_table_name is not None: + calibs_spec = makeCalibTableSpec( + record.dataset_type, + type(self._collections), + self._db.getTimespanRepresentation(), + self.getIdColumnType(), ) - byName[datasetType.name] = storage - byId[storage._dataset_type_id] = storage - self._byName = byName - self._byId = byId + calibs_table_factory = _SpecTableFactory(self._db, record.calib_table_name, calibs_spec) + storage = self._recordStorageType( + db=self._db, + datasetType=record.dataset_type, + static=self._static, + summaries=self._summaries, + tags_table_factory=tags_table_factory, + calibs_table_factory=calibs_table_factory, + dataset_type_id=record.dataset_type_id, + collections=self._collections, + use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, + ) + return storage def remove(self, name: str) -> None: # Docstring inherited from DatasetRecordStorageManager. @@ -296,33 +283,28 @@ def remove(self, name: str) -> None: def find(self, name: str) -> DatasetRecordStorage | None: # Docstring inherited from DatasetRecordStorageManager. - return self._byName.get(name) + record = self._fetch_dataset_type_record(name) + return self._make_storage(record) if record is not None else None - def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool]: + def register(self, datasetType: DatasetType) -> bool: # Docstring inherited from DatasetRecordStorageManager. if datasetType.isComponent(): raise ValueError( f"Component dataset types can not be stored in registry. Rejecting {datasetType.name}" ) - storage = self._byName.get(datasetType.name) - if storage is None: + record = self._fetch_dataset_type_record(datasetType.name) + if record is None: dimensionsKey = self._dimensions.saveDimensionGraph(datasetType.dimensions) tagTableName = makeTagTableName(datasetType, dimensionsKey) - calibTableName = ( - makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None - ) - # The order is important here, we want to create tables first and - # only register them if this operation is successful. We cannot - # wrap it into a transaction because database class assumes that - # DDL is not transaction safe in general. - tags = self._db.ensureTableExists( + self._db.ensureTableExists( tagTableName, makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()), ) - tags_table_factory = _ExistingTableFactory(tags) - calibs_table_factory = None + calibTableName = ( + makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None + ) if calibTableName is not None: - calibs = self._db.ensureTableExists( + self._db.ensureTableExists( calibTableName, makeCalibTableSpec( datasetType, @@ -331,8 +313,7 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool self.getIdColumnType(), ), ) - calibs_table_factory = _ExistingTableFactory(calibs) - row, inserted = self._db.sync( + _, inserted = self._db.sync( self._static.dataset_type, keys={"name": datasetType.name}, compared={ @@ -347,28 +328,17 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool }, returning=["id", "tag_association_table"], ) - assert row is not None - storage = self._recordStorageType( - db=self._db, - datasetType=datasetType, - static=self._static, - summaries=self._summaries, - tags_table_factory=tags_table_factory, - calibs_table_factory=calibs_table_factory, - dataset_type_id=row["id"], - collections=self._collections, - use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, - ) - self._byName[datasetType.name] = storage - self._byId[storage._dataset_type_id] = storage else: - if datasetType != storage.datasetType: + if datasetType != record.dataset_type: raise ConflictingDefinitionError( f"Given dataset type {datasetType} is inconsistent " - f"with database definition {storage.datasetType}." + f"with database definition {record.dataset_type}." ) inserted = False - return storage, bool(inserted) + # TODO: We return storage instance from this method, but the only + # client that uses this method ignores it. Maybe we should drop it + # and avoid making storage instance above. + return bool(inserted) def resolve_wildcard( self, @@ -422,15 +392,13 @@ def resolve_wildcard( raise TypeError( "Universal wildcard '...' is not permitted for dataset types in this context." ) - for storage in self._byName.values(): - result[storage.datasetType].add(None) + for datasetType in self._fetch_dataset_types(): + result[datasetType].add(None) if components: try: - result[storage.datasetType].update( - storage.datasetType.storageClass.allComponents().keys() - ) + result[datasetType].update(datasetType.storageClass.allComponents().keys()) if ( - storage.datasetType.storageClass.allComponents() + datasetType.storageClass.allComponents() and not already_warned and components_deprecated ): @@ -442,7 +410,7 @@ def resolve_wildcard( already_warned = True except KeyError as err: _LOG.warning( - f"Could not load storage class {err} for {storage.datasetType.name}; " + f"Could not load storage class {err} for {datasetType.name}; " "if it has components they will not be included in query results.", ) elif wildcard.patterns: @@ -454,29 +422,28 @@ def resolve_wildcard( FutureWarning, stacklevel=find_outside_stacklevel("lsst.daf.butler"), ) - for storage in self._byName.values(): - if any(p.fullmatch(storage.datasetType.name) for p in wildcard.patterns): - result[storage.datasetType].add(None) + dataset_types = self._fetch_dataset_types() + for datasetType in dataset_types: + if any(p.fullmatch(datasetType.name) for p in wildcard.patterns): + result[datasetType].add(None) if components is not False: - for storage in self._byName.values(): - if components is None and storage.datasetType in result: + for datasetType in dataset_types: + if components is None and datasetType in result: continue try: - components_for_parent = storage.datasetType.storageClass.allComponents().keys() + components_for_parent = datasetType.storageClass.allComponents().keys() except KeyError as err: _LOG.warning( - f"Could not load storage class {err} for {storage.datasetType.name}; " + f"Could not load storage class {err} for {datasetType.name}; " "if it has components they will not be included in query results." ) continue for component_name in components_for_parent: if any( - p.fullmatch( - DatasetType.nameWithComponent(storage.datasetType.name, component_name) - ) + p.fullmatch(DatasetType.nameWithComponent(datasetType.name, component_name)) for p in wildcard.patterns ): - result[storage.datasetType].add(component_name) + result[datasetType].add(component_name) if not already_warned and components_deprecated: warnings.warn( deprecation_message, @@ -492,49 +459,77 @@ def getDatasetRef(self, id: DatasetId) -> DatasetRef | None: sqlalchemy.sql.select( self._static.dataset.columns.dataset_type_id, self._static.dataset.columns[self._collections.getRunForeignKeyName()], + *self._static.dataset_type.columns, ) .select_from(self._static.dataset) + .join(self._static.dataset_type) .where(self._static.dataset.columns.id == id) ) with self._db.query(sql) as sql_result: row = sql_result.mappings().fetchone() if row is None: return None - recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id]) - if recordsForType is None: - self.refresh() - recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id]) - assert recordsForType is not None, "Should be guaranteed by foreign key constraints." + storage = self._make_storage(self._record_from_row(row)) return DatasetRef( - recordsForType.datasetType, - dataId=recordsForType.getDataId(id=id), + storage.datasetType, + dataId=storage.getDataId(id=id), id=id, run=self._collections[row[self._collections.getRunForeignKeyName()]].name, ) - def _dataset_type_factory(self, dataset_type_id: int) -> DatasetType: - """Return dataset type given its ID.""" - return self._byId[dataset_type_id].datasetType + def _fetch_dataset_type_record(self, name: str) -> _DatasetTypeRecord | None: + """Retrieve all dataset types defined in database. + + Yields + ------ + dataset_types : `_DatasetTypeRecord` + Information from a single database record. + """ + c = self._static.dataset_type.columns + stmt = self._static.dataset_type.select().where(c.name == name) + with self._db.query(stmt) as sql_result: + row = sql_result.mappings().one_or_none() + if row is None: + return None + else: + return self._record_from_row(row) + + def _record_from_row(self, row: Mapping) -> _DatasetTypeRecord: + name = row["name"] + dimensions = self._dimensions.loadDimensionGraph(row["dimensions_key"]) + calibTableName = row["calibration_association_table"] + datasetType = DatasetType( + name, dimensions, row["storage_class"], isCalibration=(calibTableName is not None) + ) + return _DatasetTypeRecord( + dataset_type=datasetType, + dataset_type_id=row["id"], + tag_table_name=row["tag_association_table"], + calib_table_name=calibTableName, + ) + + def _dataset_type_from_row(self, row: Mapping) -> DatasetType: + return self._record_from_row(row).dataset_type + + def _fetch_dataset_types(self) -> list[DatasetType]: + """Fetch list of all defined dataset types.""" + with self._db.query(self._static.dataset_type.select()) as sql_result: + sql_rows = sql_result.mappings().fetchall() + return [self._record_from_row(row).dataset_type for row in sql_rows] def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummary: # Docstring inherited from DatasetRecordStorageManager. - summaries = self._summaries.fetch_summaries([collection], None, self._dataset_type_factory) + summaries = self._summaries.fetch_summaries([collection], None, self._dataset_type_from_row) return summaries[collection.key] def fetch_summaries( self, collections: Iterable[CollectionRecord], dataset_types: Iterable[DatasetType] | None = None ) -> Mapping[Any, CollectionSummary]: # Docstring inherited from DatasetRecordStorageManager. - dataset_type_ids: list[int] | None = None + dataset_type_names: Iterable[str] | None = None if dataset_types is not None: - dataset_type_ids = [] - for dataset_type in dataset_types: - if dataset_type.isComponent(): - dataset_type = dataset_type.makeCompositeDatasetType() - # Assume we know all possible names. - dataset_type_id = self._byName[dataset_type.name]._dataset_type_id - dataset_type_ids.append(dataset_type_id) - return self._summaries.fetch_summaries(collections, dataset_type_ids, self._dataset_type_factory) + dataset_type_names = set(dataset_type.name for dataset_type in dataset_types) + return self._summaries.fetch_summaries(collections, dataset_type_names, self._dataset_type_from_row) _versions: list[VersionTuple] """Schema version for this class.""" diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py index 73511ddd0c..41687cb9c2 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py @@ -133,6 +133,8 @@ class CollectionSummaryManager: Manager object for the dimensions in this `Registry`. tables : `CollectionSummaryTables` Struct containing the tables that hold collection summaries. + dataset_type_table : `sqlalchemy.schema.Table` + Table containing dataset type definitions. """ def __init__( @@ -142,12 +144,14 @@ def __init__( collections: CollectionManager, dimensions: DimensionRecordStorageManager, tables: CollectionSummaryTables[sqlalchemy.schema.Table], + dataset_type_table: sqlalchemy.schema.Table, ): self._db = db self._collections = collections self._collectionKeyName = collections.getCollectionForeignKeyName() self._dimensions = dimensions self._tables = tables + self._dataset_type_table = dataset_type_table @classmethod def initialize( @@ -157,6 +161,7 @@ def initialize( *, collections: CollectionManager, dimensions: DimensionRecordStorageManager, + dataset_type_table: sqlalchemy.schema.Table, ) -> CollectionSummaryManager: """Create all summary tables (or check that they have been created), returning an object to manage them. @@ -172,6 +177,8 @@ def initialize( Manager object for the collections in this `Registry`. dimensions : `DimensionRecordStorageManager` Manager object for the dimensions in this `Registry`. + dataset_type_table : `sqlalchemy.schema.Table` + Table containing dataset type definitions. Returns ------- @@ -193,6 +200,7 @@ def initialize( collections=collections, dimensions=dimensions, tables=tables, + dataset_type_table=dataset_type_table, ) def update( @@ -240,8 +248,8 @@ def update( def fetch_summaries( self, collections: Iterable[CollectionRecord], - dataset_type_ids: Iterable[int] | None, - dataset_type_factory: Callable[[int], DatasetType], + dataset_type_names: Iterable[str] | None, + dataset_type_factory: Callable[[sqlalchemy.engine.RowMapping], DatasetType], ) -> Mapping[Any, CollectionSummary]: """Fetch collection summaries given their names and dataset types. @@ -249,12 +257,12 @@ def fetch_summaries( ---------- collections : `~collections.abc.Iterable` [`CollectionRecord`] Collection records to query. - dataset_type_ids : `~collections.abc.Iterable` [`int`] - IDs of dataset types to include into returned summaries. If `None` - then all dataset types will be included. + dataset_type_names : `~collections.abc.Iterable` [`str`] + Names of dataset types to include into returned summaries. If + `None` then all dataset types will be included. dataset_type_factory : `Callable` - Method that returns `DatasetType` instance given its dataset type - ID. + Method that takes a table row and make `DatasetType` instance out + of it. Returns ------- @@ -282,8 +290,10 @@ def fetch_summaries( # information at once. coll_col = self._tables.datasetType.columns[self._collectionKeyName].label(self._collectionKeyName) dataset_type_id_col = self._tables.datasetType.columns.dataset_type_id.label("dataset_type_id") - columns = [coll_col, dataset_type_id_col] - fromClause: sqlalchemy.sql.expression.FromClause = self._tables.datasetType + columns = [coll_col, dataset_type_id_col] + list(self._dataset_type_table.columns) + fromClause: sqlalchemy.sql.expression.FromClause = self._tables.datasetType.join( + self._dataset_type_table + ) for dimension, table in self._tables.dimensions.items(): columns.append(table.columns[dimension.name].label(dimension.name)) fromClause = fromClause.join( @@ -297,8 +307,8 @@ def fetch_summaries( sql = sqlalchemy.sql.select(*columns).select_from(fromClause) sql = sql.where(coll_col.in_([coll.key for coll in non_chains])) - if dataset_type_ids is not None: - sql = sql.where(dataset_type_id_col.in_(dataset_type_ids)) + if dataset_type_names is not None: + sql = sql.where(self._dataset_type_table.columns["name"].in_(dataset_type_names)) # Run the query and construct CollectionSummary objects from the result # rows. This will never include CHAINED collections or collections @@ -306,13 +316,16 @@ def fetch_summaries( summaries: dict[Any, CollectionSummary] = {} with self._db.query(sql) as sql_result: sql_rows = sql_result.mappings().fetchall() + dataset_type_ids: dict[int, DatasetType] = {} for row in sql_rows: # Collection key should never be None/NULL; it's what we join on. # Extract that and then turn it into a collection name. collectionKey = row[self._collectionKeyName] # dataset_type_id should also never be None/NULL; it's in the first # table we joined. - dataset_type = dataset_type_factory(row["dataset_type_id"]) + dataset_type_id = row["dataset_type_id"] + if (dataset_type := dataset_type_ids.get(dataset_type_id)) is None: + dataset_type_ids[dataset_type_id] = dataset_type = dataset_type_factory(row) # See if we have a summary already for this collection; if not, # make one. summary = summaries.get(collectionKey) diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py index 84a5a735d4..8bafb02274 100644 --- a/python/lsst/daf/butler/registry/interfaces/_datasets.py +++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py @@ -487,7 +487,7 @@ def find(self, name: str) -> DatasetRecordStorage | None: raise NotImplementedError() @abstractmethod - def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool]: + def register(self, datasetType: DatasetType) -> bool: """Ensure that this `Registry` can hold records for the given `DatasetType`, creating new tables as necessary. @@ -499,8 +499,6 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool Returns ------- - records : `DatasetRecordStorage` - The object representing the records for the given dataset type. inserted : `bool` `True` if the dataset type did not exist in the registry before. diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 5e03938b78..07eccf6ef8 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -697,8 +697,7 @@ def registerDatasetType(self, datasetType: DatasetType) -> bool: This method cannot be called within transactions, as it needs to be able to perform its own transaction to be concurrent. """ - _, inserted = self._managers.datasets.register(datasetType) - return inserted + return self._managers.datasets.register(datasetType) def removeDatasetType(self, name: str | tuple[str, ...]) -> None: """Remove the named `DatasetType` from the registry. From d174d63c4b867bd6bd5099378be3e4ea05d5070f Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Sun, 12 Nov 2023 19:52:00 -0800 Subject: [PATCH 08/11] Add special context manager to enable caching in registry. Adds special cache classes for collection and summary records and an additional structure that holds caches. New registry method is a context manager that enables caches temporarily for the duration of that context. --- python/lsst/daf/butler/_registry_shim.py | 6 + .../daf/butler/registry/_caching_context.py | 64 +++++++ .../registry/_collection_record_cache.py | 165 ++++++++++++++++++ .../registry/_collection_summary_cache.py | 86 +++++++++ python/lsst/daf/butler/registry/_registry.py | 6 + .../daf/butler/registry/collections/_base.py | 83 +++++---- .../butler/registry/collections/nameKey.py | 7 +- .../registry/collections/synthIntKey.py | 59 +------ .../datasets/byDimensions/_manager.py | 3 + .../datasets/byDimensions/summaries.py | 30 +++- .../registry/interfaces/_collections.py | 4 + .../butler/registry/interfaces/_datasets.py | 4 + python/lsst/daf/butler/registry/managers.py | 16 +- .../lsst/daf/butler/registry/sql_registry.py | 7 + 14 files changed, 434 insertions(+), 106 deletions(-) create mode 100644 python/lsst/daf/butler/registry/_caching_context.py create mode 100644 python/lsst/daf/butler/registry/_collection_record_cache.py create mode 100644 python/lsst/daf/butler/registry/_collection_summary_cache.py diff --git a/python/lsst/daf/butler/_registry_shim.py b/python/lsst/daf/butler/_registry_shim.py index 67f50a16e1..2cce959eba 100644 --- a/python/lsst/daf/butler/_registry_shim.py +++ b/python/lsst/daf/butler/_registry_shim.py @@ -102,6 +102,12 @@ def refresh(self) -> None: # Docstring inherited from a base class. self._registry.refresh() + @contextlib.contextmanager + def caching_context(self) -> Iterator[None]: + # Docstring inherited from a base class. + with self._registry.caching_context(): + yield + @contextlib.contextmanager def transaction(self, *, savepoint: bool = False) -> Iterator[None]: # Docstring inherited from a base class. diff --git a/python/lsst/daf/butler/registry/_caching_context.py b/python/lsst/daf/butler/registry/_caching_context.py new file mode 100644 index 0000000000..f217f21a98 --- /dev/null +++ b/python/lsst/daf/butler/registry/_caching_context.py @@ -0,0 +1,64 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ["CachingContext"] + +from ._collection_record_cache import CollectionRecordCache +from ._collection_summary_cache import CollectionSummaryCache + + +class CachingContext: + """Collection of caches for various types of records retrieved from + database. + + Notes + ----- + Caching is usually disabled for most of the record types, but it can be + explicitly and temporarily enabled in some context (e.g. quantum graph + building) using Registry method. This class is a collection of cache + instances which will be `None` when caching is disabled. Instance of this + class is passed to the relevant managers that can use it to query or + populate caches when caching is enabled. + """ + + collection_records: CollectionRecordCache | None = None + """Cache for collection records (`CollectionRecordCache`).""" + + collection_summaries: CollectionSummaryCache | None = None + """Cache for collection summary records (`CollectionSummaryCache`).""" + + def enable(self) -> None: + """Enable caches, initializes all caches.""" + self.collection_records = CollectionRecordCache() + self.collection_summaries = CollectionSummaryCache() + + def disable(self) -> None: + """Disable caches, sets all caches to `None`.""" + self.collection_records = None + self.collection_summaries = None diff --git a/python/lsst/daf/butler/registry/_collection_record_cache.py b/python/lsst/daf/butler/registry/_collection_record_cache.py new file mode 100644 index 0000000000..da00bb6a35 --- /dev/null +++ b/python/lsst/daf/butler/registry/_collection_record_cache.py @@ -0,0 +1,165 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("CollectionRecordCache",) + +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .interfaces import CollectionRecord + + +class CollectionRecordCache: + """Cache for collection records. + + Notes + ----- + This class stores collection records and can retrieve them using either + collection name or collection key. One complication is that key type can be + either collection name or a distinct integer value. To optimize storage + when the key is the same as collection name, this class only stores key to + record mapping when key is of a non-string type. + + In come contexts (e.g. ``resolve_wildcard``) a full list of collections is + needed. To signify that cache content can be used in such contexts, cache + defines special ``full`` flag that needs to be set by client. + """ + + def __init__(self) -> None: + self._by_name: dict[str, CollectionRecord] = {} + # This dict is only used for records whose key type is not str. + self._by_key: dict[Any, CollectionRecord] = {} + self._full = False + + @property + def full(self) -> bool: + """`True` if cache holds all known collection records (`bool`).""" + return self._full + + def add(self, record: CollectionRecord) -> None: + """Add one record to the cache. + + Parameters + ---------- + record : `CollectionRecord` + Collection record, replaces any existing record with the same name + or key. + """ + # In case we replace same record name with different key, find the + # existing record and drop its key. + if (old_record := self._by_name.get(record.name)) is not None: + self._by_key.pop(old_record.key) + if (old_record := self._by_key.get(record.key)) is not None: + self._by_name.pop(old_record.name) + self._by_name[record.name] = record + if not isinstance(record.key, str): + self._by_key[record.key] = record + + def set(self, records: Iterable[CollectionRecord], *, full: bool = False) -> None: + """Replace cache contents with the new set of records. + + Parameters + ---------- + records : `~collections.abc.Iterable` [`CollectionRecord`] + Collection records. + full : `bool` + If `True` then ``records`` contain all known collection records. + """ + self.clear() + for record in records: + self._by_name[record.name] = record + if not isinstance(record.key, str): + self._by_key[record.key] = record + self._full = full + + def clear(self) -> None: + """Remove all records from the cache.""" + self._by_name = {} + self._by_key = {} + self._full = False + + def discard(self, record: CollectionRecord) -> None: + """Remove single record from the cache. + + Parameters + ---------- + record : `CollectionRecord` + Collection record to remove. + """ + self._by_name.pop(record.name, None) + if not isinstance(record.key, str): + self._by_key.pop(record.key, None) + + def get_by_name(self, name: str) -> CollectionRecord | None: + """Return collection record given its name. + + Parameters + ---------- + name : `str` + Collection name. + + Returns + ------- + record : `CollectionRecord` or `None` + Collection record, `None` is returned if the name is not in the + cache. + """ + return self._by_name.get(name) + + def get_by_key(self, key: Any) -> CollectionRecord | None: + """Return collection record given its key. + + Parameters + ---------- + key : `Any` + Collection key. + + Returns + ------- + record : `CollectionRecord` or `None` + Collection record, `None` is returned if the key is not in the + cache. + """ + if isinstance(key, str): + return self._by_name.get(key) + return self._by_key.get(key) + + def records(self) -> Iterator[CollectionRecord]: + """Return iterator for the set of records in the cache, can only be + used if `full` is true. + + Raises + ------ + RuntimeError + Raised if ``self.full`` is `False`. + """ + if not self._full: + raise RuntimeError("cannot call records() if cache is not full") + return iter(self._by_name.values()) diff --git a/python/lsst/daf/butler/registry/_collection_summary_cache.py b/python/lsst/daf/butler/registry/_collection_summary_cache.py new file mode 100644 index 0000000000..ed5b2f2fa2 --- /dev/null +++ b/python/lsst/daf/butler/registry/_collection_summary_cache.py @@ -0,0 +1,86 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("CollectionSummaryCache",) + +from collections.abc import Iterable, Mapping +from typing import Any + +from ._collection_summary import CollectionSummary + + +class CollectionSummaryCache: + """Cache for collection summaries. + + Notes + ----- + This class stores `CollectionSummary` records indexed by collection keys. + For cache to be usable the records that are given to `update` method have + to include all dataset types, i.e. the query that produces records should + not be constrained by dataset type. + """ + + def __init__(self) -> None: + self._cache: dict[Any, CollectionSummary] = {} + + def update(self, summaries: Mapping[Any, CollectionSummary]) -> None: + """Add records to the cache. + + Parameters + ---------- + summaries : `~collections.abc.Mapping` [`Any`, `CollectionSummary`] + Summary records indexed by collection key, records must include all + dataset types. + """ + self._cache.update(summaries) + + def find_summaries(self, keys: Iterable[Any]) -> tuple[dict[Any, CollectionSummary], set[Any]]: + """Return summary records given a set of keys. + + Parameters + ---------- + keys : `~collections.abc.Iterable` [`Any`] + Sequence of collection keys. + + Returns + ------- + summaries : `dict` [`Any`, `CollectionSummary`] + Dictionary of summaries indexed by collection keys, includes + records found in the cache. + missing_keys : `set` [`Any`] + Collection keys that are not present in the cache. + """ + found = {} + not_found = set() + for key in keys: + if (summary := self._cache.get(key)) is not None: + found[key] = summary + else: + not_found.add(key) + return found, not_found diff --git a/python/lsst/daf/butler/registry/_registry.py b/python/lsst/daf/butler/registry/_registry.py index 2f0cb3231d..444a67eb54 100644 --- a/python/lsst/daf/butler/registry/_registry.py +++ b/python/lsst/daf/butler/registry/_registry.py @@ -118,6 +118,12 @@ def refresh(self) -> None: """ raise NotImplementedError() + @contextlib.contextmanager + @abstractmethod + def caching_context(self) -> Iterator[None]: + """Context manager that enables caching.""" + raise NotImplementedError() + @contextlib.contextmanager @abstractmethod def transaction(self, *, savepoint: bool = False) -> Iterator[None]: diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 1c3587d1ab..f39ee607e9 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -45,6 +45,7 @@ from ..wildcards import CollectionWildcard if TYPE_CHECKING: + from .._caching_context import CachingContext from ..interfaces import Database, DimensionRecordStorageManager @@ -190,28 +191,30 @@ def __init__( collectionIdName: str, *, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ): super().__init__(registry_schema_version=registry_schema_version) self._db = db self._tables = tables self._collectionIdName = collectionIdName - self._records: dict[K, CollectionRecord[K]] = {} # indexed by record ID self._dimensions = dimensions - self._full_fetch = False # True if cache contains everything. + self._caching_context = caching_context def refresh(self) -> None: # Docstring inherited from CollectionManager. - # We just reset the cache here but do not retrieve any records. - self._full_fetch = False - self._setRecordCache([]) + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.clear() - def _fetch_all(self) -> None: + def _fetch_all(self) -> list[CollectionRecord[K]]: """Retrieve all records into cache if not done so yet.""" - if not self._full_fetch: - records = self._fetch_by_key(None) - self._setRecordCache(records) - self._full_fetch = True + if self._caching_context.collection_records is not None: + if self._caching_context.collection_records.full: + return list(self._caching_context.collection_records.records()) + records = self._fetch_by_key(None) + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.set(records, full=True) + return records def register( self, name: str, type: CollectionType, doc: str | None = None @@ -278,12 +281,18 @@ def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]: """Return multiple records given their names.""" names = list(names) # To protect against potential races in cache updates. - records = {} - for name in names: - records[name] = self._get_cached_name(name) - fetch_names = [name for name, record in records.items() if record is None] - for record in self._fetch_by_name(fetch_names): - records[record.name] = record + records: dict[str, CollectionRecord | None] = {} + if self._caching_context.collection_records is not None: + for name in names: + records[name] = self._caching_context.collection_records.get_by_name(name) + fetch_names = [name for name, record in records.items() if record is None] + else: + fetch_names = list(names) + records = {name: None for name in fetch_names} + if fetch_names: + for record in self._fetch_by_name(fetch_names): + records[record.name] = record + self._addCachedRecord(record) missing_names = [name for name, record in records.items() if record is None] if len(missing_names) == 1: raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.") @@ -293,10 +302,14 @@ def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]: def __getitem__(self, key: Any) -> CollectionRecord[K]: # Docstring inherited from CollectionManager. - if (record := self._records.get(key)) is not None: - return record + if self._caching_context.collection_records is not None: + if (record := self._caching_context.collection_records.get_by_key(key)) is not None: + return record if records := self._fetch_by_key([key]): - return records[0] + record = records[0] + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.add(record) + return record else: raise MissingCollectionError(f"Collection with key '{key}' not found.") @@ -330,12 +343,8 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect result: list[CollectionRecord[K]] = [] - # If we have wildcard or ellipsis we need to read everything in memory. - if wildcard.patterns: - self._fetch_all() - if wildcard.patterns is ...: - for record in self._records.values(): + for record in self._fetch_all(): result.extend(resolve_nested(record, done)) del resolve_nested return result @@ -343,7 +352,7 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect for record in self._find_many(wildcard.strings): result.extend(resolve_nested(record, done)) if wildcard.patterns: - for record in self._records.values(): + for record in self._fetch_all(): if any(p.fullmatch(record.name) for p in wildcard.patterns): result.extend(resolve_nested(record, done)) del resolve_nested @@ -363,36 +372,26 @@ def setDocumentation(self, key: K, doc: str | None) -> None: # Docstring inherited from CollectionManager. self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) - def _setRecordCache(self, records: Iterable[CollectionRecord[K]]) -> None: - """Set internal record cache to contain given records, - old cached records will be removed. - """ - self._records = {} - for record in records: - self._records[record.key] = record - def _addCachedRecord(self, record: CollectionRecord[K]) -> None: """Add single record to cache.""" - self._records[record.key] = record + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.add(record) def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: """Remove single record from cache.""" - del self._records[record.key] + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.discard(record) def _getByName(self, name: str) -> CollectionRecord[K] | None: """Find collection record given collection name.""" - if (record := self._get_cached_name(name)) is not None: - return record + if self._caching_context.collection_records is not None: + if (record := self._caching_context.collection_records.get_by_name(name)) is not None: + return record records = self._fetch_by_name([name]) for record in records: self._addCachedRecord(record) return records[0] if records else None - @abstractmethod - def _get_cached_name(self, name: str) -> CollectionRecord[K] | None: - """Find cached collection record given its name.""" - raise NotImplementedError() - @abstractmethod def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]: """Fetch collection record from database given its name.""" diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index d8c50cce2f..9fc2e32271 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -45,6 +45,7 @@ ) if TYPE_CHECKING: + from .._caching_context import CachingContext from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext @@ -86,6 +87,7 @@ def initialize( context: StaticTablesContext, *, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> NameKeyCollectionManager: # Docstring inherited from CollectionManager. @@ -94,6 +96,7 @@ def initialize( tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore collectionIdName="name", dimensions=dimensions, + caching_context=caching_context, registry_schema_version=registry_schema_version, ) @@ -165,10 +168,6 @@ def getParentChains(self, key: str) -> set[str]: parent_names = set(sql_result.scalars().all()) return parent_names - def _get_cached_name(self, name: str) -> CollectionRecord[str] | None: - # Docstring inherited from base class. - return self._records.get(name) - def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: # Docstring inherited from base class. return self._fetch_by_key(names) diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index d2edcaae88..3b7dbb1de4 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -46,6 +46,7 @@ ) if TYPE_CHECKING: + from .._caching_context import CachingContext from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext @@ -77,40 +78,8 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> class SynthIntKeyCollectionManager(DefaultCollectionManager[int]): """A `CollectionManager` implementation that uses synthetic primary key (auto-incremented integer) for collections table. - - Most of the logic, including caching policy, is implemented in the base - class, this class only adds customizations specific to this particular - table schema. - - Parameters - ---------- - db : `Database` - Interface to the underlying database engine and namespace. - tables : `NamedTuple` - Named tuple of SQLAlchemy table objects. - collectionIdName : `str` - Name of the column in collections table that identifies it (PK). - dimensions : `DimensionRecordStorageManager` - Manager object for the dimensions in this `Registry`. """ - def __init__( - self, - db: Database, - tables: CollectionTablesTuple, - collectionIdName: str, - dimensions: DimensionRecordStorageManager, - registry_schema_version: VersionTuple | None = None, - ): - super().__init__( - db=db, - tables=tables, - collectionIdName=collectionIdName, - dimensions=dimensions, - registry_schema_version=registry_schema_version, - ) - self._nameCache: dict[str, CollectionRecord] = {} # indexed by collection name - @classmethod def initialize( cls, @@ -118,6 +87,7 @@ def initialize( context: StaticTablesContext, *, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> SynthIntKeyCollectionManager: # Docstring inherited from CollectionManager. @@ -126,6 +96,7 @@ def initialize( tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore collectionIdName="collection_id", dimensions=dimensions, + caching_context=caching_context, registry_schema_version=registry_schema_version, ) @@ -199,30 +170,6 @@ def getParentChains(self, key: int) -> set[str]: parent_names = set(sql_result.scalars().all()) return parent_names - def _setRecordCache(self, records: Iterable[CollectionRecord[int]]) -> None: - """Set internal record cache to contain given records, - old cached records will be removed. - """ - self._records = {} - self._nameCache = {} - for record in records: - self._records[record.key] = record - self._nameCache[record.name] = record - - def _addCachedRecord(self, record: CollectionRecord[int]) -> None: - """Add single record to cache.""" - self._records[record.key] = record - self._nameCache[record.name] = record - - def _removeCachedRecord(self, record: CollectionRecord[int]) -> None: - """Remove single record from cache.""" - del self._records[record.key] - del self._nameCache[record.name] - - def _get_cached_name(self, name: str) -> CollectionRecord[int] | None: - # Docstring inherited from base class. - return self._nameCache.get(name) - def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]: # Docstring inherited from base class. return self._fetch("name", names) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index 88e20c67f2..1cd5eeac63 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -32,6 +32,7 @@ ) if TYPE_CHECKING: + from ..._caching_context import CachingContext from ...interfaces import ( CollectionManager, CollectionRecord, @@ -147,6 +148,7 @@ def initialize( *, collections: CollectionManager, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> DatasetRecordStorageManager: # Docstring inherited from DatasetRecordStorageManager. @@ -160,6 +162,7 @@ def initialize( collections=collections, dimensions=dimensions, dataset_type_table=static.dataset_type, + caching_context=caching_context, ) return cls( db=db, diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py index 41687cb9c2..d051b4b38d 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py @@ -39,6 +39,7 @@ from ...._dataset_type import DatasetType from ...._named import NamedKeyDict, NamedKeyMapping from ....dimensions import GovernorDimension, addDimensionForeignKey +from ..._caching_context import CachingContext from ..._collection_summary import CollectionSummary from ..._collection_type import CollectionType from ...interfaces import ( @@ -135,6 +136,8 @@ class CollectionSummaryManager: Struct containing the tables that hold collection summaries. dataset_type_table : `sqlalchemy.schema.Table` Table containing dataset type definitions. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. """ def __init__( @@ -145,6 +148,7 @@ def __init__( dimensions: DimensionRecordStorageManager, tables: CollectionSummaryTables[sqlalchemy.schema.Table], dataset_type_table: sqlalchemy.schema.Table, + caching_context: CachingContext, ): self._db = db self._collections = collections @@ -152,6 +156,7 @@ def __init__( self._dimensions = dimensions self._tables = tables self._dataset_type_table = dataset_type_table + self._caching_context = caching_context @classmethod def initialize( @@ -162,6 +167,7 @@ def initialize( collections: CollectionManager, dimensions: DimensionRecordStorageManager, dataset_type_table: sqlalchemy.schema.Table, + caching_context: CachingContext, ) -> CollectionSummaryManager: """Create all summary tables (or check that they have been created), returning an object to manage them. @@ -179,6 +185,8 @@ def initialize( Manager object for the dimensions in this `Registry`. dataset_type_table : `sqlalchemy.schema.Table` Table containing dataset type definitions. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. Returns ------- @@ -201,6 +209,7 @@ def initialize( dimensions=dimensions, tables=tables, dataset_type_table=dataset_type_table, + caching_context=caching_context, ) def update( @@ -271,6 +280,17 @@ def fetch_summaries( will also contain all nested non-chained collections of the chained collections. """ + summaries: dict[Any, CollectionSummary] = {} + # Check what we have in cache first. + if self._caching_context.collection_summaries is not None: + summaries, missing_keys = self._caching_context.collection_summaries.find_summaries( + [record.key for record in collections] + ) + if not missing_keys: + return summaries + else: + collections = [record for record in collections if record.key in missing_keys] + # Need to expand all chained collections first. non_chains: list[CollectionRecord] = [] chains: dict[CollectionRecord, list[CollectionRecord]] = {} @@ -307,13 +327,14 @@ def fetch_summaries( sql = sqlalchemy.sql.select(*columns).select_from(fromClause) sql = sql.where(coll_col.in_([coll.key for coll in non_chains])) - if dataset_type_names is not None: - sql = sql.where(self._dataset_type_table.columns["name"].in_(dataset_type_names)) + # For caching we need to fetch complete summaries. + if self._caching_context.collection_summaries is None: + if dataset_type_names is not None: + sql = sql.where(self._dataset_type_table.columns["name"].in_(dataset_type_names)) # Run the query and construct CollectionSummary objects from the result # rows. This will never include CHAINED collections or collections # with no datasets. - summaries: dict[Any, CollectionSummary] = {} with self._db.query(sql) as sql_result: sql_rows = sql_result.mappings().fetchall() dataset_type_ids: dict[int, DatasetType] = {} @@ -350,4 +371,7 @@ def fetch_summaries( for chain, children in chains.items(): summaries[chain.key] = CollectionSummary.union(*(summaries[child.key] for child in children)) + if self._caching_context.collection_summaries is not None: + self._caching_context.collection_summaries.update(summaries) + return summaries diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index c07b894adc..837ede94e0 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -45,6 +45,7 @@ from ._versioning import VersionedExtension, VersionTuple if TYPE_CHECKING: + from .._caching_context import CachingContext from ._database import Database, StaticTablesContext from ._dimensions import DimensionRecordStorageManager @@ -214,6 +215,7 @@ def initialize( context: StaticTablesContext, *, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> CollectionManager: """Construct an instance of the manager. @@ -228,6 +230,8 @@ def initialize( implemented with this manager. dimensions : `DimensionRecordStorageManager` Manager object for the dimensions in this `Registry`. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. registry_schema_version : `VersionTuple` or `None` Schema version of this extension as defined in registry. diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py index 8bafb02274..ac459f3c8b 100644 --- a/python/lsst/daf/butler/registry/interfaces/_datasets.py +++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py @@ -45,6 +45,7 @@ from ._versioning import VersionedExtension, VersionTuple if TYPE_CHECKING: + from .._caching_context import CachingContext from .._collection_summary import CollectionSummary from ..queries import SqlQueryContext from ._collections import CollectionManager, CollectionRecord, RunRecord @@ -329,6 +330,7 @@ def initialize( *, collections: CollectionManager, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> DatasetRecordStorageManager: """Construct an instance of the manager. @@ -344,6 +346,8 @@ def initialize( Manager object for the collections in this `Registry`. dimensions : `DimensionRecordStorageManager` Manager object for the dimensions in this `Registry`. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. registry_schema_version : `VersionTuple` or `None` Schema version of this extension as defined in registry. diff --git a/python/lsst/daf/butler/registry/managers.py b/python/lsst/daf/butler/registry/managers.py index 1d80fcde51..48702c2fc9 100644 --- a/python/lsst/daf/butler/registry/managers.py +++ b/python/lsst/daf/butler/registry/managers.py @@ -45,6 +45,7 @@ from .._column_type_info import ColumnTypeInfo from .._config import Config from ..dimensions import DimensionConfig, DimensionUniverse +from ._caching_context import CachingContext from ._config import RegistryConfig from .interfaces import ( ButlerAttributeManager, @@ -353,6 +354,11 @@ class RegistryManagerInstances( and registry instances, including the dimension universe. """ + caching_context: CachingContext + """Object containing caches for for various information generated by + managers. + """ + @classmethod def initialize( cls, @@ -361,6 +367,7 @@ def initialize( *, types: RegistryManagerTypes, universe: DimensionUniverse, + caching_context: CachingContext | None = None, ) -> RegistryManagerInstances: """Construct manager instances from their types and an existing database connection. @@ -383,6 +390,8 @@ def initialize( instances : `RegistryManagerInstances` Struct containing manager instances. """ + if caching_context is None: + caching_context = CachingContext() dummy_table = ddl.TableSpec(fields=()) kwargs: dict[str, Any] = {} schema_versions = types.schema_versions @@ -396,6 +405,7 @@ def initialize( database, context, dimensions=kwargs["dimensions"], + caching_context=caching_context, registry_schema_version=schema_versions.get("collections"), ) datasets = types.datasets.initialize( @@ -404,6 +414,7 @@ def initialize( collections=kwargs["collections"], dimensions=kwargs["dimensions"], registry_schema_version=schema_versions.get("datasets"), + caching_context=caching_context, ) kwargs["datasets"] = datasets kwargs["opaque"] = types.opaque.initialize( @@ -440,6 +451,7 @@ def initialize( run_key_spec=types.collections.addRunForeignKey(dummy_table, primaryKey=False, nullable=False), ingest_date_dtype=datasets.ingest_date_dtype(), ) + kwargs["caching_context"] = caching_context return cls(**kwargs) def as_dict(self) -> Mapping[str, VersionedExtension]: @@ -453,7 +465,9 @@ def as_dict(self) -> Mapping[str, VersionedExtension]: manager instance. Only existing managers are returned. """ instances = { - f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.name != "column_types" + f.name: getattr(self, f.name) + for f in dataclasses.fields(self) + if f.name not in ("column_types", "caching_context") } return {key: value for key, value in instances.items() if value is not None} diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 07eccf6ef8..63c22a97d7 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -324,6 +324,13 @@ def refresh(self) -> None: with self._db.transaction(): self._managers.refresh() + @contextlib.contextmanager + def caching_context(self) -> Iterator[None]: + """Context manager that enables caching.""" + self._managers.caching_context.enable() + yield + self._managers.caching_context.disable() + @contextlib.contextmanager def transaction(self, *, savepoint: bool = False) -> Iterator[None]: """Return a context manager that represents a transaction.""" From 254b87a1efda452146df346ff01c789e0b015b8c Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Mon, 13 Nov 2023 14:22:01 -0800 Subject: [PATCH 09/11] Re-enable dataset type caching using new class DatasetTypeCache. Unlike collection caches, dataset type cache is always on, this helps to reduce number of queries in `pipetask run` without the need to explicitly enable caching in multiple places. --- .../daf/butler/registry/_caching_context.py | 15 ++ .../butler/registry/_dataset_type_cache.py | 162 ++++++++++++++++++ .../datasets/byDimensions/_manager.py | 66 ++++++- 3 files changed, 235 insertions(+), 8 deletions(-) create mode 100644 python/lsst/daf/butler/registry/_dataset_type_cache.py diff --git a/python/lsst/daf/butler/registry/_caching_context.py b/python/lsst/daf/butler/registry/_caching_context.py index f217f21a98..2674a54f69 100644 --- a/python/lsst/daf/butler/registry/_caching_context.py +++ b/python/lsst/daf/butler/registry/_caching_context.py @@ -29,8 +29,14 @@ __all__ = ["CachingContext"] +from typing import TYPE_CHECKING + from ._collection_record_cache import CollectionRecordCache from ._collection_summary_cache import CollectionSummaryCache +from ._dataset_type_cache import DatasetTypeCache + +if TYPE_CHECKING: + from .interfaces import DatasetRecordStorage class CachingContext: @@ -45,6 +51,9 @@ class CachingContext: instances which will be `None` when caching is disabled. Instance of this class is passed to the relevant managers that can use it to query or populate caches when caching is enabled. + + Dataset type cache is always enabled for now, this avoids the need for + explicitly enabling caching in pipetask executors. """ collection_records: CollectionRecordCache | None = None @@ -53,6 +62,12 @@ class is passed to the relevant managers that can use it to query or collection_summaries: CollectionSummaryCache | None = None """Cache for collection summary records (`CollectionSummaryCache`).""" + dataset_types: DatasetTypeCache[DatasetRecordStorage] + """Cache for dataset types, never disabled (`DatasetTypeCache`).""" + + def __init__(self) -> None: + self.dataset_types = DatasetTypeCache() + def enable(self) -> None: """Enable caches, initializes all caches.""" self.collection_records = CollectionRecordCache() diff --git a/python/lsst/daf/butler/registry/_dataset_type_cache.py b/python/lsst/daf/butler/registry/_dataset_type_cache.py new file mode 100644 index 0000000000..3f1665dfa3 --- /dev/null +++ b/python/lsst/daf/butler/registry/_dataset_type_cache.py @@ -0,0 +1,162 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("DatasetTypeCache",) + +from collections.abc import Iterable, Iterator +from typing import Generic, TypeVar + +from .._dataset_type import DatasetType + +_T = TypeVar("_T") + + +class DatasetTypeCache(Generic[_T]): + """Cache for dataset types. + + Notes + ----- + This class caches mapping of dataset type name to a corresponding + `DatasetType` instance. Registry manager also needs to cache corresponding + "storage" instance, so this class allows storing additional opaque object + along with the dataset type. + + In come contexts (e.g. ``resolve_wildcard``) a full list of dataset types + is needed. To signify that cache content can be used in such contexts, + cache defines special ``full`` flag that needs to be set by client. + """ + + def __init__(self) -> None: + self._cache: dict[str, tuple[DatasetType, _T | None]] = {} + self._full = False + + @property + def full(self) -> bool: + """`True` if cache holds all known dataset types (`bool`).""" + return self._full + + def add(self, dataset_type: DatasetType, extra: _T | None = None) -> None: + """Add one record to the cache. + + Parameters + ---------- + dataset_type : `DatasetType` + Dataset type, replaces any existing dataset type with the same + name. + extra : `Any`, optional + Additional opaque object stored with this dataset type. + """ + self._cache[dataset_type.name] = (dataset_type, extra) + + def set(self, data: Iterable[DatasetType | tuple[DatasetType, _T | None]], *, full: bool = False) -> None: + """Replace cache contents with the new set of dataset types. + + Parameters + ---------- + data : `~collections.abc.Iterable` + Sequence of `DatasetType` instances or tuples of `DatasetType` and + an extra opaque object. + full : `bool` + If `True` then ``data`` contains all known dataset types. + """ + self.clear() + for item in data: + if isinstance(item, DatasetType): + item = (item, None) + self._cache[item[0].name] = item + self._full = full + + def clear(self) -> None: + """Remove everything from the cache.""" + self._cache = {} + self._full = False + + def discard(self, name: str) -> None: + """Remove named dataset type from the cache. + + Parameters + ---------- + name : `str` + Name of the dataset type to remove. + """ + self._cache.pop(name, None) + + def get(self, name: str) -> tuple[DatasetType | None, _T | None]: + """Return cached info given dataset type name. + + Parameters + ---------- + name : `str` + Dataset type name. + + Returns + ------- + dataset_type : `DatasetType` or `None` + Cached dataset type, `None` is returned if the name is not in the + cache. + extra : `Any` or `None` + Cached opaque data, `None` is returned if the name is not in the + cache or no extra info was stored for this dataset type. + """ + item = self._cache.get(name) + if item is None: + return (None, None) + return item + + def get_dataset_type(self, name: str) -> DatasetType | None: + """Return dataset type given its name. + + Parameters + ---------- + name : `str` + Dataset type name. + + Returns + ------- + dataset_type : `DatasetType` or `None` + Cached dataset type, `None` is returned if the name is not in the + cache. + """ + item = self._cache.get(name) + if item is None: + return None + return item[0] + + def items(self) -> Iterator[tuple[DatasetType, _T | None]]: + """Return iterator for the set of items in the cache, can only be + used if `full` is true. + + Raises + ------ + RuntimeError + Raised if ``self.full`` is `False`. + """ + if not self._full: + raise RuntimeError("cannot call items() if cache is not full") + return iter(self._cache.values()) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index 1cd5eeac63..692d8585a3 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -121,6 +121,8 @@ class ByDimensionsDatasetRecordStorageManagerBase(DatasetRecordStorageManager): tables used by this class. summaries : `CollectionSummaryManager` Structure containing tables that summarize the contents of collections. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. """ def __init__( @@ -131,6 +133,7 @@ def __init__( dimensions: DimensionRecordStorageManager, static: StaticDatasetTablesTuple, summaries: CollectionSummaryManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ): super().__init__(registry_schema_version=registry_schema_version) @@ -139,6 +142,7 @@ def __init__( self._dimensions = dimensions self._static = static self._summaries = summaries + self._caching_context = caching_context @classmethod def initialize( @@ -170,6 +174,7 @@ def initialize( dimensions=dimensions, static=static, summaries=summaries, + caching_context=caching_context, registry_schema_version=registry_schema_version, ) @@ -237,7 +242,8 @@ def addDatasetForeignKey( def refresh(self) -> None: # Docstring inherited from DatasetRecordStorageManager. - pass + if self._caching_context.dataset_types is not None: + self._caching_context.dataset_types.clear() def _make_storage(self, record: _DatasetTypeRecord) -> ByDimensionsDatasetRecordStorage: """Create storage instance for a dataset type record.""" @@ -286,8 +292,28 @@ def remove(self, name: str) -> None: def find(self, name: str) -> DatasetRecordStorage | None: # Docstring inherited from DatasetRecordStorageManager. + if self._caching_context.dataset_types is not None: + _, storage = self._caching_context.dataset_types.get(name) + if storage is not None: + return storage + else: + # On the first cache miss populate the cache with complete list + # of dataset types (if it was not done yet). + if not self._caching_context.dataset_types.full: + self._fetch_dataset_types() + # Try again + _, storage = self._caching_context.dataset_types.get(name) + if self._caching_context.dataset_types.full: + # If not in cache then dataset type is not defined. + return storage record = self._fetch_dataset_type_record(name) - return self._make_storage(record) if record is not None else None + if record is not None: + storage = self._make_storage(record) + if self._caching_context.dataset_types is not None: + self._caching_context.dataset_types.add(storage.datasetType, storage) + return storage + else: + return None def register(self, datasetType: DatasetType) -> bool: # Docstring inherited from DatasetRecordStorageManager. @@ -316,7 +342,7 @@ def register(self, datasetType: DatasetType) -> bool: self.getIdColumnType(), ), ) - _, inserted = self._db.sync( + row, inserted = self._db.sync( self._static.dataset_type, keys={"name": datasetType.name}, compared={ @@ -331,6 +357,16 @@ def register(self, datasetType: DatasetType) -> bool: }, returning=["id", "tag_association_table"], ) + # Make sure that cache is updated + if self._caching_context.dataset_types is not None and row is not None: + record = _DatasetTypeRecord( + dataset_type=datasetType, + dataset_type_id=row["id"], + tag_table_name=tagTableName, + calib_table_name=calibTableName, + ) + storage = self._make_storage(record) + self._caching_context.dataset_types.add(datasetType, storage) else: if datasetType != record.dataset_type: raise ConflictingDefinitionError( @@ -338,9 +374,7 @@ def register(self, datasetType: DatasetType) -> bool: f"with database definition {record.dataset_type}." ) inserted = False - # TODO: We return storage instance from this method, but the only - # client that uses this method ignores it. Maybe we should drop it - # and avoid making storage instance above. + return bool(inserted) def resolve_wildcard( @@ -472,7 +506,15 @@ def getDatasetRef(self, id: DatasetId) -> DatasetRef | None: row = sql_result.mappings().fetchone() if row is None: return None - storage = self._make_storage(self._record_from_row(row)) + record = self._record_from_row(row) + storage: DatasetRecordStorage | None = None + if self._caching_context.dataset_types is not None: + _, storage = self._caching_context.dataset_types.get(record.dataset_type.name) + if storage is None: + storage = self._make_storage(record) + if self._caching_context.dataset_types is not None: + self._caching_context.dataset_types.add(storage.datasetType, storage) + assert isinstance(storage, ByDimensionsDatasetRecordStorage), "Not expected storage class" return DatasetRef( storage.datasetType, dataId=storage.getDataId(id=id), @@ -516,9 +558,17 @@ def _dataset_type_from_row(self, row: Mapping) -> DatasetType: def _fetch_dataset_types(self) -> list[DatasetType]: """Fetch list of all defined dataset types.""" + if self._caching_context.dataset_types is not None: + if self._caching_context.dataset_types.full: + return [dataset_type for dataset_type, _ in self._caching_context.dataset_types.items()] with self._db.query(self._static.dataset_type.select()) as sql_result: sql_rows = sql_result.mappings().fetchall() - return [self._record_from_row(row).dataset_type for row in sql_rows] + records = [self._record_from_row(row) for row in sql_rows] + # Cache everything and specify that cache is complete. + if self._caching_context.dataset_types is not None: + cache_data = [(record.dataset_type, self._make_storage(record)) for record in records] + self._caching_context.dataset_types.set(cache_data, full=True) + return [record.dataset_type for record in records] def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummary: # Docstring inherited from DatasetRecordStorageManager. From eaa968d4d697e109dbd6a4b4365062f48ad3b31e Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Mon, 13 Nov 2023 15:21:17 -0800 Subject: [PATCH 10/11] Typing fixes for mypy 1.7 --- python/lsst/daf/butler/_named.py | 4 ++-- python/lsst/daf/butler/persistence_context.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/lsst/daf/butler/_named.py b/python/lsst/daf/butler/_named.py index 55850645bd..e3ff851b72 100644 --- a/python/lsst/daf/butler/_named.py +++ b/python/lsst/daf/butler/_named.py @@ -266,7 +266,7 @@ def freeze(self) -> NamedKeyMapping[K, V]: to a new variable (and considering any previous references invalidated) should allow for more accurate static type checking. """ - if not isinstance(self._dict, MappingProxyType): + if not isinstance(self._dict, MappingProxyType): # type: ignore[unreachable] self._dict = MappingProxyType(self._dict) # type: ignore return self @@ -578,7 +578,7 @@ def freeze(self) -> NamedValueAbstractSet[K]: to a new variable (and considering any previous references invalidated) should allow for more accurate static type checking. """ - if not isinstance(self._mapping, MappingProxyType): + if not isinstance(self._mapping, MappingProxyType): # type: ignore[unreachable] self._mapping = MappingProxyType(self._mapping) # type: ignore return self diff --git a/python/lsst/daf/butler/persistence_context.py b/python/lsst/daf/butler/persistence_context.py index b366564d45..8830fd9550 100644 --- a/python/lsst/daf/butler/persistence_context.py +++ b/python/lsst/daf/butler/persistence_context.py @@ -33,7 +33,7 @@ import uuid from collections.abc import Callable, Hashable from contextvars import Context, ContextVar, Token, copy_context -from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast +from typing import TYPE_CHECKING, ParamSpec, TypeVar if TYPE_CHECKING: from ._dataset_ref import DatasetRef @@ -198,4 +198,4 @@ def run(self, function: Callable[_Q, _T], *args: _Q.args, **kwargs: _Q.kwargs) - # cast the result as we know this is exactly what the return type will # be. result = self._ctx.run(self._functionRunner, function, *args, **kwargs) # type: ignore - return cast(_T, result) + return result From 922359ce8e266e057c40a3cd47089f66ac0f6b74 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Tue, 14 Nov 2023 10:19:16 -0800 Subject: [PATCH 11/11] Add `Butler._caching_context` method. Registry shiim now redirects to this method instead of Registry method. RemoteButler raises NotImplementedError but may do something non-trivial later when we know how caching is going to work with client/server. --- python/lsst/daf/butler/_butler.py | 5 +++++ python/lsst/daf/butler/_registry_shim.py | 6 ++---- python/lsst/daf/butler/direct_butler.py | 4 ++++ python/lsst/daf/butler/registry/_registry.py | 3 +-- python/lsst/daf/butler/remote_butler/_remote_butler.py | 6 ++++++ 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py index f83a4ad347..937ff8e6c9 100644 --- a/python/lsst/daf/butler/_butler.py +++ b/python/lsst/daf/butler/_butler.py @@ -482,6 +482,11 @@ def get_known_repos(cls) -> set[str]: """ return ButlerRepoIndex.get_known_repos() + @abstractmethod + def _caching_context(self) -> AbstractContextManager[None]: + """Context manager that enables caching.""" + raise NotImplementedError() + @abstractmethod def transaction(self) -> AbstractContextManager[None]: """Context manager supporting `Butler` transactions. diff --git a/python/lsst/daf/butler/_registry_shim.py b/python/lsst/daf/butler/_registry_shim.py index 2cce959eba..4d2653abe0 100644 --- a/python/lsst/daf/butler/_registry_shim.py +++ b/python/lsst/daf/butler/_registry_shim.py @@ -102,11 +102,9 @@ def refresh(self) -> None: # Docstring inherited from a base class. self._registry.refresh() - @contextlib.contextmanager - def caching_context(self) -> Iterator[None]: + def caching_context(self) -> contextlib.AbstractContextManager[None]: # Docstring inherited from a base class. - with self._registry.caching_context(): - yield + return self._butler._caching_context() @contextlib.contextmanager def transaction(self, *, savepoint: bool = False) -> Iterator[None]: diff --git a/python/lsst/daf/butler/direct_butler.py b/python/lsst/daf/butler/direct_butler.py index 6b70ecb1e1..80bcf5d1ae 100644 --- a/python/lsst/daf/butler/direct_butler.py +++ b/python/lsst/daf/butler/direct_butler.py @@ -299,6 +299,10 @@ def isWriteable(self) -> bool: # Docstring inherited. return self._registry.isWriteable() + def _caching_context(self) -> contextlib.AbstractContextManager[None]: + """Context manager that enables caching.""" + return self._registry.caching_context() + @contextlib.contextmanager def transaction(self) -> Iterator[None]: """Context manager supporting `Butler` transactions. diff --git a/python/lsst/daf/butler/registry/_registry.py b/python/lsst/daf/butler/registry/_registry.py index 444a67eb54..21e651314c 100644 --- a/python/lsst/daf/butler/registry/_registry.py +++ b/python/lsst/daf/butler/registry/_registry.py @@ -118,9 +118,8 @@ def refresh(self) -> None: """ raise NotImplementedError() - @contextlib.contextmanager @abstractmethod - def caching_context(self) -> Iterator[None]: + def caching_context(self) -> contextlib.AbstractContextManager[None]: """Context manager that enables caching.""" raise NotImplementedError() diff --git a/python/lsst/daf/butler/remote_butler/_remote_butler.py b/python/lsst/daf/butler/remote_butler/_remote_butler.py index 841735a28b..1930461ec0 100644 --- a/python/lsst/daf/butler/remote_butler/_remote_butler.py +++ b/python/lsst/daf/butler/remote_butler/_remote_butler.py @@ -158,6 +158,12 @@ def _simplify_dataId( # Assume we can treat it as a dict. return SerializedDataCoordinate(dataId=data_id) + def _caching_context(self) -> AbstractContextManager[None]: + # Docstring inherited. + # Not implemented for now, will have to think whether this needs to + # do something on client side and/or remote side. + raise NotImplementedError() + def transaction(self) -> AbstractContextManager[None]: """Will always raise NotImplementedError. Transactions are not supported by RemoteButler.