Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model_sql helper to enable isolated unit testing #3476

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,31 @@ def __call__(self, *args: str) -> RelationProxy:
return self.resolve(args[0], args[1])


# RefSourceOverride is used to conditionally override the behaviour of
# the ref() and source() macros
class RefSourceOverride():
def __init__(
self,
base: Union[BaseRefResolver, BaseSourceResolver],
overrides: Dict[str, Any]
) -> None:
self.base = base
self.overrides = overrides

def __call__(self, *args: str) -> Any:
var_name = ""
if len(args) == 1:
var_name = args[0]
elif len(args) == 2:
package, name = args
var_name = package + "." + name

if var_name in self.overrides:
return self.overrides[var_name]

return self.base(*args)


class Config(Protocol):
def __init__(self, model, context_config: Optional[ContextConfig]):
...
Expand Down Expand Up @@ -582,6 +607,22 @@ def _generate_merged(self) -> Mapping[str, Any]:
return merged


# VarOverride is used to conditionally override the behaviour of the var() macro
class VarOverride():
def __init__(
self,
base: ModelConfiguredVar,
overrides: Dict[str, Any]
) -> None:
self.base = base
self.overrides = overrides

def __call__(self, var_name, default):
if var_name in self.overrides:
return self.overrides[var_name]
return self.base(var_name, default)


class ParseVar(ModelConfiguredVar):
def get_missing_var(self, var_name):
# in the parser, just always return None.
Expand Down Expand Up @@ -1133,6 +1174,38 @@ def graph(self) -> Dict[str, Any]:
def ctx_model(self) -> Dict[str, Any]:
return self.model.to_dict(omit_none=True)

@contextmember
def model_sql(self, model_name: str, overrides: Dict[str, Any]) -> str:
"""The `model_sql` method wil return compiled SQL code of the model without
wrapping CTEs arount it. The optional second argument allow for
"""

if not isinstance(self.manifest, Manifest):
return "not valid manifest"

ctx = self.to_dict()
if "var" in overrides:
ctx["var"] = VarOverride(ctx["var"], overrides["var"])

if "ref" in overrides:
ctx["ref"] = RefSourceOverride(ctx["ref"], overrides["ref"])

if "source" in overrides:
ctx["source"] = RefSourceOverride(ctx["source"], overrides["source"])

if not model_name in self.manifest.nodes:
raise ValidationException(
'model_name in cte() call not found; "{}" to be one of {}'
.format(model_name, ','.join(map(str, self.manifest.nodes.keys()))))

node = self.manifest.nodes[model_name]

return get_rendered(
string=node.raw_sql,
ctx=ctx,
node=node,
)

@contextproperty
def pre_hooks(self) -> Optional[List[Dict[str, Any]]]:
return None
Expand Down