Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull the owner from the DESCRIBE EXTENDED #39

Merged
merged 18 commits into from
Mar 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ env/
*.pyc
__pycache__
.tox/
.idea/
build/
dist/
dbt-integration-tests
Expand Down
121 changes: 72 additions & 49 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.spark import SparkRelation
from dbt.adapters.spark import SparkConnectionManager
import dbt.exceptions
from typing import List, Dict

from dbt.logger import GLOBAL_LOGGER as logger
import agate
import dbt.exceptions
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.graph.manifest import Manifest
from dbt.logger import GLOBAL_LOGGER as logger

from dbt.adapters.spark import SparkConnectionManager
from dbt.adapters.spark import SparkRelation
from dbt.adapters.spark.relation import SparkColumn

LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
GET_RELATION_TYPE_MACRO_NAME = 'spark_get_relation_type'
DROP_RELATION_MACRO_NAME = 'drop_relation'
FETCH_TBLPROPERTIES_MACRO_NAME = 'spark_fetch_tblproperties'
GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation'

KEY_TABLE_OWNER = 'Owner'


class SparkAdapter(SQLAdapter):
Expand Down Expand Up @@ -48,7 +55,7 @@ def get_relation_type(self, relation, model_name=None):
# Override that creates macros without a known type - adapter macros that
# require a type will dynamically check at query-time
def list_relations_without_caching(self, information_schema, schema,
model_name=None):
model_name=None) -> List:
kwargs = {'information_schema': information_schema, 'schema': schema}
try:
results = self.execute_macro(
Expand Down Expand Up @@ -112,53 +119,69 @@ def drop_relation(self, relation, model_name=None):
kwargs={'relation': relation}
)

def get_catalog(self, manifest):
@staticmethod
def find_table_information_separator(rows: List[dict]) -> int:
pos = 0
for row in rows:
if not row['col_name'] or row['col_name'].startswith('#'):
break
pos += 1
return pos

def parse_describe_extended(
self,
relation: Relation,
raw_rows: List[agate.Row]
) -> List[SparkColumn]:
# Convert the Row to a dict
dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows]
# Find the separator between the rows and the metadata provided
# by the DESCRIBE TABLE EXTENDED statement
pos = SparkAdapter.find_table_information_separator(dict_rows)

# Remove rows that start with a hash, they are comments
rows = [
row for row in raw_rows[0:pos]
if not row['col_name'].startswith('#')
]
metadata = {
col['col_name']: col['data_type'] for col in raw_rows[pos + 1:]
}
return [SparkColumn(
relation.database,
relation.schema,
relation.name,
relation.type,
metadata.get(KEY_TABLE_OWNER),
column['col_name'],
idx,
column['data_type']
) for idx, column in enumerate(rows)]

def get_columns_in_relation(self,
relation: Relation) -> List[SparkColumn]:
rows: List[agate.Row] = super().get_columns_in_relation(relation)
return self.parse_describe_extended(relation, rows)

def get_properties(self, relation: Relation) -> Dict[str, str]:
properties = self.execute_macro(
FETCH_TBLPROPERTIES_MACRO_NAME,
kwargs={'relation': relation}
)
return {key: value for (key, value) in properties}

def get_catalog(self, manifest: Manifest) -> agate.Table:
schemas = manifest.get_used_schemas()

column_names = (
'table_database',
'table_schema',
'table_name',
'table_type',
'table_comment',
'table_owner',
'column_name',
'column_index',
'column_type',
'column_comment',
)
def to_dict(d: any) -> Dict:
return d.__dict__

columns = []
for (database_name, schema_name) in schemas:
relations = self.list_relations(database_name, schema_name)
for relation in relations:
logger.debug("Getting table schema for relation {}".format(relation)) # noqa
table_columns = self.get_columns_in_relation(relation)
rel_type = self.get_relation_type(relation)

for column_index, column in enumerate(table_columns):
# Fixes for pseudocolumns with no type
if column.name in (
'# Partition Information',
'# col_name'
):
continue
elif column.dtype is None:
continue

column_data = (
relation.database,
relation.schema,
relation.name,
rel_type,
None,
None,
column.name,
column_index,
column.data_type,
None,
)
column_dict = dict(zip(column_names, column_data))
columns.append(column_dict)

return dbt.clients.agate_helper.table_from_data(columns, column_names)
logger.debug("Getting table schema for relation {}", relation)
columns += list(
map(to_dict, self.get_columns_in_relation(relation))
)
return agate.Table.from_object(columns)
30 changes: 29 additions & 1 deletion dbt/adapters/spark/relation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.base.relation import BaseRelation, Column


class SparkRelation(BaseRelation):
Expand Down Expand Up @@ -45,3 +45,31 @@ class SparkRelation(BaseRelation):
'required': ['metadata', 'type', 'path', 'include_policy',
'quote_policy', 'quote_character', 'dbt_created']
}


class SparkColumn(Column):

def __init__(self,
table_database: str,
table_schema: str,
table_name: str,
table_type: str,
table_owner: str,
column_name: str,
column_index: int,
column_type: str):
super(SparkColumn, self).__init__(column_name, column_type)
self.table_database = table_database
self.table_schema = table_schema
self.table_name = table_name
self.table_type = table_type
self.table_owner = table_owner
self.column_name = column_name
self.column_index = column_index

@property
def quoted(self):
return '`{}`'.format(self.column)

def __repr__(self):
return "<SparkColumn {}, {}>".format(self.name, self.data_type)
13 changes: 8 additions & 5 deletions dbt/include/spark/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,9 @@

{% macro spark__get_columns_in_relation(relation) -%}
{% call statement('get_columns_in_relation', fetch_result=True) %}
describe {{ relation }}
describe extended {{ relation }}
{% endcall %}

{% set table = load_result('get_columns_in_relation').table %}
{{ return(sql_convert_columns_in_relation(table)) }}

{% do return(load_result('get_columns_in_relation').table) %}
{% endmacro %}


Expand Down Expand Up @@ -149,6 +146,12 @@
{% endif %}
{%- endmacro %}

{% macro spark_fetch_tblproperties(relation) -%}
{% call statement('list_properties', fetch_result=True) -%}
SHOW TBLPROPERTIES {{ relation }}
{% endcall %}
{% do return(load_result('list_properties').table) %}
{%- endmacro %}

{% macro spark__rename_relation(from_relation, to_relation) -%}
{% call statement('rename_relation') -%}
Expand Down
122 changes: 103 additions & 19 deletions test/unit/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import mock
import unittest
import dbt.adapters

import dbt.flags as flags
import mock
from agate import Row
from dbt.adapters.base import BaseRelation
from pyhive import hive
from dbt.adapters.spark import SparkAdapter
import agate

from .utils import config_from_parts_or_dicts, inject_adapter
from dbt.adapters.spark import SparkAdapter
from .utils import config_from_parts_or_dicts


class TestSparkAdapter(unittest.TestCase):
Expand All @@ -29,13 +30,13 @@ def get_target_http(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'http',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 443,
'token': 'abc123',
'cluster': '01234-23423-coffeetime',
'type': 'spark',
'method': 'http',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 443,
'token': 'abc123',
'cluster': '01234-23423-coffeetime',
}
},
'target': 'test'
Expand All @@ -45,12 +46,12 @@ def get_target_thrift(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'thrift',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt'
'type': 'spark',
'method': 'thrift',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt'
}
},
'target': 'test'
Expand All @@ -60,7 +61,6 @@ def test_http_connection(self):
config = self.get_target_http(self.project_cfg)
adapter = SparkAdapter(config)


def hive_http_connect(thrift_transport):
self.assertEqual(thrift_transport.scheme, 'https')
self.assertEqual(thrift_transport.port, 443)
Expand All @@ -87,3 +87,87 @@ def hive_thrift_connect(host, port, username):

self.assertEqual(connection.state, 'open')
self.assertNotEqual(connection.handle, None)

def test_parse_relation(self):
self.maxDiff = None
rel_type = 'table'

relation = BaseRelation.create(
database='default_database',
schema='default_schema',
identifier='mytable',
type=rel_type
)

# Mimics the output of Spark with a DESCRIBE TABLE EXTENDED
plain_rows = [
('col1', 'decimal(22,0)'),
('col2', 'string',),
('dt', 'date'),
('# Partition Information', 'data_type'),
('# col_name', 'data_type'),
('dt', 'date'),
(None, None),
('# Detailed Table Information', None),
('Database', relation.database),
('Owner', 'root'),
('Created Time', 'Wed Feb 04 18:15:00 UTC 1815'),
('Last Access', 'Wed May 20 19:25:00 UTC 1925'),
('Type', 'MANAGED'),
('Provider', 'delta'),
('Location', '/mnt/vo'),
('Serde Library', 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'),
('InputFormat', 'org.apache.hadoop.mapred.SequenceFileInputFormat'),
('OutputFormat', 'org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat'),
('Partition Provider', 'Catalog')
]

input_cols = [Row(keys=['col_name', 'data_type'], values=r) for r in plain_rows]

config = self.get_target_http(self.project_cfg)
rows = SparkAdapter(config).parse_describe_extended(relation, input_cols)
Fokko marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(len(rows), 3)
Fokko marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(rows[0].__dict__, {
'table_database': relation.database,
'table_schema': relation.schema,
'table_name': relation.name,
'table_type': rel_type,
'table_owner': 'root',
'column': 'col1',
'column_name': 'col1',
'column_index': 0,
'dtype': 'decimal(22,0)',
'numeric_scale': None,
'numeric_precision': None,
'char_size': None
})

self.assertEqual(rows[1].__dict__, {
'table_database': relation.database,
'table_schema': relation.schema,
'table_name': relation.name,
'table_type': rel_type,
'table_owner': 'root',
'column': 'col2',
'column_name': 'col2',
'column_index': 1,
'dtype': 'string',
'numeric_scale': None,
'numeric_precision': None,
'char_size': None
})

self.assertEqual(rows[2].__dict__, {
'table_database': relation.database,
'table_schema': relation.schema,
'table_name': relation.name,
'table_type': rel_type,
'table_owner': 'root',
'column': 'dt',
'column_name': 'dt',
'column_index': 2,
'dtype': 'date',
'numeric_scale': None,
'numeric_precision': None,
'char_size': None
})
Fokko marked this conversation as resolved.
Show resolved Hide resolved