diff --git a/doc/changes/DM-46347.bugfix.md b/doc/changes/DM-46347.bugfix.md new file mode 100644 index 0000000000..6946a78b90 --- /dev/null +++ b/doc/changes/DM-46347.bugfix.md @@ -0,0 +1 @@ +Fixed an issue where default data IDs were not constraining query results in the new query system. diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index bfc85ac143..fe3547c7c0 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -68,6 +68,7 @@ from ..registry.managers import RegistryManagerInstances from ..registry.wildcards import CollectionWildcard from ._postprocessing import Postprocessing +from ._predicate_constraints_summary import PredicateConstraintsSummary from ._query_builder import QueryBuilder, QueryJoiner from ._query_plan import ( QueryFindFirstPlan, @@ -927,24 +928,24 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query tree.get_joined_dimension_groups(), calibration_dataset_types, ) - result = QueryJoinsPlan(predicate=predicate, columns=builder.columns) + + # Extract the data ID implied by the predicate; we can use the governor + # dimensions in that to constrain the collections we search for + # datasets later. + predicate_constraints = PredicateConstraintsSummary(predicate) + # Use the default data ID to apply additional constraints where needed. + predicate_constraints.apply_default_data_id(self._default_data_id, tree.dimensions) + predicate = predicate_constraints.predicate + + result = QueryJoinsPlan( + predicate=predicate, + columns=builder.columns, + messages=predicate_constraints.messages, + ) + # Add columns required by postprocessing. builder.postprocessing.gather_columns_required(result.columns) - # We also check that the predicate doesn't reference any dimensions - # without constraining their governor dimensions, since that's a - # particularly easy mistake to make and it's almost never intentional. - # We also allow the registry data ID values to provide governor values. - where_governors: set[str] = set() - result.predicate.gather_governors(where_governors) - for governor in where_governors: - if governor not in result.constraint_data_id and governor not in result.governors_referenced: - if governor in self._default_data_id.dimensions: - result.constraint_data_id[governor] = self._default_data_id[governor] - else: - raise InvalidQueryError( - f"Query 'where' expression references a dimension dependent on {governor} without " - "constraining it directly." - ) + # Add materializations, which can also bring in more postprocessing. for m_key, m_dimensions in tree.materializations.items(): m_state = self._materializations[m_key] @@ -969,7 +970,7 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query resolved_dataset_search = self._resolve_dataset_search( dataset_type_name, dataset_search, - result.constraint_data_id, + predicate_constraints.constraint_data_id, summaries_by_dataset_type[dataset_type_name], ) result.datasets[dataset_type_name] = resolved_dataset_search diff --git a/python/lsst/daf/butler/direct_query_driver/_predicate_constraints_summary.py b/python/lsst/daf/butler/direct_query_driver/_predicate_constraints_summary.py new file mode 100644 index 0000000000..764e5100ec --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_predicate_constraints_summary.py @@ -0,0 +1,218 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import Any + +from .._exceptions import InvalidQueryError +from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionUniverse +from ..queries import tree as qt +from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, SimplePredicateVisitor + + +class PredicateConstraintsSummary: + """Summarizes information about the constraints on data ID values implied + by a Predicate. + + Parameters + ---------- + predicate : `Predicate` + Predicate to summarize. + """ + + predicate: qt.Predicate + """The predicate examined by this summary.""" + + constraint_data_id: dict[str, DataIdValue] + """Data ID values that will be identical in all result rows due to query + constraints. + """ + + messages: list[str] + """Diagnostic messages that report reasons the query may not return any + rows. + """ + + def __init__(self, predicate: qt.Predicate) -> None: + self.predicate = predicate + self.constraint_data_id = {} + self.messages = [] + # Governor dimensions referenced directly in the predicate, but not + # necessarily constrained to the same value in all logic branches. + self._governors_referenced: set[str] = set() + + self.predicate.visit( + _DataIdExtractionVisitor(self.constraint_data_id, self.messages, self._governors_referenced) + ) + + def apply_default_data_id( + self, default_data_id: DataCoordinate, query_dimensions: DimensionGroup + ) -> None: + """Augment the predicate and summary by adding missing constraints for + governor dimensions using a default data ID. + + Parameters + ---------- + default_data_id : `DataCoordinate` + Data ID values that will be used to constrain the query if governor + dimensions have not already been constrained by the predicate. + + query_dimensions : `DimensionGroup` + The set of dimensions returned in result rows from the query. + """ + # Find governor dimensions required by the predicate. + # If these are not constrained by the predicate or the default data ID, + # we will raise an exception. + where_governors: set[str] = set() + self.predicate.gather_governors(where_governors) + + # Add in governor dimensions that are returned in result rows. + # We constrain these using a default data ID if one is available, + # but it's not an error to omit the constraint. + governors_used_by_query = where_governors | query_dimensions.governors + + # For each governor dimension needed by the query, add a constraint + # from the default data ID if the existing predicate does not + # constrain it. + for governor in governors_used_by_query: + if governor not in self.constraint_data_id and governor not in self._governors_referenced: + if governor in default_data_id.dimensions: + data_id_value = default_data_id[governor] + self.constraint_data_id[governor] = data_id_value + self._governors_referenced.add(governor) + self.predicate = self.predicate.logical_and( + _create_data_id_predicate(governor, data_id_value, query_dimensions.universe) + ) + elif governor in where_governors: + # Check that the predicate doesn't reference any dimensions + # without constraining their governor dimensions, since + # that's a particularly easy mistake to make and it's + # almost never intentional. + raise InvalidQueryError( + f"Query 'where' expression references a dimension dependent on {governor} without " + "constraining it directly." + ) + + +def _create_data_id_predicate( + dimension_name: str, value: DataIdValue, universe: DimensionUniverse +) -> qt.Predicate: + """Create a Predicate that tests whether the given dimension primary key is + equal to the given literal value. + """ + dimension = universe.dimensions[dimension_name] + return qt.Predicate.compare( + qt.DimensionKeyReference(dimension=dimension), "==", qt.make_column_literal(value) + ) + + +class _DataIdExtractionVisitor( + SimplePredicateVisitor, + ColumnExpressionVisitor[tuple[str, None] | tuple[None, Any] | tuple[None, None]], +): + """A column-expression visitor that extracts quality constraints on + dimensions that are not OR'd with anything else. + + Parameters + ---------- + data_id : `dict` + Dictionary to populate in place. + messages : `list` [ `str` ] + List of diagnostic messages to populate in place. + governor_references : `set` [ `str` ] + Set of the names of governor dimension names that were referenced + directly. This includes dimensions that were constrained to different + values in different logic branches, and hence not included in + ``data_id``. + """ + + def __init__(self, data_id: dict[str, DataIdValue], messages: list[str], governor_references: set[str]): + self.data_id = data_id + self.messages = messages + self.governor_references = governor_references + + def visit_comparison( + self, + a: qt.ColumnExpression, + operator: qt.ComparisonOperator, + b: qt.ColumnExpression, + flags: PredicateVisitFlags, + ) -> None: + k_a, v_a = a.visit(self) + k_b, v_b = b.visit(self) + if flags & PredicateVisitFlags.HAS_OR_SIBLINGS: + return None + if flags & PredicateVisitFlags.INVERTED: + if operator == "!=": + operator = "==" + else: + return None + if operator != "==": + return None + if k_a is not None and v_b is not None: + key = k_a + value = v_b + elif k_b is not None and v_a is not None: + key = k_b + value = v_a + else: + return None + if (old := self.data_id.setdefault(key, value)) != value: + self.messages.append(f"'where' expression requires both {key}={value!r} and {key}={old!r}.") + return None + + def visit_binary_expression(self, expression: qt.BinaryExpression) -> tuple[None, None]: + expression.a.visit(self) + expression.b.visit(self) + return None, None + + def visit_unary_expression(self, expression: qt.UnaryExpression) -> tuple[None, None]: + expression.operand.visit(self) + return None, None + + def visit_literal(self, expression: qt.ColumnLiteral) -> tuple[None, Any]: + return None, expression.get_literal_value() + + def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> tuple[str, None]: + if expression.dimension.governor is expression.dimension: + self.governor_references.add(expression.dimension.name) + return expression.dimension.name, None + + def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> tuple[None, None]: + if ( + expression.element.governor is expression.element + and expression.field in expression.element.alternate_keys.names + ): + self.governor_references.add(expression.element.name) + return None, None + + def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> tuple[None, None]: + return None, None + + def visit_reversed(self, expression: qt.Reversed) -> tuple[None, None]: + raise AssertionError("No Reversed expressions in predicates.") diff --git a/python/lsst/daf/butler/direct_query_driver/_query_plan.py b/python/lsst/daf/butler/direct_query_driver/_query_plan.py index ced6c1131c..7fa5fa1c74 100644 --- a/python/lsst/daf/butler/direct_query_driver/_query_plan.py +++ b/python/lsst/daf/butler/direct_query_driver/_query_plan.py @@ -37,11 +37,9 @@ import dataclasses from collections.abc import Iterator -from typing import Any -from ..dimensions import DataIdValue, DimensionElement, DimensionGroup +from ..dimensions import DimensionElement, DimensionGroup from ..queries import tree as qt -from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, SimplePredicateVisitor from ..registry.interfaces import CollectionRecord @@ -105,29 +103,13 @@ class QueryJoinsPlan: ) """Data coordinate uploads to join into the query.""" - constraint_data_id: dict[str, DataIdValue] = dataclasses.field(default_factory=dict) - """A data ID that must be consistent with all result rows, extracted from - `predicate` at construction. - """ - messages: list[str] = dataclasses.field(default_factory=list) """Diagnostic messages that report reasons the query may not return any rows. """ - governors_referenced: set[str] = dataclasses.field(default_factory=set) - """Governor dimensions referenced directly in the predicate, but not - necessarily constrained to the same value in all logic branches. - """ - def __post_init__(self) -> None: self.predicate.gather_required_columns(self.columns) - # Extract the data ID implied by the predicate; we can use the governor - # dimensions in that to constrain the collections we search for - # datasets later. - self.predicate.visit( - _DataIdExtractionVisitor(self.constraint_data_id, self.messages, self.governors_referenced) - ) def iter_mandatory(self) -> Iterator[DimensionElement]: """Return an iterator over the dimension elements that must be joined @@ -296,90 +278,3 @@ class QueryPlan: fields added directly to `QueryBuilder.special`, which may also be added to the SELECT clause. """ - - -class _DataIdExtractionVisitor( - SimplePredicateVisitor, - ColumnExpressionVisitor[tuple[str, None] | tuple[None, Any] | tuple[None, None]], -): - """A column-expression visitor that extracts quality constraints on - dimensions that are not OR'd with anything else. - - Parameters - ---------- - data_id : `dict` - Dictionary to populate in place. - messages : `list` [ `str` ] - List of diagnostic messages to populate in place. - governor_references : `set` [ `str` ] - Set of the names of governor dimension names that were referenced - directly. This includes dimensions that were constrained to different - values in different logic branches, and hence not included in - ``data_id``. - """ - - def __init__(self, data_id: dict[str, DataIdValue], messages: list[str], governor_references: set[str]): - self.data_id = data_id - self.messages = messages - self.governor_references = governor_references - - def visit_comparison( - self, - a: qt.ColumnExpression, - operator: qt.ComparisonOperator, - b: qt.ColumnExpression, - flags: PredicateVisitFlags, - ) -> None: - k_a, v_a = a.visit(self) - k_b, v_b = b.visit(self) - if flags & PredicateVisitFlags.HAS_OR_SIBLINGS: - return None - if flags & PredicateVisitFlags.INVERTED: - if operator == "!=": - operator = "==" - else: - return None - if operator != "==": - return None - if k_a is not None and v_b is not None: - key = k_a - value = v_b - elif k_b is not None and v_a is not None: - key = k_b - value = v_a - else: - return None - if (old := self.data_id.setdefault(key, value)) != value: - self.messages.append(f"'where' expression requires both {key}={value!r} and {key}={old!r}.") - return None - - def visit_binary_expression(self, expression: qt.BinaryExpression) -> tuple[None, None]: - expression.a.visit(self) - expression.b.visit(self) - return None, None - - def visit_unary_expression(self, expression: qt.UnaryExpression) -> tuple[None, None]: - expression.operand.visit(self) - return None, None - - def visit_literal(self, expression: qt.ColumnLiteral) -> tuple[None, Any]: - return None, expression.get_literal_value() - - def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> tuple[str, None]: - if expression.dimension.governor is expression.dimension: - self.governor_references.add(expression.dimension.name) - return expression.dimension.name, None - - def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> tuple[None, None]: - if ( - expression.element.governor is expression.element - and expression.field in expression.element.alternate_keys.names - ): - self.governor_references.add(expression.element.name) - return None, None - - def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> tuple[None, None]: - return None, None - - def visit_reversed(self, expression: qt.Reversed) -> tuple[None, None]: - raise AssertionError("No Reversed expressions in predicates.") diff --git a/python/lsst/daf/butler/tests/butler_queries.py b/python/lsst/daf/butler/tests/butler_queries.py index 970b42a770..9178a1d0bd 100644 --- a/python/lsst/daf/butler/tests/butler_queries.py +++ b/python/lsst/daf/butler/tests/butler_queries.py @@ -1892,6 +1892,78 @@ def test_multiple_instrument_queries(self) -> None: [DataCoordinate.standardize(instrument="Cam1", universe=butler.dimensions)], ) + def test_default_data_id(self) -> None: + butler = self.make_butler("base.yaml") + butler.registry.insertDimensionData("instrument", {"name": "Cam2"}) + butler.registry.insertDimensionData( + "physical_filter", {"instrument": "Cam2", "name": "Cam2-G", "band": "g"} + ) + + # With no default data ID, queries should return results for all + # instruments. + result = butler.query_dimension_records("physical_filter") + names = [x.name for x in result] + self.assertCountEqual(names, ["Cam1-G", "Cam1-R1", "Cam1-R2", "Cam2-G"]) + + result = butler.query_dimension_records("physical_filter", where="band='g'") + names = [x.name for x in result] + self.assertCountEqual(names, ["Cam1-G", "Cam2-G"]) + + # When there is no default data ID and a where clause references + # something depending on instrument, it throws an error as a + # sanity check. + # In this case, 'instrument' is not part of the dimensions returned by + # the query, so there is extra logic needed to detect the need for the + # default data ID. + with self.assertRaisesRegex( + InvalidQueryError, + "Query 'where' expression references a dimension dependent on instrument" + " without constraining it directly.", + ): + butler.query_data_ids(["band"], where="physical_filter='Cam1-G'") + + # Override the default data ID to specify a default instrument for + # subsequent tests. + butler.registry.defaults = RegistryDefaults(instrument="Cam1") + + # When a where clause references something depending on instrument, use + # the default data ID to constrain the instrument. + # In this case, 'instrument' is not part of the dimensions returned by + # the query, so there is extra logic needed to detect the need for the + # default data ID. + data_ids = butler.query_data_ids(["band"], where="physical_filter='Cam1-G'") + self.assertEqual([x["band"] for x in data_ids], ["g"]) + # Default data ID instrument=Cam1 does not match Cam2, so there are no + # results. + data_ids = butler.query_data_ids(["band"], where="physical_filter='Cam2-G'", explain=False) + self.assertEqual(data_ids, []) + # Overriding the default lets us get the results. + data_ids = butler.query_data_ids(["band"], where="instrument='Cam2' and physical_filter='Cam2-G'") + self.assertEqual([x["band"] for x in data_ids], ["g"]) + + # Query for a dimension that depends on instrument should pull in the + # default data ID instrument="Cam1" to constrain results. + result = butler.query_dimension_records("physical_filter") + names = [x.name for x in result] + self.assertCountEqual(names, ["Cam1-G", "Cam1-R1", "Cam1-R2"]) + + # Query for a dimension that depends on instrument should pull in the + # default data ID instrument="Cam1" to constrain results, if the where + # clause does not explicitly specify instrument. + result = butler.query_dimension_records("physical_filter", where="band='g'") + names = [x.name for x in result] + self.assertEqual(names, ["Cam1-G"]) + + # Queries that specify instrument explicitly in the where clause + # should ignore the default data ID. + result = butler.query_dimension_records("physical_filter", where="instrument='Cam2'") + names = [x.name for x in result] + self.assertCountEqual(names, ["Cam2-G"]) + + result = butler.query_dimension_records("physical_filter", where="instrument IN ('Cam2')") + names = [x.name for x in result] + self.assertCountEqual(names, ["Cam2-G"]) + def _get_exposure_ids_from_dimension_records(dimension_records: Iterable[DimensionRecord]) -> list[int]: output = []