From 80e777823bc4b2f71acec0fd2f0507e82f64dd22 Mon Sep 17 00:00:00 2001 From: Sumedh Sakdeo <773250+sumedhsakdeo@users.noreply.github.com> Date: Tue, 21 Aug 2018 13:45:42 -0700 Subject: [PATCH] Field names in big query can contain only alphanumeric and underscore (#5641) * Field names in big query can contain only alphanumeric and underscore * bad quote * better place for mutating labels * lint * bug fix thanks to mistercrunch * lint * lint again --- .gitignore | 1 + superset/db_engine_specs.py | 16 ++++++++++++++++ superset/viz.py | 8 ++++++-- tests/viz_tests.py | 16 +++++++++++++++- 4 files changed, 38 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 11929a9b7e828..66bdf28da8dbc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pyc +*.swp yarn-error.log _modules superset/assets/coverage/* diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index fe408ce3433e6..13eb69502bf51 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -426,6 +426,10 @@ def align_df_col_names_with_form_data(df, fd): return df.rename(index=str, columns=rename_cols) + @staticmethod + def mutate_expression_label(label): + return label + class PostgresBaseEngineSpec(BaseEngineSpec): """ Abstract class for Postgres 'like' databases """ @@ -1414,6 +1418,18 @@ def fetch_data(cls, cursor, limit): data = [r.values() for r in data] return data + @staticmethod + def mutate_expression_label(label): + mutated_label = re.sub('[^\w]+', '_', label) + if not re.match('^[a-zA-Z_]+.*', mutated_label): + raise SupersetTemplateException('BigQuery field_name used is invalid {}, ' + 'should start with a letter or ' + 'underscore'.format(mutated_label)) + if len(mutated_label) > 128: + raise SupersetTemplateException('BigQuery field_name {}, should be atmost ' + '128 characters'.format(mutated_label)) + return mutated_label + @classmethod def _get_fields(cls, cols): """ diff --git a/superset/viz.py b/superset/viz.py index 4c2490c51866d..9113b038d04cb 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -108,7 +108,10 @@ def process_metrics(self): if not isinstance(val, list): val = [val] for o in val: - self.metric_dict[self.get_metric_label(o)] = o + label = self.get_metric_label(o) + if isinstance(o, dict): + o['label'] = label + self.metric_dict[label] = o # Cast to list needed to return serializable object in py3 self.all_metrics = list(self.metric_dict.values()) @@ -118,7 +121,8 @@ def get_metric_label(self, metric): if isinstance(metric, string_types): return metric if isinstance(metric, dict): - return metric.get('label') + return self.datasource.database.db_engine_spec.mutate_expression_label( + metric.get('label')) @staticmethod def handle_js_int_overflow(data): diff --git a/tests/viz_tests.py b/tests/viz_tests.py index 3cc2f8677cc02..ccca0263afb11 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -122,6 +122,20 @@ def test_cache_timeout(self): class TableVizTestCase(unittest.TestCase): + + class DBEngineSpecMock: + @staticmethod + def mutate_expression_label(label): + return label + + class DatabaseMock: + def __init__(self): + self.db_engine_spec = TableVizTestCase.DBEngineSpecMock() + + class DatasourceMock: + def __init__(self): + self.database = TableVizTestCase.DatabaseMock() + def test_get_data_applies_percentage(self): form_data = { 'percent_metrics': [{ @@ -137,7 +151,7 @@ def test_get_data_applies_percentage(self): 'column': {'column_name': 'value1', 'type': 'DOUBLE'}, }, 'count', 'avg__C'], } - datasource = Mock() + datasource = TableVizTestCase.DatasourceMock() raw = {} raw['SUM(value1)'] = [15, 20, 25, 40] raw['avg__B'] = [10, 20, 5, 15]