Skip to content

Commit

Permalink
Improve table matching
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Mar 11, 2022
1 parent e816857 commit 00598fa
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
15 changes: 10 additions & 5 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def add_table_name(rls: TokenList, table: str) -> None:
tokens.extend(token.tokens)


def matches_table_name(token: Token, table: str) -> bool:
def matches_table_name(candidate: Token, table: str) -> bool:
"""
Returns if the token represents a reference to the table.
Expand All @@ -538,19 +538,24 @@ def matches_table_name(token: Token, table: str) -> bool:
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
classified as a keyword.
"""
candidate = token.value
if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])

target = sqlparse.parse(table)[0].tokens[0]
if not isinstance(target, Identifier):
target = Identifier([Token(Name, target.value)])

# match from right to left, splitting on the period, eg, schema.table == table
for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]):
if left != right:
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
if left.value != right.value:
return False

return True


def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
"""
Update a statement inpalce applying an RLS associated with a given table.
Update a statement inplace applying an RLS associated with a given table.
"""
# make sure the identifier has the table name
add_table_name(rls, table)
Expand Down
21 changes: 19 additions & 2 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
add_table_name,
has_table_query,
insert_rls,
matches_table_name,
ParsedQuery,
strip_comments_from_sql,
Table,
Expand Down Expand Up @@ -1206,6 +1207,7 @@ def test_sqlparse_issue_652():
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
("SELECT * FROM other_table", True),
("extract(HOUR from from_unixtime(hour_ts)", False),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
Expand Down Expand Up @@ -1393,7 +1395,7 @@ def test_has_table_query(sql: str, expected: bool) -> None:
),
],
)
def test_insert_rls(sql, table, rls, expected) -> None:
def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
Expand All @@ -1411,7 +1413,22 @@ def test_insert_rls(sql, table, rls, expected) -> None:
("false", "users", "false"),
],
)
def test_add_table_name(rls, table, expected) -> None:
def test_add_table_name(rls: str, table: str, expected: str) -> None:
condition = sqlparse.parse(rls)[0]
add_table_name(condition, table)
assert str(condition) == expected


@pytest.mark.parametrize(
"candidate,table,expected",
[
("table", "table", True),
("schema.table", "table", True),
("table", "schema.table", True),
('schema."my table"', '"my table"', True),
('schema."my.table"', '"my.table"', True),
],
)
def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
token = sqlparse.parse(candidate)[0].tokens[0]
assert matches_table_name(token, table) == expected

0 comments on commit 00598fa

Please sign in to comment.