Skip to content

Commit

Permalink
feat: apply post processing to chart data (apache#15843)
Browse files Browse the repository at this point in the history
* feat: apply post processing to chart data

* Fix tests and lint

* Fix lint

* trigger tests
  • Loading branch information
betodealmeida authored Jul 26, 2021
1 parent cf591f8 commit 1df76b6
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 5 deletions.
44 changes: 40 additions & 4 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from superset.charts.commands.update import UpdateChartCommand
from superset.charts.dao import ChartDAO
from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter
from superset.charts.post_processing import post_processors
from superset.charts.schemas import (
CHART_SCHEMAS,
ChartPostSchema,
Expand Down Expand Up @@ -481,9 +482,25 @@ def bulk_delete(self, **kwargs: Any) -> Response:
except ChartBulkDeleteFailedError as ex:
return self.response_422(message=str(ex))

def send_chart_response(self, result: Dict[Any, Any]) -> Response:
def send_chart_response(
self,
result: Dict[Any, Any],
viz_type: Optional[str] = None,
form_data: Optional[Dict[str, Any]] = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format

# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if (
result_type == ChartDataResultType.POST_PROCESSED
and viz_type in post_processors
):
post_process = post_processors[viz_type]
result = post_process(result, form_data)

if result_format == ChartDataResultFormat.CSV:
# Verify user has permission to export CSV file
if not security_manager.can_access("can_csv", "Superset"):
Expand All @@ -506,7 +523,11 @@ def send_chart_response(self, result: Dict[Any, Any]) -> Response:
return self.response_400(message=f"Unsupported result_format: {result_format}")

def get_data_response(
self, command: ChartDataCommand, force_cached: bool = False
self,
command: ChartDataCommand,
force_cached: bool = False,
viz_type: Optional[str] = None,
form_data: Optional[Dict[str, Any]] = None,
) -> Response:
try:
result = command.run(force_cached=force_cached)
Expand All @@ -515,7 +536,7 @@ def get_data_response(
except ChartDataQueryFailedError as exc:
return self.response_400(message=exc.message)

return self.send_chart_response(result)
return self.send_chart_response(result, viz_type, form_data)

@expose("/<int:pk>/data/", methods=["GET"])
@protect()
Expand Down Expand Up @@ -544,6 +565,11 @@ def get_data(self, pk: int) -> Response:
description: The format in which the data should be returned
schema:
type: string
- in: query
name: type
description: The type in which the data should be returned
schema:
type: string
responses:
200:
description: Query result
Expand Down Expand Up @@ -580,9 +606,12 @@ def get_data(self, pk: int) -> Response:
)
)

# override saved query context
json_body["result_format"] = request.args.get(
"format", ChartDataResultFormat.JSON
)
json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL)

try:
command = ChartDataCommand()
query_context = command.set_query_context(json_body)
Expand All @@ -604,7 +633,14 @@ def get_data(self, pk: int) -> Response:
):
return self._run_async(command)

return self.get_data_response(command)
try:
form_data = json.loads(chart.params)
except (TypeError, json.decoder.JSONDecodeError):
form_data = {}

return self.get_data_response(
command, viz_type=chart.viz_type, form_data=form_data
)

@expose("/data", methods=["POST"])
@protect()
Expand Down
100 changes: 100 additions & 0 deletions superset/charts/post_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Functions to reproduce the post-processing of data on text charts.
Some text-based charts (pivot tables and t-test table) perform
post-processing of the data in Javascript. When sending the data
to users in reports we want to show the same data they would see
on Explore.
In order to do that, we reproduce the post-processing in Python
for these chart types.
"""

from typing import Any, Callable, Dict, Optional, Union

import pandas as pd

from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name


def pivot_table(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None
) -> Dict[Any, Any]:
"""
Pivot table.
"""
for query in result["queries"]:
data = query["data"]
df = pd.DataFrame(data)
form_data = form_data or {}

if form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]

metrics = [get_metric_name(m) for m in form_data["metrics"]]
aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
for metric in metrics:
aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]):
if aggfunc == "sum":
aggfunc = lambda x: x.sum(min_count=1)
elif aggfunc not in {"min", "max"}:
aggfunc = "max"
aggfuncs[metric] = aggfunc

groupby = form_data.get("groupby") or []
columns = form_data.get("columns") or []
if form_data.get("transpose_pivot"):
groupby, columns = columns, groupby

df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=aggfuncs,
margins=form_data.get("pivot_margins"),
)

# Re-order the columns adhering to the metric ordering.
df = df[metrics]

# Display metrics side by side with each column
if form_data.get("combine_metric"):
df = df.stack(0).unstack().reindex(level=-1, columns=metrics)

# flatten column names
df.columns = [" ".join(column) for column in df.columns]

# re-arrange data into a list of dicts
data = []
for i in df.index:
row = {col: df[col][i] for col in df.columns}
row[df.index.name] = i
data.append(row)
query["data"] = data
query["colnames"] = list(df.columns)
query["coltypes"] = extract_dataframe_dtypes(df)
query["rowcount"] = len(df.index)

return result


post_processors = {
"pivot_table": pivot_table,
}
6 changes: 5 additions & 1 deletion superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def _get_results(
ChartDataResultType.SAMPLES: _get_samples,
ChartDataResultType.FULL: _get_full,
ChartDataResultType.RESULTS: _get_results,
# for requests for post-processed data we return the full results,
# and post-process it later where we have the chart context, since
# post-processing is unique to each visualization type
ChartDataResultType.POST_PROCESSED: _get_full,
}


Expand All @@ -180,5 +184,5 @@ def get_query_results(
if result_func:
return result_func(query_context, query_obj, force_cached)
raise QueryObjectValidationError(
_("Invalid result type: %(result_type)", result_type=result_type)
_("Invalid result type: %(result_type)s", result_type=result_type)
)
1 change: 1 addition & 0 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class ChartDataResultType(str, Enum):
RESULTS = "results"
SAMPLES = "samples"
TIMEGRAINS = "timegrains"
POST_PROCESSED = "post_processed"


class DatasourceDict(TypedDict):
Expand Down
2 changes: 2 additions & 0 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,7 @@ def test_chart_data_async_cached_sync_response(self):

class QueryContext:
result_format = ChartDataResultFormat.JSON
result_type = utils.ChartDataResultType.FULL

cmd_run_val = {
"query_context": QueryContext(),
Expand All @@ -1508,6 +1509,7 @@ class QueryContext:
ChartDataCommand, "run", return_value=cmd_run_val
) as patched_run:
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.FULL
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
Expand Down

0 comments on commit 1df76b6

Please sign in to comment.