Skip to content

Commit

Permalink
fix: improve normalization (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored May 27, 2024
1 parent 16085ef commit 04004e8
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ slow-test:
pytest -n auto tests

fast-test:
pytest -n auto -m "fast"
pytest -n auto tests/unit

local-test:
pytest -n auto -m "fast or local"
Expand Down
10 changes: 7 additions & 3 deletions sqlframe/base/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import sqlglot
from sqlglot import expressions as exp
from sqlglot.helper import flatten, is_iterable
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers

from sqlframe.base.decorators import normalize
from sqlframe.base.types import DataType
from sqlframe.base.util import get_func_from_session
from sqlframe.base.util import get_func_from_session, quote_preserving_alias_or_name

if t.TYPE_CHECKING:
from sqlframe.base._typing import ColumnOrLiteral, ColumnOrName
Expand Down Expand Up @@ -237,7 +239,7 @@ def column_expression(self) -> t.Union[exp.Column, exp.Literal]:

@property
def alias_or_name(self) -> str:
return self.expression.alias_or_name
return quote_preserving_alias_or_name(self.expression) # type: ignore

@classmethod
def ensure_literal(cls, value) -> Column:
Expand Down Expand Up @@ -266,7 +268,9 @@ def alias(self, name: str) -> Column:
from sqlframe.base.session import _BaseSession

dialect = _BaseSession().input_dialect
alias: exp.Expression = exp.parse_identifier(name, dialect=dialect)
alias: exp.Expression = normalize_identifiers(
exp.parse_identifier(name, dialect=dialect), dialect=dialect
)
new_expression = exp.Alias(
this=self.column_expression,
alias=alias.this if isinstance(alias, exp.Column) else alias,
Expand Down
11 changes: 8 additions & 3 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer.qualify_columns import quote_identifiers

from sqlframe.base.decorators import normalize
from sqlframe.base.operations import Operation, operation
from sqlframe.base.transforms import replace_id_value
from sqlframe.base.util import (
get_func_from_session,
get_tables_from_expression_with_join,
quote_preserving_alias_or_name,
)

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -410,7 +412,7 @@ def _get_outer_select_columns(cls, item: exp.Expression) -> t.List[Column]:

outer_select = item.find(exp.Select)
if outer_select:
return [col(x.alias_or_name) for x in outer_select.expressions]
return [col(quote_preserving_alias_or_name(x)) for x in outer_select.expressions]
return []

def _create_hash_from_expression(self, expression: exp.Expression) -> str:
Expand Down Expand Up @@ -505,7 +507,9 @@ def sql(
self.session.catalog.add_table(
cache_table_name,
{
expression.alias_or_name: expression.type.sql(dialect=dialect)
quote_preserving_alias_or_name(expression): expression.type.sql(
dialect=dialect
)
if expression.type
else "UNKNOWN"
for expression in select_expression.expressions
Expand Down Expand Up @@ -688,7 +692,7 @@ def join(
join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
self_columns = self._get_outer_select_columns(join_expression)
other_columns = self._get_outer_select_columns(other_df.expression)
join_columns = self._ensure_list_of_columns(on)
join_columns = self._ensure_and_normalize_cols(on)
# Determines the join clause and select columns to be used passed on what type of columns were provided for
# the join. The columns returned changes based on how the on expression is provided.
if how != "cross":
Expand Down Expand Up @@ -1324,6 +1328,7 @@ def toPandas(self) -> pd.DataFrame:
assert sqls[-1] is not None
return self.session._fetchdf(sqls[-1])

@normalize("name")
def createOrReplaceTempView(self, name: str) -> None:
self.session.temp_views[name] = self.copy()._convert_leaf_to_cte()

Expand Down
32 changes: 17 additions & 15 deletions sqlframe/base/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,33 @@
if t.TYPE_CHECKING:
from sqlframe.base.catalog import _BaseCatalog

CALLING_CLASS = t.TypeVar("CALLING_CLASS")

def normalize(normalize_kwargs: t.List[str]) -> t.Callable[[t.Callable], t.Callable]:

def normalize(normalize_kwargs: t.Union[str, t.List[str]]) -> t.Callable[[t.Callable], t.Callable]:
"""
Decorator used around DataFrame methods to indicate what type of operation is being performed from the
ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
included with the previous operation.
Ex: After a user does a join we want to allow them to select which columns for the different
tables that they want to carry through to the following operation. If we put that join in
a CTE preemptively then the user would not have a chance to select which column they want
in cases where there is overlap in names.
Decorator used to normalize identifiers in the kwargs of a method.
"""

def decorator(func: t.Callable) -> t.Callable:
@functools.wraps(func)
def wrapper(self: _BaseCatalog, *args, **kwargs) -> _BaseCatalog:
def wrapper(self: CALLING_CLASS, *args, **kwargs) -> CALLING_CLASS:
from sqlframe.base.session import _BaseSession

input_dialect = _BaseSession().input_dialect
kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
for kwarg in normalize_kwargs:
for kwarg in ensure_list(normalize_kwargs):
if kwarg in kwargs:
value = kwargs.get(kwarg)
if value:
expression = parse_one(value, dialect=self.session.input_dialect)
kwargs[kwarg] = normalize_identifiers(
expression, self.session.input_dialect
).sql(dialect=self.session.input_dialect)
expression = (
parse_one(value, dialect=input_dialect)
if isinstance(value, str)
else value
)
kwargs[kwarg] = normalize_identifiers(expression, input_dialect).sql(
dialect=input_dialect
)
return func(self, **kwargs)

wrapper.__wrapped__ = func # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions sqlframe/base/readerwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from sqlglot import exp
from sqlglot.helper import object_to_dict

from sqlframe.base.decorators import normalize

if sys.version_info >= (3, 11):
from typing import Self
else:
Expand Down Expand Up @@ -39,6 +41,7 @@ def __init__(self, spark: SESSION):
def session(self) -> SESSION:
return self._session

@normalize("tableName")
def table(self, tableName: str) -> DF:
if df := self.session.temp_views.get(tableName):
return df
Expand Down
1 change: 1 addition & 0 deletions sqlframe/base/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def _optimize(
self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
) -> exp.Expression:
dialect = dialect or self.output_dialect
normalize_identifiers(expression, dialect=self.input_dialect)
quote_identifiers_func(expression, dialect=dialect)
return optimize(expression, dialect=dialect, schema=self.catalog._schema)

Expand Down
21 changes: 20 additions & 1 deletion sqlframe/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,12 @@ def pandas_to_spark_schema(pandas_df: PandasDataFrame) -> types.StructType:
"""
from sqlframe.base import types

columns = list([x.replace("?column?", "unknown_column") for x in pandas_df.columns])
columns = list(
[
x.replace("?column?", f"unknown_column_{i}").replace("NULL", f"unknown_column_{i}")
for i, x in enumerate(pandas_df.columns)
]
)
d_types = list(pandas_df.dtypes)
p_schema = types.StructType(
[
Expand Down Expand Up @@ -249,3 +254,17 @@ def verify_pandas_installed():
raise ImportError(
"""Pandas is required for this functionality. `pip install "sqlframe[pandas]"` (also include your engine if needed) to install pandas."""
)


def quote_preserving_alias_or_name(col: t.Union[exp.Column, exp.Alias]) -> str:
from sqlframe.base.session import _BaseSession

if isinstance(col, exp.Alias):
col = col.args["alias"]
if isinstance(col, exp.Column):
col = col.copy()
col.set("table", None)
if isinstance(col, (exp.Identifier, exp.Column)):
return col.sql(dialect=_BaseSession().input_dialect)
# We may get things like `Null()` expression or maybe literals so we just return the alias or name in those cases
return col.alias_or_name
4 changes: 3 additions & 1 deletion sqlframe/snowflake/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def get_columns(self, table_name: str) -> t.Dict[str, exp.DataType]:
sql = f"SHOW COLUMNS IN TABLE {table.sql(dialect=self.session.input_dialect)}"
results = self.session._fetch_rows(sql)
return {
row["column_name"]: exp.DataType.build(
exp.column(row["column_name"], quoted=True).sql(
dialect=self.session.input_dialect
): exp.DataType.build(
json.loads(row["data_type"])["type"], dialect=self.session.input_dialect, udt=True
)
for row in results
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/engines/snowflake/test_snowflake_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ def test_session_from_config(cleanup_connector: SnowflakeConnection):
session = SnowflakeSession.builder.config("sqlframe.conn", cleanup_connector).getOrCreate()
columns = session.catalog.get_columns("sqlframe.db1.test_table")
assert columns == {
"COLA": exp.DataType.build("DECIMAL", dialect=session.output_dialect),
"COLB": exp.DataType.build("TEXT", dialect=session.output_dialect),
'"COLA"': exp.DataType.build("DECIMAL", dialect=session.output_dialect),
'"COLB"': exp.DataType.build("TEXT", dialect=session.output_dialect),
}
2 changes: 1 addition & 1 deletion tests/unit/standalone/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_sql_with_aggs(standalone_session: StandaloneSession, compare_sql: t.Cal
df = standalone_session.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
compare_sql(
df,
"WITH t26614157 AS (SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`), t38889420 AS (SELECT cola, colb FROM t26614157) SELECT cola, SUM(colb) FROM t38889420 GROUP BY cola",
"WITH t26614157 AS (SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`), t12367852 AS (SELECT `cola`, `colb` FROM t26614157) SELECT cola, SUM(colb) FROM t12367852 GROUP BY cola",
pretty=False,
optimize=False,
)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing as t

import pytest
from sqlglot import exp, parse_one

from sqlframe.base.util import quote_preserving_alias_or_name


@pytest.mark.parametrize(
"expression, expected",
[
("a", "a"),
("a AS b", "b"),
("`a`", "`a`"),
("`a` AS b", "b"),
("`a` AS `b`", "`b`"),
("`aB`", "`aB`"),
("`aB` AS c", "c"),
("`aB` AS `c`", "`c`"),
("`aB` AS `Cd`", "`Cd`"),
# We assume inputs have been normalized so `Cd` is returned as is instead of normalized `cd`
("`aB` AS Cd", "Cd"),
],
)
def test_quote_preserving_alias_or_name(expression: t.Union[exp.Column, exp.Alias], expected: str):
assert quote_preserving_alias_or_name(parse_one(expression, dialect="bigquery")) == expected # type: ignore

0 comments on commit 04004e8

Please sign in to comment.