diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 428966d7f8e..96188220312 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -2,7 +2,13 @@ from dataclasses import dataclass, field from itertools import chain from typing import ( - List, Dict, Any, Optional, TypeVar, Union, Mapping, + List, + Dict, + Any, + Optional, + TypeVar, + Union, + Mapping, ) from typing_extensions import Protocol, runtime_checkable @@ -184,10 +190,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str): installed = get_installed_version() if not versions_compatible(*dbt_version): msg = IMPOSSIBLE_VERSION_ERROR.format( - package=project_name, - version_spec=[ - x.to_version_string() for x in dbt_version - ] + package=project_name, version_spec=[x.to_version_string() for x in dbt_version] ) raise DbtProjectError(msg) @@ -195,9 +198,7 @@ def validate_version(dbt_version: List[VersionSpecifier], project_name: str): msg = INVALID_VERSION_ERROR.format( package=project_name, installed=installed.to_version_string(), - version_spec=[ - x.to_version_string() for x in dbt_version - ] + version_spec=[x.to_version_string() for x in dbt_version], ) raise DbtProjectError(msg) @@ -206,8 +207,8 @@ def _get_required_version( project_dict: Dict[str, Any], verify_version: bool, ) -> List[VersionSpecifier]: - dbt_raw_version: Union[List[str], str] = '>=0.0.0' - required = project_dict.get('require-dbt-version') + dbt_raw_version: Union[List[str], str] = ">=0.0.0" + required = project_dict.get("require-dbt-version") if required is not None: dbt_raw_version = required @@ -218,26 +219,20 @@ def _get_required_version( if verify_version: # no name is also an error that we want to raise - if 'name' not in project_dict: + if "name" not in project_dict: raise DbtProjectError( 'Required "name" field not present in project', ) - validate_version(dbt_version, project_dict['name']) + validate_version(dbt_version, project_dict["name"]) return dbt_version @dataclass class RenderComponents: - project_dict: Dict[str, Any] = field( - metadata=dict(description='The project dictionary') - ) - packages_dict: Dict[str, Any] = field( - metadata=dict(description='The packages dictionary') - ) - selectors_dict: Dict[str, Any] = field( - metadata=dict(description='The selectors dictionary') - ) + project_dict: Dict[str, Any] = field(metadata=dict(description="The project dictionary")) + packages_dict: Dict[str, Any] = field(metadata=dict(description="The packages dictionary")) + selectors_dict: Dict[str, Any] = field(metadata=dict(description="The selectors dictionary")) @dataclass @@ -283,13 +278,13 @@ def get_rendered( ) # Called by 'collect_parts' in RuntimeConfig - def render(self, renderer: DbtProjectYamlRenderer) -> 'Project': + def render(self, renderer: DbtProjectYamlRenderer) -> "Project": try: rendered = self.get_rendered(renderer) return self.create_project(rendered) except DbtProjectError as exc: if exc.path is None: - exc.path = os.path.join(self.project_root, 'dbt_project.yml') + exc.path = os.path.join(self.project_root, "dbt_project.yml") raise def check_config_path(self, project_dict, deprecated_path, exp_path): @@ -307,7 +302,7 @@ def check_config_path(self, project_dict, deprecated_path, exp_path): deprecated_path=deprecated_path, exp_path=exp_path) - def create_project(self, rendered: RenderComponents) -> 'Project': + def create_project(self, rendered: RenderComponents) -> "Project": unrendered = RenderComponents( project_dict=self.project_dict, packages_dict=self.packages_dict, @@ -460,10 +455,9 @@ def from_dicts( *, verify_version: bool = False, ): - """Construct a partial project from its constituent dicts. - """ - project_name = project_dict.get('name') - profile_name = project_dict.get('profile') + """Construct a partial project from its constituent dicts.""" + project_name = project_dict.get("name") + profile_name = project_dict.get("profile") return cls( profile_name=profile_name, @@ -478,14 +472,14 @@ def from_dicts( @classmethod def from_project_root( cls, project_root: str, *, verify_version: bool = False - ) -> 'PartialProject': + ) -> "PartialProject": project_root = os.path.normpath(project_root) project_dict = _raw_project_from(project_root) - config_version = project_dict.get('config-version', 1) + config_version = project_dict.get("config-version", 1) if config_version != 2: raise DbtProjectError( - f'Invalid config version: {config_version}, expected 2', - path=os.path.join(project_root, 'dbt_project.yml') + f"Invalid config version: {config_version}, expected 2", + path=os.path.join(project_root, "dbt_project.yml"), ) packages_dict = package_data_from_root(project_root) @@ -502,15 +496,10 @@ def from_project_root( class VarProvider: """Var providers are tied to a particular Project.""" - def __init__( - self, - vars: Dict[str, Dict[str, Any]] - ) -> None: + def __init__(self, vars: Dict[str, Dict[str, Any]]) -> None: self.vars = vars - def vars_for( - self, node: IsFQNResource, adapter_type: str - ) -> Mapping[str, Any]: + def vars_for(self, node: IsFQNResource, adapter_type: str) -> Mapping[str, Any]: # in v2, vars are only either project or globally scoped merged = MultiDict([self.vars]) merged.add(self.vars.get(node.package_name, {})) @@ -563,7 +552,10 @@ class Project: def all_source_paths(self) -> List[str]: return _all_source_paths( self.model_paths, self.seed_paths, self.snapshot_paths, - self.analysis_paths, self.macro_paths + self.seed_paths, + self.snapshot_paths, + self.analysis_paths, + self.macro_paths, ) @property @@ -638,9 +630,7 @@ def validate(self): raise DbtProjectError(validator_error_message(e)) from e @classmethod - def partial_load( - cls, project_root: str, *, verify_version: bool = False - ) -> PartialProject: + def partial_load(cls, project_root: str, *, verify_version: bool = False) -> PartialProject: return PartialProject.from_project_root( project_root, verify_version=verify_version, diff --git a/core/dbt/config/renderer.py b/core/dbt/config/renderer.py index 053375902a2..43840fadc1d 100644 --- a/core/dbt/config/renderer.py +++ b/core/dbt/config/renderer.py @@ -73,15 +73,15 @@ class ProjectPostprocessor(Dict[Keypath, Callable[[Any], Any]]): def __init__(self): super().__init__() - self[('on-run-start',)] = _list_if_none_or_string - self[('on-run-end',)] = _list_if_none_or_string + self[("on-run-start",)] = _list_if_none_or_string + self[("on-run-end",)] = _list_if_none_or_string - for k in ('models', 'seeds', 'snapshots'): + for k in ("models", "seeds", "snapshots"): self[(k,)] = _dict_if_none - self[(k, 'vars')] = _dict_if_none - self[(k, 'pre-hook')] = _list_if_none_or_string - self[(k, 'post-hook')] = _list_if_none_or_string - self[('seeds', 'column_types')] = _dict_if_none + self[(k, "vars")] = _dict_if_none + self[(k, "pre-hook")] = _list_if_none_or_string + self[(k, "post-hook")] = _list_if_none_or_string + self[("seeds", "column_types")] = _dict_if_none def postprocess(self, value: Any, key: Keypath) -> Any: if key in self: @@ -128,7 +128,7 @@ def render_project( ) -> Dict[str, Any]: """Render the project and insert the project root after rendering.""" rendered_project = self.render_data(project) - rendered_project['project-root'] = project_root + rendered_project["project-root"] = project_root return rendered_project def render_packages(self, packages: Dict[str, Any]):