Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable create or replace sql syntax #125

Merged
merged 11 commits into from
Dec 31, 2020
5 changes: 4 additions & 1 deletion dbt/adapters/spark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TypeVar, Optional, Dict, Any

from dbt.adapters.base.column import Column
from hologram import JsonDict

Self = TypeVar('Self', bound='SparkColumn')

Expand Down Expand Up @@ -54,7 +55,9 @@ def convert_table_stats(raw_stats: Optional[str]) -> Dict[str, Any]:
table_stats[f'stats:{key}:include'] = True
return table_stats

def to_dict(self, omit_none=False):
def to_dict(
self, omit_none: bool = True, validate: bool = False
) -> JsonDict:
original_dict = super().to_dict(omit_none=omit_none)
# If there are stats, merge them into the root of the dict
original_stats = original_dict.pop('table_stats')
Expand Down
8 changes: 6 additions & 2 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Optional, List, Dict, Any, Union, Iterable
import agate
from dbt.contracts.relation import RelationType

import dbt
import dbt.exceptions
Expand Down Expand Up @@ -131,11 +132,14 @@ def list_relations_without_caching(
f'got {len(row)} values, expected 4'
)
_schema, name, _, information = row
rel_type = ('view' if 'Type: VIEW' in information else 'table')
rel_type = RelationType.View \
if 'Type: VIEW' in information else RelationType.Table
is_delta = 'Provider: delta' in information
relation = self.Relation.create(
schema=_schema,
identifier=name,
type=rel_type
type=rel_type,
is_delta=is_delta
)
relations.append(relation)

Expand Down
3 changes: 3 additions & 0 deletions dbt/adapters/spark/relation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from dataclasses import dataclass

from dbt.adapters.base.relation import BaseRelation, Policy
Expand All @@ -23,6 +25,7 @@ class SparkRelation(BaseRelation):
quote_policy: SparkQuotePolicy = SparkQuotePolicy()
include_policy: SparkIncludePolicy = SparkIncludePolicy()
quote_character: str = '`'
is_delta: Optional[bool] = None
Fokko marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
if self.database != self.schema and self.database:
Expand Down
6 changes: 5 additions & 1 deletion dbt/include/spark/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@
{% if temporary -%}
{{ create_temporary_view(relation, sql) }}
{%- else -%}
create table {{ relation }}
{% if config.get('file_format', validator=validation.any[basestring]) == 'delta' %}
create or replace table {{ relation }}
{% else %}
create table {{ relation }}
{% endif %}
{{ file_format_clause() }}
{{ partition_cols(label="partitioned by") }}
{{ clustered_cols(label="clustered by") }}
Expand Down
21 changes: 10 additions & 11 deletions dbt/include/spark/macros/materializations/incremental.sql
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
{% do return(file_format) %}
{% endmacro %}

{% macro dbt_spark_validate_get_incremental_strategy(file_format) %}
{% macro dbt_spark_validate_get_incremental_strategy(relation) %}
{#-- Find and validate the incremental strategy #}
{%- set strategy = config.get("incremental_strategy", default="insert_overwrite") -%}

Expand All @@ -41,23 +41,22 @@
{% if strategy not in ['merge', 'insert_overwrite'] %}
{% do exceptions.raise_compiler_error(invalid_strategy_msg) %}
{%-else %}
{% if strategy == 'merge' and file_format != 'delta' %}
{% if strategy == 'merge' and not relation.is_delta %}
{% do exceptions.raise_compiler_error(invalid_merge_msg) %}
{% endif %}
{% endif %}

{% do return(strategy) %}
{% endmacro %}

{% macro dbt_spark_validate_merge(file_format) %}
{% macro dbt_spark_validate_merge(relation) %}
{% set invalid_file_format_msg -%}
You can only choose the 'merge' incremental_strategy when file_format is set to 'delta'
{%- endset %}

{% if file_format != 'delta' %}
{% if not relation.is_delta %}
{% do exceptions.raise_compiler_error(invalid_file_format_msg) %}
{% endif %}

{% endmacro %}


Expand All @@ -84,20 +83,20 @@


{% materialization incremental, adapter='spark' -%}
{% set target_relation = this %}
{% set existing_relation = load_relation(this) %}
{% set tmp_relation = make_temp_relation(this) %}

{#-- Validate early so we don't run SQL if the file_format is invalid --#}
{% set file_format = dbt_spark_validate_get_file_format() -%}
{#-- Validate early so we don't run SQL if the strategy is invalid --#}
{% set strategy = dbt_spark_validate_get_incremental_strategy(file_format) -%}
{% set strategy = dbt_spark_validate_get_incremental_strategy(target_relation) -%}

{%- set full_refresh_mode = (flags.FULL_REFRESH == True) -%}

{% set target_relation = this %}
{% set existing_relation = load_relation(this) %}
{% set tmp_relation = make_temp_relation(this) %}

{% if strategy == 'merge' %}
{%- set unique_key = config.require('unique_key') -%}
{% do dbt_spark_validate_merge(file_format) %}
{% do dbt_spark_validate_merge(target_relation) %}
{% endif %}

{% if config.get('partition_by') %}
Expand Down
14 changes: 7 additions & 7 deletions dbt/include/spark/macros/materializations/snapshot.sql
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,21 @@
Invalid file format: {{ file_format }}
Snapshot functionality requires file_format be set to 'delta'
{%- endset %}

{% set target_relation_exists, target_relation = get_or_create_relation(
database=none,
schema=model.schema,
identifier=target_table,
type='table') -%}

{%- if file_format != 'delta' -%}
{%- if not target_relation_exists.is_delta -%}
Fokko marked this conversation as resolved.
Show resolved Hide resolved
{% do exceptions.raise_compiler_error(invalid_format_msg) %}
{% endif %}

{% if not adapter.check_schema_exists(model.database, model.schema) %}
{% do create_schema(model.database, model.schema) %}
{% endif %}

{% set target_relation_exists, target_relation = get_or_create_relation(
database=none,
schema=model.schema,
identifier=target_table,
type='table') -%}

{%- if not target_relation.is_table -%}
{% do exceptions.relation_wrong_type(target_relation, 'table') %}
{%- endif -%}
Expand Down
4 changes: 3 additions & 1 deletion dbt/include/spark/macros/materializations/table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
{{ run_hooks(pre_hooks) }}

-- setup: if the target relation already exists, drop it
{% if old_relation -%}
-- in case if the existing and future table is delta, we want to do a
-- create or replace table instead of dropping, so we don't have the table unavailable
{% if old_relation and not (old_relation.is_delta and config.get('file_format', validator=validation.any[basestring]) == 'delta') -%}
{{ adapter.drop_relation(old_relation) }}
{%- endif %}

Expand Down
71 changes: 28 additions & 43 deletions test/unit/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,104 +8,86 @@ class TestSparkMacros(unittest.TestCase):

def setUp(self):
self.jinja_env = Environment(loader=FileSystemLoader('dbt/include/spark/macros'),
extensions=['jinja2.ext.do',])
extensions=['jinja2.ext.do', ])

self.config = {}

self.default_context = {}
self.default_context['validation'] = mock.Mock()
self.default_context['model'] = mock.Mock()
self.default_context['exceptions'] = mock.Mock()
self.default_context['config'] = mock.Mock()
self.default_context = {
'validation': mock.Mock(),
'model': mock.Mock(),
'exceptions': mock.Mock(),
'config': mock.Mock()
}
self.default_context['config'].get = lambda key, default=None, **kwargs: self.config.get(key, default)


def __get_template(self, template_filename):
return self.jinja_env.get_template(template_filename, globals=self.default_context)


def __run_macro(self, template, name, temporary, relation, sql):
self.default_context['model'].alias = relation
value = getattr(template.module, name)(temporary, relation, sql)
return re.sub(r'\s\s+', ' ', value)


def test_macros_load(self):
self.jinja_env.get_template('adapters.sql')


def test_macros_create_table_as(self):
template = self.__get_template('adapters.sql')
sql = self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1').strip()

self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table as select 1")

self.assertEqual(sql, "create table my_table as select 1")

def test_macros_create_table_as_file_format(self):
template = self.__get_template('adapters.sql')


self.config['file_format'] = 'delta'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table using delta as select 1")

sql = self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1').strip()
self.assertEqual(sql, "create or replace table my_table using delta as select 1")

def test_macros_create_table_as_partition(self):
template = self.__get_template('adapters.sql')


self.config['partition_by'] = 'partition_1'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table partitioned by (partition_1) as select 1")

sql = self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1').strip()
self.assertEqual(sql, "create table my_table partitioned by (partition_1) as select 1")

def test_macros_create_table_as_partitions(self):
template = self.__get_template('adapters.sql')


self.config['partition_by'] = ['partition_1', 'partition_2']
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
sql = self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1').strip()
self.assertEqual(sql,
"create table my_table partitioned by (partition_1,partition_2) as select 1")


def test_macros_create_table_as_cluster(self):
template = self.__get_template('adapters.sql')


self.config['clustered_by'] = 'cluster_1'
self.config['buckets'] = '1'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table clustered by (cluster_1) into 1 buckets as select 1")

sql = self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1').strip()
self.assertEqual(sql, "create table my_table clustered by (cluster_1) into 1 buckets as select 1")

def test_macros_create_table_as_clusters(self):
template = self.__get_template('adapters.sql')


self.config['clustered_by'] = ['cluster_1', 'cluster_2']
self.config['buckets'] = '1'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table clustered by (cluster_1,cluster_2) into 1 buckets as select 1")

sql = self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1').strip()
self.assertEqual(sql, "create table my_table clustered by (cluster_1,cluster_2) into 1 buckets as select 1")

def test_macros_create_table_as_location(self):
template = self.__get_template('adapters.sql')


self.config['location_root'] = '/mnt/root'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table location '/mnt/root/my_table' as select 1")

sql = self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1').strip()
self.assertEqual(sql, "create table my_table location '/mnt/root/my_table' as select 1")

def test_macros_create_table_as_comment(self):
template = self.__get_template('adapters.sql')


self.config['persist_docs'] = {'relation': True}
self.default_context['model'].description = 'Description Test'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table comment 'Description Test' as select 1")

sql = self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1').strip()
self.assertEqual(sql, "create table my_table comment 'Description Test' as select 1")

def test_macros_create_table_as_all(self):
template = self.__get_template('adapters.sql')
Expand All @@ -118,5 +100,8 @@ def test_macros_create_table_as_all(self):
self.config['persist_docs'] = {'relation': True}
self.default_context['model'].description = 'Description Test'

self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table using delta partitioned by (partition_1,partition_2) clustered by (cluster_1,cluster_2) into 1 buckets location '/mnt/root/my_table' comment 'Description Test' as select 1")
sql = self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1').strip()
self.assertEqual(
sql,
"create or replace table my_table using delta partitioned by (partition_1,partition_2) clustered by (cluster_1,cluster_2) into 1 buckets location '/mnt/root/my_table' comment 'Description Test' as select 1"
)