Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New test configs: where, limit, warn_if, error_if, fail_calc #3336

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## dbt 0.20.0 (Release TBD)

### Features
- Add new test configs: `where`, `limit`, `warn_if`, `error_if`, `fail_calc` ([#3258](https://github.com/fishtown-analytics/dbt/issues/3258),[#3321](https://github.com/fishtown-analytics/dbt/issues/3321), [#3336](https://github.com/fishtown-analytics/dbt/pull/3336))
- Support optional `updated_at` config parameter with `check` strategy snapshots. If not supplied, will use current timestamp (default). ([#1844](https://github.com/fishtown-analytics/dbt/issues/1844), [#3376](https://github.com/fishtown-analytics/dbt/pull/3376))
- Add the opt-in `--use-experimental-parser` flag ([#3307](https://github.com/fishtown-analytics/dbt/issues/3307))

Expand Down
7 changes: 1 addition & 6 deletions core/dbt/contracts/graph/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,7 @@ class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
column_name: Optional[str] = None
config: TestConfig = field(default_factory=TestConfig)

def same_config(self, other) -> bool:
return (
self.unrendered_config.get('severity') ==
other.unrendered_config.get('severity')
)

# TODO: this is unused
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes... indeed it is. I confirmed that dbt test -m state:modified works as expected, leveraging the unrendered_config.

def same_column_name(self, other) -> bool:
return self.column_name == other.column_name

Expand Down
28 changes: 28 additions & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,34 @@ class SeedConfig(NodeConfig):
class TestConfig(NodeConfig):
materialized: str = 'test'
severity: Severity = Severity('ERROR')
where: Optional[str] = None
limit: Optional[int] = None
fail_calc: str = "count(*)"
warn_if: str = "!= 0"
error_if: str = "!= 0"

@classmethod
def same_contents(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! I see now that this is the proper place for it. Nice work, this logic is much clearer.

cls, unrendered: Dict[str, Any], other: Dict[str, Any]
) -> bool:
"""This is like __eq__, except it explicitly checks certain fields."""
modifiers = [
'severity',
'where',
'limit',
'fail_calc',
'warn_if',
'error_if'
]

seen = set()
for _, target_name in cls._get_fields():
key = target_name
seen.add(key)
if key in modifiers:
if not cls.compare_key(unrendered, other, key):
return False
return True


@dataclass
Expand Down
7 changes: 1 addition & 6 deletions core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,12 +369,7 @@ class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
column_name: Optional[str] = None
config: TestConfig = field(default_factory=TestConfig)

def same_config(self, other) -> bool:
return (
self.unrendered_config.get('severity') ==
other.unrendered_config.get('severity')
)

# TODO: this is unused
def same_column_name(self, other) -> bool:
return self.column_name == other.column_name

Expand Down
7 changes: 6 additions & 1 deletion core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,15 @@ class BaseResult(dbtClassMixin):
execution_time: float
adapter_response: Dict[str, Any]
message: Optional[Union[str, int]]
failures: Optional[int]

@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
if 'message' not in data:
data['message'] = None
if 'failures' not in data:
data['failures'] = None
return data


Expand Down Expand Up @@ -157,7 +160,8 @@ def process_run_result(result: RunResult) -> RunResultOutput:
thread_id=result.thread_id,
execution_time=result.execution_time,
message=result.message,
adapter_response=result.adapter_response
adapter_response=result.adapter_response,
failures=result.failures
)


Expand Down Expand Up @@ -378,6 +382,7 @@ def from_result(cls, base: FreshnessResult):


Primitive = Union[bool, str, float, None]
PrimitiveDict = Dict[str, Primitive]

CatalogKey = NamedTuple(
'CatalogKey',
Expand Down
15 changes: 13 additions & 2 deletions core/dbt/include/global_project/macros/materializations/test.sql
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
{%- materialization test, default -%}

{% set limit = config.get('limit') %}
{% set fail_calc = config.get('fail_calc') %}
{% set warn_if = config.get('warn_if') %}
{% set error_if = config.get('error_if') %}

{% call statement('main', fetch_result=True) -%}
select count(*) as validation_errors

select
{{ fail_calc }} as failures,
{{ fail_calc }} {{ warn_if }} as should_warn,
{{ fail_calc }} {{ error_if }} as should_error
from (
{{ sql }}
{{ "limit " ~ limit if limit }}
) _dbt_internal_test
{%- endcall %}

{% endcall %}

{%- endmaterialization -%}
32 changes: 26 additions & 6 deletions core/dbt/parser/schema_test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ class TestBuilder(Generic[Testable]):
r'(?P<test_name>([a-zA-Z_][0-9a-zA-Z_]*))'
)
# kwargs representing test configs
MODIFIER_ARGS = ('severity', 'tags', 'enabled')
MODIFIER_ARGS = (
'severity', 'tags', 'enabled', 'where', 'limit', 'warn_if', 'error_if', 'fail_calc'
)

def __init__(
self,
Expand Down Expand Up @@ -278,6 +280,21 @@ def severity(self) -> Optional[str]:
else:
return None

def where(self) -> Optional[str]:
return self.modifiers.get('where')

def limit(self) -> Optional[int]:
return self.modifiers.get('limit')

def warn_if(self) -> Optional[str]:
return self.modifiers.get('warn_if')

def error_if(self) -> Optional[str]:
return self.modifiers.get('error_if')

def fail_calc(self) -> Optional[str]:
return self.modifiers.get('fail_calc')

def tags(self) -> List[str]:
tags = self.modifiers.get('tags', [])
if isinstance(tags, str):
Expand Down Expand Up @@ -334,10 +351,13 @@ def build_raw_sql(self) -> str:
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to reorder "{config}{{{{ {macro}(**{kwargs_name}) }}}}" to be {{{{ {macro}(**{kwargs_name}) }}}}{config}, so that specific configs (supplied by the user in the modifiers) override the generic configs set in the test macro.

  # this is the 'raw_sql' that's used in 'render_update' and execution
  # of the test macro
  # config needs to be rendered last to take precedence over default configs
  # set in the macro (test definition)
  def build_raw_sql(self) -> str:
      return (
          "{{{{ {macro}(**{kwargs_name}) }}}}{config}"
      ).format(
          macro=self.macro_name(),
          config=self.construct_config(),
          kwargs_name=SCHEMA_TEST_KWARGS_NAME,
      )


def build_model_str(self):
targ = self.target
cfg_where = "config.get('where')"
if isinstance(self.target, UnparsedNodeUpdate):
fmt = "{{{{ ref('{0.name}') }}}}"
identifier = self.target.name
target_str = f"{{{{ ref('{targ.name}') }}}}"
elif isinstance(self.target, UnpatchedSourceDefinition):
fmt = "{{{{ source('{0.source.name}', '{0.table.name}') }}}}"
else:
raise self._bad_type()
return fmt.format(self.target)
identifier = self.target.table.name
target_str = f"{{{{ source('{targ.source.name}', '{targ.table.name}') }}}}"
filtered = f"(select * from {target_str} where {{{{{cfg_where}}}}}) {identifier}"
return f"{{% if {cfg_where} %}}{filtered}{{% else %}}{target_str}{{% endif %}}"
17 changes: 17 additions & 0 deletions core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,28 @@ def render_test_update(self, node, config, builder):
if (macro_unique_id in
['macro.dbt.test_not_null', 'macro.dbt.test_unique']):
self.update_parsed_node(node, config)
# manually set configs
if builder.severity() is not None:
node.unrendered_config['severity'] = builder.severity()
node.config['severity'] = builder.severity()
if builder.enabled() is not None:
node.unrendered_config['enabled'] = builder.enabled()
node.config['enabled'] = builder.enabled()
if builder.where() is not None:
node.unrendered_config['where'] = builder.where()
node.config['where'] = builder.where()
if builder.limit() is not None:
node.unrendered_config['limit'] = builder.limit()
node.config['limit'] = builder.limit()
if builder.warn_if() is not None:
node.unrendered_config['warn_if'] = builder.warn_if()
node.config['warn_if'] = builder.warn_if()
if builder.error_if() is not None:
node.unrendered_config['error_if'] = builder.error_if()
node.config['error_if'] = builder.error_if()
if builder.fail_calc() is not None:
node.unrendered_config['fail_calc'] = builder.fail_calc()
node.config['fail_calc'] = builder.fail_calc()
# source node tests are processed at patch_source time
if isinstance(builder.target, UnpatchedSourceDefinition):
sources = [builder.target.fqn[-2], builder.target.fqn[-1]]
Expand Down
11 changes: 7 additions & 4 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def run_with_hooks(self, manifest):
return result

def _build_run_result(self, node, start_time, status, timing_info, message,
agate_table=None, adapter_response=None):
agate_table=None, adapter_response=None, failures=None):
execution_time = time.time() - start_time
thread_id = threading.current_thread().name
if adapter_response is None:
Expand All @@ -227,7 +227,8 @@ def _build_run_result(self, node, start_time, status, timing_info, message,
message=message,
node=node,
agate_table=agate_table,
adapter_response=adapter_response
adapter_response=adapter_response,
failures=failures
)

def error_result(self, node, message, start_time, timing_info):
Expand Down Expand Up @@ -256,7 +257,8 @@ def from_run_result(self, result, start_time, timing_info):
timing_info=timing_info,
message=result.message,
agate_table=result.agate_table,
adapter_response=result.adapter_response
adapter_response=result.adapter_response,
failures=result.failures
)

def skip_result(self, node, message):
Expand All @@ -268,7 +270,8 @@ def skip_result(self, node, message):
timing=[],
message=message,
node=node,
adapter_response={}
adapter_response={},
failures=None
)

def compile_and_execute(self, manifest, ctx):
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def execute(self, compiled_node, manifest):
thread_id=threading.current_thread().name,
execution_time=0,
message=None,
adapter_response={}
adapter_response={},
failures=None
)

def compile(self, manifest):
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/task/freshness.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def _build_run_result(
timing=timing_info,
message=message,
node=node,
adapter_response={}
adapter_response={},
failures=None,
)

def from_run_result(self, result, start_time, timing_info):
Expand Down Expand Up @@ -104,6 +105,7 @@ def execute(self, compiled_node, manifest):
execution_time=0,
message=None,
adapter_response={},
failures=None,
**freshness
)

Expand Down
13 changes: 2 additions & 11 deletions core/dbt/task/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import (
NodeStatus, TableMetadata, CatalogTable, CatalogResults, Primitive,
NodeStatus, TableMetadata, CatalogTable, CatalogResults, PrimitiveDict,
CatalogKey, StatsItem, StatsDict, ColumnMetadata, CatalogArtifact
)
from dbt.exceptions import InternalException
Expand All @@ -37,9 +37,6 @@ def get_stripped_prefix(source: Dict[str, Any], prefix: str) -> Dict[str, Any]:
}


PrimitiveDict = Dict[str, Primitive]


def build_catalog_table(data) -> CatalogTable:
# build the new table's metadata + stats
metadata = TableMetadata.from_dict(get_stripped_prefix(data, 'table_'))
Expand Down Expand Up @@ -193,12 +190,6 @@ def get_unique_id_mapping(
return node_map, source_map


def _coerce_decimal(value):
if isinstance(value, dbt.utils.DECIMALS):
return float(value)
return value


class GenerateTask(CompileTask):
def _get_manifest(self) -> Manifest:
if self.manifest is None:
Expand Down Expand Up @@ -251,7 +242,7 @@ def run(self) -> CatalogArtifact:
catalog_table, exceptions = adapter.get_catalog(self.manifest)

catalog_data: List[PrimitiveDict] = [
dict(zip(catalog_table.column_names, map(_coerce_decimal, row)))
dict(zip(catalog_table.column_names, map(dbt.utils._coerce_decimal, row)))
for row in catalog_table
]

Expand Down
18 changes: 7 additions & 11 deletions core/dbt/task/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dbt import utils

from dbt.contracts.results import (
FreshnessStatus, NodeResult, NodeStatus, TestStatus
FreshnessStatus, NodeStatus, TestStatus
)


Expand Down Expand Up @@ -115,7 +115,7 @@ def get_printable_result(


def print_test_result_line(
result: NodeResult, schema_name, index: int, total: int
result, index: int, total: int
) -> None:
model = result.node

Expand All @@ -128,11 +128,11 @@ def print_test_result_line(
color = ui.green
logger_fn = logger.info
elif result.status == TestStatus.Warn:
info = 'WARN {}'.format(result.message)
info = f'WARN {result.failures}'
color = ui.yellow
logger_fn = logger.warning
elif result.status == TestStatus.Fail:
info = 'FAIL {}'.format(result.message)
info = f'FAIL {result.failures}'
color = ui.red
logger_fn = logger.error
else:
Expand Down Expand Up @@ -291,14 +291,10 @@ def print_run_result_error(
result.node.name,
result.node.original_file_path))

try:
# if message is int, must be rows returned for a test
int(result.message)
except ValueError:
logger.error(" Status: {}".format(result.status))
if result.message:
logger.error(f" {result.message}")
else:
num_rows = utils.pluralize(result.message, 'result')
logger.error(" Got {}, expected 0.".format(num_rows))
logger.error(f" Status: {result.status}")

if result.node.build_path is not None:
with TextOnly():
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def _build_run_model_result(self, model, context):
thread_id=threading.current_thread().name,
execution_time=0,
message=str(result.response),
adapter_response=adapter_response
adapter_response=adapter_response,
failures=result.get('failures')
)

def _materialization_relations(
Expand Down
Loading