Skip to content

Commit

Permalink
Drop SerializedEllipsis in favor of single-element ANY_DATASET enum.
Browse files Browse the repository at this point in the history
Turns out the Pydantic adapters were buggy in a hard-to-fix way, and
while the enum is a little more verbose, it's more self-describing.
  • Loading branch information
TallJimbo committed Oct 25, 2024
1 parent 56cbe4a commit 1438993
Show file tree
Hide file tree
Showing 16 changed files with 123 additions and 134 deletions.
32 changes: 18 additions & 14 deletions python/lsst/daf/butler/direct_query_driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from collections import defaultdict
from collections.abc import Iterable, Iterator, Mapping, Set
from contextlib import ExitStack
from types import EllipsisType
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload

import sqlalchemy
Expand Down Expand Up @@ -95,7 +94,7 @@

_LOG = logging.getLogger(__name__)

_T = TypeVar("_T", bound=str | EllipsisType)
_T = TypeVar("_T", bound=str | qt.AnyDatasetType)


class DirectQueryDriver(QueryDriver):
Expand Down Expand Up @@ -445,7 +444,7 @@ def build_query(
final_columns: qt.ColumnSet,
*,
order_by: Iterable[qt.OrderExpression] = (),
find_first_dataset: str | EllipsisType | None = None,
find_first_dataset: str | qt.AnyDatasetType | None = None,
analyze_only: bool = False,
) -> QueryBuilder:
"""Convert a query description into a nearly-complete builder object
Expand All @@ -460,11 +459,11 @@ def build_query(
order_by : `~collections.abc.Iterable` [ \
`.queries.tree.OrderExpression` ], optional
Column expressions to sort by.
find_first_dataset : `str`, ``...``, or `None`, optional
find_first_dataset : `str`, ``ANY_DATASET``, or `None`, optional
Name of a dataset type for which only one result row for each data
ID should be returned, with the colletions searched in order.
``...`` is used to represent the search for all dataset types with
a particular set of dimensions in ``tree.any_dataset``.
``ANY_DATASET`` is used to represent the search for all dataset
types with a particular set of dimensions in ``tree.any_dataset``.
analyze_only : `bool`, optional
If `True`, perform the initial analysis needed to construct the
builder, but do not call methods that build its SQL form. This can
Expand Down Expand Up @@ -694,7 +693,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
# Gather the filtered collection search path for each union dataset
# type.
collections_by_dataset_type = defaultdict[str, list[str]](list)
for collection_record, collection_summary in collection_analysis.summaries_by_dataset_type[...]:
for collection_record, collection_summary in collection_analysis.summaries_by_dataset_type[
qt.ANY_DATASET
]:
for dataset_type in collection_summary.dataset_types:
if dataset_type.dimensions == tree.any_dataset.dimensions:
collections_by_dataset_type[dataset_type.name].append(collection_record.name)
Expand Down Expand Up @@ -837,7 +838,7 @@ def apply_query_projection(
needs_dimension_distinct: bool,
needs_dataset_distinct: bool,
needs_validity_match_count: bool,
find_first_dataset: str | EllipsisType | None,
find_first_dataset: str | qt.AnyDatasetType | None,
order_by: Iterable[qt.OrderExpression],
) -> None:
"""Apply the "projection" stage of query construction to a single
Expand Down Expand Up @@ -905,7 +906,7 @@ def apply_query_projection(
# fields: if rows are uniqe over the 'unique_key' fields, then they're
# automatically unique over these 'derived_fields'. We just remember
# these as pairs of (logical_table, field) for now.
derived_fields: list[tuple[str | EllipsisType, str]] = []
derived_fields: list[tuple[str | qt.AnyDatasetType, str]] = []

# There are two reasons we might need an aggregate function:
# - to make sure temporal constraints and joins have resulted in at
Expand Down Expand Up @@ -952,7 +953,7 @@ def apply_query_projection(
# it's a find-first query.
for dataset_type, fields_for_dataset in projection_columns.dataset_fields.items():
dataset_search: ResolvedDatasetSearch[Any]
if dataset_type is ...:
if dataset_type is qt.ANY_DATASET:
assert union_datasets is not None
dataset_search = union_datasets
else:
Expand Down Expand Up @@ -1219,7 +1220,7 @@ def _resolve_dataset_search(
Parameters
----------
dataset_type_name : `str` or ``...``
dataset_type_name : `str` or ``ANY_DATASET``
Name of the dataset being searched for.
dataset_search : `.queries.tree.DatasetSearch`
Struct holding the dimensions and original collection search path.
Expand All @@ -1242,7 +1243,10 @@ def _resolve_dataset_search(
result.messages.append("No datasets can be found because collection list is empty.")
for collection_record, collection_summary in collections:
rejected: bool = False
if result.name is not ... and result.name not in collection_summary.dataset_types.names:
if (
result.name is not qt.ANY_DATASET
and result.name not in collection_summary.dataset_types.names
):
result.messages.append(
f"No datasets of type {result.name!r} in collection {collection_record.name!r}."
)
Expand Down Expand Up @@ -1346,10 +1350,10 @@ def join_dataset_search(
else:
dataset_type = self.get_dataset_type(union_dataset_type_name)
assert (
... not in joins_builder.fields
qt.ANY_DATASET not in joins_builder.fields
), "Union dataset fields have unexpectedly already been joined in."
assert (
... not in joins_builder.timespans
qt.ANY_DATASET not in joins_builder.timespans
), "Union dataset timespan has unexpectedly already been joined in."

joins_builder.join(
Expand Down
13 changes: 6 additions & 7 deletions python/lsst/daf/butler/direct_query_driver/_query_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

import dataclasses
from collections.abc import Iterator, Mapping
from types import EllipsisType
from typing import TYPE_CHECKING, Generic, TypeVar

from ..dimensions import DimensionElement, DimensionGroup
Expand Down Expand Up @@ -205,16 +204,16 @@ class QueryCollectionAnalysis:
This includes CHAINED collections.
"""

calibration_dataset_types: set[str | EllipsisType] = dataclasses.field(default_factory=set)
calibration_dataset_types: set[str | qt.AnyDatasetType] = dataclasses.field(default_factory=set)
"""A set of the anmes of all calibration dataset types.
If ``...`` appears in the set, the dataset type union includes at least one
calibration dataset type.
If ``ANY_DATASET`` appears in the set, the dataset type union includes at
least one calibration dataset type.
"""

summaries_by_dataset_type: dict[str | EllipsisType, list[tuple[CollectionRecord, CollectionSummary]]] = (
dataclasses.field(default_factory=dict)
)
summaries_by_dataset_type: dict[
str | qt.AnyDatasetType, list[tuple[CollectionRecord, CollectionSummary]]
] = dataclasses.field(default_factory=dict)
"""Collection records and summaries, in search order, keyed by dataset type
name.
Expand Down
25 changes: 12 additions & 13 deletions python/lsst/daf/butler/direct_query_driver/_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Iterable, Set
from types import EllipsisType
from typing import TYPE_CHECKING, Literal, TypeVar, overload

import sqlalchemy
Expand Down Expand Up @@ -109,7 +108,7 @@ def __init__(
used to make post-projection rows unique.
"""

find_first_dataset: str | EllipsisType | None = None
find_first_dataset: str | qt.AnyDatasetType | None = None
"""If not `None`, this is a find-first query for this dataset.
This is set even if the find-first search is trivial because there is only
Expand Down Expand Up @@ -161,7 +160,7 @@ def analyze_projection(self) -> None:
self.needs_dimension_distinct = True

@abstractmethod
def analyze_find_first(self, find_first_dataset: str | EllipsisType) -> None:
def analyze_find_first(self, find_first_dataset: str | qt.AnyDatasetType) -> None:
"""Analyze the "find first" stage of query construction, in which a
Common Table Expression with PARTITION ON may be used to find the first
dataset for each data ID and dataset type in an ordered collection
Expand Down Expand Up @@ -345,7 +344,7 @@ def analyze_projection(self) -> None:
# because we have rows for datasets in multiple collections with the
# same data ID and dataset type.
for dataset_type in self.joins_analysis.columns.dataset_fields:
assert dataset_type is not ..., "Union dataset in non-dataset-union query."
assert dataset_type is not qt.ANY_DATASET, "Union dataset in non-dataset-union query."
if not self.projection_columns.dataset_fields[dataset_type]:
# The "joins"-stage query has one row for each collection for
# each data ID, but the projection-stage query just wants
Expand All @@ -358,13 +357,13 @@ def analyze_projection(self) -> None:
# the collection_key column so we can use that as one of the DISTINCT
# or GROUP BY columns.
for dataset_type, fields_for_dataset in self.projection_columns.dataset_fields.items():
assert dataset_type is not ..., "Union dataset in non-dataset-union query."
assert dataset_type is not qt.ANY_DATASET, "Union dataset in non-dataset-union query."
if len(self.joins_analysis.datasets[dataset_type].collection_records) > 1:
fields_for_dataset.add("collection_key")

def analyze_find_first(self, find_first_dataset: str | EllipsisType) -> None:
def analyze_find_first(self, find_first_dataset: str | qt.AnyDatasetType) -> None:
# Docstring inherited.
assert find_first_dataset is not ..., "No dataset union in this query"
assert find_first_dataset is not qt.ANY_DATASET, "No dataset union in this query"
self.find_first = QueryFindFirstAnalysis(self.joins_analysis.datasets[find_first_dataset])
# If we're doing a find-first search and there's a calibration
# collection in play, we need to make sure the rows coming out of
Expand Down Expand Up @@ -556,7 +555,7 @@ def analyze_projection(self) -> None:
# same data ID and dataset type.
for dataset_type in self.joins_analysis.columns.dataset_fields:
if not self.projection_columns.dataset_fields[dataset_type]:
if dataset_type is ...:
if dataset_type is qt.ANY_DATASET:
for union_term in self.union_terms:
if len(union_term.datasets.collection_records) > 1:
union_term.needs_dataset_distinct = True
Expand All @@ -572,7 +571,7 @@ def analyze_projection(self) -> None:
# the collection_key column so we can use that as one of the DISTINCT
# or GROUP BY columns.
for dataset_type, fields_for_dataset in self.projection_columns.dataset_fields.items():
if dataset_type is ...:
if dataset_type is qt.ANY_DATASET:
for union_term in self.union_terms:
# If there is more than one collection for one union term,
# we need to add collection_key to all of them to keep the
Expand All @@ -583,9 +582,9 @@ def analyze_projection(self) -> None:
elif len(self.joins_analysis.datasets[dataset_type].collection_records) > 1:
fields_for_dataset.add("collection_key")

def analyze_find_first(self, find_first_dataset: str | EllipsisType) -> None:
def analyze_find_first(self, find_first_dataset: str | qt.AnyDatasetType) -> None:
# Docstring inherited.
if find_first_dataset is ...:
if find_first_dataset is qt.ANY_DATASET:
for union_term in self.union_terms:
union_term.find_first = QueryFindFirstAnalysis(union_term.datasets)
# If we're doing a find-first search and there's a calibration
Expand Down Expand Up @@ -625,7 +624,7 @@ def apply_joins(self, driver: DirectQueryDriver) -> None:
driver.join_dataset_search(
select_builder.joins,
union_term.datasets,
self.joins_analysis.columns.dataset_fields[...],
self.joins_analysis.columns.dataset_fields[qt.ANY_DATASET],
union_dataset_type_name=dataset_type_name,
)
union_term.select_builders.append(select_builder)
Expand All @@ -647,7 +646,7 @@ def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.Orde
needs_dimension_distinct=self.needs_dimension_distinct,
needs_dataset_distinct=union_term.needs_dataset_distinct,
needs_validity_match_count=union_term.needs_validity_match_count,
find_first_dataset=None if union_term.find_first is None else ...,
find_first_dataset=None if union_term.find_first is None else qt.ANY_DATASET,
order_by=order_by,
)

Expand Down
11 changes: 5 additions & 6 deletions python/lsst/daf/butler/direct_query_driver/_sql_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import dataclasses
import itertools
from collections.abc import Iterable, Sequence
from types import EllipsisType
from typing import TYPE_CHECKING, Any, ClassVar, Self

import sqlalchemy
Expand Down Expand Up @@ -148,7 +147,7 @@ def select(self, postprocessing: Postprocessing | None) -> sqlalchemy.Select:
for logical_table, field in self.columns:
name = self.columns.get_qualified_name(logical_table, field)
if field is None:
assert logical_table is not ...
assert logical_table is not qt.ANY_DATASET
sql_columns.append(self.joins.dimension_keys[logical_table][0].label(name))
else:
name = self.joins.db.name_shrinker.shrink(name)
Expand Down Expand Up @@ -319,8 +318,8 @@ class SqlColumns:
key (which should all have equal values for all result rows).
"""

fields: NonemptyMapping[str | EllipsisType, dict[str, sqlalchemy.ColumnElement[Any]]] = dataclasses.field(
default_factory=lambda: NonemptyMapping(dict)
fields: NonemptyMapping[str | qt.AnyDatasetType, dict[str, sqlalchemy.ColumnElement[Any]]] = (
dataclasses.field(default_factory=lambda: NonemptyMapping(dict))
)
"""Mapping of columns that are neither dimension keys nor timespans.
Expand All @@ -329,7 +328,7 @@ class SqlColumns:
either a dimension element name or dataset type name.
"""

timespans: dict[str | EllipsisType, TimespanDatabaseRepresentation] = dataclasses.field(
timespans: dict[str | qt.AnyDatasetType, TimespanDatabaseRepresentation] = dataclasses.field(
default_factory=dict
)
"""Mapping of timespan columns.
Expand Down Expand Up @@ -408,7 +407,7 @@ def extract_columns(
for logical_table, field in columns:
name = columns.get_qualified_name(logical_table, field)
if field is None:
assert logical_table is not ...
assert logical_table is not qt.ANY_DATASET
self.dimension_keys[logical_table].append(column_collection[name])
else:
name = self.db.name_shrinker.shrink(name)
Expand Down
15 changes: 0 additions & 15 deletions python/lsst/daf/butler/pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
"get_universe_from_context",
"SerializableRegion",
"SerializableTime",
"SerializableEllipsis",
)

from types import EllipsisType
Expand Down Expand Up @@ -336,17 +335,3 @@ def _deserialize_ellipsis(value: object, handler: pydantic.ValidatorFunctionWrap
if s == "...":
return ...
raise ValueError(f"String {s!r} is not '...'.")


SerializableEllipsis: TypeAlias = Annotated[
EllipsisType,
pydantic.GetPydanticSchema(lambda _, h: h(str)),
pydantic.WrapValidator(_deserialize_ellipsis),
pydantic.WrapSerializer(_serialize_ellipsis),
pydantic.WithJsonSchema({"const": "..."}),
]
"""A Pydantic-annotated version of the special ellipsis object (``...``).
The serialized form is just the string "...", and hence to participate in a
union with `str` correctly, this type must come first.
"""
3 changes: 2 additions & 1 deletion python/lsst/daf/butler/queries/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
GeneralResultSpec,
)
from .tree import (
ANY_DATASET,
DatasetFieldName,
DatasetFieldReference,
DatasetSearch,
Expand Down Expand Up @@ -393,7 +394,7 @@ def general(
case DimensionFieldReference(element=element, field=field):
dimension_fields_dict.setdefault(element.name, set()).add(field)
case DatasetFieldReference(dataset_type=dataset_type, field=dataset_field):
if dataset_type is ...:
if dataset_type is ANY_DATASET:
raise InvalidQueryError("Dataset wildcard fields are not supported by Query.general.")
dataset_fields_dict.setdefault(dataset_type, set()).add(dataset_field)
case _:
Expand Down
Loading

0 comments on commit 1438993

Please sign in to comment.