Skip to content

Commit

Permalink
store package info on the factory instead of globally
Browse files Browse the repository at this point in the history
adapter_types -> plugins
Use factory.packages instead of global PACKAGES
  • Loading branch information
Jacob Beck committed Jun 25, 2020
1 parent 24ae8b4 commit 62a0bf8
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 51 deletions.
76 changes: 64 additions & 12 deletions core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import threading
from pathlib import Path
from importlib import import_module
from typing import Type, Dict, Any
from typing import Type, Dict, Any, List, Set, Optional

from dbt.exceptions import RuntimeException
from dbt.include.global_project import PACKAGES
from dbt.exceptions import RuntimeException, InternalException
from dbt.include.global_project import (
PACKAGE_PATH as GLOBAL_PROJECT_PATH,
PROJECT_NAME as GLOBAL_PROJECT_NAME,
)
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.contracts.connection import Credentials, AdapterRequiredConfig

Expand All @@ -24,18 +28,25 @@ class AdpaterContainer:
def __init__(self):
self.lock = threading.Lock()
self.adapters: Dict[str, Adapter] = {}
self.adapter_types: Dict[str, Type[Adapter]] = {}
self.plugins: Dict[str, AdapterPlugin] = {}
# map package names to their include paths
self.packages: Dict[str, Path] = {
GLOBAL_PROJECT_NAME: Path(GLOBAL_PROJECT_PATH),
}

def get_adapter_class_by_name(self, name: str) -> Type[Adapter]:
def get_plugin_by_name(self, name: str) -> AdapterPlugin:
with self.lock:
if name in self.adapter_types:
return self.adapter_types[name]

names = ", ".join(self.adapter_types.keys())
if name in self.plugins:
return self.plugins[name]
names = ", ".join(self.plugins.keys())

message = f"Invalid adapter type {name}! Must be one of {names}"
raise RuntimeException(message)

def get_adapter_class_by_name(self, name: str) -> Type[Adapter]:
plugin = self.get_plugin_by_name(name)
return plugin.adapter

def get_relation_class_by_name(self, name: str) -> Type[BaseRelation]:
adapter = self.get_adapter_class_by_name(name)
return adapter.Relation
Expand All @@ -47,7 +58,7 @@ def get_config_class_by_name(
return adapter.AdapterSpecificConfigs

def load_plugin(self, name: str) -> Type[Credentials]:
# this doesn't need a lock: in the worst case we'll overwrite PACKAGES
# this doesn't need a lock: in the worst case we'll overwrite packages
# and adapter_type entries with the same value, as they're all
# singletons
try:
Expand All @@ -74,9 +85,9 @@ def load_plugin(self, name: str) -> Type[Credentials]:

with self.lock:
# things do hold the lock to iterate over it so we need it to add
self.adapter_types[name] = plugin.adapter
self.plugins[name] = plugin

PACKAGES[plugin.project_name] = plugin.include_path
self.packages[plugin.project_name] = Path(plugin.include_path)

for dep in plugin.dependencies:
self.load_plugin(dep)
Expand Down Expand Up @@ -114,6 +125,39 @@ def cleanup_connections(self):
for adapter in self.adapters.values():
adapter.cleanup_connections()

def get_adapter_package_names(self, name: Optional[str]) -> Set[str]:
if name is None:
return list(self.packages)
package_names: Set[str] = {GLOBAL_PROJECT_NAME}
plugin_names: Set[str] = {name}
while plugin_names:
plugin_name = plugin_names.pop()
try:
plugin = self.plugins[plugin_name]
except KeyError:
raise InternalException(
f'No plugin found for {plugin_name}'
) from None
package_names.add(plugin.adapter.type())
if plugin.dependencies is None:
continue
for dep in plugin.dependencies:
if dep not in package_names:
plugin_names.add(dep)
return package_names

def get_include_paths(self, name: Optional[str]) -> List[Path]:
paths = []
for package_name in self.get_adapter_package_names(name):
try:
path = self.packages[package_name]
except KeyError:
raise InternalException(
f'No internal package listing found for {package_name}'
)
paths.append(path)
return paths


FACTORY: AdpaterContainer = AdpaterContainer()

Expand Down Expand Up @@ -153,3 +197,11 @@ def get_relation_class_by_name(name: str) -> Type[BaseRelation]:

def load_plugin(name: str) -> Type[Credentials]:
return FACTORY.load_plugin(name)


def get_include_paths(name: Optional[str]) -> List[Path]:
return FACTORY.get_include_paths(name)


def get_adapter_package_names(name: Optional[str]) -> Set[str]:
return FACTORY.get_adapter_package_names(name)
6 changes: 3 additions & 3 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .project import Project
from .renderer import DbtProjectYamlRenderer, ProfileRenderer
from dbt import tracking
from dbt.adapters.factory import get_relation_class_by_name
from dbt.adapters.factory import get_relation_class_by_name, get_include_paths
from dbt.helper_types import FQNPath, PathSet
from dbt.context.base import generate_base_context
from dbt.context.target import generate_target_context
Expand All @@ -31,7 +31,6 @@
warn_or_error,
raise_compiler_error
)
from dbt.include.global_project import PACKAGES
from dbt.legacy_config_updater import ConfigUpdater

from hologram import ValidationError
Expand Down Expand Up @@ -323,8 +322,9 @@ def warn_for_unused_resource_config_paths(
def load_dependencies(self) -> Mapping[str, 'RuntimeConfig']:
if self.dependencies is None:
all_projects = {self.project_name: self}
internal_packages = get_include_paths(self.credentials.type)
project_paths = itertools.chain(
map(Path, PACKAGES.values()),
internal_packages,
self._get_project_directories()
)
for project_name, project in self.load_projects(project_paths):
Expand Down
13 changes: 10 additions & 3 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Any, Dict, Iterable, Union, Optional
from typing import Any, Dict, Iterable, Union, Optional, Set

from dbt.clients.jinja import MacroGenerator, MacroStack
from dbt.contracts.connection import AdapterRequiredConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.include.global_project import PACKAGES
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
from dbt.node_types import NodeType
from dbt.utils import MultiDict
Expand Down Expand Up @@ -93,10 +92,12 @@ def __init__(
root_package: str,
search_package: str,
thread_ctx: MacroStack,
internal_packages: Set[str],
node: Optional[Any] = None,
) -> None:
self.root_package = root_package
self.search_package = search_package
self.internal_packages = internal_packages
self.globals: FlatNamespace = {}
self.locals: FlatNamespace = {}
self.packages: Dict[str, FlatNamespace] = {}
Expand All @@ -110,7 +111,7 @@ def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):
)

# put plugin macros into the root namespace
if macro.package_name in PACKAGES:
if macro.package_name in self.internal_packages:
namespace: str = GLOBAL_PROJECT_NAME
else:
namespace = macro.package_name
Expand Down Expand Up @@ -161,10 +162,16 @@ def __init__(
self.macro_stack = MacroStack()

def _get_namespace(self):
# avoid an import loop
from dbt.adapters.factory import get_adapter_package_names
internal_packages = get_adapter_package_names(
self.config.credentials.type
)
return MacroNamespace(
self.config.project_name,
self.search_package,
self.macro_stack,
internal_packages,
None,
)

Expand Down
6 changes: 5 additions & 1 deletion core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


from dbt.adapters.base.column import Column
from dbt.adapters.factory import get_adapter
from dbt.adapters.factory import get_adapter, get_adapter_package_names
from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered
from dbt.config import RuntimeConfig, Project
Expand Down Expand Up @@ -573,10 +573,14 @@ def __init__(
self.db_wrapper = self.provider.DatabaseWrapper(self.adapter)

def _get_namespace(self):
internal_packages = get_adapter_package_names(
self.config.credentials.type
)
return MacroNamespace(
self.config.project_name,
self.search_package,
self.macro_stack,
internal_packages,
self.model,
)

Expand Down
12 changes: 8 additions & 4 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
warn_or_error, raise_invalid_patch
)
from dbt.helper_types import PathSet
from dbt.include.global_project import PACKAGES
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.node_types import NodeType
from dbt import deprecations
Expand Down Expand Up @@ -468,10 +467,12 @@ def last(self) -> Optional[ParsedMacro]:
return self[-1].macro


def _get_locality(macro: ParsedMacro, root_project_name: str) -> Locality:
def _get_locality(
macro: ParsedMacro, root_project_name: str, internal_packages: Set[str]
) -> Locality:
if macro.package_name == root_project_name:
return Locality.Root
elif macro.package_name in PACKAGES:
elif macro.package_name in internal_packages:
return Locality.Core
else:
return Locality.Imported
Expand Down Expand Up @@ -647,12 +648,15 @@ def _find_macros_by_name(
) -> CandidateList:
"""Find macros by their name.
"""
# avoid an import cycle
from dbt.adapters.factory import get_adapter_package_names
candidates: CandidateList = CandidateList()
packages = get_adapter_package_names(self.metadata.adapter_type)
for unique_id, macro in self.macros.items():
if macro.name != name:
continue
candidate = MacroCandidate(
locality=_get_locality(macro, root_project_name),
locality=_get_locality(macro, root_project_name, packages),
macro=macro,
)
if filter is None or filter(candidate):
Expand Down
18 changes: 9 additions & 9 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import dbt.flags

from dbt import deprecations
from dbt.adapters.factory import get_relation_class_by_name
from dbt.adapters.factory import (
get_relation_class_by_name,
get_adapter_package_names,
)
from dbt.helper_types import PathSet
from dbt.include.global_project import PACKAGES
from dbt.logger import GLOBAL_LOGGER as logger, DbtProcessState
from dbt.node_types import NodeType
from dbt.clients.jinja import get_rendered
Expand Down Expand Up @@ -119,13 +121,14 @@ def _load_macros(
) -> None:
projects = self.all_projects
if internal_manifest is not None:
# skip internal packages
packages = get_adapter_package_names(self.root_project.credentials.type)
projects = {
k: v for k, v in self.all_projects.items() if k not in PACKAGES
k: v for k, v in self.all_projects.items() if k not in packages
}
self.results.macros.update(internal_manifest.macros)
self.results.files.update(internal_manifest.files)

# TODO: go back to skipping the internal manifest during macro parsing
for project in projects.values():
parser = MacroParser(self.results, project)
for path in parser.search():
Expand Down Expand Up @@ -416,10 +419,6 @@ def _check_manifest(manifest: Manifest, config: RuntimeConfig) -> None:
_warn_for_unused_resource_config_paths(manifest, config)


def internal_project_names():
return iter(PACKAGES.values())


def _load_projects(config, paths):
for path in paths:
try:
Expand Down Expand Up @@ -626,7 +625,8 @@ def process_node(


def load_internal_projects(config):
return dict(_load_projects(config, internal_project_names()))
project_names = get_adapter_package_names(config.credentials.type)
return dict(_load_projects(config, project_names))


def load_internal_manifest(config: RuntimeConfig) -> Manifest:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@


def get_adapter_standalone(config):
cls = FACTORY.adapter_types[config.credentials.type]
plugin = FACTORY.plugins[config.credentials.type]
cls = plugin.adapter
return cls(config)


Expand Down
3 changes: 2 additions & 1 deletion test/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dbt.adapters.bigquery import BigQueryCredentials
from dbt.adapters.bigquery import BigQueryAdapter
from dbt.adapters.bigquery import BigQueryRelation
from dbt.adapters.bigquery import Plugin as BigQueryPlugin
from dbt.adapters.bigquery.relation import BigQueryInformationSchema
from dbt.adapters.bigquery.connections import BigQueryConnectionManager
from dbt.adapters.base.query_headers import MacroQueryStringSetter
Expand Down Expand Up @@ -94,7 +95,7 @@ def get_adapter(self, target):
self.mock_query_header_add = self.qh_patch.start()
self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q)

inject_adapter(adapter)
inject_adapter(adapter, BigQueryPlugin)
return adapter


Expand Down
Loading

0 comments on commit 62a0bf8

Please sign in to comment.