Skip to content

Commit

Permalink
merging w feature branch
Browse files Browse the repository at this point in the history
  • Loading branch information
iknox-fa committed Dec 4, 2022
2 parents bcfe70f + e91863d commit d4bd1eb
Show file tree
Hide file tree
Showing 18 changed files with 250 additions and 1,197 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20221129-183239.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Click CLI Flags work with UserConfig
time: 2022-11-29T18:32:39.068035-05:00
custom:
Author: michelleark
Issue: "6327"
PR: "6266"
11 changes: 11 additions & 0 deletions core/dbt/cli/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# import the command you wish to run
from dbt.cli.main import run
from dbt.cli.flags import Flags

# use the command you wish to invoke
ctx = run.make_context(
"context_name_goes_here", ["--project-dir", "../dbt_projects/dbt_build_test_project"]
)

# add items to the context obj, these are required
ctx.obj = {}
42 changes: 33 additions & 9 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
from importlib import import_module
from multiprocessing import get_context
from pprint import pformat as pf
from typing import Set

from click import Context, get_current_context
from click.core import ParameterSource

from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig

if os.name != "nt":
# https://bugs.python.org/issue41567
Expand All @@ -15,12 +20,12 @@

@dataclass(frozen=True)
class Flags:
def __init__(self, ctx: Context = None) -> None:
def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None:

if ctx is None:
ctx = get_current_context()

def assign_params(ctx):
def assign_params(ctx, params_assigned_from_default):
"""Recursively adds all click params to flag object"""
for param_name, param_value in ctx.params.items():
# N.B. You have to use the base MRO method (object.__setattr__) to set attributes
Expand All @@ -29,29 +34,48 @@ def assign_params(ctx):
if hasattr(self, param_name):
raise Exception(f"Duplicate flag names found in click command: {param_name}")
object.__setattr__(self, param_name.upper(), param_value)
if ctx.get_parameter_source(param_name) == ParameterSource.DEFAULT:
params_assigned_from_default.add(param_name)
if ctx.parent:
assign_params(ctx.parent)
assign_params(ctx.parent, params_assigned_from_default)

assign_params(ctx)
params_assigned_from_default = set() # type: Set[str]
assign_params(ctx, params_assigned_from_default)

# Get the invoked command flags
if hasattr(ctx, "invoked_subcommand") and ctx.invoked_subcommand is not None:
invoked_subcommand = getattr(import_module("dbt.cli.main"), ctx.invoked_subcommand)
invoked_subcommand_name = (
ctx.invoked_subcommand if hasattr(ctx, "invoked_subcommand") else None
)
if invoked_subcommand_name is not None:
invoked_subcommand = getattr(import_module("dbt.cli.main"), invoked_subcommand_name)
invoked_subcommand.allow_extra_args = True
invoked_subcommand.ignore_unknown_options = True
invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv)
assign_params(invoked_subcommand_ctx)
assign_params(invoked_subcommand_ctx, params_assigned_from_default)

if not user_config:
profiles_dir = getattr(self, "PROFILES_DIR", None)
user_config = read_user_config(profiles_dir) if profiles_dir else None

# Overwrite default assignments with user config if available
if user_config:
for param_assigned_from_default in params_assigned_from_default:
user_config_param_value = getattr(user_config, param_assigned_from_default, None)
if user_config_param_value is not None:
object.__setattr__(
self, param_assigned_from_default.upper(), user_config_param_value
)

# Hard coded flags
object.__setattr__(self, "WHICH", ctx.info_name)
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)
object.__setattr__(self, "MP_CONTEXT", get_context("spawn"))

# Support console DO NOT TRACK initiave
object.__setattr__(
self,
"ANONYMOUS_USAGE_STATS",
False
if os.getenv("DO_NOT_TRACK", "").lower() in (1, "t", "true", "y", "yes")
if os.getenv("DO_NOT_TRACK", "").lower() in ("1", "t", "true", "y", "yes")
else True,
)

Expand Down
10 changes: 9 additions & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dbt.profiler import profiler
from dbt.task.run import RunTask
from dbt.tracking import initialize_from_flags, track_run
from dbt.config.runtime import load_project
from dbt.task.deps import DepsTask


def cli_runner():
Expand Down Expand Up @@ -240,7 +242,13 @@ def debug(ctx, **kwargs):
def deps(ctx, **kwargs):
"""Pull the most recent version of the dependencies listed in packages.yml"""
flags = Flags()
click.echo(f"`{inspect.stack()[0][3]}` called\n flags: {flags}")
project = ctx.obj["project"]

task = DepsTask.from_project(project, flags.VARS)

results = task.run()
success = task.interpret_results(results)
return results, success


# dbt init
Expand Down
1 change: 1 addition & 0 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@
envvar=None,
help="Supply variables to the project. This argument overrides variables defined in your dbt_project.yml file. This argument should be a YAML string, eg. '{my_variable: my_value}'",
type=YAML(),
default="{}",
)

version = click.option(
Expand Down
11 changes: 9 additions & 2 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
Project as ProjectContract,
SemverString,
)
from dbt.contracts.project import PackageConfig
from dbt.contracts.project import PackageConfig, ProjectPackageMetadata
from dbt.dataclass_schema import ValidationError
from .renderer import DbtProjectYamlRenderer
from .renderer import DbtProjectYamlRenderer, PackageRenderer
from .selectors import (
selector_config_from_data,
selector_data_from_root,
Expand Down Expand Up @@ -289,6 +289,13 @@ def render(self, renderer: DbtProjectYamlRenderer) -> "Project":
exc.path = os.path.join(self.project_root, "dbt_project.yml")
raise

def render_package_metadata(self, renderer: PackageRenderer) -> ProjectPackageMetadata:
packages_data = renderer.render_data(self.packages_dict)
packages_config = package_config_from_data(packages_data)
if not self.project_name:
raise DbtProjectError(DbtProjectError("Package dbt_project.yml must have a name!"))
return ProjectPackageMetadata(self.project_name, packages_config.packages)

def check_config_path(self, project_dict, deprecated_path, exp_path):
if deprecated_path in project_dict:
if exp_path in project_dict:
Expand Down
11 changes: 7 additions & 4 deletions core/dbt/deps/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import List, Optional

from dbt.clients import git, system
from dbt.config import Project
from dbt.config.project import PartialProject, Project
from dbt.config.renderer import PackageRenderer
from dbt.contracts.project import (
ProjectPackageMetadata,
GitPackage,
Expand Down Expand Up @@ -89,7 +90,9 @@ def _checkout(self):
raise
return os.path.join(get_downloads_path(), dir_)

def _fetch_metadata(self, project, renderer) -> ProjectPackageMetadata:
def _fetch_metadata(
self, project: Project, renderer: PackageRenderer
) -> ProjectPackageMetadata:
path = self._checkout()

if self.unpinned_msg() and self.warn_unpinned:
Expand All @@ -100,8 +103,8 @@ def _fetch_metadata(self, project, renderer) -> ProjectPackageMetadata:
),
log_fmt=ui.yellow("WARNING: {}"),
)
loaded = Project.from_project_root(path, renderer)
return ProjectPackageMetadata.from_project(loaded)
partial = PartialProject.from_project_root(path)
return partial.render_package_metadata(renderer)

def install(self, project, renderer):
dest_path = self.get_installation_path(project, renderer)
Expand Down
10 changes: 7 additions & 3 deletions core/dbt/deps/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
)
from dbt.events.functions import fire_event
from dbt.events.types import DepsCreatingLocalSymlink, DepsSymlinkNotAvailable
from dbt.config.project import PartialProject, Project
from dbt.config.renderer import PackageRenderer


class LocalPackageMixin:
Expand Down Expand Up @@ -39,9 +41,11 @@ def resolve_path(self, project):
project.project_root,
)

def _fetch_metadata(self, project, renderer):
loaded = project.from_project_root(self.resolve_path(project), renderer)
return ProjectPackageMetadata.from_project(loaded)
def _fetch_metadata(
self, project: Project, renderer: PackageRenderer
) -> ProjectPackageMetadata:
partial = PartialProject.from_project_root(self.resolve_path(project))
return partial.render_package_metadata(renderer)

def install(self, project, renderer):
src_path = self.resolve_path(project)
Expand Down
25 changes: 14 additions & 11 deletions core/dbt/deps/resolver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass, field
from typing import Dict, List, NoReturn, Union, Type, Iterator, Set
from typing import Dict, List, NoReturn, Union, Type, Iterator, Set, Any

from dbt.exceptions import raise_dependency_error, InternalException

from dbt.config import Project, RuntimeConfig
from dbt.config.renderer import DbtProjectYamlRenderer
from dbt.config import Project
from dbt.config.renderer import PackageRenderer
from dbt.deps.base import BasePackage, PinnedPackage, UnpinnedPackage
from dbt.deps.local import LocalUnpinnedPackage
from dbt.deps.git import GitUnpinnedPackage
Expand Down Expand Up @@ -94,19 +94,19 @@ def __iter__(self) -> Iterator[UnpinnedPackage]:

def _check_for_duplicate_project_names(
final_deps: List[PinnedPackage],
config: Project,
renderer: DbtProjectYamlRenderer,
project: Project,
renderer: PackageRenderer,
):
seen: Set[str] = set()
for package in final_deps:
project_name = package.get_project_name(config, renderer)
project_name = package.get_project_name(project, renderer)
if project_name in seen:
raise_dependency_error(
f'Found duplicate project "{project_name}". This occurs when '
"a dependency has the same project name as some other "
"dependency."
)
elif project_name == config.project_name:
elif project_name == project.project_name:
raise_dependency_error(
"Found a dependency with the same name as the root project "
f'"{project_name}". Package names must be unique in a project.'
Expand All @@ -116,21 +116,24 @@ def _check_for_duplicate_project_names(


def resolve_packages(
packages: List[PackageContract], config: RuntimeConfig
packages: List[PackageContract],
project: Project,
cli_vars: Dict[str, Any],
) -> List[PinnedPackage]:
pending = PackageListing.from_contracts(packages)
final = PackageListing()
renderer = DbtProjectYamlRenderer(config, config.cli_vars)

renderer = PackageRenderer(cli_vars)

while pending:
next_pending = PackageListing()
# resolve the dependency in question
for package in pending:
final.incorporate(package)
target = final[package].resolved().fetch_metadata(config, renderer)
target = final[package].resolved().fetch_metadata(project, renderer)
next_pending.update_from(target.packages)
pending = next_pending

resolved = final.resolved()
_check_for_duplicate_project_names(resolved, config, renderer)
_check_for_duplicate_project_names(resolved, project, renderer)
return resolved
Loading

0 comments on commit d4bd1eb

Please sign in to comment.