Skip to content

Commit

Permalink
fix timestamp after sqlglot change
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq committed Dec 16, 2024
1 parent c222b60 commit 7b9bda4
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 176 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"pyarrow>=10,<19",
"pyspark>=2,<3.6",
"pytest>=8.2.0,<8.4",
"pytest-forked",
"pytest-postgresql>=6,<7",
"pytest-xdist>=3.6,<3.7",
"pre-commit>=3.5;python_version=='3.8'",
Expand Down
33 changes: 33 additions & 0 deletions sqlframe/base/function_alternatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,39 @@ def first_always_ignore_nulls(col: ColumnOrName, ignorenulls: t.Optional[bool] =
return first(col)


def to_timestamp_with_time_zone(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
from sqlframe.base.session import _BaseSession

if format is not None:
return Column.invoke_expression_over_column(
col, expression.StrToTime, format=_BaseSession().format_time(format)
)

return Column.ensure_col(col).cast("timestamp with time zone", dialect="postgres")


def to_timestamp_tz(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
from sqlframe.base.session import _BaseSession

if format is not None:
return Column.invoke_expression_over_column(
col, expression.StrToTime, format=_BaseSession().format_time(format)
)

return Column.ensure_col(col).cast("timestamptz", dialect="duckdb")


def to_timestamp_just_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
from sqlframe.base.session import _BaseSession

if format is not None:
return Column.invoke_expression_over_column(
col, expression.StrToTime, format=_BaseSession().format_time(format)
)

return Column.ensure_col(col).cast("datetime", dialect="bigquery")


def bitwise_not_from_bitnot(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "BITNOT")

Expand Down
2 changes: 1 addition & 1 deletion sqlframe/base/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
col, expression.StrToTime, format=_BaseSession().format_time(format)
)

return Column.ensure_col(col).cast("timestamp")
return Column.ensure_col(col).cast("timestampltz")


@meta()
Expand Down
2 changes: 2 additions & 0 deletions sqlframe/base/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ def _to_value(cls, value: t.Any) -> t.Any:
return cls._to_row(list(value.keys()), list(value.values()))
elif isinstance(value, (list, set, tuple)) and value:
return [cls._to_value(x) for x in value]
elif isinstance(value, datetime.datetime):
return value.replace(tzinfo=None)
return value

@classmethod
Expand Down
1 change: 1 addition & 0 deletions sqlframe/bigquery/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
_is_string_using_typeof_string as _is_string,
array_append_using_array_cat as array_append,
endswith_with_underscore as endswith,
to_timestamp_just_timestamp as to_timestamp,
)


Expand Down
1 change: 1 addition & 0 deletions sqlframe/duckdb/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@
endswith_with_underscore as endswith,
last_day_with_cast as last_day,
regexp_replace_global_option as regexp_replace,
to_timestamp_tz as to_timestamp,
)
1 change: 1 addition & 0 deletions sqlframe/duckdb/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ from sqlframe.base.function_alternatives import ( # noqa
try_element_at_zero_based as try_element_at,
to_unix_timestamp_include_default_format as to_unix_timestamp,
regexp_replace_global_option as regexp_replace,
to_timestamp_tz as to_timestamp,
)
from sqlframe.base.functions import (
abs as abs,
Expand Down
1 change: 1 addition & 0 deletions sqlframe/postgres/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@
endswith_using_like as endswith,
last_day_with_cast as last_day,
regexp_replace_global_option as regexp_replace,
to_timestamp_with_time_zone as to_timestamp,
)
3 changes: 2 additions & 1 deletion tests/common_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def pyspark_session(tmp_path_factory, gen_tpcds: t.List[Path]) -> PySparkSession
.config("spark.sql.warehouse.dir", data_dir)
.config("spark.driver.extraJavaOptions", f"-Dderby.system.home={derby_dir}")
.config("spark.sql.shuffle.partitions", 1)
.config("spark.sql.session.timeZone", "America/Los_Angeles")
.config("spark.sql.session.timeZone", "UTC")
.master("local[1]")
.appName("Unit-tests")
.getOrCreate()
Expand Down Expand Up @@ -225,6 +225,7 @@ def snowflake_connection() -> SnowflakeConnection:
@pytest.fixture
def snowflake_session(snowflake_connection: SnowflakeConnection) -> SnowflakeSession:
session = SnowflakeSession(snowflake_connection)
session._execute("ALTER SESSION SET TIMEZONE = 'UTC'")
session._execute("CREATE SCHEMA IF NOT EXISTS db1")
session._execute("CREATE TABLE IF NOT EXISTS db1.table1 (id INTEGER, name VARCHAR(100))")
session._execute(
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def pytest_collection_modifyitems(items, *args, **kwargs):
def set_tz():
import os

os.environ["TZ"] = "US/Pacific"
os.environ["TZ"] = "UTC"
time.tzset()
yield
del os.environ["TZ"]
Expand Down
20 changes: 10 additions & 10 deletions tests/integration/engines/snowflake/test_snowflake_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_print_schema_basic(snowflake_employee: SnowflakeDataFrame, capsys):
== """
root
|-- employee_id: decimal(38, 0) (nullable = true)
|-- fname: string (nullable = true)
|-- lname: string (nullable = true)
|-- fname: varchar(16777216) (nullable = true)
|-- lname: varchar(16777216) (nullable = true)
|-- age: decimal(38, 0) (nullable = true)
|-- store_id: decimal(38, 0) (nullable = true)""".strip()
)
Expand All @@ -70,9 +70,9 @@ def test_print_schema_nested(snowflake_datatypes: SnowflakeDataFrame, capsys):
root
|-- bigint_col: decimal(38, 0) (nullable = true)
|-- double_col: float (nullable = true)
|-- string_col: string (nullable = true)
|-- map_string_bigint__col: map<string, decimal(38, 0)> (nullable = true)
| |-- key: string (nullable = true)
|-- string_col: varchar(16777216) (nullable = true)
|-- map_string_bigint__col: map<varchar(16777216), decimal(38, 0)> (nullable = true)
| |-- key: varchar(16777216) (nullable = true)
| |-- value: decimal(38, 0) (nullable = true)
|-- array_struct_a_bigint_b_bigint__: array<object<a decimal(38, 0), b decimal(38, 0)>> (nullable = true)
| |-- element: object<a decimal(38, 0), b decimal(38, 0)> (nullable = true)
Expand All @@ -83,7 +83,7 @@ def test_print_schema_nested(snowflake_datatypes: SnowflakeDataFrame, capsys):
|-- struct_a_bigint__col: object<a decimal(38, 0)> (nullable = true)
| |-- a: decimal(38, 0) (nullable = true)
|-- date_col: date (nullable = true)
|-- timestamp_col: timestamp_ntz (nullable = true)
|-- timestamp_col: timestamp (nullable = true)
|-- timestamptz_col: timestamp (nullable = true)
|-- boolean_col: boolean (nullable = true)""".strip()
)
Expand All @@ -96,9 +96,9 @@ def test_schema(snowflake_employee: SnowflakeDataFrame):
assert struct_fields[0].name == "employee_id"
assert struct_fields[0].dataType == types.DecimalType(38, 0)
assert struct_fields[1].name == "fname"
assert struct_fields[1].dataType == types.StringType()
assert struct_fields[1].dataType == types.VarcharType(16777216)
assert struct_fields[2].name == "lname"
assert struct_fields[2].dataType == types.StringType()
assert struct_fields[2].dataType == types.VarcharType(16777216)
assert struct_fields[3].name == "age"
assert struct_fields[3].dataType == types.DecimalType(38, 0)
assert struct_fields[4].name == "store_id"
Expand All @@ -114,10 +114,10 @@ def test_schema_nested(snowflake_datatypes: SnowflakeDataFrame):
assert struct_fields[1].name == "double_col"
assert struct_fields[1].dataType == types.FloatType()
assert struct_fields[2].name == "string_col"
assert struct_fields[2].dataType == types.StringType()
assert struct_fields[2].dataType == types.VarcharType(16777216)
assert struct_fields[3].name == "map_string_bigint__col"
assert struct_fields[3].dataType == types.MapType(
types.StringType(),
types.VarcharType(16777216),
types.DecimalType(38, 0),
)
assert struct_fields[4].name == "array_struct_a_bigint_b_bigint__"
Expand Down
Loading

0 comments on commit 7b9bda4

Please sign in to comment.