diff --git a/superset/security/manager.py b/superset/security/manager.py index 94719cb524e70..a5324314334ac 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -68,6 +68,7 @@ GuestTokenUser, GuestUser, ) +from superset.superset_typing import Metric from superset.utils.core import ( DatasourceName, DatasourceType, @@ -144,6 +145,13 @@ def __init__(self, **kwargs: Any) -> None: RoleModelView.related_views = [] +def freeze_metric(metric: Metric) -> str: + """ + Used to compare metric sets. + """ + return json.dumps(metric, sort_keys=True) + + def query_context_modified(query_context: "QueryContext") -> bool: """ Check if a query context has been modified. @@ -154,9 +162,9 @@ def query_context_modified(query_context: "QueryContext") -> bool: form_data = query_context.form_data stored_chart = query_context.slice_ - # sanity checks + # native filter requests if form_data is None or stored_chart is None: - return True + return False # cannot request a different chart if form_data.get("slice_id") != stored_chart.id: @@ -164,11 +172,10 @@ def query_context_modified(query_context: "QueryContext") -> bool: # compare form_data requested_metrics = { - frozenset(metric.items()) if isinstance(metric, dict) else metric - for metric in form_data.get("metrics") or [] + freeze_metric(metric) for metric in form_data.get("metrics") or [] } stored_metrics = { - frozenset(metric.items()) if isinstance(metric, dict) else metric + freeze_metric(metric) for metric in stored_chart.params_dict.get("metrics") or [] } if not requested_metrics.issubset(stored_metrics): @@ -176,7 +183,7 @@ def query_context_modified(query_context: "QueryContext") -> bool: # compare queries in query_context queries_metrics = { - frozenset(metric.items()) if isinstance(metric, dict) else metric + freeze_metric(metric) for query in query_context.queries for metric in query.metrics or [] } @@ -185,10 +192,7 @@ def query_context_modified(query_context: "QueryContext") -> bool: stored_query_context = json.loads(cast(str, stored_chart.query_context)) for query in stored_query_context.get("queries") or []: stored_metrics.update( - { - frozenset(metric.items()) if isinstance(metric, dict) else metric - for metric in query.get("metrics") or [] - } + {freeze_metric(metric) for metric in query.get("metrics") or []} ) return not queries_metrics.issubset(stored_metrics) diff --git a/tests/unit_tests/security/manager_test.py b/tests/unit_tests/security/manager_test.py index 22ec0dda4a7a4..7ed32b0abbc3e 100644 --- a/tests/unit_tests/security/manager_test.py +++ b/tests/unit_tests/security/manager_test.py @@ -26,7 +26,7 @@ from superset.exceptions import SupersetSecurityException from superset.extensions import appbuilder from superset.models.slice import Slice -from superset.security.manager import SupersetSecurityManager +from superset.security.manager import query_context_modified, SupersetSecurityManager from superset.sql_parse import Table from superset.superset_typing import AdhocMetric from superset.utils.core import override_user @@ -414,3 +414,120 @@ def test_raise_for_access_chart_owner( sm.raise_for_access( chart=slice, ) + + +def test_query_context_modified( + mocker: MockFixture, + stored_metrics: list[AdhocMetric], +) -> None: + """ + Test the `query_context_modified` function. + + The function is used to ensure guest users are not modifying the request payload on + embedded dashboard, preventing users from modifying it to access metrics different + from the ones stored in dashboard charts. + """ + query_context = mocker.MagicMock() + query_context.slice_.id = 42 + query_context.slice_.query_context = None + query_context.slice_.params_dict = { + "metrics": stored_metrics, + } + + query_context.form_data = { + "slice_id": 42, + "metrics": stored_metrics, + } + query_context.queries = [QueryObject(metrics=stored_metrics)] # type: ignore + assert not query_context_modified(query_context) + + +def test_query_context_modified_tampered( + mocker: MockFixture, + stored_metrics: list[AdhocMetric], +) -> None: + """ + Test the `query_context_modified` function when the request is tampered with. + + The function is used to ensure guest users are not modifying the request payload on + embedded dashboard, preventing users from modifying it to access metrics different + from the ones stored in dashboard charts. + """ + query_context = mocker.MagicMock() + query_context.slice_.id = 42 + query_context.slice_.query_context = None + query_context.slice_.params_dict = { + "metrics": stored_metrics, + } + + tampered_metrics = [ + { + "column": None, + "expressionType": "SQL", + "hasCustomLabel": False, + "label": "COUNT(*) + 2", + "sqlExpression": "COUNT(*) + 2", + } + ] + + query_context.form_data = { + "slice_id": 42, + "metrics": tampered_metrics, + } + query_context.queries = [QueryObject(metrics=tampered_metrics)] # type: ignore + assert query_context_modified(query_context) + + +def test_query_context_modified_native_filter(mocker: MockFixture) -> None: + """ + Test the `query_context_modified` function with a native filter request. + + A native filter request has no chart (slice) associated with it. + """ + query_context = mocker.MagicMock() + query_context.slice_ = None + + assert not query_context_modified(query_context) + + +def test_query_context_modified_mixed_chart(mocker: MockFixture) -> None: + """ + Test the `query_context_modified` function for a mixed chart request. + + The metrics in the mixed chart are a nested dictionary (due to `columns`), and need + to be serialized to JSON with the keys sorted in order to compare the request + metrics with the chart metrics. + """ + stored_metrics = [ + { + "optionName": "metric_vgops097wej_g8uff99zhk7", + "label": "AVG(num)", + "expressionType": "SIMPLE", + "column": {"column_name": "num", "type": "BIGINT(20)"}, + "aggregate": "AVG", + } + ] + # different order (remember, dicts have order!) + requested_metrics = [ + { + "aggregate": "AVG", + "column": {"column_name": "num", "type": "BIGINT(20)"}, + "expressionType": "SIMPLE", + "label": "AVG(num)", + "optionName": "metric_vgops097wej_g8uff99zhk7", + } + ] + + query_context = mocker.MagicMock() + query_context.slice_.id = 42 + query_context.slice_.query_context = None + query_context.slice_.params_dict = { + "metrics": stored_metrics, + } + + query_context.form_data = { + "slice_id": 42, + "metrics": requested_metrics, + } + query_context.queries = [QueryObject(metrics=requested_metrics)] # type: ignore + assert not query_context_modified(query_context)