Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull the owner from the DESCRIBE EXTENDED #39

Merged
merged 18 commits into from
Mar 16, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ env/
*.pyc
__pycache__
.tox/
.idea/
build/
dist/
dbt-integration-tests
Expand Down
133 changes: 88 additions & 45 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.spark import SparkRelation
from dbt.adapters.spark import SparkConnectionManager
import dbt.exceptions
from typing import List, Dict

from dbt.logger import GLOBAL_LOGGER as logger
import agate
import dbt.exceptions
from agate import Column
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.graph.manifest import Manifest
from dbt.logger import GLOBAL_LOGGER as logger

from dbt.adapters.spark import SparkConnectionManager
from dbt.adapters.spark import SparkRelation

LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
GET_RELATION_TYPE_MACRO_NAME = 'spark_get_relation_type'
DROP_RELATION_MACRO_NAME = 'drop_relation'
FETCH_TBLPROPERTIES_MACRO_NAME = 'spark_fetch_tblproperties'


class SparkAdapter(SQLAdapter):
ConnectionManager = SparkConnectionManager
Relation = SparkRelation

column_names = (
'table_database',
'table_schema',
'table_name',
'table_type',
'table_comment',
'table_owner',
'column_name',
'column_index',
'column_type',
'column_comment',
)

@classmethod
def date_function(cls):
return 'CURRENT_TIMESTAMP()'
Expand Down Expand Up @@ -56,7 +73,7 @@ def get_relation_type(self, relation, model_name=None):
# Override that creates macros without a known type - adapter macros that
# require a type will dynamically check at query-time
def list_relations_without_caching(self, information_schema, schema,
model_name=None):
model_name=None) -> List[Relation]:
kwargs = {'information_schema': information_schema, 'schema': schema}
results = self.execute_macro(
LIST_RELATIONS_MACRO_NAME,
Expand Down Expand Up @@ -90,53 +107,79 @@ def drop_relation(self, relation, model_name=None):
kwargs={'relation': relation}
)

def get_catalog(self, manifest):
schemas = manifest.get_used_schemas()
@staticmethod
def _parse_relation(relation: Relation,
jtcohen6 marked this conversation as resolved.
Show resolved Hide resolved
table_columns: List[Column],
rel_type: str,
properties: Dict[str, str] = None) -> List[dict]:
properties = properties or {}
table_owner_key = 'Owner'

# First check if it is present in the properties
table_owner = properties.get(table_owner_key)

found_detailed_table_marker = False
for column in table_columns:
if column.name == '# Detailed Table Information':
found_detailed_table_marker = True

# In case there is another column with the name Owner
if not found_detailed_table_marker:
continue

column_names = (
'table_database',
'table_schema',
'table_name',
'table_type',
'table_comment',
'table_owner',
'column_name',
'column_index',
'column_type',
'column_comment',
if not table_owner and column.name == table_owner_key:
table_owner = column.data_type

columns = []
for column in table_columns:
# Fixes for pseudocolumns with no type
if column.name in {
'# Partition Information',
'# col_name',
''
}:
continue
elif column.name == '# Detailed Table Information':
# Loop until the detailed table information
break
elif column.data_type is None:
continue

column_data = (
relation.database,
relation.schema,
relation.name,
rel_type,
None,
table_owner,
column.name,
len(columns),
column.data_type,
None
)
column_dict = dict(zip(SparkAdapter.column_names, column_data))
columns.append(column_dict)

return columns

def get_properties(self, relation: Relation) -> Dict[str, str]:
properties = self.execute_macro(
FETCH_TBLPROPERTIES_MACRO_NAME,
kwargs={'relation': relation}
)
return {key: value for (key, value) in properties}

def get_catalog(self, manifest: Manifest):
schemas = manifest.get_used_schemas()

columns = []
for (database_name, schema_name) in schemas:
relations = self.list_relations(database_name, schema_name)
for relation in relations:
properties = self.get_properties(relation)
logger.debug("Getting table schema for relation {}".format(relation)) # noqa
table_columns = self.get_columns_in_relation(relation)
rel_type = self.get_relation_type(relation)
columns += self._parse_relation(relation, table_columns, rel_type, properties)

for column_index, column in enumerate(table_columns):
# Fixes for pseudocolumns with no type
if column.name in (
'# Partition Information',
'# col_name'
):
continue
elif column.dtype is None:
continue

column_data = (
relation.database,
relation.schema,
relation.name,
rel_type,
None,
None,
column.name,
column_index,
column.data_type,
None,
)
column_dict = dict(zip(column_names, column_data))
columns.append(column_dict)

return dbt.clients.agate_helper.table_from_data(columns, column_names)
return dbt.clients.agate_helper.table_from_data(columns, SparkAdapter.column_names)
8 changes: 7 additions & 1 deletion dbt/include/spark/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

{% macro spark__get_columns_in_relation(relation) -%}
{% call statement('get_columns_in_relation', fetch_result=True) %}
describe {{ relation }}
describe extended {{ relation }}
{% endcall %}

{% set table = load_result('get_columns_in_relation').table %}
Expand Down Expand Up @@ -89,6 +89,12 @@
{% endif %}
{%- endmacro %}

{% macro spark_fetch_tblproperties(relation) -%}
{% call statement('list_properties', fetch_result=True) -%}
SHOW TBLPROPERTIES {{ relation }}
{% endcall %}
{% do return(load_result('list_properties').table) %}
{%- endmacro %}

{% macro spark__rename_relation(from_relation, to_relation) -%}
{% call statement('rename_relation') -%}
Expand Down
164 changes: 145 additions & 19 deletions test/unit/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import dbt.flags as flags
import mock
import unittest
import dbt.adapters
import dbt.flags as flags
from collections import namedtuple

from agate import Column, MappedSequence
from dbt.adapters.base import BaseRelation
from pyhive import hive
from dbt.adapters.spark import SparkAdapter
import agate

from .utils import config_from_parts_or_dicts, inject_adapter
from dbt.adapters.spark import SparkAdapter
from .utils import config_from_parts_or_dicts


class TestSparkAdapter(unittest.TestCase):
Expand All @@ -29,13 +31,13 @@ def get_target_http(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'http',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 443,
'token': 'abc123',
'cluster': '01234-23423-coffeetime',
'type': 'spark',
'method': 'http',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 443,
'token': 'abc123',
'cluster': '01234-23423-coffeetime',
}
},
'target': 'test'
Expand All @@ -45,12 +47,12 @@ def get_target_thrift(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'thrift',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt'
'type': 'spark',
'method': 'thrift',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt'
}
},
'target': 'test'
Expand All @@ -60,7 +62,6 @@ def test_http_connection(self):
config = self.get_target_http(self.project_cfg)
adapter = SparkAdapter(config)


def hive_http_connect(thrift_transport):
self.assertEqual(thrift_transport.scheme, 'https')
self.assertEqual(thrift_transport.port, 443)
Expand All @@ -87,3 +88,128 @@ def hive_thrift_connect(host, port, username):

self.assertEqual(connection.state, 'open')
self.assertNotEqual(connection.handle, None)

def test_parse_relation(self):
rel_type = 'table'

relation = BaseRelation.create(
database='default_database',
schema='default_schema',
identifier='mytable',
type=rel_type
)

# Mimics the output of Spark with a DESCRIBE TABLE EXTENDED
plain_rows = [
('col1', 'decimal(22,0)'),
('col2', 'string',),
('# Partition Information', 'data_type'),
('# col_name', 'data_type'),
('dt', 'date'),
('', ''),
('# Detailed Table Information', ''),
('Database', relation.database),
('Owner', 'root'),
('Created Time', 'Wed Feb 04 18:15:00 UTC 1815'),
('Last Access', 'Wed May 20 19:25:00 UTC 1925'),
('Type', 'MANAGED'),
('Provider', 'delta'),
('Location', '/mnt/vo'),
('Serde Library', 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'),
('InputFormat', 'org.apache.hadoop.mapred.SequenceFileInputFormat'),
('OutputFormat', 'org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat'),
('Partition Provider', 'Catalog')
]

input_cols = [Column(index=None, name=r[0], data_type=r[1], rows=MappedSequence(
keys=['col_name', 'data_type'],
values=r
)) for r in plain_rows]

rows = SparkAdapter._parse_relation(relation, input_cols, rel_type)
self.assertEqual(len(rows), 3)
Fokko marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(rows[0], {
'table_database': relation.database,
'table_schema': relation.schema,
'table_name': relation.name,
'table_type': rel_type,
'table_comment': None,
'table_owner': 'root',
'column_name': 'col1',
'column_index': 0,
'column_type': 'decimal(22,0)',
'column_comment': None
})

self.assertEqual(rows[1], {
'table_database': relation.database,
'table_schema': relation.schema,
'table_name': relation.name,
'table_type': rel_type,
'table_comment': None,
'table_owner': 'root',
'column_name': 'col2',
'column_index': 1,
'column_type': 'string',
'column_comment': None
})

self.assertEqual(rows[2], {
'table_database': relation.database,
'table_schema': relation.schema,
'table_name': relation.name,
'table_type': rel_type,
'table_comment': None,
'table_owner': 'root',
'column_name': 'dt',
'column_index': 2,
'column_type': 'date',
'column_comment': None
})

def test_parse_relation_with_properties(self):
rel_type = 'table'

relation = BaseRelation.create(
database='default_database',
schema='default_schema',
identifier='mytable',
type=rel_type
)

# Mimics the output of Spark with a DESCRIBE TABLE EXTENDED
plain_rows = [
('col1', 'decimal(19,25)'),
('', ''),
('# Detailed Table Information', ''),
('Database', relation.database),
('Owner', 'root'),
('Created Time', 'Wed Feb 04 18:15:00 UTC 1815'),
('Last Access', 'Wed May 20 19:25:00 UTC 1925'),
('Type', 'MANAGED'),
('Provider', 'delta'),
('Location', '/mnt/vo'),
('Serde Library', 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'),
('InputFormat', 'org.apache.hadoop.mapred.SequenceFileInputFormat'),
('OutputFormat', 'org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat'),
('Partition Provider', 'Catalog')
]

input_cols = [Column(index=None, name=r[0], data_type=r[1], rows=MappedSequence(
keys=['col_name', 'data_type'],
values=r
)) for r in plain_rows]

rows = SparkAdapter._parse_relation(relation, input_cols, rel_type, {'Owner': 'Fokko'})
self.assertEqual(rows[0], {
'table_database': relation.database,
'table_schema': relation.schema,
'table_name': relation.name,
'table_type': rel_type,
'table_comment': None,
'table_owner': 'Fokko',
'column_name': 'col1',
'column_index': 0,
'column_type': 'decimal(19,25)',
'column_comment': None
})