From 377f7546b81f9ebde65fbfd4ef4f383340a6d6a5 Mon Sep 17 00:00:00 2001 From: Kamil Gabryjelski Date: Fri, 6 Oct 2023 18:47:00 +0200 Subject: [PATCH] fix: Apply normalization to all dttm columns (#25147) --- superset/common/query_context_factory.py | 1 + superset/common/query_context_processor.py | 5 +- superset/common/query_object_factory.py | 66 +++++++++++++- .../integration_tests/query_context_tests.py | 8 +- .../common/test_query_object_factory.py | 90 ++++++++++++++++++- 5 files changed, 160 insertions(+), 10 deletions(-) diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index d6510ccd9a434..4fd0de7856bfe 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -185,6 +185,7 @@ def _apply_granularity( filter for filter in query_object.filter if filter["col"] != filter_to_remove + or filter["op"] != "TEMPORAL_RANGE" ] def _apply_filters(self, query_object: QueryObject) -> None: diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 5a0468b671b39..dcf19c0c321fa 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -282,10 +282,11 @@ def _get_timestamp_format( datasource = self._qc_datasource labels = tuple( label - for label in [ + for label in { *get_base_axis_labels(query_object.columns), + *[col for col in query_object.columns or [] if isinstance(col, str)], query_object.granularity, - ] + } if datasource # Query datasource didn't support `get_column` and hasattr(datasource, "get_column") diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index a2732ae5537f1..33393e88a6ad9 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -16,17 +16,24 @@ # under the License. from __future__ import annotations +from datetime import datetime from typing import Any, TYPE_CHECKING from superset.common.chart_data import ChartDataResultType from superset.common.query_object import QueryObject from superset.common.utils.time_range_utils import get_since_until_from_time_range -from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType +from superset.utils.core import ( + apply_max_row_limit, + DatasourceDict, + DatasourceType, + FilterOperator, + QueryObjectFilterClause, +) if TYPE_CHECKING: from sqlalchemy.orm import sessionmaker - from superset.connectors.base.models import BaseDatasource + from superset.connectors.base.models import BaseColumn, BaseDatasource from superset.daos.datasource import DatasourceDAO @@ -66,6 +73,10 @@ def create( # pylint: disable=too-many-arguments ) kwargs["from_dttm"] = from_dttm kwargs["to_dttm"] = to_dttm + if datasource_model_instance and kwargs.get("filters", []): + kwargs["filters"] = self._process_filters( + datasource_model_instance, kwargs["filters"] + ) return QueryObject( datasource=datasource_model_instance, extras=extras, @@ -102,3 +113,54 @@ def _process_row_limit( # light version of the view.utils.core # import view.utils require application context # Todo: move it and the view.utils.core to utils package + + def _process_filters( + self, datasource: BaseDatasource, query_filters: list[QueryObjectFilterClause] + ) -> list[QueryObjectFilterClause]: + def get_dttm_filter_value( + value: Any, col: BaseColumn, date_format: str + ) -> int | str: + if not isinstance(value, int): + return value + if date_format in {"epoch_ms", "epoch_s"}: + if date_format == "epoch_s": + value = str(value) + else: + value = str(value * 1000) + else: + dttm = datetime.utcfromtimestamp(value / 1000) + value = dttm.strftime(date_format) + + if col.type in col.num_types: + value = int(value) + return value + + for query_filter in query_filters: + if query_filter.get("op") == FilterOperator.TEMPORAL_RANGE: + continue + filter_col = query_filter.get("col") + if not isinstance(filter_col, str): + continue + column = datasource.get_column(filter_col) + if not column: + continue + filter_value = query_filter.get("val") + + date_format = column.python_date_format + if not date_format and datasource.db_extra: + date_format = datasource.db_extra.get( + "python_date_format_by_column_name", {} + ).get(column.column_name) + + if column.is_dttm and date_format: + if isinstance(filter_value, list): + query_filter["val"] = [ + get_dttm_filter_value(value, column, date_format) + for value in filter_value + ] + else: + query_filter["val"] = get_dttm_filter_value( + filter_value, column, date_format + ) + + return query_filters diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 8c2082d1c4b12..00a98b2c21d93 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -836,11 +836,9 @@ def test_special_chars_in_column_name(app_context, physical_dataset): query_object = qc.queries[0] df = qc.get_df_payload(query_object)["df"] - if query_object.datasource.database.backend == "sqlite": - # sqlite returns string as timestamp column - assert df["time column with spaces"][0] == "2002-01-03 00:00:00" - assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00" - else: + + # sqlite doesn't have timestamp columns + if query_object.datasource.database.backend != "sqlite": assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03" assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01" diff --git a/tests/unit_tests/common/test_query_object_factory.py b/tests/unit_tests/common/test_query_object_factory.py index 02304828dca82..4e8fadfe3e993 100644 --- a/tests/unit_tests/common/test_query_object_factory.py +++ b/tests/unit_tests/common/test_query_object_factory.py @@ -43,9 +43,45 @@ def session_factory() -> Mock: return Mock() +class SimpleDatasetColumn: + def __init__(self, col_params: dict[str, Any]): + self.__dict__.update(col_params) + + +TEMPORAL_COLUMN_NAMES = ["temporal_column", "temporal_column_with_python_date_format"] +TEMPORAL_COLUMNS = { + TEMPORAL_COLUMN_NAMES[0]: SimpleDatasetColumn( + { + "column_name": TEMPORAL_COLUMN_NAMES[0], + "is_dttm": True, + "python_date_format": None, + "type": "string", + "num_types": ["BIGINT"], + } + ), + TEMPORAL_COLUMN_NAMES[1]: SimpleDatasetColumn( + { + "column_name": TEMPORAL_COLUMN_NAMES[1], + "type": "BIGINT", + "is_dttm": True, + "python_date_format": "%Y", + "num_types": ["BIGINT"], + } + ), +} + + @fixture def connector_registry() -> Mock: - return Mock(spec=["get_datasource"]) + datasource_dao_mock = Mock(spec=["get_datasource"]) + datasource_dao_mock.get_datasource.return_value = Mock() + datasource_dao_mock.get_datasource().get_column = Mock( + side_effect=lambda col_name: TEMPORAL_COLUMNS[col_name] + if col_name in TEMPORAL_COLUMN_NAMES + else Mock() + ) + datasource_dao_mock.get_datasource().db_extra = None + return datasource_dao_mock def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int: @@ -112,3 +148,55 @@ def test_query_context_null_post_processing_op( raw_query_context["result_type"], **raw_query_object ) assert query_object.post_processing == [] + + def test_query_context_no_python_date_format_filters( + self, + query_object_factory: QueryObjectFactory, + raw_query_context: dict[str, Any], + ): + raw_query_object = raw_query_context["queries"][0] + raw_query_object["filters"].append( + {"col": TEMPORAL_COLUMN_NAMES[0], "op": "==", "val": 315532800000} + ) + query_object = query_object_factory.create( + raw_query_context["result_type"], + raw_query_context["datasource"], + **raw_query_object + ) + assert query_object.filter[3]["val"] == 315532800000 + + def test_query_context_python_date_format_filters( + self, + query_object_factory: QueryObjectFactory, + raw_query_context: dict[str, Any], + ): + raw_query_object = raw_query_context["queries"][0] + raw_query_object["filters"].append( + {"col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": 315532800000} + ) + query_object = query_object_factory.create( + raw_query_context["result_type"], + raw_query_context["datasource"], + **raw_query_object + ) + assert query_object.filter[3]["val"] == 1980 + + def test_query_context_python_date_format_filters_list_of_values( + self, + query_object_factory: QueryObjectFactory, + raw_query_context: dict[str, Any], + ): + raw_query_object = raw_query_context["queries"][0] + raw_query_object["filters"].append( + { + "col": TEMPORAL_COLUMN_NAMES[1], + "op": "==", + "val": [315532800000, 631152000000], + } + ) + query_object = query_object_factory.create( + raw_query_context["result_type"], + raw_query_context["datasource"], + **raw_query_object + ) + assert query_object.filter[3]["val"] == [1980, 1990]