diff --git a/Makefile b/Makefile index b4fbd4c..bbb1cc8 100644 --- a/Makefile +++ b/Makefile @@ -22,6 +22,9 @@ duckdb-test: snowflake-test: pytest -n auto -m "snowflake" +databricks-test: + pytest -n auto -m "databricks" + style: pre-commit run --all-files diff --git a/docs/databricks.md b/docs/databricks.md index eff6a8b..2ab042a 100644 --- a/docs/databricks.md +++ b/docs/databricks.md @@ -1,5 +1,3 @@ -from test import auth_type - # Databricks (In Development) ## Installation diff --git a/sqlframe/base/function_alternatives.py b/sqlframe/base/function_alternatives.py index 3c7bdad..43ac72f 100644 --- a/sqlframe/base/function_alternatives.py +++ b/sqlframe/base/function_alternatives.py @@ -1220,6 +1220,11 @@ def get_json_object_cast_object(col: ColumnOrName, path: str) -> Column: return get_json_object(col_func(col).cast("variant"), path) +def get_json_object_using_function(col: ColumnOrName, path: str) -> Column: + lit = get_func_from_session("lit") + return Column.invoke_anonymous_function(col, "GET_JSON_OBJECT", lit(path)) + + def create_map_with_cast(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: from sqlframe.base.functions import create_map diff --git a/sqlframe/base/functions.py b/sqlframe/base/functions.py index f6c52f2..6800b2e 100644 --- a/sqlframe/base/functions.py +++ b/sqlframe/base/functions.py @@ -2173,7 +2173,7 @@ def current_database() -> Column: current_schema = current_database -@meta(unsupported_engines="*") +@meta(unsupported_engines=["*", "databricks"]) def current_timezone() -> Column: return Column.invoke_anonymous_function(None, "current_timezone") @@ -2261,7 +2261,7 @@ def get(col: ColumnOrName, index: t.Union[ColumnOrName, int]) -> Column: return Column.invoke_anonymous_function(col, "get", index) -@meta(unsupported_engines="*") +@meta(unsupported_engines=["*", "databricks"]) def get_active_spark_context() -> SparkContext: """Raise RuntimeError if SparkContext is not initialized, otherwise, returns the active SparkContext.""" @@ -2778,7 +2778,7 @@ def isnotnull(col: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col, "isnotnull") -@meta(unsupported_engines="*") +@meta(unsupported_engines=["*", "databricks"]) def java_method(*cols: ColumnOrName) -> Column: """ Calls a method with reflection. @@ -3050,7 +3050,7 @@ def ln(col: ColumnOrName) -> Column: return Column.invoke_expression_over_column(col, expression.Ln) -@meta(unsupported_engines="*") +@meta(unsupported_engines=["*", "databricks"]) def localtimestamp() -> Column: """ Returns the current timestamp without time zone at the start of query evaluation @@ -3080,7 +3080,7 @@ def localtimestamp() -> Column: return Column.invoke_anonymous_function(None, "localtimestamp") -@meta(unsupported_engines="*") +@meta(unsupported_engines=["*", "databricks"]) def make_dt_interval( days: t.Optional[ColumnOrName] = None, hours: t.Optional[ColumnOrName] = None, @@ -3227,7 +3227,7 @@ def make_timestamp( ) -@meta(unsupported_engines="*") +@meta(unsupported_engines=["*", "databricks"]) def make_timestamp_ltz( years: ColumnOrName, months: ColumnOrName, @@ -3354,7 +3354,7 @@ def make_timestamp_ntz( ) -@meta(unsupported_engines="*") +@meta(unsupported_engines=["*", "databricks"]) def make_ym_interval( years: t.Optional[ColumnOrName] = None, months: t.Optional[ColumnOrName] = None, @@ -3922,7 +3922,7 @@ def printf(format: ColumnOrName, *cols: ColumnOrName) -> Column: return Column.invoke_anonymous_function(format, "printf", *cols) -@meta(unsupported_engines=["*", "spark"]) +@meta(unsupported_engines=["*", "spark", "databricks"]) def product(col: ColumnOrName) -> Column: """ Aggregate function: returns the product of the values in a group. @@ -3961,7 +3961,7 @@ def product(col: ColumnOrName) -> Column: reduce = aggregate -@meta(unsupported_engines="*") +@meta(unsupported_engines=["*", "databricks"]) def reflect(*cols: ColumnOrName) -> Column: """ Calls a method with reflection. @@ -5046,7 +5046,7 @@ def to_str(value: t.Any) -> t.Optional[str]: return str(value) -@meta(unsupported_engines="*") +@meta(unsupported_engines=["*", "databricks"]) def to_timestamp_ltz( timestamp: ColumnOrName, format: t.Optional[ColumnOrName] = None, diff --git a/sqlframe/databricks/catalog.py b/sqlframe/databricks/catalog.py index 0c4a884..0347291 100644 --- a/sqlframe/databricks/catalog.py +++ b/sqlframe/databricks/catalog.py @@ -26,7 +26,6 @@ class DatabricksCatalog( - SetCurrentCatalogFromUseMixin["DatabricksSession", "DatabricksDataFrame"], GetCurrentCatalogFromFunctionMixin["DatabricksSession", "DatabricksDataFrame"], GetCurrentDatabaseFromFunctionMixin["DatabricksSession", "DatabricksDataFrame"], ListDatabasesFromInfoSchemaMixin["DatabricksSession", "DatabricksDataFrame"], @@ -38,6 +37,15 @@ class DatabricksCatalog( CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog") UPPERCASE_INFO_SCHEMA = True + def setCurrentCatalog(self, catalogName: str) -> None: + self.session._collect( + exp.Use( + kind=exp.Var(this=exp.to_identifier("CATALOG")), + this=exp.parse_identifier(catalogName, dialect=self.session.input_dialect), + ), + quote_identifiers=False, + ) + def listFunctions( self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None ) -> t.List[Function]: @@ -106,7 +114,9 @@ def listFunctions( ) functions = [ Function( - name=normalize_string(x["function"], from_dialect="execution", to_dialect="output"), + name=normalize_string( + x["function"].split(".")[-1], from_dialect="execution", to_dialect="output" + ), catalog=normalize_string( schema.catalog, from_dialect="execution", to_dialect="output" ), diff --git a/sqlframe/databricks/dataframe.py b/sqlframe/databricks/dataframe.py index bf460c3..eba9704 100644 --- a/sqlframe/databricks/dataframe.py +++ b/sqlframe/databricks/dataframe.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import sys import typing as t from sqlframe.base.catalog import Column as CatalogColumn @@ -52,7 +51,9 @@ def _typed_columns(self) -> t.List[CatalogColumn]: columns.append( CatalogColumn( name=normalize_string( - row.col_name, from_dialect="execution", to_dialect="output" + row.col_name, + from_dialect="execution", + to_dialect="output", ), dataType=normalize_string( row.data_type, diff --git a/sqlframe/databricks/functions.py b/sqlframe/databricks/functions.py index 0a53063..2b613c8 100644 --- a/sqlframe/databricks/functions.py +++ b/sqlframe/databricks/functions.py @@ -19,4 +19,5 @@ arrays_overlap_renamed as arrays_overlap, _is_string_using_typeof_string_lcase as _is_string, try_element_at_zero_based as try_element_at, + get_json_object_using_function as get_json_object, ) diff --git a/sqlframe/databricks/session.py b/sqlframe/databricks/session.py index 2dd728f..70ed035 100644 --- a/sqlframe/databricks/session.py +++ b/sqlframe/databricks/session.py @@ -44,7 +44,20 @@ def __init__( from databricks import sql if not hasattr(self, "_conn"): - super().__init__(conn or sql.connect(server_hostname, http_path, access_token)) + super().__init__( + conn or sql.connect(server_hostname, http_path, access_token, disable_pandas=True) + ) + + @classmethod + def _try_get_map(cls, value: t.Any) -> t.Optional[t.Dict[str, t.Any]]: + if ( + value + and isinstance(value, list) + and all(isinstance(item, tuple) for item in value) + and all(len(item) == 2 for item in value) + ): + return dict(value) + return None class Builder(_BaseSession.Builder): DEFAULT_EXECUTION_DIALECT = "databricks" diff --git a/tests/common_fixtures.py b/tests/common_fixtures.py index e61dbbe..e5c8aba 100644 --- a/tests/common_fixtures.py +++ b/tests/common_fixtures.py @@ -12,6 +12,7 @@ from sqlframe.base.session import _BaseSession from sqlframe.bigquery.session import BigQuerySession +from sqlframe.databricks.session import DatabricksSession from sqlframe.duckdb.session import DuckDBSession from sqlframe.postgres.session import PostgresSession from sqlframe.redshift.session import RedshiftSession @@ -22,6 +23,7 @@ from sqlframe.standalone.session import StandaloneSession if t.TYPE_CHECKING: + from databricks.sql import Connection as DatabricksConnection from google.cloud.bigquery.dbapi.connection import ( Connection as BigQueryConnection, ) @@ -231,6 +233,31 @@ def snowflake_session(snowflake_connection: SnowflakeConnection) -> SnowflakeSes return session +@pytest.fixture(scope="session") +def databricks_connection() -> DatabricksConnection: + from databricks.sql import connect + + conn = connect( + server_hostname=os.environ["SQLFRAME_DATABRICKS_SERVER_HOSTNAME"], + http_path=os.environ["SQLFRAME_DATABRICKS_HTTP_PATH"], + access_token=os.environ["SQLFRAME_DATABRICKS_ACCESS_TOKEN"], + auth_type="access_token", + catalog=os.environ["SQLFRAME_DATABRICKS_CATALOG"], + schema=os.environ["SQLFRAME_DATABRICKS_SCHEMA"], + _disable_pandas=True, + ) + return conn + + +@pytest.fixture +def databricks_session(databricks_connection: DatabricksConnection) -> DatabricksSession: + session = DatabricksSession(databricks_connection) + session._execute("CREATE SCHEMA IF NOT EXISTS db1") + session._execute("CREATE TABLE IF NOT EXISTS db1.table1 (id INTEGER, name VARCHAR(100))") + session._execute("CREATE OR REPLACE FUNCTION db1.add(x INT, y INT) RETURNS INT RETURN x + y") + return session + + @pytest.fixture(scope="module") def _employee_data() -> EmployeeData: return [ diff --git a/tests/integration/engines/databricks/__init__.py b/tests/integration/engines/databricks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/engines/databricks/test_databricks_catalog.py b/tests/integration/engines/databricks/test_databricks_catalog.py new file mode 100644 index 0000000..606acfc --- /dev/null +++ b/tests/integration/engines/databricks/test_databricks_catalog.py @@ -0,0 +1,338 @@ +import typing as t + +import pytest + +from sqlframe.base.catalog import CatalogMetadata, Column, Database, Function, Table +from sqlframe.databricks.session import DatabricksSession + +pytest_plugins = ["tests.integration.fixtures"] +pytestmark = [ + pytest.mark.databricks, + pytest.mark.xdist_group("databricks_tests"), +] + + +@pytest.fixture +def reset_catalog(databricks_session: DatabricksSession) -> t.Iterator[None]: + yield + databricks_session.catalog.setCurrentCatalog("sqlframe") + databricks_session.catalog.setCurrentDatabase("db1") + + +@pytest.fixture +def reset_database(databricks_session: DatabricksSession) -> t.Iterator[None]: + yield + databricks_session.catalog.setCurrentDatabase("db1") + + +def test_current_catalog(databricks_session: DatabricksSession): + assert databricks_session.catalog.currentCatalog() == "sqlframe" + + +def test_set_current_catalog(databricks_session: DatabricksSession, reset_catalog): + assert databricks_session.catalog.currentCatalog() == "sqlframe" + databricks_session.catalog.setCurrentCatalog("catalog1") + assert databricks_session.catalog.currentCatalog() == "catalog1" + + +def test_list_catalogs(databricks_session: DatabricksSession): + assert sorted(databricks_session.catalog.listCatalogs(), key=lambda x: x.name) == [ + CatalogMetadata(name="sqlframe", description=None) + ] + + +def test_current_database(databricks_session: DatabricksSession): + assert databricks_session.catalog.currentDatabase() == "db1" + + +def test_set_current_database(databricks_session: DatabricksSession, reset_database): + assert databricks_session.catalog.currentDatabase() == "db1" + databricks_session.catalog.setCurrentDatabase("default") + assert databricks_session.catalog.currentDatabase() == "default" + + +def test_list_databases(databricks_session: DatabricksSession): + assert sorted( + databricks_session.catalog.listDatabases(), key=lambda x: (x.catalog, x.name) + ) == [ + Database(name="db1", catalog="sqlframe", description=None, locationUri=""), + Database(name="default", catalog="sqlframe", description=None, locationUri=""), + Database(name="information_schema", catalog="sqlframe", description=None, locationUri=""), + ] + + +def test_list_databases_pattern(databricks_session: DatabricksSession): + assert sorted( + databricks_session.catalog.listDatabases("db*"), key=lambda x: (x.catalog, x.name) + ) == [ + Database(name="db1", catalog="sqlframe", description=None, locationUri=""), + ] + + +def test_get_database_no_match(databricks_session: DatabricksSession): + with pytest.raises(ValueError): + assert databricks_session.catalog.getDatabase("nonexistent") + + +def test_get_database_name_only(databricks_session: DatabricksSession): + assert databricks_session.catalog.getDatabase("db1") == Database( + name="db1", catalog="sqlframe", description=None, locationUri="" + ) + + +def test_get_database_name_and_catalog(databricks_session: DatabricksSession): + assert databricks_session.catalog.getDatabase("sqlframe.db1") == Database( + name="db1", catalog="sqlframe", description=None, locationUri="" + ) + + +def test_database_exists_does_exist(databricks_session: DatabricksSession): + assert databricks_session.catalog.databaseExists("db1") is True + + +def test_database_exists_does_not_exist(databricks_session: DatabricksSession): + assert databricks_session.catalog.databaseExists("nonexistent") is False + + +def test_list_tables_no_args(databricks_session: DatabricksSession): + assert sorted( + databricks_session.catalog.listTables(), key=lambda x: (x.catalog, x.database, x.name) + ) == [ + Table( + name="table1", + catalog="sqlframe", + namespace=["db1"], + description=None, + tableType="MANAGED", + isTemporary=False, + ) + ] + + +def test_list_tables_db_no_catalog(databricks_session: DatabricksSession): + assert sorted( + databricks_session.catalog.listTables("db1"), key=lambda x: (x.catalog, x.database, x.name) + ) == [ + Table( + name="table1", + catalog="sqlframe", + namespace=["db1"], + description=None, + tableType="MANAGED", + isTemporary=False, + ) + ] + + +def test_list_tables_db_and_catalog(databricks_session: DatabricksSession): + assert sorted( + databricks_session.catalog.listTables("sqlframe.db1"), + key=lambda x: (x.catalog, x.database, x.name), + ) == [ + Table( + name="table1", + catalog="sqlframe", + namespace=["db1"], + description=None, + tableType="MANAGED", + isTemporary=False, + ) + ] + + +def test_list_tables_pattern(databricks_session: DatabricksSession): + assert Table( + name="table1", + catalog="sqlframe", + namespace=["db1"], + description=None, + tableType="MANAGED", + isTemporary=False, + ) in databricks_session.catalog.listTables(pattern="tab*") + + +def test_get_table(databricks_session: DatabricksSession): + assert databricks_session.catalog.getTable("sqlframe.db1.table1") == Table( + name="table1", + catalog="sqlframe", + namespace=["db1"], + description=None, + tableType="MANAGED", + isTemporary=False, + ) + + +def test_get_table_not_exists(databricks_session: DatabricksSession): + with pytest.raises(ValueError): + assert databricks_session.catalog.getTable("dev.db1.nonexistent") + + +def test_list_functions(databricks_session: DatabricksSession): + assert databricks_session.catalog.listFunctions() == [ + Function( + name="add", + catalog="sqlframe", + namespace=["db1"], + description=None, + className="", + isTemporary=False, + ) + ] + + +def test_list_functions_pattern(databricks_session: DatabricksSession): + assert databricks_session.catalog.listFunctions(dbName="db1", pattern="ad*") == [ + Function( + name="add", + catalog="sqlframe", + namespace=["db1"], + description=None, + className="", + isTemporary=False, + ) + ] + + +def test_function_exists_does_exist(databricks_session: DatabricksSession): + assert databricks_session.catalog.functionExists("add", dbName="sqlframe.db1") is True + + +def test_function_exists_does_not_exist(databricks_session: DatabricksSession): + assert databricks_session.catalog.functionExists("nonexistent") is False + + +def test_get_function_exists(databricks_session: DatabricksSession): + assert databricks_session.catalog.getFunction("sqlframe.db1.add") == Function( + name="add", + catalog="sqlframe", + namespace=["db1"], + description=None, + className="", + isTemporary=False, + ) + + +def test_get_function_not_exists(databricks_session: DatabricksSession): + with pytest.raises(ValueError): + assert databricks_session.catalog.getFunction("sqlframe.db1.nonexistent") + + +def test_list_columns(databricks_session: DatabricksSession): + assert sorted( + databricks_session.catalog.listColumns("sqlframe.db1.table1"), key=lambda x: x.name + ) == [ + Column( + name="id", + description=None, + dataType="INT", + nullable=True, + isPartition=False, + isBucket=False, + ), + Column( + name="name", + description=None, + dataType="VARCHAR(100)", + nullable=True, + isPartition=False, + isBucket=False, + ), + ] + + +def test_list_columns_use_db_name(databricks_session: DatabricksSession): + assert sorted( + databricks_session.catalog.listColumns("table1", dbName="sqlframe.db1"), + key=lambda x: x.name, + ) == [ + Column( + name="id", + description=None, + dataType="INT", + nullable=True, + isPartition=False, + isBucket=False, + ), + Column( + name="name", + description=None, + dataType="VARCHAR(100)", + nullable=True, + isPartition=False, + isBucket=False, + ), + ] + + +def test_table_exists_table_name_only(databricks_session: DatabricksSession): + assert databricks_session.catalog.tableExists("sqlframe.db1.table1") is True + + +def test_table_exists_table_name_and_db_name(databricks_session: DatabricksSession): + assert databricks_session.catalog.tableExists("table1", dbName="sqlframe.db1") is True + + +def test_table_not_exists(databricks_session: DatabricksSession): + assert databricks_session.catalog.tableExists("nonexistent") is False + + +def test_create_external_table(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.createExternalTable( + "table1", "sqlframe.default", "path/to/table" + ) + + +def test_create_table(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.createTable("table1", "sqlframe.default") + + +def test_drop_temp_view(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.dropTempView("view1") + + +def test_drop_global_temp_view(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.dropGlobalTempView("view1") + + +def test_register_function(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.registerFunction("function1", lambda x: x) + + +def test_is_cached(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.isCached("table1") + + +def test_cache_table(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.cacheTable("table1") + + +def test_uncache_table(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.uncacheTable("table1") + + +def test_clear_cache(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.clearCache() + + +def test_refresh_table(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.refreshTable("table1") + + +def test_recover_partitions(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.recoverPartitions("table1") + + +def test_refresh_by_path(databricks_session: DatabricksSession): + with pytest.raises(NotImplementedError): + databricks_session.catalog.refreshByPath("path/to/table") diff --git a/tests/integration/engines/databricks/test_databricks_dataframe.py b/tests/integration/engines/databricks/test_databricks_dataframe.py new file mode 100644 index 0000000..e942c9a --- /dev/null +++ b/tests/integration/engines/databricks/test_databricks_dataframe.py @@ -0,0 +1,169 @@ +import datetime + +import pytest + +from sqlframe.base import types +from sqlframe.databricks import DatabricksDataFrame, DatabricksSession + +pytest_plugins = ["tests.integration.fixtures"] +pytestmark = [ + pytest.mark.databricks, + pytest.mark.xdist_group("databricks_tests"), +] + + +@pytest.fixture() +def databricks_datatypes(databricks_session: DatabricksSession) -> DatabricksDataFrame: + return databricks_session.createDataFrame( + [ + ( + 1, + 2.0, + "foo", + {"a": 1}, + [types.Row(a=1, b=2)], + [1, 2, 3], + types.Row(a=1), + datetime.date(2022, 1, 1), + datetime.datetime(2022, 1, 1, 0, 0, 0), + datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc), + True, + ) + ], + [ + "bigint_col", + "double_col", + "string_col", + "map_col", + "array>", + "array_col", + "struct_col", + "date_col", + "timestamp_col", + "timestamptz_col", + "boolean_col", + ], + ) + + +def test_print_schema_basic(databricks_employee: DatabricksDataFrame, capsys): + databricks_employee.printSchema() + captured = capsys.readouterr() + assert ( + captured.out.strip() + == """ +root + |-- employee_id: int (nullable = true) + |-- fname: string (nullable = true) + |-- lname: string (nullable = true) + |-- age: int (nullable = true) + |-- store_id: int (nullable = true)""".strip() + ) + + +def test_print_schema_nested(databricks_datatypes: DatabricksDataFrame, capsys): + databricks_datatypes.printSchema() + captured = capsys.readouterr() + assert ( + captured.out.strip() + == """ +root + |-- bigint_col: bigint (nullable = true) + |-- double_col: double (nullable = true) + |-- string_col: string (nullable = true) + |-- `map_col`: map (nullable = true) + | |-- key: string (nullable = true) + | |-- value: bigint (nullable = true) + |-- `array>`: array> (nullable = true) + | |-- element: struct (nullable = true) + | | |-- a: bigint (nullable = true) + | | |-- b: bigint (nullable = true) + |-- `array_col`: array (nullable = true) + | |-- element: bigint (nullable = true) + |-- `struct_col`: struct (nullable = true) + | |-- a: bigint (nullable = true) + |-- date_col: date (nullable = true) + |-- timestamp_col: timestamp (nullable = true) + |-- timestamptz_col: timestamp (nullable = true) + |-- boolean_col: boolean (nullable = true)""".strip() + ) + + +def test_schema(databricks_employee: DatabricksDataFrame): + assert databricks_employee.schema == types.StructType( + [ + types.StructField( + "employee_id", + types.IntegerType(), + ), + types.StructField( + "fname", + types.StringType(), + ), + types.StructField( + "lname", + types.StringType(), + ), + types.StructField( + "age", + types.IntegerType(), + ), + types.StructField( + "store_id", + types.IntegerType(), + ), + ] + ) + + +def test_schema_nested(databricks_datatypes: DatabricksDataFrame): + assert isinstance(databricks_datatypes.schema, types.StructType) + struct_fields = list(databricks_datatypes.schema) + assert len(struct_fields) == 11 + assert struct_fields[0].name == "bigint_col" + assert struct_fields[0].dataType == types.LongType() + assert struct_fields[1].name == "double_col" + assert struct_fields[1].dataType == types.DoubleType() + assert struct_fields[2].name == "string_col" + assert struct_fields[2].dataType == types.StringType() + assert struct_fields[3].name == "`map_col`" + assert struct_fields[3].dataType == types.MapType( + types.StringType(), + types.LongType(), + ) + assert struct_fields[4].name == "`array>`" + assert struct_fields[4].dataType == types.ArrayType( + types.StructType( + [ + types.StructField( + "a", + types.LongType(), + ), + types.StructField( + "b", + types.LongType(), + ), + ] + ), + ) + assert struct_fields[5].name == "`array_col`" + assert struct_fields[5].dataType == types.ArrayType( + types.LongType(), + ) + assert struct_fields[6].name == "`struct_col`" + assert struct_fields[6].dataType == types.StructType( + [ + types.StructField( + "a", + types.LongType(), + ), + ] + ) + assert struct_fields[7].name == "date_col" + assert struct_fields[7].dataType == types.DateType() + assert struct_fields[8].name == "timestamp_col" + assert struct_fields[8].dataType == types.TimestampType() + assert struct_fields[9].name == "timestamptz_col" + assert struct_fields[9].dataType == types.TimestampType() + assert struct_fields[10].name == "boolean_col" + assert struct_fields[10].dataType == types.BooleanType() diff --git a/tests/integration/engines/databricks/test_databricks_session.py b/tests/integration/engines/databricks/test_databricks_session.py new file mode 100644 index 0000000..5dba5c5 --- /dev/null +++ b/tests/integration/engines/databricks/test_databricks_session.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import os +import typing as t + +import pytest +from sqlglot import exp + +from sqlframe.databricks.session import DatabricksSession + +if t.TYPE_CHECKING: + from databricks.sql import Connection as DatabricksConnection + +pytest_plugins = ["tests.common_fixtures"] +pytestmark = [ + pytest.mark.databricks, + pytest.mark.xdist_group("databricks_tests"), +] + + +@pytest.fixture +def cleanup_connector() -> t.Iterator[DatabricksConnection]: + from databricks.sql import connect + + conn = connect( + server_hostname=os.environ["SQLFRAME_DATABRICKS_SERVER_HOSTNAME"], + http_path=os.environ["SQLFRAME_DATABRICKS_HTTP_PATH"], + access_token=os.environ["SQLFRAME_DATABRICKS_ACCESS_TOKEN"], + auth_type="access_token", + catalog=os.environ["SQLFRAME_DATABRICKS_CATALOG"], + schema=os.environ["SQLFRAME_DATABRICKS_SCHEMA"], + ) + conn.cursor().execute("CREATE SCHEMA IF NOT EXISTS db1") + conn.cursor().execute("DROP TABLE IF EXISTS db1.test_table") + conn.cursor().execute("CREATE TABLE IF NOT EXISTS db1.test_table (cola INT, colb STRING)") + + yield conn + conn.cursor().execute("DROP TABLE IF EXISTS db1.test_table") + + +def test_session_from_config(cleanup_connector: DatabricksConnection): + session = DatabricksSession.builder.config("sqlframe.conn", cleanup_connector).getOrCreate() + columns = session.catalog.get_columns("sqlframe.db1.test_table") + assert columns == { + "cola": exp.DataType.build("INT", dialect=session.output_dialect), + "colb": exp.DataType.build("STRING", dialect=session.output_dialect), + } diff --git a/tests/integration/engines/test_engine_session.py b/tests/integration/engines/test_engine_session.py index 828dc4f..5cb2cea 100644 --- a/tests/integration/engines/test_engine_session.py +++ b/tests/integration/engines/test_engine_session.py @@ -21,9 +21,12 @@ def cleanup_session(get_session: t.Callable[[], _BaseSession]) -> t.Iterator[_Ba def test_session(cleanup_session: _BaseSession): session = cleanup_session session._execute("DROP TABLE IF EXISTS test_table") + sql = "CREATE TABLE test_table (cola INT, colb STRING, `col with space` STRING)" + if session.execution_dialect == Dialect.get_or_raise("databricks"): + sql += " TBLPROPERTIES('delta.columnMapping.mode' = 'name');" session._collect( parse_one( - "CREATE TABLE test_table (cola INT, colb STRING, `col with space` STRING)", + sql, dialect="spark", ) ) diff --git a/tests/integration/engines/test_int_functions.py b/tests/integration/engines/test_int_functions.py index 10ef32e..be6d5c8 100644 --- a/tests/integration/engines/test_int_functions.py +++ b/tests/integration/engines/test_int_functions.py @@ -18,6 +18,7 @@ get_func_from_session as get_func_from_session_without_fallback, ) from sqlframe.bigquery import BigQuerySession +from sqlframe.databricks import DatabricksSession from sqlframe.duckdb import DuckDBCatalog, DuckDBSession from sqlframe.postgres import PostgresDataFrame, PostgresSession from sqlframe.snowflake import SnowflakeSession @@ -144,6 +145,9 @@ def test_lit(get_session_and_func, arg, expected): if isinstance(session, SnowflakeSession): if isinstance(arg, Row): pytest.skip("Snowflake doesn't support literal row types") + if isinstance(session, DatabricksSession): + if isinstance(arg, datetime.datetime) and arg.tzinfo is None: + expected = expected.replace(tzinfo=datetime.timezone.utc) if isinstance(session, DuckDBSession): if isinstance(arg, dict): expected = Row(**expected) @@ -193,7 +197,7 @@ def test_typeof(get_session_and_func, get_types, arg, expected): if isinstance(session, PySparkSession) else dialect_to_string(session.execution_dialect) ) - if isinstance(session, (SparkSession, PySparkSession)): + if isinstance(session, (SparkSession, PySparkSession, DatabricksSession)): if expected == "timestamptz": expected = "timestamp" if isinstance(session, DuckDBSession): @@ -239,7 +243,15 @@ def test_alias(get_session_and_func): assert df.select(col("employee_id").alias("test")).first().__fields__[0] == "test" space_result = df.select(col("employee_id").alias("A Space In New Name")).first().__fields__[0] if isinstance( - session, (DuckDBSession, BigQuerySession, PostgresSession, SnowflakeSession, SparkSession) + session, + ( + DuckDBSession, + BigQuerySession, + PostgresSession, + SnowflakeSession, + SparkSession, + DatabricksSession, + ), ): assert space_result == "`a space in new name`" else: @@ -1429,7 +1441,10 @@ def test_to_timestamp(get_session_and_func): session, to_timestamp = get_session_and_func("to_timestamp") df = session.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) result = df.select(to_timestamp(df.t).alias("dt")).first()[0] - assert result == datetime.datetime(1997, 2, 28, 10, 30) + if isinstance(session, DatabricksSession): + assert result == datetime.datetime(1997, 2, 28, 10, 30, tzinfo=datetime.timezone.utc) + else: + assert result == datetime.datetime(1997, 2, 28, 10, 30) result = df.select(to_timestamp(df.t, "yyyy-MM-dd HH:mm:ss").alias("dt")).first()[0] if isinstance(session, (BigQuerySession, DuckDBSession)): assert result == datetime.datetime( @@ -1440,7 +1455,7 @@ def test_to_timestamp(get_session_and_func): 30, tzinfo=datetime.timezone.utc if isinstance(session, BigQuerySession) else None, ) - elif isinstance(session, PostgresSession): + elif isinstance(session, (PostgresSession, DatabricksSession)): assert result == datetime.datetime(1997, 2, 28, 10, 30, tzinfo=datetime.timezone.utc) elif isinstance(session, SnowflakeSession): assert result == datetime.datetime( @@ -1467,14 +1482,18 @@ def test_trunc(get_session_and_func): def test_date_trunc(get_session_and_func): session, date_trunc = get_session_and_func("date_trunc") df = session.createDataFrame([("1997-02-28 05:02:11",)], ["t"]) - assert df.select(date_trunc("year", df.t).alias("year")).first()[0] == datetime.datetime( + assert df.select(date_trunc("year", df.t).alias("year")).first()[0].replace( + tzinfo=None + ) == datetime.datetime( 1997, 1, 1, 0, 0, ) - assert df.select(date_trunc("month", df.t).alias("month")).first()[0] == datetime.datetime( + assert df.select(date_trunc("month", df.t).alias("month")).first()[0].replace( + tzinfo=None + ) == datetime.datetime( 1997, 2, 1, @@ -1498,7 +1517,10 @@ def test_last_day(get_session_and_func): def test_from_unixtime(get_session_and_func): session, from_unixtime = get_session_and_func("from_unixtime") df = session.createDataFrame([(1428476400,)], ["unix_time"]) - if isinstance(session, (BigQuerySession, DuckDBSession, PostgresSession, SnowflakeSession)): + if isinstance( + session, + (BigQuerySession, DuckDBSession, PostgresSession, SnowflakeSession, DatabricksSession), + ): expected = "2015-04-08 07:00:00" else: expected = "2015-04-08 00:00:00" @@ -1509,7 +1531,10 @@ def test_unix_timestamp(get_session_and_func): session, unix_timestamp = get_session_and_func("unix_timestamp") df = session.createDataFrame([("2015-04-08",)], ["dt"]) result = df.select(unix_timestamp("dt", "yyyy-MM-dd").alias("unix_time")).first()[0] - if isinstance(session, (BigQuerySession, DuckDBSession, PostgresSession, SnowflakeSession)): + if isinstance( + session, + (BigQuerySession, DuckDBSession, PostgresSession, SnowflakeSession, DatabricksSession), + ): assert result == 1428451200 else: assert result == 1428476400 @@ -1518,29 +1543,32 @@ def test_unix_timestamp(get_session_and_func): def test_from_utc_timestamp(get_session_and_func): session, from_utc_timestamp = get_session_and_func("from_utc_timestamp") df = session.createDataFrame([("1997-02-28 10:30:00", "JST")], ["ts", "tz"]) - assert df.select(from_utc_timestamp(df.ts, "PST").alias("local_time")).first()[ - 0 - ] == datetime.datetime(1997, 2, 28, 2, 30) - assert df.select(from_utc_timestamp(df.ts, df.tz).alias("local_time")).first()[ - 0 - ] == datetime.datetime(1997, 2, 28, 19, 30) + assert df.select(from_utc_timestamp(df.ts, "PST").alias("local_time")).first()[0].replace( + tzinfo=None + ) == datetime.datetime(1997, 2, 28, 2, 30) + assert df.select(from_utc_timestamp(df.ts, df.tz).alias("local_time")).first()[0].replace( + tzinfo=None + ) == datetime.datetime(1997, 2, 28, 19, 30) def test_to_utc_timestamp(get_session_and_func): session, to_utc_timestamp = get_session_and_func("to_utc_timestamp") df = session.createDataFrame([("1997-02-28 10:30:00", "JST")], ["ts", "tz"]) - assert df.select(to_utc_timestamp(df.ts, "PST").alias("utc_time")).first()[ - 0 - ] == datetime.datetime(1997, 2, 28, 18, 30) - assert df.select(to_utc_timestamp(df.ts, df.tz).alias("utc_time")).first()[ - 0 - ] == datetime.datetime(1997, 2, 28, 1, 30) + assert df.select(to_utc_timestamp(df.ts, "PST").alias("utc_time")).first()[0].replace( + tzinfo=None + ) == datetime.datetime(1997, 2, 28, 18, 30) + assert df.select(to_utc_timestamp(df.ts, df.tz).alias("utc_time")).first()[0].replace( + tzinfo=None + ) == datetime.datetime(1997, 2, 28, 1, 30) def test_timestamp_seconds(get_session_and_func): session, timestamp_seconds = get_session_and_func("timestamp_seconds") df = session.createDataFrame([(1230219000,)], ["unix_time"]) - if isinstance(session, (BigQuerySession, DuckDBSession, PostgresSession, SnowflakeSession)): + if isinstance( + session, + (BigQuerySession, DuckDBSession, PostgresSession, SnowflakeSession, DatabricksSession), + ): expected = datetime.datetime(2008, 12, 25, 15, 30, 00) else: expected = datetime.datetime(2008, 12, 25, 7, 30) @@ -1553,11 +1581,14 @@ def test_timestamp_seconds(get_session_and_func): def test_window(get_session_and_func, get_func): session, window = get_session_and_func("window") sum = get_func("sum", session) + col = get_func("col", session) df = session.createDataFrame([(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)]).toDF("date", "val") w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) + # SQLFrame does not support the syntax used in the example so the "col" function was used instead. + # https://spark.apache.org/docs/3.4.0/api/python/reference/pyspark.sql/api/pyspark.sql.functions.window.html result = w.select( - w.window.start.cast("string").alias("start"), - w.window.end.cast("string").alias("end"), + col("window.start").cast("string").alias("start"), + col("window.end").cast("string").alias("end"), "sum", ).collect() assert result == [ @@ -1568,20 +1599,23 @@ def test_window(get_session_and_func, get_func): def test_session_window(get_session_and_func, get_func): session, session_window = get_session_and_func("session_window") sum = get_func("sum", session) + col = get_func("col", session) lit = get_func("lit", session) df = session.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") w = df.groupBy(session_window("date", "5 seconds")).agg(sum("val").alias("sum")) + # SQLFrame does not support the syntax used in the example so the "col" function was used instead. + # https://spark.apache.org/docs/3.4.0/api/python/reference/pyspark.sql/api/pyspark.sql.functions.session_window.html assert w.select( - w.session_window.start.cast("string").alias("start"), - w.session_window.end.cast("string").alias("end"), + col("session_window.start").cast("string").alias("start"), + col("session_window.end").cast("string").alias("end"), "sum", ).collect() == [ Row(start="2016-03-11 09:00:07", end="2016-03-11 09:00:12", sum=1), ] w = df.groupBy(session_window("date", lit("5 seconds"))).agg(sum("val").alias("sum")) assert w.select( - w.session_window.start.cast("string").alias("start"), - w.session_window.end.cast("string").alias("end"), + col("session_window.start").cast("string").alias("start"), + col("session_window.end").cast("string").alias("end"), "sum", ).collect() == [Row(start="2016-03-11 09:00:07", end="2016-03-11 09:00:12", sum=1)] @@ -2338,6 +2372,40 @@ def test_explode_outer(get_session_and_func, get_func): Row(id=2, an_array=[], key=None, value=None), Row(id=3, an_array=None, key=None, value=None), ] + elif isinstance(session, DatabricksSession): + df = ( + session.range(1) + .select( + lit(1).alias("id"), + lit(["foo", "bar"]).alias("an_array"), + lit({"x": 1.0}).alias("a_map"), + ) + .union( + session.range(1).select( + lit(2).alias("id"), + lit([]).alias("an_array"), + lit({}).alias("a_map"), + ) + ) + .union( + session.range(1).select( + lit(3).alias("id"), + lit(None).alias("an_array"), + lit(None).alias("a_map"), + ) + ) + ) + assert df.select("id", "a_map", explode_outer("an_array")).collect() == [ + Row(id=1, a_map={"x": Decimal("1.0")}, col="foo"), + Row(id=1, a_map={"x": Decimal("1.0")}, col="bar"), + Row(id=2, a_map=[], col=None), + Row(id=3, a_map=None, col=None), + ] + assert df.select("id", "an_array", explode_outer("a_map")).collect() == [ + Row(id=1, an_array=["foo", "bar"], key="x", value=1.0), + Row(id=2, an_array=[], key=None, value=None), + Row(id=3, an_array=None, key=None, value=None), + ] else: df = ( session.range(1) @@ -2909,9 +2977,19 @@ def test_zip_with(get_session_and_func, get_func): def test_transform_keys(get_session_and_func, get_func): session, transform_keys = get_session_and_func("transform_keys") upper = get_func("upper", session) + concat_ws = get_func("concat_ws", session) + lit = get_func("lit", session) df = session.createDataFrame([(1, {"foo": -2.0, "bar": 2.0})], ("id", "data")) - row = df.select(transform_keys("data", lambda k, _: upper(k)).alias("data_upper")).head() - assert sorted(row["data_upper"].items()) == [("BAR", 2.0), ("FOO", -2.0)] + if isinstance(session, DatabricksSession): + row = df.select( + transform_keys("data", lambda k, _: concat_ws("_", k, lit("a"))).alias("data_upper") + ).head() + expected = [("bar_a", 2.0), ("foo_a", -2.0)] + assert sorted(row["data_upper"].items()) == expected + else: + row = df.select(transform_keys("data", lambda k, _: upper(k)).alias("data_upper")).head() + expected = [("BAR", 2.0), ("FOO", -2.0)] + assert sorted(row["data_upper"].items()) == expected def test_transform_values(get_session_and_func, get_func): @@ -2923,7 +3001,11 @@ def test_transform_values(get_session_and_func, get_func): "data", lambda k, v: when(k.isin("IT", "OPS"), v + 10.0).otherwise(v) ).alias("new_data") ).head() - assert sorted(row["new_data"].items()) == [("IT", 20.0), ("OPS", 34.0), ("SALES", 2.0)] + if isinstance(session, (SparkSession, PySparkSession)): + expected = [("IT", 20.0), ("OPS", 34.0), ("SALES", 2.0)] + else: + expected = [("it", 20.0), ("ops", 34.0), ("sales", 2.0)] + assert sorted(row["new_data"].items()) == expected def test_map_filter(get_session_and_func, get_func): @@ -2942,7 +3024,11 @@ def test_map_zip_with(get_session_and_func, get_func): row = df.select( map_zip_with("base", "ratio", lambda k, v1, v2: round(v1 * v2, 2)).alias("updated_data") ).head() - assert sorted(row["updated_data"].items()) == [("IT", 48.0), ("SALES", 16.8)] + if isinstance(session, (SparkSession, PySparkSession)): + expected = [("IT", 48.0), ("SALES", 16.8)] + else: + expected = [("it", 48.0), ("sales", 16.8)] + assert sorted(row["updated_data"].items()) == expected def test_nullif(get_session_and_func): @@ -3153,7 +3239,7 @@ def test_try_to_number(get_session_and_func, get_func): lit = get_func("lit", session) df = session.createDataFrame([("$78.12",)], ["e"]) actual = df.select(try_to_number(df.e, lit("$99.99")).alias("r")).first()[0] - if isinstance(session, SparkSession): + if isinstance(session, (SparkSession, DatabricksSession)): expected = 78.12 else: expected = Decimal("78.12") @@ -3388,7 +3474,7 @@ def test_approx_percentile(get_session_and_func, get_func): assert df.select(approx_percentile("value", [0.25, 0.5, 0.75], 1000000)).collect() == [ Row(value=[0.7264430125286507, 9.98975299938167, 19.335304783039014]) ] - assert df.groupBy("key").agg(approx_percentile("value", 0.5, 1000000)).collect() == [ + assert sorted(df.groupBy("key").agg(approx_percentile("value", 0.5, 1000000)).collect()) == [ Row(key=0, value=-0.03519435193070876), Row(key=1, value=9.990389751837329), Row(key=2, value=19.967859769284075), @@ -3491,6 +3577,8 @@ def test_convert_timezone(get_session_and_func, get_func): expected = datetime.datetime(2015, 4, 7, 16, 0, tzinfo=datetime.timezone.utc) elif isinstance(session, SnowflakeSession): expected = datetime.datetime(2015, 4, 8, 15, 0, tzinfo=pytz.FixedOffset(480)) + elif isinstance(session, DatabricksSession): + expected = datetime.datetime(2015, 4, 8, 8, 0) else: expected = datetime.datetime(2015, 4, 8, 15, 0) assert df.select(convert_timezone(None, lit("Asia/Hong_Kong"), "dt").alias("ts")).collect() == [ @@ -3532,17 +3620,20 @@ def test_current_user(get_session_and_func, get_func): def test_current_catalog(get_session_and_func, get_func): session, current_catalog = get_session_and_func("current_catalog") - assert session.range(1).select(current_catalog()).first()[0] == "spark_catalog" + if isinstance(session, DatabricksSession): + assert session.range(1).select(current_catalog()).first()[0] == "sqlframe" + else: + assert session.range(1).select(current_catalog()).first()[0] == "spark_catalog" def test_current_database(get_session_and_func, get_func): session, current_database = get_session_and_func("current_database") - assert session.range(1).select(current_database()).first()[0] == "db1" + assert session.range(1).select(current_database()).first()[0] in ("db1", "default", "public") def test_current_schema(get_session_and_func, get_func): session, current_schema = get_session_and_func("current_schema") - assert session.range(1).select(current_schema()).first()[0] == "db1" + assert session.range(1).select(current_schema()).first()[0] in ("db1", "default", "public") def test_current_timezone(get_session_and_func, get_func): @@ -3766,9 +3857,22 @@ def test_histogram_numeric(get_session_and_func, get_func): session, histogram_numeric = get_session_and_func("histogram_numeric") lit = get_func("lit", session) df = session.createDataFrame([("a", 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["c1", "c2"]) - assert df.select(histogram_numeric("c2", lit(5))).collect() == [ - Row(value=[Row(x=1, y=1.0), Row(x=2, y=2.0), Row(x=3, y=1.0), Row(x=8, y=1.0)]) - ] + if isinstance(session, DatabricksSession): + assert df.select(histogram_numeric("c2", lit(5))).collect() == [ + Row( + value=[ + Row(x=1, y=1.0), + Row(x=2, y=1.0), + Row(x=2, y=1.0), + Row(x=3, y=1.0), + Row(x=8, y=1.0), + ] + ) + ] + else: + assert df.select(histogram_numeric("c2", lit(5))).collect() == [ + Row(value=[Row(x=1, y=1.0), Row(x=2, y=2.0), Row(x=3, y=1.0), Row(x=8, y=1.0)]) + ] def test_hll_sketch_agg(get_session_and_func, get_func): @@ -3993,12 +4097,24 @@ def test_make_timestamp(get_session_and_func, get_func): [[2014, 12, 28, 6, 30, 45.887, "CET"]], ["year", "month", "day", "hour", "min", "sec", "timezone"], ) - assert df.select( - make_timestamp(df.year, df.month, df.day, df.hour, df.min, df.sec, df.timezone).alias("r") - ).first()[0] == datetime.datetime(2014, 12, 27, 21, 30, 45, 887000) - assert df.select( - make_timestamp(df.year, df.month, df.day, df.hour, df.min, df.sec).alias("r") - ).first()[0] == datetime.datetime(2014, 12, 28, 6, 30, 45, 887000) + if isinstance(session, DatabricksSession): + assert df.select( + make_timestamp(df.year, df.month, df.day, df.hour, df.min, df.sec, df.timezone).alias( + "r" + ) + ).first()[0].replace(tzinfo=None) == datetime.datetime(2014, 12, 28, 5, 30, 45, 887000) + assert df.select( + make_timestamp(df.year, df.month, df.day, df.hour, df.min, df.sec).alias("r") + ).first()[0].replace(tzinfo=None) == datetime.datetime(2014, 12, 28, 6, 30, 45, 887000) + else: + assert df.select( + make_timestamp(df.year, df.month, df.day, df.hour, df.min, df.sec, df.timezone).alias( + "r" + ) + ).first()[0] == datetime.datetime(2014, 12, 27, 21, 30, 45, 887000) + assert df.select( + make_timestamp(df.year, df.month, df.day, df.hour, df.min, df.sec).alias("r") + ).first()[0] == datetime.datetime(2014, 12, 28, 6, 30, 45, 887000) def test_make_timestamp_ltz(get_session_and_func, get_func): @@ -4474,7 +4590,7 @@ def test_regr_intercept(get_session_and_func, get_func): x = (col("id") % 3).alias("x") y = (randn(42) + x * 10).alias("y") df = session.range(0, 1000, 1, 1).select(x, y) - assert df.select(regr_intercept("y", "x")).first()[0] == -0.04961745990969568 + assert math.isclose(df.select(regr_intercept("y", "x")).first()[0], -0.04961745990969568) def test_regr_r2(get_session_and_func, get_func): @@ -4484,7 +4600,7 @@ def test_regr_r2(get_session_and_func, get_func): x = (col("id") % 3).alias("x") y = (randn(42) + x * 10).alias("y") df = session.range(0, 1000, 1, 1).select(x, y) - assert df.select(regr_r2("y", "x")).first()[0] == 0.9851908293645436 + assert math.isclose(df.select(regr_r2("y", "x")).first()[0], 0.9851908293645436) def test_regr_slope(get_session_and_func, get_func): @@ -4494,7 +4610,7 @@ def test_regr_slope(get_session_and_func, get_func): x = (col("id") % 3).alias("x") y = (randn(42) + x * 10).alias("y") df = session.range(0, 1000, 1, 1).select(x, y) - assert df.select(regr_slope("y", "x")).first()[0] == 10.040390844891048 + assert math.isclose(df.select(regr_slope("y", "x")).first()[0], 10.040390844891048) def test_regr_sxx(get_session_and_func, get_func): @@ -4504,7 +4620,7 @@ def test_regr_sxx(get_session_and_func, get_func): x = (col("id") % 3).alias("x") y = (randn(42) + x * 10).alias("y") df = session.range(0, 1000, 1, 1).select(x, y) - assert df.select(regr_sxx("y", "x")).first()[0] == 666.9989999999996 + assert math.isclose(df.select(regr_sxx("y", "x")).first()[0], 666.9989999999996) def test_regr_sxy(get_session_and_func, get_func): @@ -4514,7 +4630,7 @@ def test_regr_sxy(get_session_and_func, get_func): x = (col("id") % 3).alias("x") y = (randn(42) + x * 10).alias("y") df = session.range(0, 1000, 1, 1).select(x, y) - assert df.select(regr_sxy("y", "x")).first()[0] == 6696.93065315148 + assert math.isclose(df.select(regr_sxy("y", "x")).first()[0], 6696.93065315148) def test_regr_syy(get_session_and_func, get_func): @@ -4524,7 +4640,7 @@ def test_regr_syy(get_session_and_func, get_func): x = (col("id") % 3).alias("x") y = (randn(42) + x * 10).alias("y") df = session.range(0, 1000, 1, 1).select(x, y) - assert df.select(regr_syy("y", "x")).first()[0] == 68250.53503811295 + assert math.isclose(df.select(regr_syy("y", "x")).first()[0], 68250.53503811295) def test_replace(get_session_and_func, get_func): @@ -4710,17 +4826,27 @@ def test_substr(get_session_and_func, get_func): def test_timestamp_micros(get_session_and_func, get_func): session, timestamp_micros = get_session_and_func("timestamp_micros") time_df = session.createDataFrame([(1230219000,)], ["unix_time"]) - assert time_df.select(timestamp_micros(time_df.unix_time).alias("ts")).first()[ - 0 - ] == datetime.datetime(1969, 12, 31, 16, 20, 30, 219000) + if isinstance(session, DatabricksSession): + assert time_df.select(timestamp_micros(time_df.unix_time).alias("ts")).first()[0].replace( + tzinfo=None + ) == datetime.datetime(1970, 1, 1, 0, 20, 30, 219000) + else: + assert time_df.select(timestamp_micros(time_df.unix_time).alias("ts")).first()[ + 0 + ] == datetime.datetime(1969, 12, 31, 16, 20, 30, 219000) def test_timestamp_millis(get_session_and_func, get_func): session, timestamp_millis = get_session_and_func("timestamp_millis") time_df = session.createDataFrame([(1230219000,)], ["unix_time"]) - assert time_df.select(timestamp_millis(time_df.unix_time).alias("ts")).first()[ - 0 - ] == datetime.datetime(1970, 1, 14, 21, 43, 39) + if isinstance(session, DatabricksSession): + assert time_df.select(timestamp_millis(time_df.unix_time).alias("ts")).first()[0].replace( + tzinfo=None + ) == datetime.datetime(1970, 1, 15, 5, 43, 39) + else: + assert time_df.select(timestamp_millis(time_df.unix_time).alias("ts")).first()[ + 0 + ] == datetime.datetime(1970, 1, 14, 21, 43, 39) def test_to_char(get_session_and_func, get_func): @@ -4772,7 +4898,7 @@ def test_to_unix_timestamp(get_session_and_func, get_func): lit = get_func("lit", session) df = session.createDataFrame([("2016-04-08",)], ["e"]) result = df.select(to_unix_timestamp(df.e, lit("yyyy-MM-dd")).alias("r")).first()[0] - if isinstance(session, DuckDBSession): + if isinstance(session, (DuckDBSession, DatabricksSession)): assert result == 1460073600.0 else: assert result == 1460098800 @@ -4866,12 +4992,12 @@ def test_try_to_timestamp(get_session_and_func, get_func): lit = get_func("lit", session) df = session.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) result = df.select(try_to_timestamp(df.t).alias("dt")).first()[0] - if isinstance(session, BigQuerySession): + if isinstance(session, (BigQuerySession, DatabricksSession)): assert result == datetime.datetime(1997, 2, 28, 10, 30, tzinfo=datetime.timezone.utc) else: assert result == datetime.datetime(1997, 2, 28, 10, 30) result = df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).first()[0] - if isinstance(session, BigQuerySession): + if isinstance(session, (BigQuerySession, DatabricksSession)): assert result == datetime.datetime(1997, 2, 28, 10, 30, tzinfo=datetime.timezone.utc) else: assert result == datetime.datetime(1997, 2, 28, 10, 30) @@ -4894,21 +5020,30 @@ def test_unix_micros(get_session_and_func, get_func): session, unix_micros = get_session_and_func("unix_micros") to_timestamp = get_func("to_timestamp", session) df = session.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) - assert df.select(unix_micros(to_timestamp(df.t)).alias("n")).first()[0] == 1437584400000000 + if isinstance(session, DatabricksSession): + assert df.select(unix_micros(to_timestamp(df.t)).alias("n")).first()[0] == 1437559200000000 + else: + assert df.select(unix_micros(to_timestamp(df.t)).alias("n")).first()[0] == 1437584400000000 def test_unix_millis(get_session_and_func, get_func): session, unix_millis = get_session_and_func("unix_millis") to_timestamp = get_func("to_timestamp", session) df = session.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) - assert df.select(unix_millis(to_timestamp(df.t)).alias("n")).first()[0] == 1437584400000 + if isinstance(session, DatabricksSession): + assert df.select(unix_millis(to_timestamp(df.t)).alias("n")).first()[0] == 1437559200000 + else: + assert df.select(unix_millis(to_timestamp(df.t)).alias("n")).first()[0] == 1437584400000 def test_unix_seconds(get_session_and_func, get_func): session, unix_seconds = get_session_and_func("unix_seconds") to_timestamp = get_func("to_timestamp", session) df = session.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) - assert df.select(unix_seconds(to_timestamp(df.t)).alias("n")).first()[0] == 1437584400 + if isinstance(session, DatabricksSession): + assert df.select(unix_seconds(to_timestamp(df.t)).alias("n")).first()[0] == 1437559200 + else: + assert df.select(unix_seconds(to_timestamp(df.t)).alias("n")).first()[0] == 1437584400 def test_url_decode(get_session_and_func, get_func): @@ -4954,12 +5089,15 @@ def test_window_time(get_session_and_func, get_func): session, window_time = get_session_and_func("window_time") window = get_func("window", session) sum = get_func("sum", session) + col = get_func("col", session) df = session.createDataFrame( [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)], ).toDF("date", "val") w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) + # SQLFrame does not support the syntax used in the example so the "col" function was used instead. + # https://spark.apache.org/docs/3.4.0/api/python/reference/pyspark.sql/api/pyspark.sql.functions.window_time.html assert w.select( - w.window.end.cast("string").alias("end"), + col("window.end").cast("string").alias("end"), window_time(w.window).cast("string").alias("window_time"), "sum", ).collect() == [Row(end="2016-03-11 09:00:10", window_time="2016-03-11 09:00:09.999999", sum=1)] diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index 8cc29a9..151949b 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -23,6 +23,9 @@ from sqlframe.bigquery import types as BigQueryTypes from sqlframe.bigquery.dataframe import BigQueryDataFrame from sqlframe.bigquery.session import BigQuerySession +from sqlframe.databricks import types as DatabricksTypes +from sqlframe.databricks.dataframe import DatabricksDataFrame +from sqlframe.databricks.session import DatabricksSession from sqlframe.duckdb import types as DuckDBTypes from sqlframe.duckdb.dataframe import DuckDBDataFrame from sqlframe.duckdb.session import DuckDBSession @@ -97,6 +100,15 @@ pytest.mark.xdist_group("snowflake_tests"), ], ), + pytest.param( + "databricks", + marks=[ + pytest.mark.databricks, + pytest.mark.remote, + # Set xdist group in order to serialize tests + pytest.mark.xdist_group("databricks_tests"), + ], + ), pytest.param( "spark", marks=[ @@ -518,6 +530,56 @@ def snowflake_district( return df +@pytest.fixture +def databricks_employee( + databricks_session: DatabricksSession, _employee_data: EmployeeData +) -> DatabricksDataFrame: + databricks_employee_schema = DatabricksTypes.StructType( + [ + DatabricksTypes.StructField("employee_id", DatabricksTypes.IntegerType(), False), + DatabricksTypes.StructField("fname", DatabricksTypes.StringType(), False), + DatabricksTypes.StructField("lname", DatabricksTypes.StringType(), False), + DatabricksTypes.StructField("age", DatabricksTypes.IntegerType(), False), + DatabricksTypes.StructField("store_id", DatabricksTypes.IntegerType(), False), + ] + ) + df = databricks_session.createDataFrame(data=_employee_data, schema=databricks_employee_schema) + df.createOrReplaceTempView("employee") + return df + + +@pytest.fixture +def databricks_store( + databricks_session: DatabricksSession, _store_data: StoreData +) -> DatabricksDataFrame: + databricks_store_schema = DatabricksTypes.StructType( + [ + DatabricksTypes.StructField("store_id", DatabricksTypes.IntegerType(), False), + DatabricksTypes.StructField("store_name", DatabricksTypes.StringType(), False), + DatabricksTypes.StructField("district_id", DatabricksTypes.IntegerType(), False), + DatabricksTypes.StructField("num_sales", DatabricksTypes.IntegerType(), False), + ] + ) + df = databricks_session.createDataFrame(data=_store_data, schema=databricks_store_schema) + df.createOrReplaceTempView("store") + return df + + +@pytest.fixture +def databricks_district( + databricks_session: DatabricksSession, _district_data: DistrictData +) -> DatabricksDataFrame: + databricks_district_schema = DatabricksTypes.StructType( + [ + DatabricksTypes.StructField("district_id", DatabricksTypes.IntegerType(), False), + DatabricksTypes.StructField("district_name", DatabricksTypes.StringType(), False), + ] + ) + df = databricks_session.createDataFrame(data=_district_data, schema=databricks_district_schema) + df.createOrReplaceTempView("district") + return df + + @pytest.fixture def compare_frames(pyspark_session: PySparkSession) -> t.Callable: def _make_function( @@ -713,3 +775,11 @@ def _is_spark() -> bool: return request.node.name.endswith("[spark]") return _is_spark + + +@pytest.fixture +def is_databricks(request: FixtureRequest) -> t.Callable: + def _is_databricks() -> bool: + return request.node.name.endswith("[databricks]") + + return _is_databricks diff --git a/tests/unit/databricks/__init__.py b/tests/unit/databricks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/databricks/test_activate.py b/tests/unit/databricks/test_activate.py new file mode 100644 index 0000000..76f1ea7 --- /dev/null +++ b/tests/unit/databricks/test_activate.py @@ -0,0 +1,52 @@ +import pytest + +from sqlframe import activate +from sqlframe.databricks import Column as DatabricksColumn +from sqlframe.databricks import ( + DatabricksCatalog, + DatabricksDataFrame, + DatabricksDataFrameNaFunctions, + DatabricksDataFrameReader, + DatabricksDataFrameStatFunctions, + DatabricksDataFrameWriter, + DatabricksGroupedData, + DatabricksSession, + DatabricksUDFRegistration, +) +from sqlframe.databricks import Row as DatabricksRow +from sqlframe.databricks import Window as DatabricksWindow +from sqlframe.databricks import WindowSpec as DatabricksWindowSpec +from sqlframe.databricks import functions as DatabricksF +from sqlframe.databricks import types as DatabricksTypes + + +@pytest.mark.forked +def test_activate_databricks(check_pyspark_imports): + check_pyspark_imports( + "databricks", + sqlf_session=DatabricksSession, + sqlf_catalog=DatabricksCatalog, + sqlf_column=DatabricksColumn, + sqlf_dataframe=DatabricksDataFrame, + sqlf_grouped_data=DatabricksGroupedData, + sqlf_window=DatabricksWindow, + sqlf_window_spec=DatabricksWindowSpec, + sqlf_functions=DatabricksF, + sqlf_types=DatabricksTypes, + sqlf_udf_registration=DatabricksUDFRegistration, + sqlf_dataframe_reader=DatabricksDataFrameReader, + sqlf_dataframe_writer=DatabricksDataFrameWriter, + sqlf_dataframe_na_functions=DatabricksDataFrameNaFunctions, + sqlf_dataframe_stat_functions=DatabricksDataFrameStatFunctions, + sqlf_row=DatabricksRow, + ) + + +@pytest.mark.forked +def test_activate_databricks_default_dataset(): + activate("databricks", config={"default_dataset": "sqlframe.sqlframe_test"}) + from pyspark.sql import SparkSession + + assert SparkSession == DatabricksSession + spark = SparkSession.builder.appName("test").getOrCreate() + assert spark.default_dataset == "sqlframe.sqlframe_test"