Skip to content

Commit

Permalink
[PECO-1286] Add tests for complex types in query results (#293)
Browse files Browse the repository at this point in the history
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
  • Loading branch information
Jesse authored Nov 29, 2023
1 parent a1bb6f9 commit aaaf047
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
session_configuration: Dict[str, Any] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
_use_arrow_native_complex_types: Optional[bool] = True,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -152,8 +153,13 @@ def read(self) -> Optional[OAuthToken]:
experimental_oauth_persistence=DevOnlyFilePersistence("~/dev-oauth.json")
)
```
:param _use_arrow_native_complex_types: `bool`, optional
Controls whether a complex type field value is returned as a string or as a native Arrow type. Defaults to True.
When True:
MAP is returned as List[Tuple[str, Any]]
STRUCT is returned as Dict[str, Any]
ARRAY is returned as numpy.ndarray
When False, complex types are returned as a strings. These are generally deserializable as JSON.
"""

# Internal arguments in **kwargs:
Expand Down Expand Up @@ -184,9 +190,6 @@ def read(self) -> Optional[OAuthToken]:
# _disable_pandas
# In case the deserialisation through pandas causes any issues, it can be disabled with
# this flag.
# _use_arrow_native_complex_types
# DBR will return native Arrow types for structs, arrays and maps instead of Arrow strings
# (True by default)
# _use_arrow_native_decimals
# Databricks runtime will return native Arrow types for decimals instead of Arrow strings
# (True by default)
Expand Down Expand Up @@ -225,6 +228,7 @@ def read(self) -> Optional[OAuthToken]:
http_path,
(http_headers or []) + base_headers,
auth_provider,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
**kwargs,
)

Expand Down
63 changes: 63 additions & 0 deletions tests/e2e/test_complex_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

import pytest
from numpy import ndarray

from tests.e2e.test_driver import PySQLPytestTestCase


class TestComplexTypes(PySQLPytestTestCase):
@pytest.fixture(scope="class")
def table_fixture(self):
"""A pytest fixture that creates a table with a complex type, inserts a record, yields, and then drops the table"""

with self.cursor() as cursor:
# Create the table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pysql_test_complex_types_table (
array_col ARRAY<STRING>,
map_col MAP<STRING, INTEGER>,
struct_col STRUCT<field1: STRING, field2: INTEGER>
)
"""
)
# Insert a record
cursor.execute(
"""
INSERT INTO pysql_test_complex_types_table
VALUES (
ARRAY('a', 'b', 'c'),
MAP('a', 1, 'b', 2, 'c', 3),
NAMED_STRUCT('field1', 'a', 'field2', 1)
)
"""
)
yield
# Clean up the table after the test
cursor.execute("DROP TABLE IF EXISTS pysql_test_complex_types_table")

@pytest.mark.parametrize(
"field,expected_type",
[("array_col", ndarray), ("map_col", list), ("struct_col", dict)],
)
def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
"""Confirms the return types of a complex type field when reading as arrow"""

with self.cursor() as cursor:
result = cursor.execute(
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
).fetchone()

assert isinstance(result[field], expected_type)

@pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")])
def test_read_complex_types_as_string(self, field, table_fixture):
"""Confirms the return type of a complex type that is returned as a string"""
with self.cursor(
extra_params={"_use_arrow_native_complex_types": False}
) as cursor:
result = cursor.execute(
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
).fetchone()

assert isinstance(result[field], str)

0 comments on commit aaaf047

Please sign in to comment.