Skip to content

Commit

Permalink
Add unrendered configs to project
Browse files Browse the repository at this point in the history
automatic commit by git-black, original commits:
  5e71a2a
  • Loading branch information
Jacob Beck authored and iknox-fa committed Feb 8, 2022
1 parent 7a4c2e5 commit 75f01fe
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 51 deletions.
76 changes: 33 additions & 43 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from dataclasses import dataclass, field
from itertools import chain
from typing import (
List, Dict, Any, Optional, TypeVar, Union, Mapping,
List,
Dict,
Any,
Optional,
TypeVar,
Union,
Mapping,
)
from typing_extensions import Protocol, runtime_checkable

Expand Down Expand Up @@ -184,20 +190,15 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str):
installed = get_installed_version()
if not versions_compatible(*dbt_version):
msg = IMPOSSIBLE_VERSION_ERROR.format(
package=project_name,
version_spec=[
x.to_version_string() for x in dbt_version
]
package=project_name, version_spec=[x.to_version_string() for x in dbt_version]
)
raise DbtProjectError(msg)

if not versions_compatible(installed, *dbt_version):
msg = INVALID_VERSION_ERROR.format(
package=project_name,
installed=installed.to_version_string(),
version_spec=[
x.to_version_string() for x in dbt_version
]
version_spec=[x.to_version_string() for x in dbt_version],
)
raise DbtProjectError(msg)

Expand All @@ -206,8 +207,8 @@ def _get_required_version(
project_dict: Dict[str, Any],
verify_version: bool,
) -> List[VersionSpecifier]:
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
required = project_dict.get('require-dbt-version')
dbt_raw_version: Union[List[str], str] = ">=0.0.0"
required = project_dict.get("require-dbt-version")
if required is not None:
dbt_raw_version = required

Expand All @@ -218,26 +219,20 @@ def _get_required_version(

if verify_version:
# no name is also an error that we want to raise
if 'name' not in project_dict:
if "name" not in project_dict:
raise DbtProjectError(
'Required "name" field not present in project',
)
validate_version(dbt_version, project_dict['name'])
validate_version(dbt_version, project_dict["name"])

return dbt_version


@dataclass
class RenderComponents:
project_dict: Dict[str, Any] = field(
metadata=dict(description='The project dictionary')
)
packages_dict: Dict[str, Any] = field(
metadata=dict(description='The packages dictionary')
)
selectors_dict: Dict[str, Any] = field(
metadata=dict(description='The selectors dictionary')
)
project_dict: Dict[str, Any] = field(metadata=dict(description="The project dictionary"))
packages_dict: Dict[str, Any] = field(metadata=dict(description="The packages dictionary"))
selectors_dict: Dict[str, Any] = field(metadata=dict(description="The selectors dictionary"))


@dataclass
Expand Down Expand Up @@ -283,13 +278,13 @@ def get_rendered(
)

# Called by 'collect_parts' in RuntimeConfig
def render(self, renderer: DbtProjectYamlRenderer) -> 'Project':
def render(self, renderer: DbtProjectYamlRenderer) -> "Project":
try:
rendered = self.get_rendered(renderer)
return self.create_project(rendered)
except DbtProjectError as exc:
if exc.path is None:
exc.path = os.path.join(self.project_root, 'dbt_project.yml')
exc.path = os.path.join(self.project_root, "dbt_project.yml")
raise

def check_config_path(self, project_dict, deprecated_path, exp_path):
Expand All @@ -307,7 +302,7 @@ def check_config_path(self, project_dict, deprecated_path, exp_path):
deprecated_path=deprecated_path,
exp_path=exp_path)

def create_project(self, rendered: RenderComponents) -> 'Project':
def create_project(self, rendered: RenderComponents) -> "Project":
unrendered = RenderComponents(
project_dict=self.project_dict,
packages_dict=self.packages_dict,
Expand Down Expand Up @@ -460,10 +455,9 @@ def from_dicts(
*,
verify_version: bool = False,
):
"""Construct a partial project from its constituent dicts.
"""
project_name = project_dict.get('name')
profile_name = project_dict.get('profile')
"""Construct a partial project from its constituent dicts."""
project_name = project_dict.get("name")
profile_name = project_dict.get("profile")

return cls(
profile_name=profile_name,
Expand All @@ -478,14 +472,14 @@ def from_dicts(
@classmethod
def from_project_root(
cls, project_root: str, *, verify_version: bool = False
) -> 'PartialProject':
) -> "PartialProject":
project_root = os.path.normpath(project_root)
project_dict = _raw_project_from(project_root)
config_version = project_dict.get('config-version', 1)
config_version = project_dict.get("config-version", 1)
if config_version != 2:
raise DbtProjectError(
f'Invalid config version: {config_version}, expected 2',
path=os.path.join(project_root, 'dbt_project.yml')
f"Invalid config version: {config_version}, expected 2",
path=os.path.join(project_root, "dbt_project.yml"),
)

packages_dict = package_data_from_root(project_root)
Expand All @@ -502,15 +496,10 @@ def from_project_root(
class VarProvider:
"""Var providers are tied to a particular Project."""

def __init__(
self,
vars: Dict[str, Dict[str, Any]]
) -> None:
def __init__(self, vars: Dict[str, Dict[str, Any]]) -> None:
self.vars = vars

def vars_for(
self, node: IsFQNResource, adapter_type: str
) -> Mapping[str, Any]:
def vars_for(self, node: IsFQNResource, adapter_type: str) -> Mapping[str, Any]:
# in v2, vars are only either project or globally scoped
merged = MultiDict([self.vars])
merged.add(self.vars.get(node.package_name, {}))
Expand Down Expand Up @@ -563,7 +552,10 @@ class Project:
def all_source_paths(self) -> List[str]:
return _all_source_paths(
self.model_paths, self.seed_paths, self.snapshot_paths,
self.analysis_paths, self.macro_paths
self.seed_paths,
self.snapshot_paths,
self.analysis_paths,
self.macro_paths,
)

@property
Expand Down Expand Up @@ -638,9 +630,7 @@ def validate(self):
raise DbtProjectError(validator_error_message(e)) from e

@classmethod
def partial_load(
cls, project_root: str, *, verify_version: bool = False
) -> PartialProject:
def partial_load(cls, project_root: str, *, verify_version: bool = False) -> PartialProject:
return PartialProject.from_project_root(
project_root,
verify_version=verify_version,
Expand Down
16 changes: 8 additions & 8 deletions core/dbt/config/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ class ProjectPostprocessor(Dict[Keypath, Callable[[Any], Any]]):
def __init__(self):
super().__init__()

self[('on-run-start',)] = _list_if_none_or_string
self[('on-run-end',)] = _list_if_none_or_string
self[("on-run-start",)] = _list_if_none_or_string
self[("on-run-end",)] = _list_if_none_or_string

for k in ('models', 'seeds', 'snapshots'):
for k in ("models", "seeds", "snapshots"):
self[(k,)] = _dict_if_none
self[(k, 'vars')] = _dict_if_none
self[(k, 'pre-hook')] = _list_if_none_or_string
self[(k, 'post-hook')] = _list_if_none_or_string
self[('seeds', 'column_types')] = _dict_if_none
self[(k, "vars")] = _dict_if_none
self[(k, "pre-hook")] = _list_if_none_or_string
self[(k, "post-hook")] = _list_if_none_or_string
self[("seeds", "column_types")] = _dict_if_none

def postprocess(self, value: Any, key: Keypath) -> Any:
if key in self:
Expand Down Expand Up @@ -128,7 +128,7 @@ def render_project(
) -> Dict[str, Any]:
"""Render the project and insert the project root after rendering."""
rendered_project = self.render_data(project)
rendered_project['project-root'] = project_root
rendered_project["project-root"] = project_root
return rendered_project

def render_packages(self, packages: Dict[str, Any]):
Expand Down

0 comments on commit 75f01fe

Please sign in to comment.