diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 6becd4551bf..5129d60d646 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -1,5 +1,6 @@ import abc from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import Future # noqa - we use this for typing only from contextlib import contextmanager from datetime import datetime from typing import ( @@ -22,6 +23,7 @@ from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedSeedNode +from dbt.exceptions import warn_or_error from dbt.node_types import NodeType from dbt.logger import GLOBAL_LOGGER as logger from dbt.utils import filter_null_values @@ -118,7 +120,7 @@ def _relation_name(rel: Optional[BaseRelation]) -> str: return str(rel) -class SchemaSearchMap(Dict[InformationSchema, Set[str]]): +class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]): """A utility class to keep track of what information_schema tables to search for what schemas """ @@ -1044,22 +1046,21 @@ def _get_one_catalog( results = self._catalog_filter_table(table, manifest) return results - def get_catalog(self, manifest: Manifest) -> agate.Table: + def get_catalog( + self, manifest: Manifest + ) -> Tuple[agate.Table, List[Exception]]: # snowflake is super slow. split it out into the specified threads num_threads = self.config.threads schema_map = self._get_cache_schemas(manifest) - catalogs: agate.Table = agate.Table(rows=[]) with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [ executor.submit(self._get_one_catalog, info, schemas, manifest) for info, schemas in schema_map.items() if len(schemas) > 0 ] - for future in as_completed(futures): - catalog = future.result() - catalogs = agate.Table.merge([catalogs, catalog]) + catalogs, exceptions = catch_as_completed(futures) - return catalogs + return catalogs, exceptions def cancel_open_connections(self): """Cancel all open connections.""" @@ -1133,3 +1134,30 @@ def post_model_hook(self, config: Mapping[str, Any], context: Any) -> None: The second parameter is the value returned by pre_mdoel_hook. """ pass + + +def catch_as_completed( + futures # typing: List[Future[agate.Table]] +) -> Tuple[agate.Table, List[Exception]]: + + catalogs: agate.Table = agate.Table(rows=[]) + exceptions: List[Exception] = [] + + for future in as_completed(futures): + exc = future.exception() + # we want to re-raise on ctrl+c and BaseException + if exc is None: + catalog = future.result() + catalogs = agate.Table.merge([catalogs, catalog]) + elif ( + isinstance(exc, KeyboardInterrupt) or + not isinstance(exc, Exception) + ): + raise exc + else: + warn_or_error( + f'Encountered an error while generating catalog: {str(exc)}' + ) + # exc is not None, derives from Exception, and isn't ctrl+c + exceptions.append(exc) + return catalogs, exceptions diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index 69226debd8a..ea42ad81ede 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -291,4 +291,5 @@ def key(self) -> CatalogKey: class CatalogResults(JsonSchemaMixin, Writable): nodes: Dict[str, CatalogTable] generated_at: datetime + errors: Optional[List[str]] _compile_results: Optional[Any] = None diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index 23b152f5d34..06e41e84c64 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -510,6 +510,7 @@ def from_result( return cls( nodes=base.nodes, generated_at=base.generated_at, + errors=base.errors, _compile_results=base._compile_results, logs=logs, tags=tags, diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index 97c03a1589a..00e1a052cfe 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -14,6 +14,7 @@ ) from dbt.exceptions import InternalException from dbt.include.global_project import DOCS_INDEX_FILE_PATH +from dbt.logger import GLOBAL_LOGGER as logger import dbt.ui.printer import dbt.utils import dbt.compilation @@ -194,7 +195,9 @@ def run(self): dbt.ui.printer.print_timestamped_line( 'compile failed, cannot generate docs' ) - return CatalogResults({}, datetime.utcnow(), compile_results) + return CatalogResults( + {}, datetime.utcnow(), compile_results, None + ) shutil.copyfile( DOCS_INDEX_FILE_PATH, @@ -208,7 +211,7 @@ def run(self): adapter = get_adapter(self.config) with adapter.connection_named('generate_catalog'): dbt.ui.printer.print_timestamped_line("Building catalog") - catalog_table = adapter.get_catalog(self.manifest) + catalog_table, exceptions = adapter.get_catalog(self.manifest) catalog_data: List[PrimitiveDict] = [ dict(zip(catalog_table.column_names, map(_coerce_decimal, row))) @@ -216,10 +219,16 @@ def run(self): ] catalog = Catalog(catalog_data) + + errors: Optional[List[str]] = None + if exceptions: + errors = [str(e) for e in exceptions] + results = self.get_catalog_results( nodes=catalog.make_unique_id_map(self.manifest), generated_at=datetime.utcnow(), compile_results=compile_results, + errors=errors, ) path = os.path.join(self.config.target_path, CATALOG_FILENAME) @@ -229,21 +238,32 @@ def run(self): dbt.ui.printer.print_timestamped_line( 'Catalog written to {}'.format(os.path.abspath(path)) ) + + if exceptions: + logger.error( + 'dbt encountered {} failure{} while writing the catalog' + .format(len(exceptions), (len(exceptions) == 1) * 's') + ) + return results def get_catalog_results( self, nodes: Dict[str, CatalogTable], generated_at: datetime, - compile_results: Optional[Any] + compile_results: Optional[Any], + errors: Optional[List[str]] ) -> CatalogResults: return CatalogResults( nodes=nodes, generated_at=generated_at, _compile_results=compile_results, + errors=errors, ) def interpret_results(self, results): + if results.errors: + return False compile_results = results._compile_results if compile_results is None: return True diff --git a/core/dbt/task/rpc/project_commands.py b/core/dbt/task/rpc/project_commands.py index 0449f306847..d88f46b5371 100644 --- a/core/dbt/task/rpc/project_commands.py +++ b/core/dbt/task/rpc/project_commands.py @@ -110,12 +110,13 @@ def set_args(self, params: RPCDocsGenerateParameters) -> None: self.args.compile = params.compile def get_catalog_results( - self, nodes, generated_at, compile_results + self, nodes, generated_at, compile_results, errors ) -> RemoteCatalogResults: return RemoteCatalogResults( nodes=nodes, generated_at=datetime.utcnow(), _compile_results=compile_results, + errors=errors, logs=[], ) diff --git a/plugins/bigquery/dbt/include/bigquery/macros/catalog.sql b/plugins/bigquery/dbt/include/bigquery/macros/catalog.sql index 1f468f0b339..d41b03604f4 100644 --- a/plugins/bigquery/dbt/include/bigquery/macros/catalog.sql +++ b/plugins/bigquery/dbt/include/bigquery/macros/catalog.sql @@ -17,7 +17,7 @@ from {{ information_schema.replace(information_schema_view='SCHEMATA') }} where ( {%- for schema in schemas -%} - schema_name = '{{ schema }}'{%- if not loop.last %} or {% endif -%} + upper(schema_name) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%} {%- endfor -%} ) ), diff --git a/plugins/postgres/dbt/include/postgres/macros/catalog.sql b/plugins/postgres/dbt/include/postgres/macros/catalog.sql index 7887a434275..377fdbd9ae9 100644 --- a/plugins/postgres/dbt/include/postgres/macros/catalog.sql +++ b/plugins/postgres/dbt/include/postgres/macros/catalog.sql @@ -30,7 +30,7 @@ where ( {%- for schema in schemas -%} - sch.nspname = '{{ schema }}'{%- if not loop.last %} or {% endif -%} + upper(sch.nspname) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%} {%- endfor -%} ) and not pg_is_other_temp_schema(sch.oid) -- not a temporary schema belonging to another session diff --git a/plugins/redshift/dbt/include/redshift/macros/catalog.sql b/plugins/redshift/dbt/include/redshift/macros/catalog.sql index b8239c977f6..83935f67a35 100644 --- a/plugins/redshift/dbt/include/redshift/macros/catalog.sql +++ b/plugins/redshift/dbt/include/redshift/macros/catalog.sql @@ -96,7 +96,7 @@ where ( {%- for schema in schemas -%} - table_schema = '{{ schema }}'{%- if not loop.last %} or {% endif -%} + upper(table_schema) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%} {%- endfor -%} ) @@ -185,7 +185,7 @@ from svv_table_info where ( {%- for schema in schemas -%} - schema = '{{ schema }}'{%- if not loop.last %} or {% endif -%} + upper(schema) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%} {%- endfor -%} ) diff --git a/plugins/snowflake/dbt/include/snowflake/macros/catalog.sql b/plugins/snowflake/dbt/include/snowflake/macros/catalog.sql index fe8fbf8ad90..1d796a3251d 100644 --- a/plugins/snowflake/dbt/include/snowflake/macros/catalog.sql +++ b/plugins/snowflake/dbt/include/snowflake/macros/catalog.sql @@ -51,7 +51,7 @@ join columns using ("table_database", "table_schema", "table_name") where ( {%- for schema in schemas -%} - "table_schema" = '{{ schema }}'{%- if not loop.last %} or {% endif -%} + upper("table_schema") = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%} {%- endfor -%} ) order by "column_index" diff --git a/test/integration/029_docs_generate_tests/fail_macros/failure.sql b/test/integration/029_docs_generate_tests/fail_macros/failure.sql index 56c01eb7025..f0519ed245a 100644 --- a/test/integration/029_docs_generate_tests/fail_macros/failure.sql +++ b/test/integration/029_docs_generate_tests/fail_macros/failure.sql @@ -1,3 +1,3 @@ -{% macro get_catalog(information_schemas) %} +{% macro get_catalog(information_schema, schemas) %} {% do exceptions.raise_compiler_error('rejected: no catalogs for you') %} {% endmacro %} diff --git a/test/unit/test_docs_generate.py b/test/unit/test_docs_generate.py index ede163b5d81..79c60eab37f 100644 --- a/test/unit/test_docs_generate.py +++ b/test/unit/test_docs_generate.py @@ -29,6 +29,7 @@ def generate_catalog_dict(self, columns): result = generate.CatalogResults( nodes=generate.Catalog(columns).make_unique_id_map(self.manifest), generated_at=datetime.utcnow(), + errors=None, ) return result.to_dict(omit_none=False)['nodes'] diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index 13b450a3218..16aac9e8c42 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -199,11 +199,12 @@ def test_get_catalog_various_schemas(self, mock_get_schemas, mock_execute): mock_manifest.get_used_schemas.return_value = {('dbt', 'foo'), ('dbt', 'quux')} - catalog = self.adapter.get_catalog(mock_manifest) + catalog, exceptions = self.adapter.get_catalog(mock_manifest) self.assertEqual( set(map(tuple, catalog)), {('dbt', 'foo', 'bar'), ('dbt', 'FOO', 'baz'), ('dbt', 'quux', 'bar')} ) + self.assertEqual(exceptions, []) class TestConnectingPostgresAdapter(unittest.TestCase): diff --git a/third-party-stubs/agate/__init__.pyi b/third-party-stubs/agate/__init__.pyi index 7f21d3badf7..2d1f01f1020 100644 --- a/third-party-stubs/agate/__init__.pyi +++ b/third-party-stubs/agate/__init__.pyi @@ -57,6 +57,8 @@ class Table: def from_object(cls, obj: Iterable[Dict[str, Any]], *, column_types: Optional['TypeTester'] = None) -> 'Table': ... @classmethod def from_csv(cls, path: Iterable[str], *, column_types: Optional['TypeTester'] = None) -> 'Table': ... + @classmethod + def merge(cls, tables: Iterable['Table']) -> 'Table': ... class TypeTester: