Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass mp_context to adapter factory #9275

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231212-154842.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: pass mp_context to adapter factory as argument instead of import
time: 2023-12-12T15:48:42.866175-08:00
custom:
Author: colin-rogers-dbt
Issue: "9025"
17 changes: 11 additions & 6 deletions core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from multiprocessing.context import SpawnContext

import threading
import traceback
from contextlib import contextmanager
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type
from typing import Any, Dict, List, Optional, Set, Type, Union

from dbt.adapters.base.plugin import AdapterPlugin
from dbt.adapters.protocol import AdapterConfig, AdapterProtocol, RelationProtocol
Expand All @@ -14,7 +16,6 @@
from dbt.adapters.include.global_project import PACKAGE_PATH as GLOBAL_PROJECT_PATH
from dbt.adapters.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
from dbt.common.semver import VersionSpecifier
from dbt.mp_context import get_mp_context

Adapter = AdapterProtocol

Expand Down Expand Up @@ -88,7 +89,9 @@ def load_plugin(self, name: str) -> Type[Credentials]:

return plugin.credentials

def register_adapter(self, config: AdapterRequiredConfig) -> None:
def register_adapter(
self, config: AdapterRequiredConfig, mp_context: Union[SpawnContext, SpawnContext]
) -> None:
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)
adapter_version = import_module(f".{adapter_name}.__version__", "dbt.adapters").version
Expand All @@ -103,7 +106,7 @@ def register_adapter(self, config: AdapterRequiredConfig) -> None:
# this shouldn't really happen...
return

adapter: Adapter = adapter_type(config, get_mp_context()) # type: ignore
adapter: Adapter = adapter_type(config, mp_context) # type: ignore
self.adapters[adapter_name] = adapter

def lookup_adapter(self, adapter_name: str) -> Adapter:
Expand Down Expand Up @@ -173,8 +176,10 @@ def get_adapter_constraint_support(self, name: Optional[str]) -> List[str]:
FACTORY: AdapterContainer = AdapterContainer()


def register_adapter(config: AdapterRequiredConfig) -> None:
FACTORY.register_adapter(config)
def register_adapter(
config: AdapterRequiredConfig, mp_context: Union[SpawnContext, SpawnContext]
) -> None:
FACTORY.register_adapter(config, mp_context)


def get_adapter(config: AdapterRequiredConfig):
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dbt.tracking
from dbt.common.invocation import reset_invocation_id
from dbt.mp_context import get_mp_context
from dbt.version import installed as installed_version
from dbt.adapters.factory import adapter_management, register_adapter, get_adapter
from dbt.flags import set_flags, get_flag_dict
Expand Down Expand Up @@ -274,7 +275,7 @@ def wrapper(*args, **kwargs):
raise DbtProjectError("profile, project, and runtime_config required for manifest")

runtime_config = ctx.obj["runtime_config"]
register_adapter(runtime_config)
register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)

Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dbt.links import ProfileConfigDocs
from dbt.common.ui import green, red
from dbt.common.events.format import pluralize
from dbt.mp_context import get_mp_context
from dbt.version import get_installed_version

from dbt.task.base import BaseTask, get_nearest_project_dir
Expand Down Expand Up @@ -443,7 +444,7 @@ def test_configuration(self, profile_status_msg, project_status_msg):
@staticmethod
def attempt_connection(profile) -> Optional[str]:
"""Return a string containing the error message, or None if there was no error."""
register_adapter(profile)
register_adapter(profile, get_mp_context())
adapter = get_adapter(profile)
try:
with adapter.connection_named("debug"):
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/tests/fixtures/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import yaml

from dbt.mp_context import get_mp_context
from dbt.parser.manifest import ManifestLoader
from dbt.common.exceptions import CompilationError, DbtDatabaseError
from dbt.context.providers import generate_runtime_macro_context
Expand Down Expand Up @@ -287,7 +288,7 @@ def adapter(
)
flags.set_from_args(args, {})
runtime_config = RuntimeConfig.from_args(args)
register_adapter(runtime_config)
register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
# We only need the base macros, not macros from dependencies, and don't want
# to run 'dbt deps' here.
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dbt.contracts.graph.manifest import MacroManifest, ManifestStateCheck
from dbt.graph import NodeSelector, parse_difference
from dbt.events.logging import setup_event_logger
from dbt.mp_context import get_mp_context

try:
from queue import Empty
Expand Down Expand Up @@ -153,7 +154,7 @@ def use_models(self, models):

def load_manifest(self, config):
inject_plugin(PostgresPlugin)
register_adapter(config)
register_adapter(config, get_mp_context())
loader = dbt.parser.manifest.ManifestLoader(config, {config.project_name: config})
loader.manifest.macros = self.macro_manifest.macros
loader.load()
Expand Down