diff --git a/CHANGELOG.md b/CHANGELOG.md index ae5745c88c3..26164461b38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Contributors: ### Under the hood - Bump werkzeug upper bound dependency to ` str: return cls.ConnectionManager.TYPE @property - def _macro_manifest(self) -> Manifest: + def _macro_manifest(self) -> MacroManifest: if self._macro_manifest_lazy is None: return self.load_macro_manifest() return self._macro_manifest_lazy - def check_macro_manifest(self) -> Optional[Manifest]: + def check_macro_manifest(self) -> Optional[MacroManifest]: """Return the internal manifest (used for executing macros) if it's been initialized, otherwise return None. """ return self._macro_manifest_lazy - def load_macro_manifest(self) -> Manifest: + def load_macro_manifest(self) -> MacroManifest: if self._macro_manifest_lazy is None: # avoid a circular import from dbt.parser.manifest import load_macro_manifest diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index ba85a0ee38b..d8905615f0a 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -21,8 +21,8 @@ @dataclass(frozen=True, eq=False, repr=False) class BaseRelation(FakeAPIObject, Hashable): - type: Optional[RelationType] path: Path + type: Optional[RelationType] = None quote_character: str = '"' include_policy: Policy = Policy() quote_policy: Policy = Policy() diff --git a/core/dbt/clients/jinja.py b/core/dbt/clients/jinja.py index a9a038355cb..f6f7993436a 100644 --- a/core/dbt/clients/jinja.py +++ b/core/dbt/clients/jinja.py @@ -231,6 +231,7 @@ def get_macro(self): template = self.get_template() # make the module. previously we set both vars and local, but that's # redundant: They both end up in the same place + # make_module is in jinja2.environment. It returns a TemplateModule module = template.make_module(vars=self.context, shared=False) macro = module.__dict__[get_dbt_macro_name(name)] module.__dict__.update(self.context) @@ -244,6 +245,7 @@ def exception_handler(self) -> Iterator[None]: raise_compiler_error(str(e)) def call_macro(self, *args, **kwargs): + # called from __call__ methods if self.context is None: raise InternalException( 'Context is still None in call_macro!' @@ -306,8 +308,10 @@ def exception_handler(self) -> Iterator[None]: e.stack.append(self.macro) raise e + # This adds the macro's unique id to the node's 'depends_on' @contextmanager def track_call(self): + # This is only called from __call__ if self.stack is None or self.node is None: yield else: @@ -322,6 +326,7 @@ def track_call(self): finally: self.stack.pop(unique_id) + # this makes MacroGenerator objects callable like functions def __call__(self, *args, **kwargs): with self.track_call(): return self.call_macro(*args, **kwargs) diff --git a/core/dbt/clients/system.py b/core/dbt/clients/system.py index f93a5d939d9..28f8101badb 100644 --- a/core/dbt/clients/system.py +++ b/core/dbt/clients/system.py @@ -438,7 +438,9 @@ def run_cmd( return out, err -def download(url: str, path: str, timeout: Union[float, tuple] = None) -> None: +def download( + url: str, path: str, timeout: Optional[Union[float, tuple]] = None +) -> None: path = convert_path(path) connection_timeout = timeout or float(os.getenv('DBT_HTTP_TIMEOUT', 10)) response = requests.get(url, timeout=connection_timeout) diff --git a/core/dbt/clients/yaml_helper.py b/core/dbt/clients/yaml_helper.py index 904d8aa6043..d97ec85b038 100644 --- a/core/dbt/clients/yaml_helper.py +++ b/core/dbt/clients/yaml_helper.py @@ -1,16 +1,19 @@ -from typing import Any - import dbt.exceptions import yaml import yaml.scanner # the C version is faster, but it doesn't always exist -YamlLoader: Any try: - from yaml import CSafeLoader as YamlLoader + from yaml import ( + CLoader as Loader, + CSafeLoader as SafeLoader, + CDumper as Dumper + ) except ImportError: - from yaml import SafeLoader as YamlLoader + from yaml import ( # type: ignore # noqa: F401 + Loader, SafeLoader, Dumper + ) YAML_ERROR_MESSAGE = """ @@ -54,7 +57,7 @@ def contextualized_yaml_error(raw_contents, error): def safe_load(contents): - return yaml.load(contents, Loader=YamlLoader) + return yaml.load(contents, Loader=SafeLoader) def load_yaml_text(contents): diff --git a/core/dbt/config/profile.py b/core/dbt/config/profile.py index 49c576ba218..256198929ff 100644 --- a/core/dbt/config/profile.py +++ b/core/dbt/config/profile.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional, Tuple import os -from hologram import ValidationError +from dbt.dataclass_schema import ValidationError from dbt.clients.system import load_file_contents from dbt.clients.yaml_helper import load_yaml_text @@ -75,6 +75,7 @@ def read_user_config(directory: str) -> UserConfig: if profile: user_cfg = coerce_dict_str(profile.get('config', {})) if user_cfg is not None: + UserConfig.validate(user_cfg) return UserConfig.from_dict(user_cfg) except (RuntimeException, ValidationError): pass @@ -137,10 +138,10 @@ def __eq__(self, other: object) -> bool: def validate(self): try: if self.credentials: - self.credentials.to_dict(validate=True) - ProfileConfig.from_dict( - self.to_profile_info(serialize_credentials=True) - ) + dct = self.credentials.to_dict() + self.credentials.validate(dct) + dct = self.to_profile_info(serialize_credentials=True) + ProfileConfig.validate(dct) except ValidationError as exc: raise DbtProfileError(validator_error_message(exc)) from exc @@ -160,7 +161,9 @@ def _credentials_from_profile( typename = profile.pop('type') try: cls = load_plugin(typename) - credentials = cls.from_dict(profile) + data = cls.translate_aliases(profile) + cls.validate(data) + credentials = cls.from_dict(data) except (RuntimeException, ValidationError) as e: msg = str(e) if isinstance(e, RuntimeException) else e.message raise DbtProfileError( @@ -233,6 +236,7 @@ def from_credentials( """ if user_cfg is None: user_cfg = {} + UserConfig.validate(user_cfg) config = UserConfig.from_dict(user_cfg) profile = cls( diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 1de8756a127..fe9ca743ce5 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -26,15 +26,12 @@ from dbt.utils import MultiDict from dbt.node_types import NodeType from dbt.config.selectors import SelectorDict - from dbt.contracts.project import ( Project as ProjectContract, SemverString, ) from dbt.contracts.project import PackageConfig - -from hologram import ValidationError - +from dbt.dataclass_schema import ValidationError from .renderer import DbtProjectYamlRenderer from .selectors import ( selector_config_from_data, @@ -101,6 +98,7 @@ def package_config_from_data(packages_data: Dict[str, Any]): packages_data = {'packages': []} try: + PackageConfig.validate(packages_data) packages = PackageConfig.from_dict(packages_data) except ValidationError as e: raise DbtProjectError( @@ -306,7 +304,10 @@ def create_project(self, rendered: RenderComponents) -> 'Project': ) try: - cfg = ProjectContract.from_dict(rendered.project_dict) + ProjectContract.validate(rendered.project_dict) + cfg = ProjectContract.from_dict( + rendered.project_dict + ) except ValidationError as e: raise DbtProjectError(validator_error_message(e)) from e # name/version are required in the Project definition, so we can assume @@ -586,7 +587,7 @@ def to_project_config(self, with_packages=False): def validate(self): try: - ProjectContract.from_dict(self.to_project_config()) + ProjectContract.validate(self.to_project_config()) except ValidationError as e: raise DbtProjectError(validator_error_message(e)) from e diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index f433ac411fd..ea3c028e506 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -33,7 +33,7 @@ raise_compiler_error ) -from hologram import ValidationError +from dbt.dataclass_schema import ValidationError def _project_quoting_dict( @@ -174,7 +174,7 @@ def validate(self): :raises DbtProjectError: If the configuration fails validation. """ try: - Configuration.from_dict(self.serialize()) + Configuration.validate(self.serialize()) except ValidationError as e: raise DbtProjectError(validator_error_message(e)) from e @@ -391,7 +391,7 @@ def __getattribute__(self, name): f"'UnsetConfig' object has no attribute {name}" ) - def to_dict(self): + def __post_serialize__(self, dct, options=None): return {} diff --git a/core/dbt/config/selectors.py b/core/dbt/config/selectors.py index 7b888732196..272a62edba4 100644 --- a/core/dbt/config/selectors.py +++ b/core/dbt/config/selectors.py @@ -1,8 +1,9 @@ from pathlib import Path from typing import Dict, Any -import yaml - -from hologram import ValidationError +from dbt.clients.yaml_helper import ( # noqa: F401 + yaml, Loader, Dumper, load_yaml_text +) +from dbt.dataclass_schema import ValidationError from .renderer import SelectorRenderer @@ -11,7 +12,6 @@ path_exists, resolve_path_from_base, ) -from dbt.clients.yaml_helper import load_yaml_text from dbt.contracts.selection import SelectorFile from dbt.exceptions import DbtSelectorsError, RuntimeException from dbt.graph import parse_from_selectors_definition, SelectionSpec @@ -30,9 +30,11 @@ class SelectorConfig(Dict[str, SelectionSpec]): + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'SelectorConfig': + def selectors_from_dict(cls, data: Dict[str, Any]) -> 'SelectorConfig': try: + SelectorFile.validate(data) selector_file = SelectorFile.from_dict(data) selectors = parse_from_selectors_definition(selector_file) except ValidationError as exc: @@ -66,7 +68,7 @@ def render_from_dict( f'Could not render selector data: {exc}', result_type='invalid_selector', ) from exc - return cls.from_dict(rendered) + return cls.selectors_from_dict(rendered) @classmethod def from_path( @@ -107,7 +109,7 @@ def selector_config_from_data( selectors_data = {'selectors': []} try: - selectors = SelectorConfig.from_dict(selectors_data) + selectors = SelectorConfig.selectors_from_dict(selectors_data) except ValidationError as e: raise DbtSelectorsError( MALFORMED_SELECTOR_ERROR.format(error=str(e.message)), diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index f04ea33cb12..ca23ecdaad1 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -7,13 +7,14 @@ from dbt import flags from dbt import tracking from dbt.clients.jinja import undefined_error, get_rendered -from dbt.clients import yaml_helper +from dbt.clients.yaml_helper import ( # noqa: F401 + yaml, safe_load, SafeLoader, Loader, Dumper +) from dbt.contracts.graph.compiled import CompiledResource from dbt.exceptions import raise_compiler_error, MacroReturn from dbt.logger import GLOBAL_LOGGER as logger from dbt.version import __version__ as dbt_version -import yaml # These modules are added to the context. Consider alternative # approaches which will extend well to potentially many modules import pytz @@ -172,6 +173,7 @@ def generate_builtins(self): builtins[key] = value return builtins + # no dbtClassMixin so this is not an actual override def to_dict(self): self._ctx['context'] = self._ctx builtins = self.generate_builtins() @@ -394,7 +396,7 @@ def fromyaml(value: str, default: Any = None) -> Any: -- ["good"] """ try: - return yaml_helper.safe_load(value) + return safe_load(value) except (AttributeError, ValueError, yaml.YAMLError): return default diff --git a/core/dbt/context/context_config.py b/core/dbt/context/context_config.py index 62d0ffe1440..80467ea70d3 100644 --- a/core/dbt/context/context_config.py +++ b/core/dbt/context/context_config.py @@ -165,7 +165,7 @@ def initial_result(self, resource_type: NodeType, base: bool) -> C: # Calculate the defaults. We don't want to validate the defaults, # because it might be invalid in the case of required config members # (such as on snapshots!) - result = config_cls.from_dict({}, validate=False) + result = config_cls.from_dict({}) return result def _update_from_config( diff --git a/core/dbt/context/macro_resolver.py b/core/dbt/context/macro_resolver.py new file mode 100644 index 00000000000..aae83185337 --- /dev/null +++ b/core/dbt/context/macro_resolver.py @@ -0,0 +1,153 @@ +from typing import ( + Dict, MutableMapping, Optional +) +from dbt.contracts.graph.parsed import ParsedMacro +from dbt.exceptions import raise_duplicate_macro_name, raise_compiler_error +from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME +from dbt.clients.jinja import MacroGenerator + +MacroNamespace = Dict[str, ParsedMacro] + + +# This class builds the MacroResolver by adding macros +# to various categories for finding macros in the right order, +# so that higher precedence macros are found first. +# This functionality is also provided by the MacroNamespace, +# but the intention is to eventually replace that class. +# This enables us to get the macor unique_id without +# processing every macro in the project. +class MacroResolver: + def __init__( + self, + macros: MutableMapping[str, ParsedMacro], + root_project_name: str, + internal_package_names, + ) -> None: + self.root_project_name = root_project_name + self.macros = macros + # internal packages comes from get_adapter_package_names + self.internal_package_names = internal_package_names + + # To be filled in from macros. + self.internal_packages: Dict[str, MacroNamespace] = {} + self.packages: Dict[str, MacroNamespace] = {} + self.root_package_macros: MacroNamespace = {} + + # add the macros to internal_packages, packages, and root packages + self.add_macros() + self._build_internal_packages_namespace() + self._build_macros_by_name() + + def _build_internal_packages_namespace(self): + # Iterate in reverse-order and overwrite: the packages that are first + # in the list are the ones we want to "win". + self.internal_packages_namespace: MacroNamespace = {} + for pkg in reversed(self.internal_package_names): + if pkg in self.internal_packages: + # Turn the internal packages into a flat namespace + self.internal_packages_namespace.update( + self.internal_packages[pkg]) + + def _build_macros_by_name(self): + macros_by_name = {} + # search root package macros + for macro in self.root_package_macros.values(): + macros_by_name[macro.name] = macro + # search miscellaneous non-internal packages + for fnamespace in self.packages.values(): + for macro in fnamespace.values(): + macros_by_name[macro.name] = macro + # search all internal packages + for macro in self.internal_packages_namespace.values(): + macros_by_name[macro.name] = macro + self.macros_by_name = macros_by_name + + def _add_macro_to( + self, + package_namespaces: Dict[str, MacroNamespace], + macro: ParsedMacro, + ): + if macro.package_name in package_namespaces: + namespace = package_namespaces[macro.package_name] + else: + namespace = {} + package_namespaces[macro.package_name] = namespace + + if macro.name in namespace: + raise_duplicate_macro_name( + macro, macro, macro.package_name + ) + package_namespaces[macro.package_name][macro.name] = macro + + def add_macro(self, macro: ParsedMacro): + macro_name: str = macro.name + + # internal macros (from plugins) will be processed separately from + # project macros, so store them in a different place + if macro.package_name in self.internal_package_names: + self._add_macro_to(self.internal_packages, macro) + else: + # if it's not an internal package + self._add_macro_to(self.packages, macro) + # add to root_package_macros if it's in the root package + if macro.package_name == self.root_project_name: + self.root_package_macros[macro_name] = macro + + def add_macros(self): + for macro in self.macros.values(): + self.add_macro(macro) + + def get_macro_id(self, local_package, macro_name): + local_package_macros = {} + if (local_package not in self.internal_package_names and + local_package in self.packages): + local_package_macros = self.packages[local_package] + # First: search the local packages for this macro + if macro_name in local_package_macros: + return local_package_macros[macro_name].unique_id + if macro_name in self.macros_by_name: + return self.macros_by_name[macro_name].unique_id + return None + + +# Currently this is just used by test processing in the schema +# parser (in connection with the MacroResolver). Future work +# will extend the use of these classes to other parsing areas. +# One of the features of this class compared to the MacroNamespace +# is that you can limit the number of macros provided to the +# context dictionary in the 'to_dict' manifest method. +class TestMacroNamespace: + def __init__( + self, macro_resolver, ctx, node, thread_ctx, depends_on_macros + ): + self.macro_resolver = macro_resolver + self.ctx = ctx + self.node = node + self.thread_ctx = thread_ctx + local_namespace = {} + if depends_on_macros: + for macro_unique_id in depends_on_macros: + macro = self.manifest.macros[macro_unique_id] + local_namespace[macro.name] = MacroGenerator( + macro, self.ctx, self.node, self.thread_ctx, + ) + self.local_namespace = local_namespace + + def get_from_package( + self, package_name: Optional[str], name: str + ) -> Optional[MacroGenerator]: + macro = None + if package_name is None: + macro = self.macro_resolver.macros_by_name.get(name) + elif package_name == GLOBAL_PROJECT_NAME: + macro = self.macro_resolver.internal_packages_namespace.get(name) + elif package_name in self.resolver.packages: + macro = self.macro_resolver.packages[package_name].get(name) + else: + raise_compiler_error( + f"Could not find package '{package_name}'" + ) + macro_func = MacroGenerator( + macro, self.ctx, self.node, self.thread_ctx + ) + return macro_func diff --git a/core/dbt/context/macros.py b/core/dbt/context/macros.py index dd78d7f7727..6332fb967ca 100644 --- a/core/dbt/context/macros.py +++ b/core/dbt/context/macros.py @@ -15,6 +15,10 @@ FullNamespace = Dict[str, NamespaceMember] +# The point of this class is to collect the various macros +# and provide the ability to flatten them into the ManifestContexts +# that are created for jinja, so that macro calls can be resolved. +# Creates special iterators and _keys methods to flatten the lists. class MacroNamespace(Mapping): def __init__( self, @@ -37,12 +41,16 @@ def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]: } yield self.global_project_namespace + # provides special keys method for MacroNamespace iterator + # returns keys from local_namespace, global_namespace, packages, + # global_project_namespace def _keys(self) -> Set[str]: keys: Set[str] = set() for search in self._search_order(): keys.update(search) return keys + # special iterator using special keys def __iter__(self) -> Iterator[str]: for key in self._keys(): yield key @@ -72,6 +80,10 @@ def get_from_package( ) +# This class builds the MacroNamespace by adding macros to +# internal_packages or packages, and locals/globals. +# Call 'build_namespace' to return a MacroNamespace. +# This is used by ManifestContext (and subclasses) class MacroNamespaceBuilder: def __init__( self, @@ -83,10 +95,15 @@ def __init__( ) -> None: self.root_package = root_package self.search_package = search_package + # internal packages comes from get_adapter_package_names self.internal_package_names = set(internal_packages) self.internal_package_names_order = internal_packages + # macro_func is added here if in root package self.globals: FlatNamespace = {} + # macro_func is added here if it's the package for this node self.locals: FlatNamespace = {} + # Create a dictionary of [package name][macro name] = + # MacroGenerator object which acts like a function self.internal_packages: Dict[str, FlatNamespace] = {} self.packages: Dict[str, FlatNamespace] = {} self.thread_ctx = thread_ctx @@ -94,25 +111,28 @@ def __init__( def _add_macro_to( self, - heirarchy: Dict[str, FlatNamespace], + hierarchy: Dict[str, FlatNamespace], macro: ParsedMacro, macro_func: MacroGenerator, ): - if macro.package_name in heirarchy: - namespace = heirarchy[macro.package_name] + if macro.package_name in hierarchy: + namespace = hierarchy[macro.package_name] else: namespace = {} - heirarchy[macro.package_name] = namespace + hierarchy[macro.package_name] = namespace if macro.name in namespace: raise_duplicate_macro_name( macro_func.macro, macro, macro.package_name ) - heirarchy[macro.package_name][macro.name] = macro_func + hierarchy[macro.package_name][macro.name] = macro_func def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]): macro_name: str = macro.name + # MacroGenerator is in clients/jinja.py + # a MacroGenerator object is a callable object that will + # execute the MacroGenerator.__call__ function macro_func: MacroGenerator = MacroGenerator( macro, ctx, self.node, self.thread_ctx ) @@ -122,10 +142,12 @@ def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]): if macro.package_name in self.internal_package_names: self._add_macro_to(self.internal_packages, macro, macro_func) else: + # if it's not an internal package self._add_macro_to(self.packages, macro, macro_func) - + # add to locals if it's the package this node is in if macro.package_name == self.search_package: self.locals[macro_name] = macro_func + # add to globals if it's in the root package elif macro.package_name == self.root_package: self.globals[macro_name] = macro_func @@ -143,6 +165,7 @@ def build_namespace( global_project_namespace: FlatNamespace = {} for pkg in reversed(self.internal_package_names_order): if pkg in self.internal_packages: + # add the macros pointed to by this package name global_project_namespace.update(self.internal_packages[pkg]) return MacroNamespace( diff --git a/core/dbt/context/manifest.py b/core/dbt/context/manifest.py index 60fbd6c9d1d..e9c99e33952 100644 --- a/core/dbt/context/manifest.py +++ b/core/dbt/context/manifest.py @@ -2,7 +2,8 @@ from dbt.clients.jinja import MacroStack from dbt.contracts.connection import AdapterRequiredConfig -from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.manifest import Manifest, AnyManifest +from dbt.context.macro_resolver import TestMacroNamespace from .configured import ConfiguredContext @@ -19,17 +20,25 @@ class ManifestContext(ConfiguredContext): def __init__( self, config: AdapterRequiredConfig, - manifest: Manifest, + manifest: AnyManifest, search_package: str, ) -> None: super().__init__(config) self.manifest = manifest + # this is the package of the node for which this context was built self.search_package = search_package self.macro_stack = MacroStack() + # This namespace is used by the BaseDatabaseWrapper in jinja rendering. + # The namespace is passed to it when it's constructed. It expects + # to be able to do: namespace.get_from_package(..) + self.namespace = self._build_namespace() + + def _build_namespace(self): + # this takes all the macros in the manifest and adds them + # to the MacroNamespaceBuilder stored in self.namespace builder = self._get_namespace_builder() - self.namespace = builder.build_namespace( - self.manifest.macros.values(), - self._ctx, + return builder.build_namespace( + self.manifest.macros.values(), self._ctx ) def _get_namespace_builder(self) -> MacroNamespaceBuilder: @@ -46,9 +55,15 @@ def _get_namespace_builder(self) -> MacroNamespaceBuilder: None, ) + # This does not use the Mashumaro code def to_dict(self): dct = super().to_dict() - dct.update(self.namespace) + # This moves all of the macros in the 'namespace' into top level + # keys in the manifest dictionary + if isinstance(self.namespace, TestMacroNamespace): + dct.update(self.namespace.local_namespace) + else: + dct.update(self.namespace) return dct diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index ea39a73555e..7515680ef19 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -10,15 +10,18 @@ from dbt.adapters.base.column import Column from dbt.adapters.factory import get_adapter, get_adapter_package_names from dbt.clients import agate_helper -from dbt.clients.jinja import get_rendered, MacroGenerator +from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack from dbt.config import RuntimeConfig, Project from .base import contextmember, contextproperty, Var from .configured import FQNLookup from .context_config import ContextConfig +from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace from .macros import MacroNamespaceBuilder, MacroNamespace from .manifest import ManifestContext -from dbt.contracts.graph.manifest import Manifest, Disabled from dbt.contracts.connection import AdapterResponse +from dbt.contracts.graph.manifest import ( + Manifest, AnyManifest, Disabled, MacroManifest +) from dbt.contracts.graph.compiled import ( CompiledResource, CompiledSeedNode, @@ -141,6 +144,7 @@ def dispatch( for prefix in self._get_adapter_macro_prefixes(): search_name = f'{prefix}__{macro_name}' try: + # this uses the namespace from the context macro = self._namespace.get_from_package( package_name, search_name ) @@ -638,10 +642,13 @@ def __init__( self.context_config: Optional[ContextConfig] = context_config self.provider: Provider = provider self.adapter = get_adapter(self.config) + # The macro namespace is used in creating the DatabaseWrapper self.db_wrapper = self.provider.DatabaseWrapper( self.adapter, self.namespace ) + # This overrides the method in ManifestContext, and provides + # a model, which the ManifestContext builder does not def _get_namespace_builder(self): internal_packages = get_adapter_package_names( self.config.credentials.type @@ -1203,7 +1210,7 @@ def __init__( self, model: ParsedMacro, config: RuntimeConfig, - manifest: Manifest, + manifest: AnyManifest, provider: Provider, search_package: Optional[str], ) -> None: @@ -1289,34 +1296,28 @@ def this(self) -> Optional[RelationProxy]: return self.db_wrapper.Relation.create_from(self.config, self.model) +# This is called by '_context_for', used in 'render_with_context' def generate_parser_model( model: ManifestNode, config: RuntimeConfig, - manifest: Manifest, + manifest: MacroManifest, context_config: ContextConfig, ) -> Dict[str, Any]: + # The __init__ method of ModelContext also initializes + # a ManifestContext object which creates a MacroNamespaceBuilder + # which adds every macro in the Manifest. ctx = ModelContext( model, config, manifest, ParseProvider(), context_config ) - return ctx.to_dict() - - -def generate_parser_macro( - macro: ParsedMacro, - config: RuntimeConfig, - manifest: Manifest, - package_name: Optional[str], -) -> Dict[str, Any]: - ctx = MacroContext( - macro, config, manifest, ParseProvider(), package_name - ) + # The 'to_dict' method in ManifestContext moves all of the macro names + # in the macro 'namespace' up to top level keys return ctx.to_dict() def generate_generate_component_name_macro( macro: ParsedMacro, config: RuntimeConfig, - manifest: Manifest, + manifest: MacroManifest, ) -> Dict[str, Any]: ctx = MacroContext( macro, config, manifest, GenerateNameProvider(), None @@ -1369,7 +1370,7 @@ def __call__(self, *args) -> str: def generate_parse_exposure( exposure: ParsedExposure, config: RuntimeConfig, - manifest: Manifest, + manifest: MacroManifest, package_name: str, ) -> Dict[str, Any]: project = config.load_dependencies()[package_name] @@ -1387,3 +1388,57 @@ def generate_parse_exposure( manifest, ) } + + +# This class is currently used by the schema parser in order +# to limit the number of macros in the context by using +# the TestMacroNamespace +class TestContext(ProviderContext): + def __init__( + self, + model, + config: RuntimeConfig, + manifest: Manifest, + provider: Provider, + context_config: Optional[ContextConfig], + macro_resolver: MacroResolver, + ) -> None: + # this must be before super init so that macro_resolver exists for + # build_namespace + self.macro_resolver = macro_resolver + self.thread_ctx = MacroStack() + super().__init__(model, config, manifest, provider, context_config) + self._build_test_namespace + + def _build_namespace(self): + return {} + + # this overrides _build_namespace in ManifestContext which provides a + # complete namespace of all macros to only specify macros in the depends_on + # This only provides a namespace with macros in the test node + # 'depends_on.macros' by using the TestMacroNamespace + def _build_test_namespace(self): + depends_on_macros = [] + if self.model.depends_on and self.model.depends_on.macros: + depends_on_macros = self.model.depends_on.macros + macro_namespace = TestMacroNamespace( + self.macro_resolver, self.ctx, self.node, self.thread_ctx, + depends_on_macros + ) + self._namespace = macro_namespace + + +def generate_test_context( + model: ManifestNode, + config: RuntimeConfig, + manifest: Manifest, + context_config: ContextConfig, + macro_resolver: MacroResolver +) -> Dict[str, Any]: + ctx = TestContext( + model, config, manifest, ParseProvider(), context_config, + macro_resolver + ) + # The 'to_dict' method in ManifestContext moves all of the macro names + # in the macro 'namespace' up to top level keys + return ctx.to_dict() diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index edfce5fd1aa..86c6978c7a1 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -2,28 +2,29 @@ import itertools from dataclasses import dataclass, field from typing import ( - Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Callable, + Any, ClassVar, Dict, Tuple, Iterable, Optional, List, Callable, ) +from dbt.exceptions import InternalException +from dbt.utils import translate_aliases +from dbt.logger import GLOBAL_LOGGER as logger from typing_extensions import Protocol - -from hologram import JsonSchemaMixin -from hologram.helpers import ( - StrEnum, register_pattern, ExtensibleJsonSchemaMixin +from dbt.dataclass_schema import ( + dbtClassMixin, StrEnum, ExtensibleDbtClassMixin, + ValidatedStringMixin, register_pattern ) - from dbt.contracts.util import Replaceable -from dbt.exceptions import InternalException -from dbt.utils import translate_aliases -from dbt.logger import GLOBAL_LOGGER as logger + +class Identifier(ValidatedStringMixin): + ValidationRegex = r'^[A-Za-z_][A-Za-z0-9_]+$' -Identifier = NewType('Identifier', str) +# we need register_pattern for jsonschema validation register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$') @dataclass -class AdapterResponse(JsonSchemaMixin): +class AdapterResponse(dbtClassMixin): _message: str code: Optional[str] = None rows_affected: Optional[int] = None @@ -40,20 +41,19 @@ class ConnectionState(StrEnum): @dataclass(init=False) -class Connection(ExtensibleJsonSchemaMixin, Replaceable): +class Connection(ExtensibleDbtClassMixin, Replaceable): type: Identifier - name: Optional[str] + name: Optional[str] = None state: ConnectionState = ConnectionState.INIT transaction_open: bool = False - # prevent serialization _handle: Optional[Any] = None - _credentials: JsonSchemaMixin = field(init=False) + _credentials: Optional[Any] = None def __init__( self, type: Identifier, name: Optional[str], - credentials: JsonSchemaMixin, + credentials: dbtClassMixin, state: ConnectionState = ConnectionState.INIT, transaction_open: bool = False, handle: Optional[Any] = None, @@ -113,7 +113,7 @@ def resolve(self, connection: Connection) -> Connection: # will work. @dataclass # type: ignore class Credentials( - ExtensibleJsonSchemaMixin, + ExtensibleDbtClassMixin, Replaceable, metaclass=abc.ABCMeta ): @@ -132,7 +132,7 @@ def connection_info( ) -> Iterable[Tuple[str, Any]]: """Return an ordered iterator of key/value pairs for pretty-printing. """ - as_dict = self.to_dict(omit_none=False, with_aliases=with_aliases) + as_dict = self.to_dict(options={'keep_none': True}) connection_keys = set(self._connection_keys()) aliases: List[str] = [] if with_aliases: @@ -148,9 +148,10 @@ def _connection_keys(self) -> Tuple[str, ...]: raise NotImplementedError @classmethod - def from_dict(cls, data): + def __pre_deserialize__(cls, data, options=None): + data = super().__pre_deserialize__(data, options=options) data = cls.translate_aliases(data) - return super().from_dict(data) + return data @classmethod def translate_aliases( @@ -158,31 +159,26 @@ def translate_aliases( ) -> Dict[str, Any]: return translate_aliases(kwargs, cls._ALIASES, recurse) - def to_dict(self, omit_none=True, validate=False, *, with_aliases=False): - serialized = super().to_dict(omit_none=omit_none, validate=validate) - if with_aliases: - serialized.update({ - new_name: serialized[canonical_name] + def __post_serialize__(self, dct, options=None): + # no super() -- do we need it? + if self._ALIASES: + dct.update({ + new_name: dct[canonical_name] for new_name, canonical_name in self._ALIASES.items() - if canonical_name in serialized + if canonical_name in dct }) - return serialized + return dct class UserConfigContract(Protocol): send_anonymous_usage_stats: bool - use_colors: Optional[bool] - partial_parse: Optional[bool] - printer_width: Optional[int] + use_colors: Optional[bool] = None + partial_parse: Optional[bool] = None + printer_width: Optional[int] = None def set_values(self, cookie_dir: str) -> None: ... - def to_dict( - self, omit_none: bool = True, validate: bool = False - ) -> Dict[str, Any]: - ... - class HasCredentials(Protocol): credentials: Credentials @@ -216,7 +212,7 @@ def to_target_dict(self): @dataclass -class QueryComment(JsonSchemaMixin): +class QueryComment(dbtClassMixin): comment: str = DEFAULT_QUERY_COMMENT append: bool = False diff --git a/core/dbt/contracts/files.py b/core/dbt/contracts/files.py index ec3798ef7ac..049905d03b5 100644 --- a/core/dbt/contracts/files.py +++ b/core/dbt/contracts/files.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from typing import List, Optional, Union -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin from dbt.exceptions import InternalException @@ -15,7 +15,7 @@ @dataclass -class FilePath(JsonSchemaMixin): +class FilePath(dbtClassMixin): searched_path: str relative_path: str project_root: str @@ -51,7 +51,7 @@ def seed_too_large(self) -> bool: @dataclass -class FileHash(JsonSchemaMixin): +class FileHash(dbtClassMixin): name: str # the hash type name checksum: str # the hashlib.hash_type().hexdigest() of the file contents @@ -91,7 +91,7 @@ def from_contents(cls, contents: str, name='sha256') -> 'FileHash': @dataclass -class RemoteFile(JsonSchemaMixin): +class RemoteFile(dbtClassMixin): @property def searched_path(self) -> str: return 'from remote system' @@ -110,7 +110,7 @@ def original_file_path(self): @dataclass -class SourceFile(JsonSchemaMixin): +class SourceFile(dbtClassMixin): """Define a source file in dbt""" path: Union[FilePath, RemoteFile] # the path information checksum: FileHash diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index 388fe069b63..c8f08710b94 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -19,19 +19,19 @@ from dbt.node_types import NodeType from dbt.contracts.util import Replaceable -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin from dataclasses import dataclass, field from typing import Optional, List, Union, Dict, Type @dataclass -class InjectedCTE(JsonSchemaMixin, Replaceable): +class InjectedCTE(dbtClassMixin, Replaceable): id: str sql: str @dataclass -class CompiledNodeMixin(JsonSchemaMixin): +class CompiledNodeMixin(dbtClassMixin): # this is a special mixin class to provide a required argument. If a node # is missing a `compiled` flag entirely, it must not be a CompiledNode. compiled: bool @@ -178,8 +178,7 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource: raise ValueError('invalid resource_type: {}' .format(compiled.resource_type)) - # validate=False to allow extra keys from compiling - return cls.from_dict(compiled.to_dict(), validate=False) + return cls.from_dict(compiled.to_dict()) NonSourceCompiledNode = Union[ diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index c3e4354c270..ee24e62db84 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -428,8 +428,85 @@ def _update_into(dest: MutableMapping[str, T], new_item: T): dest[unique_id] = new_item +# This contains macro methods that are in both the Manifest +# and the MacroManifest +class MacroMethods: + # Just to make mypy happy. There must be a better way. + def __init__(self): + self.macros = [] + self.metadata = {} + + def find_macro_by_name( + self, name: str, root_project_name: str, package: Optional[str] + ) -> Optional[ParsedMacro]: + """Find a macro in the graph by its name and package name, or None for + any package. The root project name is used to determine priority: + - locally defined macros come first + - then imported macros + - then macros defined in the root project + """ + filter: Optional[Callable[[MacroCandidate], bool]] = None + if package is not None: + def filter(candidate: MacroCandidate) -> bool: + return package == candidate.macro.package_name + + candidates: CandidateList = self._find_macros_by_name( + name=name, + root_project_name=root_project_name, + filter=filter, + ) + + return candidates.last() + + def find_generate_macro_by_name( + self, component: str, root_project_name: str + ) -> Optional[ParsedMacro]: + """ + The `generate_X_name` macros are similar to regular ones, but ignore + imported packages. + - if there is a `generate_{component}_name` macro in the root + project, return it + - return the `generate_{component}_name` macro from the 'dbt' + internal project + """ + def filter(candidate: MacroCandidate) -> bool: + return candidate.locality != Locality.Imported + + candidates: CandidateList = self._find_macros_by_name( + name=f'generate_{component}_name', + root_project_name=root_project_name, + # filter out imported packages + filter=filter, + ) + return candidates.last() + + def _find_macros_by_name( + self, + name: str, + root_project_name: str, + filter: Optional[Callable[[MacroCandidate], bool]] = None + ) -> CandidateList: + """Find macros by their name. + """ + # avoid an import cycle + from dbt.adapters.factory import get_adapter_package_names + candidates: CandidateList = CandidateList() + packages = set(get_adapter_package_names(self.metadata.adapter_type)) + for unique_id, macro in self.macros.items(): + if macro.name != name: + continue + candidate = MacroCandidate( + locality=_get_locality(macro, root_project_name, packages), + macro=macro, + ) + if filter is None or filter(candidate): + candidates.append(candidate) + + return candidates + + @dataclass -class Manifest: +class Manifest(MacroMethods): """The manifest for the full graph, after parsing and during compilation. """ # These attributes are both positional and by keyword. If an attribute @@ -450,27 +527,6 @@ class Manifest: _refs_cache: Optional[RefableCache] = None _lock: Lock = field(default_factory=flags.MP_CONTEXT.Lock) - @classmethod - def from_macros( - cls, - macros: Optional[MutableMapping[str, ParsedMacro]] = None, - files: Optional[MutableMapping[str, SourceFile]] = None, - ) -> 'Manifest': - if macros is None: - macros = {} - if files is None: - files = {} - return cls( - nodes={}, - sources={}, - macros=macros, - docs={}, - exposures={}, - selectors={}, - disabled=[], - files=files, - ) - def sync_update_node( self, new_node: NonSourceCompiledNode ) -> NonSourceCompiledNode: @@ -508,10 +564,12 @@ def build_flat_graph(self): """ self.flat_graph = { 'nodes': { - k: v.to_dict(omit_none=False) for k, v in self.nodes.items() + k: v.to_dict(options={'keep_none': True}) + for k, v in self.nodes.items() }, 'sources': { - k: v.to_dict(omit_none=False) for k, v in self.sources.items() + k: v.to_dict(options={'keep_none': True}) + for k, v in self.sources.items() } } @@ -536,30 +594,6 @@ def find_disabled_source_by_name( assert isinstance(result, ParsedSourceDefinition) return result - def _find_macros_by_name( - self, - name: str, - root_project_name: str, - filter: Optional[Callable[[MacroCandidate], bool]] = None - ) -> CandidateList: - """Find macros by their name. - """ - # avoid an import cycle - from dbt.adapters.factory import get_adapter_package_names - candidates: CandidateList = CandidateList() - packages = set(get_adapter_package_names(self.metadata.adapter_type)) - for unique_id, macro in self.macros.items(): - if macro.name != name: - continue - candidate = MacroCandidate( - locality=_get_locality(macro, root_project_name, packages), - macro=macro, - ) - if filter is None or filter(candidate): - candidates.append(candidate) - - return candidates - def _materialization_candidates_for( self, project_name: str, materialization_name: str, @@ -581,50 +615,6 @@ def _materialization_candidates_for( for m in self._find_macros_by_name(full_name, project_name) ) - def find_macro_by_name( - self, name: str, root_project_name: str, package: Optional[str] - ) -> Optional[ParsedMacro]: - """Find a macro in the graph by its name and package name, or None for - any package. The root project name is used to determine priority: - - locally defined macros come first - - then imported macros - - then macros defined in the root project - """ - filter: Optional[Callable[[MacroCandidate], bool]] = None - if package is not None: - def filter(candidate: MacroCandidate) -> bool: - return package == candidate.macro.package_name - - candidates: CandidateList = self._find_macros_by_name( - name=name, - root_project_name=root_project_name, - filter=filter, - ) - - return candidates.last() - - def find_generate_macro_by_name( - self, component: str, root_project_name: str - ) -> Optional[ParsedMacro]: - """ - The `generate_X_name` macros are similar to regular ones, but ignore - imported packages. - - if there is a `generate_{component}_name` macro in the root - project, return it - - return the `generate_{component}_name` macro from the 'dbt' - internal project - """ - def filter(candidate: MacroCandidate) -> bool: - return candidate.locality != Locality.Imported - - candidates: CandidateList = self._find_macros_by_name( - name=f'generate_{component}_name', - root_project_name=root_project_name, - # filter out imported packages - filter=filter, - ) - return candidates.last() - def find_materialization_macro_by_name( self, project_name: str, materialization_name: str, adapter_type: str ) -> Optional[ParsedMacro]: @@ -763,10 +753,10 @@ def writable_manifest(self): parent_map=backward_edges, ) - def to_dict(self, omit_none=True, validate=False): - return self.writable_manifest().to_dict( - omit_none=omit_none, validate=validate - ) + # When 'to_dict' is called on the Manifest, it substitues a + # WritableManifest + def __pre_serialize__(self, options=None): + return self.writable_manifest() def write(self, path): self.writable_manifest().write(path) @@ -944,6 +934,19 @@ def __reduce_ex__(self, protocol): return self.__class__, args +class MacroManifest(MacroMethods): + def __init__(self, macros, files): + self.macros = macros + self.files = files + self.metadata = ManifestMetadata() + # This is returned by the 'graph' context property + # in the ProviderContext class. + self.flat_graph = {} + + +AnyManifest = Union[Manifest, MacroManifest] + + @dataclass @schema_version('manifest', 1) class WritableManifest(ArtifactMixin): diff --git a/core/dbt/contracts/graph/model_config.py b/core/dbt/contracts/graph/model_config.py index f8b5352049a..09370e5f563 100644 --- a/core/dbt/contracts/graph/model_config.py +++ b/core/dbt/contracts/graph/model_config.py @@ -2,19 +2,12 @@ from enum import Enum from itertools import chain from typing import ( - Any, List, Optional, Dict, MutableMapping, Union, Type, NewType, Tuple, - TypeVar, Callable, cast, Hashable + Any, List, Optional, Dict, MutableMapping, Union, Type, + TypeVar, Callable, +) +from dbt.dataclass_schema import ( + dbtClassMixin, ValidationError, register_pattern, ) - -# TODO: patch+upgrade hologram to avoid this jsonschema import -import jsonschema # type: ignore - -# This is protected, but we really do want to reuse this logic, and the cache! -# It would be nice to move the custom error picking stuff into hologram! -from hologram import _validate_schema -from hologram import JsonSchemaMixin, ValidationError -from hologram.helpers import StrEnum, register_pattern - from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed from dbt.exceptions import CompilationException, InternalException from dbt.contracts.util import Replaceable, list_str @@ -170,22 +163,15 @@ def insensitive_patterns(*patterns: str): return '^({})$'.format('|'.join(lowercased)) -Severity = NewType('Severity', str) - -register_pattern(Severity, insensitive_patterns('warn', 'error')) - - -class SnapshotStrategy(StrEnum): - Timestamp = 'timestamp' - Check = 'check' +class Severity(str): + pass -class All(StrEnum): - All = 'all' +register_pattern(Severity, insensitive_patterns('warn', 'error')) @dataclass -class Hook(JsonSchemaMixin, Replaceable): +class Hook(dbtClassMixin, Replaceable): sql: str transaction: bool = True index: Optional[int] = None @@ -313,29 +299,6 @@ def _extract_dict( ) return result - def to_dict( - self, - omit_none: bool = True, - validate: bool = False, - *, - omit_hidden: bool = True, - ) -> Dict[str, Any]: - result = super().to_dict(omit_none=omit_none, validate=validate) - if omit_hidden and not omit_none: - for fld, target_field in self._get_fields(): - if target_field not in result: - continue - - # if the field is not None, preserve it regardless of the - # setting. This is in line with existing behavior, but isn't - # an endorsement of it! - if result[target_field] is not None: - continue - - if not ShowBehavior.should_show(fld): - del result[target_field] - return result - def update_from( self: T, data: Dict[str, Any], adapter_type: str, validate: bool = True ) -> T: @@ -344,7 +307,7 @@ def update_from( """ # sadly, this is a circular import from dbt.adapters.factory import get_config_class_by_name - dct = self.to_dict(omit_none=False, validate=False, omit_hidden=False) + dct = self.to_dict(options={'keep_none': True}) adapter_config_cls = get_config_class_by_name(adapter_type) @@ -358,21 +321,23 @@ def update_from( dct.update(data) # any validation failures must have come from the update - return self.from_dict(dct, validate=validate) + if validate: + self.validate(dct) + return self.from_dict(dct) def finalize_and_validate(self: T) -> T: - # from_dict will validate for us - dct = self.to_dict(omit_none=False, validate=False) + dct = self.to_dict(options={'keep_none': True}) + self.validate(dct) return self.from_dict(dct) def replace(self, **kwargs): - dct = self.to_dict(validate=False) + dct = self.to_dict() mapping = self.field_mapping() for key, value in kwargs.items(): new_key = mapping.get(key, key) dct[new_key] = value - return self.from_dict(dct, validate=False) + return self.from_dict(dct) @dataclass @@ -431,12 +396,33 @@ class NodeConfig(BaseConfig): full_refresh: Optional[bool] = None @classmethod - def from_dict(cls, data, validate=True): + def __pre_deserialize__(cls, data, options=None): + data = super().__pre_deserialize__(data, options=options) + field_map = {'post-hook': 'post_hook', 'pre-hook': 'pre_hook'} + # create a new dict because otherwise it gets overwritten in + # tests + new_dict = {} + for key in data: + new_dict[key] = data[key] + data = new_dict for key in hooks.ModelHookType: if key in data: data[key] = [hooks.get_hook_dict(h) for h in data[key]] - return super().from_dict(data, validate=validate) - + for field_name in field_map: + if field_name in data: + new_name = field_map[field_name] + data[new_name] = data.pop(field_name) + return data + + def __post_serialize__(self, dct, options=None): + dct = super().__post_serialize__(dct, options=options) + field_map = {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'} + for field_name in field_map: + if field_name in dct: + dct[field_map[field_name]] = dct.pop(field_name) + return dct + + # this is still used by jsonschema validation @classmethod def field_mapping(cls): return {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'} @@ -454,182 +440,49 @@ class TestConfig(NodeConfig): severity: Severity = Severity('ERROR') -SnapshotVariants = Union[ - 'TimestampSnapshotConfig', - 'CheckSnapshotConfig', - 'GenericSnapshotConfig', -] - - -def _relevance_without_strategy(error: jsonschema.ValidationError): - # calculate the 'relevance' of an error the normal jsonschema way, except - # if the validator is in the 'strategy' field and its conflicting with the - # 'enum'. This suppresses `"'timestamp' is not one of ['check']` and such - if 'strategy' in error.path and error.validator in {'enum', 'not'}: - length = 1 - else: - length = -len(error.path) - validator = error.validator - return length, validator not in {'anyOf', 'oneOf'} - - -@dataclass -class SnapshotWrapper(JsonSchemaMixin): - """This is a little wrapper to let us serialize/deserialize the - SnapshotVariants union. - """ - config: SnapshotVariants # mypy: ignore - - @classmethod - def validate(cls, data: Any): - config = data.get('config', {}) - - if config.get('strategy') == 'check': - schema = _validate_schema(CheckSnapshotConfig) - to_validate = config - - elif config.get('strategy') == 'timestamp': - schema = _validate_schema(TimestampSnapshotConfig) - to_validate = config - - else: - h_cls = cast(Hashable, cls) - schema = _validate_schema(h_cls) - to_validate = data - - validator = jsonschema.Draft7Validator(schema) - - error = jsonschema.exceptions.best_match( - validator.iter_errors(to_validate), - key=_relevance_without_strategy, - ) - - if error is not None: - raise ValidationError.create_from(error) from error - - @dataclass class EmptySnapshotConfig(NodeConfig): materialized: str = 'snapshot' -@dataclass(init=False) +@dataclass class SnapshotConfig(EmptySnapshotConfig): - unique_key: str = field(init=False, metadata=dict(init_required=True)) - target_schema: str = field(init=False, metadata=dict(init_required=True)) + strategy: Optional[str] = None + unique_key: Optional[str] = None + target_schema: Optional[str] = None target_database: Optional[str] = None + updated_at: Optional[str] = None + check_cols: Optional[Union[str, List[str]]] = None - def __init__( - self, - unique_key: str, - target_schema: str, - target_database: Optional[str] = None, - **kwargs - ) -> None: - self.unique_key = unique_key - self.target_schema = target_schema - self.target_database = target_database - # kwargs['materialized'] = materialized - super().__init__(**kwargs) - - # type hacks... @classmethod - def _get_fields(cls) -> List[Tuple[Field, str]]: # type: ignore - fields: List[Tuple[Field, str]] = [] - for old_field, name in super()._get_fields(): - new_field = old_field - # tell hologram we're really an initvar - if old_field.metadata and old_field.metadata.get('init_required'): - new_field = field(init=True, metadata=old_field.metadata) - new_field.name = old_field.name - new_field.type = old_field.type - new_field._field_type = old_field._field_type # type: ignore - fields.append((new_field, name)) - return fields - - def finalize_and_validate(self: 'SnapshotConfig') -> SnapshotVariants: + def validate(cls, data): + super().validate(data) + if data.get('strategy') == 'check': + if not data.get('check_cols'): + raise ValidationError( + "A snapshot configured with the check strategy must " + "specify a check_cols configuration.") + if (isinstance(data['check_cols'], str) and + data['check_cols'] != 'all'): + raise ValidationError( + f"Invalid value for 'check_cols': {data['check_cols']}. " + "Expected 'all' or a list of strings.") + + elif data.get('strategy') == 'timestamp': + if not data.get('updated_at'): + raise ValidationError( + "A snapshot configured with the timestamp strategy " + "must specify an updated_at configuration.") + if data.get('check_cols'): + raise ValidationError( + "A 'timestamp' snapshot should not have 'check_cols'") + # If the strategy is not 'check' or 'timestamp' it's a custom strategy, + # formerly supported with GenericSnapshotConfig + + def finalize_and_validate(self): data = self.to_dict() - return SnapshotWrapper.from_dict({'config': data}).config - - -@dataclass(init=False) -class GenericSnapshotConfig(SnapshotConfig): - strategy: str = field(init=False, metadata=dict(init_required=True)) - - def __init__(self, strategy: str, **kwargs) -> None: - self.strategy = strategy - super().__init__(**kwargs) - - @classmethod - def _collect_json_schema( - cls, definitions: Dict[str, Any] - ) -> Dict[str, Any]: - # this is the method you want to override in hologram if you want - # to do clever things about the json schema and have classes that - # contain instances of your JsonSchemaMixin respect the change. - schema = super()._collect_json_schema(definitions) - - # Instead of just the strategy we'd calculate normally, say - # "this strategy except none of our specialization strategies". - strategies = [schema['properties']['strategy']] - for specialization in (TimestampSnapshotConfig, CheckSnapshotConfig): - strategies.append( - {'not': specialization.json_schema()['properties']['strategy']} - ) - - schema['properties']['strategy'] = { - 'allOf': strategies - } - return schema - - -@dataclass(init=False) -class TimestampSnapshotConfig(SnapshotConfig): - strategy: str = field( - init=False, - metadata=dict( - restrict=[str(SnapshotStrategy.Timestamp)], - init_required=True, - ), - ) - updated_at: str = field(init=False, metadata=dict(init_required=True)) - - def __init__( - self, strategy: str, updated_at: str, **kwargs - ) -> None: - self.strategy = strategy - self.updated_at = updated_at - super().__init__(**kwargs) - - -@dataclass(init=False) -class CheckSnapshotConfig(SnapshotConfig): - strategy: str = field( - init=False, - metadata=dict( - restrict=[str(SnapshotStrategy.Check)], - init_required=True, - ), - ) - # TODO: is there a way to get this to accept tuples of strings? Adding - # `Tuple[str, ...]` to the list of types results in this: - # ['email'] is valid under each of {'type': 'array', 'items': - # {'type': 'string'}}, {'type': 'array', 'items': {'type': 'string'}} - # but without it, parsing gets upset about values like `('email',)` - # maybe hologram itself should support this behavior? It's not like tuples - # are meaningful in json - check_cols: Union[All, List[str]] = field( - init=False, - metadata=dict(init_required=True), - ) - - def __init__( - self, strategy: str, check_cols: Union[All, List[str]], - **kwargs - ) -> None: - self.strategy = strategy - self.check_cols = check_cols - super().__init__(**kwargs) + self.validate(data) + return self.from_dict(data) RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = { diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index 0a7e41c26a0..a4020e2a487 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -13,8 +13,9 @@ TypeVar, ) -from hologram import JsonSchemaMixin -from hologram.helpers import ExtensibleJsonSchemaMixin +from dbt.dataclass_schema import ( + dbtClassMixin, ExtensibleDbtClassMixin +) from dbt.clients.system import write_file from dbt.contracts.files import FileHash, MAXIMUM_SEED_SIZE_NAME @@ -38,20 +39,14 @@ TestConfig, SourceConfig, EmptySnapshotConfig, - SnapshotVariants, -) -# import these 3 so the SnapshotVariants forward ref works. -from .model_config import ( # noqa - TimestampSnapshotConfig, - CheckSnapshotConfig, - GenericSnapshotConfig, + SnapshotConfig, ) @dataclass class ColumnInfo( AdditionalPropertiesMixin, - ExtensibleJsonSchemaMixin, + ExtensibleDbtClassMixin, Replaceable ): name: str @@ -64,7 +59,7 @@ class ColumnInfo( @dataclass -class HasFqn(JsonSchemaMixin, Replaceable): +class HasFqn(dbtClassMixin, Replaceable): fqn: List[str] def same_fqn(self, other: 'HasFqn') -> bool: @@ -72,12 +67,12 @@ def same_fqn(self, other: 'HasFqn') -> bool: @dataclass -class HasUniqueID(JsonSchemaMixin, Replaceable): +class HasUniqueID(dbtClassMixin, Replaceable): unique_id: str @dataclass -class MacroDependsOn(JsonSchemaMixin, Replaceable): +class MacroDependsOn(dbtClassMixin, Replaceable): macros: List[str] = field(default_factory=list) # 'in' on lists is O(n) so this is O(n^2) for # of macros @@ -96,12 +91,22 @@ def add_node(self, value: str): @dataclass -class HasRelationMetadata(JsonSchemaMixin, Replaceable): +class HasRelationMetadata(dbtClassMixin, Replaceable): database: Optional[str] schema: str + # Can't set database to None like it ought to be + # because it messes up the subclasses and default parameters + # so hack it here + @classmethod + def __pre_deserialize__(cls, data, options=None): + data = super().__pre_deserialize__(data, options=options) + if 'database' not in data: + data['database'] = None + return data + -class ParsedNodeMixins(JsonSchemaMixin): +class ParsedNodeMixins(dbtClassMixin): resource_type: NodeType depends_on: DependsOn config: NodeConfig @@ -132,8 +137,12 @@ def patch(self, patch: 'ParsedNodePatch'): self.meta = patch.meta self.docs = patch.docs if flags.STRICT_MODE: - assert isinstance(self, JsonSchemaMixin) - self.to_dict(validate=True, omit_none=False) + # It seems odd that an instance can be invalid + # Maybe there should be validation or restrictions + # elsewhere? + assert isinstance(self, dbtClassMixin) + dct = self.to_dict(options={'keep_none': True}) + self.validate(dct) def get_materialization(self): return self.config.materialized @@ -335,14 +344,14 @@ def same_body(self: T, other: T) -> bool: @dataclass -class TestMetadata(JsonSchemaMixin, Replaceable): - namespace: Optional[str] +class TestMetadata(dbtClassMixin, Replaceable): name: str - kwargs: Dict[str, Any] + kwargs: Dict[str, Any] = field(default_factory=dict) + namespace: Optional[str] = None @dataclass -class HasTestMetadata(JsonSchemaMixin): +class HasTestMetadata(dbtClassMixin): test_metadata: TestMetadata @@ -394,7 +403,7 @@ class IntermediateSnapshotNode(ParsedNode): @dataclass class ParsedSnapshotNode(ParsedNode): resource_type: NodeType = field(metadata={'restrict': [NodeType.Snapshot]}) - config: SnapshotVariants + config: SnapshotConfig @dataclass @@ -443,8 +452,10 @@ def patch(self, patch: ParsedMacroPatch): self.docs = patch.docs self.arguments = patch.arguments if flags.STRICT_MODE: - assert isinstance(self, JsonSchemaMixin) - self.to_dict(validate=True, omit_none=False) + # What does this actually validate? + assert isinstance(self, dbtClassMixin) + dct = self.to_dict(options={'keep_none': True}) + self.validate(dct) def same_contents(self, other: Optional['ParsedMacro']) -> bool: if other is None: diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index 6d2fcf3a508..165e22e4b1a 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -8,8 +8,9 @@ import dbt.helper_types # noqa:F401 from dbt.exceptions import CompilationException -from hologram import JsonSchemaMixin -from hologram.helpers import StrEnum, ExtensibleJsonSchemaMixin +from dbt.dataclass_schema import ( + dbtClassMixin, StrEnum, ExtensibleDbtClassMixin +) from dataclasses import dataclass, field from datetime import timedelta @@ -18,7 +19,7 @@ @dataclass -class UnparsedBaseNode(JsonSchemaMixin, Replaceable): +class UnparsedBaseNode(dbtClassMixin, Replaceable): package_name: str root_path: str path: str @@ -66,12 +67,12 @@ class UnparsedRunHook(UnparsedNode): @dataclass -class Docs(JsonSchemaMixin, Replaceable): +class Docs(dbtClassMixin, Replaceable): show: bool = True @dataclass -class HasDocs(AdditionalPropertiesMixin, ExtensibleJsonSchemaMixin, +class HasDocs(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replaceable): name: str description: str = '' @@ -100,7 +101,7 @@ class UnparsedColumn(HasTests): @dataclass -class HasColumnDocs(JsonSchemaMixin, Replaceable): +class HasColumnDocs(dbtClassMixin, Replaceable): columns: Sequence[HasDocs] = field(default_factory=list) @@ -110,7 +111,7 @@ class HasColumnTests(HasColumnDocs): @dataclass -class HasYamlMetadata(JsonSchemaMixin): +class HasYamlMetadata(dbtClassMixin): original_file_path: str yaml_key: str package_name: str @@ -127,7 +128,7 @@ class UnparsedNodeUpdate(HasColumnTests, HasTests, HasYamlMetadata): @dataclass -class MacroArgument(JsonSchemaMixin): +class MacroArgument(dbtClassMixin): name: str type: Optional[str] = None description: str = '' @@ -148,7 +149,7 @@ def plural(self) -> str: @dataclass -class Time(JsonSchemaMixin, Replaceable): +class Time(dbtClassMixin, Replaceable): count: int period: TimePeriod @@ -159,7 +160,7 @@ def exceeded(self, actual_age: float) -> bool: @dataclass -class FreshnessThreshold(JsonSchemaMixin, Mergeable): +class FreshnessThreshold(dbtClassMixin, Mergeable): warn_after: Optional[Time] = None error_after: Optional[Time] = None filter: Optional[str] = None @@ -180,7 +181,7 @@ def __bool__(self): @dataclass class AdditionalPropertiesAllowed( AdditionalPropertiesMixin, - ExtensibleJsonSchemaMixin + ExtensibleDbtClassMixin ): _extra: Dict[str, Any] = field(default_factory=dict) @@ -212,7 +213,7 @@ def __bool__(self): @dataclass -class Quoting(JsonSchemaMixin, Mergeable): +class Quoting(dbtClassMixin, Mergeable): database: Optional[bool] = None schema: Optional[bool] = None identifier: Optional[bool] = None @@ -230,15 +231,18 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests): external: Optional[ExternalTable] = None tags: List[str] = field(default_factory=list) - def to_dict(self, omit_none=True, validate=False): - result = super().to_dict(omit_none=omit_none, validate=validate) - if omit_none and self.freshness is None: - result['freshness'] = None - return result + def __post_serialize__(self, dct, options=None): + dct = super().__post_serialize__(dct) + keep_none = False + if options and 'keep_none' in options and options['keep_none']: + keep_none = True + if not keep_none and self.freshness is None: + dct['freshness'] = None + return dct @dataclass -class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable): +class UnparsedSourceDefinition(dbtClassMixin, Replaceable): name: str description: str = '' meta: Dict[str, Any] = field(default_factory=dict) @@ -257,15 +261,18 @@ class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable): def yaml_key(self) -> 'str': return 'sources' - def to_dict(self, omit_none=True, validate=False): - result = super().to_dict(omit_none=omit_none, validate=validate) - if omit_none and self.freshness is None: - result['freshness'] = None - return result + def __post_serialize__(self, dct, options=None): + dct = super().__post_serialize__(dct) + keep_none = False + if options and 'keep_none' in options and options['keep_none']: + keep_none = True + if not keep_none and self.freshness is None: + dct['freshness'] = None + return dct @dataclass -class SourceTablePatch(JsonSchemaMixin): +class SourceTablePatch(dbtClassMixin): name: str description: Optional[str] = None meta: Optional[Dict[str, Any]] = None @@ -283,7 +290,7 @@ class SourceTablePatch(JsonSchemaMixin): columns: Optional[Sequence[UnparsedColumn]] = None def to_patch_dict(self) -> Dict[str, Any]: - dct = self.to_dict(omit_none=True) + dct = self.to_dict() remove_keys = ('name') for key in remove_keys: if key in dct: @@ -296,7 +303,7 @@ def to_patch_dict(self) -> Dict[str, Any]: @dataclass -class SourcePatch(JsonSchemaMixin, Replaceable): +class SourcePatch(dbtClassMixin, Replaceable): name: str = field( metadata=dict(description='The name of the source to override'), ) @@ -320,7 +327,7 @@ class SourcePatch(JsonSchemaMixin, Replaceable): tags: Optional[List[str]] = None def to_patch_dict(self) -> Dict[str, Any]: - dct = self.to_dict(omit_none=True) + dct = self.to_dict() remove_keys = ('name', 'overrides', 'tables', 'path') for key in remove_keys: if key in dct: @@ -340,7 +347,7 @@ def get_table_named(self, name: str) -> Optional[SourceTablePatch]: @dataclass -class UnparsedDocumentation(JsonSchemaMixin, Replaceable): +class UnparsedDocumentation(dbtClassMixin, Replaceable): package_name: str root_path: str path: str @@ -400,13 +407,13 @@ class MaturityType(StrEnum): @dataclass -class ExposureOwner(JsonSchemaMixin, Replaceable): +class ExposureOwner(dbtClassMixin, Replaceable): email: str name: Optional[str] = None @dataclass -class UnparsedExposure(JsonSchemaMixin, Replaceable): +class UnparsedExposure(dbtClassMixin, Replaceable): name: str type: ExposureType owner: ExposureOwner diff --git a/core/dbt/contracts/project.py b/core/dbt/contracts/project.py index 0d531167324..a7ea5d91f6c 100644 --- a/core/dbt/contracts/project.py +++ b/core/dbt/contracts/project.py @@ -4,25 +4,39 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa from dbt import tracking from dbt import ui - -from hologram import JsonSchemaMixin, ValidationError -from hologram.helpers import HyphenatedJsonSchemaMixin, register_pattern, \ - ExtensibleJsonSchemaMixin - +from dbt.dataclass_schema import ( + dbtClassMixin, ValidationError, + HyphenatedDbtClassMixin, + ExtensibleDbtClassMixin, + register_pattern, ValidatedStringMixin +) from dataclasses import dataclass, field -from typing import Optional, List, Dict, Union, Any, NewType +from typing import Optional, List, Dict, Union, Any +from mashumaro.types import SerializableType PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True -Name = NewType('Name', str) +class Name(ValidatedStringMixin): + ValidationRegex = r'^[^\d\W]\w*$' + + register_pattern(Name, r'^[^\d\W]\w*$') + +class SemverString(str, SerializableType): + def _serialize(self) -> str: + return self + + @classmethod + def _deserialize(cls, value: str) -> 'SemverString': + return SemverString(value) + + # this does not support the full semver (does not allow a trailing -fooXYZ) and # is not restrictive enough for full semver, (allows '1.0'). But it's like # 'semver lite'. -SemverString = NewType('SemverString', str) register_pattern( SemverString, r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$', @@ -30,15 +44,15 @@ @dataclass -class Quoting(JsonSchemaMixin, Mergeable): - identifier: Optional[bool] - schema: Optional[bool] - database: Optional[bool] - project: Optional[bool] +class Quoting(dbtClassMixin, Mergeable): + schema: Optional[bool] = None + database: Optional[bool] = None + project: Optional[bool] = None + identifier: Optional[bool] = None @dataclass -class Package(Replaceable, HyphenatedJsonSchemaMixin): +class Package(Replaceable, HyphenatedDbtClassMixin): pass @@ -54,7 +68,7 @@ class LocalPackage(Package): @dataclass class GitPackage(Package): git: str - revision: Optional[RawVersion] + revision: Optional[RawVersion] = None warn_unpinned: Optional[bool] = None def get_revisions(self) -> List[str]: @@ -80,7 +94,7 @@ def get_versions(self) -> List[str]: @dataclass -class PackageConfig(JsonSchemaMixin, Replaceable): +class PackageConfig(dbtClassMixin, Replaceable): packages: List[PackageSpec] @@ -96,13 +110,13 @@ def from_project(cls, project): @dataclass -class Downloads(ExtensibleJsonSchemaMixin, Replaceable): +class Downloads(ExtensibleDbtClassMixin, Replaceable): tarball: str @dataclass class RegistryPackageMetadata( - ExtensibleJsonSchemaMixin, + ExtensibleDbtClassMixin, ProjectPackageMetadata, ): downloads: Downloads @@ -154,7 +168,7 @@ class RegistryPackageMetadata( @dataclass -class Project(HyphenatedJsonSchemaMixin, Replaceable): +class Project(HyphenatedDbtClassMixin, Replaceable): name: Name version: Union[SemverString, float] config_version: int @@ -191,18 +205,16 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable): query_comment: Optional[Union[QueryComment, NoValue, str]] = NoValue() @classmethod - def from_dict(cls, data, validate=True) -> 'Project': - result = super().from_dict(data, validate=validate) - if result.name in BANNED_PROJECT_NAMES: + def validate(cls, data): + super().validate(data) + if data['name'] in BANNED_PROJECT_NAMES: raise ValidationError( - f'Invalid project name: {result.name} is a reserved word' + f"Invalid project name: {data['name']} is a reserved word" ) - return result - @dataclass -class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract): +class UserConfig(ExtensibleDbtClassMixin, Replaceable, UserConfigContract): send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS use_colors: Optional[bool] = None partial_parse: Optional[bool] = None @@ -222,7 +234,7 @@ def set_values(self, cookie_dir): @dataclass -class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable): +class ProfileConfig(HyphenatedDbtClassMixin, Replaceable): profile_name: str = field(metadata={'preserve_underscore': True}) target_name: str = field(metadata={'preserve_underscore': True}) config: UserConfig @@ -233,10 +245,10 @@ class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable): @dataclass class ConfiguredQuoting(Quoting, Replaceable): - identifier: bool - schema: bool - database: Optional[bool] - project: Optional[bool] + identifier: bool = True + schema: bool = True + database: Optional[bool] = None + project: Optional[bool] = None @dataclass @@ -249,5 +261,5 @@ class Configuration(Project, ProfileConfig): @dataclass -class ProjectList(JsonSchemaMixin): +class ProjectList(dbtClassMixin): projects: Dict[str, Project] diff --git a/core/dbt/contracts/relation.py b/core/dbt/contracts/relation.py index bc755f1a541..6bcb58c749e 100644 --- a/core/dbt/contracts/relation.py +++ b/core/dbt/contracts/relation.py @@ -1,12 +1,11 @@ from collections.abc import Mapping from dataclasses import dataclass, fields from typing import ( - Optional, TypeVar, Generic, Dict, + Optional, Dict, ) from typing_extensions import Protocol -from hologram import JsonSchemaMixin -from hologram.helpers import StrEnum +from dbt.dataclass_schema import dbtClassMixin, StrEnum from dbt import deprecations from dbt.contracts.util import Replaceable @@ -32,7 +31,7 @@ class HasQuoting(Protocol): quoting: Dict[str, bool] -class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping): +class FakeAPIObject(dbtClassMixin, Replaceable, Mapping): # override the mapping truthiness, len is always >1 def __bool__(self): return True @@ -58,16 +57,13 @@ def incorporate(self, **kwargs): return self.from_dict(value) -T = TypeVar('T') - - @dataclass -class _ComponentObject(FakeAPIObject, Generic[T]): - database: T - schema: T - identifier: T +class Policy(FakeAPIObject): + database: bool = True + schema: bool = True + identifier: bool = True - def get_part(self, key: ComponentName) -> T: + def get_part(self, key: ComponentName) -> bool: if key == ComponentName.Database: return self.database elif key == ComponentName.Schema: @@ -80,25 +76,18 @@ def get_part(self, key: ComponentName) -> T: .format(key, list(ComponentName)) ) - def replace_dict(self, dct: Dict[ComponentName, T]): - kwargs: Dict[str, T] = {} + def replace_dict(self, dct: Dict[ComponentName, bool]): + kwargs: Dict[str, bool] = {} for k, v in dct.items(): kwargs[str(k)] = v return self.replace(**kwargs) @dataclass -class Policy(_ComponentObject[bool]): - database: bool = True - schema: bool = True - identifier: bool = True - - -@dataclass -class Path(_ComponentObject[Optional[str]]): - database: Optional[str] - schema: Optional[str] - identifier: Optional[str] +class Path(FakeAPIObject): + database: Optional[str] = None + schema: Optional[str] = None + identifier: Optional[str] = None def __post_init__(self): # handle pesky jinja2.Undefined sneaking in here and messing up rende @@ -120,3 +109,22 @@ def get_lowered_part(self, key: ComponentName) -> Optional[str]: if part is not None: part = part.lower() return part + + def get_part(self, key: ComponentName) -> Optional[str]: + if key == ComponentName.Database: + return self.database + elif key == ComponentName.Schema: + return self.schema + elif key == ComponentName.Identifier: + return self.identifier + else: + raise ValueError( + 'Got a key of {}, expected one of {}' + .format(key, list(ComponentName)) + ) + + def replace_dict(self, dct: Dict[ComponentName, str]): + kwargs: Dict[str, str] = {} + for k, v in dct.items(): + kwargs[str(k)] = v + return self.replace(**kwargs) diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index 61680ce0c67..43fc9117ce6 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -17,20 +17,21 @@ GLOBAL_LOGGER as logger, ) from dbt.utils import lowercase -from hologram.helpers import StrEnum -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin, StrEnum import agate from dataclasses import dataclass, field from datetime import datetime -from typing import Union, Dict, List, Optional, Any, NamedTuple, Sequence +from typing import ( + Union, Dict, List, Optional, Any, NamedTuple, Sequence, +) from dbt.clients.system import write_json @dataclass -class TimingInfo(JsonSchemaMixin): +class TimingInfo(dbtClassMixin): name: str started_at: Optional[datetime] = None completed_at: Optional[datetime] = None @@ -87,13 +88,20 @@ class FreshnessStatus(StrEnum): @dataclass -class BaseResult(JsonSchemaMixin): +class BaseResult(dbtClassMixin): status: Union[RunStatus, TestStatus, FreshnessStatus] timing: List[TimingInfo] thread_id: str execution_time: float - message: Optional[Union[str, int]] adapter_response: Dict[str, Any] + message: Optional[Union[str, int]] + + @classmethod + def __pre_deserialize__(cls, data, options=None): + data = super().__pre_deserialize__(data, options=options) + if 'message' not in data: + data['message'] = None + return data @dataclass @@ -103,7 +111,11 @@ class NodeResult(BaseResult): @dataclass class RunResult(NodeResult): - agate_table: Optional[agate.Table] = None + agate_table: Optional[agate.Table] = field( + default=None, metadata={ + 'serialize': lambda x: None, 'deserialize': lambda x: None + } + ) @property def skipped(self): @@ -111,7 +123,7 @@ def skipped(self): @dataclass -class ExecutionResult(JsonSchemaMixin): +class ExecutionResult(dbtClassMixin): results: Sequence[BaseResult] elapsed_time: float @@ -193,8 +205,8 @@ def from_execution_results( args=args ) - def write(self, path: str, omit_none=False): - write_json(path, self.to_dict(omit_none=omit_none)) + def write(self, path: str): + write_json(path, self.to_dict(options={'keep_none': True})) @dataclass @@ -253,14 +265,14 @@ class FreshnessErrorEnum(StrEnum): @dataclass -class SourceFreshnessRuntimeError(JsonSchemaMixin): +class SourceFreshnessRuntimeError(dbtClassMixin): unique_id: str error: Optional[Union[str, int]] status: FreshnessErrorEnum @dataclass -class SourceFreshnessOutput(JsonSchemaMixin): +class SourceFreshnessOutput(dbtClassMixin): unique_id: str max_loaded_at: datetime snapshotted_at: datetime @@ -374,40 +386,40 @@ def from_result(cls, base: FreshnessResult): @dataclass -class StatsItem(JsonSchemaMixin): +class StatsItem(dbtClassMixin): id: str label: str value: Primitive - description: Optional[str] include: bool + description: Optional[str] = None StatsDict = Dict[str, StatsItem] @dataclass -class ColumnMetadata(JsonSchemaMixin): +class ColumnMetadata(dbtClassMixin): type: str - comment: Optional[str] index: int name: str + comment: Optional[str] = None ColumnMap = Dict[str, ColumnMetadata] @dataclass -class TableMetadata(JsonSchemaMixin): +class TableMetadata(dbtClassMixin): type: str - database: Optional[str] schema: str name: str - comment: Optional[str] - owner: Optional[str] + database: Optional[str] = None + comment: Optional[str] = None + owner: Optional[str] = None @dataclass -class CatalogTable(JsonSchemaMixin, Replaceable): +class CatalogTable(dbtClassMixin, Replaceable): metadata: TableMetadata columns: ColumnMap stats: StatsDict @@ -430,12 +442,18 @@ class CatalogMetadata(BaseArtifactMetadata): @dataclass -class CatalogResults(JsonSchemaMixin): +class CatalogResults(dbtClassMixin): nodes: Dict[str, CatalogTable] sources: Dict[str, CatalogTable] - errors: Optional[List[str]] + errors: Optional[List[str]] = None _compile_results: Optional[Any] = None + def __post_serialize__(self, dct, options=None): + dct = super().__post_serialize__(dct, options=options) + if '_compile_results' in dct: + del dct['_compile_results'] + return dct + @dataclass @schema_version('catalog', 1) diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index 462ecae99af..213f13717bb 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -5,8 +5,7 @@ from datetime import datetime, timedelta from typing import Optional, Union, List, Any, Dict, Type, Sequence -from hologram import JsonSchemaMixin -from hologram.helpers import StrEnum +from dbt.dataclass_schema import dbtClassMixin, StrEnum from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.manifest import WritableManifest @@ -34,16 +33,25 @@ @dataclass -class RPCParameters(JsonSchemaMixin): - timeout: Optional[float] +class RPCParameters(dbtClassMixin): task_tags: TaskTags + timeout: Optional[float] + + @classmethod + def __pre_deserialize__(cls, data, options=None): + data = super().__pre_deserialize__(data, options=options) + if 'timeout' not in data: + data['timeout'] = None + if 'task_tags' not in data: + data['task_tags'] = None + return data @dataclass class RPCExecParameters(RPCParameters): name: str sql: str - macros: Optional[str] + macros: Optional[str] = None @dataclass @@ -132,7 +140,7 @@ class StatusParameters(RPCParameters): @dataclass -class GCSettings(JsonSchemaMixin): +class GCSettings(dbtClassMixin): # start evicting the longest-ago-ended tasks here maxsize: int # start evicting all tasks before now - auto_reap_age when we have this @@ -254,7 +262,7 @@ def from_local_result( @dataclass -class ResultTable(JsonSchemaMixin): +class ResultTable(dbtClassMixin): column_names: List[str] rows: List[Any] @@ -411,21 +419,31 @@ def finished(self) -> bool: @dataclass -class TaskTiming(JsonSchemaMixin): +class TaskTiming(dbtClassMixin): state: TaskHandlerState start: Optional[datetime] end: Optional[datetime] elapsed: Optional[float] + # These ought to be defaults but superclass order doesn't + # allow that to work + @classmethod + def __pre_deserialize__(cls, data, options=None): + data = super().__pre_deserialize__(data, options=options) + for field_name in ('start', 'end', 'elapsed'): + if field_name not in data: + data[field_name] = None + return data + @dataclass class TaskRow(TaskTiming): task_id: TaskID - request_id: Union[str, int] request_source: str method: str - timeout: Optional[float] - tags: TaskTags + request_id: Union[str, int] + tags: TaskTags = None + timeout: Optional[float] = None @dataclass @@ -451,7 +469,7 @@ class KillResult(RemoteResult): @dataclass @schema_version('remote-manifest-result', 1) class GetManifestResult(RemoteResult): - manifest: Optional[WritableManifest] + manifest: Optional[WritableManifest] = None # this is kind of carefuly structured: BlocksManifestTasks is implied by @@ -475,6 +493,16 @@ class PollResult(RemoteResult, TaskTiming): end: Optional[datetime] elapsed: Optional[float] + # These ought to be defaults but superclass order doesn't + # allow that to work + @classmethod + def __pre_deserialize__(cls, data, options=None): + data = super().__pre_deserialize__(data, options=options) + for field_name in ('start', 'end', 'elapsed'): + if field_name not in data: + data[field_name] = None + return data + @dataclass @schema_version('poll-remote-deps-result', 1) diff --git a/core/dbt/contracts/selection.py b/core/dbt/contracts/selection.py index a5bba15c749..a93429a8dae 100644 --- a/core/dbt/contracts/selection.py +++ b/core/dbt/contracts/selection.py @@ -1,18 +1,18 @@ from dataclasses import dataclass -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin from typing import List, Dict, Any, Union @dataclass -class SelectorDefinition(JsonSchemaMixin): +class SelectorDefinition(dbtClassMixin): name: str definition: Union[str, Dict[str, Any]] description: str = '' @dataclass -class SelectorFile(JsonSchemaMixin): +class SelectorFile(dbtClassMixin): selectors: List[SelectorDefinition] version: int = 2 diff --git a/core/dbt/contracts/util.py b/core/dbt/contracts/util.py index 817ac3e264b..26657925774 100644 --- a/core/dbt/contracts/util.py +++ b/core/dbt/contracts/util.py @@ -7,13 +7,12 @@ from dbt.clients.system import write_json, read_json from dbt.exceptions import ( - IncompatibleSchemaException, InternalException, RuntimeException, ) from dbt.version import __version__ from dbt.tracking import get_invocation_id -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin MacroKey = Tuple[str, str] SourceKey = Tuple[str, str] @@ -57,8 +56,10 @@ def merged(self, *args): class Writable: - def write(self, path: str, omit_none: bool = False): - write_json(path, self.to_dict(omit_none=omit_none)) # type: ignore + def write(self, path: str): + write_json( + path, self.to_dict(options={'keep_none': True}) # type: ignore + ) class AdditionalPropertiesMixin: @@ -69,22 +70,41 @@ class AdditionalPropertiesMixin: """ ADDITIONAL_PROPERTIES = True + # This takes attributes in the dictionary that are + # not in the class definitions and puts them in an + # _extra dict in the class @classmethod - def from_dict(cls, data, validate=True): - self = super().from_dict(data=data, validate=validate) - keys = self.to_dict(validate=False, omit_none=False) + def __pre_deserialize__(cls, data, options=None): + # dir() did not work because fields with + # metadata settings are not found + # The original version of this would create the + # object first and then update extra with the + # extra keys, but that won't work here, so + # we're copying the dict so we don't insert the + # _extra in the original data. This also requires + # that Mashumaro actually build the '_extra' field + cls_keys = cls._get_field_names() + new_dict = {} for key, value in data.items(): - if key not in keys: - self.extra[key] = value - return self + if key not in cls_keys and key != '_extra': + if '_extra' not in new_dict: + new_dict['_extra'] = {} + new_dict['_extra'][key] = value + else: + new_dict[key] = value + data = new_dict + data = super().__pre_deserialize__(data, options=options) + return data - def to_dict(self, omit_none=True, validate=False): - data = super().to_dict(omit_none=omit_none, validate=validate) + def __post_serialize__(self, dct, options=None): + data = super().__post_serialize__(dct, options=options) data.update(self.extra) + if '_extra' in data: + del data['_extra'] return data def replace(self, **kwargs): - dct = self.to_dict(omit_none=False, validate=False) + dct = self.to_dict(options={'keep_none': True}) dct.update(kwargs) return self.from_dict(dct) @@ -135,7 +155,7 @@ def get_metadata_env() -> Dict[str, str]: @dataclasses.dataclass -class BaseArtifactMetadata(JsonSchemaMixin): +class BaseArtifactMetadata(dbtClassMixin): dbt_schema_version: str dbt_version: str = __version__ generated_at: datetime = dataclasses.field( @@ -158,7 +178,7 @@ def inner(cls: Type[VersionedSchema]): @dataclasses.dataclass -class VersionedSchema(JsonSchemaMixin): +class VersionedSchema(dbtClassMixin): dbt_schema_version: ClassVar[SchemaVersion] @classmethod @@ -180,18 +200,9 @@ class ArtifactMixin(VersionedSchema, Writable, Readable): metadata: BaseArtifactMetadata @classmethod - def from_dict( - cls: Type[T], data: Dict[str, Any], validate: bool = True - ) -> T: + def validate(cls, data): + super().validate(data) if cls.dbt_schema_version is None: raise InternalException( 'Cannot call from_dict with no schema version!' ) - - if validate: - expected = str(cls.dbt_schema_version) - found = data.get('metadata', {}).get(SCHEMA_VERSION_KEY) - if found != expected: - raise IncompatibleSchemaException(expected, found) - - return super().from_dict(data=data, validate=validate) diff --git a/core/dbt/dataclass_schema.py b/core/dbt/dataclass_schema.py new file mode 100644 index 00000000000..1382d489c50 --- /dev/null +++ b/core/dbt/dataclass_schema.py @@ -0,0 +1,170 @@ +from typing import ( + Type, ClassVar, Dict, cast, TypeVar +) +import re +from dataclasses import fields +from enum import Enum +from datetime import datetime +from dateutil.parser import parse + +from hologram import JsonSchemaMixin, FieldEncoder, ValidationError + +from mashumaro import DataClassDictMixin +from mashumaro.types import SerializableEncoder, SerializableType + + +class DateTimeSerializableEncoder(SerializableEncoder[datetime]): + @classmethod + def _serialize(cls, value: datetime) -> str: + out = value.isoformat() + # Assume UTC if timezone is missing + if value.tzinfo is None: + out = out + "Z" + return out + + @classmethod + def _deserialize(cls, value: str) -> datetime: + return ( + value if isinstance(value, datetime) else parse(cast(str, value)) + ) + + +TV = TypeVar("TV") + + +# This class pulls in both JsonSchemaMixin from Hologram and +# DataClassDictMixin from our fork of Mashumaro. The 'to_dict' +# and 'from_dict' methods come from Mashumaro. Building +# jsonschemas for every class and the 'validate' method +# come from Hologram. +class dbtClassMixin(DataClassDictMixin, JsonSchemaMixin): + """Mixin which adds methods to generate a JSON schema and + convert to and from JSON encodable dicts with validation + against the schema + """ + + _serializable_encoders: ClassVar[Dict[str, SerializableEncoder]] = { + 'datetime.datetime': DateTimeSerializableEncoder(), + } + _hyphenated: ClassVar[bool] = False + ADDITIONAL_PROPERTIES: ClassVar[bool] = False + + # This is called by the mashumaro to_dict in order to handle + # nested classes. + # Munges the dict that's returned. + def __post_serialize__(self, dct, options=None): + keep_none = False + if options and 'keep_none' in options and options['keep_none']: + keep_none = True + if not keep_none: # remove attributes that are None + new_dict = {k: v for k, v in dct.items() if v is not None} + dct = new_dict + + if self._hyphenated: + new_dict = {} + for key in dct: + if '_' in key: + new_key = key.replace('_', '-') + new_dict[new_key] = dct[key] + else: + new_dict[key] = dct[key] + dct = new_dict + + return dct + + # This is called by the mashumaro _from_dict method, before + # performing the conversion to a dict + @classmethod + def __pre_deserialize__(cls, data, options=None): + if cls._hyphenated: + new_dict = {} + for key in data: + if '-' in key: + new_key = key.replace('-', '_') + new_dict[new_key] = data[key] + else: + new_dict[key] = data[key] + data = new_dict + return data + + # This is used in the hologram._encode_field method, which calls + # a 'to_dict' method which does not have the same parameters in + # hologram and in mashumaro. + def _local_to_dict(self, **kwargs): + args = {} + if 'omit_none' in kwargs and kwargs['omit_none'] is False: + args['options'] = {'keep_none': True} + return self.to_dict(**args) + + +class ValidatedStringMixin(str, SerializableType): + ValidationRegex = '' + + @classmethod + def _deserialize(cls, value: str) -> 'ValidatedStringMixin': + cls.validate(value) + return ValidatedStringMixin(value) + + def _serialize(self) -> str: + return str(self) + + @classmethod + def validate(cls, value): + res = re.match(cls.ValidationRegex, value) + + if res is None: + raise ValidationError(f"Invalid value: {value}") # TODO + + +# These classes must be in this order or it doesn't work +class StrEnum(str, SerializableType, Enum): + def __str__(self): + return self.value + + # https://docs.python.org/3.6/library/enum.html#using-automatic-values + def _generate_next_value_(name, *_): + return name + + def _serialize(self) -> str: + return self.value + + @classmethod + def _deserialize(cls, value: str): + return cls(value) + + +class HyphenatedDbtClassMixin(dbtClassMixin): + # used by from_dict/to_dict + _hyphenated: ClassVar[bool] = True + + # used by jsonschema validation, _get_fields + @classmethod + def field_mapping(cls): + result = {} + for field in fields(cls): + skip = field.metadata.get("preserve_underscore") + if skip: + continue + + if "_" in field.name: + result[field.name] = field.name.replace("_", "-") + return result + + +class ExtensibleDbtClassMixin(dbtClassMixin): + ADDITIONAL_PROPERTIES = True + + +# This is used by Hologram in jsonschema validation +def register_pattern(base_type: Type, pattern: str) -> None: + """base_type should be a typing.NewType that should always have the given + regex pattern. That means that its underlying type ('__supertype__') had + better be a str! + """ + + class PatternEncoder(FieldEncoder): + @property + def json_schema(self): + return {"type": "string", "pattern": pattern} + + dbtClassMixin.register_field_encoders({base_type: PatternEncoder()}) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 1a1134e3a26..82cd9267bc4 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -7,14 +7,14 @@ from dbt import flags from dbt.ui import line_wrap_message -import hologram +import dbt.dataclass_schema def validator_error_message(exc): - """Given a hologram.ValidationError (which is basically a + """Given a dbt.dataclass_schema.ValidationError (which is basically a jsonschema.ValidationError), return the relevant parts as a string """ - if not isinstance(exc, hologram.ValidationError): + if not isinstance(exc, dbt.dataclass_schema.ValidationError): return str(exc) path = "[%s]" % "][".join(map(repr, exc.relative_path)) return 'at path {}: {}'.format(path, exc.message) diff --git a/core/dbt/graph/cli.py b/core/dbt/graph/cli.py index c39c142bba1..93da3306dff 100644 --- a/core/dbt/graph/cli.py +++ b/core/dbt/graph/cli.py @@ -1,6 +1,6 @@ # special support for CLI argument parsing. import itertools -import yaml +from dbt.clients.yaml_helper import yaml, Loader, Dumper # noqa: F401 from typing import ( Dict, List, Optional, Tuple, Any, Union @@ -236,7 +236,7 @@ def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec: ) # if key isn't a valid method name, this will raise - base = SelectionCriteria.from_dict(definition, dct) + base = SelectionCriteria.selection_criteria_from_dict(definition, dct) if diff_arg is None: return base else: diff --git a/core/dbt/graph/selector_methods.py b/core/dbt/graph/selector_methods.py index b43efce45ce..0320f594e46 100644 --- a/core/dbt/graph/selector_methods.py +++ b/core/dbt/graph/selector_methods.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Set, List, Dict, Iterator, Tuple, Any, Union, Type, Optional -from hologram.helpers import StrEnum +from dbt.dataclass_schema import StrEnum from .graph import UniqueId diff --git a/core/dbt/graph/selector_spec.py b/core/dbt/graph/selector_spec.py index 417696b2056..50f94c58538 100644 --- a/core/dbt/graph/selector_spec.py +++ b/core/dbt/graph/selector_spec.py @@ -102,7 +102,9 @@ def parse_method( return method_name, method_arguments @classmethod - def from_dict(cls, raw: Any, dct: Dict[str, Any]) -> 'SelectionCriteria': + def selection_criteria_from_dict( + cls, raw: Any, dct: Dict[str, Any] + ) -> 'SelectionCriteria': if 'value' not in dct: raise RuntimeException( f'Invalid node spec "{raw}" - no search value!' @@ -150,7 +152,7 @@ def from_single_spec(cls, raw: str) -> 'SelectionCriteria': # bad spec! raise RuntimeException(f'Invalid selector spec "{raw}"') - return cls.from_dict(raw, result.groupdict()) + return cls.selection_criteria_from_dict(raw, result.groupdict()) class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta): diff --git a/core/dbt/helper_types.py b/core/dbt/helper_types.py index ca69e019864..1e38b971279 100644 --- a/core/dbt/helper_types.py +++ b/core/dbt/helper_types.py @@ -2,14 +2,27 @@ from dataclasses import dataclass from datetime import timedelta from pathlib import Path -from typing import NewType, Tuple, AbstractSet +from typing import Tuple, AbstractSet, Union -from hologram import ( - FieldEncoder, JsonSchemaMixin, JsonDict, ValidationError +from dbt.dataclass_schema import ( + dbtClassMixin, ValidationError, StrEnum, ) -from hologram.helpers import StrEnum +from hologram import FieldEncoder, JsonDict +from mashumaro.types import SerializableType -Port = NewType('Port', int) + +class Port(int, SerializableType): + @classmethod + def _deserialize(cls, value: Union[int, str]) -> 'Port': + try: + value = int(value) + except ValueError: + raise ValidationError(f'Cannot encode {value} into port number') + + return Port(value) + + def _serialize(self) -> int: + return self class PortEncoder(FieldEncoder): @@ -66,12 +79,12 @@ def __eq__(self, other): @dataclass -class NoValue(JsonSchemaMixin): +class NoValue(dbtClassMixin): """Sometimes, you want a way to say none that isn't None""" novalue: NVEnum = NVEnum.novalue -JsonSchemaMixin.register_field_encoders({ +dbtClassMixin.register_field_encoders({ Port: PortEncoder(), timedelta: TimeDeltaFieldEncoder(), Path: PathEncoder(), diff --git a/core/dbt/hooks.py b/core/dbt/hooks.py index 26403226e3b..0603f2adb7a 100644 --- a/core/dbt/hooks.py +++ b/core/dbt/hooks.py @@ -1,4 +1,4 @@ -from hologram.helpers import StrEnum +from dbt.dataclass_schema import StrEnum import json from typing import Union, Dict, Any diff --git a/core/dbt/logger.py b/core/dbt/logger.py index ccaad7edc63..1916f49020e 100644 --- a/core/dbt/logger.py +++ b/core/dbt/logger.py @@ -13,7 +13,7 @@ import colorama import logbook -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin # Colorama needs some help on windows because we're using logger.info # intead of print(). If the Windows env doesn't have a TERM var set, @@ -45,11 +45,10 @@ ExceptionInformation = str -Extras = Dict[str, Any] @dataclass -class LogMessage(JsonSchemaMixin): +class LogMessage(dbtClassMixin): timestamp: datetime message: str channel: str @@ -57,7 +56,7 @@ class LogMessage(JsonSchemaMixin): levelname: str thread_name: str process: int - extra: Optional[Extras] = None + extra: Optional[Dict[str, Any]] = None exc_info: Optional[ExceptionInformation] = None @classmethod @@ -215,7 +214,7 @@ def process(self, record): class TimingProcessor(logbook.Processor): - def __init__(self, timing_info: Optional[JsonSchemaMixin] = None): + def __init__(self, timing_info: Optional[dbtClassMixin] = None): self.timing_info = timing_info super().__init__() diff --git a/core/dbt/node_types.py b/core/dbt/node_types.py index e6502ca7030..b21d738c4c5 100644 --- a/core/dbt/node_types.py +++ b/core/dbt/node_types.py @@ -1,6 +1,6 @@ from typing import List -from hologram.helpers import StrEnum +from dbt.dataclass_schema import StrEnum class NodeType(StrEnum): diff --git a/core/dbt/parser/analysis.py b/core/dbt/parser/analysis.py index 8d13d99368b..3a9ae0cc894 100644 --- a/core/dbt/parser/analysis.py +++ b/core/dbt/parser/analysis.py @@ -13,7 +13,9 @@ def get_paths(self): ) def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode: - return ParsedAnalysisNode.from_dict(dct, validate=validate) + if validate: + ParsedAnalysisNode.validate(dct) + return ParsedAnalysisNode.from_dict(dct) @property def resource_type(self) -> NodeType: diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index e4c4f6d9442..7f603d02d61 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -5,7 +5,7 @@ List, Dict, Any, Iterable, Generic, TypeVar ) -from hologram import ValidationError +from dbt.dataclass_schema import ValidationError from dbt import utils from dbt.clients.jinja import MacroGenerator @@ -23,7 +23,7 @@ from dbt.contracts.files import ( SourceFile, FilePath, FileHash ) -from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.manifest import MacroManifest from dbt.contracts.graph.parsed import HasUniqueID from dbt.contracts.graph.unparsed import UnparsedNode from dbt.exceptions import ( @@ -99,7 +99,7 @@ def __init__( results: ParseResult, project: Project, root_project: RuntimeConfig, - macro_manifest: Manifest, + macro_manifest: MacroManifest, ) -> None: super().__init__(results, project) self.root_project = root_project @@ -108,9 +108,10 @@ def __init__( class RelationUpdate: def __init__( - self, config: RuntimeConfig, manifest: Manifest, component: str + self, config: RuntimeConfig, macro_manifest: MacroManifest, + component: str ) -> None: - macro = manifest.find_generate_macro_by_name( + macro = macro_manifest.find_generate_macro_by_name( component=component, root_project_name=config.project_name, ) @@ -120,7 +121,7 @@ def __init__( ) root_context = generate_generate_component_name_macro( - macro, config, manifest + macro, config, macro_manifest ) self.updater = MacroGenerator(macro, root_context) self.component = component @@ -144,18 +145,21 @@ def __init__( results: ParseResult, project: Project, root_project: RuntimeConfig, - macro_manifest: Manifest, + macro_manifest: MacroManifest, ) -> None: super().__init__(results, project, root_project, macro_manifest) self._update_node_database = RelationUpdate( - manifest=macro_manifest, config=root_project, component='database' + macro_manifest=macro_manifest, config=root_project, + component='database' ) self._update_node_schema = RelationUpdate( - manifest=macro_manifest, config=root_project, component='schema' + macro_manifest=macro_manifest, config=root_project, + component='schema' ) self._update_node_alias = RelationUpdate( - manifest=macro_manifest, config=root_project, component='alias' + macro_manifest=macro_manifest, config=root_project, + component='alias' ) @abc.abstractclassmethod @@ -252,7 +256,7 @@ def _create_parsetime_node( } dct.update(kwargs) try: - return self.parse_from_dict(dct) + return self.parse_from_dict(dct, validate=True) except ValidationError as exc: msg = validator_error_message(exc) # this is a bit silly, but build an UnparsedNode just for error @@ -275,20 +279,24 @@ def _context_for( def render_with_context( self, parsed_node: IntermediateNode, config: ContextConfig ) -> None: - """Given the parsed node and a ContextConfig to use during parsing, - render the node's sql wtih macro capture enabled. + # Given the parsed node and a ContextConfig to use during parsing, + # render the node's sql wtih macro capture enabled. + # Note: this mutates the config object when config calls are rendered. - Note: this mutates the config object when config() calls are rendered. - """ # during parsing, we don't have a connection, but we might need one, so # we have to acquire it. with get_adapter(self.root_project).connection_for(parsed_node): context = self._context_for(parsed_node, config) + # this goes through the process of rendering, but just throws away + # the rendered result. The "macro capture" is the point? get_rendered( parsed_node.raw_sql, context, parsed_node, capture_macros=True ) + # This is taking the original config for the node, converting it to a dict, + # updating the config with new config passed in, then re-creating the + # config from the dict in the node. def update_parsed_node_config( self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] ) -> None: diff --git a/core/dbt/parser/data_test.py b/core/dbt/parser/data_test.py index 52a9d3dbd42..53f8d10eab8 100644 --- a/core/dbt/parser/data_test.py +++ b/core/dbt/parser/data_test.py @@ -12,7 +12,9 @@ def get_paths(self): ) def parse_from_dict(self, dct, validate=True) -> ParsedDataTestNode: - return ParsedDataTestNode.from_dict(dct, validate=validate) + if validate: + ParsedDataTestNode.validate(dct) + return ParsedDataTestNode.from_dict(dct) @property def resource_type(self) -> NodeType: diff --git a/core/dbt/parser/hooks.py b/core/dbt/parser/hooks.py index 9edc5ea2d15..dfcd6431569 100644 --- a/core/dbt/parser/hooks.py +++ b/core/dbt/parser/hooks.py @@ -79,7 +79,9 @@ def get_paths(self) -> List[FilePath]: return [path] def parse_from_dict(self, dct, validate=True) -> ParsedHookNode: - return ParsedHookNode.from_dict(dct, validate=validate) + if validate: + ParsedHookNode.validate(dct) + return ParsedHookNode.from_dict(dct) @classmethod def get_compiled_path(cls, block: HookBlock): diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index fd0e7b3e572..df7768c79ed 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -23,7 +23,9 @@ from dbt.context.docs import generate_runtime_docs from dbt.contracts.files import FilePath, FileHash from dbt.contracts.graph.compiled import ManifestNode -from dbt.contracts.graph.manifest import Manifest, Disabled +from dbt.contracts.graph.manifest import ( + Manifest, MacroManifest, AnyManifest, Disabled +) from dbt.contracts.graph.parsed import ( ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo, ParsedExposure ) @@ -51,7 +53,7 @@ from dbt.ui import warning_tag from dbt.version import __version__ -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle' PARSING_STATE = DbtProcessState('parsing') @@ -59,14 +61,14 @@ @dataclass -class ParserInfo(JsonSchemaMixin): +class ParserInfo(dbtClassMixin): parser: str elapsed: float path_count: int = 0 @dataclass -class ProjectLoaderInfo(JsonSchemaMixin): +class ProjectLoaderInfo(dbtClassMixin): project_name: str elapsed: float parsers: List[ParserInfo] @@ -74,7 +76,7 @@ class ProjectLoaderInfo(JsonSchemaMixin): @dataclass -class ManifestLoaderInfo(JsonSchemaMixin, Writable): +class ManifestLoaderInfo(dbtClassMixin, Writable): path_count: int = 0 is_partial_parse_enabled: Optional[bool] = None parse_project_elapsed: Optional[float] = None @@ -137,16 +139,19 @@ def __init__( self, root_project: RuntimeConfig, all_projects: Mapping[str, Project], - macro_hook: Optional[Callable[[Manifest], Any]] = None, + macro_hook: Optional[Callable[[AnyManifest], Any]] = None, ) -> None: self.root_project: RuntimeConfig = root_project self.all_projects: Mapping[str, Project] = all_projects - self.macro_hook: Callable[[Manifest], Any] + self.macro_hook: Callable[[AnyManifest], Any] if macro_hook is None: self.macro_hook = lambda m: None else: self.macro_hook = macro_hook + # results holds all of the nodes created by parsing, + # in dictionaries: nodes, sources, docs, macros, exposures, + # macro_patches, patches, source_patches, files, etc self.results: ParseResult = make_parse_result( root_project, all_projects, ) @@ -210,7 +215,7 @@ def _get_file(self, path: FilePath, parser: BaseParser) -> FileBlock: def parse_project( self, project: Project, - macro_manifest: Manifest, + macro_manifest: MacroManifest, old_results: Optional[ParseResult], ) -> None: parsers: List[Parser] = [] @@ -252,7 +257,7 @@ def parse_project( self._perf_info.path_count + total_path_count ) - def load_only_macros(self) -> Manifest: + def load_only_macros(self) -> MacroManifest: old_results = self.read_parse_results() for project in self.all_projects.values(): @@ -261,17 +266,20 @@ def load_only_macros(self) -> Manifest: self.parse_with_cache(path, parser, old_results) # make a manifest with just the macros to get the context - macro_manifest = Manifest.from_macros( + macro_manifest = MacroManifest( macros=self.results.macros, files=self.results.files ) self.macro_hook(macro_manifest) return macro_manifest - def load(self, macro_manifest: Manifest): + # This is where the main action happens + def load(self, macro_manifest: MacroManifest): + # if partial parse is enabled, load old results old_results = self.read_parse_results() if old_results is not None: logger.debug('Got an acceptable cached parse result') + # store the macros & files from the adapter macro manifest self.results.macros.update(macro_manifest.macros) self.results.files.update(macro_manifest.files) @@ -423,8 +431,8 @@ def create_manifest(self) -> Manifest: def load_all( cls, root_config: RuntimeConfig, - macro_manifest: Manifest, - macro_hook: Callable[[Manifest], Any], + macro_manifest: MacroManifest, + macro_hook: Callable[[AnyManifest], Any], ) -> Manifest: with PARSING_STATE: start_load_all = time.perf_counter() @@ -449,8 +457,8 @@ def load_all( def load_macros( cls, root_config: RuntimeConfig, - macro_hook: Callable[[Manifest], Any], - ) -> Manifest: + macro_hook: Callable[[AnyManifest], Any], + ) -> MacroManifest: with PARSING_STATE: projects = root_config.load_dependencies() loader = cls(root_config, projects, macro_hook) @@ -841,14 +849,14 @@ def process_node( def load_macro_manifest( config: RuntimeConfig, - macro_hook: Callable[[Manifest], Any], -) -> Manifest: + macro_hook: Callable[[AnyManifest], Any], +) -> MacroManifest: return ManifestLoader.load_macros(config, macro_hook) def load_manifest( config: RuntimeConfig, - macro_manifest: Manifest, - macro_hook: Callable[[Manifest], Any], + macro_manifest: MacroManifest, + macro_hook: Callable[[AnyManifest], Any], ) -> Manifest: return ManifestLoader.load_all(config, macro_manifest, macro_hook) diff --git a/core/dbt/parser/models.py b/core/dbt/parser/models.py index 339004d267e..e7c04838511 100644 --- a/core/dbt/parser/models.py +++ b/core/dbt/parser/models.py @@ -11,7 +11,9 @@ def get_paths(self): ) def parse_from_dict(self, dct, validate=True) -> ParsedModelNode: - return ParsedModelNode.from_dict(dct, validate=validate) + if validate: + ParsedModelNode.validate(dct) + return ParsedModelNode.from_dict(dct) @property def resource_type(self) -> NodeType: diff --git a/core/dbt/parser/results.py b/core/dbt/parser/results.py index a1b5def9df2..f2fc3b9d5f1 100644 --- a/core/dbt/parser/results.py +++ b/core/dbt/parser/results.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import TypeVar, MutableMapping, Mapping, Union, List -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin from dbt.contracts.files import RemoteFile, FileHash, SourceFile from dbt.contracts.graph.compiled import CompileResultNode @@ -62,7 +62,7 @@ def dict_field(): @dataclass -class ParseResult(JsonSchemaMixin, Writable, Replaceable): +class ParseResult(dbtClassMixin, Writable, Replaceable): vars_hash: FileHash profile_hash: FileHash project_hashes: MutableMapping[str, FileHash] diff --git a/core/dbt/parser/rpc.py b/core/dbt/parser/rpc.py index 1b03663732f..9bf54c68f89 100644 --- a/core/dbt/parser/rpc.py +++ b/core/dbt/parser/rpc.py @@ -26,7 +26,9 @@ def get_paths(self): return [] def parse_from_dict(self, dct, validate=True) -> ParsedRPCNode: - return ParsedRPCNode.from_dict(dct, validate=validate) + if validate: + ParsedRPCNode.validate(dct) + return ParsedRPCNode.from_dict(dct) @property def resource_type(self) -> NodeType: diff --git a/core/dbt/parser/schema_test_builders.py b/core/dbt/parser/schema_test_builders.py index 2832aaeeeef..b6a2d31aa62 100644 --- a/core/dbt/parser/schema_test_builders.py +++ b/core/dbt/parser/schema_test_builders.py @@ -179,6 +179,7 @@ class TestBuilder(Generic[Testable]): - or it may not be namespaced (test) """ + # The 'test_name' is used to find the 'macro' that implements the test TEST_NAME_PATTERN = re.compile( r'((?P([a-zA-Z_][0-9a-zA-Z_]*))\.)?' r'(?P([a-zA-Z_][0-9a-zA-Z_]*))' @@ -302,6 +303,8 @@ def get_test_name(self) -> Tuple[str, str]: name = '{}_{}'.format(self.namespace, name) return get_nice_schema_test_name(name, self.target.name, self.args) + # this is the 'raw_sql' that's used in 'render_update' and execution + # of the test macro def build_raw_sql(self) -> str: return ( "{{{{ config(severity='{severity}') }}}}" diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 128842441e4..f176a8e1731 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -6,9 +6,9 @@ Iterable, Dict, Any, Union, List, Optional, Generic, TypeVar, Type ) -from hologram import ValidationError, JsonSchemaMixin +from dbt.dataclass_schema import ValidationError, dbtClassMixin -from dbt.adapters.factory import get_adapter +from dbt.adapters.factory import get_adapter, get_adapter_package_names from dbt.clients.jinja import get_rendered, add_rendered_test_kwargs from dbt.clients.yaml_helper import load_yaml_text from dbt.config.renderer import SchemaYamlRenderer @@ -20,7 +20,10 @@ ) from dbt.context.configured import generate_schema_yml from dbt.context.target import generate_target_context -from dbt.context.providers import generate_parse_exposure +from dbt.context.providers import ( + generate_parse_exposure, generate_test_context +) +from dbt.context.macro_resolver import MacroResolver from dbt.contracts.files import FileHash from dbt.contracts.graph.manifest import SourceFile from dbt.contracts.graph.model_config import SourceConfig @@ -173,6 +176,15 @@ def __init__( self.raw_renderer = SchemaYamlRenderer(ctx) + internal_package_names = get_adapter_package_names( + self.root_project.credentials.type + ) + self.macro_resolver = MacroResolver( + self.macro_manifest.macros, + self.root_project.project_name, + internal_package_names + ) + @classmethod def get_compiled_path(cls, block: FileBlock) -> str: # should this raise an error? @@ -202,9 +214,11 @@ def get_paths(self): ) def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode: - return ParsedSchemaTestNode.from_dict(dct, validate=validate) + if validate: + ParsedSchemaTestNode.validate(dct) + return ParsedSchemaTestNode.from_dict(dct) - def _parse_format_version( + def _check_format_version( self, yaml: YamlBlock ) -> None: path = yaml.path.relative_path @@ -374,7 +388,8 @@ def create_test_node( 'checksum': FileHash.empty().to_dict(), } try: - return self.parse_from_dict(dct) + ParsedSchemaTestNode.validate(dct) + return ParsedSchemaTestNode.from_dict(dct) except ValidationError as exc: msg = validator_error_message(exc) # this is a bit silly, but build an UnparsedNode just for error @@ -387,6 +402,7 @@ def create_test_node( ) raise CompilationException(msg, node=node) from exc + # lots of time spent in this method def _parse_generic_test( self, target: Testable, @@ -425,6 +441,7 @@ def _parse_generic_test( # is not necessarily this package's name fqn = self.get_fqn(fqn_path, builder.fqn_name) + # this is the config that is used in render_update config = self.initial_config(fqn) metadata = { @@ -447,9 +464,53 @@ def _parse_generic_test( column_name=column_name, test_metadata=metadata, ) - self.render_update(node, config) + self.render_test_update(node, config, builder) + return node + # This does special shortcut processing for the two + # most common internal macros, not_null and unique, + # which avoids the jinja rendering to resolve config + # and variables, etc, which might be in the macro. + # In the future we will look at generalizing this + # more to handle additional macros or to use static + # parsing to avoid jinja overhead. + def render_test_update(self, node, config, builder): + macro_unique_id = self.macro_resolver.get_macro_id( + node.package_name, 'test_' + builder.name) + # Add the depends_on here so we can limit the macros added + # to the context in rendering processing + node.depends_on.add_macro(macro_unique_id) + if (macro_unique_id in + ['macro.dbt.test_not_null', 'macro.dbt.test_unique']): + self.update_parsed_node(node, config) + node.unrendered_config['severity'] = builder.severity() + node.config['severity'] = builder.severity() + # source node tests are processed at patch_source time + if isinstance(builder.target, UnpatchedSourceDefinition): + sources = [builder.target.fqn[-2], builder.target.fqn[-1]] + node.sources.append(sources) + else: # all other nodes + node.refs.append([builder.target.name]) + else: + try: + # make a base context that doesn't have the magic kwargs field + context = generate_test_context( + node, self.root_project, self.macro_manifest, config, + self.macro_resolver, + ) + # update with rendered test kwargs (which collects any refs) + add_rendered_test_kwargs(context, node, capture_macros=True) + # the parsed node is not rendered in the native context. + get_rendered( + node.raw_sql, context, node, capture_macros=True + ) + self.update_parsed_node(node, config) + except ValidationError as exc: + # we got a ValidationError - probably bad types in config() + msg = validator_error_message(exc) + raise CompilationException(msg, node=node) from exc + def parse_source_test( self, target: UnpatchedSourceDefinition, @@ -561,10 +622,13 @@ def parse_exposures(self, block: YamlBlock) -> None: def parse_file(self, block: FileBlock) -> None: dct = self._yaml_from_file(block.file) - # mark the file as seen, even if there are no macros in it + + # mark the file as seen, in ParseResult.files self.results.get_file(block.file) + if dct: try: + # This does a deep_map to check for circular references dct = self.raw_renderer.render_data(dct) except CompilationException as exc: raise CompilationException( @@ -572,28 +636,58 @@ def parse_file(self, block: FileBlock) -> None: f'project {self.project.project_name}: {exc}' ) from exc + # contains the FileBlock and the data (dictionary) yaml_block = YamlBlock.from_file_block(block, dct) - self._parse_format_version(yaml_block) + # checks version + self._check_format_version(yaml_block) parser: YamlDocsReader - for key in NodeType.documentable(): - plural = key.pluralize() - if key == NodeType.Source: - parser = SourceParser(self, yaml_block, plural) - elif key == NodeType.Macro: - parser = MacroPatchParser(self, yaml_block, plural) - elif key == NodeType.Analysis: - parser = AnalysisPatchParser(self, yaml_block, plural) - elif key == NodeType.Exposure: - # handle exposures separately, but they are - # technically still "documentable" - continue - else: - parser = TestablePatchParser(self, yaml_block, plural) + + # There are 7 kinds of parsers: + # Model, Seed, Snapshot, Source, Macro, Analysis, Exposures + + # NonSourceParser.parse(), TestablePatchParser is a variety of + # NodePatchParser + if 'models' in dct: + parser = TestablePatchParser(self, yaml_block, 'models') for test_block in parser.parse(): self.parse_tests(test_block) - self.parse_exposures(yaml_block) + + # NonSourceParser.parse() + if 'seeds' in dct: + parser = TestablePatchParser(self, yaml_block, 'seeds') + for test_block in parser.parse(): + self.parse_tests(test_block) + + # NonSourceParser.parse() + if 'snapshots' in dct: + parser = TestablePatchParser(self, yaml_block, 'snapshots') + for test_block in parser.parse(): + self.parse_tests(test_block) + + # This parser uses SourceParser.parse() which doesn't return + # any test blocks. Source tests are handled at a later point + # in the process. + if 'sources' in dct: + parser = SourceParser(self, yaml_block, 'sources') + parser.parse() + + # NonSourceParser.parse() + if 'macros' in dct: + parser = MacroPatchParser(self, yaml_block, 'macros') + for test_block in parser.parse(): + self.parse_tests(test_block) + + # NonSourceParser.parse() + if 'analyses' in dct: + parser = AnalysisPatchParser(self, yaml_block, 'analyses') + for test_block in parser.parse(): + self.parse_tests(test_block) + + # parse exposures + if 'exposures' in dct: + self.parse_exposures(yaml_block) Parsed = TypeVar( @@ -610,11 +704,14 @@ def parse_file(self, block: FileBlock) -> None: ) +# abstract base class (ABCMeta) class YamlReader(metaclass=ABCMeta): def __init__( self, schema_parser: SchemaParser, yaml: YamlBlock, key: str ) -> None: self.schema_parser = schema_parser + # key: models, seeds, snapshots, sources, macros, + # analyses, exposures self.key = key self.yaml = yaml @@ -634,6 +731,9 @@ def default_database(self): def root_project(self): return self.schema_parser.root_project + # for the different schema subparsers ('models', 'source', etc) + # get the list of dicts pointed to by the key in the yaml config, + # ensure that the dicts have string keys def get_key_dicts(self) -> Iterable[Dict[str, Any]]: data = self.yaml.data.get(self.key, []) if not isinstance(data, list): @@ -643,7 +743,10 @@ def get_key_dicts(self) -> Iterable[Dict[str, Any]]: ) path = self.yaml.path.original_file_path + # for each dict in the data (which is a list of dicts) for entry in data: + # check that entry is a dict and that all dict values + # are strings if coerce_dict_str(entry) is not None: yield entry else: @@ -659,19 +762,22 @@ def parse(self) -> List[TestBlock]: raise NotImplementedError('parse is abstract') -T = TypeVar('T', bound=JsonSchemaMixin) +T = TypeVar('T', bound=dbtClassMixin) class SourceParser(YamlDocsReader): def _target_from_dict(self, cls: Type[T], data: Dict[str, Any]) -> T: path = self.yaml.path.original_file_path try: + cls.validate(data) return cls.from_dict(data) except (ValidationError, JSONValidationException) as exc: msg = error_context(path, self.key, data, exc) raise CompilationException(msg) from exc + # the other parse method returns TestBlocks. This one doesn't. def parse(self) -> List[TestBlock]: + # get a verified list of dicts for the key handled by this parser for data in self.get_key_dicts(): data = self.project.credentials.translate_aliases( data, recurse=True @@ -714,10 +820,12 @@ def add_source_definitions(self, source: UnparsedSourceDefinition) -> None: self.results.add_source(self.yaml.file, result) +# This class has three main subclasses: TestablePatchParser (models, +# seeds, snapshots), MacroPatchParser, and AnalysisPatchParser class NonSourceParser(YamlDocsReader, Generic[NonSourceTarget, Parsed]): @abstractmethod def _target_type(self) -> Type[NonSourceTarget]: - raise NotImplementedError('_unsafe_from_dict not implemented') + raise NotImplementedError('_target_type not implemented') @abstractmethod def get_block(self, node: NonSourceTarget) -> TargetBlock: @@ -732,33 +840,55 @@ def parse_patch( def parse(self) -> List[TestBlock]: node: NonSourceTarget test_blocks: List[TestBlock] = [] + # get list of 'node' objects + # UnparsedNodeUpdate (TestablePatchParser, models, seeds, snapshots) + # = HasColumnTests, HasTests + # UnparsedAnalysisUpdate (UnparsedAnalysisParser, analyses) + # = HasColumnDocs, HasDocs + # UnparsedMacroUpdate (MacroPatchParser, 'macros') + # = HasDocs + # correspond to this parser's 'key' for node in self.get_unparsed_target(): + # node_block is a TargetBlock (Macro or Analysis) + # or a TestBlock (all of the others) node_block = self.get_block(node) if isinstance(node_block, TestBlock): + # TestablePatchParser = models, seeds, snapshots test_blocks.append(node_block) if isinstance(node, (HasColumnDocs, HasColumnTests)): + # UnparsedNodeUpdate and UnparsedAnalysisUpdate refs: ParserRef = ParserRef.from_target(node) else: refs = ParserRef() + # This adds the node_block to self.results (a ParseResult + # object) as a ParsedNodePatch or ParsedMacroPatch self.parse_patch(node_block, refs) return test_blocks def get_unparsed_target(self) -> Iterable[NonSourceTarget]: path = self.yaml.path.original_file_path - for data in self.get_key_dicts(): + # get verified list of dicts for the 'key' that this + # parser handles + key_dicts = self.get_key_dicts() + for data in key_dicts: + # add extra data to each dict. This updates the dicts + # in the parser yaml data.update({ 'original_file_path': path, 'yaml_key': self.key, 'package_name': self.project.project_name, }) try: - model = self._target_type().from_dict(data) + # target_type: UnparsedNodeUpdate, UnparsedAnalysisUpdate, + # or UnparsedMacroUpdate + self._target_type().validate(data) + node = self._target_type().from_dict(data) except (ValidationError, JSONValidationException) as exc: msg = error_context(path, self.key, data, exc) raise CompilationException(msg) from exc else: - yield model + yield node class NodePatchParser( @@ -866,6 +996,7 @@ def parse_exposure(self, unparsed: UnparsedExposure) -> ParsedExposure: def parse(self) -> Iterable[ParsedExposure]: for data in self.get_key_dicts(): try: + UnparsedExposure.validate(data) unparsed = UnparsedExposure.from_dict(data) except (ValidationError, JSONValidationException) as exc: msg = error_context(self.yaml.path, self.key, data, exc) diff --git a/core/dbt/parser/seeds.py b/core/dbt/parser/seeds.py index c3b514432f4..ae8283f24f2 100644 --- a/core/dbt/parser/seeds.py +++ b/core/dbt/parser/seeds.py @@ -13,7 +13,9 @@ def get_paths(self): ) def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode: - return ParsedSeedNode.from_dict(dct, validate=validate) + if validate: + ParsedSeedNode.validate(dct) + return ParsedSeedNode.from_dict(dct) @property def resource_type(self) -> NodeType: diff --git a/core/dbt/parser/snapshots.py b/core/dbt/parser/snapshots.py index 56dfa296b30..16e5c52ab0d 100644 --- a/core/dbt/parser/snapshots.py +++ b/core/dbt/parser/snapshots.py @@ -1,7 +1,7 @@ import os from typing import List -from hologram import ValidationError +from dbt.dataclass_schema import ValidationError from dbt.contracts.graph.parsed import ( IntermediateSnapshotNode, ParsedSnapshotNode @@ -26,7 +26,9 @@ def get_paths(self): ) def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode: - return IntermediateSnapshotNode.from_dict(dct, validate=validate) + if validate: + IntermediateSnapshotNode.validate(dct) + return IntermediateSnapshotNode.from_dict(dct) @property def resource_type(self) -> NodeType: diff --git a/core/dbt/parser/sources.py b/core/dbt/parser/sources.py index 6dd9ab739f8..30f018f6029 100644 --- a/core/dbt/parser/sources.py +++ b/core/dbt/parser/sources.py @@ -6,7 +6,7 @@ Set, ) from dbt.config import RuntimeConfig -from dbt.contracts.graph.manifest import Manifest, SourceKey +from dbt.contracts.graph.manifest import MacroManifest, SourceKey from dbt.contracts.graph.parsed import ( UnpatchedSourceDefinition, ParsedSourceDefinition, @@ -33,7 +33,7 @@ def __init__( ) -> None: self.results = results self.root_project = root_project - self.macro_manifest = Manifest.from_macros( + self.macro_manifest = MacroManifest( macros=self.results.macros, files=self.results.files ) diff --git a/core/dbt/perf_utils.py b/core/dbt/perf_utils.py index 782859a88a3..e73941a918e 100644 --- a/core/dbt/perf_utils.py +++ b/core/dbt/perf_utils.py @@ -3,7 +3,7 @@ """ from dbt.adapters.factory import get_adapter from dbt.parser.manifest import load_manifest -from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.manifest import Manifest, MacroManifest from dbt.config import RuntimeConfig @@ -23,10 +23,10 @@ def get_full_manifest( config.clear_dependencies() adapter.clear_macro_manifest() - internal: Manifest = adapter.load_macro_manifest() + macro_manifest: MacroManifest = adapter.load_macro_manifest() return load_manifest( config, - internal, + macro_manifest, adapter.connections.set_query_header, ) diff --git a/core/dbt/rpc/logger.py b/core/dbt/rpc/logger.py index 047747293bb..f5d6b86add4 100644 --- a/core/dbt/rpc/logger.py +++ b/core/dbt/rpc/logger.py @@ -1,8 +1,7 @@ import logbook import logbook.queues from jsonrpc.exceptions import JSONRPCError -from hologram import JsonSchemaMixin -from hologram.helpers import StrEnum +from dbt.dataclass_schema import StrEnum from dataclasses import dataclass, field from datetime import datetime, timedelta @@ -25,8 +24,11 @@ class QueueMessageType(StrEnum): terminating = frozenset((Error, Result, Timeout)) +# This class was subclassed from JsonSchemaMixin, but it +# doesn't appear to be necessary, and Mashumaro does not +# handle logbook.LogRecord @dataclass -class QueueMessage(JsonSchemaMixin): +class QueueMessage: message_type: QueueMessageType diff --git a/core/dbt/rpc/method.py b/core/dbt/rpc/method.py index 74563305107..5e9ffdc1707 100644 --- a/core/dbt/rpc/method.py +++ b/core/dbt/rpc/method.py @@ -3,7 +3,7 @@ from copy import deepcopy from typing import List, Optional, Type, TypeVar, Generic, Dict, Any -from hologram import JsonSchemaMixin, ValidationError +from dbt.dataclass_schema import dbtClassMixin, ValidationError from dbt.contracts.rpc import RPCParameters, RemoteResult, RemoteMethodFlags from dbt.exceptions import NotImplementedException, InternalException @@ -109,7 +109,7 @@ def run(self): 'the run() method on builtins should never be called' ) - def __call__(self, **kwargs: Dict[str, Any]) -> JsonSchemaMixin: + def __call__(self, **kwargs: Dict[str, Any]) -> dbtClassMixin: try: params = self.get_parameters().from_dict(kwargs) except ValidationError as exc: diff --git a/core/dbt/rpc/response_manager.py b/core/dbt/rpc/response_manager.py index 7bf9ae746f9..1d44f7e0cbe 100644 --- a/core/dbt/rpc/response_manager.py +++ b/core/dbt/rpc/response_manager.py @@ -1,7 +1,7 @@ import json from typing import Callable, Dict, Any -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin from jsonrpc.exceptions import ( JSONRPCParseError, JSONRPCInvalidRequestException, @@ -90,11 +90,14 @@ def handle_valid_request( @classmethod def _get_responses(cls, requests, dispatcher): for output in super()._get_responses(requests, dispatcher): - # if it's a result, check if it's a JsonSchemaMixin and if so call + # if it's a result, check if it's a dbtClassMixin and if so call # to_dict if hasattr(output, 'result'): - if isinstance(output.result, JsonSchemaMixin): - output.result = output.result.to_dict(omit_none=False) + if isinstance(output.result, dbtClassMixin): + # Note: errors in to_dict do not show up anywhere in + # the output and all you get is a generic 500 error + output.result = \ + output.result.to_dict(options={'keep_none': True}) yield output @classmethod diff --git a/core/dbt/rpc/task_handler.py b/core/dbt/rpc/task_handler.py index f2ed10a6d49..c657f4e3de7 100644 --- a/core/dbt/rpc/task_handler.py +++ b/core/dbt/rpc/task_handler.py @@ -9,7 +9,7 @@ ) from typing_extensions import Protocol -from hologram import JsonSchemaMixin, ValidationError +from dbt.dataclass_schema import dbtClassMixin, ValidationError import dbt.exceptions import dbt.flags @@ -283,7 +283,7 @@ def __init__( # - The actual thread that this represents, which writes its data to # the result and logs. The atomicity of list.append() and item # assignment means we don't need a lock. - self.result: Optional[JsonSchemaMixin] = None + self.result: Optional[dbtClassMixin] = None self.error: Optional[RPCException] = None self.state: TaskHandlerState = TaskHandlerState.NotStarted self.logs: List[LogMessage] = [] @@ -453,6 +453,7 @@ def _collect_parameters(self): ) try: + cls.validate(self.task_kwargs) return cls.from_dict(self.task_kwargs) except ValidationError as exc: # raise a TypeError to indicate invalid parameters so we get a nice diff --git a/core/dbt/rpc/task_handler_protocol.py b/core/dbt/rpc/task_handler_protocol.py index 1479d9030e2..3fca77741a3 100644 --- a/core/dbt/rpc/task_handler_protocol.py +++ b/core/dbt/rpc/task_handler_protocol.py @@ -14,11 +14,11 @@ class TaskHandlerProtocol(Protocol): - started: Optional[datetime] - ended: Optional[datetime] - state: TaskHandlerState task_id: TaskID - process: Optional[multiprocessing.Process] + state: TaskHandlerState + started: Optional[datetime] = None + ended: Optional[datetime] = None + process: Optional[multiprocessing.Process] = None @property def request_id(self) -> Union[str, int]: diff --git a/core/dbt/semver.py b/core/dbt/semver.py index 2948975c651..8ce6e4f4dea 100644 --- a/core/dbt/semver.py +++ b/core/dbt/semver.py @@ -4,8 +4,7 @@ from dbt.exceptions import VersionsNotCompatibleException import dbt.utils -from hologram import JsonSchemaMixin -from hologram.helpers import StrEnum +from dbt.dataclass_schema import dbtClassMixin, StrEnum from typing import Optional @@ -18,12 +17,12 @@ class Matchers(StrEnum): @dataclass -class VersionSpecification(JsonSchemaMixin): - major: Optional[str] - minor: Optional[str] - patch: Optional[str] - prerelease: Optional[str] - build: Optional[str] +class VersionSpecification(dbtClassMixin): + major: Optional[str] = None + minor: Optional[str] = None + patch: Optional[str] = None + prerelease: Optional[str] = None + build: Optional[str] = None matcher: Matchers = Matchers.EXACT diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index 3a043ac4ff0..86c3ac57b85 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Dict, List, Any, Optional, Tuple, Set -from hologram import ValidationError +from dbt.dataclass_schema import ValidationError from .compile import CompileTask diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index d78223eef8b..5f91cf052a7 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -110,7 +110,7 @@ def generate_json(self): for node in self._iterate_selected_nodes(): yield json.dumps({ k: v - for k, v in node.to_dict(omit_none=False).items() + for k, v in node.to_dict(options={'keep_none': True}).items() if k in self.ALLOWED_KEYS }) diff --git a/core/dbt/task/parse.py b/core/dbt/task/parse.py index 4acfee3631a..6aa122153dc 100644 --- a/core/dbt/task/parse.py +++ b/core/dbt/task/parse.py @@ -8,12 +8,17 @@ # snakeviz dbt.cprof from dbt.task.base import ConfiguredTask from dbt.adapters.factory import get_adapter -from dbt.parser.manifest import Manifest, ManifestLoader, _check_manifest +from dbt.parser.manifest import ( + Manifest, MacroManifest, ManifestLoader, _check_manifest +) from dbt.logger import DbtProcessState, print_timestamped_line +from dbt.clients.system import write_file from dbt.graph import Graph import time from typing import Optional import os +import json +import dbt.utils MANIFEST_FILE_NAME = 'manifest.json' PERF_INFO_FILE_NAME = 'perf_info.json' @@ -33,7 +38,8 @@ def write_manifest(self): def write_perf_info(self): path = os.path.join(self.config.target_path, PERF_INFO_FILE_NAME) - self.loader._perf_info.write(path) + write_file(path, json.dumps(self.loader._perf_info, + cls=dbt.utils.JSONEncoder, indent=4)) print_timestamped_line(f"Performance info: {path}") # This method takes code that normally exists in other files @@ -47,7 +53,7 @@ def write_perf_info(self): def get_full_manifest(self): adapter = get_adapter(self.config) # type: ignore - macro_manifest: Manifest = adapter.load_macro_manifest() + macro_manifest: MacroManifest = adapter.load_macro_manifest() print_timestamped_line("Macro manifest loaded") root_config = self.config macro_hook = adapter.connections.set_query_header diff --git a/core/dbt/task/rpc/cli.py b/core/dbt/task/rpc/cli.py index 20855f1dac6..f05dee3e9f7 100644 --- a/core/dbt/task/rpc/cli.py +++ b/core/dbt/task/rpc/cli.py @@ -1,6 +1,6 @@ import abc import shlex -import yaml +from dbt.clients.yaml_helper import Dumper, yaml # noqa: F401 from typing import Type, Optional diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 001a230f3bb..0fbbdde0e4d 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -3,7 +3,7 @@ import time from typing import List, Dict, Any, Iterable, Set, Tuple, Optional, AbstractSet -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin from .compile import CompileRunner, CompileTask @@ -96,6 +96,7 @@ def get_hooks_by_tags( def get_hook(source, index): hook_dict = get_hook_dict(source) hook_dict.setdefault('index', index) + Hook.validate(hook_dict) return Hook.from_dict(hook_dict) @@ -191,7 +192,7 @@ def after_execute(self, result): def _build_run_model_result(self, model, context): result = context['load_result']('main') adapter_response = {} - if isinstance(result.response, JsonSchemaMixin): + if isinstance(result.response, dbtClassMixin): adapter_response = result.response.to_dict() return RunResult( node=model, diff --git a/core/dbt/tracking.py b/core/dbt/tracking.py index 1eba8860f3d..b36513996c9 100644 --- a/core/dbt/tracking.py +++ b/core/dbt/tracking.py @@ -1,6 +1,8 @@ from typing import Optional -from dbt.clients import yaml_helper +from dbt.clients.yaml_helper import ( # noqa:F401 + yaml, safe_load, Loader, Dumper, +) from dbt.logger import GLOBAL_LOGGER as logger from dbt import version as dbt_version from snowplow_tracker import Subject, Tracker, Emitter, logger as sp_logger @@ -12,7 +14,6 @@ import platform import uuid import requests -import yaml import os sp_logger.setLevel(100) @@ -147,7 +148,7 @@ def get_cookie(self): else: with open(self.cookie_path, "r") as fh: try: - user = yaml_helper.safe_load(fh) + user = safe_load(fh) if user is None: user = self.set_cookie() except yaml.reader.ReaderError: diff --git a/core/dbt/utils.py b/core/dbt/utils.py index f100fa27a65..c4618a7ef27 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -415,7 +415,7 @@ def restrict_to(*restrictions): def coerce_dict_str(value: Any) -> Optional[Dict[str, Any]]: """For annoying mypy reasons, this helper makes dealing with nested dicts easier. You get either `None` if it's not a Dict[str, Any], or the - Dict[str, Any] you expected (to pass it to JsonSchemaMixin.from_dict(...)). + Dict[str, Any] you expected (to pass it to dbtClassMixin.from_dict(...)). """ if (isinstance(value, dict) and all(isinstance(k, str) for k in value)): return value @@ -539,7 +539,9 @@ def fqn_search( level_config = root.get(level, None) if not isinstance(level_config, dict): break - yield copy.deepcopy(level_config) + # This used to do a 'deepcopy', + # but it didn't seem to be necessary + yield level_config root = level_config diff --git a/core/setup.py b/core/setup.py index bd37c39a202..3929099edf9 100644 --- a/core/setup.py +++ b/core/setup.py @@ -70,7 +70,7 @@ def read(fname): 'json-rpc>=1.12,<2', 'werkzeug>=0.15,<2.0', 'dataclasses==0.6;python_version<"3.7"', - 'hologram==0.0.12', + # 'hologram==0.0.12', # must be updated prior to release 'logbook>=1.5,<1.6', 'typing-extensions>=3.7.4,<3.8', # the following are all to match snowflake-connector-python diff --git a/dev_requirements.txt b/dev_requirements.txt index a7887c64529..490bca96d11 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -13,3 +13,6 @@ mypy==0.782 wheel twine pytest-logbook>=1.2.0,<1.3 +git+https://github.com/fishtown-analytics/hologram.git@mashumaro-support +git+https://github.com/fishtown-analytics/dbt-mashumaro.git@dbt-customizations +jsonschema diff --git a/plugins/bigquery/dbt/adapters/bigquery/connections.py b/plugins/bigquery/dbt/adapters/bigquery/connections.py index da3fc7d99de..d660e75dd1d 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/connections.py +++ b/plugins/bigquery/dbt/adapters/bigquery/connections.py @@ -27,7 +27,7 @@ from dbt.logger import GLOBAL_LOGGER as logger from dbt.version import __version__ as dbt_version -from hologram.helpers import StrEnum +from dbt.dataclass_schema import StrEnum BQ_QUERY_JOB_SPLIT = '-----Query Job SQL Follows-----' diff --git a/plugins/bigquery/dbt/adapters/bigquery/impl.py b/plugins/bigquery/dbt/adapters/bigquery/impl.py index 72d56b9b5f5..d45dc286862 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/impl.py +++ b/plugins/bigquery/dbt/adapters/bigquery/impl.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Any, Set, Union -from hologram import JsonSchemaMixin, ValidationError +from dbt.dataclass_schema import dbtClassMixin, ValidationError import dbt.deprecations import dbt.exceptions @@ -47,7 +47,7 @@ def sql_escape(string): @dataclass -class PartitionConfig(JsonSchemaMixin): +class PartitionConfig(dbtClassMixin): field: str data_type: str = 'date' granularity: str = 'day' @@ -69,6 +69,7 @@ def parse(cls, raw_partition_by) -> Optional['PartitionConfig']: if raw_partition_by is None: return None try: + cls.validate(raw_partition_by) return cls.from_dict(raw_partition_by) except ValidationError as exc: msg = dbt.exceptions.validator_error_message(exc) @@ -84,7 +85,7 @@ def parse(cls, raw_partition_by) -> Optional['PartitionConfig']: @dataclass -class GrantTarget(JsonSchemaMixin): +class GrantTarget(dbtClassMixin): dataset: str project: str @@ -808,6 +809,7 @@ def grant_access_to(self, entity, entity_type, role, grant_target_dict): conn = self.connections.get_thread_connection() client = conn.handle + GrantTarget.validate(grant_target_dict) grant_target = GrantTarget.from_dict(grant_target_dict) dataset = client.get_dataset( self.connections.dataset_from_id(grant_target.render()) diff --git a/plugins/postgres/dbt/adapters/postgres/connections.py b/plugins/postgres/dbt/adapters/postgres/connections.py index 5d735406dd6..96bb0b0f5d2 100644 --- a/plugins/postgres/dbt/adapters/postgres/connections.py +++ b/plugins/postgres/dbt/adapters/postgres/connections.py @@ -17,9 +17,9 @@ class PostgresCredentials(Credentials): host: str user: str - role: Optional[str] port: Port password: str # on postgres the password is mandatory + role: Optional[str] = None search_path: Optional[str] = None keepalives_idle: int = 0 # 0 means to use the default value sslmode: Optional[str] = None diff --git a/plugins/redshift/dbt/adapters/redshift/connections.py b/plugins/redshift/dbt/adapters/redshift/connections.py index 731685a5c2d..4fe65eb3bcf 100644 --- a/plugins/redshift/dbt/adapters/redshift/connections.py +++ b/plugins/redshift/dbt/adapters/redshift/connections.py @@ -10,8 +10,7 @@ import boto3 -from hologram import FieldEncoder, JsonSchemaMixin -from hologram.helpers import StrEnum +from dbt.dataclass_schema import FieldEncoder, dbtClassMixin, StrEnum from dataclasses import dataclass, field from typing import Optional, List @@ -28,7 +27,7 @@ def json_schema(self): return {'type': 'integer', 'minimum': 0, 'maximum': 65535} -JsonSchemaMixin.register_field_encoders({IAMDuration: IAMDurationEncoder()}) +dbtClassMixin.register_field_encoders({IAMDuration: IAMDurationEncoder()}) class RedshiftConnectionMethod(StrEnum): diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index 374352eba03..4821cdaad09 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -30,16 +30,16 @@ class SnowflakeCredentials(Credentials): account: str user: str - warehouse: Optional[str] - role: Optional[str] - password: Optional[str] - authenticator: Optional[str] - private_key_path: Optional[str] - private_key_passphrase: Optional[str] - token: Optional[str] - oauth_client_id: Optional[str] - oauth_client_secret: Optional[str] - query_tag: Optional[str] + warehouse: Optional[str] = None + role: Optional[str] = None + password: Optional[str] = None + authenticator: Optional[str] = None + private_key_path: Optional[str] = None + private_key_passphrase: Optional[str] = None + token: Optional[str] = None + oauth_client_id: Optional[str] = None + oauth_client_secret: Optional[str] = None + query_tag: Optional[str] = None client_session_keep_alive: bool = False def __post_init__(self): diff --git a/scripts/check_libyaml.py b/scripts/check_libyaml.py new file mode 100755 index 00000000000..d4dd98f2260 --- /dev/null +++ b/scripts/check_libyaml.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +try: + from yaml import ( + CLoader as Loader, + CSafeLoader as SafeLoader, + CDumper as Dumper + ) +except ImportError: + from yaml import ( + Loader, SafeLoader, Dumper + ) + +if Loader.__name__ == 'CLoader': + print("libyaml is working") +elif Loader.__name__ == 'Loader': + print("libyaml is not working") + print("Check the python executable and pyyaml for libyaml support") diff --git a/scripts/collect-artifact-schema.py b/scripts/collect-artifact-schema.py index ecdf98a82eb..5de3d1e3969 100644 --- a/scripts/collect-artifact-schema.py +++ b/scripts/collect-artifact-schema.py @@ -3,7 +3,7 @@ from typing import Dict, Any import json -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin from dbt.contracts.graph.manifest import WritableManifest from dbt.contracts.results import ( CatalogArtifact, RunResultsArtifact, FreshnessExecutionResultArtifact @@ -11,7 +11,7 @@ @dataclass -class Schemas(JsonSchemaMixin): +class Schemas(dbtClassMixin): manifest: Dict[str, Any] catalog: Dict[str, Any] run_results: Dict[str, Any] diff --git a/scripts/collect-dbt-contexts.py b/scripts/collect-dbt-contexts.py index 24d66757f42..cf353f5ffbe 100644 --- a/scripts/collect-dbt-contexts.py +++ b/scripts/collect-dbt-contexts.py @@ -4,7 +4,7 @@ import json from dataclasses import dataclass from typing import List, Optional, Iterable, Union, Dict, Any -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin from dbt.context.base import BaseContext @@ -21,20 +21,20 @@ @dataclass -class ContextValue(JsonSchemaMixin): +class ContextValue(dbtClassMixin): name: str value: str # a type description doc: Optional[str] @dataclass -class MethodArgument(JsonSchemaMixin): +class MethodArgument(dbtClassMixin): name: str value: str # a type description @dataclass -class ContextMethod(JsonSchemaMixin): +class ContextMethod(dbtClassMixin): name: str args: List[MethodArgument] result: str # a type description @@ -42,7 +42,7 @@ class ContextMethod(JsonSchemaMixin): @dataclass -class Unknown(JsonSchemaMixin): +class Unknown(dbtClassMixin): name: str value: str doc: Optional[str] @@ -96,7 +96,7 @@ def collect(cls): @dataclass -class ContextCatalog(JsonSchemaMixin): +class ContextCatalog(dbtClassMixin): base: List[ContextMember] target: List[ContextMember] model: List[ContextMember] diff --git a/scripts/dtr.py b/scripts/dtr.py index 78580da78c4..0920f90c743 100755 --- a/scripts/dtr.py +++ b/scripts/dtr.py @@ -7,10 +7,11 @@ import sys # Python version defaults to 3.6 -# To run postgres integration tests: `dtr.py -i --pg` (this is the default) -# To run postgres integration tests, clearing `dbt.log` beforehand: `dtr.py -il --pg` -# To run postgres + redshift integration tests: `dtr.py -i --pg --rs` -# To drop to pdb on failure, add `--pdb` +# To run postgres integration tests: `dtr.py -i -t pg` (this is the default) +# To run postgres integration tests, clearing `dbt.log` beforehand: `dtr.py -il -t pg` +# dtr.py -i -t pg -a test/integration/029_docs_generate_tests +# To run postgres + redshift integration tests: `dtr.py -i -t pg -t rs` +# To drop to pdb on failure, add `--pdb` or `-p` # To run mypy tests: `dtr.py -m`. # To run flake8 test: `dtr.py -f`. # To run unit tests: `dtr.py -u` @@ -82,12 +83,12 @@ def parse_args(argv): ) parser.add_argument('-v', '--python-version', - default='36', choices=['27', '36', '37', '38'], + default='38', choices=['36', '37', '38'], help='what python version to run') parser.add_argument( '-t', '--types', default=None, - help='The types of tests to run, if this is an integration run, as csv' + help='The types of tests to run, if this is an integration run' ) parser.add_argument( '-c', '--continue', diff --git a/test/integration/004_simple_snapshot_test/test_simple_snapshot.py b/test/integration/004_simple_snapshot_test/test_simple_snapshot.py index e3ea509c707..f522fee583e 100644 --- a/test/integration/004_simple_snapshot_test/test_simple_snapshot.py +++ b/test/integration/004_simple_snapshot_test/test_simple_snapshot.py @@ -531,7 +531,7 @@ def test__postgres__invalid(self): with self.assertRaises(dbt.exceptions.CompilationException) as exc: self.run_dbt(['compile'], expect_pass=False) - self.assertIn('target_schema', str(exc.exception)) + self.assertIn('Compilation Error in model ref_snapshot', str(exc.exception)) class TestCheckCols(TestSimpleSnapshotFiles): diff --git a/test/integration/047_dbt_ls_test/test_ls.py b/test/integration/047_dbt_ls_test/test_ls.py index f4ef4726a07..747c4a1f947 100644 --- a/test/integration/047_dbt_ls_test/test_ls.py +++ b/test/integration/047_dbt_ls_test/test_ls.py @@ -92,6 +92,7 @@ def expect_snapshot_output(self): 'database': None, 'schema': None, 'alias': None, + 'check_cols': None, }, 'alias': 'my_snapshot', 'resource_type': 'snapshot', diff --git a/test/integration/100_rpc_test/test_rpc.py b/test/integration/100_rpc_test/test_rpc.py index 733ce6341a6..0517e9e7484 100644 --- a/test/integration/100_rpc_test/test_rpc.py +++ b/test/integration/100_rpc_test/test_rpc.py @@ -66,7 +66,7 @@ def is_up(self): def start(self): super().start() - for _ in range(180): + for _ in range(240): if self.is_up(): break time.sleep(0.5) diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index 0380c7e095e..00a4ef3b885 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -6,7 +6,7 @@ from requests.exceptions import ConnectionError from unittest.mock import patch, MagicMock, Mock, create_autospec, ANY -import hologram +import dbt.dataclass_schema import dbt.flags as flags @@ -19,6 +19,7 @@ from dbt.clients import agate_helper import dbt.exceptions from dbt.logger import GLOBAL_LOGGER as logger # noqa +from dbt.context.providers import RuntimeConfigObject import google.cloud.bigquery @@ -364,7 +365,7 @@ def test_view_temp_relation(self): 'identifier': False } } - BigQueryRelation.from_dict(kwargs) + BigQueryRelation.validate(kwargs) def test_view_relation(self): kwargs = { @@ -379,7 +380,7 @@ def test_view_relation(self): 'schema': True } } - BigQueryRelation.from_dict(kwargs) + BigQueryRelation.validate(kwargs) def test_table_relation(self): kwargs = { @@ -394,7 +395,7 @@ def test_table_relation(self): 'schema': True } } - BigQueryRelation.from_dict(kwargs) + BigQueryRelation.validate(kwargs) def test_external_source_relation(self): kwargs = { @@ -409,7 +410,7 @@ def test_external_source_relation(self): 'schema': True } } - BigQueryRelation.from_dict(kwargs) + BigQueryRelation.validate(kwargs) def test_invalid_relation(self): kwargs = { @@ -424,8 +425,8 @@ def test_invalid_relation(self): 'schema': True } } - with self.assertRaises(hologram.ValidationError): - BigQueryRelation.from_dict(kwargs) + with self.assertRaises(dbt.dataclass_schema.ValidationError): + BigQueryRelation.validate(kwargs) class TestBigQueryInformationSchema(unittest.TestCase): @@ -451,6 +452,7 @@ def test_replace(self): 'identifier': True, } } + BigQueryRelation.validate(kwargs) relation = BigQueryRelation.from_dict(kwargs) info_schema = relation.information_schema() @@ -808,7 +810,7 @@ def test_parse_partition_by(self): def test_hours_to_expiration(self): adapter = self.get_adapter('oauth') mock_config = create_autospec( - dbt.context.providers.RuntimeConfigObject) + RuntimeConfigObject) config = {'hours_to_expiration': 4} mock_config.get.side_effect = lambda name: config.get(name) @@ -822,7 +824,7 @@ def test_hours_to_expiration(self): def test_hours_to_expiration_temporary(self): adapter = self.get_adapter('oauth') mock_config = create_autospec( - dbt.context.providers.RuntimeConfigObject) + RuntimeConfigObject) config={'hours_to_expiration': 4} mock_config.get.side_effect = lambda name: config.get(name) diff --git a/test/unit/test_context.py b/test/unit/test_context.py index cf673286b47..b9098d943cc 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -415,16 +415,6 @@ def test_query_header_context(config, manifest_fx): assert_has_keys(REQUIRED_QUERY_HEADER_KEYS, MAYBE_KEYS, ctx) -def test_macro_parse_context(config, manifest_fx, get_adapter, get_include_paths): - ctx = providers.generate_parser_macro( - macro=manifest_fx.macros['macro.root.macro_a'], - config=config, - manifest=manifest_fx, - package_name='root', - ) - assert_has_keys(REQUIRED_MACRO_KEYS, MAYBE_KEYS, ctx) - - def test_macro_runtime_context(config, manifest_fx, get_adapter, get_include_paths): ctx = providers.generate_runtime_macro( macro=manifest_fx.macros['macro.root.macro_a'], diff --git a/test/unit/test_contracts_graph_compiled.py b/test/unit/test_contracts_graph_compiled.py index c1c1653bac7..5d6e6a4d8d9 100644 --- a/test/unit/test_contracts_graph_compiled.py +++ b/test/unit/test_contracts_graph_compiled.py @@ -16,6 +16,7 @@ assert_fails_validation, dict_replace, replace_config, + compare_dicts, ) diff --git a/test/unit/test_contracts_graph_parsed.py b/test/unit/test_contracts_graph_parsed.py index dc3d593241e..7e968ed9fe5 100644 --- a/test/unit/test_contracts_graph_parsed.py +++ b/test/unit/test_contracts_graph_parsed.py @@ -4,15 +4,12 @@ from dbt.node_types import NodeType from dbt.contracts.files import FileHash from dbt.contracts.graph.model_config import ( - All, NodeConfig, SeedConfig, TestConfig, - TimestampSnapshotConfig, - CheckSnapshotConfig, + SnapshotConfig, SourceConfig, EmptySnapshotConfig, - SnapshotStrategy, Hook, ) from dbt.contracts.graph.parsed import ( @@ -44,8 +41,8 @@ ) from dbt import flags -from hologram import ValidationError -from .utils import ContractTestCase, assert_symmetric, assert_from_dict, assert_fails_validation, dict_replace, replace_config +from dbt.dataclass_schema import ValidationError +from .utils import ContractTestCase, assert_symmetric, assert_from_dict, assert_to_dict, compare_dicts, assert_fails_validation, dict_replace, replace_config @pytest.fixture(autouse=True) @@ -724,7 +721,7 @@ def test_patch_parsed_model(basic_parsed_model_object, basic_parsed_model_patch_ def test_patch_parsed_model_invalid(basic_parsed_model_object, basic_parsed_model_patch_object): - pre_patch = basic_parsed_model_object + pre_patch = basic_parsed_model_object # ParsedModelNode patch = basic_parsed_model_patch_object.replace(description=None) with pytest.raises(ValidationError): pre_patch.patch(patch) @@ -1144,7 +1141,9 @@ def test_basic_schema_test_node(minimal_parsed_schema_test_dict, basic_parsed_sc def test_complex_schema_test_node(complex_parsed_schema_test_dict, complex_parsed_schema_test_object): - node = complex_parsed_schema_test_object + # this tests for the presence of _extra keys + node = complex_parsed_schema_test_object # ParsedSchemaTestNode + assert(node.config._extra['extra_key']) node_dict = complex_parsed_schema_test_dict assert_symmetric(node, node_dict) assert node.empty is False @@ -1185,8 +1184,8 @@ def basic_timestamp_snapshot_config_dict(): @pytest.fixture def basic_timestamp_snapshot_config_object(): - return TimestampSnapshotConfig( - strategy=SnapshotStrategy.Timestamp, + return SnapshotConfig( + strategy='timestamp', updated_at='last_update', unique_key='id', target_database='some_snapshot_db', @@ -1217,11 +1216,11 @@ def complex_timestamp_snapshot_config_dict(): @pytest.fixture def complex_timestamp_snapshot_config_object(): - cfg = TimestampSnapshotConfig( + cfg = SnapshotConfig( column_types={'a': 'text'}, materialized='snapshot', post_hook=[Hook(sql='insert into blah(a, b) select "1", 1')], - strategy=SnapshotStrategy.Timestamp, + strategy='timestamp', target_database='some_snapshot_db', target_schema='some_snapshot_schema', updated_at='last_update', @@ -1241,20 +1240,14 @@ def test_basic_timestamp_snapshot_config(basic_timestamp_snapshot_config_dict, b def test_complex_timestamp_snapshot_config(complex_timestamp_snapshot_config_dict, complex_timestamp_snapshot_config_object): cfg = complex_timestamp_snapshot_config_object cfg_dict = complex_timestamp_snapshot_config_dict - assert_symmetric(cfg, cfg_dict, TimestampSnapshotConfig) - - -def test_invalid_wrong_strategy(basic_timestamp_snapshot_config_dict): - bad_type = basic_timestamp_snapshot_config_dict - bad_type['strategy'] = 'check' - assert_fails_validation(bad_type, TimestampSnapshotConfig) + assert_symmetric(cfg, cfg_dict, SnapshotConfig) def test_invalid_missing_updated_at(basic_timestamp_snapshot_config_dict): bad_fields = basic_timestamp_snapshot_config_dict del bad_fields['updated_at'] bad_fields['check_cols'] = 'all' - assert_fails_validation(bad_fields, TimestampSnapshotConfig) + assert_fails_validation(bad_fields, SnapshotConfig) @pytest.fixture @@ -1279,9 +1272,9 @@ def basic_check_snapshot_config_dict(): @pytest.fixture def basic_check_snapshot_config_object(): - return CheckSnapshotConfig( - strategy=SnapshotStrategy.Check, - check_cols=All.All, + return SnapshotConfig( + strategy='check', + check_cols='all', unique_key='id', target_database='some_snapshot_db', target_schema='some_snapshot_schema', @@ -1311,11 +1304,11 @@ def complex_set_snapshot_config_dict(): @pytest.fixture def complex_set_snapshot_config_object(): - cfg = CheckSnapshotConfig( + cfg = SnapshotConfig( column_types={'a': 'text'}, materialized='snapshot', post_hook=[Hook(sql='insert into blah(a, b) select "1", 1')], - strategy=SnapshotStrategy.Check, + strategy='check', check_cols=['a', 'b'], target_database='some_snapshot_db', target_schema='some_snapshot_schema', @@ -1328,7 +1321,7 @@ def complex_set_snapshot_config_object(): def test_basic_snapshot_config(basic_check_snapshot_config_dict, basic_check_snapshot_config_object): cfg_dict = basic_check_snapshot_config_dict cfg = basic_check_snapshot_config_object - assert_symmetric(cfg, cfg_dict, CheckSnapshotConfig) + assert_symmetric(cfg, cfg_dict, SnapshotConfig) pickle.loads(pickle.dumps(cfg)) @@ -1342,20 +1335,20 @@ def test_complex_snapshot_config(complex_set_snapshot_config_dict, complex_set_s def test_invalid_check_wrong_strategy(basic_check_snapshot_config_dict): wrong_strategy = basic_check_snapshot_config_dict wrong_strategy['strategy'] = 'timestamp' - assert_fails_validation(wrong_strategy, CheckSnapshotConfig) + assert_fails_validation(wrong_strategy, SnapshotConfig) def test_invalid_missing_check_cols(basic_check_snapshot_config_dict): wrong_fields = basic_check_snapshot_config_dict del wrong_fields['check_cols'] - with pytest.raises(ValidationError, match=r"'check_cols' is a required property"): - CheckSnapshotConfig.from_dict(wrong_fields) + with pytest.raises(ValidationError, match=r"A snapshot configured with the check strategy"): + SnapshotConfig.validate(wrong_fields) def test_invalid_check_value(basic_check_snapshot_config_dict): invalid_check_type = basic_check_snapshot_config_dict invalid_check_type['check_cols'] = 'some' - assert_fails_validation(invalid_check_type, CheckSnapshotConfig) + assert_fails_validation(invalid_check_type, SnapshotConfig) @pytest.fixture @@ -1429,8 +1422,8 @@ def basic_timestamp_snapshot_object(): schema='test_schema', alias='bar', tags=[], - config=TimestampSnapshotConfig( - strategy=SnapshotStrategy.Timestamp, + config=SnapshotConfig( + strategy='timestamp', unique_key='id', updated_at='last_update', target_database='some_snapshot_db', @@ -1559,10 +1552,10 @@ def basic_check_snapshot_object(): schema='test_schema', alias='bar', tags=[], - config=CheckSnapshotConfig( - strategy=SnapshotStrategy.Check, + config=SnapshotConfig( + strategy='check', unique_key='id', - check_cols=All.All, + check_cols='all', target_database='some_snapshot_db', target_schema='some_snapshot_schema', ), diff --git a/test/unit/test_contracts_project.py b/test/unit/test_contracts_project.py index 68b6c9496ea..c100eb24fcc 100644 --- a/test/unit/test_contracts_project.py +++ b/test/unit/test_contracts_project.py @@ -1,6 +1,6 @@ from .utils import ContractTestCase -from hologram import ValidationError +from dbt.dataclass_schema import ValidationError from dbt.contracts.project import Project @@ -34,7 +34,7 @@ def test_invalid_name(self): 'config-version': 2, } with self.assertRaises(ValidationError): - self.ContractType.from_dict(dct) + self.ContractType.validate(dct) def test_unsupported_version(self): dct = { @@ -43,5 +43,5 @@ def test_unsupported_version(self): 'profile': 'test', 'project-root': '/usr/src/app', } - with self.assertRaises(ValidationError): + with self.assertRaises(Exception): self.ContractType.from_dict(dct) diff --git a/test/unit/test_deps.py b/test/unit/test_deps.py index 3e3eacc499d..4741a77e1bb 100644 --- a/test/unit/test_deps.py +++ b/test/unit/test_deps.py @@ -16,7 +16,7 @@ from dbt.contracts.project import PackageConfig from dbt.semver import VersionSpecifier -from hologram import ValidationError +from dbt.dataclass_schema import ValidationError class TestLocalPackage(unittest.TestCase): @@ -33,7 +33,7 @@ def test_init(self): class TestGitPackage(unittest.TestCase): def test_init(self): a_contract = GitPackage.from_dict( - {'git': 'http://example.com', 'revision': '0.0.1'} + {'git': 'http://example.com', 'revision': '0.0.1'}, ) self.assertEqual(a_contract.git, 'http://example.com') self.assertEqual(a_contract.revision, '0.0.1') @@ -52,17 +52,17 @@ def test_init(self): def test_invalid(self): with self.assertRaises(ValidationError): - GitPackage.from_dict( - {'git': 'http://example.com', 'version': '0.0.1'} + GitPackage.validate( + {'git': 'http://example.com', 'version': '0.0.1'}, ) def test_resolve_ok(self): a_contract = GitPackage.from_dict( - {'git': 'http://example.com', 'revision': '0.0.1'} + {'git': 'http://example.com', 'revision': '0.0.1'}, ) b_contract = GitPackage.from_dict( {'git': 'http://example.com', 'revision': '0.0.1', - 'warn-unpinned': False} + 'warn-unpinned': False}, ) a = GitUnpinnedPackage.from_contract(a_contract) b = GitUnpinnedPackage.from_contract(b_contract) @@ -78,10 +78,10 @@ def test_resolve_ok(self): def test_resolve_fail(self): a_contract = GitPackage.from_dict( - {'git': 'http://example.com', 'revision': '0.0.1'} + {'git': 'http://example.com', 'revision': '0.0.1'}, ) b_contract = GitPackage.from_dict( - {'git': 'http://example.com', 'revision': '0.0.2'} + {'git': 'http://example.com', 'revision': '0.0.2'}, ) a = GitUnpinnedPackage.from_contract(a_contract) b = GitUnpinnedPackage.from_contract(b_contract) @@ -170,8 +170,8 @@ def test_init(self): def test_invalid(self): with self.assertRaises(ValidationError): - RegistryPackage.from_dict( - {'package': 'namespace/name', 'key': 'invalid'} + RegistryPackage.validate( + {'package': 'namespace/name', 'key': 'invalid'}, ) def test_resolve_ok(self): diff --git a/test/unit/test_docs_blocks.py b/test/unit/test_docs_blocks.py index f8c228b9bb5..7870fdb3457 100644 --- a/test/unit/test_docs_blocks.py +++ b/test/unit/test_docs_blocks.py @@ -2,7 +2,7 @@ import unittest from dbt.contracts.files import SourceFile, FileHash, FilePath -from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.manifest import Manifest, MacroManifest from dbt.contracts.graph.parsed import ParsedDocumentation from dbt.node_types import NodeType from dbt.parser import docs @@ -147,7 +147,7 @@ def test_load_file(self): results=ParseResult.rpc(), root_project=self.root_project_config, project=self.subdir_project_config, - macro_manifest=Manifest.from_macros()) + macro_manifest=MacroManifest({}, {})) file_block = self._build_file(TEST_DOCUMENTATION_FILE, 'test_file.md') @@ -172,7 +172,7 @@ def test_load_file_extras(self): results=ParseResult.rpc(), root_project=self.root_project_config, project=self.subdir_project_config, - macro_manifest=Manifest.from_macros()) + macro_manifest=MacroManifest({}, {})) file_block = self._build_file(TEST_DOCUMENTATION_FILE, 'test_file.md') @@ -189,7 +189,7 @@ def test_multiple_raw_blocks(self): results=ParseResult.rpc(), root_project=self.root_project_config, project=self.subdir_project_config, - macro_manifest=Manifest.from_macros()) + macro_manifest=MacroManifest({}, {})) file_block = self._build_file(MULTIPLE_RAW_BLOCKS, 'test_file.md') diff --git a/test/unit/test_docs_generate.py b/test/unit/test_docs_generate.py index 49576e9e100..96dc568ff88 100644 --- a/test/unit/test_docs_generate.py +++ b/test/unit/test_docs_generate.py @@ -32,7 +32,7 @@ def generate_catalog_dict(self, columns): sources=sources, errors=None, ) - return result.to_dict(omit_none=False)['nodes'] + return result.to_dict(options={'keep_none': True})['nodes'] def test__unflatten_empty(self): columns = {} diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 20cd61522d3..944d39219c6 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -13,7 +13,7 @@ import dbt.utils import dbt.parser.manifest from dbt.contracts.files import SourceFile, FileHash, FilePath -from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.manifest import Manifest, MacroManifest from dbt.parser.results import ParseResult from dbt.parser.base import BaseParser from dbt.graph import NodeSelector, parse_difference @@ -105,9 +105,8 @@ def _mock_parse_result(config, all_projects): self.mock_source_file = self.load_source_file_patcher.start() self.mock_source_file.side_effect = lambda path: [n for n in self.mock_models if n.path == path][0] - self.macro_manifest = Manifest.from_macros(macros={ - n.unique_id: n for n in generate_name_macros('test_models_compile') - }) + self.macro_manifest = MacroManifest( + {n.unique_id: n for n in generate_name_macros('test_models_compile')}, {}) def filesystem_iter(iter_self): if 'sql' not in iter_self.extension: diff --git a/test/unit/test_macro_resolver.py b/test/unit/test_macro_resolver.py new file mode 100644 index 00000000000..e5fc03e365a --- /dev/null +++ b/test/unit/test_macro_resolver.py @@ -0,0 +1,56 @@ +import itertools +import unittest +import os +from typing import Set, Dict, Any +from unittest import mock + +import pytest + +# make sure 'redshift' is available +from dbt.adapters import postgres, redshift +from dbt.adapters import factory +from dbt.adapters.base import AdapterConfig +from dbt.contracts.graph.parsed import ( + ParsedModelNode, NodeConfig, DependsOn, ParsedMacro +) +from dbt.context import base, target, configured, providers, docs, manifest, macros +from dbt.contracts.files import FileHash +from dbt.node_types import NodeType +import dbt.exceptions +from .utils import profile_from_dict, config_from_parts_or_dicts, inject_adapter, clear_plugin +from .mock_adapter import adapter_factory + +from dbt.context.macro_resolver import MacroResolver + + +def mock_macro(name, package_name): + macro = mock.MagicMock( + __class__=ParsedMacro, + package_name=package_name, + resource_type='macro', + unique_id=f'macro.{package_name}.{name}', + ) + # Mock(name=...) does not set the `name` attribute, this does. + macro.name = name + return macro + +class TestMacroResolver(unittest.TestCase): + + def test_resolver(self): + data = [ + {'package_name': 'my_test', 'name': 'unique'}, + {'package_name': 'my_test', 'name': 'macro_xx'}, + {'package_name': 'one', 'name': 'unique'}, + {'package_name': 'one', 'name': 'not_null'}, + {'package_name': 'two', 'name': 'macro_a'}, + {'package_name': 'two', 'name': 'macro_b'}, + ] + macros = {} + for mdata in data: + macro = mock_macro(mdata['name'], mdata['package_name']) + macros[macro.unique_id] = macro + resolver = MacroResolver(macros, 'my_test', ['one']) + assert(resolver) + self.assertEqual(resolver.get_macro_id('one', 'not_null'), 'macro.one.not_null') + + diff --git a/test/unit/test_model_config.py b/test/unit/test_model_config.py index 841b14ed93c..c3195a70e32 100644 --- a/test/unit/test_model_config.py +++ b/test/unit/test_model_config.py @@ -1,12 +1,12 @@ from dataclasses import dataclass, field -from hologram import JsonSchemaMixin +from dbt.dataclass_schema import dbtClassMixin from typing import List, Dict import pytest from dbt.contracts.graph.model_config import MergeBehavior, ShowBehavior, CompareBehavior @dataclass -class ThingWithMergeBehavior(JsonSchemaMixin): +class ThingWithMergeBehavior(dbtClassMixin): default_behavior: int appended: List[str] = field(metadata={'merge': MergeBehavior.Append}) updated: Dict[str, int] = field(metadata={'merge': MergeBehavior.Update}) @@ -35,7 +35,7 @@ def test_merge_behavior_from_field(): @dataclass -class ThingWithShowBehavior(JsonSchemaMixin): +class ThingWithShowBehavior(dbtClassMixin): default_behavior: int hidden: str = field(metadata={'show_hide': ShowBehavior.Hide}) shown: float = field(metadata={'show_hide': ShowBehavior.Show}) @@ -61,7 +61,7 @@ def test_show_behavior_from_field(): @dataclass -class ThingWithCompareBehavior(JsonSchemaMixin): +class ThingWithCompareBehavior(dbtClassMixin): default_behavior: int included: float = field(metadata={'compare': CompareBehavior.Include}) excluded: str = field(metadata={'compare': CompareBehavior.Exclude}) diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index 9064a2e3af2..90b9d012a9a 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -20,9 +20,9 @@ from dbt.node_types import NodeType from dbt.contracts.files import SourceFile, FileHash, FilePath -from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.manifest import Manifest, MacroManifest from dbt.contracts.graph.model_config import ( - NodeConfig, TestConfig, TimestampSnapshotConfig, SnapshotStrategy, + NodeConfig, TestConfig, SnapshotConfig ) from dbt.contracts.graph.parsed import ( ParsedModelNode, ParsedMacro, ParsedNodePatch, DependsOn, ColumnInfo, @@ -126,8 +126,8 @@ def setUp(self): self.parser_patcher = mock.patch('dbt.parser.base.get_adapter') self.factory_parser = self.parser_patcher.start() - self.macro_manifest = Manifest.from_macros( - macros={m.unique_id: m for m in generate_name_macros('root')} + self.macro_manifest = MacroManifest( + {m.unique_id: m for m in generate_name_macros('root')}, {} ) def tearDown(self): @@ -544,8 +544,8 @@ def test_single_block(self): package_name='snowplow', original_file_path=normalize('snapshots/nested/snap_1.sql'), root_path=get_abs_os_path('./dbt_modules/snowplow'), - config=TimestampSnapshotConfig( - strategy=SnapshotStrategy.Timestamp, + config=SnapshotConfig( + strategy='timestamp', updated_at='last_update', target_database='dbt', target_schema='analytics', @@ -604,8 +604,8 @@ def test_multi_block(self): package_name='snowplow', original_file_path=normalize('snapshots/nested/snap_1.sql'), root_path=get_abs_os_path('./dbt_modules/snowplow'), - config=TimestampSnapshotConfig( - strategy=SnapshotStrategy.Timestamp, + config=SnapshotConfig( + strategy='timestamp', updated_at='last_update', target_database='dbt', target_schema='analytics', @@ -634,8 +634,8 @@ def test_multi_block(self): package_name='snowplow', original_file_path=normalize('snapshots/nested/snap_1.sql'), root_path=get_abs_os_path('./dbt_modules/snowplow'), - config=TimestampSnapshotConfig( - strategy=SnapshotStrategy.Timestamp, + config=SnapshotConfig( + strategy='timestamp', updated_at='last_update', target_database='dbt', target_schema='analytics', diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index cf5637f38ff..5e336d494ba 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -138,7 +138,7 @@ def test_default_session_is_not_used_when_iam_used(self): self.config.credentials.cluster_id = 'clusterid' with mock.patch('dbt.adapters.redshift.connections.boto3.Session'): RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) - self.assertEquals(boto3.DEFAULT_SESSION.client.call_count, 0, + self.assertEqual(boto3.DEFAULT_SESSION.client.call_count, 0, "The redshift client should not be created using the default session because the session object is not thread-safe") def test_default_session_is_not_used_when_iam_not_used(self): @@ -146,7 +146,7 @@ def test_default_session_is_not_used_when_iam_not_used(self): self.config.credentials = self.config.credentials.replace(method=None) with mock.patch('dbt.adapters.redshift.connections.boto3.Session'): RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) - self.assertEquals(boto3.DEFAULT_SESSION.client.call_count, 0, + self.assertEqual(boto3.DEFAULT_SESSION.client.call_count, 0, "The redshift client should not be created using the default session because the session object is not thread-safe") def test_cancel_open_connections_empty(self): diff --git a/test/unit/utils.py b/test/unit/utils.py index 8c3355ce88e..96bb0e2f140 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -10,7 +10,7 @@ import agate import pytest -from hologram import ValidationError +from dbt.dataclass_schema import ValidationError def normalize(path): @@ -150,6 +150,7 @@ def assert_to_dict(self, obj, dct): def assert_from_dict(self, obj, dct, cls=None): if cls is None: cls = self.ContractType + cls.validate(dct) self.assertEqual(cls.from_dict(dct), obj) def assert_symmetric(self, obj, dct, cls=None): @@ -161,9 +162,28 @@ def assert_fails_validation(self, dct, cls=None): cls = self.ContractType with self.assertRaises(ValidationError): + cls.validate(dct) cls.from_dict(dct) +def compare_dicts(dict1, dict2): + first_set = set(dict1.keys()) + second_set = set(dict2.keys()) + print(f"--- Difference between first and second keys: {first_set.difference(second_set)}") + print(f"--- Difference between second and first keys: {second_set.difference(first_set)}") + common_keys = set(first_set).intersection(set(second_set)) + found_differences = False + for key in common_keys: + if dict1[key] != dict2[key] : + print(f"--- --- first dict: {key}: {str(dict1[key])}") + print(f"--- --- second dict: {key}: {str(dict2[key])}") + found_differences = True + if found_differences: + print("--- Found differences in dictionaries") + else: + print("--- Found no differences in dictionaries") + + def assert_to_dict(obj, dct): assert obj.to_dict() == dct @@ -171,6 +191,7 @@ def assert_to_dict(obj, dct): def assert_from_dict(obj, dct, cls=None): if cls is None: cls = obj.__class__ + cls.validate(dct) assert cls.from_dict(dct) == obj @@ -181,6 +202,7 @@ def assert_symmetric(obj, dct, cls=None): def assert_fails_validation(dct, cls): with pytest.raises(ValidationError): + cls.validate(dct) cls.from_dict(dct) diff --git a/third-party-stubs/mashumaro/__init__.pyi b/third-party-stubs/mashumaro/__init__.pyi new file mode 100644 index 00000000000..662aa36260e --- /dev/null +++ b/third-party-stubs/mashumaro/__init__.pyi @@ -0,0 +1,5 @@ +from mashumaro.exceptions import MissingField as MissingField +from mashumaro.serializer.base.dict import DataClassDictMixin as DataClassDictMixin +from mashumaro.serializer.json import DataClassJSONMixin as DataClassJSONMixin +from mashumaro.serializer.msgpack import DataClassMessagePackMixin as DataClassMessagePackMixin +from mashumaro.serializer.yaml import DataClassYAMLMixin as DataClassYAMLMixin diff --git a/third-party-stubs/mashumaro/exceptions.pyi b/third-party-stubs/mashumaro/exceptions.pyi new file mode 100644 index 00000000000..cd4b7a18dc6 --- /dev/null +++ b/third-party-stubs/mashumaro/exceptions.pyi @@ -0,0 +1,37 @@ +from mashumaro.meta.helpers import type_name as type_name +from typing import Any, Optional + +class MissingField(LookupError): + field_name: Any = ... + field_type: Any = ... + holder_class: Any = ... + def __init__(self, field_name: Any, field_type: Any, holder_class: Any) -> None: ... + @property + def field_type_name(self): ... + @property + def holder_class_name(self): ... + +class UnserializableDataError(TypeError): ... + +class UnserializableField(UnserializableDataError): + field_name: Any = ... + field_type: Any = ... + holder_class: Any = ... + msg: Any = ... + def __init__(self, field_name: Any, field_type: Any, holder_class: Any, msg: Optional[Any] = ...) -> None: ... + @property + def field_type_name(self): ... + @property + def holder_class_name(self): ... + +class InvalidFieldValue(ValueError): + field_name: Any = ... + field_type: Any = ... + field_value: Any = ... + holder_class: Any = ... + msg: Any = ... + def __init__(self, field_name: Any, field_type: Any, field_value: Any, holder_class: Any, msg: Optional[Any] = ...) -> None: ... + @property + def field_type_name(self): ... + @property + def holder_class_name(self): ... diff --git a/third-party-stubs/mashumaro/meta/__init__.pyi b/third-party-stubs/mashumaro/meta/__init__.pyi new file mode 100644 index 00000000000..e69de29bb2d diff --git a/third-party-stubs/mashumaro/meta/helpers.pyi b/third-party-stubs/mashumaro/meta/helpers.pyi new file mode 100644 index 00000000000..a176cc07dda --- /dev/null +++ b/third-party-stubs/mashumaro/meta/helpers.pyi @@ -0,0 +1,11 @@ +from typing import Any + +def get_imported_module_names(): ... +def get_type_origin(t: Any): ... +def type_name(t: Any): ... +def is_special_typing_primitive(t: Any): ... +def is_generic(t: Any): ... +def is_union(t: Any): ... +def is_type_var(t: Any): ... +def is_class_var(t: Any): ... +def is_init_var(t: Any): ... diff --git a/third-party-stubs/mashumaro/meta/macros.pyi b/third-party-stubs/mashumaro/meta/macros.pyi new file mode 100644 index 00000000000..c44b85e172a --- /dev/null +++ b/third-party-stubs/mashumaro/meta/macros.pyi @@ -0,0 +1,6 @@ +from typing import Any + +PY_36: Any +PY_37: Any +PY_38: Any +PY_39: Any diff --git a/third-party-stubs/mashumaro/meta/patch.pyi b/third-party-stubs/mashumaro/meta/patch.pyi new file mode 100644 index 00000000000..d3fa7446eb6 --- /dev/null +++ b/third-party-stubs/mashumaro/meta/patch.pyi @@ -0,0 +1 @@ +def patch_fromisoformat() -> None: ... diff --git a/third-party-stubs/mashumaro/serializer/__init__.pyi b/third-party-stubs/mashumaro/serializer/__init__.pyi new file mode 100644 index 00000000000..e69de29bb2d diff --git a/third-party-stubs/mashumaro/serializer/base/__init__.pyi b/third-party-stubs/mashumaro/serializer/base/__init__.pyi new file mode 100644 index 00000000000..2a72874963c --- /dev/null +++ b/third-party-stubs/mashumaro/serializer/base/__init__.pyi @@ -0,0 +1 @@ +from .dict import DataClassDictMixin as DataClassDictMixin diff --git a/third-party-stubs/mashumaro/serializer/base/dict.pyi b/third-party-stubs/mashumaro/serializer/base/dict.pyi new file mode 100644 index 00000000000..568079dac8e --- /dev/null +++ b/third-party-stubs/mashumaro/serializer/base/dict.pyi @@ -0,0 +1,11 @@ +from typing import Any, Mapping, Dict, Optional + +class DataClassDictMixin: + def __init_subclass__(cls, **kwargs: Any) -> None: ... + def __pre_serialize__(self, options: Optional[Dict[str, Any]]) -> Any: ... + def __post_serialize__(self, dct: Mapping, options: Optional[Dict[str, Any]]) -> Any: ... + @classmethod + def __pre_deserialize__(cls: Any, dct: Mapping, options: Optional[Dict[str, Any]]) -> Any: ... + def to_dict( self, use_bytes: bool = False, use_enum: bool = False, use_datetime: bool = False, options: Optional[Dict[str, Any]] = None) -> dict: ... + @classmethod + def from_dict( cls, d: Mapping, use_bytes: bool = False, use_enum: bool = False, use_datetime: bool = False, options: Optional[Dict[str, Any]] = None) -> Any: ... diff --git a/third-party-stubs/mashumaro/serializer/base/helpers.pyi b/third-party-stubs/mashumaro/serializer/base/helpers.pyi new file mode 100644 index 00000000000..f286a1a95d1 --- /dev/null +++ b/third-party-stubs/mashumaro/serializer/base/helpers.pyi @@ -0,0 +1,3 @@ +from typing import Any + +def parse_timezone(s: str) -> Any: ... diff --git a/third-party-stubs/mashumaro/serializer/base/metaprogramming.pyi b/third-party-stubs/mashumaro/serializer/base/metaprogramming.pyi new file mode 100644 index 00000000000..3b6267e8034 --- /dev/null +++ b/third-party-stubs/mashumaro/serializer/base/metaprogramming.pyi @@ -0,0 +1,33 @@ +from mashumaro.meta.helpers import * +from mashumaro.serializer.base.helpers import * +from base64 import decodebytes as decodebytes, encodebytes as encodebytes +from mashumaro.exceptions import InvalidFieldValue as InvalidFieldValue, MissingField as MissingField, UnserializableDataError as UnserializableDataError, UnserializableField as UnserializableField +from mashumaro.meta.patch import patch_fromisoformat as patch_fromisoformat +from mashumaro.types import SerializableType as SerializableType, SerializationStrategy as SerializationStrategy +from typing import Any + +NoneType: Any +INITIAL_MODULES: Any + +class CodeBuilder: + cls: Any = ... + lines: Any = ... + modules: Any = ... + globals: Any = ... + def __init__(self, cls: Any) -> None: ... + def reset(self) -> None: ... + @property + def namespace(self): ... + @property + def annotations(self): ... + @property + def fields(self): ... + @property + def defaults(self): ... + def add_line(self, line: Any) -> None: ... + def indent(self) -> None: ... + def compile(self) -> None: ... + def add_from_dict(self) -> None: ... + def add_to_dict(self) -> None: ... + def add_pack_union(self, fname: Any, ftype: Any, parent: Any, variant_types: Any, value_name: Any): ... + def add_unpack_union(self, fname: Any, ftype: Any, parent: Any, variant_types: Any, value_name: Any): ... diff --git a/third-party-stubs/mashumaro/serializer/json.pyi b/third-party-stubs/mashumaro/serializer/json.pyi new file mode 100644 index 00000000000..135f1a4616f --- /dev/null +++ b/third-party-stubs/mashumaro/serializer/json.pyi @@ -0,0 +1,13 @@ +from mashumaro.serializer.base import DataClassDictMixin as DataClassDictMixin +from typing import Any, Callable, Dict, Mapping, Type, TypeVar, Union + +DEFAULT_DICT_PARAMS: Any +EncodedData = Union[str, bytes, bytearray] +Encoder = Callable[[Dict], EncodedData] +Decoder = Callable[[EncodedData], Dict] +T = TypeVar('T', bound='DataClassJSONMixin') + +class DataClassJSONMixin(DataClassDictMixin): + def to_json(self, encoder: Encoder=..., dict_params: Mapping=..., **encoder_kwargs: Any) -> EncodedData: ... + @classmethod + def from_json(cls: Type[T], data: EncodedData, decoder: Decoder=..., dict_params: Mapping=..., **decoder_kwargs: Any) -> DataClassDictMixin: ... diff --git a/third-party-stubs/mashumaro/serializer/msgpack.pyi b/third-party-stubs/mashumaro/serializer/msgpack.pyi new file mode 100644 index 00000000000..b3372e74c1b --- /dev/null +++ b/third-party-stubs/mashumaro/serializer/msgpack.pyi @@ -0,0 +1,13 @@ +from mashumaro.serializer.base import DataClassDictMixin as DataClassDictMixin +from typing import Any, Callable, Dict, Mapping, Type, TypeVar, Union + +DEFAULT_DICT_PARAMS: Any +EncodedData = Union[str, bytes, bytearray] +Encoder = Callable[[Dict], EncodedData] +Decoder = Callable[[EncodedData], Dict] +T = TypeVar('T', bound='DataClassMessagePackMixin') + +class DataClassMessagePackMixin(DataClassDictMixin): + def to_msgpack(self, encoder: Encoder=..., dict_params: Mapping=..., **encoder_kwargs: Any) -> EncodedData: ... + @classmethod + def from_msgpack(cls: Type[T], data: EncodedData, decoder: Decoder=..., dict_params: Mapping=..., **decoder_kwargs: Any) -> DataClassDictMixin: ... diff --git a/third-party-stubs/mashumaro/serializer/yaml.pyi b/third-party-stubs/mashumaro/serializer/yaml.pyi new file mode 100644 index 00000000000..66ccc4077e7 --- /dev/null +++ b/third-party-stubs/mashumaro/serializer/yaml.pyi @@ -0,0 +1,13 @@ +from mashumaro.serializer.base import DataClassDictMixin as DataClassDictMixin +from typing import Any, Callable, Dict, Mapping, Type, TypeVar, Union + +DEFAULT_DICT_PARAMS: Any +EncodedData = Union[str, bytes] +Encoder = Callable[[Dict], EncodedData] +Decoder = Callable[[EncodedData], Dict] +T = TypeVar('T', bound='DataClassYAMLMixin') + +class DataClassYAMLMixin(DataClassDictMixin): + def to_yaml(self, encoder: Encoder=..., dict_params: Mapping=..., **encoder_kwargs: Any) -> EncodedData: ... + @classmethod + def from_yaml(cls: Type[T], data: EncodedData, decoder: Decoder=..., dict_params: Mapping=..., **decoder_kwargs: Any) -> T: ... diff --git a/third-party-stubs/mashumaro/types.pyi b/third-party-stubs/mashumaro/types.pyi new file mode 100644 index 00000000000..6a704f20cc0 --- /dev/null +++ b/third-party-stubs/mashumaro/types.pyi @@ -0,0 +1,26 @@ +from typing import Any, Optional, TypeVar, Generic + +TV = TypeVar("TV") + +class SerializableEncoder(Generic[TV]): + @classmethod + def _serialize(cls, value): ... + @classmethod + def _deserialize(cls, value): ... + + +class SerializableType: + def _serialize(self): ... + @classmethod + def _deserialize(cls, value): ... + + +class SerializationStrategy: + def _serialize(self, value): ... + def _deserialize(self, value): ... + + +class RoundedDecimal(SerializationStrategy): + exp: Any = ... + rounding: Any = ... + def __init__(self, places: Optional[Any] = ..., rounding: Optional[Any] = ...) -> None: ...