diff --git a/README.md b/README.md index 821b3ce7d..a3339b2fb 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,10 @@ The following configurations can be supplied to models run with the dbt-spark pl | Option | Description | Required? | Example | |---------|----------------------------------------------------|-------------------------|--------------------------| | file_format | The file format to use when creating tables | Optional | `parquet` | - +| location_root | The created table uses the specified directory to store its data. The table alias is appended to it. | Optional | `/mnt/root` | +| partition_by | Partition the created table by the specified columns. A directory is created for each partition. | Optional | `partition_1` | +| clustered_by | Each partition in the created table will be split into a fixed number of buckets by the specified columns. | Optional | `cluster_1` | +| buckets | The number of buckets to create while clustering | Required if `clustered_by` is specified | `8` | **Incremental Models** diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 30c85820d..ba4527968 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -16,6 +16,10 @@ class SparkAdapter(SQLAdapter): ConnectionManager = SparkConnectionManager Relation = SparkRelation + AdapterSpecificConfigs = frozenset({"file_format", "location_root", + "partition_by", "clustered_by", + "buckets"}) + @classmethod def date_function(cls): return 'CURRENT_TIMESTAMP()' diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index c4616ca8c..ebf328085 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -5,6 +5,7 @@ {{ sql }} {% endmacro %} + {% macro file_format_clause() %} {%- set file_format = config.get('file_format', validator=validation.any[basestring]) -%} {%- if file_format is not none %} @@ -12,6 +13,29 @@ {%- endif %} {%- endmacro -%} + +{% macro location_clause() %} + {%- set location_root = config.get('location_root', validator=validation.any[basestring]) -%} + {%- set identifier = model['alias'] -%} + {%- if location_root is not none %} + location '{{ location_root }}/{{ identifier }}' + {%- endif %} +{%- endmacro -%} + + +{% macro comment_clause() %} + {%- set raw_persist_docs = config.get('persist_docs', {}) -%} + + {%- if raw_persist_docs is mapping -%} + {%- set raw_relation = raw_persist_docs.get('relation', false) -%} + {%- if raw_relation -%} + comment '{{ model.description }}' + {% endif %} + {%- else -%} + {{ exceptions.raise_compiler_error("Invalid value provided for 'persist_docs'. Expected dict but got value: " ~ raw_persist_docs) }} + {% endif %} +{%- endmacro -%} + {% macro partition_cols(label, required=false) %} {%- set cols = config.get('partition_by', validator=validation.any[list, basestring]) -%} {%- if cols is not none %} @@ -27,6 +51,24 @@ {%- endif %} {%- endmacro -%} + +{% macro clustered_cols(label, required=false) %} + {%- set cols = config.get('clustered_by', validator=validation.any[list, basestring]) -%} + {%- set buckets = config.get('buckets', validator=validation.any[int]) -%} + {%- if (cols is not none) and (buckets is not none) %} + {%- if cols is string -%} + {%- set cols = [cols] -%} + {%- endif -%} + {{ label }} ( + {%- for item in cols -%} + {{ item }} + {%- if not loop.last -%},{%- endif -%} + {%- endfor -%} + ) into {{ buckets }} buckets + {%- endif %} +{%- endmacro -%} + + {% macro spark__create_table_as(temporary, relation, sql) -%} {% if temporary -%} {{ spark_create_temporary_view(relation, sql) }} @@ -34,16 +76,23 @@ create table {{ relation }} {{ file_format_clause() }} {{ partition_cols(label="partitioned by") }} + {{ clustered_cols(label="clustered by") }} + {{ location_clause() }} + {{ comment_clause() }} as {{ sql }} {%- endif %} {%- endmacro -%} + {% macro spark__create_view_as(relation, sql) -%} - create view {{ relation }} as + create view {{ relation }} + {{ comment_clause() }} + as {{ sql }} {% endmacro %} + {% macro spark__get_columns_in_relation(relation) -%} {% call statement('get_columns_in_relation', fetch_result=True) %} describe {{ relation }} diff --git a/test/unit/test_macros.py b/test/unit/test_macros.py new file mode 100644 index 000000000..eb8852ed0 --- /dev/null +++ b/test/unit/test_macros.py @@ -0,0 +1,124 @@ +import mock +import unittest +import re +from collections import defaultdict +from jinja2 import Environment, FileSystemLoader +from dbt.context.common import _add_validation + + +class TestSparkMacros(unittest.TestCase): + + def setUp(self): + self.jinja_env = Environment(loader=FileSystemLoader('dbt/include/spark/macros'), + extensions=['jinja2.ext.do',]) + + self.config = {} + + self.default_context = {} + self.default_context['validation'] = mock.Mock() + self.default_context['model'] = mock.Mock() + self.default_context['exceptions'] = mock.Mock() + self.default_context['config'] = mock.Mock() + self.default_context['config'].get = lambda key, default=None, **kwargs: self.config.get(key, default) + + + def __get_template(self, template_filename): + return self.jinja_env.get_template(template_filename, globals=self.default_context) + + + def __run_macro(self, template, name, temporary, relation, sql): + self.default_context['model'].alias = relation + value = getattr(template.module, name)(temporary, relation, sql) + return re.sub(r'\s\s+', ' ', value) + + + def test_macros_load(self): + self.jinja_env.get_template('adapters.sql') + + + def test_macros_create_table_as(self): + template = self.__get_template('adapters.sql') + + self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'), + "create table my_table as select 1") + + + def test_macros_create_table_as_file_format(self): + template = self.__get_template('adapters.sql') + + + self.config['file_format'] = 'delta' + self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'), + "create table my_table using delta as select 1") + + + def test_macros_create_table_as_partition(self): + template = self.__get_template('adapters.sql') + + + self.config['partition_by'] = 'partition_1' + self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'), + "create table my_table partitioned by (partition_1) as select 1") + + + def test_macros_create_table_as_partitions(self): + template = self.__get_template('adapters.sql') + + + self.config['partition_by'] = ['partition_1', 'partition_2'] + self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'), + "create table my_table partitioned by (partition_1,partition_2) as select 1") + + + def test_macros_create_table_as_cluster(self): + template = self.__get_template('adapters.sql') + + + self.config['clustered_by'] = 'cluster_1' + self.config['buckets'] = '1' + self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'), + "create table my_table clustered by (cluster_1) into 1 buckets as select 1") + + + def test_macros_create_table_as_clusters(self): + template = self.__get_template('adapters.sql') + + + self.config['clustered_by'] = ['cluster_1', 'cluster_2'] + self.config['buckets'] = '1' + self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'), + "create table my_table clustered by (cluster_1,cluster_2) into 1 buckets as select 1") + + + def test_macros_create_table_as_location(self): + template = self.__get_template('adapters.sql') + + + self.config['location_root'] = '/mnt/root' + self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'), + "create table my_table location '/mnt/root/my_table' as select 1") + + + def test_macros_create_table_as_comment(self): + template = self.__get_template('adapters.sql') + + + self.config['persist_docs'] = {'relation': True} + self.default_context['model'].description = 'Description Test' + self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'), + "create table my_table comment 'Description Test' as select 1") + + + def test_macros_create_table_as_all(self): + template = self.__get_template('adapters.sql') + + self.config['file_format'] = 'delta' + self.config['location_root'] = '/mnt/root' + self.config['partition_by'] = ['partition_1', 'partition_2'] + self.config['clustered_by'] = ['cluster_1', 'cluster_2'] + self.config['buckets'] = '1' + self.config['persist_docs'] = {'relation': True} + self.default_context['model'].description = 'Description Test' + + self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'), + "create table my_table using delta partitioned by (partition_1,partition_2) clustered by (cluster_1,cluster_2) into 1 buckets location '/mnt/root/my_table' comment 'Description Test' as select 1")