Skip to content

Commit

Permalink
Move node patch method to schema parser patch_node_properties and ref…
Browse files Browse the repository at this point in the history
…actor schema parsing (#7640)
  • Loading branch information
gshank authored May 17, 2023
1 parent 4f249b8 commit 5f7ae2f
Show file tree
Hide file tree
Showing 14 changed files with 1,112 additions and 1,226 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20230516-094241.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Move node patch method to schema parser patch_node_properties and refactor schema
parsing
time: 2023-05-16T09:42:41.793503-04:00
custom:
Author: gshank
Issue: "7430"
9 changes: 8 additions & 1 deletion core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ResultNode,
BaseNode,
ManifestOrPublicNode,
ModelNode,
)
from dbt.contracts.graph.unparsed import SourcePatch, NodeVersion, UnparsedVersion
from dbt.contracts.graph.manifest_upgrade import upgrade_manifest_json
Expand Down Expand Up @@ -188,7 +189,13 @@ def find(
# If this is an unpinned ref (no 'version' arg was passed),
# AND this is a versioned node,
# AND this ref is being resolved at runtime -- get_node_info != {}
if version is None and node.is_versioned and get_node_info():
# Only ModelNodes can be versioned.
if (
isinstance(node, ModelNode)
and version is None
and node.is_versioned
and get_node_info()
):
# Check to see if newer versions are available, and log an "FYI" if so
max_version: UnparsedVersion = max(
[
Expand Down
69 changes: 2 additions & 67 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@
)
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
from dbt.events.functions import warn_or_error
from dbt.exceptions import ParsingError, InvalidAccessTypeError, ContractBreakingChangeError
from dbt.exceptions import ParsingError, ContractBreakingChangeError
from dbt.events.types import (
SeedIncreased,
SeedExceedsLimitSamePath,
SeedExceedsLimitAndPathChanged,
SeedExceedsLimitChecksumChanged,
ValidationWarning,
)
from dbt.events.contextvars import set_contextvars
from dbt.flags import get_flags
Expand Down Expand Up @@ -443,63 +442,6 @@ def same_contract(self, old, adapter_type=None) -> bool:
# This would only apply to seeds
return True

def patch(self, patch: "ParsedNodePatch"):
"""Given a ParsedNodePatch, add the new information to the node."""
# NOTE: Constraint patching is awkwardly done in the parse_patch function
# which calls this one. We need to combine the logic.

# explicitly pick out the parts to update so we don't inadvertently
# step on the model name or anything
# Note: config should already be updated
self.patch_path: Optional[str] = patch.file_id
# update created_at so process_docs will run in partial parsing
self.created_at = time.time()
self.description = patch.description
self.columns = patch.columns
self.name = patch.name

# TODO: version, latest_version, and access are specific to ModelNodes, consider splitting out to ModelNode
if self.resource_type != NodeType.Model:
if patch.version:
warn_or_error(
ValidationWarning(
field_name="version",
resource_type=self.resource_type.value,
node_name=patch.name,
)
)
if patch.latest_version:
warn_or_error(
ValidationWarning(
field_name="latest_version",
resource_type=self.resource_type.value,
node_name=patch.name,
)
)
self.version = patch.version
self.latest_version = patch.latest_version

# This might not be the ideal place to validate the "access" field,
# but at this point we have the information we need to properly
# validate and we don't before this.
if patch.access:
if self.resource_type == NodeType.Model:
if AccessType.is_valid(patch.access):
self.access = AccessType(patch.access)
else:
raise InvalidAccessTypeError(
unique_id=self.unique_id,
field_value=patch.access,
)
else:
warn_or_error(
ValidationWarning(
field_name="access",
resource_type=self.resource_type.value,
node_name=patch.name,
)
)

def same_contents(self, old, adapter_type) -> bool:
if old is None:
return False
Expand Down Expand Up @@ -1015,14 +957,6 @@ class Macro(BaseNode):
created_at: float = field(default_factory=lambda: time.time())
supported_languages: Optional[List[ModelLanguage]] = None

def patch(self, patch: "ParsedMacroPatch"):
self.patch_path: Optional[str] = patch.file_id
self.description = patch.description
self.created_at = time.time()
self.meta = patch.meta
self.docs = patch.docs
self.arguments = patch.arguments

def same_contents(self, other: Optional["Macro"]) -> bool:
if other is None:
return False
Expand Down Expand Up @@ -1466,6 +1400,7 @@ class ParsedNodePatch(ParsedPatch):
access: Optional[str]
version: Optional[NodeVersion]
latest_version: Optional[NodeVersion]
constraints: List[Dict[str, Any]]


@dataclass
Expand Down
222 changes: 222 additions & 0 deletions core/dbt/parser/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
from dbt.contracts.graph.unparsed import (
HasColumnProps,
UnparsedColumn,
UnparsedNodeUpdate,
UnparsedMacroUpdate,
UnparsedAnalysisUpdate,
UnparsedExposure,
UnparsedModelUpdate,
)
from dbt.contracts.graph.unparsed import NodeVersion, HasColumnTests, HasColumnDocs
from dbt.contracts.graph.nodes import (
UnpatchedSourceDefinition,
ColumnInfo,
ColumnLevelConstraint,
ConstraintType,
)
from dbt.parser.search import FileBlock
from typing import List, Dict, Any, TypeVar, Generic, Union, Optional
from dataclasses import dataclass
from dbt.exceptions import DbtInternalError, ParsingError


def trimmed(inp: str) -> str:
if len(inp) < 50:
return inp
return inp[:44] + "..." + inp[-3:]


TestDef = Union[str, Dict[str, Any]]


Target = TypeVar(
"Target",
UnparsedNodeUpdate,
UnparsedMacroUpdate,
UnparsedAnalysisUpdate,
UnpatchedSourceDefinition,
UnparsedExposure,
UnparsedModelUpdate,
)


ColumnTarget = TypeVar(
"ColumnTarget",
UnparsedModelUpdate,
UnparsedNodeUpdate,
UnparsedAnalysisUpdate,
UnpatchedSourceDefinition,
)

Versioned = TypeVar("Versioned", bound=UnparsedModelUpdate)

Testable = TypeVar("Testable", UnparsedNodeUpdate, UnpatchedSourceDefinition, UnparsedModelUpdate)


@dataclass
class YamlBlock(FileBlock):
data: Dict[str, Any]

@classmethod
def from_file_block(cls, src: FileBlock, data: Dict[str, Any]):
return cls(
file=src.file,
data=data,
)


@dataclass
class TargetBlock(YamlBlock, Generic[Target]):
target: Target

@property
def name(self):
return self.target.name

@property
def columns(self):
return []

@property
def tests(self) -> List[TestDef]:
return []

@classmethod
def from_yaml_block(cls, src: YamlBlock, target: Target) -> "TargetBlock[Target]":
return cls(
file=src.file,
data=src.data,
target=target,
)


@dataclass
class TargetColumnsBlock(TargetBlock[ColumnTarget], Generic[ColumnTarget]):
@property
def columns(self):
if self.target.columns is None:
return []
else:
return self.target.columns


@dataclass
class TestBlock(TargetColumnsBlock[Testable], Generic[Testable]):
@property
def tests(self) -> List[TestDef]:
if self.target.tests is None:
return []
else:
return self.target.tests

@property
def quote_columns(self) -> Optional[bool]:
return self.target.quote_columns

@classmethod
def from_yaml_block(cls, src: YamlBlock, target: Testable) -> "TestBlock[Testable]":
return cls(
file=src.file,
data=src.data,
target=target,
)


@dataclass
class VersionedTestBlock(TestBlock, Generic[Versioned]):
@property
def columns(self):
if not self.target.versions:
return super().columns
else:
raise DbtInternalError(".columns for VersionedTestBlock with versions")

@property
def tests(self) -> List[TestDef]:
if not self.target.versions:
return super().tests
else:
raise DbtInternalError(".tests for VersionedTestBlock with versions")

@classmethod
def from_yaml_block(cls, src: YamlBlock, target: Versioned) -> "VersionedTestBlock[Versioned]":
return cls(
file=src.file,
data=src.data,
target=target,
)


@dataclass
class GenericTestBlock(TestBlock[Testable], Generic[Testable]):
test: Dict[str, Any]
column_name: Optional[str]
tags: List[str]
version: Optional[NodeVersion]

@classmethod
def from_test_block(
cls,
src: TestBlock,
test: Dict[str, Any],
column_name: Optional[str],
tags: List[str],
version: Optional[NodeVersion],
) -> "GenericTestBlock":
return cls(
file=src.file,
data=src.data,
target=src.target,
test=test,
column_name=column_name,
tags=tags,
version=version,
)


class ParserRef:
"""A helper object to hold parse-time references."""

def __init__(self):
self.column_info: Dict[str, ColumnInfo] = {}

def _add(self, column: HasColumnProps):
tags: List[str] = []
tags.extend(getattr(column, "tags", ()))
quote: Optional[bool]
if isinstance(column, UnparsedColumn):
quote = column.quote
else:
quote = None

if any(
c
for c in column.constraints
if "type" not in c or not ConstraintType.is_valid(c["type"])
):
raise ParsingError(f"Invalid constraint type on column {column.name}")

self.column_info[column.name] = ColumnInfo(
name=column.name,
description=column.description,
data_type=column.data_type,
constraints=[ColumnLevelConstraint.from_dict(c) for c in column.constraints],
meta=column.meta,
tags=tags,
quote=quote,
_extra=column.extra,
)

@classmethod
def from_target(cls, target: Union[HasColumnDocs, HasColumnTests]) -> "ParserRef":
refs = cls()
for column in target.columns:
refs._add(column)
return refs

@classmethod
def from_versioned_target(cls, target: Versioned, version: NodeVersion) -> "ParserRef":
refs = cls()
for base_column in target.get_columns_for_version(version):
refs._add(base_column)
return refs
4 changes: 2 additions & 2 deletions core/dbt/parser/generic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def resource_type(self) -> NodeType:
def get_compiled_path(cls, block: FileBlock):
return block.path.relative_path

def parse_generic_test(
def create_generic_test_macro(
self, block: jinja.BlockTag, base_node: UnparsedMacro, name: str
) -> Macro:
unique_id = self.generate_unique_id(name)
Expand Down Expand Up @@ -76,7 +76,7 @@ def parse_unparsed_generic_test(self, base_node: UnparsedMacro) -> Iterable[Macr
continue

name: str = generic_test_name.replace(MACRO_PREFIX, "")
node = self.parse_generic_test(block, base_node, name)
node = self.create_generic_test_macro(block, base_node, name)
yield node

def parse_file(self, block: FileBlock):
Expand Down
Loading

0 comments on commit 5f7ae2f

Please sign in to comment.