Skip to content

Commit

Permalink
Seed rewrite (dbt-labs#618)
Browse files Browse the repository at this point in the history
* loader for seed data files

* Functioning rework of seed task

* Make CompilerRunner fns private and impl. SeedRunner.compile

Trying to distinguish between the public/private interface for this
class. And the SeedRunner doesn't need the functionality in the compile
function, it just needs a compile function to exist for use in the
compilation process.

* Test changes and fixes

* make the DB setup script usable locally

* convert simple copy test to use seeed

* Fixes to get Snowflake working

* New seed flag and make it non-destructive by default

* Convert update SQL script to another seed

* cleanup

* implement bigquery csv load

* context handling of StringIO

* Better typing

* strip seeder and csvkit dependency

* update bigquery to use new data typing and to fix unicode issue

* update seed test

* fix abstract functions in base adapter

* support time type

* try pinning crypto, pyopenssl versions

* remove unnecessary version pins

* insert all at once, rather than one query per row

* do not quote field names on creation

* bad

* quiet down parsedatetime logger

* pep8

* UI updates + node conformity for seed nodes

* add seed to list of resource types, cleanup

* show option for CSVs

* typo

* pep8

* move agate import to avoid strange warnings

* deprecation warning for --drop-existing

* quote column names in seed files

* revert quoting change (breaks Snowflake). Hush warnings
  • Loading branch information
b-ryan authored and drewbanin committed Feb 10, 2018
1 parent 76098ea commit 0372fef
Show file tree
Hide file tree
Showing 32 changed files with 825 additions and 429 deletions.
56 changes: 54 additions & 2 deletions dbt/adapters/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class BigQueryAdapter(PostgresAdapter):
def handle_error(cls, error, message, sql):
logger.debug(message.format(sql=sql))
logger.debug(error)
error_msg = "\n".join([error['message'] for error in error.errors])
error_msg = "\n".join(
[item['message'] for item in error.errors])

raise dbt.exceptions.DatabaseException(error_msg)

@classmethod
Expand Down Expand Up @@ -372,7 +374,8 @@ def warning_on_hooks(cls, hook_type):
dbt.ui.printer.COLOR_FG_YELLOW)

@classmethod
def add_query(cls, profile, sql, model_name=None, auto_begin=True):
def add_query(cls, profile, sql, model_name=None, auto_begin=True,
bindings=None):
if model_name in ['on-run-start', 'on-run-end']:
cls.warning_on_hooks(model_name)
else:
Expand All @@ -395,3 +398,52 @@ def quote_schema_and_table(cls, profile, schema, table, model_name=None):
return '{}.{}.{}'.format(cls.quote(project),
cls.quote(schema),
cls.quote(table))

@classmethod
def convert_text_type(cls, agate_table, col_idx):
return "string"

@classmethod
def convert_number_type(cls, agate_table, col_idx):
import agate
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float64" if decimals else "int64"

@classmethod
def convert_boolean_type(cls, agate_table, col_idx):
return "bool"

@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
return "datetime"

@classmethod
def create_csv_table(cls, profile, schema, table_name, agate_table):
pass

@classmethod
def reset_csv_table(cls, profile, schema, table_name, agate_table,
full_refresh=False):
cls.drop(profile, schema, table_name, "table")

@classmethod
def _agate_to_schema(cls, agate_table):
bq_schema = []
for idx, col_name in enumerate(agate_table.column_names):
type_ = cls.convert_agate_type(agate_table, idx)
bq_schema.append(
google.cloud.bigquery.SchemaField(col_name, type_))
return bq_schema

@classmethod
def load_csv_rows(cls, profile, schema, table_name, agate_table):
bq_schema = cls._agate_to_schema(agate_table)
dataset = cls.get_dataset(profile, schema, None)
table = dataset.table(table_name, schema=bq_schema)
conn = cls.get_connection(profile, None)
client = conn.get('handle')
with open(agate_table.original_abspath, "rb") as f:
job = table.upload_from_file(f, "CSV", rewind=True,
client=client, skip_leading_rows=1)
with cls.exception_handler(profile, "LOAD TABLE"):
cls.poll_until_job_completes(job, cls.get_timeout(conn))
85 changes: 82 additions & 3 deletions dbt/adapters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@ def cancel_connection(cls, project, connection):
raise dbt.exceptions.NotImplementedException(
'`cancel_connection` is not implemented for this adapter!')

@classmethod
def create_csv_table(cls, profile, schema, table_name, agate_table):
raise dbt.exceptions.NotImplementedException(
'`create_csv_table` is not implemented for this adapter!')

@classmethod
def reset_csv_table(cls, profile, schema, table_name, agate_table,
full_refresh=False):
raise dbt.exceptions.NotImplementedException(
'`reset_csv_table` is not implemented for this adapter!')

@classmethod
def load_csv_rows(cls, profile, schema, table_name, agate_table):
raise dbt.exceptions.NotImplementedException(
'`load_csv_rows` is not implemented for this adapter!')

###
# FUNCTIONS THAT SHOULD BE ABSTRACT
###
Expand Down Expand Up @@ -507,7 +523,8 @@ def close(cls, connection):
return connection

@classmethod
def add_query(cls, profile, sql, model_name=None, auto_begin=True):
def add_query(cls, profile, sql, model_name=None, auto_begin=True,
bindings=None):
connection = cls.get_connection(profile, model_name)
connection_name = connection.get('name')

Expand All @@ -522,7 +539,7 @@ def add_query(cls, profile, sql, model_name=None, auto_begin=True):
pre = time.time()

cursor = connection.get('handle').cursor()
cursor.execute(sql)
cursor.execute(sql, bindings)

logger.debug("SQL status: %s in %0.2f seconds",
cls.get_status(cursor), (time.time() - pre))
Expand Down Expand Up @@ -603,9 +620,71 @@ def already_exists(cls, profile, schema, table, model_name=None):

@classmethod
def quote(cls, identifier):
return '"{}"'.format(identifier)
return '"{}"'.format(identifier.replace('"', '""'))

@classmethod
def quote_schema_and_table(cls, profile, schema, table, model_name=None):
return '{}.{}'.format(cls.quote(schema),
cls.quote(table))

@classmethod
def handle_csv_table(cls, profile, schema, table_name, agate_table,
full_refresh=False):
existing = cls.query_for_existing(profile, schema)
existing_type = existing.get(table_name)
if existing_type and existing_type != "table":
raise dbt.exceptions.RuntimeException(
"Cannot seed to '{}', it is a view".format(table_name))
if existing_type:
cls.reset_csv_table(profile, schema, table_name, agate_table,
full_refresh=full_refresh)
else:
cls.create_csv_table(profile, schema, table_name, agate_table)
cls.load_csv_rows(profile, schema, table_name, agate_table)
cls.commit_if_has_connection(profile, None)

@classmethod
def convert_text_type(cls, agate_table, col_idx):
raise dbt.exceptions.NotImplementedException(
'`convert_text_type` is not implemented for this adapter!')

@classmethod
def convert_number_type(cls, agate_table, col_idx):
raise dbt.exceptions.NotImplementedException(
'`convert_number_type` is not implemented for this adapter!')

@classmethod
def convert_boolean_type(cls, agate_table, col_idx):
raise dbt.exceptions.NotImplementedException(
'`convert_boolean_type` is not implemented for this adapter!')

@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
raise dbt.exceptions.NotImplementedException(
'`convert_datetime_type` is not implemented for this adapter!')

@classmethod
def convert_date_type(cls, agate_table, col_idx):
raise dbt.exceptions.NotImplementedException(
'`convert_date_type` is not implemented for this adapter!')

@classmethod
def convert_time_type(cls, agate_table, col_idx):
raise dbt.exceptions.NotImplementedException(
'`convert_time_type` is not implemented for this adapter!')

@classmethod
def convert_agate_type(cls, agate_table, col_idx):
import agate
agate_type = agate_table.column_types[col_idx]
conversions = [
(agate.Text, cls.convert_text_type),
(agate.Number, cls.convert_number_type),
(agate.Boolean, cls.convert_boolean_type),
(agate.DateTime, cls.convert_datetime_type),
(agate.Date, cls.convert_date_type),
(agate.TimeDelta, cls.convert_time_type),
]
for agate_cls, func in conversions:
if isinstance(agate_type, agate_cls):
return func(agate_table, col_idx)
71 changes: 71 additions & 0 deletions dbt/adapters/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dbt.adapters.default
import dbt.compat
import dbt.exceptions
from dbt.utils import max_digits

from dbt.logger import GLOBAL_LOGGER as logger

Expand Down Expand Up @@ -165,3 +166,73 @@ def cancel_connection(cls, profile, connection):
res = cursor.fetchone()

logger.debug("Cancel query '{}': {}".format(connection_name, res))

@classmethod
def convert_text_type(cls, agate_table, col_idx):
return "text"

@classmethod
def convert_number_type(cls, agate_table, col_idx):
import agate
column = agate_table.columns[col_idx]
precision = max_digits(column.values_without_nulls())
# agate uses the term Precision but in this context, it is really the
# scale - ie. the number of decimal places
scale = agate_table.aggregate(agate.MaxPrecision(col_idx))
if not scale:
return "integer"
return "numeric({}, {})".format(precision, scale)

@classmethod
def convert_boolean_type(cls, agate_table, col_idx):
return "boolean"

@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
return "timestamp without time zone"

@classmethod
def convert_date_type(cls, agate_table, col_idx):
return "date"

@classmethod
def convert_time_type(cls, agate_table, col_idx):
return "time"

@classmethod
def create_csv_table(cls, profile, schema, table_name, agate_table):
col_sqls = []
for idx, col_name in enumerate(agate_table.column_names):
type_ = cls.convert_agate_type(agate_table, idx)
col_sqls.append('{} {}'.format(col_name, type_))
sql = 'create table "{}"."{}" ({})'.format(schema, table_name,
", ".join(col_sqls))
return cls.add_query(profile, sql)

@classmethod
def reset_csv_table(cls, profile, schema, table_name, agate_table,
full_refresh=False):
if full_refresh:
cls.drop_table(profile, schema, table_name, None)
cls.create_csv_table(profile, schema, table_name, agate_table)
else:
cls.truncate(profile, schema, table_name)

@classmethod
def load_csv_rows(cls, profile, schema, table_name, agate_table):
bindings = []
placeholders = []
cols_sql = ", ".join(c for c in agate_table.column_names)

for row in agate_table.rows:
bindings += row
placeholders.append("({})".format(
", ".join("%s" for _ in agate_table.column_names)))

sql = ('insert into {}.{} ({}) values {}'
.format(cls.quote(schema),
cls.quote(table_name),
cols_sql,
",\n".join(placeholders)))

cls.add_query(profile, sql, bindings=bindings)
11 changes: 11 additions & 0 deletions dbt/adapters/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,14 @@ def drop(cls, profile, schema, relation, relation_type, model_name=None):

finally:
drop_lock.release()

@classmethod
def convert_text_type(cls, agate_table, col_idx):
column = agate_table.columns[col_idx]
lens = (len(d.encode("utf-8")) for d in column.values_without_nulls())
max_len = max(lens) if lens else 64
return "varchar({})".format(max_len)

@classmethod
def convert_time_type(cls, agate_table, col_idx):
return "varchar(24)"
10 changes: 8 additions & 2 deletions dbt/adapters/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def check_schema_exists(cls, profile, schema, model_name=None):

@classmethod
def add_query(cls, profile, sql, model_name=None, auto_begin=True,
select_schema=True):
select_schema=True, bindings=None):
# snowflake only allows one query per api call.
queries = sql.strip().split(";")
cursor = None
Expand All @@ -193,6 +193,11 @@ def add_query(cls, profile, sql, model_name=None, auto_begin=True,
model_name,
auto_begin)

if bindings:
# The snowflake connector is more strict than, eg., psycopg2 -
# which allows any iterable thing to be passed as a binding.
bindings = tuple(bindings)

for individual_query in queries:
# hack -- after the last ';', remove comments and don't run
# empty queries. this avoids using exceptions as flow control,
Expand All @@ -205,7 +210,8 @@ def add_query(cls, profile, sql, model_name=None, auto_begin=True,
continue

connection, cursor = super(PostgresAdapter, cls).add_query(
profile, individual_query, model_name, auto_begin)
profile, individual_query, model_name, auto_begin,
bindings=bindings)

return connection, cursor

Expand Down
11 changes: 2 additions & 9 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,10 @@ def print_compile_stats(stats):
NodeType.Analysis: 'analyses',
NodeType.Macro: 'macros',
NodeType.Operation: 'operations',
NodeType.Seed: 'seed files',
}

results = {
NodeType.Model: 0,
NodeType.Test: 0,
NodeType.Archive: 0,
NodeType.Analysis: 0,
NodeType.Macro: 0,
NodeType.Operation: 0,
}

results = {k: 0 for k in names.keys()}
results.update(stats)

stat_line = ", ".join(
Expand Down
4 changes: 4 additions & 0 deletions dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from voluptuous import Schema, Required, All, Any, Length, ALLOW_EXTRA
from voluptuous import Optional

import dbt.exceptions

Expand Down Expand Up @@ -43,6 +44,9 @@
Required('empty'): bool,
Required('config'): config_contract,
Required('tags'): All(set),

# For csv files
Optional('agate_table'): object,
})

parsed_nodes_contract = Schema({
Expand Down
3 changes: 2 additions & 1 deletion dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
Required('resource_type'): Any(NodeType.Model,
NodeType.Test,
NodeType.Analysis,
NodeType.Operation)
NodeType.Operation,
NodeType.Seed)
})

unparsed_nodes_contract = Schema([unparsed_node_contract])
Expand Down
8 changes: 8 additions & 0 deletions dbt/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def show(self, *args, **kwargs):
# removed (in favor of 'target') in DBT version 0.7.0"""


class SeedDropExistingDeprecation(DBTDeprecation):
name = 'drop-existing'
description = """The --drop-existing argument has been deprecated. Please
use --full-refresh instead. The --drop-existing option will be removed in a
future version of dbt."""


def warn(name, *args, **kwargs):
if name not in deprecations:
# this should (hopefully) never happen
Expand All @@ -37,6 +44,7 @@ def warn(name, *args, **kwargs):
active_deprecations = set()

deprecations_list = [
SeedDropExistingDeprecation()
]

deprecations = {d.name: d for d in deprecations_list}
Expand Down
Loading

0 comments on commit 0372fef

Please sign in to comment.