Skip to content

Commit

Permalink
PR feedback w/ improved catalog results behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob Beck committed Jan 31, 2020
1 parent 0bf6eca commit c1af3ab
Show file tree
Hide file tree
Showing 13 changed files with 73 additions and 18 deletions.
42 changes: 35 additions & 7 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import Future # noqa - we use this for typing only
from contextlib import contextmanager
from datetime import datetime
from typing import (
Expand All @@ -22,6 +23,7 @@
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedSeedNode
from dbt.exceptions import warn_or_error
from dbt.node_types import NodeType
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import filter_null_values
Expand Down Expand Up @@ -118,7 +120,7 @@ def _relation_name(rel: Optional[BaseRelation]) -> str:
return str(rel)


class SchemaSearchMap(Dict[InformationSchema, Set[str]]):
class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
"""A utility class to keep track of what information_schema tables to
search for what schemas
"""
Expand Down Expand Up @@ -1044,22 +1046,21 @@ def _get_one_catalog(
results = self._catalog_filter_table(table, manifest)
return results

def get_catalog(self, manifest: Manifest) -> agate.Table:
def get_catalog(
self, manifest: Manifest
) -> Tuple[agate.Table, List[Exception]]:
# snowflake is super slow. split it out into the specified threads
num_threads = self.config.threads
schema_map = self._get_cache_schemas(manifest)
catalogs: agate.Table = agate.Table(rows=[])

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [
executor.submit(self._get_one_catalog, info, schemas, manifest)
for info, schemas in schema_map.items() if len(schemas) > 0
]
for future in as_completed(futures):
catalog = future.result()
catalogs = agate.Table.merge([catalogs, catalog])
catalogs, exceptions = catch_as_completed(futures)

return catalogs
return catalogs, exceptions

def cancel_open_connections(self):
"""Cancel all open connections."""
Expand Down Expand Up @@ -1133,3 +1134,30 @@ def post_model_hook(self, config: Mapping[str, Any], context: Any) -> None:
The second parameter is the value returned by pre_mdoel_hook.
"""
pass


def catch_as_completed(
futures # typing: List[Future[agate.Table]]
) -> Tuple[agate.Table, List[Exception]]:

catalogs: agate.Table = agate.Table(rows=[])
exceptions: List[Exception] = []

for future in as_completed(futures):
exc = future.exception()
# we want to re-raise on ctrl+c and BaseException
if exc is None:
catalog = future.result()
catalogs = agate.Table.merge([catalogs, catalog])
elif (
isinstance(exc, KeyboardInterrupt) or
not isinstance(exc, Exception)
):
raise exc
else:
warn_or_error(
f'Encountered an error while generating catalog: {str(exc)}'
)
# exc is not None, derives from Exception, and isn't ctrl+c
exceptions.append(exc)
return catalogs, exceptions
1 change: 1 addition & 0 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,5 @@ def key(self) -> CatalogKey:
class CatalogResults(JsonSchemaMixin, Writable):
nodes: Dict[str, CatalogTable]
generated_at: datetime
errors: Optional[List[str]]
_compile_results: Optional[Any] = None
1 change: 1 addition & 0 deletions core/dbt/contracts/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def from_result(
return cls(
nodes=base.nodes,
generated_at=base.generated_at,
errors=base.errors,
_compile_results=base._compile_results,
logs=logs,
tags=tags,
Expand Down
26 changes: 23 additions & 3 deletions core/dbt/task/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from dbt.exceptions import InternalException
from dbt.include.global_project import DOCS_INDEX_FILE_PATH
from dbt.logger import GLOBAL_LOGGER as logger
import dbt.ui.printer
import dbt.utils
import dbt.compilation
Expand Down Expand Up @@ -194,7 +195,9 @@ def run(self):
dbt.ui.printer.print_timestamped_line(
'compile failed, cannot generate docs'
)
return CatalogResults({}, datetime.utcnow(), compile_results)
return CatalogResults(
{}, datetime.utcnow(), compile_results, None
)

shutil.copyfile(
DOCS_INDEX_FILE_PATH,
Expand All @@ -208,18 +211,24 @@ def run(self):
adapter = get_adapter(self.config)
with adapter.connection_named('generate_catalog'):
dbt.ui.printer.print_timestamped_line("Building catalog")
catalog_table = adapter.get_catalog(self.manifest)
catalog_table, exceptions = adapter.get_catalog(self.manifest)

catalog_data: List[PrimitiveDict] = [
dict(zip(catalog_table.column_names, map(_coerce_decimal, row)))
for row in catalog_table
]

catalog = Catalog(catalog_data)

errors: Optional[List[str]] = None
if exceptions:
errors = [str(e) for e in exceptions]

results = self.get_catalog_results(
nodes=catalog.make_unique_id_map(self.manifest),
generated_at=datetime.utcnow(),
compile_results=compile_results,
errors=errors,
)

path = os.path.join(self.config.target_path, CATALOG_FILENAME)
Expand All @@ -229,21 +238,32 @@ def run(self):
dbt.ui.printer.print_timestamped_line(
'Catalog written to {}'.format(os.path.abspath(path))
)

if exceptions:
logger.error(
'dbt encountered {} failure{} while writing the catalog'
.format(len(exceptions), (len(exceptions) == 1) * 's')
)

return results

def get_catalog_results(
self,
nodes: Dict[str, CatalogTable],
generated_at: datetime,
compile_results: Optional[Any]
compile_results: Optional[Any],
errors: Optional[List[str]]
) -> CatalogResults:
return CatalogResults(
nodes=nodes,
generated_at=generated_at,
_compile_results=compile_results,
errors=errors,
)

def interpret_results(self, results):
if results.errors:
return False
compile_results = results._compile_results
if compile_results is None:
return True
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/rpc/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ def set_args(self, params: RPCDocsGenerateParameters) -> None:
self.args.compile = params.compile

def get_catalog_results(
self, nodes, generated_at, compile_results
self, nodes, generated_at, compile_results, errors
) -> RemoteCatalogResults:
return RemoteCatalogResults(
nodes=nodes,
generated_at=datetime.utcnow(),
_compile_results=compile_results,
errors=errors,
logs=[],
)

Expand Down
2 changes: 1 addition & 1 deletion plugins/bigquery/dbt/include/bigquery/macros/catalog.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from {{ information_schema.replace(information_schema_view='SCHEMATA') }}
where (
{%- for schema in schemas -%}
schema_name = '{{ schema }}'{%- if not loop.last %} or {% endif -%}
upper(schema_name) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%}
{%- endfor -%}
)
),
Expand Down
2 changes: 1 addition & 1 deletion plugins/postgres/dbt/include/postgres/macros/catalog.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
where (
{%- for schema in schemas -%}
sch.nspname = '{{ schema }}'{%- if not loop.last %} or {% endif -%}
upper(sch.nspname) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%}
{%- endfor -%}
)
and not pg_is_other_temp_schema(sch.oid) -- not a temporary schema belonging to another session
Expand Down
4 changes: 2 additions & 2 deletions plugins/redshift/dbt/include/redshift/macros/catalog.sql
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@

where (
{%- for schema in schemas -%}
table_schema = '{{ schema }}'{%- if not loop.last %} or {% endif -%}
upper(table_schema) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%}
{%- endfor -%}
)

Expand Down Expand Up @@ -185,7 +185,7 @@
from svv_table_info
where (
{%- for schema in schemas -%}
schema = '{{ schema }}'{%- if not loop.last %} or {% endif -%}
upper(schema) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%}
{%- endfor -%}
)

Expand Down
2 changes: 1 addition & 1 deletion plugins/snowflake/dbt/include/snowflake/macros/catalog.sql
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
join columns using ("table_database", "table_schema", "table_name")
where (
{%- for schema in schemas -%}
"table_schema" = '{{ schema }}'{%- if not loop.last %} or {% endif -%}
upper("table_schema") = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%}
{%- endfor -%}
)
order by "column_index"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{% macro get_catalog(information_schemas) %}
{% macro get_catalog(information_schema, schemas) %}
{% do exceptions.raise_compiler_error('rejected: no catalogs for you') %}
{% endmacro %}
1 change: 1 addition & 0 deletions test/unit/test_docs_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def generate_catalog_dict(self, columns):
result = generate.CatalogResults(
nodes=generate.Catalog(columns).make_unique_id_map(self.manifest),
generated_at=datetime.utcnow(),
errors=None,
)
return result.to_dict(omit_none=False)['nodes']

Expand Down
3 changes: 2 additions & 1 deletion test/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,12 @@ def test_get_catalog_various_schemas(self, mock_get_schemas, mock_execute):
mock_manifest.get_used_schemas.return_value = {('dbt', 'foo'),
('dbt', 'quux')}

catalog = self.adapter.get_catalog(mock_manifest)
catalog, exceptions = self.adapter.get_catalog(mock_manifest)
self.assertEqual(
set(map(tuple, catalog)),
{('dbt', 'foo', 'bar'), ('dbt', 'FOO', 'baz'), ('dbt', 'quux', 'bar')}
)
self.assertEqual(exceptions, [])


class TestConnectingPostgresAdapter(unittest.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions third-party-stubs/agate/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class Table:
def from_object(cls, obj: Iterable[Dict[str, Any]], *, column_types: Optional['TypeTester'] = None) -> 'Table': ...
@classmethod
def from_csv(cls, path: Iterable[str], *, column_types: Optional['TypeTester'] = None) -> 'Table': ...
@classmethod
def merge(cls, tables: Iterable['Table']) -> 'Table': ...


class TypeTester:
Expand Down

0 comments on commit c1af3ab

Please sign in to comment.