diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index a69a58ccad9..15a96f34c83 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -119,22 +119,7 @@ class CompiledSchemaTestNode(CompiledNode, HasTestMetadata): column_name: Optional[str] = None config: TestConfig = field(default_factory=TestConfig) - def same_config(self, other) -> bool: - return ( - self.unrendered_config.get('severity') == - other.unrendered_config.get('severity') and - self.unrendered_config.get('where') == - other.unrendered_config.get('where') and - self.unrendered_config.get('limit') == - other.unrendered_config.get('limit') and - self.unrendered_config.get('fail_calc') == - other.unrendered_config.get('fail_calc') and - self.unrendered_config.get('error_if') == - other.unrendered_config.get('error_if') and - self.unrendered_config.get('warn_if') == - other.unrendered_config.get('warn_if') - ) - + # TODO: this is unused def same_column_name(self, other) -> bool: return self.column_name == other.column_name diff --git a/core/dbt/contracts/graph/model_config.py b/core/dbt/contracts/graph/model_config.py index 56452eff69e..ff32b54987d 100644 --- a/core/dbt/contracts/graph/model_config.py +++ b/core/dbt/contracts/graph/model_config.py @@ -444,6 +444,29 @@ class TestConfig(NodeConfig): warn_if: str = "!= 0" error_if: str = "!= 0" + @classmethod + def same_contents( + cls, unrendered: Dict[str, Any], other: Dict[str, Any] + ) -> bool: + """This is like __eq__, except it explicitly checks certain fields.""" + modifiers = [ + 'severity', + 'where', + 'limit', + 'fail_calc', + 'warn_if', + 'error_if' + ] + + seen = set() + for _, target_name in cls._get_fields(): + key = target_name + seen.add(key) + if key in modifiers: + if not cls.compare_key(unrendered, other, key): + return False + return True + @dataclass class EmptySnapshotConfig(NodeConfig): diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index 713e80788e3..60a8c436133 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -369,22 +369,7 @@ class ParsedSchemaTestNode(ParsedNode, HasTestMetadata): column_name: Optional[str] = None config: TestConfig = field(default_factory=TestConfig) - def same_config(self, other) -> bool: - return ( - self.unrendered_config.get('severity') == - other.unrendered_config.get('severity') and - self.unrendered_config.get('where') == - other.unrendered_config.get('where') and - self.unrendered_config.get('limit') == - other.unrendered_config.get('limit') and - self.unrendered_config.get('fail_calc') == - other.unrendered_config.get('fail_calc') and - self.unrendered_config.get('error_if') == - other.unrendered_config.get('error_if') and - self.unrendered_config.get('warn_if') == - other.unrendered_config.get('warn_if') - ) - + # TODO: this is unused def same_column_name(self, other) -> bool: return self.column_name == other.column_name diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index 9bb11bc3ade..7076fea29fb 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -122,6 +122,11 @@ def skipped(self): return self.status == RunStatus.Skipped +@dataclass +class TestResult(RunResult): + failures: int = 0 + + @dataclass class ExecutionResult(dbtClassMixin): results: Sequence[BaseResult] diff --git a/core/dbt/include/global_project/macros/materializations/test.sql b/core/dbt/include/global_project/macros/materializations/test.sql index 6cd3e5e564d..dc1ca93de19 100644 --- a/core/dbt/include/global_project/macros/materializations/test.sql +++ b/core/dbt/include/global_project/macros/materializations/test.sql @@ -8,11 +8,9 @@ {% call statement('main', fetch_result=True) -%} select - {{ fail_calc }} as validation_errors, + {{ fail_calc }} as failures, {{ fail_calc }} {{ warn_if }} as should_warn, - {{ fail_calc }} {{ error_if }} as should_error, - '{{ warn_if }}' as warn_if, - '{{ error_if }}' as error_if + {{ fail_calc }} {{ error_if }} as should_error from ( {{ sql }} {{ "limit " ~ limit if limit }} diff --git a/core/dbt/parser/schema_test_builders.py b/core/dbt/parser/schema_test_builders.py index 735f97e9291..28842caa8f2 100644 --- a/core/dbt/parser/schema_test_builders.py +++ b/core/dbt/parser/schema_test_builders.py @@ -283,7 +283,7 @@ def severity(self) -> Optional[str]: def where(self) -> Optional[str]: return self.modifiers.get('where') - def limit(self) -> Optional[str]: + def limit(self) -> Optional[int]: return self.modifiers.get('limit') def warn_if(self) -> Optional[str]: diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index c36e84fef4c..de9134ed7fa 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -190,12 +190,6 @@ def get_unique_id_mapping( return node_map, source_map -def _coerce_decimal(value): - if isinstance(value, dbt.utils.DECIMALS): - return float(value) - return value - - class GenerateTask(CompileTask): def _get_manifest(self) -> Manifest: if self.manifest is None: @@ -248,7 +242,7 @@ def run(self) -> CatalogArtifact: catalog_table, exceptions = adapter.get_catalog(self.manifest) catalog_data: List[PrimitiveDict] = [ - dict(zip(catalog_table.column_names, map(_coerce_decimal, row))) + dict(zip(catalog_table.column_names, map(dbt.utils._coerce_decimal, row))) for row in catalog_table ] diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index eef3e26062c..c5b5bc342c9 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -14,7 +14,7 @@ CompiledTestNode, ) from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.results import RunResult, TestStatus, PrimitiveDict +from dbt.contracts.results import TestResult, TestStatus, PrimitiveDict from dbt.context.providers import generate_runtime_model from dbt.clients.jinja import MacroGenerator from dbt.exceptions import ( @@ -31,12 +31,10 @@ @dataclass -class TestResult(dbtClassMixin): - validation_errors: int +class TestResultData(dbtClassMixin): + failures: int should_warn: bool should_error: bool - warn_if: str - error_if: str class TestRunner(CompileRunner): @@ -60,7 +58,7 @@ def execute_test( self, test: Union[CompiledDataTestNode, CompiledSchemaTestNode], manifest: Manifest - ) -> TestResult: + ) -> TestResultData: context = generate_runtime_model( test, self.config, manifest ) @@ -97,45 +95,51 @@ def execute_test( f"1 row" ) num_cols = len(table.columns) - if num_cols != 5: + if num_cols != 3: raise InternalException( f"dbt internally failed to execute {test.unique_id}: " f"Returned {num_cols} columns, but expected " - f"5 columns" + f"3 columns" ) - test_result_data: PrimitiveDict = dict(zip(table.column_names, table.rows[0])) - - return TestResult.from_dict(test_result_data) + test_result_dct: PrimitiveDict = dict( + zip( + [column_name.lower() for column_name in table.column_names], + map(utils._coerce_decimal, table.rows[0]) + ) + ) + TestResultData.validate(test_result_dct) + return TestResultData.from_dict(test_result_dct) def execute(self, test: CompiledTestNode, manifest: Manifest): result = self.execute_test(test, manifest) severity = test.config.severity.upper() thread_id = threading.current_thread().name - num_errors = utils.pluralize(result.validation_errors, 'result') + num_errors = utils.pluralize(result.failures, 'result') status = None message = None if severity == "ERROR" and result.should_error: status = TestStatus.Fail - message = f'Got {num_errors}, configured to fail if {result.error_if}' + message = f'Got {num_errors}, configured to fail if {test.config.error_if}' elif result.should_warn: if flags.WARN_ERROR: status = TestStatus.Fail - message = f'Got {num_errors}, configured to fail if {result.warn_if}' + message = f'Got {num_errors}, configured to fail if {test.config.warn_if}' else: status = TestStatus.Warn - message = f'Got {num_errors}, configured to warn if {result.warn_if}' + message = f'Got {num_errors}, configured to warn if {test.config.warn_if}' else: status = TestStatus.Pass - return RunResult( + return TestResult( node=test, status=status, timing=[], thread_id=thread_id, execution_time=0, message=message, + failures=result.failures, adapter_response={}, ) diff --git a/core/dbt/utils.py b/core/dbt/utils.py index def02a55f6c..ffed4d6d074 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -420,6 +420,12 @@ def coerce_dict_str(value: Any) -> Optional[Dict[str, Any]]: return None +def _coerce_decimal(value): + if isinstance(value, DECIMALS): + return float(value) + return value + + def lowercase(value: Optional[str]) -> Optional[str]: if value is None: return None