diff --git a/superset/jinja_context.py b/superset/jinja_context.py index b0e29505a0d0b..713bee777c051 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -27,7 +27,8 @@ import dateutil from flask import current_app, g, has_request_context, request from flask_babel import gettext as _ -from jinja2 import DebugUndefined, Environment +from jinja2 import DebugUndefined, Environment, nodes +from jinja2.nodes import Call, Node from jinja2.sandbox import SandboxedEnvironment from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql.expression import bindparam @@ -888,6 +889,26 @@ def get_dataset_id_from_context(metric_key: str) -> int: raise SupersetTemplateException(exc_message) +def has_metric_macro(template_string: str, env: Environment) -> bool: + """ + Checks if a template string contains a metric macro. + + >>> has_metric_macro("{{ metric('my_metric') }}") + True + + """ + ast = env.parse(template_string) + + def visit_node(node: Node) -> bool: + return ( + isinstance(node, Call) + and isinstance(node.node, nodes.Name) + and node.node.name == "metric" + ) or any(visit_node(child) for child in node.iter_child_nodes()) + + return visit_node(ast) + + def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str: """ Given a metric key, returns its syntax. @@ -908,16 +929,32 @@ def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str: dataset = DatasetDAO.find_by_id(dataset_id) if not dataset: raise DatasetNotFoundError(f"Dataset ID {dataset_id} not found.") + metrics: dict[str, str] = { metric.metric_name: metric.expression for metric in dataset.metrics } - dataset_name = dataset.table_name - if metric := metrics.get(metric_key): - return metric - raise SupersetTemplateException( - _( - "Metric ``%(metric_name)s`` not found in %(dataset_name)s.", - metric_name=metric_key, - dataset_name=dataset_name, + if metric_key not in metrics: + raise SupersetTemplateException( + _( + "Metric ``%(metric_name)s`` not found in %(dataset_name)s.", + metric_name=metric_key, + dataset_name=dataset.table_name, + ) ) - ) + + definition = metrics[metric_key] + + env = SandboxedEnvironment(undefined=DebugUndefined) + context = {"metric": partial(safe_proxy, metric_macro)} + while has_metric_macro(definition, env): + old_definition = definition + template = env.from_string(definition) + try: + definition = template.render(context) + except RecursionError as ex: + raise SupersetTemplateException("Cyclic metric macro detected") from ex + + if definition == old_definition: + break + + return definition diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index c17c066b9dbdc..faba80812870a 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -544,6 +544,99 @@ def test_metric_macro_with_dataset_id(mocker: MockerFixture) -> None: mock_get_form_data.assert_not_called() +def test_metric_macro_recursive(mocker: MockerFixture) -> None: + """ + Test the ``metric_macro`` when the definition is recursive. + """ + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="a", expression="COUNT(*)"), + SqlMetric(metric_name="b", expression="{{ metric('a') }}"), + SqlMetric(metric_name="c", expression="{{ metric('b') }}"), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, + ) + assert metric_macro("c", 1) == "COUNT(*)" + + +def test_metric_macro_recursive_compound(mocker: MockerFixture) -> None: + """ + Test the ``metric_macro`` when the definition is compound. + """ + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="a", expression="SUM(*)"), + SqlMetric(metric_name="b", expression="COUNT(*)"), + SqlMetric( + metric_name="c", + expression="{{ metric('a') }} / {{ metric('b') }}", + ), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, + ) + assert metric_macro("c", 1) == "SUM(*) / COUNT(*)" + + +def test_metric_macro_recursive_cyclic(mocker: MockerFixture) -> None: + """ + Test the ``metric_macro`` when the definition is cyclic. + + In this case it should stop, and not go into an infinite loop. + """ + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="a", expression="{{ metric('c') }}"), + SqlMetric(metric_name="b", expression="{{ metric('a') }}"), + SqlMetric(metric_name="c", expression="{{ metric('b') }}"), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, + ) + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("c", 1) + assert str(excinfo.value) == "Cyclic metric macro detected" + + +def test_metric_macro_recursive_infinite(mocker: MockerFixture) -> None: + """ + Test the ``metric_macro`` when the definition is cyclic. + + In this case it should stop, and not go into an infinite loop. + """ + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="a", expression="{{ metric('a') }}"), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, + ) + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("a", 1) + assert str(excinfo.value) == "Cyclic metric macro detected" + + def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> None: """ Test the ``metric_macro`` when passing a dataset ID and an invalid key.