From 40d3a4fba0a61479d296d55dbd9b126b78ae8351 Mon Sep 17 00:00:00 2001 From: Mikael Rapp Date: Fri, 18 Jun 2021 15:57:31 +0200 Subject: [PATCH] Add model_sql helper to enable isolated unit testing --- core/dbt/context/providers.py | 73 +++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 24c265c83da..d37b1b1f56c 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -274,6 +274,31 @@ def __call__(self, *args: str) -> RelationProxy: return self.resolve(args[0], args[1]) +# RefSourceOverride is used to conditionally override the behaviour of +# the ref() and source() macros +class RefSourceOverride(): + def __init__( + self, + base: Union[BaseRefResolver, BaseSourceResolver], + overrides: Dict[str, Any] + ) -> None: + self.base = base + self.overrides = overrides + + def __call__(self, *args: str) -> Any: + var_name = "" + if len(args) == 1: + var_name = args[0] + elif len(args) == 2: + package, name = args + var_name = package + "." + name + + if var_name in self.overrides: + return self.overrides[var_name] + + return self.base(*args) + + class Config(Protocol): def __init__(self, model, context_config: Optional[ContextConfig]): ... @@ -582,6 +607,22 @@ def _generate_merged(self) -> Mapping[str, Any]: return merged +# VarOverride is used to conditionally override the behaviour of the var() macro +class VarOverride(): + def __init__( + self, + base: ModelConfiguredVar, + overrides: Dict[str, Any] + ) -> None: + self.base = base + self.overrides = overrides + + def __call__(self, var_name, default): + if var_name in self.overrides: + return self.overrides[var_name] + return self.base(var_name, default) + + class ParseVar(ModelConfiguredVar): def get_missing_var(self, var_name): # in the parser, just always return None. @@ -1133,6 +1174,38 @@ def graph(self) -> Dict[str, Any]: def ctx_model(self) -> Dict[str, Any]: return self.model.to_dict(omit_none=True) + @contextmember + def model_sql(self, model_name: str, overrides: Dict[str, Any]) -> str: + """The `model_sql` method wil return compiled SQL code of the model without + wrapping CTEs arount it. The optional second argument allow for + """ + + if not isinstance(self.manifest, Manifest): + return "not valid manifest" + + ctx = self.to_dict() + if "var" in overrides: + ctx["var"] = VarOverride(ctx["var"], overrides["var"]) + + if "ref" in overrides: + ctx["ref"] = RefSourceOverride(ctx["ref"], overrides["ref"]) + + if "source" in overrides: + ctx["source"] = RefSourceOverride(ctx["source"], overrides["source"]) + + if not model_name in self.manifest.nodes: + raise ValidationException( + 'model_name in cte() call not found; "{}" to be one of {}' + .format(model_name, ','.join(map(str, self.manifest.nodes.keys())))) + + node = self.manifest.nodes[model_name] + + return get_rendered( + string=node.raw_sql, + ctx=ctx, + node=node, + ) + @contextproperty def pre_hooks(self) -> Optional[List[Dict[str, Any]]]: return None