Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Actually refactor to move off info schema #896

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions dbt/adapters/databricks/behaviors/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Optional, cast

from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.contracts.relation import RelationType
from dbt.adapters.databricks.connections import DatabricksConnectionManager
from dbt.adapters.databricks.relation import KEY_TABLE_PROVIDER, is_hive_metastore
from dbt.adapters.databricks.utils import handle_missing_objects
from dbt.adapters.spark.impl import KEY_TABLE_OWNER
from dbt.adapters.sql.impl import SQLAdapter

CURRENT_CATALOG_MACRO_NAME = "current_catalog"
USE_CATALOG_MACRO_NAME = "use_catalog"
GET_CATALOG_MACRO_NAME = "get_catalog"
SHOW_TABLES_MACRO_NAME = "show_tables"
SHOW_VIEWS_MACRO_NAME = "show_views"


class MetadataBehavior(ABC):
@classmethod
@abstractmethod
def list_relations_without_caching(
cls, adapter: SQLAdapter, schema_relation: BaseRelation
) -> list[BaseRelation]:
pass


class DefaultMetadataBehavior(MetadataBehavior):
@classmethod
def list_relations_without_caching(
cls, adapter: SQLAdapter, schema_relation: BaseRelation
) -> list[BaseRelation]:
empty: list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]] = []
results = handle_missing_objects(
lambda: cls._get_relations_without_caching(adapter, schema_relation), empty
)

relations = []
for row in results:
name, kind, file_format, owner = row
metadata = None
if file_format:
metadata = {KEY_TABLE_OWNER: owner, KEY_TABLE_PROVIDER: file_format}
relations.append(
adapter.Relation.create(
database=schema_relation.database,
schema=schema_relation.schema,
identifier=name,
type=adapter.Relation.get_relation_type(kind),
metadata=metadata,
)
)

return relations

@classmethod
def _get_relations_without_caching(
cls, adapter: SQLAdapter, relation: BaseRelation
) -> list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]:
if is_hive_metastore(relation.database):
return cls._get_hive_relations(adapter, relation)
return cls._get_uc_relations(adapter, relation)

@staticmethod
def _get_uc_relations(
adapter: SQLAdapter, relation: BaseRelation
) -> list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]:
kwargs = {"relation": relation}
results = adapter.execute_macro("get_uc_tables", kwargs=kwargs)
return [
(row["table_name"], row["table_type"], row["file_format"], row["table_owner"])
for row in results
]

@classmethod
def _get_hive_relations(
cls, adapter: SQLAdapter, relation: BaseRelation
) -> list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]:
kwargs = {"relation": relation}
connection_manager = cast(DatabricksConnectionManager, adapter.connections)

new_rows: list[tuple[str, Optional[str]]]
if all([relation.database, relation.schema]):
tables = connection_manager.list_tables(
database=relation.database, # type: ignore[arg-type]
schema=relation.schema, # type: ignore[arg-type]
)

new_rows = []
for row in tables:
# list_tables returns TABLE_TYPE as view for both materialized views and for
# streaming tables. Set type to "" in this case and it will be resolved below.
type = row["TABLE_TYPE"].lower() if row["TABLE_TYPE"] else None
row = (row["TABLE_NAME"], type)
new_rows.append(row)

else:
tables = adapter.execute_macro(SHOW_TABLES_MACRO_NAME, kwargs=kwargs)
new_rows = [(row["tableName"], None) for row in tables]

# if there are any table types to be resolved
if any(not row[1] for row in new_rows):
with cls._catalog(adapter, relation.database):
views = adapter.execute_macro(SHOW_VIEWS_MACRO_NAME, kwargs=kwargs)
view_names = set(views.columns["viewName"].values()) # type: ignore[attr-defined]
new_rows = [
(row[0], str(RelationType.View if row[0] in view_names else RelationType.Table))
for row in new_rows
]

return [(row[0], row[1], None, None) for row in new_rows]

@staticmethod
@contextmanager
def _catalog(adapter: SQLAdapter, catalog: Optional[str]) -> Iterator[None]:
"""
A context manager to make the operation work in the specified catalog,
and move back to the current catalog after the operation.

If `catalog` is None, the operation works in the current catalog.
"""
current_catalog: Optional[str] = None
try:
if catalog is not None:
current_catalog = adapter.execute_macro(CURRENT_CATALOG_MACRO_NAME)[0][0]
if current_catalog is not None:
if current_catalog != catalog:
adapter.execute_macro(USE_CATALOG_MACRO_NAME, kwargs=dict(catalog=catalog))
else:
current_catalog = None
yield
finally:
if current_catalog is not None:
adapter.execute_macro(USE_CATALOG_MACRO_NAME, kwargs=dict(catalog=current_catalog))
84 changes: 5 additions & 79 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
GetColumnsByDescribe,
GetColumnsByInformationSchema,
)
from dbt.adapters.databricks.behaviors.metadata import DefaultMetadataBehavior, MetadataBehavior
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.connections import (
DatabricksConnectionManager,
Expand All @@ -43,7 +44,6 @@
WorkflowPythonJobHelper,
)
from dbt.adapters.databricks.relation import (
KEY_TABLE_PROVIDER,
DatabricksRelation,
DatabricksRelationType,
)
Expand Down Expand Up @@ -171,13 +171,15 @@ class DatabricksAdapter(SparkAdapter):
)

get_column_behavior: GetColumnsBehavior
metadata_behavior: MetadataBehavior

def __init__(self, config: Any, mp_context: SpawnContext) -> None:
super().__init__(config, mp_context)

# dbt doesn't propogate flags for certain workflows like dbt debug so this requires
# an additional guard
self.get_column_behavior = GetColumnsByDescribe()
self.metadata_behavior = DefaultMetadataBehavior()
try:
if self.behavior.use_info_schema_for_columns.no_warn: # type: ignore[attr-defined]
self.get_column_behavior = GetColumnsByInformationSchema()
Expand Down Expand Up @@ -295,84 +297,8 @@ def execute(
if staging_table is not None:
self.drop_relation(staging_table)

def list_relations_without_caching( # type: ignore[override]
self, schema_relation: DatabricksRelation
) -> list[DatabricksRelation]:
empty: list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]] = []
results = handle_missing_objects(
lambda: self.get_relations_without_caching(schema_relation), empty
)

relations = []
for row in results:
name, kind, file_format, owner = row
metadata = None
if file_format:
metadata = {KEY_TABLE_OWNER: owner, KEY_TABLE_PROVIDER: file_format}
relations.append(
self.Relation.create(
database=schema_relation.database,
schema=schema_relation.schema,
identifier=name,
type=self.Relation.get_relation_type(kind),
metadata=metadata,
)
)

return relations

def get_relations_without_caching(
self, relation: DatabricksRelation
) -> list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]:
if relation.is_hive_metastore():
return self._get_hive_relations(relation)
return self._get_uc_relations(relation)

def _get_uc_relations(
self, relation: DatabricksRelation
) -> list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]:
kwargs = {"relation": relation}
results = self.execute_macro("get_uc_tables", kwargs=kwargs)
return [
(row["table_name"], row["table_type"], row["file_format"], row["table_owner"])
for row in results
]

def _get_hive_relations(
self, relation: DatabricksRelation
) -> list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]:
kwargs = {"relation": relation}

new_rows: list[tuple[str, Optional[str]]]
if all([relation.database, relation.schema]):
tables = self.connections.list_tables(
database=relation.database, # type: ignore[arg-type]
schema=relation.schema, # type: ignore[arg-type]
)

new_rows = []
for row in tables:
# list_tables returns TABLE_TYPE as view for both materialized views and for
# streaming tables. Set type to "" in this case and it will be resolved below.
type = row["TABLE_TYPE"].lower() if row["TABLE_TYPE"] else None
row = (row["TABLE_NAME"], type)
new_rows.append(row)

else:
tables = self.execute_macro(SHOW_TABLES_MACRO_NAME, kwargs=kwargs)
new_rows = [(row["tableName"], None) for row in tables]

# if there are any table types to be resolved
if any(not row[1] for row in new_rows):
with self._catalog(relation.database):
views = self.execute_macro(SHOW_VIEWS_MACRO_NAME, kwargs=kwargs)
view_names = set(views.columns["viewName"].values()) # type: ignore[attr-defined]
new_rows = [
(row[0], str(RelationType.View if row[0] in view_names else RelationType.Table))
for row in new_rows
]

return [(row[0], row[1], None, None) for row in new_rows]
def list_relations_without_caching(self, schema_relation: BaseRelation) -> list[BaseRelation]:
return self.metadata_behavior.list_relations_without_caching(self, schema_relation)

@available.parse(lambda *a, **k: [])
def get_column_schema_from_query(self, sql: str) -> list[DatabricksColumn]:
Expand Down
38 changes: 1 addition & 37 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
CATALOG_KEY_IN_SESSION_PROPERTIES,
)
from dbt.adapters.databricks.impl import get_identifier_list_string
from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType
from dbt.adapters.databricks.relation import DatabricksRelation
from dbt.adapters.databricks.utils import check_not_found_error
from dbt.config import RuntimeConfig
from tests.unit.utils import config_from_parts_or_dicts
Expand Down Expand Up @@ -342,42 +342,6 @@ def _test_databricks_sql_connector_http_header_connection(self, http_headers, co
assert connection.credentials.token == "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
assert connection.credentials.schema == "analytics"

@patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create")
def test_list_relations_without_caching__no_relations(self, _):
with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked:
mocked.return_value = []
adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn"))
assert adapter.list_relations("database", "schema") == []

@patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create")
def test_list_relations_without_caching__some_relations(self, _):
with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked:
mocked.return_value = [("name", "table", "hudi", "owner")]
adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn"))
relations = adapter.list_relations("database", "schema")
assert len(relations) == 1
relation = relations[0]
assert relation.identifier == "name"
assert relation.database == "database"
assert relation.schema == "schema"
assert relation.type == DatabricksRelationType.Table
assert relation.owner == "owner"
assert relation.is_hudi

@patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create")
def test_list_relations_without_caching__hive_relation(self, _):
with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked:
mocked.return_value = [("name", "table", None, None)]
adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn"))
relations = adapter.list_relations("database", "schema")
assert len(relations) == 1
relation = relations[0]
assert relation.identifier == "name"
assert relation.database == "database"
assert relation.schema == "schema"
assert relation.type == DatabricksRelationType.Table
assert not relation.has_information()

@patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create")
def test_get_schema_for_catalog__no_columns(self, _):
with patch.object(DatabricksAdapter, "_list_relations_with_information") as list_info:
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from unittest.mock import Mock, patch

import pytest

from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.databricks.behaviors.metadata import DefaultMetadataBehavior
from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType


class TestDefaultMetadataBehavior:
@pytest.fixture
def behavior(self):
return DefaultMetadataBehavior()

@pytest.fixture
def relation(self):
return BaseRelation.create("database", "schema")

@pytest.fixture
def adapter(self):
a = Mock()
a.Relation = DatabricksRelation
return a

def test_list_relations_without_caching__no_relations(self, behavior, relation):
with patch.object(DefaultMetadataBehavior, "_get_relations_without_caching") as mocked:
mocked.return_value = []
assert behavior.list_relations_without_caching(Mock(), relation) == []

def test_list_relations_without_caching__some_relations(self, behavior, relation, adapter):
with patch.object(DefaultMetadataBehavior, "_get_relations_without_caching") as mocked:
mocked.return_value = [("name", "table", "hudi", "owner")]
relations = behavior.list_relations_without_caching(adapter, relation)
assert len(relations) == 1
rrelation = relations[0]
assert rrelation.identifier == "name"
assert rrelation.database == "database"
assert rrelation.schema == "schema"
assert rrelation.type == DatabricksRelationType.Table
assert rrelation.owner == "owner"
assert rrelation.is_hudi

def test_list_relations_without_caching__hive_relation(self, behavior, relation, adapter):
with patch.object(DefaultMetadataBehavior, "_get_relations_without_caching") as mocked:
mocked.return_value = [("name", "table", None, None)]
relations = behavior.list_relations_without_caching(adapter, relation)
assert len(relations) == 1
rrelation = relations[0]
assert rrelation.identifier == "name"
assert rrelation.database == "database"
assert rrelation.schema == "schema"
assert rrelation.type == DatabricksRelationType.Table
assert not rrelation.has_information()
Loading