Skip to content

Commit

Permalink
respond to some feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle Wigley committed May 20, 2021
1 parent bacac1b commit 87ed860
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 60 deletions.
17 changes: 1 addition & 16 deletions core/dbt/contracts/graph/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,7 @@ class CompiledSchemaTestNode(CompiledNode, HasTestMetadata):
column_name: Optional[str] = None
config: TestConfig = field(default_factory=TestConfig)

def same_config(self, other) -> bool:
return (
self.unrendered_config.get('severity') ==
other.unrendered_config.get('severity') and
self.unrendered_config.get('where') ==
other.unrendered_config.get('where') and
self.unrendered_config.get('limit') ==
other.unrendered_config.get('limit') and
self.unrendered_config.get('fail_calc') ==
other.unrendered_config.get('fail_calc') and
self.unrendered_config.get('error_if') ==
other.unrendered_config.get('error_if') and
self.unrendered_config.get('warn_if') ==
other.unrendered_config.get('warn_if')
)

# TODO: this is unused
def same_column_name(self, other) -> bool:
return self.column_name == other.column_name

Expand Down
23 changes: 23 additions & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,29 @@ class TestConfig(NodeConfig):
warn_if: str = "!= 0"
error_if: str = "!= 0"

@classmethod
def same_contents(
cls, unrendered: Dict[str, Any], other: Dict[str, Any]
) -> bool:
"""This is like __eq__, except it explicitly checks certain fields."""
modifiers = [
'severity',
'where',
'limit',
'fail_calc',
'warn_if',
'error_if'
]

seen = set()
for _, target_name in cls._get_fields():
key = target_name
seen.add(key)
if key in modifiers:
if not cls.compare_key(unrendered, other, key):
return False
return True


@dataclass
class EmptySnapshotConfig(NodeConfig):
Expand Down
17 changes: 1 addition & 16 deletions core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,22 +369,7 @@ class ParsedSchemaTestNode(ParsedNode, HasTestMetadata):
column_name: Optional[str] = None
config: TestConfig = field(default_factory=TestConfig)

def same_config(self, other) -> bool:
return (
self.unrendered_config.get('severity') ==
other.unrendered_config.get('severity') and
self.unrendered_config.get('where') ==
other.unrendered_config.get('where') and
self.unrendered_config.get('limit') ==
other.unrendered_config.get('limit') and
self.unrendered_config.get('fail_calc') ==
other.unrendered_config.get('fail_calc') and
self.unrendered_config.get('error_if') ==
other.unrendered_config.get('error_if') and
self.unrendered_config.get('warn_if') ==
other.unrendered_config.get('warn_if')
)

# TODO: this is unused
def same_column_name(self, other) -> bool:
return self.column_name == other.column_name

Expand Down
5 changes: 5 additions & 0 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def skipped(self):
return self.status == RunStatus.Skipped


@dataclass
class TestResult(RunResult):
failures: int = 0


@dataclass
class ExecutionResult(dbtClassMixin):
results: Sequence[BaseResult]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
{% call statement('main', fetch_result=True) -%}

select
{{ fail_calc }} as validation_errors,
{{ fail_calc }} as failures,
{{ fail_calc }} {{ warn_if }} as should_warn,
{{ fail_calc }} {{ error_if }} as should_error,
'{{ warn_if }}' as warn_if,
'{{ error_if }}' as error_if
{{ fail_calc }} {{ error_if }} as should_error
from (
{{ sql }}
{{ "limit " ~ limit if limit }}
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/schema_test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def severity(self) -> Optional[str]:
def where(self) -> Optional[str]:
return self.modifiers.get('where')

def limit(self) -> Optional[str]:
def limit(self) -> Optional[int]:
return self.modifiers.get('limit')

def warn_if(self) -> Optional[str]:
Expand Down
8 changes: 1 addition & 7 deletions core/dbt/task/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,6 @@ def get_unique_id_mapping(
return node_map, source_map


def _coerce_decimal(value):
if isinstance(value, dbt.utils.DECIMALS):
return float(value)
return value


class GenerateTask(CompileTask):
def _get_manifest(self) -> Manifest:
if self.manifest is None:
Expand Down Expand Up @@ -248,7 +242,7 @@ def run(self) -> CatalogArtifact:
catalog_table, exceptions = adapter.get_catalog(self.manifest)

catalog_data: List[PrimitiveDict] = [
dict(zip(catalog_table.column_names, map(_coerce_decimal, row)))
dict(zip(catalog_table.column_names, map(dbt.utils._coerce_decimal, row)))
for row in catalog_table
]

Expand Down
36 changes: 20 additions & 16 deletions core/dbt/task/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
CompiledTestNode,
)
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import RunResult, TestStatus, PrimitiveDict
from dbt.contracts.results import TestResult, TestStatus, PrimitiveDict
from dbt.context.providers import generate_runtime_model
from dbt.clients.jinja import MacroGenerator
from dbt.exceptions import (
Expand All @@ -31,12 +31,10 @@


@dataclass
class TestResult(dbtClassMixin):
validation_errors: int
class TestResultData(dbtClassMixin):
failures: int
should_warn: bool
should_error: bool
warn_if: str
error_if: str


class TestRunner(CompileRunner):
Expand All @@ -60,7 +58,7 @@ def execute_test(
self,
test: Union[CompiledDataTestNode, CompiledSchemaTestNode],
manifest: Manifest
) -> TestResult:
) -> TestResultData:
context = generate_runtime_model(
test, self.config, manifest
)
Expand Down Expand Up @@ -97,45 +95,51 @@ def execute_test(
f"1 row"
)
num_cols = len(table.columns)
if num_cols != 5:
if num_cols != 3:
raise InternalException(
f"dbt internally failed to execute {test.unique_id}: "
f"Returned {num_cols} columns, but expected "
f"5 columns"
f"3 columns"
)

test_result_data: PrimitiveDict = dict(zip(table.column_names, table.rows[0]))

return TestResult.from_dict(test_result_data)
test_result_dct: PrimitiveDict = dict(
zip(
[column_name.lower() for column_name in table.column_names],
map(utils._coerce_decimal, table.rows[0])
)
)
TestResultData.validate(test_result_dct)
return TestResultData.from_dict(test_result_dct)

def execute(self, test: CompiledTestNode, manifest: Manifest):
result = self.execute_test(test, manifest)

severity = test.config.severity.upper()
thread_id = threading.current_thread().name
num_errors = utils.pluralize(result.validation_errors, 'result')
num_errors = utils.pluralize(result.failures, 'result')
status = None
message = None
if severity == "ERROR" and result.should_error:
status = TestStatus.Fail
message = f'Got {num_errors}, configured to fail if {result.error_if}'
message = f'Got {num_errors}, configured to fail if {test.config.error_if}'
elif result.should_warn:
if flags.WARN_ERROR:
status = TestStatus.Fail
message = f'Got {num_errors}, configured to fail if {result.warn_if}'
message = f'Got {num_errors}, configured to fail if {test.config.warn_if}'
else:
status = TestStatus.Warn
message = f'Got {num_errors}, configured to warn if {result.warn_if}'
message = f'Got {num_errors}, configured to warn if {test.config.warn_if}'
else:
status = TestStatus.Pass

return RunResult(
return TestResult(
node=test,
status=status,
timing=[],
thread_id=thread_id,
execution_time=0,
message=message,
failures=result.failures,
adapter_response={},
)

Expand Down
6 changes: 6 additions & 0 deletions core/dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,12 @@ def coerce_dict_str(value: Any) -> Optional[Dict[str, Any]]:
return None


def _coerce_decimal(value):
if isinstance(value, DECIMALS):
return float(value)
return value


def lowercase(value: Optional[str]) -> Optional[str]:
if value is None:
return None
Expand Down

0 comments on commit 87ed860

Please sign in to comment.