diff --git a/python/lsst/daf/butler/direct_query_driver/__init__.py b/python/lsst/daf/butler/direct_query_driver/__init__.py index f717d4d3ad..ab7c7faf7d 100644 --- a/python/lsst/daf/butler/direct_query_driver/__init__.py +++ b/python/lsst/daf/butler/direct_query_driver/__init__.py @@ -27,4 +27,4 @@ from ._driver import DirectQueryDriver from ._postprocessing import Postprocessing -from ._query_builder import QueryBuilder, QueryJoiner +from ._sql_builders import SqlJoinsBuilder, SqlSelectBuilder diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index 56835849da..acf27a2ae6 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -27,17 +27,18 @@ from __future__ import annotations -import uuid - __all__ = ("DirectQueryDriver",) import dataclasses import itertools import logging import sys +import uuid +from collections import defaultdict from collections.abc import Iterable, Iterator, Mapping, Set from contextlib import ExitStack -from typing import TYPE_CHECKING, Any, cast, overload +from types import EllipsisType +from typing import TYPE_CHECKING, Any, TypeVar, cast, overload import sqlalchemy @@ -69,14 +70,14 @@ from ..registry.wildcards import CollectionWildcard from ._postprocessing import Postprocessing from ._predicate_constraints_summary import PredicateConstraintsSummary -from ._query_builder import QueryBuilder, QueryJoiner -from ._query_plan import ( - QueryFindFirstPlan, - QueryJoinsPlan, - QueryPlan, - QueryProjectionPlan, +from ._query_analysis import ( + QueryCollectionAnalysis, + QueryFindFirstAnalysis, + QueryJoinsAnalysis, + QueryTreeAnalysis, ResolvedDatasetSearch, ) +from ._query_builder import QueryBuilder, SingleSelectQueryBuilder, UnionQueryBuilder from ._result_page_converter import ( DataCoordinateResultPageConverter, DatasetRefResultPageConverter, @@ -85,6 +86,7 @@ ResultPageConverter, ResultPageConverterContext, ) +from ._sql_builders import SqlJoinsBuilder, SqlSelectBuilder, make_table_spec from ._sql_column_visitor import SqlColumnVisitor if TYPE_CHECKING: @@ -93,6 +95,8 @@ _LOG = logging.getLogger(__name__) +_T = TypeVar("_T", bound=str | EllipsisType) + class DirectQueryDriver(QueryDriver): """The `QueryDriver` implementation for `DirectButler`. @@ -209,27 +213,26 @@ def execute(self, result_spec: ResultSpec, tree: qt.QueryTree) -> Iterator[Resul # Docstring inherited. if self._exit_stack is None: raise RuntimeError("QueryDriver context must be entered before queries can be executed.") - plan = self.build_query( + builder = self.build_query( tree, final_columns=result_spec.get_result_columns(), order_by=result_spec.order_by, find_first_dataset=result_spec.find_first_dataset, ) - builder = plan.builder - sql_select = builder.select(plan.postprocessing) + sql_select, sql_columns = builder.finish_select() if result_spec.order_by: - visitor = SqlColumnVisitor(builder.joiner, self) + visitor = SqlColumnVisitor(sql_columns, self) sql_select = sql_select.order_by(*[visitor.expect_scalar(term) for term in result_spec.order_by]) if result_spec.limit is not None: - if plan.postprocessing: - plan.postprocessing.limit = result_spec.limit + if builder.postprocessing: + builder.postprocessing.limit = result_spec.limit else: sql_select = sql_select.limit(result_spec.limit) - if plan.postprocessing.limit is not None: + if builder.postprocessing.limit is not None: # We might want to fetch many fewer rows than the default page # size if we have to implement limit in postprocessing. raw_page_size = min( - self._postprocessing_filter_factor * plan.postprocessing.limit, + self._postprocessing_filter_factor * builder.postprocessing.limit, self._raw_page_size, ) else: @@ -239,9 +242,9 @@ def execute(self, result_spec: ResultSpec, tree: qt.QueryTree) -> Iterator[Resul cursor = _Cursor( self.db, sql_select, - postprocessing=plan.postprocessing, + postprocessing=builder.postprocessing, raw_page_size=raw_page_size, - page_converter=self._create_result_page_converter(result_spec, builder), + page_converter=self._create_result_page_converter(result_spec, builder.final_columns), ) # Since this function isn't a context manager and the caller could stop # iterating before we retrieve all the results, we have to track open @@ -263,10 +266,10 @@ def _read_results(self, cursor: _Cursor) -> Iterator[ResultPage]: self._cursors.discard(cursor) cursor.close() - def _create_result_page_converter(self, spec: ResultSpec, builder: QueryBuilder) -> ResultPageConverter: + def _create_result_page_converter(self, spec: ResultSpec, columns: qt.ColumnSet) -> ResultPageConverter: context = ResultPageConverterContext( db=self.db, - column_order=builder.columns.get_column_order(), + column_order=columns.get_column_order(), dimension_record_cache=self._dimension_record_cache, ) match spec: @@ -294,7 +297,6 @@ def materialize( if self._exit_stack is None: raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.") plan = self.build_query(tree, qt.ColumnSet(dimensions)) - builder = plan.builder # Current implementation ignores 'datasets' aside from remembering # them, because figuring out what to put in the temporary table for # them is tricky, especially if calibration collections are involved. @@ -308,9 +310,9 @@ def materialize( # search is straightforward and definitely well-indexed, and not much # (if at all) worse than joining back in on a materialized UUID. # - sql_select = builder.select(plan.postprocessing) + sql_select, _ = plan.finish_select(return_columns=False) table = self._exit_stack.enter_context( - self.db.temporary_table(builder.make_table_spec(plan.postprocessing)) + self.db.temporary_table(make_table_spec(plan.final_columns, self.db, plan.postprocessing)) ) self.db.insert(table, select=sql_select) if key is None: @@ -337,10 +339,12 @@ def upload_data_coordinates( if not columns: table_spec.fields.add( ddl.FieldSpec( - QueryBuilder.EMPTY_COLUMNS_NAME, dtype=QueryBuilder.EMPTY_COLUMNS_TYPE, nullable=True + SqlSelectBuilder.EMPTY_COLUMNS_NAME, + dtype=SqlSelectBuilder.EMPTY_COLUMNS_TYPE, + nullable=True, ) ) - dict_rows = [{QueryBuilder.EMPTY_COLUMNS_NAME: None}] + dict_rows = [{SqlSelectBuilder.EMPTY_COLUMNS_NAME: None}] else: dict_rows = [dict(zip(dimensions.required, values)) for values in rows] from_clause: sqlalchemy.FromClause @@ -364,31 +368,31 @@ def count( ) -> int: # Docstring inherited. columns = result_spec.get_result_columns() - plan = self.build_query(tree, columns, find_first_dataset=result_spec.find_first_dataset) - builder = plan.builder - if not all(d.collection_records for d in plan.joins.datasets.values()): + builder = self.build_query(tree, columns, find_first_dataset=result_spec.find_first_dataset) + if not all(d.collection_records for d in builder.joins_analysis.datasets.values()): return 0 + # No need to do similar check on if not exact: - plan.postprocessing = Postprocessing() - if plan.postprocessing: + builder.postprocessing = Postprocessing() + if builder.postprocessing: if not discard: raise InvalidQueryError("Cannot count query rows exactly without discarding them.") - sql_select = builder.select(plan.postprocessing) - plan.postprocessing.limit = result_spec.limit + sql_select, _ = builder.finish_select(return_columns=False) + builder.postprocessing.limit = result_spec.limit n = 0 with self.db.query(sql_select.execution_options(yield_per=self._raw_page_size)) as results: - for _ in plan.postprocessing.apply(results): + for _ in builder.postprocessing.apply(results): n += 1 return n - # If the query has DISTINCT or GROUP BY, nest it in a subquery so we - # count deduplicated rows. - builder = builder.nested(postprocessing=plan.postprocessing) + # If the query has DISTINCT, GROUP BY, or UNION [ALL], nest it in a + # subquery so we count deduplicated rows. + select_builder = builder.finish_nested() # Replace the columns of the query with just COUNT(*). - builder.columns = qt.ColumnSet(self._universe.empty) + select_builder.columns = qt.ColumnSet(self._universe.empty) count_func: sqlalchemy.ColumnElement[int] = sqlalchemy.func.count() - builder.joiner.special["_ROWCOUNT"] = count_func + select_builder.joins.special["_ROWCOUNT"] = count_func # Render and run the query. - sql_select = builder.select(plan.postprocessing) + sql_select = select_builder.select(builder.postprocessing) with self.db.query(sql_select) as result: count = cast(int, result.scalar()) if result_spec.limit is not None: @@ -397,31 +401,30 @@ def count( def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool: # Docstring inherited. - plan = self.build_query(tree, qt.ColumnSet(tree.dimensions)) - builder = plan.builder - if not all(d.collection_records for d in plan.joins.datasets.values()): + builder = self.build_query(tree, qt.ColumnSet(tree.dimensions)) + if not all(d.collection_records for d in builder.joins_analysis.datasets.values()): return False if not execute: if exact: raise InvalidQueryError("Cannot obtain exact result for 'any' without executing.") return True - if plan.postprocessing and exact: - sql_select = builder.select(plan.postprocessing) + if builder.postprocessing and exact: + sql_select, _ = builder.finish_select(return_columns=False) with self.db.query( sql_select.execution_options(yield_per=self._postprocessing_filter_factor) ) as result: - for _ in plan.postprocessing.apply(result): + for _ in builder.postprocessing.apply(result): return True return False - sql_select = builder.select(plan.postprocessing).limit(1) - with self.db.query(sql_select) as result: + sql_select, _ = builder.finish_select() + with self.db.query(sql_select.limit(1)) as result: return result.first() is not None def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]: # Docstring inherited. - plan = self.analyze_query(tree, qt.ColumnSet(tree.dimensions)) - if plan.joins.messages or not execute: - return plan.joins.messages + plan = self.build_query(tree, qt.ColumnSet(tree.dimensions), analyze_only=True) + if plan.joins_analysis.messages or not execute: + return plan.joins_analysis.messages # TODO: guess at ways to split up query that might fail or succeed if # run separately, execute them with LIMIT 1 and report the results. return [] @@ -440,10 +443,13 @@ def build_query( self, tree: qt.QueryTree, final_columns: qt.ColumnSet, + *, order_by: Iterable[qt.OrderExpression] = (), - find_first_dataset: str | None = None, - ) -> QueryPlan: - """Convert a query description into a mostly-completed `QueryBuilder`. + find_first_dataset: str | EllipsisType | None = None, + analyze_only: bool = False, + ) -> QueryBuilder: + """Convert a query description into a nearly-complete builder object + for the SQL version of that query. Parameters ---------- @@ -454,140 +460,304 @@ def build_query( order_by : `~collections.abc.Iterable` [ \ `.queries.tree.OrderExpression` ], optional Column expressions to sort by. - find_first_dataset : `str` or `None`, optional + find_first_dataset : `str`, ``...``, or `None`, optional Name of a dataset type for which only one result row for each data ID should be returned, with the colletions searched in order. + ``...`` is used to represent the search for all dataset types with + a particular set of dimensions in ``tree.any_dataset``. + analyze_only : `bool`, optional + If `True`, perform the initial analysis needed to construct the + builder, but do not call methods that build its SQL form. This can + be useful for obtaining diagnostic information about the query that + would be generated. Returns ------- - plan : `QueryPlan` - Plan used to transform the query into SQL, including a builder - object that can be used to create a SQL SELECT via its - `~QueryBuilder.select` method and a `Postprocessing` object that - describes work to be done after executing the query. - """ - # See the QueryPlan docs for an overview of what these stages of query - # construction do. - plan = self.analyze_query(tree, final_columns, order_by, find_first_dataset) - self.apply_query_joins(plan) - self.apply_query_projection(plan, order_by) - self.apply_query_find_first(plan) - plan.builder.columns = plan.final_columns - return plan - - def analyze_query( - self, - tree: qt.QueryTree, - final_columns: qt.ColumnSet, - order_by: Iterable[qt.OrderExpression] = (), - find_first_dataset: str | None = None, - ) -> QueryPlan: - """Construct a plan for building a query and initialize a builder. - - Parameters - ---------- - tree : `.queries.tree.QueryTree` - Description of the joins and row filters in the query. - final_columns : `.queries.tree.ColumnSet` - Final output columns that should be emitted by the SQL query. - order_by : `~collections.abc.Iterable` [ \ - `.queries.tree.OrderExpression` ], optional - Column expressions to sort by. - find_first_dataset : `str` or `None`, optional - Name of a dataset type for which only one result row for each data - ID should be returned, with the collections searched in order. - - Returns - ------- - plan : `QueryPlan` - Plan used to transform the query into SQL, including a builder - object that can be used to create a SQL SELECT via its - `~QueryBuilder.select` method and a `Postprocessing` object that - describes work to be done after executing the query. + builder : `QueryBuilder` + An object that contains an analysis of the queries columns and + general structure, SQLAlchemy representations of most of the + constructed query, and a description of Python-side postprocessing + to be performed after executing it. Callers should generally only + have to call `finish_select` or `finish_nested`; all ``apply`` + methods will have already been called. + + Notes + ----- + The SQL queries produced by this driver are built in three steps: + + - First we "analyze" the entire query, fetching extra information as + needed to identify the tables we'll join in, the columns we'll have + at each level of what may be a nested SELECT, and the degree to which + we'll have to use GROUP BY / DISTINCT or window functions to obtain + just the rows we want. This process initializes a `QueryBuilder` + instance, but does not call any of its ``apply*`` methods to actually + build the SQLAlchemy version of the query, and it is all that is done + when ``analyze_only=True``. + - In the next step we call `QueryBuilder` ``apply*`` methods to mutate + and/or replace the `QueryBuilder` object's nested `SqlSelectBuilder` + instances to reflect that analysis, building the SQL from the inside + out (subqueries before their parents). This is also done by the + `build_query` method by default + - The returned builder can be used by calling either `finish_select` + (to get the full executable query) or `finish_nested` (to wrap + that query as a subquery - if needed - in order to do related + queries, like ``COUNT(*)`` or ``LIMIT 1`` checks). + + Within both the analysis and application steps, we further split the + query *structure* and building process into three stages, mapping + roughly to levels of (potential) subquery nesting: + + - In the "joins" stage, we join all the tables we need columns or + constraints from and apply the predicate as a WHERE clause. All + non-calculated columns are included in the query at this stage. + - In the optional "projection" stage, we apply a GROUP BY or DISTINCT + to reduce the set of columns and/or rows, and in some cases add new + calculated columns. + - In the optional "find first" stage, we use a common table expression + with PARTITION ON to search for datasets in a sequence of collections + in the order those collections were provided. + + In addition, there are two different overall structures for butler + queries, modeled as two different `QueryBuilder` subclasses. + + - `SingleSelectQueryBuilder` produces a single possibly-nested SELECT + query structured directly according to the stages above. This is + used for almost all queries - specifically whenever + `QueryTree.any_dataset` is `None`. + - `UnionQueryBuilder` implements queries for datasets of multiple types + with the same dimensions as a UNION ALL combination of SELECTs, in + which each component SELECT has the joins/projection/find-first + structure. This cannot be implemented as a sequence of + `SingleSelectQueryBuilder` objets, however, since the need for the + terms to share a single `Postprocessing` object and column list means + that they are not independent. + + The `build_query` method delegates to methods of the `QueryBuilder` + object when "applying" to let them handle these differences. These + often call back to other methods on the driver (so code shared by both + builders mostly lives in the driver class). """ - # The fact that this method returns both a QueryPlan and an initial - # QueryBuilder (rather than just a QueryPlan) is a tradeoff that lets - # DimensionRecordStorageManager.process_query_overlaps (which is called - # by the `_analyze_query_tree` call below) pull out overlap expressions - # from the predicate at the same time it turns them into SQL table - # joins (in the builder). - joins_plan, builder, postprocessing = self._analyze_query_tree(tree) - + # Analyze the dimensions, dataset searches, and other join operands + # that will go into the query. This also initializes a + # SqlSelectBuilder and Postprocessing with spatial/temporal constraints + # potentially transformed by the dimensions manager (but none of the + # rest of the analysis reflected in that SqlSelectBuilder). + query_tree_analysis = self._analyze_query_tree(tree) # The "projection" columns differ from the final columns by not # omitting any dimension keys (this keeps queries for different result # types more similar during construction), including any columns needed # only by order_by terms, and including the collection key if we need # it for GROUP BY or DISTINCT. - projection_plan = QueryProjectionPlan(final_columns.copy(), find_first_dataset=find_first_dataset) - projection_plan.columns.restore_dimension_keys() + projection_columns = final_columns.copy() + projection_columns.restore_dimension_keys() for term in order_by: - term.gather_required_columns(projection_plan.columns) - # The projection gets interesting if it does not have all of the - # dimension keys or dataset fields of the "joins" stage, because that - # means it needs to do a GROUP BY or DISTINCT ON to get unique rows. - if projection_plan.columns.dimensions != joins_plan.columns.dimensions: - assert projection_plan.columns.dimensions.issubset(joins_plan.columns.dimensions) - # We're going from a larger set of dimensions to a smaller set, - # that means we'll be doing a SELECT DISTINCT [ON] or GROUP BY. - projection_plan.needs_dimension_distinct = True - for dataset_type, fields_for_dataset in joins_plan.columns.dataset_fields.items(): - if not projection_plan.columns.dataset_fields[dataset_type]: - # The "joins"-stage query has one row for each collection for - # each data ID, but the projection-stage query just wants - # one row for each data ID. - if len(joins_plan.datasets[dataset_type].collection_records) > 1: - projection_plan.needs_dataset_distinct = True - break - # If there are any dataset fields being propagated through that - # projection and there is more than one collection, we need to - # include the collection_key column so we can use that as one of - # the DISTINCT or GROUP BY columns. - for dataset_type, fields_for_dataset in projection_plan.columns.dataset_fields.items(): - if len(joins_plan.datasets[dataset_type].collection_records) > 1: - fields_for_dataset.add("collection_key") - + term.gather_required_columns(projection_columns) + # There are two kinds of query builders: simple SELECTS and UNIONs + # over dataset types. + builder: QueryBuilder + if tree.any_dataset is not None: + builder = UnionQueryBuilder( + query_tree_analysis, + union_dataset_dimensions=tree.any_dataset.dimensions, + projection_columns=projection_columns, + final_columns=final_columns, + ) + else: + builder = SingleSelectQueryBuilder( + tree_analysis=query_tree_analysis, + projection_columns=projection_columns, + final_columns=final_columns, + ) + # Finish setting up the projection part of the builder. + builder.analyze_projection() # The joins-stage query also needs to include all columns needed by the # downstream projection query. Note that this: # - never adds new dimensions to the joins stage (since those are # always a superset of the projection-stage dimensions); - # - does not affect our determination of - # projection_plan.needs_dataset_distinct, because any dataset fields - # being added to the joins stage here are already in the projection. - joins_plan.columns.update(projection_plan.columns) - - find_first_plan = None + # - does not affect our previous determination of + # needs_dataset_distinct, because any dataset fields being added to + # the joins stage here are already in the projection. + builder.joins_analysis.columns.update(builder.projection_columns) + # Set up the find-first part of the builder. if find_first_dataset is not None: - find_first_plan = QueryFindFirstPlan(joins_plan.datasets[find_first_dataset]) - # If we're doing a find-first search and there's a calibration - # collection in play, we need to make sure the rows coming out of - # the base query have only one timespan for each data ID + - # collection, and we can only do that with a GROUP BY and COUNT - # that we inspect in postprocessing. - if find_first_plan.search.is_calibration_search: - postprocessing.check_validity_match_count = True - plan = QueryPlan( - joins=joins_plan, - projection=projection_plan, - find_first=find_first_plan, - final_columns=final_columns, - builder=builder, + builder.analyze_find_first(find_first_dataset) + # At this point, analysis is complete, and we can proceed to making + # the select_builder(s) reflect that analysis. + if not analyze_only: + builder.apply_joins(self) + builder.apply_projection(self, order_by) + builder.apply_find_first(self) + return builder + + def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis: + """Analyze a `.queries.tree.QueryTree` as the first step in building + a SQL query. + + Parameters + ---------- + tree : `.queries.tree.QueryTree` + Description of the joins and row filters in the query. + + Returns + ------- + tree_analysis : `QueryTreeAnalysis` + Struct containing additional information need to build the joins + stage of a query. + + Notes + ----- + See `build_query` for the complete picture of how SQL queries are + constructed. This method is the very first step for all queries. + + The fact that this method returns both a QueryPlan and an initial + SqlSelectBuilder (rather than just a QueryPlan) is a tradeoff that lets + DimensionRecordStorageManager.process_query_overlaps (which is called + by the `_analyze_query_tree` call below) pull out overlap expressions + from the predicate at the same time it turns them into SQL table joins + (in the builder). + """ + # Fetch the records and summaries for any collections we might be + # searching for datasets and organize them for the kind of lookups + # we'll do later. + collection_analysis = self._analyze_collections(tree) + # Delegate to the dimensions manager to rewrite the predicate and start + # a SqlSelectBuilder to cover any spatial overlap joins or constraints. + # We'll return that SqlSelectBuilder (or copies of it) at the end. + ( + predicate, + select_builder, + postprocessing, + ) = self.managers.dimensions.process_query_overlaps( + tree.dimensions, + tree.predicate, + tree.get_joined_dimension_groups(), + collection_analysis.calibration_dataset_types, + ) + # Extract the data ID implied by the predicate; we can use the governor + # dimensions in that to constrain the collections we search for + # datasets later. + predicate_constraints = PredicateConstraintsSummary(predicate) + # Use the default data ID to apply additional constraints where needed. + predicate_constraints.apply_default_data_id(self._default_data_id, tree.dimensions) + predicate = predicate_constraints.predicate + # Initialize the plan we're return at the end of the method. + joins = QueryJoinsAnalysis(predicate=predicate, columns=select_builder.columns) + joins.messages.extend(predicate_constraints.messages) + # Add columns required by postprocessing. + postprocessing.gather_columns_required(joins.columns) + # Add materializations, which can also bring in more postprocessing. + for m_key, m_dimensions in tree.materializations.items(): + m_state = self._materializations[m_key] + joins.materializations[m_key] = m_dimensions + # When a query is materialized, the new tree has an empty + # (trivially true) predicate because the original was used to make + # the materialized rows. But the original postprocessing isn't + # executed when the materialization happens, so we have to include + # it here. + postprocessing.spatial_join_filtering.extend(m_state.postprocessing.spatial_join_filtering) + postprocessing.spatial_where_filtering.extend(m_state.postprocessing.spatial_where_filtering) + # Add data coordinate uploads. + joins.data_coordinate_uploads.update(tree.data_coordinate_uploads) + # Add dataset_searches and filter out collections that don't have the + # right dataset type or governor dimensions. We re-resolve dataset + # searches now that we have a constraint data ID. + for dataset_type_name, dataset_search in tree.datasets.items(): + resolved_dataset_search = self._resolve_dataset_search( + dataset_type_name, + dataset_search, + predicate_constraints.constraint_data_id, + collection_analysis.summaries_by_dataset_type[dataset_type_name], + ) + if resolved_dataset_search.dimensions != self.get_dataset_type(dataset_type_name).dimensions: + # This is really for server-side defensiveness; it's hard to + # imagine the query getting different dimensions for a dataset + # type in two calls to the same query driver. + raise InvalidQueryError( + f"Incorrect dimensions {resolved_dataset_search.dimensions} for dataset " + f"{dataset_type_name!r} in query " + f"(vs. {self.get_dataset_type(dataset_type_name).dimensions})." + ) + joins.datasets[dataset_type_name] = resolved_dataset_search + if not resolved_dataset_search.collection_records: + joins.messages.append( + f"Search for dataset type {resolved_dataset_search.name!r} in " + f"{list(dataset_search.collections)} is doomed to fail." + ) + joins.messages.extend(resolved_dataset_search.messages) + # Process the special any_dataset search, if there is one. This entails + # making a modified copy of the plan for each distinct post-filtering + # collection search path. + if tree.any_dataset is None: + return QueryTreeAnalysis( + joins, union_datasets=[], initial_select_builder=select_builder, postprocessing=postprocessing + ) + # Gather the filtered collection search path for each union dataset + # type. + collections_by_dataset_type = defaultdict[str, list[str]](list) + for collection_record, collection_summary in collection_analysis.summaries_by_dataset_type[...]: + for dataset_type in collection_summary.dataset_types: + if dataset_type.dimensions == tree.any_dataset.dimensions: + collections_by_dataset_type[dataset_type.name].append(collection_record.name) + # Reverse the lookup order on the mapping we just made to group + # dataset types by their collection search path. Each such group + # yields an output plan. + dataset_searches_by_collections: dict[tuple[str, ...], ResolvedDatasetSearch[list[str]]] = {} + for dataset_type_name, collection_path in collections_by_dataset_type.items(): + key = tuple(collection_path) + if (resolved_search := dataset_searches_by_collections.get(key)) is None: + resolved_search = ResolvedDatasetSearch[list[str]]( + [], + dimensions=tree.any_dataset.dimensions, + collection_records=[ + collection_analysis.collection_records[collection_name] + for collection_name in collection_path + ], + messages=[], + ) + resolved_search.is_calibration_search = any( + r.type is CollectionType.CALIBRATION for r in resolved_search.collection_records + ) + dataset_searches_by_collections[key] = resolved_search + resolved_search.name.append(dataset_type_name) + return QueryTreeAnalysis( + joins, + union_datasets=list(dataset_searches_by_collections.values()), + initial_select_builder=select_builder, postprocessing=postprocessing, ) - return plan - def apply_query_joins(self, plan: QueryPlan) -> None: - """Modify the builder inside a `QueryPlan` to include all tables and - other FROM and WHERE clause terms needed. + def apply_initial_query_joins( + self, + select_builder: SqlSelectBuilder, + joins_analysis: QueryJoinsAnalysis, + union_dataset_dimensions: DimensionGroup | None, + ) -> None: + """Apply most of the "joins" stage of query construction to a single + SELECT. + + This method is expected to be invoked by `QueryBuilder.apply_joins` + implementations. It handles all tables and subqueries in the FROM + clause, except: + + - the `QueryTree.any_dataset` search (handled by `UnionQueryBuilder` + directly); + - joins of dimension tables needed only for their keys (handled by + `apply_missing_dimension_joins`). Parameters ---------- - plan : `QueryPlan` - `QueryPlan` to modify in-place. + select_builder : `SqlSelectBuilder` + Low-level SQL builder for a single SELECT term, modified in place. + joins_analysis : `QueryJoinsAnalysis` + Information about the joins stage of query construction. + union_dataset_dimensions : `DimensionGroup` or `None` + Dimensions of the union dataset types, or `None` if this is not + a union dataset query. """ # Process data coordinate upload joins. - for upload_key, upload_dimensions in plan.joins.data_coordinate_uploads.items(): - plan.builder.joiner.join( - QueryJoiner(self.db, self._upload_tables[upload_key]).extract_dimensions( + for upload_key, upload_dimensions in joins_analysis.data_coordinate_uploads.items(): + select_builder.joins.join( + SqlJoinsBuilder(db=self.db, from_clause=self._upload_tables[upload_key]).extract_dimensions( upload_dimensions.required ) ) @@ -596,34 +766,52 @@ def apply_query_joins(self, plan: QueryPlan) -> None: # can be dropped if they are only present to provide a constraint on # data IDs, since that's already embedded in a materialization. materialized_datasets: set[str] = set() - for materialization_key, materialization_dimensions in plan.joins.materializations.items(): + for materialization_key, materialization_dimensions in joins_analysis.materializations.items(): materialized_datasets.update( self._join_materialization( - plan.builder.joiner, materialization_key, materialization_dimensions + select_builder.joins, materialization_key, materialization_dimensions ) ) - # Process dataset joins. - for dataset_search in plan.joins.datasets.values(): - self._join_dataset_search( - plan.builder.joiner, + # Process dataset joins (not including any union dataset). + for dataset_search in joins_analysis.datasets.values(): + self.join_dataset_search( + select_builder.joins, dataset_search, - plan.joins.columns.dataset_fields[dataset_search.name], + joins_analysis.columns.dataset_fields[dataset_search.name], ) # Join in dimension element tables that we know we need relationships # or columns from. - for element in plan.joins.iter_mandatory(): - plan.builder.joiner.join( - self.managers.dimensions.make_query_joiner( - element, plan.joins.columns.dimension_fields[element.name] + for element in joins_analysis.iter_mandatory(union_dataset_dimensions): + select_builder.joins.join( + self.managers.dimensions.make_joins_builder( + element, joins_analysis.columns.dimension_fields[element.name] ) ) - # See if any dimension keys are still missing, and if so join in their - # tables. Note that we know there are no fields needed from these. - while not (plan.builder.joiner.dimension_keys.keys() >= plan.joins.columns.dimensions.names): - # Look for opportunities to join in multiple dimensions via single - # table, to reduce the total number of tables joined in. + + def apply_missing_dimension_joins( + self, select_builder: SqlSelectBuilder, joins_analysis: QueryJoinsAnalysis + ) -> None: + """Apply dimension-table joins to a single SQL SELECT builder to ensure + the full set of desired dimension keys is preset. + + This method is expected to be invoked by `QueryBuilder.apply_joins` + implementations. + + Parameters + ---------- + select_builder : `SqlSelectBuilder` + Low-level SQL builder for a single SELECT term, modified in place. + joins_analysis : `QueryJoinsAnalysis` + Information about the joins stage of query construction. + """ + # See if any dimension keys are still missing, and if so join in + # their tables. Note that we know there are no fields needed from + # these. + while not (select_builder.joins.dimension_keys.keys() >= joins_analysis.columns.dimensions.names): + # Look for opportunities to join in multiple dimensions via + # single table, to reduce the total number of tables joined in. missing_dimension_names = ( - plan.joins.columns.dimensions.names - plan.builder.joiner.dimension_keys.keys() + joins_analysis.columns.dimensions.names - select_builder.joins.dimension_keys.keys() ) best = self._universe[ max( @@ -631,29 +819,76 @@ def apply_query_joins(self, plan: QueryPlan) -> None: key=lambda name: len(self._universe[name].dimensions.names & missing_dimension_names), ) ] - plan.builder.joiner.join(self.managers.dimensions.make_query_joiner(best, frozenset())) - # Add the WHERE clause to the joiner. - plan.builder.joiner.where(plan.joins.predicate.visit(SqlColumnVisitor(plan.builder.joiner, self))) + to_join = self.managers.dimensions.make_joins_builder(best, frozenset()) + select_builder.joins.join(to_join) + # Add the WHERE clause to the builder. + select_builder.joins.where( + joins_analysis.predicate.visit(SqlColumnVisitor(select_builder.joins, self)) + ) def apply_query_projection( self, - plan: QueryPlan, + select_builder: SqlSelectBuilder, + postprocessing: Postprocessing, + *, + join_datasets: Mapping[str, ResolvedDatasetSearch[str]], + union_datasets: ResolvedDatasetSearch[list[str]] | None, + projection_columns: qt.ColumnSet, + needs_dimension_distinct: bool, + needs_dataset_distinct: bool, + needs_validity_match_count: bool, + find_first_dataset: str | EllipsisType | None, order_by: Iterable[qt.OrderExpression], ) -> None: - """Modify `QueryBuilder` to reflect the "projection" stage of query - construction, which can involve a GROUP BY or DISTINCT [ON] clause - that enforces uniqueness. + """Apply the "projection" stage of query construction to a single + SQL SELECT builder. + + This method is expected to be invoked by + `QueryBuilder.apply_projection` implementations. Parameters ---------- - plan : `QueryPlan` - `QueryPlan` to modify in-place. + select_builder : `SqlSelectBuilder` + Low-level SQL builder for a single SELECT term, modified in place. + postprocessing : `Postprocessing` + Description of query processing that happens in Python after the + query is executed by the database. + join_datasets : `~collections.abc.Mapping` [ `str`, \ + `ResolvedDatasetSearch` [ `str` ] ] + Information about regular (non-union) dataset searches joined into + the query. + union_datasets : `ResolvedDatasetSearch [ `list` [ `str` ] ] or `None` + Information about a search for dataset of multiple types with the + same dimensions and the same post-filtering collection search path. + projection_columns : `.queries.tree.ColumnSet` + Columns to include in this projection stage of the query. + needs_dimension_distinct : `bool` + Whether this query needs a GROUP BY or DISTINCT to filter out rows + where the only differences are dimension values that are not being + returned to the user. + needs_dataset_distinct : `bool` + Whether this query needs a GROUP BY or DISTINCT to filter out rows + that correspond to entries for different collections. + needs_validity_match_count : `bool` + Whether this query needs a COUNT column to track the number of + datasets for each data ID and dataset type. If this is `False` but + ``postprocessing.check_valdity_match_count`` is `True`, a dummy + count column that is just "1" should be added, because the check + is needed only for some other SELECT term in a UNION ALL. + find_first_dataset : `str` or ``...`` or `None` + Name of the dataset type that will need a find-first stage. + Ellipsis ``...`` is used when the union datasets need a find-first + search, while `None` is used to represent both ``find_first=False`` + and the case when ``find_first=True`` but only only collection has + survived filtering. order_by : `~collections.abc.Iterable` [ \ `.queries.tree.OrderExpression` ] Order by clause associated with the query. """ - plan.builder.columns = plan.projection.columns - if not plan.postprocessing and not plan.postprocessing.check_validity_match_count: + select_builder.columns = projection_columns + if not needs_dimension_distinct and not needs_dataset_distinct and not needs_validity_match_count: + if postprocessing.check_validity_match_count: + select_builder.joins.special[postprocessing.VALIDITY_MATCH_COUNT] = sqlalchemy.literal(1) # Rows are already unique; nothing else to do in this method. return # This method generates either a SELECT DISTINCT [ON] or a SELECT with @@ -662,15 +897,15 @@ def apply_query_projection( # Dimension key columns form at least most of our GROUP BY or DISTINCT # ON clause. unique_keys: list[sqlalchemy.ColumnElement[Any]] = [ - plan.builder.joiner.dimension_keys[k][0] - for k in plan.projection.columns.dimensions.data_coordinate_keys + select_builder.joins.dimension_keys[k][0] + for k in projection_columns.dimensions.data_coordinate_keys ] # Many of our fields derive their uniqueness from the unique_key # fields: if rows are uniqe over the 'unique_key' fields, then they're # automatically unique over these 'derived_fields'. We just remember # these as pairs of (logical_table, field) for now. - derived_fields: list[tuple[str, str]] = [] + derived_fields: list[tuple[str | EllipsisType, str]] = [] # There are two reasons we might need an aggregate function: # - to make sure temporal constraints and joins have resulted in at @@ -682,14 +917,17 @@ def apply_query_projection( # visit_detector_region.region, but the output rows don't have # detector, just visit - so we compute the union of the # visit_detector region over all matched detectors). - if plan.postprocessing.check_validity_match_count: - plan.builder.joiner.special[plan.postprocessing.VALIDITY_MATCH_COUNT] = ( - sqlalchemy.func.count().label(plan.postprocessing.VALIDITY_MATCH_COUNT) - ) - have_aggregates = True + if postprocessing.check_validity_match_count: + if needs_validity_match_count: + select_builder.joins.special[postprocessing.VALIDITY_MATCH_COUNT] = ( + sqlalchemy.func.count().label(postprocessing.VALIDITY_MATCH_COUNT) + ) + have_aggregates = True + else: + select_builder.joins.special[postprocessing.VALIDITY_MATCH_COUNT] = sqlalchemy.literal(1) - for element in plan.postprocessing.iter_missing(plan.projection.columns): - if element.name in plan.projection.columns.dimensions.elements: + for element in postprocessing.iter_missing(projection_columns): + if element.name in projection_columns.dimensions.elements: # The region associated with dimension keys returned by the # query are derived fields, since there is only one region # associated with each dimension key value. @@ -700,33 +938,37 @@ def apply_query_projection( # regions. When that happens, we want to apply an aggregate # function to them that computes the union of the regions that # are grouped together. - plan.builder.joiner.fields[element.name]["region"] = ddl.Base64Region.union_aggregate( - plan.builder.joiner.fields[element.name]["region"] + select_builder.joins.fields[element.name]["region"] = ddl.Base64Region.union_aggregate( + select_builder.joins.fields[element.name]["region"] ) have_aggregates = True # All dimension record fields are derived fields. - for element_name, fields_for_element in plan.projection.columns.dimension_fields.items(): + for element_name, fields_for_element in projection_columns.dimension_fields.items(): for element_field in fields_for_element: derived_fields.append((element_name, element_field)) # Some dataset fields are derived fields and some are unique keys, and # it depends on the kinds of collection(s) we're searching and whether # it's a find-first query. - for dataset_type, fields_for_dataset in plan.projection.columns.dataset_fields.items(): + for dataset_type, fields_for_dataset in projection_columns.dataset_fields.items(): + dataset_search: ResolvedDatasetSearch[Any] + if dataset_type is ...: + assert union_datasets is not None + dataset_search = union_datasets + else: + dataset_search = join_datasets[dataset_type] for dataset_field in fields_for_dataset: if dataset_field == "collection_key": # If the collection_key field is present, it's needed for # uniqueness if we're looking in more than one collection. # If not, it's a derived field. - if len(plan.joins.datasets[dataset_type].collection_records) > 1: - unique_keys.append(plan.builder.joiner.fields[dataset_type]["collection_key"]) + if len(dataset_search.collection_records) > 1: + unique_keys.append(select_builder.joins.fields[dataset_type]["collection_key"]) else: derived_fields.append((dataset_type, "collection_key")) - elif dataset_field == "timespan" and plan.joins.datasets[dataset_type].is_calibration_search: - # If we're doing a non-find-first query against a - # CALIBRATION collection, the timespan is also a unique - # key... - if dataset_type == plan.projection.find_first_dataset: + elif dataset_field == "timespan" and dataset_search.is_calibration_search: + # The timespan is also a unique key... + if dataset_type == find_first_dataset: # ...unless we're doing a find-first search on this # dataset, in which case we need to use ANY_VALUE on # the timespan and check that _VALIDITY_MATCH_COUNT @@ -736,22 +978,22 @@ def apply_query_projection( # clauses and JOINs. if not self.db.has_any_aggregate: raise NotImplementedError( - f"Cannot generate query that returns {dataset_type}.timespan after a " - "find-first search, because this a database does not support the ANY_VALUE " + f"Cannot generate query that returns timespan for {dataset_type!r} after a " + "find-first search, because this database does not support the ANY_VALUE " "aggregate function (or equivalent)." ) - plan.builder.joiner.timespans[dataset_type] = plan.builder.joiner.timespans[ + select_builder.joins.timespans[dataset_type] = select_builder.joins.timespans[ dataset_type ].apply_any_aggregate(self.db.apply_any_aggregate) else: - unique_keys.extend(plan.builder.joiner.timespans[dataset_type].flatten()) + unique_keys.extend(select_builder.joins.timespans[dataset_type].flatten()) else: # Other dataset fields derive their uniqueness from key # fields. derived_fields.append((dataset_type, dataset_field)) if not have_aggregates and not derived_fields: # SELECT DISTINCT is sufficient. - plan.builder.distinct = True + select_builder.distinct = True # With DISTINCT ON, Postgres requires that the leftmost parts of the # ORDER BY match the DISTINCT ON expressions. It's somewhat tricky to # enforce that, so instead we just don't use DISTINCT ON if ORDER BY is @@ -759,45 +1001,63 @@ def apply_query_projection( # restriction. elif not have_aggregates and self.db.has_distinct_on and len(list(order_by)) == 0: # SELECT DISTINCT ON is sufficient and supported by this database. - plan.builder.distinct = tuple(unique_keys) + select_builder.distinct = tuple(unique_keys) else: # GROUP BY is the only option. if derived_fields: if self.db.has_any_aggregate: for logical_table, field in derived_fields: if field == "timespan": - plan.builder.joiner.timespans[logical_table] = plan.builder.joiner.timespans[ + select_builder.joins.timespans[logical_table] = select_builder.joins.timespans[ logical_table ].apply_any_aggregate(self.db.apply_any_aggregate) else: - plan.builder.joiner.fields[logical_table][field] = self.db.apply_any_aggregate( - plan.builder.joiner.fields[logical_table][field] + select_builder.joins.fields[logical_table][field] = self.db.apply_any_aggregate( + select_builder.joins.fields[logical_table][field] ) else: _LOG.warning( "Adding %d fields to GROUP BY because this database backend does not support the " "ANY_VALUE aggregate function (or equivalent). This may result in a poor query " - "plan. Materializing the query first sometimes avoids this problem.", + "plan. Materializing the query first sometimes avoids this problem. This warning " + "can be ignored unless query performance is a problem.", len(derived_fields), ) for logical_table, field in derived_fields: if field == "timespan": - unique_keys.extend(plan.builder.joiner.timespans[logical_table].flatten()) + unique_keys.extend(select_builder.joins.timespans[logical_table].flatten()) else: - unique_keys.append(plan.builder.joiner.fields[logical_table][field]) - plan.builder.group_by = tuple(unique_keys) + unique_keys.append(select_builder.joins.fields[logical_table][field]) + select_builder.group_by = tuple(unique_keys) - def apply_query_find_first(self, plan: QueryPlan) -> None: - """Modify an under-construction SQL query to return only one row for - each data ID, searching collections in order. + def apply_query_find_first( + self, + select_builder: SqlSelectBuilder, + postprocessing: Postprocessing, + find_first_analysis: QueryFindFirstAnalysis, + ) -> SqlSelectBuilder: + """Apply the "find first" stage of query construction to a single + SQL SELECT builder. + + This method is expected to be invoked by + `QueryBuilder.apply_find_first` implementations. Parameters ---------- - plan : `QueryPlan` - `QueryPlan` to modify in-place. + select_builder : `SqlSelectBuilder` + Low-level SQL builder for a single SELECT term, consumed on return. + postprocessing : `Postprocessing` + Description of query processing that happens in Python after the + query is executed by the database. + find_first_analysis : `QueryFindFirstAnalysis` + Information about the find-first stage gathered during the analysis + phase of query construction. + + Returns + ------- + select_builder : `SqlSelectBuilder` + Low-level SQL builder that includes the find-first logic. """ - if not plan.find_first: - return # The query we're building looks like this: # # WITH {dst}_base AS ( @@ -817,9 +1077,9 @@ def apply_query_find_first(self, plan: QueryPlan) -> None: # WHERE # {dst}_window.rownum = 1; # - # The outermost SELECT will be represented by the QueryBuilder we - # return. The QueryBuilder we're given corresponds to the Common Table - # Expression (CTE) at the top. + # The outermost SELECT will be represented by the SqlSelectBuilder we + # return. The SqlSelectBuilder we're given corresponds to the Common + # Table Expression (CTE) at the top. # # For SQLite only, we could use a much simpler GROUP BY instead, # because it extends the standard to do exactly what we want when MIN @@ -827,33 +1087,34 @@ def apply_query_find_first(self, plan: QueryPlan) -> None: # (https://www.sqlite.org/quirks.html). But since that doesn't work # with PostgreSQL it doesn't help us. # - plan.builder = plan.builder.nested(cte=True, force=True, postprocessing=plan.postprocessing) + select_builder = select_builder.nested(cte=True, force=True, postprocessing=postprocessing) # We start by filling out the "window" SELECT statement... partition_by = [ - plan.builder.joiner.dimension_keys[d][0] for d in plan.builder.columns.dimensions.required + select_builder.joins.dimension_keys[d][0] for d in select_builder.columns.dimensions.required ] rank_sql_column = sqlalchemy.case( - {record.key: n for n, record in enumerate(plan.find_first.search.collection_records)}, - value=plan.builder.joiner.fields[plan.find_first.dataset_type]["collection_key"], + {record.key: n for n, record in enumerate(find_first_analysis.search.collection_records)}, + value=select_builder.joins.fields[find_first_analysis.dataset_type]["collection_key"], ) if partition_by: - plan.builder.joiner.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over( + select_builder.joins.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over( partition_by=partition_by, order_by=rank_sql_column ) else: - plan.builder.joiner.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over( + select_builder.joins.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over( order_by=rank_sql_column ) # ... and then turn that into a subquery with a constraint on rownum. - plan.builder = plan.builder.nested(force=True, postprocessing=plan.postprocessing) + select_builder = select_builder.nested(force=True, postprocessing=postprocessing) # We can now add the WHERE constraint on rownum into the outer query. - plan.builder.joiner.where(plan.builder.joiner.special["_ROWNUM"] == 1) + select_builder.joins.where(select_builder.joins.special["_ROWNUM"] == 1) # Don't propagate _ROWNUM into downstream queries. - del plan.builder.joiner.special["_ROWNUM"] + del select_builder.joins.special["_ROWNUM"] + return select_builder - def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, QueryBuilder, Postprocessing]: - """Start constructing a plan for building a query from a - `.queries.tree.QueryTree`. + def _analyze_collections(self, tree: qt.QueryTree) -> QueryCollectionAnalysis: + """Fetch and organize information about all collections appearing in a + query. Parameters ---------- @@ -862,16 +1123,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query Returns ------- - plan : `QueryJoinsPlan` - Initial component of the plan relevant for the "joins" stage, - including all joins and columns needed by ``tree``. Additional - columns will be added to this plan later. - builder : `QueryBuilder` - Builder object initialized with overlap joins and constraints - potentially included, with the remainder still present in - `QueryJoinPlans.predicate`. - postprocessing : `Postprocessing` - Struct representing post-query processing to be done in Python. + collection_analysis : `QueryCollectionAnalysis` + Struct containing collection records and summaries, organized + for later access by dataset type. """ # Retrieve collection information for all collections in a tree. collection_names = set( @@ -879,6 +1133,8 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query dataset_search.collections for dataset_search in tree.datasets.values() ) ) + if tree.any_dataset is not None: + collection_names.update(tree.any_dataset.collections) collection_records = { record.name: record for record in self.managers.collections.resolve_wildcard( @@ -889,84 +1145,25 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query record for record in collection_records.values() if record.type is not CollectionType.CHAINED ] # Fetch summaries for a subset of dataset types. - dataset_types = [self.get_dataset_type(dataset_type_name) for dataset_type_name in tree.datasets] - summaries = self.managers.datasets.fetch_summaries(non_chain_records, dataset_types) + if tree.any_dataset is not None: + summaries = self.managers.datasets.fetch_summaries(non_chain_records, dataset_types=None) + else: + dataset_types = [self.get_dataset_type(dataset_type_name) for dataset_type_name in tree.datasets] + summaries = self.managers.datasets.fetch_summaries(non_chain_records, dataset_types) + result = QueryCollectionAnalysis(collection_records=collection_records) # Do a preliminary resolution for dataset searches to identify any # calibration lookups that might participate in temporal joins. - calibration_dataset_types: set[str] = set() - summaries_by_dataset_type: dict[str, list[tuple[CollectionRecord, CollectionSummary]]] = {} - for dataset_type_name, dataset_search in tree.datasets.items(): + for dataset_type_name, dataset_search in tree.iter_all_dataset_searches(): collection_summaries = self._filter_collections( dataset_search.collections, collection_records, summaries ) - summaries_by_dataset_type[dataset_type_name] = collection_summaries + result.summaries_by_dataset_type[dataset_type_name] = collection_summaries resolved_dataset_search = self._resolve_dataset_search( dataset_type_name, dataset_search, {}, collection_summaries ) if resolved_dataset_search.is_calibration_search: - calibration_dataset_types.add(dataset_type_name) - # Delegate to the dimensions manager to rewrite the predicate and start - # a QueryBuilder to cover any spatial overlap joins or constraints. - # We'll return that QueryBuilder at the end. - ( - predicate, - builder, - postprocessing, - ) = self.managers.dimensions.process_query_overlaps( - tree.dimensions, - tree.predicate, - tree.get_joined_dimension_groups(), - calibration_dataset_types, - ) - - # Extract the data ID implied by the predicate; we can use the governor - # dimensions in that to constrain the collections we search for - # datasets later. - predicate_constraints = PredicateConstraintsSummary(predicate) - # Use the default data ID to apply additional constraints where needed. - predicate_constraints.apply_default_data_id(self._default_data_id, tree.dimensions) - predicate = predicate_constraints.predicate - - result = QueryJoinsPlan( - predicate=predicate, - columns=builder.columns, - messages=predicate_constraints.messages, - ) - - # Add columns required by postprocessing. - postprocessing.gather_columns_required(result.columns) - - # Add materializations, which can also bring in more postprocessing. - for m_key, m_dimensions in tree.materializations.items(): - m_state = self._materializations[m_key] - result.materializations[m_key] = m_dimensions - # When a query is materialized, the new tree has an empty - # (trivially true) predicate because the original was used to make - # the materialized rows. But the original postprocessing isn't - # executed when the materialization happens, so we have to include - # it here. - postprocessing.spatial_join_filtering.extend(m_state.postprocessing.spatial_join_filtering) - postprocessing.spatial_where_filtering.extend(m_state.postprocessing.spatial_where_filtering) - # Add data coordinate uploads. - result.data_coordinate_uploads.update(tree.data_coordinate_uploads) - # Add dataset_searches and filter out collections that don't have the - # right dataset type or governor dimensions. We re-resolve dataset - # searches now that we have a constraint data ID. - for dataset_type_name, dataset_search in tree.datasets.items(): - resolved_dataset_search = self._resolve_dataset_search( - dataset_type_name, - dataset_search, - predicate_constraints.constraint_data_id, - summaries_by_dataset_type[dataset_type_name], - ) - result.datasets[dataset_type_name] = resolved_dataset_search - if not resolved_dataset_search.collection_records: - result.messages.append( - f"Search for dataset type {dataset_type_name!r} in " - f"{list(dataset_search.collections)} is doomed to fail." - ) - result.messages.extend(resolved_dataset_search.messages) - return result, builder, postprocessing + result.calibration_dataset_types.add(dataset_type_name) + return result def _filter_collections( self, @@ -1012,17 +1209,17 @@ def recurse(names: Iterable[str]) -> Iterator[tuple[CollectionRecord, Collection def _resolve_dataset_search( self, - dataset_type_name: str, + dataset_type_name: _T, dataset_search: qt.DatasetSearch, constraint_data_id: Mapping[str, DataIdValue], collections: list[tuple[CollectionRecord, CollectionSummary]], - ) -> ResolvedDatasetSearch: + ) -> ResolvedDatasetSearch[_T]: """Resolve the collections that should actually be searched for datasets of a particular type. Parameters ---------- - dataset_type_name : `str` + dataset_type_name : `str` or ``...`` Name of the dataset being searched for. dataset_search : `.queries.tree.DatasetSearch` Struct holding the dimensions and original collection search path. @@ -1045,7 +1242,7 @@ def _resolve_dataset_search( result.messages.append("No datasets can be found because collection list is empty.") for collection_record, collection_summary in collections: rejected: bool = False - if result.name not in collection_summary.dataset_types.names: + if result.name is not ... and result.name not in collection_summary.dataset_types.names: result.messages.append( f"No datasets of type {result.name!r} in collection {collection_record.name!r}." ) @@ -1061,19 +1258,11 @@ def _resolve_dataset_search( if collection_record.type is CollectionType.CALIBRATION: result.is_calibration_search = True result.collection_records.append(collection_record) - if result.dimensions != self.get_dataset_type(dataset_type_name).dimensions: - # This is really for server-side defensiveness; it's hard to - # imagine the query getting different dimensions for a dataset - # type in two calls to the same query driver. - raise InvalidQueryError( - f"Incorrect dimensions {result.dimensions} for dataset {dataset_type_name} " - f"in query (vs. {self.get_dataset_type(dataset_type_name).dimensions})." - ) return result def _join_materialization( self, - joiner: QueryJoiner, + joins_builder: SqlJoinsBuilder, key: qt.MaterializationKey, dimensions: DimensionGroup, ) -> frozenset[str]: @@ -1081,8 +1270,8 @@ def _join_materialization( Parameters ---------- - joiner : `QueryJoiner` - Component of a `QueryBuilder` that holds the FROM and WHERE + joins_builder : `SqlJoinsBuilder` + Component of a `SqlSelectBuilder` that holds the FROM and WHERE clauses. This will be modified in-place on return. key : `.queries.tree.MaterializationKey` Unique identifier created for this materialization when it was @@ -1098,38 +1287,77 @@ def _join_materialization( """ columns = qt.ColumnSet(dimensions) m_state = self._materializations[key] - joiner.join(QueryJoiner(self.db, m_state.table).extract_columns(columns, m_state.postprocessing)) + joins_builder.join( + SqlJoinsBuilder(db=self.db, from_clause=m_state.table).extract_columns( + columns, m_state.postprocessing + ) + ) return m_state.datasets - def _join_dataset_search( + @overload + def join_dataset_search( self, - joiner: QueryJoiner, - resolved_search: ResolvedDatasetSearch, + joins_builder: SqlJoinsBuilder, + resolved_search: ResolvedDatasetSearch[list[str]], fields: Set[str], + union_dataset_type_name: str, + ) -> None: ... + + @overload + def join_dataset_search( + self, + joins_builder: SqlJoinsBuilder, + resolved_search: ResolvedDatasetSearch[str], + fields: Set[str], + ) -> None: ... + + def join_dataset_search( + self, + joins_builder: SqlJoinsBuilder, + resolved_search: ResolvedDatasetSearch[Any], + fields: Set[str], + union_dataset_type_name: str | None = None, ) -> None: """Join a dataset search into an under-construction query. Parameters ---------- - joiner : `QueryJoiner` - Component of a `QueryBuilder` that holds the FROM and WHERE + joins_builder : `SqlJoinsBuilder` + Component of a `SqlSelectBuilder` that holds the FROM and WHERE clauses. This will be modified in-place on return. resolved_search : `ResolvedDatasetSearch` Struct that describes the dataset type and collections. fields : `~collections.abc.Set` [ `str` ] Dataset fields to include. + union_dataset_type_name : `str`, optional + Dataset type name to use when `resolved_search` represents multiple + union datasets. """ - # The next two asserts will need to be dropped (and the implications + # The asserts below will need to be dropped (and the implications # dealt with instead) if materializations start having dataset fields. - assert ( - resolved_search.name not in joiner.fields - ), "Dataset fields have unexpectedly already been joined in." - assert ( - resolved_search.name not in joiner.timespans - ), "Dataset timespan has unexpectedly already been joined in." - joiner.join( - self.managers.datasets.make_query_joiner( - self.get_dataset_type(resolved_search.name), resolved_search.collection_records, fields + if union_dataset_type_name is None: + dataset_type = self.get_dataset_type(cast(str, resolved_search.name)) + assert ( + dataset_type.name not in joins_builder.fields + ), "Dataset fields have unexpectedly already been joined in." + assert ( + dataset_type.name not in joins_builder.timespans + ), "Dataset timespan has unexpectedly already been joined in." + else: + dataset_type = self.get_dataset_type(union_dataset_type_name) + assert ( + ... not in joins_builder.fields + ), "Union dataset fields have unexpectedly already been joined in." + assert ( + ... not in joins_builder.timespans + ), "Union dataset timespan has unexpectedly already been joined in." + + joins_builder.join( + self.managers.datasets.make_joins_builder( + dataset_type, + resolved_search.collection_records, + fields, + is_union=(union_dataset_type_name is not None), ) ) diff --git a/python/lsst/daf/butler/direct_query_driver/_query_plan.py b/python/lsst/daf/butler/direct_query_driver/_query_analysis.py similarity index 52% rename from python/lsst/daf/butler/direct_query_driver/_query_plan.py rename to python/lsst/daf/butler/direct_query_driver/_query_analysis.py index f78d1462e0..1898b8fbae 100644 --- a/python/lsst/daf/butler/direct_query_driver/_query_plan.py +++ b/python/lsst/daf/butler/direct_query_driver/_query_analysis.py @@ -28,34 +28,37 @@ from __future__ import annotations __all__ = ( - "QueryPlan", - "QueryJoinsPlan", - "QueryProjectionPlan", - "QueryFindFirstPlan", + "QueryJoinsAnalysis", + "QueryFindFirstAnalysis", "ResolvedDatasetSearch", + "QueryCollectionAnalysis", ) import dataclasses -from collections.abc import Iterator -from typing import TYPE_CHECKING +from collections.abc import Iterator, Mapping +from types import EllipsisType +from typing import TYPE_CHECKING, Generic, TypeVar from ..dimensions import DimensionElement, DimensionGroup from ..queries import tree as qt +from ..registry import CollectionSummary from ..registry.interfaces import CollectionRecord if TYPE_CHECKING: from ._postprocessing import Postprocessing - from ._query_builder import QueryBuilder + from ._sql_builders import SqlSelectBuilder + +_T = TypeVar("_T") @dataclasses.dataclass -class ResolvedDatasetSearch: +class ResolvedDatasetSearch(Generic[_T]): """A struct describing a dataset search joined into a query, after resolving its collection search path. """ - name: str - """Name of the dataset type.""" + name: _T + """Name or names of the dataset type(s).""" dimensions: DimensionGroup """Dimensions of the dataset type.""" @@ -76,15 +79,17 @@ class ResolvedDatasetSearch: `~CollectionType.CALIBRATION` collection, `False` otherwise. Since only calibration datasets can be present in - `~CollectionType.CALIBRATION` collections, this also + `~CollectionType.CALIBRATION` collections, this also indicates that the + dataset type is a calibration. """ @dataclasses.dataclass -class QueryJoinsPlan: +class QueryJoinsAnalysis: """A struct describing the "joins" section of a butler query. - See `QueryPlan` and `QueryPlan.joins` for additional information. + See `DirectQueryDriver.build_query` for an overview of how queries are + transformed into SQL, and the role this object plays in that. """ predicate: qt.Predicate @@ -100,7 +105,7 @@ class QueryJoinsPlan: materializations: dict[qt.MaterializationKey, DimensionGroup] = dataclasses.field(default_factory=dict) """Materializations to join into the query.""" - datasets: dict[str, ResolvedDatasetSearch] = dataclasses.field(default_factory=dict) + datasets: dict[str, ResolvedDatasetSearch[str]] = dataclasses.field(default_factory=dict) """Dataset searches to join into the query.""" data_coordinate_uploads: dict[qt.DataCoordinateUploadKey, DimensionGroup] = dataclasses.field( @@ -116,7 +121,7 @@ class QueryJoinsPlan: def __post_init__(self) -> None: self.predicate.gather_required_columns(self.columns) - def iter_mandatory(self) -> Iterator[DimensionElement]: + def iter_mandatory(self, union_dataset_dimensions: DimensionGroup | None) -> Iterator[DimensionElement]: """Return an iterator over the dimension elements that must be joined into the query. @@ -124,6 +129,12 @@ def iter_mandatory(self) -> Iterator[DimensionElement]: relationships that result rows must be consistent with. They do not necessarily include all dimension keys in `columns`, since each of those can typically be included in a query in multiple different ways. + + Parameters + ---------- + union_dataset_dimensions : `DimensionGroup` or `None` + Dimensions of the union dataset types, or `None` if this is not + a union dataset query. """ for element_name in self.columns.dimensions.elements: element = self.columns.dimensions.universe[element_name] @@ -148,6 +159,11 @@ def iter_mandatory(self) -> Iterator[DimensionElement]: for dataset_spec in self.datasets.values() ): continue + if ( + union_dataset_dimensions is not None + and element.minimal_group.names <= union_dataset_dimensions.required + ): + continue # Materializations have all key columns for their dimensions. if any( element in materialization_dimensions.names @@ -158,133 +174,77 @@ def iter_mandatory(self) -> Iterator[DimensionElement]: @dataclasses.dataclass -class QueryProjectionPlan: - """A struct describing the "projection" stage of a butler query. - - This struct evaluates to `True` in boolean contexts if either - `needs_dimension_distinct` or `needs_dataset_distinct` are `True`. In - other cases the projection is effectively a no-op, because the - "joins"-stage rows are already unique. - - See `QueryPlan` and `QueryPlan.projection` for additional information. - """ - - columns: qt.ColumnSet - """The columns present in the query after the projection is applied. +class QueryFindFirstAnalysis(Generic[_T]): + """A struct describing the "find-first" stage of a butler query. - This is always a subset of `QueryJoinsPlan.columns`. + See `DirectQueryDriver.build_query` for an overview of how queries are + transformed into SQL, and the role this object plays in that. """ - needs_dimension_distinct: bool = False - """If `True`, the projection's dimensions do not include all dimensions in - the "joins" stage, and hence a SELECT DISTINCT [ON] or GROUP BY must be - used to make post-projection rows unique. - """ + search: ResolvedDatasetSearch[_T] + """Information about the dataset type or types being searched for.""" - needs_dataset_distinct: bool = False - """If `True`, the projection columns do not include collection-specific - dataset fields that were present in the "joins" stage, and hence a SELECT - DISTINCT [ON] or GROUP BY must be added to make post-projection rows - unique. - """ + @property + def dataset_type(self) -> _T: + """Name(s) of the dataset type(s).""" + return self.search.name def __bool__(self) -> bool: - return self.needs_dimension_distinct or self.needs_dataset_distinct + return len(self.search.collection_records) > 1 - find_first_dataset: str | None = None - """If not `None`, this is a find-first query for this dataset. - This is set even if the find-first search is trivial because there is only - one resolved collection. +@dataclasses.dataclass +class QueryCollectionAnalysis: + """A struct containing information about all of the collections that appear + in a butler query. """ + collection_records: Mapping[str, CollectionRecord] + """All collection records, keyed by collection name. -@dataclasses.dataclass -class QueryFindFirstPlan: - """A struct describing the "find-first" stage of a butler query. - - See `QueryPlan` and `QueryPlan.find_first` for additional information. + This includes CHAINED collections. """ - search: ResolvedDatasetSearch - """Information about the dataset being searched for.""" - - @property - def dataset_type(self) -> str: - """Name of the dataset type.""" - return self.search.name + calibration_dataset_types: set[str | EllipsisType] = dataclasses.field(default_factory=set) + """A set of the anmes of all calibration dataset types. - def __bool__(self) -> bool: - return len(self.search.collection_records) > 1 + If ``...`` appears in the set, the dataset type union includes at least one + calibration dataset type. + """ + summaries_by_dataset_type: dict[str | EllipsisType, list[tuple[CollectionRecord, CollectionSummary]]] = ( + dataclasses.field(default_factory=dict) + ) + """Collection records and summaries, in search order, keyed by dataset type + name. -@dataclasses.dataclass -class QueryPlan: - """A struct that aggregates information about a complete butler query. - - Notes - ----- - Butler queries are transformed into a combination of SQL and Python-side - postprocessing in three stages, with each corresponding to an attributes of - this class and a method of `DirectQueryDriver` - - - In the `joins` stage (`~DirectQueryDriver.apply_query_joins`), we define - the main SQL FROM and WHERE clauses, by joining all tables needed to - bring in any columns, or constrain the keys of its rows. - - - In the `projection` stage (`~DirectQueryDriver.apply_query_projection`), - we select only the columns needed for the query's result rows (including - columns needed only by postprocessing and ORDER BY, as well those needed - by the objects returned to users). If the result rows are not naturally - unique given what went into the query in the "joins" stage, the - projection involves a SELECT DISTINCT [ON] or GROUP BY to make them - unique, and in a few rare cases uses aggregate functions with GROUP BY. - - - In the `find_first` stage (`~DirectQueryDriver.apply_query_find_first`), - we use a window function (PARTITION BY) subquery to find only the first - dataset in the collection search path for each data ID. This stage does - nothing if there is no find-first dataset search, or if the search is - trivial because there is only one collection. - - In `DirectQueryDriver.build_query`, a `QueryPlan` instance is constructed - via `DirectQueryDriver.analyze_query`, which also returns an initial - `QueryBuilder`. After this point the plans are considered frozen, and the - nested plan attributes are then passed to each of the corresponding - `DirectQueryDriver` methods along with the builder, which is mutated (and - occasionally replaced) into the complete SQL/postprocessing form of the - query. + CHAINED collections are flattened out in the nested lists. Lists have been + filtered to be consistent with the dataset types in the summaries, but not + necessarily the governor dimensions in the summaries. """ - joins: QueryJoinsPlan - """Description of the "joins" stage of query construction.""" - - projection: QueryProjectionPlan - """Description of the "projection" stage of query construction.""" - find_first: QueryFindFirstPlan | None - """Description of the "find_first" stage of query construction. +@dataclasses.dataclass +class QueryTreeAnalysis: + """A struct aggregating all analysis results derived from the query tree. - This attribute is `None` if there is no find-first search at all, and - `False` in boolean contexts if the search is trivial because there is only - one collection after the collections have been resolved. + See `DirectQueryDriver.build_query` for an overview of how queries are + transformed into SQL, and the role this object plays in that. """ - final_columns: qt.ColumnSet - """The columns included in the SELECT clause of the complete SQL query - that is actually executed. - - This is a subset of `QueryProjectionPlan.columns` that differs only in - columns used by the `find_first` stage or an ORDER BY expression. + joins: QueryJoinsAnalysis + """Analysis of the "joins" stage, including all joins and columns needed by + ``tree``. Additional columns will be added to this plan later. + """ - Like all other `.queries.tree.ColumnSet` attributes, it does not include - fields added directly to `QueryBuilder.special`, which may also be added - to the SELECT clause. + union_datasets: list[ResolvedDatasetSearch[list[str]]] + """Resolved dataset searches that expand `QueryTree.any_dataset` out + into groups of dataset types with the same collection search path. """ - builder: QueryBuilder - """Under-construction SQL query associated with this plan.""" + initial_select_builder: SqlSelectBuilder + """In-progress SQL query builder, initialized with just spatial and + temporal overlaps.""" postprocessing: Postprocessing - """Struct representing post-query processing in Python, which may require - additional columns in the query results. - """ + """Struct representing post-query processing to be done in Python.""" diff --git a/python/lsst/daf/butler/direct_query_driver/_query_builder.py b/python/lsst/daf/butler/direct_query_driver/_query_builder.py index 17aecea547..9efc1b7649 100644 --- a/python/lsst/daf/butler/direct_query_driver/_query_builder.py +++ b/python/lsst/daf/butler/direct_query_driver/_query_builder.py @@ -27,536 +27,684 @@ from __future__ import annotations -__all__ = ("QueryJoiner", "QueryBuilder") +__all__ = ( + "QueryBuilder", + "SingleSelectQueryBuilder", + "UnionQueryBuilder", + "UnionQueryBuilderTerm", +) import dataclasses -import itertools -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, ClassVar +from abc import ABC, abstractmethod +from collections.abc import Iterable, Set +from types import EllipsisType +from typing import TYPE_CHECKING, Literal, TypeVar, overload import sqlalchemy -from .. import ddl -from ..nonempty_mapping import NonemptyMapping +from ..dimensions import DimensionGroup from ..queries import tree as qt -from ._postprocessing import Postprocessing +from ..registry.interfaces import Database +from ._query_analysis import ( + QueryFindFirstAnalysis, + QueryJoinsAnalysis, + QueryTreeAnalysis, + ResolvedDatasetSearch, +) +from ._sql_builders import SqlColumns, SqlJoinsBuilder, SqlSelectBuilder if TYPE_CHECKING: - from ..registry.interfaces import Database - from ..timespan_database_representation import TimespanDatabaseRepresentation - - -@dataclasses.dataclass -class QueryBuilder: - """A struct used to represent an under-construction SQL SELECT query. - - This object's methods frequently "consume" ``self``, by either returning - it after modification or returning related copy that may share state with - the original. Users should be careful never to use consumed instances, and - are recommended to reuse the same variable name to make that hard to do - accidentally. + from ._driver import DirectQueryDriver + from ._postprocessing import Postprocessing + +_T = TypeVar("_T") + + +class QueryBuilder(ABC): + """An abstract base class for objects that transform query descriptions + into SQL and `Postprocessing`. + + See `DirectQueryDriver.build_query` for an overview of query construction, + including the role this class plays in it. + + Parameters + ---------- + tree_analysis : `QueryTreeAnalysis` + Result of initial analysis of the most of the query description. + considered consumed because nested attributes will be referenced and + may be modified in-place in the future. + projection_columns : `.queries.tree.ColumnSet` + Columns to include in the query's "projection" stage, where a GROUP BY + or DISTINCT may be performed. + final_columns : `.queries.tree.ColumnSet` + Columns to include in the final query. """ - joiner: QueryJoiner - """Struct representing the SQL FROM and WHERE clauses, as well as the - columns *available* to the query (but not necessarily in the SELECT - clause). + def __init__( + self, + tree_analysis: QueryTreeAnalysis, + *, + projection_columns: qt.ColumnSet, + final_columns: qt.ColumnSet, + ): + self.joins_analysis = tree_analysis.joins + self.postprocessing = tree_analysis.postprocessing + self.projection_columns = projection_columns + self.final_columns = final_columns + self.needs_dimension_distinct = False + self.find_first_dataset = None + + joins_analysis: QueryJoinsAnalysis + """Description of the "joins" stage of query construction.""" + + projection_columns: qt.ColumnSet + """The columns present in the query after the projection is applied. + + This is always a subset of `QueryJoinsAnalysis.columns`. """ - columns: qt.ColumnSet - """Columns to include the SELECT clause. - - This does not include columns required only by `postprocessing` and columns - in `QueryJoiner.special`, which are also always included in the SELECT - clause. + needs_dimension_distinct: bool = False + """If `True`, the projection's dimensions do not include all dimensions in + the "joins" stage, and hence a SELECT DISTINCT [ON] or GROUP BY must be + used to make post-projection rows unique. """ - distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = () - """A representation of a DISTINCT or DISTINCT ON clause. + find_first_dataset: str | EllipsisType | None = None + """If not `None`, this is a find-first query for this dataset. - If `True`, this represents a SELECT DISTINCT. If a non-empty sequence, - this represents a SELECT DISTINCT ON. If `False` or an empty sequence, - there is no DISTINCT clause. + This is set even if the find-first search is trivial because there is only + one resolved collection. """ - group_by: Sequence[sqlalchemy.ColumnElement[Any]] = () - """A representation of a GROUP BY clause. + final_columns: qt.ColumnSet + """The columns included in the SELECT clause of the complete SQL query + that is actually executed. - If not-empty, a GROUP BY clause with these columns is added. This - generally requires that every `sqlalchemy.ColumnElement` held in the nested - `joiner` that is part of `columns` must either be part of `group_by` or - hold an aggregate function. - """ + This is a subset of `QueryProjectionPlan.columns` that differs only in + columns used by the `find_first` stage or an ORDER BY expression. - EMPTY_COLUMNS_NAME: ClassVar[str] = "IGNORED" - """Name of the column added to a SQL SELECT clause in order to construct - queries that have no real columns. + Like all other `.queries.tree.ColumnSet` attributes, it does not include + fields added directly to `SqlSelectBuilder.special`, which may also be + added to the SELECT clause. """ - EMPTY_COLUMNS_TYPE: ClassVar[type] = sqlalchemy.Boolean - """Type of the column added to a SQL SELECT clause in order to construct - queries that have no real columns. + postprocessing: Postprocessing + """Struct representing post-query processing in Python, which may require + additional columns in the query results. """ - @classmethod - def handle_empty_columns( - cls, columns: list[sqlalchemy.sql.ColumnElement] - ) -> list[sqlalchemy.ColumnElement]: - """Handle the edge case where a SELECT statement has no columns, by - adding a literal column that should be ignored. + @abstractmethod + def analyze_projection(self) -> None: + """Analyze the "projection" stage of query construction, in which the + query may be nested in a GROUP BY or DISTINCT subquery in order to + ensure rows do not have duplicates. + + This modifies the builder in place, and should be called immediately + after construction. + + Notes + ----- + Implementations should delegate to `super` to set + `needs_dimension_distinct`, but generally need to provide additional + logic to determine whether a GROUP BY or DISTINCT will be needed for + other reasons (e.g. duplication due to dataset searches over multiple + collections). + """ + # The projection gets interesting if it does not have all of the + # dimension keys or dataset fields of the "joins" stage, because that + # means it needs to do a GROUP BY or DISTINCT ON to get unique rows. + # Subclass implementations handle the check for dataset fields. + if self.projection_columns.dimensions != self.joins_analysis.columns.dimensions: + assert self.projection_columns.dimensions.issubset(self.joins_analysis.columns.dimensions) + # We're going from a larger set of dimensions to a smaller set; + # that means we'll be doing a SELECT DISTINCT [ON] or GROUP BY. + self.needs_dimension_distinct = True + + @abstractmethod + def analyze_find_first(self, find_first_dataset: str | EllipsisType) -> None: + """Analyze the "find first" stage of query construction, in which a + Common Table Expression with PARTITION ON may be used to find the first + dataset for each data ID and dataset type in an ordered collection + sequence. + + This modifies the builder in place, and should be called immediately + after `analyze_projection`. Parameters ---------- - columns : `list` [ `sqlalchemy.ColumnElement` ] - List of SQLAlchemy column objects. This may have no elements when - this method is called, and will always have at least one element - when it returns. - - Returns - ------- - columns : `list` [ `sqlalchemy.ColumnElement` ] - The same list that was passed in, after any modification. + find_first_dataset : `str` or ``...`` + Name of the dataset type that needs a find-first search. ``...`` + is used to indicate the dataset types in a union dataset query. """ - if not columns: - columns.append(sqlalchemy.sql.literal(True).label(cls.EMPTY_COLUMNS_NAME)) - return columns + raise NotImplementedError() - def select(self, postprocessing: Postprocessing | None) -> sqlalchemy.Select: - """Transform this builder into a SQLAlchemy representation of a SELECT - query. + @abstractmethod + def apply_joins(self, driver: DirectQueryDriver) -> None: + """Translate the "joins" stage of the query to SQL. - Parameters - ---------- - postprocessing : `Postprocessing` - Struct representing post-query processing in Python, which may - require additional columns in the query results. - - Returns - ------- - select : `sqlalchemy.Select` - SQLAlchemy SELECT statement. - """ - assert not (self.distinct and self.group_by), "At most one of distinct and group_by can be set." - sql_columns: list[sqlalchemy.ColumnElement[Any]] = [] - for logical_table, field in self.columns: - name = self.columns.get_qualified_name(logical_table, field) - if field is None: - sql_columns.append(self.joiner.dimension_keys[logical_table][0].label(name)) - else: - name = self.joiner.db.name_shrinker.shrink(name) - if self.columns.is_timespan(logical_table, field): - sql_columns.extend(self.joiner.timespans[logical_table].flatten(name)) - else: - sql_columns.append(self.joiner.fields[logical_table][field].label(name)) - if postprocessing is not None: - for element in postprocessing.iter_missing(self.columns): - sql_columns.append( - self.joiner.fields[element.name]["region"].label( - self.joiner.db.name_shrinker.shrink( - self.columns.get_qualified_name(element.name, "region") - ) - ) - ) - for label, sql_column in self.joiner.special.items(): - sql_columns.append(sql_column.label(label)) - self.handle_empty_columns(sql_columns) - result = sqlalchemy.select(*sql_columns) - if self.joiner.from_clause is not None: - result = result.select_from(self.joiner.from_clause) - if self.distinct is True: - result = result.distinct() - elif self.distinct: - result = result.distinct(*self.distinct) - if self.group_by: - result = result.group_by(*self.group_by) - if self.joiner.where_terms: - result = result.where(*self.joiner.where_terms) - return result - - def join(self, other: QueryJoiner) -> QueryBuilder: - """Join tables, subqueries, and WHERE clauses from another query into - this one, in place. + This modifies the builder in place. It is the first step in the + "apply" phase, and should be called after `analyze_find_first` finishes + the "analysis" phase (if more than analysis is needed). Parameters ---------- - other : `QueryJoiner` - Object holding the FROM and WHERE clauses to add to this one. - JOIN ON clauses are generated via the dimension keys in common. - - Returns - ------- - self : `QueryBuilder` - This `QueryBuilder` instance (never a copy); returned to enable - method-chaining. + driver : `DirectQueryDriver` + Driver that invoked this builder and may be called back into for + lower-level SQL generation operations. """ - self.joiner.join(other) - return self + raise NotImplementedError() - def to_joiner( - self, cte: bool = False, force: bool = False, *, postprocessing: Postprocessing | None - ) -> QueryJoiner: - """Convert this builder into a `QueryJoiner`, nesting it in a subquery - or common table expression only if needed to apply DISTINCT or GROUP BY - clauses. + @abstractmethod + def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None: + """Translate the "projection" stage of the query to SQL. - This method consumes ``self``. + This modifies the builder in place. It is the second step in the + "apply" phase, after `apply_joins`. Parameters ---------- - cte : `bool`, optional - If `True`, nest via a common table expression instead of a - subquery. - force : `bool`, optional - If `True`, nest via a subquery or common table expression even if - there is no DISTINCT or GROUP BY. - postprocessing : `Postprocessing` - Struct representing post-query processing in Python, which may - require additional columns in the query results. - - Returns - ------- - joiner : `QueryJoiner` - QueryJoiner` with at least all columns in `columns` available. - This may or may not be the `joiner` attribute of this object. + driver : `DirectQueryDriver` + Driver that invoked this builder and may be called back into for + lower-level SQL generation operations. + order_by : `~collections.abc.Iterable` [ \ + `.queries.tree.OrderExpression` ] + Column expression used to order the query rows. """ - if force or self.distinct or self.group_by: - sql_from_clause = ( - self.select(postprocessing).cte() if cte else self.select(postprocessing).subquery() - ) - return QueryJoiner(self.joiner.db, sql_from_clause).extract_columns( - self.columns, special=self.joiner.special.keys() - ) - return self.joiner + raise NotImplementedError() - def nested( - self, cte: bool = False, force: bool = False, *, postprocessing: Postprocessing | None - ) -> QueryBuilder: - """Convert this builder into a `QueryBuiler` that is guaranteed to have - no DISTINCT or GROUP BY, nesting it in a subquery or common table - expression only if needed to apply any current DISTINCT or GROUP BY - clauses. + @abstractmethod + def apply_find_first(self, driver: DirectQueryDriver) -> None: + """Transform the "find first" stage of the query to SQL. - This method consumes ``self``. + This modifies the builder in place. It is the third and final step in + the "apply" phase, after "apply_projection". Parameters ---------- - cte : `bool`, optional - If `True`, nest via a common table expression instead of a - subquery. - force : `bool`, optional - If `True`, nest via a subquery or common table expression even if - there is no DISTINCT or GROUP BY. - postprocessing : `Postprocessing` - Struct representing post-query processing in Python, which may - require additional columns in the query results. - - Returns - ------- - builder : `QueryBuilder` - `QueryBuilder` with at least all columns in `columns` available. - This may or may not be the `builder` attribute of this object. + driver : `DirectQueryDriver` + Driver that invoked this builder and may be called back into for + lower-level SQL generation operations. """ - return QueryBuilder( - self.to_joiner(cte=cte, force=force, postprocessing=postprocessing), columns=self.columns - ) + raise NotImplementedError() + + @overload + def finish_select( + self, return_columns: Literal[True] = True + ) -> tuple[sqlalchemy.CompoundSelect | sqlalchemy.Select, SqlColumns]: ... - def union_subquery( - self, others: Iterable[QueryBuilder], postprocessing: Postprocessing | None = None - ) -> QueryJoiner: - """Combine this builder with others to make a SELECT UNION subquery. + @overload + def finish_select( + self, return_columns: Literal[False] + ) -> tuple[sqlalchemy.CompoundSelect | sqlalchemy.Select, None]: ... + + @abstractmethod + def finish_select( + self, return_columns: bool = True + ) -> tuple[sqlalchemy.CompoundSelect | sqlalchemy.Select, SqlColumns | None]: + """Finish translating the query into executable SQL. Parameters ---------- - others : `~collections.abc.Iterable` [ `QueryBuilder` ] - Other query builders to union with. Their `columns` attributes - must be the same as those of ``self``. - postprocessing : `Postprocessing` - Struct representing post-query processing in Python, which may - require additional columns in the query results. + return_columns : `bool` + If `True`, return a structure that organizes the SQLAlchemy + column objects available to the query. Returns ------- - joiner : `QueryJoiner` - `QueryJoiner` with at least all columns in `columns` available. - This may or may not be the `joiner` attribute of this object. + sql_select : `sqlalchemy.Select` or `sqlalchemy.CompoundSelect`. + A SELECT [UNION ALL] SQL query. + sql_columns : `SqlColumns` or `None` + The columns available to the query (including any available to + an ORDER BY clause, not just those in the SELECT clause, in + contexts where those are not the same. May be `None` (but is not + guaranteed to be) if ``return_columns=False``. """ - select0 = self.select(postprocessing) - other_selects = [other.select(postprocessing) for other in others] - return QueryJoiner( - self.joiner.db, - from_clause=select0.union(*other_selects).subquery(), - ).extract_columns(self.columns, postprocessing) + raise NotImplementedError() - def make_table_spec(self, postprocessing: Postprocessing | None) -> ddl.TableSpec: - """Make a specification that can be used to create a table to store - this query's outputs. + @abstractmethod + def finish_nested(self, cte: bool = False) -> SqlSelectBuilder: + """Finish translating the query into SQL that can be used as a + subquery. Parameters ---------- - postprocessing : `Postprocessing` - Struct representing post-query processing in Python, which may - require additional columns in the query results. + cte : `bool`, optional + If `True`, nest the query in a common table expression (i.e. SQL + WITH statement) instead of a subquery. Returns ------- - spec : `.ddl.TableSpec` - Table specification for this query's result columns (including - those from `postprocessing` and `QueryJoiner.special`). + select_builder : `SqlSelectBuilder` + A builder object that maps to a single SELECT statement. This may + directly hold the original query with no subquery or CTE if that + query was a single SELECT with no GROUP BY or DISTINCT; in either + case it is guaranteed that modifying this builder's result columns + and transforming it into a SELECT will not change the number of + rows. """ - assert not self.joiner.special, "special columns not supported in make_table_spec" - results = ddl.TableSpec( - [ - self.columns.get_column_spec(logical_table, field).to_sql_spec( - name_shrinker=self.joiner.db.name_shrinker - ) - for logical_table, field in self.columns - ] + raise NotImplementedError() + + +class SingleSelectQueryBuilder(QueryBuilder): + """An implementation of `QueryBuilder` for queries that are structured as + a single SELECT (i.e. not a union). + + See `DirectQueryDriver.build_query` for an overview of query construction, + including the role this class plays in it. This builder is used for most + butler queries, for which `.queries.tree.QueryTree.any_dataset` is `None`. + + Parameters + ---------- + tree_analysis : `QueryTreeAnalysis` + Result of initial analysis of the most of the query description. + considered consumed because nested attributes will be referenced and + may be modified in-place in the future. + projection_columns : `.queries.tree.ColumnSet` + Columns to include in the query's "projection" stage, where a GROUP BY + or DISTINCT may be performed. + final_columns : `.queries.tree.ColumnSet` + Columns to include in the final query. + """ + + def __init__( + self, + tree_analysis: QueryTreeAnalysis, + *, + projection_columns: qt.ColumnSet, + final_columns: qt.ColumnSet, + ) -> None: + super().__init__( + tree_analysis=tree_analysis, + projection_columns=projection_columns, + final_columns=final_columns, ) - if postprocessing: - for element in postprocessing.iter_missing(self.columns): - results.fields.add( - ddl.FieldSpec.for_region( - self.joiner.db.name_shrinker.shrink( - self.columns.get_qualified_name(element.name, "region") - ) - ) - ) - if not results.fields: - results.fields.add(ddl.FieldSpec(name=self.EMPTY_COLUMNS_NAME, dtype=self.EMPTY_COLUMNS_TYPE)) - return results + assert not tree_analysis.union_datasets, "UnionQueryPlan should be used instead." + self._select_builder = tree_analysis.initial_select_builder + self.find_first = None + self.needs_dataset_distinct = False + + needs_dataset_distinct: bool = False + """If `True`, the projection columns do not include collection-specific + dataset fields that were present in the "joins" stage, and hence a SELECT + DISTINCT [ON] or GROUP BY must be added to make post-projection rows + unique. + """ + find_first: QueryFindFirstAnalysis[str] | None = None + """Description of the "find_first" stage of query construction. -@dataclasses.dataclass -class QueryJoiner: - """A struct used to represent the FROM and WHERE clauses of an - under-construction SQL SELECT query. - - This object's methods frequently "consume" ``self``, by either returning - it after modification or returning related copy that may share state with - the original. Users should be careful never to use consumed instances, and - are recommended to reuse the same variable name to make that hard to do - accidentally. + This attribute is `None` if there is no find-first search at all, and + `False` in boolean contexts if the search is trivial because there is only + one collection after the collections have been resolved. """ - db: Database - """Object that abstracts over the database engine.""" + def analyze_projection(self) -> None: + # Docstring inherited. + super().analyze_projection() + # See if we need to do a DISTINCT [ON] or GROUP BY to get unique rows + # because we have rows for datasets in multiple collections with the + # same data ID and dataset type. + for dataset_type in self.joins_analysis.columns.dataset_fields: + assert dataset_type is not ..., "Union dataset in non-dataset-union query." + if not self.projection_columns.dataset_fields[dataset_type]: + # The "joins"-stage query has one row for each collection for + # each data ID, but the projection-stage query just wants + # one row for each data ID. + if len(self.joins_analysis.datasets[dataset_type].collection_records) > 1: + self.needs_dataset_distinct = True + break + # If there are any dataset fields being propagated through the + # projection and there is more than one collection, we need to include + # the collection_key column so we can use that as one of the DISTINCT + # or GROUP BY columns. + for dataset_type, fields_for_dataset in self.projection_columns.dataset_fields.items(): + assert dataset_type is not ..., "Union dataset in non-dataset-union query." + if len(self.joins_analysis.datasets[dataset_type].collection_records) > 1: + fields_for_dataset.add("collection_key") + + def analyze_find_first(self, find_first_dataset: str | EllipsisType) -> None: + # Docstring inherited. + assert find_first_dataset is not ..., "No dataset union in this query" + self.find_first = QueryFindFirstAnalysis(self.joins_analysis.datasets[find_first_dataset]) + # If we're doing a find-first search and there's a calibration + # collection in play, we need to make sure the rows coming out of + # the base query have only one timespan for each data ID + + # collection, and we can only do that with a GROUP BY and COUNT + # that we inspect in postprocessing. + if self.find_first.search.is_calibration_search: + self.postprocessing.check_validity_match_count = True + + def apply_joins(self, driver: DirectQueryDriver) -> None: + # Docstring inherited. + driver.apply_initial_query_joins( + self._select_builder, self.joins_analysis, union_dataset_dimensions=None + ) + driver.apply_missing_dimension_joins(self._select_builder, self.joins_analysis) + + def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None: + # Docstring inherited. + driver.apply_query_projection( + self._select_builder, + self.postprocessing, + join_datasets=self.joins_analysis.datasets, + union_datasets=None, + projection_columns=self.projection_columns, + needs_dimension_distinct=self.needs_dimension_distinct, + needs_dataset_distinct=self.needs_dataset_distinct, + needs_validity_match_count=self.postprocessing.check_validity_match_count, + find_first_dataset=None if self.find_first is None else self.find_first.search.name, + order_by=order_by, + ) - from_clause: sqlalchemy.FromClause | None = None - """SQLAlchemy representation of the FROM clause. + def apply_find_first(self, driver: DirectQueryDriver) -> None: + # Docstring inherited. + if not self.find_first: + return + self._select_builder = driver.apply_query_find_first( + self._select_builder, self.postprocessing, self.find_first + ) - This is initialized to `None` but in almost all cases is immediately - replaced. - """ + # The overloads in the base class seem to keep MyPy from recognizing the + # return type as covariant. + def finish_select( # type: ignore + self, + return_columns: bool = True, + ) -> tuple[sqlalchemy.Select, SqlColumns]: + # Docstring inherited. + self._select_builder.columns = self.final_columns + return self._select_builder.select(self.postprocessing), self._select_builder.joins - where_terms: list[sqlalchemy.ColumnElement[bool]] = dataclasses.field(default_factory=list) - """Sequence of WHERE clause terms to be combined with AND.""" + def finish_nested(self, cte: bool = False) -> SqlSelectBuilder: + # Docstring inherited. + self._select_builder.columns = self.final_columns + return self._select_builder.nested(cte=cte, postprocessing=self.postprocessing) - dimension_keys: NonemptyMapping[str, list[sqlalchemy.ColumnElement]] = dataclasses.field( - default_factory=lambda: NonemptyMapping(list) - ) - """Mapping of dimension keys included in the FROM clause. - Nested lists correspond to different tables that have the same dimension - key (which should all have equal values for all result rows). +@dataclasses.dataclass +class UnionQueryBuilderTerm: + """A helper struct that holds state for `UnionQueryBuilder` that + corresponds to a set of dataset types with the same post-filtering + collection sequence. """ - fields: NonemptyMapping[str, dict[str, sqlalchemy.ColumnElement[Any]]] = dataclasses.field( - default_factory=lambda: NonemptyMapping(dict) - ) - """Mapping of columns that are neither dimension keys nor timespans. + select_builders: list[SqlSelectBuilder] + """Under-construction SQL queries associated with this plan, to be unioned + together when complete. - Inner and outer keys correspond to the "logical table" and "field" pairs - that result from iterating over `~.queries.tree.ColumnSet`, with the former - either a dimension element name or dataset type name. + Each term corresponds to a different dataset type and a single SELECT; note + that this means a `UnionQueryBuilderTerm` does not map 1-1 with a SELECT in + the final UNION - it maps to a set of extremely similar SELECTs that differ + only in the dataset type name injected into each SELECT at the end. """ - timespans: dict[str, TimespanDatabaseRepresentation] = dataclasses.field(default_factory=dict) - """Mapping of timespan columns. - - Keys are "logical tables" - dimension element names or dataset type names. + datasets: ResolvedDatasetSearch[list[str]] + """Searches for datasets of different types to be joined into the rest of + the query, with the results (after projection and find-first) unioned + together. + + The dataset types in a single `QueryUnionTermPlan` have the exact same + post-filtering collection search path, and hence the exact same query + plan, aside from the dataset type used to generate their dataset subquery. + Dataset types that have the same dimensions but do not have the same + post-filtering collection search path go in different `QueryUnionTermPlan` + instances, which still contribute to the same UNION [ALL] query. + Dataset types with different dimensions cannot go in the same SQL query + at all. """ - special: dict[str, sqlalchemy.ColumnElement[Any]] = dataclasses.field(default_factory=dict) - """Special columns that are available from the FROM clause and - automatically included in the SELECT clause when this joiner is nested - within a `QueryBuilder`. - - These columns are not part of the dimension universe and are not associated - with a dataset. They are never returned to users, even if they may be - included in raw SQL results. + needs_dataset_distinct: bool = False + """If `True`, the projection columns do not include collection-specific + dataset fields that were present in the "joins" stage, and hence a SELECT + DISTINCT [ON] or GROUP BY must be added to make post-projection rows + unique. """ - def extract_dimensions(self, dimensions: Iterable[str], **kwargs: str) -> QueryJoiner: - """Add dimension key columns from `from_clause` into `dimension_keys`. - - Parameters - ---------- - dimensions : `~collections.abc.Iterable` [ `str` ] - Names of dimensions to include, assuming that their names in - `sql_columns` are just the dimension names. - **kwargs : `str` - Additional dimensions to include, with the names in `sql_columns` - as keys and the actual dimension names as values. + needs_validity_match_count: bool = False + """Whether this query needs a validity match column for postprocessing + to check. - Returns - ------- - self : `QueryJoiner` - This `QueryJoiner` instance (never a copy). Provided to enable - method chaining. - """ - assert self.from_clause is not None, "Cannot extract columns with no FROM clause." - for dimension_name in dimensions: - self.dimension_keys[dimension_name].append(self.from_clause.columns[dimension_name]) - for k, v in kwargs.items(): - self.dimension_keys[v].append(self.from_clause.columns[k]) - return self - - def extract_columns( - self, - columns: qt.ColumnSet, - postprocessing: Postprocessing | None = None, - special: Iterable[str] = (), - ) -> QueryJoiner: - """Add columns from `from_clause` into `dimension_keys`. - - Parameters - ---------- - columns : `.queries.tree.ColumnSet` - Columns to include, assuming that - `.queries.tree.ColumnSet.get_qualified_name` corresponds to the - name used in `sql_columns` (after name shrinking). - postprocessing : `Postprocessing`, optional - Postprocessing object whose needed columns should also be included. - special : `~collections.abc.Iterable` [ `str` ], optional - Additional special columns to extract. - - Returns - ------- - self : `QueryJoiner` - This `QueryJoiner` instance (never a copy). Provided to enable - method chaining. - """ - assert self.from_clause is not None, "Cannot extract columns with no FROM clause." - for logical_table, field in columns: - name = columns.get_qualified_name(logical_table, field) - if field is None: - self.dimension_keys[logical_table].append(self.from_clause.columns[name]) - else: - name = self.db.name_shrinker.shrink(name) - if columns.is_timespan(logical_table, field): - self.timespans[logical_table] = self.db.getTimespanRepresentation().from_columns( - self.from_clause.columns, name - ) - else: - self.fields[logical_table][field] = self.from_clause.columns[name] - if postprocessing is not None: - for element in postprocessing.iter_missing(columns): - self.fields[element.name]["region"] = self.from_clause.columns[ - self.db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) - ] - if postprocessing.check_validity_match_count: - self.special[postprocessing.VALIDITY_MATCH_COUNT] = self.from_clause.columns[ - postprocessing.VALIDITY_MATCH_COUNT - ] - for name in special: - self.special[name] = self.from_clause.columns[name] - return self - - def join(self, other: QueryJoiner) -> QueryJoiner: - """Combine this `QueryJoiner` with another via an INNER JOIN on - dimension keys. - - This method consumes ``self``. + This can be `False` even if `Postprocessing.check_validity_match_count` is + `True`, indicating that some other term in the union needs the column and + hence this term just needs a dummy column (with "1" as the value). + """ - Parameters - ---------- - other : `QueryJoiner` - Other joiner to combine with this one. + find_first: QueryFindFirstAnalysis[list[str]] | None = None + """Description of the "find_first" stage of query construction. - Returns - ------- - joined : `QueryJoiner` - A `QueryJoiner` with all columns present in either operand, with - its `from_clause` representing a SQL INNER JOIN where the dimension - key columns common to both operands are constrained to be equal. - If either operand does not have `from_clause`, the other's is used. - The `where_terms` of the two operands are concatenated, - representing a logical AND (with no attempt at deduplication). - """ - join_on: list[sqlalchemy.ColumnElement] = [] - for dimension_name in other.dimension_keys.keys(): - if dimension_name in self.dimension_keys: - for column1, column2 in itertools.product( - self.dimension_keys[dimension_name], other.dimension_keys[dimension_name] - ): - join_on.append(column1 == column2) - self.dimension_keys[dimension_name].extend(other.dimension_keys[dimension_name]) - if self.from_clause is None: - self.from_clause = other.from_clause - elif other.from_clause is not None: - join_on_sql: sqlalchemy.ColumnElement[bool] - match len(join_on): - case 0: - join_on_sql = sqlalchemy.true() - case 1: - (join_on_sql,) = join_on - case _: - join_on_sql = sqlalchemy.and_(*join_on) - self.from_clause = self.from_clause.join(other.from_clause, onclause=join_on_sql) - for logical_table, fields in other.fields.items(): - self.fields[logical_table].update(fields) - self.timespans.update(other.timespans) - self.special.update(other.special) - self.where_terms += other.where_terms - return self - - def where(self, *args: sqlalchemy.ColumnElement[bool]) -> QueryJoiner: - """Add a WHERE clause term. + This attribute is `None` if there is no find-first search at all, and + `False` in boolean contexts if the search is trivial because there is only + one collection after the collections have been resolved. + """ - Parameters - ---------- - *args : `sqlalchemy.ColumnElement` - SQL boolean column expressions to be combined with AND. - Returns - ------- - self : `QueryJoiner` - This `QueryJoiner` instance (never a copy). Provided to enable - method chaining. - """ - self.where_terms.extend(args) - return self +class UnionQueryBuilder(QueryBuilder): + """An implementation of `QueryBuilder` for queries that are structured as + a UNION ALL with one SELECT for each dataset type. + + See `DirectQueryDriver.build_query` for an overview of query construction, + including the role this class plays in it. This builder is used + special butler queries where `.queries.tree.QueryTree.any_dataset` is not + `None`. + + Parameters + ---------- + tree_analysis : `QueryTreeAnalysis` + Result of initial analysis of the most of the query description. + considered consumed because nested attributes will be referenced and + may be modified in-place in the future. + projection_columns : `.queries.tree.ColumnSet` + Columns to include in the query's "projection" stage, where a GROUP BY + or DISTINCT may be performed. + final_columns : `.queries.tree.ColumnSet` + Columns to include in the final query. + union_dataset_dimensions : `DimensionGroup` + Dimensions of the dataset types that comprise the union. + + Notes + ----- + `UnionQueryBuilder` can be in one of two states: + + - During the "analysis" phase and at the beginning of the "apply" phase, + it has a single initial `SqlSelectBuilder`, because all union terms are + identical at this stage. The `UnionQueryTerm.builder` lists are empty. + - Within `apply_joins`, this single `SqlSelectBuilder` is copied to + populate the per-dataset type `SqlSelectBuilder` instances in the + `UnionQueryTerm.builders` lists. + """ - def to_builder( + def __init__( self, - columns: qt.ColumnSet, - distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = (), - group_by: Sequence[sqlalchemy.ColumnElement[Any]] = (), - ) -> QueryBuilder: - """Convert this joiner into a `QueryBuilder` by providing SELECT clause - columns and optional DISTINCT or GROUP BY clauses. + tree_analysis: QueryTreeAnalysis, + *, + projection_columns: qt.ColumnSet, + final_columns: qt.ColumnSet, + union_dataset_dimensions: DimensionGroup, + ): + super().__init__( + tree_analysis=tree_analysis, + projection_columns=projection_columns, + final_columns=final_columns, + ) + self._initial_select_builder: SqlSelectBuilder | None = tree_analysis.initial_select_builder + self.union_dataset_dimensions = union_dataset_dimensions + self.union_terms = [ + UnionQueryBuilderTerm(select_builders=[], datasets=datasets) + for datasets in tree_analysis.union_datasets + ] + + @property + def db(self) -> Database: + """The database object associated with the nested select builders.""" + if self._initial_select_builder is not None: + return self._initial_select_builder.joins.db + else: + return self.union_terms[0].select_builders[0].joins.db + + @property + def special(self) -> Set[str]: + """The special columns associated with the nested select builders.""" + if self._initial_select_builder is not None: + return self._initial_select_builder.joins.special.keys() + else: + return self.union_terms[0].select_builders[0].joins.special.keys() + + def analyze_projection(self) -> None: + # Docstring inherited. + super().analyze_projection() + # See if we need to do a DISTINCT [ON] or GROUP BY to get unique rows + # because we have rows for datasets in multiple collections with the + # same data ID and dataset type. + for dataset_type in self.joins_analysis.columns.dataset_fields: + if not self.projection_columns.dataset_fields[dataset_type]: + if dataset_type is ...: + for union_term in self.union_terms: + if len(union_term.datasets.collection_records) > 1: + union_term.needs_dataset_distinct = True + elif len(self.joins_analysis.datasets[dataset_type].collection_records) > 1: + # If a dataset being joined into all union terms has + # multiple collections, need_dataset_distinct is true + # for all union terms and we can exit the loop early. + for union_term in self.union_terms: + union_term.needs_dataset_distinct = True + break + # If there are any dataset fields being propagated through the + # projection and there is more than one collection, we need to include + # the collection_key column so we can use that as one of the DISTINCT + # or GROUP BY columns. + for dataset_type, fields_for_dataset in self.projection_columns.dataset_fields.items(): + if dataset_type is ...: + for union_term in self.union_terms: + # If there is more than one collection for one union term, + # we need to add collection_key to all of them to keep the + # SELECT columns uniform. + if len(union_term.datasets.collection_records) > 1: + fields_for_dataset.add("collection_key") + break + elif len(self.joins_analysis.datasets[dataset_type].collection_records) > 1: + fields_for_dataset.add("collection_key") + + def analyze_find_first(self, find_first_dataset: str | EllipsisType) -> None: + # Docstring inherited. + if find_first_dataset is ...: + for union_term in self.union_terms: + union_term.find_first = QueryFindFirstAnalysis(union_term.datasets) + # If we're doing a find-first search and there's a calibration + # collection in play, we need to make sure the rows coming out + # of the base query have only one timespan for each data ID + + # collection, and we can only do that with a GROUP BY and COUNT + # that we inspect in postprocessing. + # Because the postprocessing is applied to the full query, all + # union terms will need this column, even if only one populates + # it with a nontrivial value. + if union_term.find_first.search.is_calibration_search: + self.postprocessing.check_validity_match_count = True + union_term.needs_validity_match_count = True + else: + # The query system machinery should actually be able to handle this + # case without too much difficulty (we just put the same + # find_first plan in each union term), but the result doesn't seem + # like it'd be useful, so it's better not to have to maintain that + # logic branch. + raise NotImplementedError( + f"Additional dataset search {find_first_dataset!r} can only be joined into a " + "union dataset query as a constraint in data IDs, not as a find-first result." + ) - This method consumes ``self``. + def apply_joins(self, driver: DirectQueryDriver) -> None: + # Docstring inherited. + assert self._initial_select_builder is not None + driver.apply_initial_query_joins( + self._initial_select_builder, self.joins_analysis, self.union_dataset_dimensions + ) + # Join in the union datasets. This makes one copy of the initial + # select builder for each dataset type, and hence from here on we have + # to repeat whatever we do to all select builders. + for union_term in self.union_terms: + for dataset_type_name in union_term.datasets.name: + select_builder = self._initial_select_builder.copy() + driver.join_dataset_search( + select_builder.joins, + union_term.datasets, + self.joins_analysis.columns.dataset_fields[...], + union_dataset_type_name=dataset_type_name, + ) + union_term.select_builders.append(select_builder) + self._initial_select_builder = None + for union_term in self.union_terms: + for select_builder in union_term.select_builders: + driver.apply_missing_dimension_joins(select_builder, self.joins_analysis) + + def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None: + # Docstring inherited. + for union_term in self.union_terms: + for builder in union_term.select_builders: + driver.apply_query_projection( + builder, + self.postprocessing, + join_datasets=self.joins_analysis.datasets, + union_datasets=union_term.datasets, + projection_columns=self.projection_columns, + needs_dimension_distinct=self.needs_dimension_distinct, + needs_dataset_distinct=union_term.needs_dataset_distinct, + needs_validity_match_count=union_term.needs_validity_match_count, + find_first_dataset=None if union_term.find_first is None else ..., + order_by=order_by, + ) - Parameters - ---------- - columns : `~.queries.tree.ColumnSet` - Regular columns to include in the SELECT clause. - distinct : `bool` or `~collections.abc.Sequence` [ \ - `sqlalchemy.ColumnElement` ], optional - Specification of the DISTINCT clause (see `QueryBuilder.distinct`). - group_by : `~collections.abc.Sequence` [ \ - `sqlalchemy.ColumnElement` ], optional - Specification of the GROUP BY clause (see `QueryBuilder.group_by`). + def apply_find_first(self, driver: DirectQueryDriver) -> None: + # Docstring inherited. + for union_term in self.union_terms: + if not union_term.find_first: + continue + union_term.select_builders = [ + driver.apply_query_find_first(builder, self.postprocessing, union_term.find_first) + for builder in union_term.select_builders + ] - Returns - ------- - builder : `QueryBuilder` - New query builder. - """ - return QueryBuilder( - self, - columns, - distinct=distinct if type(distinct) is bool else tuple(distinct), - group_by=tuple(group_by), + @overload + def finish_select( + self, return_columns: Literal[True] = True + ) -> tuple[sqlalchemy.CompoundSelect | sqlalchemy.Select, SqlColumns]: ... + + @overload + def finish_select( + self, return_columns: Literal[False] + ) -> tuple[sqlalchemy.CompoundSelect | sqlalchemy.Select, None]: ... + + def finish_select( + self, return_columns: bool = True + ) -> tuple[sqlalchemy.CompoundSelect | sqlalchemy.Select, SqlColumns | None]: + # Docstring inherited. + terms: list[sqlalchemy.Select] = [] + for union_term in self.union_terms: + for dataset_type_name, select_builder in zip( + union_term.datasets.name, union_term.select_builders + ): + select_builder.columns = self.final_columns + select_builder.joins.special["_DATASET_TYPE_NAME"] = sqlalchemy.literal(dataset_type_name) + terms.append(select_builder.select(self.postprocessing)) + sql: sqlalchemy.Select | sqlalchemy.CompoundSelect = ( + sqlalchemy.union_all(*terms) if len(terms) > 1 else terms[0] ) + columns: SqlColumns | None = None + if return_columns: + columns = SqlColumns( + db=self.db, + ) + columns.extract_columns( + self.final_columns, + self.postprocessing, + self.special, + column_collection=sql.selected_columns, + ) + return sql, columns + + def finish_nested(self, cte: bool = False) -> SqlSelectBuilder: + # Docstring inherited. + sql_select, _ = self.finish_select(return_columns=False) + from_clause = sql_select.cte() if cte else sql_select.subquery() + joins_builder = SqlJoinsBuilder( + db=self.db, + from_clause=from_clause, + ).extract_columns(self.final_columns, self.postprocessing) + return SqlSelectBuilder(joins_builder, columns=self.final_columns) diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_builders.py b/python/lsst/daf/butler/direct_query_driver/_sql_builders.py new file mode 100644 index 0000000000..c180a58e4e --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_sql_builders.py @@ -0,0 +1,678 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("SqlJoinsBuilder", "SqlSelectBuilder", "SqlColumns", "make_table_spec") + +import dataclasses +import itertools +from collections.abc import Iterable, Sequence +from types import EllipsisType +from typing import TYPE_CHECKING, Any, ClassVar, Self + +import sqlalchemy + +from .. import ddl +from ..nonempty_mapping import NonemptyMapping +from ..queries import tree as qt +from ._postprocessing import Postprocessing + +if TYPE_CHECKING: + from ..registry.interfaces import Database + from ..timespan_database_representation import TimespanDatabaseRepresentation + + +@dataclasses.dataclass +class SqlSelectBuilder: + """A struct used to represent an under-construction SQL SELECT query. + + This object's methods frequently "consume" ``self``, by either returning + it after modification or returning related copy that may share state with + the original. Users should be careful never to use consumed instances, and + are recommended to reuse the same variable name to make that hard to do + accidentally. + """ + + joins: SqlJoinsBuilder + """Struct representing the SQL FROM and WHERE clauses, as well as the + columns *available* to the query (but not necessarily in the SELECT + clause). + """ + + columns: qt.ColumnSet + """Columns to include the SELECT clause. + + This does not include columns required only by `Postprocessing` and columns + in `SqlJoinsBuilder.special`, which are also always included in the SELECT + clause. + """ + + distinct: bool | tuple[sqlalchemy.ColumnElement[Any], ...] = () + """A representation of a DISTINCT or DISTINCT ON clause. + + If `True`, this represents a SELECT DISTINCT. If a non-empty sequence, + this represents a SELECT DISTINCT ON. If `False` or an empty sequence, + there is no DISTINCT clause. + """ + + group_by: tuple[sqlalchemy.ColumnElement[Any], ...] = () + """A representation of a GROUP BY clause. + + If not-empty, a GROUP BY clause with these columns is added. This + generally requires that every `sqlalchemy.ColumnElement` held in the nested + `joins` builder that is part of `columns` must either be part of `group_by` + or hold an aggregate function. + """ + + EMPTY_COLUMNS_NAME: ClassVar[str] = "IGNORED" + """Name of the column added to a SQL SELECT clause in order to construct + queries that have no real columns. + """ + + EMPTY_COLUMNS_TYPE: ClassVar[type] = sqlalchemy.Boolean + """Type of the column added to a SQL SELECT clause in order to construct + queries that have no real columns. + """ + + def copy(self) -> SqlSelectBuilder: + """Return a copy that can be safely mutated without affecting the + original. + """ + return dataclasses.replace(self, joins=self.joins.copy(), columns=self.columns.copy()) + + @classmethod + def handle_empty_columns( + cls, columns: list[sqlalchemy.sql.ColumnElement] + ) -> list[sqlalchemy.ColumnElement]: + """Handle the edge case where a SELECT statement has no columns, by + adding a literal column that should be ignored. + + Parameters + ---------- + columns : `list` [ `sqlalchemy.ColumnElement` ] + List of SQLAlchemy column objects. This may have no elements when + this method is called, and will always have at least one element + when it returns. + + Returns + ------- + columns : `list` [ `sqlalchemy.ColumnElement` ] + The same list that was passed in, after any modification. + """ + if not columns: + columns.append(sqlalchemy.sql.literal(True).label(cls.EMPTY_COLUMNS_NAME)) + return columns + + def select(self, postprocessing: Postprocessing | None) -> sqlalchemy.Select: + """Transform this builder into a SQLAlchemy representation of a SELECT + query. + + Parameters + ---------- + postprocessing : `Postprocessing` + Struct representing post-query processing in Python, which may + require additional columns in the query results. + + Returns + ------- + select : `sqlalchemy.Select` + SQLAlchemy SELECT statement. + """ + assert not (self.distinct and self.group_by), "At most one of distinct and group_by can be set." + sql_columns: list[sqlalchemy.ColumnElement[Any]] = [] + for logical_table, field in self.columns: + name = self.columns.get_qualified_name(logical_table, field) + if field is None: + assert logical_table is not ... + sql_columns.append(self.joins.dimension_keys[logical_table][0].label(name)) + else: + name = self.joins.db.name_shrinker.shrink(name) + if self.columns.is_timespan(logical_table, field): + sql_columns.extend(self.joins.timespans[logical_table].flatten(name)) + else: + sql_columns.append(self.joins.fields[logical_table][field].label(name)) + if postprocessing is not None: + for element in postprocessing.iter_missing(self.columns): + sql_columns.append( + self.joins.fields[element.name]["region"].label( + self.joins.db.name_shrinker.shrink( + self.columns.get_qualified_name(element.name, "region") + ) + ) + ) + for label, sql_column in self.joins.special.items(): + sql_columns.append(sql_column.label(label)) + self.handle_empty_columns(sql_columns) + result = sqlalchemy.select(*sql_columns) + if self.joins.from_clause is not None: + result = result.select_from(self.joins.from_clause) + if self.distinct is True: + result = result.distinct() + elif self.distinct: + result = result.distinct(*self.distinct) + if self.group_by: + result = result.group_by(*self.group_by) + if self.joins.where_terms: + result = result.where(*self.joins.where_terms) + return result + + def join(self, other: SqlJoinsBuilder) -> SqlSelectBuilder: + """Join tables, subqueries, and WHERE clauses from another query into + this one, in place. + + Parameters + ---------- + other : `SqlJoinsBuilder` + Object holding the FROM and WHERE clauses to add to this one. + JOIN ON clauses are generated via the dimension keys in common. + + Returns + ------- + self : `SqlSelectBuilder` + This `SqlSelectBuilder` instance (never a copy); returned to enable + method-chaining. + """ + self.joins.join(other) + return self + + def into_from_builder( + self, cte: bool = False, force: bool = False, *, postprocessing: Postprocessing | None + ) -> SqlJoinsBuilder: + """Convert this builder into a `SqlJoinsBuilder`, nesting it in a + subquery or common table expression only if needed to apply DISTINCT or + GROUP BY clauses. + + This method consumes ``self``. + + Parameters + ---------- + cte : `bool`, optional + If `True`, nest via a common table expression instead of a + subquery. + force : `bool`, optional + If `True`, nest via a subquery or common table expression even if + there is no DISTINCT or GROUP BY. + postprocessing : `Postprocessing` + Struct representing post-query processing in Python, which may + require additional columns in the query results. + + Returns + ------- + joins_builder : `SqlJoinsBuilder` + SqlJoinsBuilder` with at least all columns in `columns` available. + This may or may not be the `joins` attribute of this object. + """ + if force or self.distinct or self.group_by: + sql_from_clause = ( + self.select(postprocessing).cte() if cte else self.select(postprocessing).subquery() + ) + return SqlJoinsBuilder(db=self.joins.db, from_clause=sql_from_clause).extract_columns( + self.columns, special=self.joins.special.keys() + ) + return self.joins + + def nested( + self, cte: bool = False, force: bool = False, *, postprocessing: Postprocessing | None + ) -> SqlSelectBuilder: + """Convert this builder into a `SqlSelectBuilder` that is guaranteed to + have no DISTINCT or GROUP BY, nesting it in a subquery or common table + expression only if needed to apply any current DISTINCT or GROUP BY + clauses. + + This method consumes ``self``. + + Parameters + ---------- + cte : `bool`, optional + If `True`, nest via a common table expression instead of a + subquery. + force : `bool`, optional + If `True`, nest via a subquery or common table expression even if + there is no DISTINCT or GROUP BY. + postprocessing : `Postprocessing` + Struct representing post-query processing in Python, which may + require additional columns in the query results. + + Returns + ------- + builder : `SqlSelectBuilder` + `SqlSelectBuilder` with at least all columns in `columns` + available. This may or may not be the `builder` attribute of this + object. + """ + return SqlSelectBuilder( + self.into_from_builder(cte=cte, force=force, postprocessing=postprocessing), columns=self.columns + ) + + def union_subquery( + self, others: Iterable[SqlSelectBuilder], postprocessing: Postprocessing | None = None + ) -> SqlJoinsBuilder: + """Combine this builder with others to make a SELECT UNION subquery. + + Parameters + ---------- + others : `~collections.abc.Iterable` [ `SqlSelectBuilder` ] + Other query builders to union with. Their `columns` attributes + must be the same as those of ``self``. + postprocessing : `Postprocessing` + Struct representing post-query processing in Python, which may + require additional columns in the query results. + + Returns + ------- + joins_builder : `SqlJoinsBuilder` + `SqlJoinsBuilder` with at least all columns in `columns` available. + This may or may not be the `joins` attribute of this object. + """ + select0 = self.select(postprocessing) + other_selects = [other.select(postprocessing) for other in others] + return SqlJoinsBuilder( + db=self.joins.db, + from_clause=select0.union(*other_selects).subquery(), + ).extract_columns(self.columns, postprocessing) + + +@dataclasses.dataclass(kw_only=True) +class SqlColumns: + """A struct that holds SQLAlchemy columns objects for a query, categorized + by type. + + This class mostly serves as a base class for `SqlJoinsBuilder`, but unlike + `SqlJoinsBuilder` it is capable of representing columns in a compound + SELECT (i.e. UNION or UNION ALL) clause, not just a FROM clause. + """ + + db: Database + """Object that abstracts over the database engine.""" + + dimension_keys: NonemptyMapping[str, list[sqlalchemy.ColumnElement]] = dataclasses.field( + default_factory=lambda: NonemptyMapping(list) + ) + """Mapping of dimension keys included in the FROM clause. + + Nested lists correspond to different tables that have the same dimension + key (which should all have equal values for all result rows). + """ + + fields: NonemptyMapping[str | EllipsisType, dict[str, sqlalchemy.ColumnElement[Any]]] = dataclasses.field( + default_factory=lambda: NonemptyMapping(dict) + ) + """Mapping of columns that are neither dimension keys nor timespans. + + Inner and outer keys correspond to the "logical table" and "field" pairs + that result from iterating over `~.queries.tree.ColumnSet`, with the former + either a dimension element name or dataset type name. + """ + + timespans: dict[str | EllipsisType, TimespanDatabaseRepresentation] = dataclasses.field( + default_factory=dict + ) + """Mapping of timespan columns. + + Keys are "logical tables" - dimension element names or dataset type names. + """ + + special: dict[str, sqlalchemy.ColumnElement[Any]] = dataclasses.field(default_factory=dict) + """Special columns that are available from the FROM clause and + automatically included in the SELECT clause when this join builder is + nested within a `SqlSelectBuilder`. + + These columns are not part of the dimension universe and are not associated + with a dataset. They are never returned to users, even if they may be + included in raw SQL results. + """ + + def extract_dimensions( + self, dimensions: Iterable[str], *, column_collection: sqlalchemy.ColumnCollection, **kwargs: str + ) -> Self: + """Add dimension key columns from `from_clause` into `dimension_keys`. + + Parameters + ---------- + dimensions : `~collections.abc.Iterable` [ `str` ] + Names of dimensions to include, assuming that their names in + `sql_columns` are just the dimension names. + column_collection : `sqlalchemy.ColumnCollection` + SQLAlchemy column collection to extract from. + + **kwargs : `str` + Additional dimensions to include, with the names in `sql_columns` + as keys and the actual dimension names as values. + + Returns + ------- + self : `QueryColumns` + This `QueryColumns` instance (never a copy). Provided to enable + method chaining. + """ + for dimension_name in dimensions: + self.dimension_keys[dimension_name].append(column_collection[dimension_name]) + for k, v in kwargs.items(): + self.dimension_keys[v].append(column_collection[k]) + return self + + def extract_columns( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + special: Iterable[str] = (), + *, + column_collection: sqlalchemy.ColumnCollection, + ) -> Self: + """Add columns from `from_clause` into `dimension_keys`. + + Parameters + ---------- + columns : `.queries.tree.ColumnSet` + Columns to include, assuming that + `.queries.tree.ColumnSet.get_qualified_name` corresponds to the + name used in `sql_columns` (after name shrinking). + postprocessing : `Postprocessing`, optional + Postprocessing object whose needed columns should also be included. + special : `~collections.abc.Iterable` [ `str` ], optional + Additional special columns to extract. + column_collection : `sqlalchemy.ColumnCollection` + SQLAlchemy column collection to extract from. + + Returns + ------- + self : `QueryColumns` + This `QueryColumns` instance (never a copy). Provided to enable + method chaining. + """ + for logical_table, field in columns: + name = columns.get_qualified_name(logical_table, field) + if field is None: + assert logical_table is not ... + self.dimension_keys[logical_table].append(column_collection[name]) + else: + name = self.db.name_shrinker.shrink(name) + if columns.is_timespan(logical_table, field): + self.timespans[logical_table] = self.db.getTimespanRepresentation().from_columns( + column_collection, name + ) + else: + self.fields[logical_table][field] = column_collection[name] + if postprocessing is not None: + for element in postprocessing.iter_missing(columns): + self.fields[element.name]["region"] = column_collection[ + self.db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) + ] + if postprocessing.check_validity_match_count: + self.special[postprocessing.VALIDITY_MATCH_COUNT] = column_collection[ + postprocessing.VALIDITY_MATCH_COUNT + ] + for name in special: + self.special[name] = column_collection[name] + return self + + +@dataclasses.dataclass(kw_only=True) +class SqlJoinsBuilder(SqlColumns): + """A struct used to represent the FROM and WHERE clauses of an + under-construction SQL SELECT query. + + This object's methods frequently "consume" ``self``, by either returning + it after modification or returning related copy that may share state with + the original. Users should be careful never to use consumed instances, and + are recommended to reuse the same variable name to make that hard to do + accidentally. + """ + + from_clause: sqlalchemy.FromClause | None = None + """SQLAlchemy representation of the FROM clause. + + This is initialized to `None` but in almost all cases is immediately + replaced. + """ + + where_terms: list[sqlalchemy.ColumnElement[bool]] = dataclasses.field(default_factory=list) + """Sequence of WHERE clause terms to be combined with AND.""" + + def copy(self) -> SqlJoinsBuilder: + """Return a copy that can be safely mutated without affecting the + original. + """ + return dataclasses.replace( + self, + where_terms=self.where_terms.copy(), + dimension_keys=self.dimension_keys.copy(), + fields=self.fields.copy(), + timespans=self.timespans.copy(), + special=self.special.copy(), + ) + + def extract_dimensions( + self, + dimensions: Iterable[str], + *, + column_collection: sqlalchemy.ColumnCollection | None = None, + **kwargs: str, + ) -> Self: + """Add dimension key columns from `from_clause` into `dimension_keys`. + + Parameters + ---------- + dimensions : `~collections.abc.Iterable` [ `str` ] + Names of dimensions to include, assuming that their names in + `sql_columns` are just the dimension names. + column_collection : `sqlalchemy.ColumnCollection`, optional + SQLAlchemy column collection to extract from. Defaults to + ``self.from_clause.columns``. + **kwargs : `str` + Additional dimensions to include, with the names in `sql_columns` + as keys and the actual dimension names as values. + + Returns + ------- + self : `SqlJoinsBuilder` + This `SqlJoinsBuilder` instance (never a copy). Provided to enable + method chaining. + """ + if column_collection is None: + assert self.from_clause is not None, "Cannot extract columns with no FROM clause." + column_collection = self.from_clause.columns + return super().extract_dimensions(dimensions, column_collection=column_collection, **kwargs) + + def extract_columns( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + special: Iterable[str] = (), + *, + column_collection: sqlalchemy.ColumnCollection | None = None, + ) -> Self: + """Add columns from `from_clause` into `dimension_keys`. + + Parameters + ---------- + columns : `.queries.tree.ColumnSet` + Columns to include, assuming that + `.queries.tree.ColumnSet.get_qualified_name` corresponds to the + name used in `sql_columns` (after name shrinking). + postprocessing : `Postprocessing`, optional + Postprocessing object whose needed columns should also be included. + special : `~collections.abc.Iterable` [ `str` ], optional + Additional special columns to extract. + column_collection : `sqlalchemy.ColumnCollection`, optional + SQLAlchemy column collection to extract from. Defaults to + ``self.from_clause.columns``. + + Returns + ------- + self : `SqlJoinsBuilder` + This `SqlJoinsBuilder` instance (never a copy). Provided to enable + method chaining. + """ + if column_collection is None: + assert self.from_clause is not None, "Cannot extract columns with no FROM clause." + column_collection = self.from_clause.columns + return super().extract_columns(columns, postprocessing, special, column_collection=column_collection) + + def join(self, other: SqlJoinsBuilder) -> SqlJoinsBuilder: + """Combine this `SqlJoinsBuilder` with another via an INNER JOIN on + dimension keys. + + This method consumes ``self``. + + Parameters + ---------- + other : `SqlJoinsBuilder` + Other join builder to combine with this one. + + Returns + ------- + joined : `SqlJoinsBuilder` + A `SqlJoinsBuilder` with all columns present in either operand, + with its `from_clause` representing a SQL INNER JOIN where the + dimension key columns common to both operands are constrained to be + equal. If either operand does not have `from_clause`, the other's + is used. The `where_terms` of the two operands are concatenated, + representing a logical AND (with no attempt at deduplication). + """ + join_on: list[sqlalchemy.ColumnElement] = [] + for dimension_name in other.dimension_keys.keys(): + if dimension_name in self.dimension_keys: + for column1, column2 in itertools.product( + self.dimension_keys[dimension_name], other.dimension_keys[dimension_name] + ): + join_on.append(column1 == column2) + self.dimension_keys[dimension_name].extend(other.dimension_keys[dimension_name]) + if self.from_clause is None: + self.from_clause = other.from_clause + elif other.from_clause is not None: + join_on_sql: sqlalchemy.ColumnElement[bool] + match len(join_on): + case 0: + join_on_sql = sqlalchemy.true() + case 1: + (join_on_sql,) = join_on + case _: + join_on_sql = sqlalchemy.and_(*join_on) + self.from_clause = self.from_clause.join(other.from_clause, onclause=join_on_sql) + for logical_table, fields in other.fields.items(): + self.fields[logical_table].update(fields) + self.timespans.update(other.timespans) + self.special.update(other.special) + self.where_terms += other.where_terms + return self + + def where(self, *args: sqlalchemy.ColumnElement[bool]) -> SqlJoinsBuilder: + """Add a WHERE clause term. + + Parameters + ---------- + *args : `sqlalchemy.ColumnElement` + SQL boolean column expressions to be combined with AND. + + Returns + ------- + self : `SqlJoinsBuilder` + This `SqlJoinsBuilder` instance (never a copy). Provided to enable + method chaining. + """ + self.where_terms.extend(args) + return self + + def to_select_builder( + self, + columns: qt.ColumnSet, + distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = (), + group_by: Sequence[sqlalchemy.ColumnElement[Any]] = (), + ) -> SqlSelectBuilder: + """Convert this join builder into a `SqlSelectBuilder` by providing + SELECT clause columns and optional DISTINCT or GROUP BY clauses. + + This method consumes ``self``. + + Parameters + ---------- + columns : `~.queries.tree.ColumnSet` + Regular columns to include in the SELECT clause. + distinct : `bool` or `~collections.abc.Sequence` [ \ + `sqlalchemy.ColumnElement` ], optional + Specification of the DISTINCT clause (see + `SqlSelectBuilder.distinct`). + group_by : `~collections.abc.Sequence` [ \ + `sqlalchemy.ColumnElement` ], optional + Specification of the GROUP BY clause (see + `SqlSelectBuilder.group_by`). + + Returns + ------- + builder : `SqlSelectBuilder` + New query builder. + """ + return SqlSelectBuilder( + self, + columns, + distinct=distinct if type(distinct) is bool else tuple(distinct), + group_by=tuple(group_by), + ) + + +def make_table_spec( + columns: qt.ColumnSet, db: Database, postprocessing: Postprocessing | None +) -> ddl.TableSpec: + """Make a specification that can be used to create a table to store + this query's outputs. + + Parameters + ---------- + columns : `lsst.daf.butler.queries.tree.ColumnSet` + Columns to include in the table. + db : `Database` + Database engine and connection abstraction. + postprocessing : `Postprocessing` + Struct representing post-query processing in Python, which may + require additional columns in the query results. + + Returns + ------- + spec : `.ddl.TableSpec` + Table specification for this query's result columns (including + those from `postprocessing` and `SqlJoinsBuilder.special`). + """ + results = ddl.TableSpec( + [ + columns.get_column_spec(logical_table, field).to_sql_spec(name_shrinker=db.name_shrinker) + for logical_table, field in columns + ] + ) + if postprocessing: + for element in postprocessing.iter_missing(columns): + results.fields.add( + ddl.FieldSpec.for_region( + db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) + ) + ) + if not results.fields: + results.fields.add( + ddl.FieldSpec(name=SqlSelectBuilder.EMPTY_COLUMNS_NAME, dtype=SqlSelectBuilder.EMPTY_COLUMNS_TYPE) + ) + return results diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py index 45666b360e..634679824f 100644 --- a/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py +++ b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py @@ -43,7 +43,7 @@ if TYPE_CHECKING: from ._driver import DirectQueryDriver - from ._query_builder import QueryJoiner + from ._sql_builders import SqlColumns class SqlColumnVisitor( @@ -57,16 +57,16 @@ class SqlColumnVisitor( Parameters ---------- - joiner : `QueryJoiner` - `QueryJoiner` that provides SQL columns for column-reference + columns : `QueryColumns` + `QueryColumns` that provides SQL columns for column-reference expressions. driver : `QueryDriver` Driver used to construct nested queries for "in query" predicates. """ - def __init__(self, joiner: QueryJoiner, driver: DirectQueryDriver): + def __init__(self, columns: SqlColumns, driver: DirectQueryDriver): self._driver = driver - self._joiner = joiner + self._columns = columns def visit_literal( self, expression: qt.ColumnLiteral @@ -82,23 +82,23 @@ def visit_dimension_key_reference( self, expression: qt.DimensionKeyReference ) -> sqlalchemy.ColumnElement[int | str]: # Docstring inherited. - return self._joiner.dimension_keys[expression.dimension.name][0] + return self._columns.dimension_keys[expression.dimension.name][0] def visit_dimension_field_reference( self, expression: qt.DimensionFieldReference ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: # Docstring inherited. if expression.column_type == "timespan": - return self._joiner.timespans[expression.element.name] - return self._joiner.fields[expression.element.name][expression.field] + return self._columns.timespans[expression.element.name] + return self._columns.fields[expression.element.name][expression.field] def visit_dataset_field_reference( self, expression: qt.DatasetFieldReference ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: # Docstring inherited. if expression.column_type == "timespan": - return self._joiner.timespans[expression.dataset_type] - return self._joiner.fields[expression.dataset_type][expression.field] + return self._columns.timespans[expression.dataset_type] + return self._columns.fields[expression.dataset_type][expression.field] def visit_unary_expression(self, expression: qt.UnaryExpression) -> sqlalchemy.ColumnElement[Any]: # Docstring inherited. @@ -235,16 +235,16 @@ def visit_in_query_tree( # Docstring inherited. columns = qt.ColumnSet(self._driver.universe.empty) column.gather_required_columns(columns) - plan = self._driver.build_query(query_tree, columns) - builder = plan.builder - if plan.postprocessing: + builder = self._driver.build_query(query_tree, columns) + if builder.postprocessing: raise NotImplementedError( "Right-hand side subquery in IN expression would require postprocessing." ) - subquery_visitor = SqlColumnVisitor(builder.joiner, self._driver) - builder.joiner.special["_MEMBER"] = subquery_visitor.expect_scalar(column) - builder.columns = qt.ColumnSet(self._driver.universe.empty) - subquery_select = builder.select(plan.postprocessing) + select_builder = builder.finish_nested() + subquery_visitor = SqlColumnVisitor(select_builder.joins, self._driver) + select_builder.joins.special["_MEMBER"] = subquery_visitor.expect_scalar(column) + select_builder.columns = qt.ColumnSet(self._driver.universe.empty) + subquery_select = select_builder.select(postprocessing=None) sql_member = self.expect_scalar(member) return sql_member.in_(subquery_select) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index f656f724e0..47fe6f4556 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -6,6 +6,7 @@ import datetime import logging from collections.abc import Iterable, Mapping, Sequence, Set +from types import EllipsisType from typing import TYPE_CHECKING, Any, ClassVar, cast import astropy.time @@ -22,7 +23,7 @@ from ...._exceptions_legacy import DatasetTypeError from ...._timespan import Timespan from ....dimensions import DataCoordinate, DimensionGroup, DimensionUniverse -from ....direct_query_driver import QueryBuilder, QueryJoiner # new query system, server+direct only +from ....direct_query_driver import SqlJoinsBuilder, SqlSelectBuilder # new query system, server+direct only from ....queries import tree as qt # new query system, both clients + server from ..._caching_context import CachingContext, GenericCachingContext from ..._collection_summary import CollectionSummary @@ -1366,9 +1367,13 @@ def _finish_single_relation( ) return leaf - def make_query_joiner( - self, dataset_type: DatasetType, collections: Sequence[CollectionRecord], fields: Set[str] - ) -> QueryJoiner: + def make_joins_builder( + self, + dataset_type: DatasetType, + collections: Sequence[CollectionRecord], + fields: Set[str], + is_union: bool = False, + ) -> SqlJoinsBuilder: if (storage := self._find_storage(dataset_type.name)) is None: raise MissingDatasetTypeError(f"Dataset type {dataset_type.name!r} has not been registered.") # This method largely mimics `make_relation`, but it uses the new query @@ -1408,8 +1413,9 @@ def make_query_joiner( # FOREIGN KEY (and its index) are defined only on dataset_id. columns = qt.ColumnSet(dataset_type.dimensions) columns.drop_implied_dimension_keys() - columns.dataset_fields[dataset_type.name].update(fields) - tags_builder: QueryBuilder | None = None + fields_key: str | EllipsisType = ... if is_union else dataset_type.name + columns.dataset_fields[fields_key].update(fields) + tags_builder: SqlSelectBuilder | None = None if collection_types != {CollectionType.CALIBRATION}: # We'll need a subquery for the tags table if any of the given # collections are not a CALIBRATION collection. This intentionally @@ -1417,33 +1423,37 @@ def make_query_joiner( # create a dummy subquery that we know will fail. # We give the table an alias because it might appear multiple times # in the same query, for different dataset types. - tags_table = storage.dynamic_tables.tags(self._db, type(self._collections)) + tags_table = storage.dynamic_tables.tags(self._db, type(self._collections)).alias( + f"{dataset_type.name}_tags{'_union' if is_union else ''}" + ) tags_builder = self._finish_query_builder( storage, - QueryJoiner(self._db, tags_table.alias(f"{dataset_type.name}_tags")).to_builder(columns), + SqlJoinsBuilder(db=self._db, from_clause=tags_table).to_select_builder(columns), [record for record in collections if record.type is not CollectionType.CALIBRATION], fields, + fields_key, ) if "timespan" in fields: - tags_builder.joiner.timespans[dataset_type.name] = ( - self._db.getTimespanRepresentation().fromLiteral(None) + tags_builder.joins.timespans[fields_key] = self._db.getTimespanRepresentation().fromLiteral( + None ) - calibs_builder: QueryBuilder | None = None + calibs_builder: SqlSelectBuilder | None = None if CollectionType.CALIBRATION in collection_types: # If at least one collection is a CALIBRATION collection, we'll # need a subquery for the calibs table, and could include the # timespan as a result or constraint. calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections)).alias( - f"{dataset_type.name}_calibs" + f"{dataset_type.name}_calibs{'_union' if is_union else ''}" ) calibs_builder = self._finish_query_builder( storage, - QueryJoiner(self._db, calibs_table).to_builder(columns), + SqlJoinsBuilder(db=self._db, from_clause=calibs_table).to_select_builder(columns), [record for record in collections if record.type is CollectionType.CALIBRATION], fields, + fields_key, ) if "timespan" in fields: - calibs_builder.joiner.timespans[dataset_type.name] = ( + calibs_builder.joins.timespans[fields_key] = ( self._db.getTimespanRepresentation().from_columns(calibs_table.columns) ) @@ -1455,39 +1465,40 @@ def make_query_joiner( # Need a UNION subquery. return tags_builder.union_subquery([calibs_builder]) else: - return tags_builder.to_joiner(postprocessing=None) + return tags_builder.into_from_builder(postprocessing=None) elif calibs_builder is not None: - return calibs_builder.to_joiner(postprocessing=None) + return calibs_builder.into_from_builder(postprocessing=None) else: raise AssertionError("Branch should be unreachable.") def _finish_query_builder( self, storage: _DatasetRecordStorage, - sql_projection: QueryBuilder, + sql_projection: SqlSelectBuilder, collections: Sequence[CollectionRecord], fields: Set[str], - ) -> QueryBuilder: + fields_key: str | EllipsisType, + ) -> SqlSelectBuilder: # This method plays the same role as _finish_single_relation in the new # query system. It is called exactly one or two times by # make_sql_builder, just as _finish_single_relation is called exactly # one or two times by make_relation. See make_sql_builder comments for # what's different. - assert sql_projection.joiner.from_clause is not None + assert sql_projection.joins.from_clause is not None run_collections_only = all(record.type is CollectionType.RUN for record in collections) - sql_projection.joiner.where( - sql_projection.joiner.from_clause.c.dataset_type_id == storage.dataset_type_id + sql_projection.joins.where( + sql_projection.joins.from_clause.c.dataset_type_id == storage.dataset_type_id ) - dataset_id_col = sql_projection.joiner.from_clause.c.dataset_id - collection_col = sql_projection.joiner.from_clause.c[self._collections.getCollectionForeignKeyName()] - fields_provided = sql_projection.joiner.fields[storage.dataset_type.name] + dataset_id_col = sql_projection.joins.from_clause.c.dataset_id + collection_col = sql_projection.joins.from_clause.c[self._collections.getCollectionForeignKeyName()] + fields_provided = sql_projection.joins.fields[fields_key] # We always constrain and optionally retrieve the collection(s) via the # tags/calibs table. if "collection_key" in fields: - sql_projection.joiner.fields[storage.dataset_type.name]["collection_key"] = collection_col + sql_projection.joins.fields[fields_key]["collection_key"] = collection_col if len(collections) == 1: only_collection_record = collections[0] - sql_projection.joiner.where(collection_col == only_collection_record.key) + sql_projection.joins.where(collection_col == only_collection_record.key) if "collection" in fields: fields_provided["collection"] = sqlalchemy.literal(only_collection_record.name).cast( # This cast is necessary to ensure that Postgres knows the @@ -1497,11 +1508,11 @@ def _finish_query_builder( ) elif not collections: - sql_projection.joiner.where(sqlalchemy.literal(False)) + sql_projection.joins.where(sqlalchemy.literal(False)) if "collection" in fields: fields_provided["collection"] = sqlalchemy.literal("NO COLLECTIONS") else: - sql_projection.joiner.where(collection_col.in_([collection.key for collection in collections])) + sql_projection.joins.where(collection_col.in_([collection.key for collection in collections])) if "collection" in fields: # Avoid a join to the collection table to get the name by using # a CASE statement. The SQL will be a bit more verbose but @@ -1510,7 +1521,7 @@ def _finish_query_builder( collections, collection_col ) # Add more column definitions, starting with the data ID. - sql_projection.joiner.extract_dimensions(storage.dataset_type.dimensions.required) + sql_projection.joins.extract_dimensions(storage.dataset_type.dimensions.required) # We can always get the dataset_id from the tags/calibs table, even if # could also get it from the 'static' dataset table. if "dataset_id" in fields: @@ -1558,25 +1569,25 @@ def _finish_query_builder( need_static_table = True if need_static_table: # If we need the static table, join it in via dataset_id. We don't - # use QueryJoiner.join because we're joining on dataset ID, not + # use SqlJoinsBuilder.join because we're joining on dataset ID, not # dimensions. - sql_projection.joiner.from_clause = sql_projection.joiner.from_clause.join( + sql_projection.joins.from_clause = sql_projection.joins.from_clause.join( self._static.dataset, onclause=(dataset_id_col == self._static.dataset.c.id) ) # Also constrain dataset_type_id in static table in case that helps # generate a better plan. We could also include this in the JOIN ON # clause, but my guess is that that's a good idea IFF it's in the # foreign key, and right now it isn't. - sql_projection.joiner.where(self._static.dataset.c.dataset_type_id == storage.dataset_type_id) + sql_projection.joins.where(self._static.dataset.c.dataset_type_id == storage.dataset_type_id) if need_collection_table: # Join the collection table to look up the RUN collection name # associated with the dataset. ( fields_provided["run"], - sql_projection.joiner.from_clause, + sql_projection.joins.from_clause, ) = self._collections.lookup_name_sql( self._static.dataset.c[self._run_key_column], - sql_projection.joiner.from_clause, + sql_projection.joins.from_clause, ) sql_projection.distinct = ( diff --git a/python/lsst/daf/butler/registry/dimensions/static.py b/python/lsst/daf/butler/registry/dimensions/static.py index 28a2ca7d6b..41d618b4c2 100644 --- a/python/lsst/daf/butler/registry/dimensions/static.py +++ b/python/lsst/daf/butler/registry/dimensions/static.py @@ -59,8 +59,8 @@ from ...dimensions.record_cache import DimensionRecordCache from ...direct_query_driver import ( # Future query system (direct,server). Postprocessing, - QueryBuilder, - QueryJoiner, + SqlJoinsBuilder, + SqlSelectBuilder, ) from ...queries import tree as qt # Future query system (direct,client,server) from ...queries.overlaps import OverlapsVisitor @@ -458,18 +458,18 @@ def make_spatial_join_relation( ) return overlaps, needs_refinement - def make_query_joiner(self, element: DimensionElement, fields: Set[str]) -> QueryJoiner: + def make_joins_builder(self, element: DimensionElement, fields: Set[str]) -> SqlJoinsBuilder: if element.implied_union_target is not None: assert not fields, "Dimensions with implied-union storage never have fields." - return QueryBuilder( - self.make_query_joiner(element.implied_union_target, fields), + return SqlSelectBuilder( + self.make_joins_builder(element.implied_union_target, fields), columns=qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True, - ).to_joiner(postprocessing=None) + ).into_from_builder(postprocessing=None) if not element.has_own_table: raise NotImplementedError(f"Cannot join dimension element {element} with no table.") table = self._tables[element.name] - result = QueryJoiner(self._db, table) + result = SqlJoinsBuilder(db=self._db, from_clause=table) for dimension_name, column_name in zip(element.required.names, element.schema.required.names): result.dimension_keys[dimension_name].append(table.columns[column_name]) result.extract_dimensions(element.implied.names) @@ -488,7 +488,7 @@ def process_query_overlaps( predicate: qt.Predicate, join_operands: Iterable[DimensionGroup], calibration_dataset_types: Set[str | EllipsisType], - ) -> tuple[qt.Predicate, QueryBuilder, Postprocessing]: + ) -> tuple[qt.Predicate, SqlSelectBuilder, Postprocessing]: overlaps_visitor = _CommonSkyPixMediatedOverlapsVisitor( self._db, dimensions, calibration_dataset_types, self._overlap_tables ) @@ -1028,7 +1028,7 @@ def __init__( overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]], ): super().__init__(dimensions, calibration_dataset_types) - self.builder: QueryBuilder = QueryJoiner(db).to_builder(qt.ColumnSet(dimensions)) + self.builder: SqlSelectBuilder = SqlJoinsBuilder(db=db).to_select_builder(qt.ColumnSet(dimensions)) self.postprocessing = Postprocessing() self.common_skypix = dimensions.universe.commonSkyPix self.overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]] = overlap_tables @@ -1074,16 +1074,16 @@ def visit_spatial_constraint( # table that embeds the SQL WHERE clause we want and then # projects out that dimension (with SELECT DISTINCT, to # avoid introducing duplicate rows into the larger query). - joiner = self._make_common_skypix_overlap_joiner(element) + joins_builder = self._make_common_skypix_overlap_joins_builder(element) sql_where_or: list[sqlalchemy.ColumnElement[bool]] = [] - sql_skypix_col = joiner.dimension_keys[self.common_skypix.name][0] + sql_skypix_col = joins_builder.dimension_keys[self.common_skypix.name][0] for begin, end in self.common_skypix.pixelization.envelope(region): sql_where_or.append(sqlalchemy.and_(sql_skypix_col >= begin, sql_skypix_col < end)) - joiner.where(sqlalchemy.or_(*sql_where_or)) + joins_builder.where(sqlalchemy.or_(*sql_where_or)) self.builder.join( - joiner.to_builder( + joins_builder.to_select_builder( qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True - ).to_joiner(postprocessing=self.postprocessing) + ).into_from_builder(postprocessing=None) ) # Short circuit here since the SQL WHERE clause has already # been embedded in the subquery. @@ -1142,13 +1142,13 @@ def visit_spatial_join( # index column with SELECT DISTINCT. self.builder.join( - self._make_common_skypix_overlap_joiner(a) - .join(self._make_common_skypix_overlap_joiner(b)) - .to_builder( + self._make_common_skypix_overlap_joins_builder(a) + .join(self._make_common_skypix_overlap_joins_builder(b)) + .to_select_builder( qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(), distinct=True, ) - .to_joiner(postprocessing=self.postprocessing) + .into_from_builder(postprocessing=None) ) # In both cases we add postprocessing to check that the regions # really do overlap, since overlapping the same common skypix @@ -1160,13 +1160,13 @@ def visit_spatial_join( def _join_common_skypix_overlap(self, element: DatabaseDimensionElement) -> None: if element not in self.common_skypix_overlaps_done: - self.builder.join(self._make_common_skypix_overlap_joiner(element)) + self.builder.join(self._make_common_skypix_overlap_joins_builder(element)) self.common_skypix_overlaps_done.add(element) - def _make_common_skypix_overlap_joiner(self, element: DatabaseDimensionElement) -> QueryJoiner: + def _make_common_skypix_overlap_joins_builder(self, element: DatabaseDimensionElement) -> SqlJoinsBuilder: _, overlap_table = self.overlap_tables[element.name] return ( - QueryJoiner(self.builder.joiner.db, overlap_table) + SqlJoinsBuilder(db=self.builder.joins.db, from_clause=overlap_table) .extract_dimensions(element.required.names, skypix_index=self.common_skypix.name) .where( sqlalchemy.and_( diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py index 2ee7d7b265..2ace8b61ad 100644 --- a/python/lsst/daf/butler/registry/interfaces/_datasets.py +++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py @@ -45,7 +45,7 @@ from ._versioning import VersionedExtension, VersionTuple if TYPE_CHECKING: - from ...direct_query_driver import QueryJoiner # new query system, server+direct only + from ...direct_query_driver import SqlJoinsBuilder # new query system, server+direct only from .._caching_context import CachingContext from .._collection_summary import CollectionSummary from ..queries import SqlQueryContext # old registry query system @@ -627,11 +627,15 @@ def make_relation( raise NotImplementedError() @abstractmethod - def make_query_joiner( - self, dataset_type: DatasetType, collections: Sequence[CollectionRecord], fields: Set[str] - ) -> QueryJoiner: - """Make a `..direct_query_driver.QueryJoiner` that represents a search - for datasets of this type. + def make_joins_builder( + self, + dataset_type: DatasetType, + collections: Sequence[CollectionRecord], + fields: Set[str], + is_union: bool = False, + ) -> SqlJoinsBuilder: + """Make a `..direct_query_driver.SqlJoinsBuilder` that represents a + search for datasets of this type. Parameters ---------- @@ -641,7 +645,7 @@ def make_query_joiner( Collections to search, in order, after filtering out collections with no datasets of this type via collection summaries. fields : `~collections.abc.Set` [ `str` ] - Names of fields to make available in the joiner. Options include: + Names of fields to make available in the builder. Options include: - ``dataset_id`` (UUID) - ``run`` (collection name, `str`) @@ -652,10 +656,15 @@ def make_query_joiner( Dimension keys for the dataset type's required dimensions are always included. + is_union : `bool`, optional + If `True`, this search is being joined in as part of one term in + a union over all dataset types. This causes fields to be added to + the builder via the special ``...`` instad of the dataset type + name. Returns ------- - joiner : `..direct_query_driver.QueryJoiner` + builder : `..direct_query_driver.SqlJoinsBuilder` A query-construction object representing a table or subquery. """ raise NotImplementedError() diff --git a/python/lsst/daf/butler/registry/interfaces/_dimensions.py b/python/lsst/daf/butler/registry/interfaces/_dimensions.py index 81b0854d5b..dda6d9b6d6 100644 --- a/python/lsst/daf/butler/registry/interfaces/_dimensions.py +++ b/python/lsst/daf/butler/registry/interfaces/_dimensions.py @@ -49,8 +49,8 @@ if TYPE_CHECKING: from ...direct_query_driver import ( # Future query system (direct,server). Postprocessing, - QueryBuilder, - QueryJoiner, + SqlJoinsBuilder, + SqlSelectBuilder, ) from ...queries.tree import Predicate # Future query system (direct,client,server). from .. import queries # Old Registry.query* system. @@ -368,8 +368,8 @@ def make_spatial_join_relation( raise NotImplementedError() @abstractmethod - def make_query_joiner(self, element: DimensionElement, fields: Set[str]) -> QueryJoiner: - """Make a `..direct_query_driver.QueryJoiner` that represents a + def make_joins_builder(self, element: DimensionElement, fields: Set[str]) -> SqlJoinsBuilder: + """Make a `..direct_query_driver.SqlJoinsBuilder` that represents a dimension element table. Parameters @@ -377,14 +377,14 @@ def make_query_joiner(self, element: DimensionElement, fields: Set[str]) -> Quer element : `DimensionElement` Dimension element the table corresponds to. fields : `~collections.abc.Set` [ `str` ] - Names of fields to make available in the joiner. These can be any + Names of fields to make available in the builder. These can be any metadata or alternate key field in the element's schema, including the special ``region`` and ``timespan`` fields. Dimension keys in the element's schema are always included. Returns ------- - joiner : `..direct_query_driver.QueryJoiner` + builder : `..direct_query_driver.SqlJoinsBuilder` A query-construction object representing a table or subquery. This is guaranteed to have rows that are unique over dimension keys and all possible key values for this dimension, so joining in a @@ -403,7 +403,7 @@ def process_query_overlaps( predicate: Predicate, join_operands: Iterable[DimensionGroup], calibration_dataset_types: Set[str | EllipsisType], - ) -> tuple[Predicate, QueryBuilder, Postprocessing]: + ) -> tuple[Predicate, SqlSelectBuilder, Postprocessing]: """Process a query's WHERE predicate and dimensions to handle spatial and temporal overlaps. @@ -432,7 +432,7 @@ def process_query_overlaps( behavior of the filter while possibly rewriting overlap expressions that have been partially moved into ``builder`` as some combination of new nested predicates, joins, and postprocessing. - builder : `..direct_query_driver.QueryBuilder` + builder : `..direct_query_driver.SqlSelectBuilder` A query-construction helper object that includes any initial joins and postprocessing needed to handle overlap expression extracted from the original predicate.