From 62a0bf87327f24486386315b776534409259aab6 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 19 Jun 2020 15:10:59 -0600 Subject: [PATCH] store package info on the factory instead of globally adapter_types -> plugins Use factory.packages instead of global PACKAGES --- core/dbt/adapters/factory.py | 76 ++++++++++++++++--- core/dbt/config/runtime.py | 6 +- core/dbt/context/configured.py | 13 +++- core/dbt/context/providers.py | 6 +- core/dbt/contracts/graph/manifest.py | 12 ++- core/dbt/parser/manifest.py | 18 ++--- .../test_concurrent_transaction.py | 3 +- test/unit/test_bigquery_adapter.py | 3 +- test/unit/test_context.py | 17 +++-- test/unit/test_graph.py | 9 ++- test/unit/test_postgres_adapter.py | 5 +- test/unit/test_redshift_adapter.py | 9 ++- test/unit/test_snowflake_adapter.py | 3 +- test/unit/utils.py | 10 ++- 14 files changed, 139 insertions(+), 51 deletions(-) diff --git a/core/dbt/adapters/factory.py b/core/dbt/adapters/factory.py index 8cc8d20ae39..307abba2bc9 100644 --- a/core/dbt/adapters/factory.py +++ b/core/dbt/adapters/factory.py @@ -1,9 +1,13 @@ import threading +from pathlib import Path from importlib import import_module -from typing import Type, Dict, Any +from typing import Type, Dict, Any, List, Set, Optional -from dbt.exceptions import RuntimeException -from dbt.include.global_project import PACKAGES +from dbt.exceptions import RuntimeException, InternalException +from dbt.include.global_project import ( + PACKAGE_PATH as GLOBAL_PROJECT_PATH, + PROJECT_NAME as GLOBAL_PROJECT_NAME, +) from dbt.logger import GLOBAL_LOGGER as logger from dbt.contracts.connection import Credentials, AdapterRequiredConfig @@ -24,18 +28,25 @@ class AdpaterContainer: def __init__(self): self.lock = threading.Lock() self.adapters: Dict[str, Adapter] = {} - self.adapter_types: Dict[str, Type[Adapter]] = {} + self.plugins: Dict[str, AdapterPlugin] = {} + # map package names to their include paths + self.packages: Dict[str, Path] = { + GLOBAL_PROJECT_NAME: Path(GLOBAL_PROJECT_PATH), + } - def get_adapter_class_by_name(self, name: str) -> Type[Adapter]: + def get_plugin_by_name(self, name: str) -> AdapterPlugin: with self.lock: - if name in self.adapter_types: - return self.adapter_types[name] - - names = ", ".join(self.adapter_types.keys()) + if name in self.plugins: + return self.plugins[name] + names = ", ".join(self.plugins.keys()) message = f"Invalid adapter type {name}! Must be one of {names}" raise RuntimeException(message) + def get_adapter_class_by_name(self, name: str) -> Type[Adapter]: + plugin = self.get_plugin_by_name(name) + return plugin.adapter + def get_relation_class_by_name(self, name: str) -> Type[BaseRelation]: adapter = self.get_adapter_class_by_name(name) return adapter.Relation @@ -47,7 +58,7 @@ def get_config_class_by_name( return adapter.AdapterSpecificConfigs def load_plugin(self, name: str) -> Type[Credentials]: - # this doesn't need a lock: in the worst case we'll overwrite PACKAGES + # this doesn't need a lock: in the worst case we'll overwrite packages # and adapter_type entries with the same value, as they're all # singletons try: @@ -74,9 +85,9 @@ def load_plugin(self, name: str) -> Type[Credentials]: with self.lock: # things do hold the lock to iterate over it so we need it to add - self.adapter_types[name] = plugin.adapter + self.plugins[name] = plugin - PACKAGES[plugin.project_name] = plugin.include_path + self.packages[plugin.project_name] = Path(plugin.include_path) for dep in plugin.dependencies: self.load_plugin(dep) @@ -114,6 +125,39 @@ def cleanup_connections(self): for adapter in self.adapters.values(): adapter.cleanup_connections() + def get_adapter_package_names(self, name: Optional[str]) -> Set[str]: + if name is None: + return list(self.packages) + package_names: Set[str] = {GLOBAL_PROJECT_NAME} + plugin_names: Set[str] = {name} + while plugin_names: + plugin_name = plugin_names.pop() + try: + plugin = self.plugins[plugin_name] + except KeyError: + raise InternalException( + f'No plugin found for {plugin_name}' + ) from None + package_names.add(plugin.adapter.type()) + if plugin.dependencies is None: + continue + for dep in plugin.dependencies: + if dep not in package_names: + plugin_names.add(dep) + return package_names + + def get_include_paths(self, name: Optional[str]) -> List[Path]: + paths = [] + for package_name in self.get_adapter_package_names(name): + try: + path = self.packages[package_name] + except KeyError: + raise InternalException( + f'No internal package listing found for {package_name}' + ) + paths.append(path) + return paths + FACTORY: AdpaterContainer = AdpaterContainer() @@ -153,3 +197,11 @@ def get_relation_class_by_name(name: str) -> Type[BaseRelation]: def load_plugin(name: str) -> Type[Credentials]: return FACTORY.load_plugin(name) + + +def get_include_paths(name: Optional[str]) -> List[Path]: + return FACTORY.get_include_paths(name) + + +def get_adapter_package_names(name: Optional[str]) -> Set[str]: + return FACTORY.get_adapter_package_names(name) diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index bc212878775..bad39fa8b6c 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -12,7 +12,7 @@ from .project import Project from .renderer import DbtProjectYamlRenderer, ProfileRenderer from dbt import tracking -from dbt.adapters.factory import get_relation_class_by_name +from dbt.adapters.factory import get_relation_class_by_name, get_include_paths from dbt.helper_types import FQNPath, PathSet from dbt.context.base import generate_base_context from dbt.context.target import generate_target_context @@ -31,7 +31,6 @@ warn_or_error, raise_compiler_error ) -from dbt.include.global_project import PACKAGES from dbt.legacy_config_updater import ConfigUpdater from hologram import ValidationError @@ -323,8 +322,9 @@ def warn_for_unused_resource_config_paths( def load_dependencies(self) -> Mapping[str, 'RuntimeConfig']: if self.dependencies is None: all_projects = {self.project_name: self} + internal_packages = get_include_paths(self.credentials.type) project_paths = itertools.chain( - map(Path, PACKAGES.values()), + internal_packages, self._get_project_directories() ) for project_name, project in self.load_projects(project_paths): diff --git a/core/dbt/context/configured.py b/core/dbt/context/configured.py index a0a1510fdf8..c658240eace 100644 --- a/core/dbt/context/configured.py +++ b/core/dbt/context/configured.py @@ -1,10 +1,9 @@ -from typing import Any, Dict, Iterable, Union, Optional +from typing import Any, Dict, Iterable, Union, Optional, Set from dbt.clients.jinja import MacroGenerator, MacroStack from dbt.contracts.connection import AdapterRequiredConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedMacro -from dbt.include.global_project import PACKAGES from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME from dbt.node_types import NodeType from dbt.utils import MultiDict @@ -93,10 +92,12 @@ def __init__( root_package: str, search_package: str, thread_ctx: MacroStack, + internal_packages: Set[str], node: Optional[Any] = None, ) -> None: self.root_package = root_package self.search_package = search_package + self.internal_packages = internal_packages self.globals: FlatNamespace = {} self.locals: FlatNamespace = {} self.packages: Dict[str, FlatNamespace] = {} @@ -110,7 +111,7 @@ def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]): ) # put plugin macros into the root namespace - if macro.package_name in PACKAGES: + if macro.package_name in self.internal_packages: namespace: str = GLOBAL_PROJECT_NAME else: namespace = macro.package_name @@ -161,10 +162,16 @@ def __init__( self.macro_stack = MacroStack() def _get_namespace(self): + # avoid an import loop + from dbt.adapters.factory import get_adapter_package_names + internal_packages = get_adapter_package_names( + self.config.credentials.type + ) return MacroNamespace( self.config.project_name, self.search_package, self.macro_stack, + internal_packages, None, ) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 5c59591217b..e85be69d93b 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -8,7 +8,7 @@ from dbt.adapters.base.column import Column -from dbt.adapters.factory import get_adapter +from dbt.adapters.factory import get_adapter, get_adapter_package_names from dbt.clients import agate_helper from dbt.clients.jinja import get_rendered from dbt.config import RuntimeConfig, Project @@ -573,10 +573,14 @@ def __init__( self.db_wrapper = self.provider.DatabaseWrapper(self.adapter) def _get_namespace(self): + internal_packages = get_adapter_package_names( + self.config.credentials.type + ) return MacroNamespace( self.config.project_name, self.search_package, self.macro_stack, + internal_packages, self.model, ) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 2df6c0e8dec..42d9faa6030 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -25,7 +25,6 @@ warn_or_error, raise_invalid_patch ) from dbt.helper_types import PathSet -from dbt.include.global_project import PACKAGES from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType from dbt import deprecations @@ -468,10 +467,12 @@ def last(self) -> Optional[ParsedMacro]: return self[-1].macro -def _get_locality(macro: ParsedMacro, root_project_name: str) -> Locality: +def _get_locality( + macro: ParsedMacro, root_project_name: str, internal_packages: Set[str] +) -> Locality: if macro.package_name == root_project_name: return Locality.Root - elif macro.package_name in PACKAGES: + elif macro.package_name in internal_packages: return Locality.Core else: return Locality.Imported @@ -647,12 +648,15 @@ def _find_macros_by_name( ) -> CandidateList: """Find macros by their name. """ + # avoid an import cycle + from dbt.adapters.factory import get_adapter_package_names candidates: CandidateList = CandidateList() + packages = get_adapter_package_names(self.metadata.adapter_type) for unique_id, macro in self.macros.items(): if macro.name != name: continue candidate = MacroCandidate( - locality=_get_locality(macro, root_project_name), + locality=_get_locality(macro, root_project_name, packages), macro=macro, ) if filter is None or filter(candidate): diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 60ef395fdbd..341e9de8a08 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -9,9 +9,11 @@ import dbt.flags from dbt import deprecations -from dbt.adapters.factory import get_relation_class_by_name +from dbt.adapters.factory import ( + get_relation_class_by_name, + get_adapter_package_names, +) from dbt.helper_types import PathSet -from dbt.include.global_project import PACKAGES from dbt.logger import GLOBAL_LOGGER as logger, DbtProcessState from dbt.node_types import NodeType from dbt.clients.jinja import get_rendered @@ -119,13 +121,14 @@ def _load_macros( ) -> None: projects = self.all_projects if internal_manifest is not None: + # skip internal packages + packages = get_adapter_package_names(self.root_project.credentials.type) projects = { - k: v for k, v in self.all_projects.items() if k not in PACKAGES + k: v for k, v in self.all_projects.items() if k not in packages } self.results.macros.update(internal_manifest.macros) self.results.files.update(internal_manifest.files) - # TODO: go back to skipping the internal manifest during macro parsing for project in projects.values(): parser = MacroParser(self.results, project) for path in parser.search(): @@ -416,10 +419,6 @@ def _check_manifest(manifest: Manifest, config: RuntimeConfig) -> None: _warn_for_unused_resource_config_paths(manifest, config) -def internal_project_names(): - return iter(PACKAGES.values()) - - def _load_projects(config, paths): for path in paths: try: @@ -626,7 +625,8 @@ def process_node( def load_internal_projects(config): - return dict(_load_projects(config, internal_project_names())) + project_names = get_adapter_package_names(config.credentials.type) + return dict(_load_projects(config, project_names)) def load_internal_manifest(config: RuntimeConfig) -> Manifest: diff --git a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py index 69d2a2af903..c1051bb584b 100644 --- a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py +++ b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py @@ -4,7 +4,8 @@ def get_adapter_standalone(config): - cls = FACTORY.adapter_types[config.credentials.type] + plugin = FACTORY.plugins[config.credentials.type] + cls = plugin.adapter return cls(config) diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index 4a8ac7ae303..71c3c95d3b8 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -12,6 +12,7 @@ from dbt.adapters.bigquery import BigQueryCredentials from dbt.adapters.bigquery import BigQueryAdapter from dbt.adapters.bigquery import BigQueryRelation +from dbt.adapters.bigquery import Plugin as BigQueryPlugin from dbt.adapters.bigquery.relation import BigQueryInformationSchema from dbt.adapters.bigquery.connections import BigQueryConnectionManager from dbt.adapters.base.query_headers import MacroQueryStringSetter @@ -94,7 +95,7 @@ def get_adapter(self, target): self.mock_query_header_add = self.qh_patch.start() self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q) - inject_adapter(adapter) + inject_adapter(adapter, BigQueryPlugin) return adapter diff --git a/test/unit/test_context.py b/test/unit/test_context.py index d8d25c92d48..c8131fdf9d7 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -5,7 +5,7 @@ import pytest -# make sure 'postgres' is in PACKAGES +# make sure 'postgres' is available from dbt.adapters import postgres # noqa from dbt.adapters.base import AdapterConfig from dbt.clients.jinja import MacroStack @@ -349,6 +349,13 @@ def get_adapter(): yield patch +@pytest.fixture +def get_include_paths(): + with mock.patch.object(providers, 'get_include_paths') as patch: + patch.return_value = [] + yield patch + + @pytest.fixture def config(): return config_from_parts_or_dicts(PROJECT_DATA, PROFILE_DATA) @@ -367,7 +374,7 @@ def test_query_header_context(config, manifest): assert_has_keys(REQUIRED_QUERY_HEADER_KEYS, MAYBE_KEYS, ctx) -def test_macro_parse_context(config, manifest, get_adapter): +def test_macro_parse_context(config, manifest, get_adapter, get_include_paths): ctx = providers.generate_parser_macro( macro=manifest.macros['macro.root.macro_a'], config=config, @@ -377,7 +384,7 @@ def test_macro_parse_context(config, manifest, get_adapter): assert_has_keys(REQUIRED_MACRO_KEYS, MAYBE_KEYS, ctx) -def test_macro_runtime_context(config, manifest, get_adapter): +def test_macro_runtime_context(config, manifest, get_adapter, get_include_paths): ctx = providers.generate_runtime_macro( macro=manifest.macros['macro.root.macro_a'], config=config, @@ -387,7 +394,7 @@ def test_macro_runtime_context(config, manifest, get_adapter): assert_has_keys(REQUIRED_MACRO_KEYS, MAYBE_KEYS, ctx) -def test_model_parse_context(config, manifest, get_adapter): +def test_model_parse_context(config, manifest, get_adapter, get_include_paths): ctx = providers.generate_parser_model( model=mock_model(), config=config, @@ -397,7 +404,7 @@ def test_model_parse_context(config, manifest, get_adapter): assert_has_keys(REQUIRED_MODEL_KEYS, MAYBE_KEYS, ctx) -def test_model_runtime_context(config, manifest, get_adapter): +def test_model_runtime_context(config, manifest, get_adapter, get_include_paths): ctx = providers.generate_runtime_model( model=mock_model(), config=config, diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 1707ea753b0..7c1e9ca3cd8 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -2,6 +2,8 @@ import unittest from unittest.mock import MagicMock, patch +from dbt.adapters.postgres import Plugin as PostgresPlugin +from dbt.adapters.factory import reset_adapters import dbt.clients.system import dbt.compilation import dbt.exceptions @@ -24,7 +26,7 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa -from .utils import config_from_parts_or_dicts, generate_name_macros, MockMacro +from .utils import config_from_parts_or_dicts, generate_name_macros, MockMacro, inject_plugin class GraphTest(unittest.TestCase): @@ -39,7 +41,7 @@ def tearDown(self): self.mock_hook_constructor.stop() self.load_patch.stop() self.load_source_file_patcher.stop() - # self.relation_update_patcher.stop() + reset_adapters() def setUp(self): dbt.flags.STRICT_MODE = True @@ -104,8 +106,6 @@ def _mock_parse_result(config, all_projects): self.mock_source_file = self.load_source_file_patcher.start() self.mock_source_file.side_effect = lambda path: [n for n in self.mock_models if n.path == path][0] - # self.relation_update_patcher = patch.object(RelationUpdate, '_relation_components', lambda: []) - # self.mock_relation_update = self.relation_update_patcher.start() self.internal_manifest = Manifest.from_macros(macros={ n.unique_id: n for n in generate_name_macros('test_models_compile') }) @@ -131,6 +131,7 @@ def create_hook_patcher(cls, results, project, relative_dirs, extension): self.mock_filesystem_constructor.side_effect = create_filesystem_searcher self.mock_hook_constructor = self.hook_patcher.start() self.mock_hook_constructor.side_effect = create_hook_patcher + inject_plugin(PostgresPlugin) def get_config(self, extra_cfg=None): if extra_cfg is None: diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index 3b1ef63567e..6c5b3b0eb12 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -8,6 +8,7 @@ from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.adapters.postgres import PostgresAdapter +from dbt.adapters.postgres import Plugin as PostgresPlugin from dbt.clients import agate_helper from dbt.exceptions import ValidationException, DbtConfigError from dbt.logger import GLOBAL_LOGGER as logger # noqa @@ -50,7 +51,7 @@ def setUp(self): def adapter(self): if self._adapter is None: self._adapter = PostgresAdapter(self.config) - inject_adapter(self._adapter) + inject_adapter(self._adapter, PostgresPlugin) return self._adapter @mock.patch('dbt.adapters.postgres.connections.psycopg2') @@ -283,7 +284,7 @@ def setUp(self): self.mock_query_header_add = self.qh_patch.start() self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q) self.adapter.acquire_connection() - inject_adapter(self.adapter) + inject_adapter(self.adapter, PostgresPlugin) self.load_patch = mock.patch('dbt.parser.manifest.make_parse_result') self.mock_parse_result = self.load_patch.start() diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index f2801de8e3d..91fa30e8fd5 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -1,4 +1,3 @@ -import string import unittest from unittest import mock import agate @@ -6,12 +5,15 @@ import dbt.adapters # noqa import dbt.flags as flags -from dbt.adapters.redshift import RedshiftAdapter +from dbt.adapters.redshift import ( + RedshiftAdapter, + Plugin as RedshiftPlugin, +) from dbt.clients import agate_helper from dbt.exceptions import FailedToConnectException from dbt.logger import GLOBAL_LOGGER as logger # noqa -from .utils import config_from_parts_or_dicts, mock_connection, TestAdapterConversions +from .utils import config_from_parts_or_dicts, mock_connection, TestAdapterConversions, inject_adapter @classmethod @@ -60,6 +62,7 @@ def setUp(self): def adapter(self): if self._adapter is None: self._adapter = RedshiftAdapter(self.config) + inject_adapter(self._adapter, RedshiftPlugin) return self._adapter def test_implicit_database_conn(self): diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 7e41feeb589..9d0c60705f3 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -6,6 +6,7 @@ import dbt.flags as flags from dbt.adapters.snowflake import SnowflakeAdapter +from dbt.adapters.snowflake import Plugin as SnowflakePlugin from dbt.adapters.snowflake.column import SnowflakeColumn from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.clients import agate_helper @@ -72,7 +73,7 @@ def setUp(self): self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q) self.adapter.acquire_connection() - inject_adapter(self.adapter) + inject_adapter(self.adapter, SnowflakePlugin) def tearDown(self): # we want a unique self.handle every time. diff --git a/test/unit/utils.py b/test/unit/utils.py index d27793afd87..a3f198ab7c4 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -103,14 +103,20 @@ def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): ) -def inject_adapter(value): +def inject_plugin(plugin): + from dbt.adapters.factory import FACTORY + key = plugin.adapter.type() + FACTORY.plugins[key] = plugin + + +def inject_adapter(value, plugin): """Inject the given adapter into the adapter factory, so your hand-crafted artisanal adapter will be available from get_adapter() as if dbt loaded it. """ + inject_plugin(plugin) from dbt.adapters.factory import FACTORY key = value.type() FACTORY.adapters[key] = value - FACTORY.adapter_types[key] = type(value) class ContractTestCase(TestCase):