From 1df76b62bb4ce35c5adf82676d96f11f4559efbb Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 26 Jul 2021 10:58:59 -0700 Subject: [PATCH] feat: apply post processing to chart data (#15843) * feat: apply post processing to chart data * Fix tests and lint * Fix lint * trigger tests --- superset/charts/api.py | 44 ++++++++- superset/charts/post_processing.py | 100 ++++++++++++++++++++ superset/common/query_actions.py | 6 +- superset/utils/core.py | 1 + tests/integration_tests/charts/api_tests.py | 2 + 5 files changed, 148 insertions(+), 5 deletions(-) create mode 100644 superset/charts/post_processing.py diff --git a/superset/charts/api.py b/superset/charts/api.py index 0d394cd51857c..98dd509acccdf 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -52,6 +52,7 @@ from superset.charts.commands.update import UpdateChartCommand from superset.charts.dao import ChartDAO from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter +from superset.charts.post_processing import post_processors from superset.charts.schemas import ( CHART_SCHEMAS, ChartPostSchema, @@ -481,9 +482,25 @@ def bulk_delete(self, **kwargs: Any) -> Response: except ChartBulkDeleteFailedError as ex: return self.response_422(message=str(ex)) - def send_chart_response(self, result: Dict[Any, Any]) -> Response: + def send_chart_response( + self, + result: Dict[Any, Any], + viz_type: Optional[str] = None, + form_data: Optional[Dict[str, Any]] = None, + ) -> Response: + result_type = result["query_context"].result_type result_format = result["query_context"].result_format + # Post-process the data so it matches the data presented in the chart. + # This is needed for sending reports based on text charts that do the + # post-processing of data, eg, the pivot table. + if ( + result_type == ChartDataResultType.POST_PROCESSED + and viz_type in post_processors + ): + post_process = post_processors[viz_type] + result = post_process(result, form_data) + if result_format == ChartDataResultFormat.CSV: # Verify user has permission to export CSV file if not security_manager.can_access("can_csv", "Superset"): @@ -506,7 +523,11 @@ def send_chart_response(self, result: Dict[Any, Any]) -> Response: return self.response_400(message=f"Unsupported result_format: {result_format}") def get_data_response( - self, command: ChartDataCommand, force_cached: bool = False + self, + command: ChartDataCommand, + force_cached: bool = False, + viz_type: Optional[str] = None, + form_data: Optional[Dict[str, Any]] = None, ) -> Response: try: result = command.run(force_cached=force_cached) @@ -515,7 +536,7 @@ def get_data_response( except ChartDataQueryFailedError as exc: return self.response_400(message=exc.message) - return self.send_chart_response(result) + return self.send_chart_response(result, viz_type, form_data) @expose("//data/", methods=["GET"]) @protect() @@ -544,6 +565,11 @@ def get_data(self, pk: int) -> Response: description: The format in which the data should be returned schema: type: string + - in: query + name: type + description: The type in which the data should be returned + schema: + type: string responses: 200: description: Query result @@ -580,9 +606,12 @@ def get_data(self, pk: int) -> Response: ) ) + # override saved query context json_body["result_format"] = request.args.get( "format", ChartDataResultFormat.JSON ) + json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL) + try: command = ChartDataCommand() query_context = command.set_query_context(json_body) @@ -604,7 +633,14 @@ def get_data(self, pk: int) -> Response: ): return self._run_async(command) - return self.get_data_response(command) + try: + form_data = json.loads(chart.params) + except (TypeError, json.decoder.JSONDecodeError): + form_data = {} + + return self.get_data_response( + command, viz_type=chart.viz_type, form_data=form_data + ) @expose("/data", methods=["POST"]) @protect() diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py new file mode 100644 index 0000000000000..68e08fd2068ba --- /dev/null +++ b/superset/charts/post_processing.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Functions to reproduce the post-processing of data on text charts. + +Some text-based charts (pivot tables and t-test table) perform +post-processing of the data in Javascript. When sending the data +to users in reports we want to show the same data they would see +on Explore. + +In order to do that, we reproduce the post-processing in Python +for these chart types. +""" + +from typing import Any, Callable, Dict, Optional, Union + +import pandas as pd + +from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name + + +def pivot_table( + result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None +) -> Dict[Any, Any]: + """ + Pivot table. + """ + for query in result["queries"]: + data = query["data"] + df = pd.DataFrame(data) + form_data = form_data or {} + + if form_data.get("granularity") == "all" and DTTM_ALIAS in df: + del df[DTTM_ALIAS] + + metrics = [get_metric_name(m) for m in form_data["metrics"]] + aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {} + for metric in metrics: + aggfunc = form_data.get("pandas_aggfunc") or "sum" + if pd.api.types.is_numeric_dtype(df[metric]): + if aggfunc == "sum": + aggfunc = lambda x: x.sum(min_count=1) + elif aggfunc not in {"min", "max"}: + aggfunc = "max" + aggfuncs[metric] = aggfunc + + groupby = form_data.get("groupby") or [] + columns = form_data.get("columns") or [] + if form_data.get("transpose_pivot"): + groupby, columns = columns, groupby + + df = df.pivot_table( + index=groupby, + columns=columns, + values=metrics, + aggfunc=aggfuncs, + margins=form_data.get("pivot_margins"), + ) + + # Re-order the columns adhering to the metric ordering. + df = df[metrics] + + # Display metrics side by side with each column + if form_data.get("combine_metric"): + df = df.stack(0).unstack().reindex(level=-1, columns=metrics) + + # flatten column names + df.columns = [" ".join(column) for column in df.columns] + + # re-arrange data into a list of dicts + data = [] + for i in df.index: + row = {col: df[col][i] for col in df.columns} + row[df.index.name] = i + data.append(row) + query["data"] = data + query["colnames"] = list(df.columns) + query["coltypes"] = extract_dataframe_dtypes(df) + query["rowcount"] = len(df.index) + + return result + + +post_processors = { + "pivot_table": pivot_table, +} diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index dbd73c065abff..d0ef7d3f26575 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -157,6 +157,10 @@ def _get_results( ChartDataResultType.SAMPLES: _get_samples, ChartDataResultType.FULL: _get_full, ChartDataResultType.RESULTS: _get_results, + # for requests for post-processed data we return the full results, + # and post-process it later where we have the chart context, since + # post-processing is unique to each visualization type + ChartDataResultType.POST_PROCESSED: _get_full, } @@ -180,5 +184,5 @@ def get_query_results( if result_func: return result_func(query_context, query_obj, force_cached) raise QueryObjectValidationError( - _("Invalid result type: %(result_type)", result_type=result_type) + _("Invalid result type: %(result_type)s", result_type=result_type) ) diff --git a/superset/utils/core.py b/superset/utils/core.py index 936e191bc93a1..4371640cc5058 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -178,6 +178,7 @@ class ChartDataResultType(str, Enum): RESULTS = "results" SAMPLES = "samples" TIMEGRAINS = "timegrains" + POST_PROCESSED = "post_processed" class DatasourceDict(TypedDict): diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index b99dbe6db2883..74d0fa200e378 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -1498,6 +1498,7 @@ def test_chart_data_async_cached_sync_response(self): class QueryContext: result_format = ChartDataResultFormat.JSON + result_type = utils.ChartDataResultType.FULL cmd_run_val = { "query_context": QueryContext(), @@ -1508,6 +1509,7 @@ class QueryContext: ChartDataCommand, "run", return_value=cmd_run_val ) as patched_run: request_payload = get_query_context("birth_names") + request_payload["result_type"] = utils.ChartDataResultType.FULL rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8"))