Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Databricks tests #218

Merged
merged 16 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ duckdb-test:
snowflake-test:
pytest -n auto -m "snowflake"

databricks-test:
pytest -n auto -m "databricks"

style:
pre-commit run --all-files

Expand Down
2 changes: 0 additions & 2 deletions docs/databricks.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from test import auth_type

# Databricks (In Development)

## Installation
Expand Down
5 changes: 5 additions & 0 deletions sqlframe/base/function_alternatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,11 @@ def get_json_object_cast_object(col: ColumnOrName, path: str) -> Column:
return get_json_object(col_func(col).cast("variant"), path)


def get_json_object_using_function(col: ColumnOrName, path: str) -> Column:
lit = get_func_from_session("lit")
return Column.invoke_anonymous_function(col, "GET_JSON_OBJECT", lit(path))


def create_map_with_cast(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
from sqlframe.base.functions import create_map

Expand Down
20 changes: 10 additions & 10 deletions sqlframe/base/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,7 +2173,7 @@ def current_database() -> Column:
current_schema = current_database


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def current_timezone() -> Column:
return Column.invoke_anonymous_function(None, "current_timezone")

Expand Down Expand Up @@ -2261,7 +2261,7 @@ def get(col: ColumnOrName, index: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_anonymous_function(col, "get", index)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def get_active_spark_context() -> SparkContext:
"""Raise RuntimeError if SparkContext is not initialized,
otherwise, returns the active SparkContext."""
Expand Down Expand Up @@ -2778,7 +2778,7 @@ def isnotnull(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "isnotnull")


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def java_method(*cols: ColumnOrName) -> Column:
"""
Calls a method with reflection.
Expand Down Expand Up @@ -3050,7 +3050,7 @@ def ln(col: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(col, expression.Ln)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def localtimestamp() -> Column:
"""
Returns the current timestamp without time zone at the start of query evaluation
Expand Down Expand Up @@ -3080,7 +3080,7 @@ def localtimestamp() -> Column:
return Column.invoke_anonymous_function(None, "localtimestamp")


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def make_dt_interval(
days: t.Optional[ColumnOrName] = None,
hours: t.Optional[ColumnOrName] = None,
Expand Down Expand Up @@ -3227,7 +3227,7 @@ def make_timestamp(
)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def make_timestamp_ltz(
years: ColumnOrName,
months: ColumnOrName,
Expand Down Expand Up @@ -3354,7 +3354,7 @@ def make_timestamp_ntz(
)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def make_ym_interval(
years: t.Optional[ColumnOrName] = None,
months: t.Optional[ColumnOrName] = None,
Expand Down Expand Up @@ -3922,7 +3922,7 @@ def printf(format: ColumnOrName, *cols: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(format, "printf", *cols)


@meta(unsupported_engines=["*", "spark"])
@meta(unsupported_engines=["*", "spark", "databricks"])
def product(col: ColumnOrName) -> Column:
"""
Aggregate function: returns the product of the values in a group.
Expand Down Expand Up @@ -3961,7 +3961,7 @@ def product(col: ColumnOrName) -> Column:
reduce = aggregate


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def reflect(*cols: ColumnOrName) -> Column:
"""
Calls a method with reflection.
Expand Down Expand Up @@ -5046,7 +5046,7 @@ def to_str(value: t.Any) -> t.Optional[str]:
return str(value)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def to_timestamp_ltz(
timestamp: ColumnOrName,
format: t.Optional[ColumnOrName] = None,
Expand Down
14 changes: 12 additions & 2 deletions sqlframe/databricks/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


class DatabricksCatalog(
SetCurrentCatalogFromUseMixin["DatabricksSession", "DatabricksDataFrame"],
GetCurrentCatalogFromFunctionMixin["DatabricksSession", "DatabricksDataFrame"],
GetCurrentDatabaseFromFunctionMixin["DatabricksSession", "DatabricksDataFrame"],
ListDatabasesFromInfoSchemaMixin["DatabricksSession", "DatabricksDataFrame"],
Expand All @@ -38,6 +37,15 @@ class DatabricksCatalog(
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog")
UPPERCASE_INFO_SCHEMA = True

def setCurrentCatalog(self, catalogName: str) -> None:
self.session._collect(
exp.Use(
kind=exp.Var(this=exp.to_identifier("CATALOG")),
this=exp.parse_identifier(catalogName, dialect=self.session.input_dialect),
),
quote_identifiers=False,
)

def listFunctions(
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
) -> t.List[Function]:
Expand Down Expand Up @@ -106,7 +114,9 @@ def listFunctions(
)
functions = [
Function(
name=normalize_string(x["function"], from_dialect="execution", to_dialect="output"),
name=normalize_string(
x["function"].split(".")[-1], from_dialect="execution", to_dialect="output"
),
catalog=normalize_string(
schema.catalog, from_dialect="execution", to_dialect="output"
),
Expand Down
2 changes: 1 addition & 1 deletion sqlframe/databricks/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _typed_columns(self) -> t.List[CatalogColumn]:
CatalogColumn(
name=normalize_string(
row.col_name, from_dialect="execution", to_dialect="output"
),
).replace("`", ""),
zerodarkzone marked this conversation as resolved.
Show resolved Hide resolved
dataType=normalize_string(
row.data_type,
from_dialect="execution",
Expand Down
1 change: 1 addition & 0 deletions sqlframe/databricks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
arrays_overlap_renamed as arrays_overlap,
_is_string_using_typeof_string_lcase as _is_string,
try_element_at_zero_based as try_element_at,
get_json_object_using_function as get_json_object,
)
27 changes: 27 additions & 0 deletions tests/common_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from sqlframe.base.session import _BaseSession
from sqlframe.bigquery.session import BigQuerySession
from sqlframe.databricks.session import DatabricksSession
from sqlframe.duckdb.session import DuckDBSession
from sqlframe.postgres.session import PostgresSession
from sqlframe.redshift.session import RedshiftSession
Expand All @@ -22,6 +23,7 @@
from sqlframe.standalone.session import StandaloneSession

if t.TYPE_CHECKING:
from databricks.sql import Connection as DatabricksConnection
from google.cloud.bigquery.dbapi.connection import (
Connection as BigQueryConnection,
)
Expand Down Expand Up @@ -231,6 +233,31 @@ def snowflake_session(snowflake_connection: SnowflakeConnection) -> SnowflakeSes
return session


@pytest.fixture(scope="session")
def databricks_connection() -> DatabricksConnection:
from databricks.sql import connect

conn = connect(
server_hostname=os.environ["SQLFRAME_DATABRICKS_SERVER_HOSTNAME"],
http_path=os.environ["SQLFRAME_DATABRICKS_HTTP_PATH"],
access_token=os.environ["SQLFRAME_DATABRICKS_ACCESS_TOKEN"],
auth_type="access_token",
catalog=os.environ["SQLFRAME_DATABRICKS_CATALOG"],
schema=os.environ["SQLFRAME_DATABRICKS_SCHEMA"],
_disable_pandas=True,
)
return conn


@pytest.fixture
def databricks_session(databricks_connection: DatabricksConnection) -> DatabricksSession:
session = DatabricksSession(databricks_connection)
session._execute("CREATE SCHEMA IF NOT EXISTS db1")
session._execute("CREATE TABLE IF NOT EXISTS db1.table1 (id INTEGER, name VARCHAR(100))")
session._execute("CREATE OR REPLACE FUNCTION db1.add(x INT, y INT) RETURNS INT RETURN x + y")
return session


@pytest.fixture(scope="module")
def _employee_data() -> EmployeeData:
return [
Expand Down
Empty file.
Loading
Loading