From bc852e265366288be1ee7f957e73e127d9a7b762 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 2 Nov 2021 15:29:18 -0700 Subject: [PATCH] feat: improve logic in is_select --- superset/sql_parse.py | 19 +++++++++-- tests/unit_tests/sql_parse_tests.py | 49 ++++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 8b173d9b82543..bd1411f43bb4c 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -29,7 +29,7 @@ Token, TokenList, ) -from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace +from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace from sqlparse.utils import imt RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} @@ -133,7 +133,22 @@ def limit(self) -> Optional[int]: def is_select(self) -> bool: # make sure we strip comments; prevents a bug with coments in the CTE parsed = sqlparse.parse(self.strip_comments()) - return parsed[0].get_type() == "SELECT" + if parsed[0].get_type() == "SELECT": + return True + + if parsed[0].get_type() != "UNKNOWN": + return False + + # for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is allowed, + # and no DDL is allowed + if any(token.ttype == DDL for token in parsed[0]) or any( + token.ttype == DML and token.value != "SELECT" for token in parsed[0] + ): + return False + + return any( + token.ttype == DML and token.value == "SELECT" for token in parsed[0] + ) def is_valid_ctas(self) -> bool: parsed = sqlparse.parse(self.strip_comments()) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index a6e41e131cdc6..927ca1c096abf 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -15,10 +15,17 @@ # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name + +import sqlparse + from superset.sql_parse import ParsedQuery -def test_cte_with_comments(): +def test_cte_with_comments_is_select(): + """ + Some CTES with comments are not correctly identified as SELECTS. + """ sql = ParsedQuery( """WITH blah AS (SELECT * FROM core_dev.manager_team), @@ -44,3 +51,43 @@ def test_cte_with_comments(): INNER JOIN blah2 ON blah2.team_id = blah.team_id""" ) assert sql.is_select() + + +def test_cte_is_select(): + """ + Some CTEs are not correctly identified as SELECTS. + """ + # `AS(` gets parsed as a function + sql = ParsedQuery( + """WITH foo AS( +SELECT + FLOOR(__time TO WEEK) AS "week", + name, + COUNT(DISTINCT user_id) AS "unique_users" +FROM "druid"."my_table" +GROUP BY 1,2 +) +SELECT + f.week, + f.name, + f.unique_users +FROM foo f""" + ) + assert sql.is_select() + + +def test_unknown_select(): + """ + Test that `is_select` works when sqlparse fails to identify the type. + """ + sql = "WITH foo AS(SELECT 1) SELECT 1" + assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN" + assert ParsedQuery(sql).is_select() + + sql = "WITH foo AS(SELECT 1) INSERT INTO my_table (a) VALUES (1)" + assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN" + assert not ParsedQuery(sql).is_select() + + sql = "WITH foo AS(SELECT 1) DELETE FROM my_table" + assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN" + assert not ParsedQuery(sql).is_select()