Skip to content

Commit

Permalink
Add "union dataset" to query model and interface classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Oct 22, 2024
1 parent d8dbb9e commit 4818ac7
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 43 deletions.
2 changes: 2 additions & 0 deletions python/lsst/daf/butler/queries/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def general(
case DimensionFieldReference(element=element, field=field):
dimension_fields_dict.setdefault(element.name, set()).add(field)
case DatasetFieldReference(dataset_type=dataset_type, field=dataset_field):
if dataset_type is ...:
raise InvalidQueryError("Dataset wildcard fields are not supported by Query.general.")
dataset_fields_dict.setdefault(dataset_type, set()).add(dataset_field)
case _:
raise TypeError(f"Unexpected type of identifier ({name}): {identifier}")
Expand Down
34 changes: 20 additions & 14 deletions python/lsst/daf/butler/queries/overlaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import itertools
from collections.abc import Hashable, Iterable, Mapping, Sequence, Set
from types import EllipsisType
from typing import Generic, Literal, TypeVar, cast

from lsst.sphgeom import Region
Expand Down Expand Up @@ -114,7 +115,7 @@ class CalibrationTemporalEndpoint(TopologicalRelationshipEndpoint):
Parameters
----------
dataset_type_name : `str`
dataset_type_name : `str` or ``...``
Name of the dataset type.
Notes
Expand All @@ -127,12 +128,12 @@ class CalibrationTemporalEndpoint(TopologicalRelationshipEndpoint):
the same family).
"""

def __init__(self, dataset_type_name: str):
def __init__(self, dataset_type_name: str | EllipsisType):
self.dataset_type_name = dataset_type_name

@property
def name(self) -> str:
return self.dataset_type_name
return self.dataset_type_name if self.dataset_type_name is not ... else "<calibrations>"

@property
def topology(self) -> Mapping[TopologicalSpace, TopologicalFamily]:
Expand All @@ -147,18 +148,21 @@ class CalibrationTemporalFamily(TopologicalFamily):
Parameters
----------
dataset_type_name : `str`
dataset_type_name : `str` or ``...``
Name of the dataset type.
"""

def __init__(self, dataset_type_name: str):
super().__init__(dataset_type_name, TopologicalSpace.TEMPORAL)
def __init__(self, dataset_type_name: str | EllipsisType):
super().__init__(
dataset_type_name if dataset_type_name is not ... else "<calibrations>", TopologicalSpace.TEMPORAL
)
self.dataset_type_name = dataset_type_name

def choose(self, dimensions: DimensionGroup) -> CalibrationTemporalEndpoint:
return CalibrationTemporalEndpoint(self.name)
return CalibrationTemporalEndpoint(self.dataset_type_name)

def make_column_reference(self, endpoint: TopologicalRelationshipEndpoint) -> tree.DatasetFieldReference:
return tree.DatasetFieldReference(dataset_type=endpoint.name, field="timespan")
return tree.DatasetFieldReference(dataset_type=self.dataset_type_name, field="timespan")


class OverlapsVisitor(SimplePredicateVisitor):
Expand All @@ -181,7 +185,7 @@ class OverlapsVisitor(SimplePredicateVisitor):
implementations that want to rewrite the predicate at the same time.
"""

def __init__(self, dimensions: DimensionGroup, calibration_dataset_types: Set[str]):
def __init__(self, dimensions: DimensionGroup, calibration_dataset_types: Set[str | EllipsisType]):
self.dimensions = dimensions
self._spatial_connections = _NaiveDisjointSet(self.dimensions.spatial)
temporal_families: list[TopologicalFamily] = [
Expand Down Expand Up @@ -477,7 +481,7 @@ def visit_temporal_dimension_join(
return None

def visit_validity_range_dimension_join(
self, a: str, b: DimensionElement, flags: PredicateVisitFlags
self, a: str | EllipsisType, b: DimensionElement, flags: PredicateVisitFlags
) -> tree.Predicate | None:
"""Handle a temporal overlap comparison between two dimension elements.
Expand All @@ -487,7 +491,7 @@ def visit_validity_range_dimension_join(
Parameters
----------
a : `str`
a : `str` or ``...``
Name of a calibration dataset type.
b : `DimensionElement`
The dimension element to join the dataset validity range to.
Expand All @@ -504,7 +508,9 @@ def visit_validity_range_dimension_join(
self._temporal_connections.merge(CalibrationTemporalFamily(a), cast(TopologicalFamily, b.temporal))
return None

def visit_validity_range_join(self, a: str, b: str, flags: PredicateVisitFlags) -> tree.Predicate | None:
def visit_validity_range_join(
self, a: str | EllipsisType, b: str | EllipsisType, flags: PredicateVisitFlags
) -> tree.Predicate | None:
"""Handle a temporal overlap comparison between two dimension elements.
The default implementation updates the set of known temporal
Expand All @@ -513,9 +519,9 @@ def visit_validity_range_join(self, a: str, b: str, flags: PredicateVisitFlags)
Parameters
----------
a : `str`
a : `str` or ``...``
Name of a calibration dataset type.
b : `DimensionElement`
b : `str` or ``...``
Another claibration dataset type to join to.
flags : `tree.PredicateLeafFlags`
Information about where this overlap comparison appears in the
Expand Down
3 changes: 2 additions & 1 deletion python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from abc import ABC, abstractmethod
from collections.abc import Mapping
from types import EllipsisType
from typing import Annotated, Literal, TypeAlias, cast

import pydantic
Expand Down Expand Up @@ -96,7 +97,7 @@ def validate_tree(self, tree: QueryTree) -> None:
)

@property
def find_first_dataset(self) -> str | None:
def find_first_dataset(self) -> str | EllipsisType | None:
"""The dataset type for which find-first resolution is required, if
any.
"""
Expand Down
10 changes: 7 additions & 3 deletions python/lsst/daf/butler/queries/tree/_column_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..._exceptions import InvalidQueryError
from ...column_spec import ColumnType
from ...dimensions import Dimension, DimensionElement
from ...pydantic_utils import SerializableEllipsis
from ._base import ColumnExpressionBase, DatasetFieldName

if TYPE_CHECKING:
Expand Down Expand Up @@ -141,8 +142,8 @@ class DatasetFieldReference(ColumnExpressionBase):

is_column_reference: ClassVar[bool] = True

dataset_type: str
"""Name of the dataset type."""
dataset_type: SerializableEllipsis | str
"""Name of the dataset type, or ``...`` to match any dataset type."""

field: DatasetFieldName
"""Name of the field (i.e. column) in the dataset's logical table."""
Expand Down Expand Up @@ -173,7 +174,10 @@ def column_type(self) -> ColumnType:
raise AssertionError(f"Invalid field {self.field!r} for dataset.")

def __str__(self) -> str:
return f"{self.dataset_type}.{self.field}"
if self.dataset_type is ...:
return self.field
else:
return f"{self.dataset_type}.{self.field}"

def visit(self, visitor: ColumnExpressionVisitor[_T]) -> _T:
# Docstring inherited.
Expand Down
37 changes: 20 additions & 17 deletions python/lsst/daf/butler/queries/tree/_column_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
__all__ = ("ColumnSet", "ColumnOrder", "ResultColumn")

from collections.abc import Iterable, Iterator, Mapping, Sequence, Set
from typing import NamedTuple
from types import EllipsisType
from typing import NamedTuple, cast

from ... import column_spec
from ...dimensions import DataIdValue, DimensionGroup
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(self, dimensions: DimensionGroup) -> None:
self._dimensions = dimensions
self._removed_dimension_keys: set[str] = set()
self._dimension_fields: dict[str, set[str]] = {name: set() for name in dimensions.elements}
self._dataset_fields = NonemptyMapping[str, set[str]](set)
self._dataset_fields = NonemptyMapping[str | EllipsisType, set[str]](set)

@property
def dimensions(self) -> DimensionGroup:
Expand All @@ -82,7 +83,7 @@ def dimension_fields(self) -> Mapping[str, set[str]]:
return self._dimension_fields

@property
def dataset_fields(self) -> NonemptyMapping[str, set[str]]:
def dataset_fields(self) -> NonemptyMapping[str | EllipsisType, set[str]]:
"""Dataset fields included in the set, grouped by dataset type name.
The keys of this mapping are just those that actually have nonempty
Expand Down Expand Up @@ -269,20 +270,20 @@ def get_column_order(self) -> ColumnOrder:
# We sort dataset types and their fields lexicographically just to keep
# our queries from having any dependence on set-iteration order.
dataset_fields: list[ResultColumn] = []
for dataset_type in sorted(self._dataset_fields):
for dataset_type in sorted(self._dataset_fields, key=str): # transform ... to "..." for sorting
for field in sorted(self._dataset_fields[dataset_type]):
dataset_fields.append(ResultColumn(dataset_type, field))

return ColumnOrder(dimension_names, dimension_elements, dataset_fields)

def is_timespan(self, logical_table: str, field: str | None) -> bool:
def is_timespan(self, logical_table: EllipsisType | str, field: str | None) -> bool:
"""Test whether the given column is a timespan.
Parameters
----------
logical_table : `str`
logical_table : `str` or ``...``
Name of the dimension element or dataset type the column belongs
to.
to. ``...`` is used to represent any dataset type.
field : `str` or `None`
Column within the logical table, or `None` for dimension key
columns.
Expand All @@ -295,14 +296,14 @@ def is_timespan(self, logical_table: str, field: str | None) -> bool:
return field == "timespan"

@staticmethod
def get_qualified_name(logical_table: str, field: str | None) -> str:
def get_qualified_name(logical_table: EllipsisType | str, field: str | None) -> str:
"""Return string that should be used to fully identify a column.
Parameters
----------
logical_table : `str`
logical_table : `str` or ``...```
Name of the dimension element or dataset type the column belongs
to.
to. ``...`` is used to represent any dataset type.
field : `str` or `None`
Column within the logical table, or `None` for dimension key
columns.
Expand All @@ -312,16 +313,16 @@ def get_qualified_name(logical_table: str, field: str | None) -> str:
name : `str`
Fully-qualified name.
"""
return logical_table if field is None else f"{logical_table}:{field}"
return str(logical_table) if field is None else f"{logical_table}:{field}"

def get_column_spec(self, logical_table: str, field: str | None) -> column_spec.ColumnSpec:
def get_column_spec(self, logical_table: EllipsisType | str, field: str | None) -> column_spec.ColumnSpec:
"""Return a complete description of a column.
Parameters
----------
logical_table : `str`
logical_table : `str` or ``...``
Name of the dimension element or dataset type the column belongs
to.
to. ``...`` is used to represent any dataset type.
field : `str` or `None`
Column within the logical table, or `None` for dimension key
columns.
Expand All @@ -333,10 +334,12 @@ def get_column_spec(self, logical_table: str, field: str | None) -> column_spec.
"""
qualified_name = self.get_qualified_name(logical_table, field)
if field is None:
assert logical_table is not ...
return self._dimensions.universe.dimensions[logical_table].primary_key.model_copy(
update=dict(name=qualified_name)
)
if logical_table in self._dimension_fields:
assert logical_table is not ...
return (
self._dimensions.universe[logical_table]
.schema.all[field]
Expand Down Expand Up @@ -369,15 +372,15 @@ def _get_dimension_keys(self) -> Set[str]:
class ResultColumn(NamedTuple):
"""Defines a column that can be output from a query."""

logical_table: str
logical_table: EllipsisType | str
"""Dimension element name or dataset type name."""

field: str | None
"""Column associated with the dimension element or dataset type, or `None`
if it is a dimension key column."""

def __str__(self) -> str:
return self.logical_table if self.field is None else f"{self.logical_table}.{self.field}"
return str(self.logical_table) if self.field is None else f"{self.logical_table}.{self.field}"


class ColumnOrder:
Expand Down Expand Up @@ -417,7 +420,7 @@ def dimension_key_names(self) -> list[str]:
"""Return the names of the dimension key columns included in result
rows, in the order they appear in the row.
"""
return [column.logical_table for column in self._dimension_keys]
return [cast(str, column.logical_table) for column in self._dimension_keys]

def extract_dimension_key_columns(self, row: Sequence[DataIdValue]) -> Sequence[DataIdValue]:
"""Given a full result row, return just the dimension key columns.
Expand Down
45 changes: 41 additions & 4 deletions python/lsst/daf/butler/queries/tree/_query_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
)

import uuid
from collections.abc import Mapping
from collections.abc import Iterator, Mapping
from types import EllipsisType
from typing import TypeAlias, final

import pydantic
Expand Down Expand Up @@ -127,6 +128,11 @@ class QueryTree(QueryTreeBase):
datasets: Mapping[str, DatasetSearch] = pydantic.Field(default_factory=dict)
"""Dataset searches that have been joined into the query."""

any_dataset: DatasetSearch | None = pydantic.Field(default=None)
"""A special optional dataset search for all dataset types with a
particular set of dimensions.
"""

data_coordinate_uploads: Mapping[DataCoordinateUploadKey, DimensionGroup] = pydantic.Field(
default_factory=dict
)
Expand All @@ -141,6 +147,11 @@ class QueryTree(QueryTreeBase):
predicate: Predicate = Predicate.from_bool(True)
"""Boolean expression trees whose logical AND defines a row filter."""

def iter_all_dataset_searches(self) -> Iterator[tuple[str | EllipsisType, DatasetSearch]]:
yield from self.datasets.items()
if self.any_dataset is not None:
yield (..., self.any_dataset)

def get_joined_dimension_groups(self) -> frozenset[DimensionGroup]:
"""Return a set of the dimension groups of all data coordinate uploads,
dataset searches, and materializations.
Expand All @@ -149,6 +160,8 @@ def get_joined_dimension_groups(self) -> frozenset[DimensionGroup]:
result.update(self.materializations.values())
for dataset_spec in self.datasets.values():
result.add(dataset_spec.dimensions)
if self.any_dataset is not None:
result.add(self.any_dataset.dimensions)
return frozenset(result)

def join_dimensions(self, dimensions: DimensionGroup) -> QueryTree:
Expand Down Expand Up @@ -242,6 +255,28 @@ def join_dataset(self, dataset_type: str, search: DatasetSearch) -> QueryTree:
update=dict(dimensions=self.dimensions | search.dimensions, datasets=datasets)
)

def join_any_dataset(self, search: DatasetSearch) -> QueryTree:
"""Return a new tree that joins in a search for any dataset type with
the given diensions.
Parameters
----------
search : `DatasetSearch`
Struct containing the collection search path and dimensions.
Returns
-------
result : `QueryTree`
A new tree that joins in the dataset search.
"""
if self.any_dataset is not None:
assert self.any_dataset == search, "Dataset search should be new or the same."
return self
else:
return self.model_copy(
update=dict(dimensions=self.dimensions | search.dimensions, any_dataset=search)
)

def where(self, *terms: Predicate) -> QueryTree:
"""Return a new tree that adds row filtering via a boolean column
expression.
Expand Down Expand Up @@ -275,10 +310,12 @@ def where(self, *terms: Predicate) -> QueryTree:
for where_term in terms:
where_term.gather_required_columns(columns)
predicate = predicate.logical_and(where_term)
if not (columns.dataset_fields.keys() <= self.datasets.keys()):
missing_dataset_types = columns.dataset_fields.keys() - self.datasets.keys()
if self.any_dataset is not None:
missing_dataset_types.discard(...)
if missing_dataset_types:
raise InvalidQueryError(
f"Cannot reference dataset type(s) {columns.dataset_fields.keys() - self.datasets.keys()} "
"that have not been joined."
f"Cannot reference dataset type(s) {missing_dataset_types} that have not been joined."
)
return self.model_copy(update=dict(dimensions=columns.dimensions, predicate=predicate))

Expand Down
Loading

0 comments on commit 4818ac7

Please sign in to comment.