Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from hologram to mashumaro jsonschema #8132

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230718-145428.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Switch from hologram to mashumaro jsonschema
time: 2023-07-18T14:54:28.41453-04:00
custom:
Author: gshank
Issue: "6776"
14 changes: 13 additions & 1 deletion core/dbt/context/context_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,21 @@ def initial_result(self, resource_type: NodeType, base: bool) -> C:

def _update_from_config(self, result: C, partial: Dict[str, Any], validate: bool = False) -> C:
translated = self._active_project.credentials.translate_aliases(partial)
return result.update_from(
translated = self.translate_hook_names(translated)
updated = result.update_from(
translated, self._active_project.credentials.type, validate=validate
)
return updated

def translate_hook_names(self, project_dict):
# This is a kind of kludge because the fix for #6411 specifically allowed misspelling
# the hook field names in dbt_project.yml, which only ever worked because we didn't
# run validate on the dbt_project configs.
if "pre_hook" in project_dict:
project_dict["pre-hook"] = project_dict.pop("pre_hook")
if "post_hook" in project_dict:
project_dict["post-hook"] = project_dict.pop("post_hook")
return project_dict

def calculate_node_config_dict(
self,
Expand Down
17 changes: 7 additions & 10 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,21 @@
from dbt.events.functions import fire_event
from dbt.events.types import NewConnectionOpening
from dbt.events.contextvars import get_node_info
from typing_extensions import Protocol
from typing_extensions import Protocol, Annotated
from dbt.dataclass_schema import (
dbtClassMixin,
StrEnum,
ExtensibleDbtClassMixin,
HyphenatedDbtClassMixin,
ValidatedStringMixin,
register_pattern,
)
from dbt.contracts.util import Replaceable
from mashumaro.jsonschema.annotations import Pattern


class Identifier(ValidatedStringMixin):
ValidationRegex = r"^[A-Za-z_][A-Za-z0-9_]+$"


# we need register_pattern for jsonschema validation
register_pattern(Identifier, r"^[A-Za-z_][A-Za-z0-9_]+$")


@dataclass
class AdapterResponse(dbtClassMixin):
_message: str
Expand All @@ -55,7 +50,8 @@ class ConnectionState(StrEnum):

@dataclass(init=False)
class Connection(ExtensibleDbtClassMixin, Replaceable):
type: Identifier
# Annotated is used by mashumaro for jsonschema generation
type: Annotated[Identifier, Pattern(r"^[A-Za-z_][A-Za-z0-9_]+$")]
name: Optional[str] = None
state: ConnectionState = ConnectionState.INIT
transaction_open: bool = False
Expand Down Expand Up @@ -161,6 +157,7 @@ def _connection_keys(self) -> Tuple[str, ...]:
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
# Need to fixup dbname => database, pass => password
data = cls.translate_aliases(data)
return data

Expand Down Expand Up @@ -220,10 +217,10 @@ def to_target_dict(self):


@dataclass
class QueryComment(HyphenatedDbtClassMixin):
class QueryComment(dbtClassMixin):
comment: str = DEFAULT_QUERY_COMMENT
append: bool = False
job_label: bool = False
job_label: bool = field(default=False, metadata={"alias": "job-label"})


class AdapterRequiredConfig(HasCredentials, Protocol):
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ def __init__(self, macros):


@dataclass
@schema_version("manifest", 10)
@schema_version("manifest", 11)
class WritableManifest(ArtifactMixin):
nodes: Mapping[UniqueID, ManifestNode] = field(
metadata=dict(description=("The nodes defined in the dbt project and its dependencies"))
Expand Down Expand Up @@ -1486,14 +1486,15 @@ def compatible_previous_versions(self):
("manifest", 7),
("manifest", 8),
("manifest", 9),
("manifest", 10),
]

@classmethod
def upgrade_schema_version(cls, data):
"""This overrides the "upgrade_schema_version" call in VersionedSchema (via
ArtifactMixin) to modify the dictionary passed in from earlier versions of the manifest."""
manifest_schema_version = get_manifest_schema_version(data)
if manifest_schema_version <= 9:
if manifest_schema_version <= 10:
data = upgrade_manifest_json(data, manifest_schema_version)
return cls.from_dict(data)

Expand Down
44 changes: 9 additions & 35 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from enum import Enum
from itertools import chain
from typing import Any, List, Optional, Dict, Union, Type, TypeVar, Callable
from typing_extensions import Annotated

from dbt.dataclass_schema import (
dbtClassMixin,
ValidationError,
register_pattern,
StrEnum,
)
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed, Docs
Expand All @@ -15,6 +15,7 @@
from dbt.exceptions import DbtInternalError, CompilationError
from dbt import hooks
from dbt.node_types import NodeType
from mashumaro.jsonschema.annotations import Pattern


M = TypeVar("M", bound="Metadata")
Expand Down Expand Up @@ -188,9 +189,6 @@ class Severity(str):
pass


register_pattern(Severity, insensitive_patterns("warn", "error"))


class OnConfigurationChangeOption(StrEnum):
Apply = "apply"
Continue = "continue"
Expand Down Expand Up @@ -376,15 +374,6 @@ def finalize_and_validate(self: T) -> T:
self.validate(dct)
return self.from_dict(dct)

def replace(self, **kwargs):
dct = self.to_dict(omit_none=True)

mapping = self.field_mapping()
for key, value in kwargs.items():
new_key = mapping.get(key, key)
dct[new_key] = value
return self.from_dict(dct)


@dataclass
class SemanticModelConfig(BaseConfig):
Expand Down Expand Up @@ -447,11 +436,11 @@ class NodeConfig(NodeAndTestConfig):
persist_docs: Dict[str, Any] = field(default_factory=dict)
post_hook: List[Hook] = field(
default_factory=list,
metadata=MergeBehavior.Append.meta(),
metadata={"merge": MergeBehavior.Append, "alias": "post-hook"},
)
pre_hook: List[Hook] = field(
default_factory=list,
metadata=MergeBehavior.Append.meta(),
metadata={"merge": MergeBehavior.Append, "alias": "pre-hook"},
)
quoting: Dict[str, Any] = field(
default_factory=dict,
Expand Down Expand Up @@ -511,30 +500,11 @@ def __post_init__(self):
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
field_map = {"post-hook": "post_hook", "pre-hook": "pre_hook"}
# create a new dict because otherwise it gets overwritten in
# tests
new_dict = {}
for key in data:
new_dict[key] = data[key]
data = new_dict
for key in hooks.ModelHookType:
if key in data:
data[key] = [hooks.get_hook_dict(h) for h in data[key]]
for field_name in field_map:
if field_name in data:
new_name = field_map[field_name]
data[new_name] = data.pop(field_name)
return data

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
field_map = {"post_hook": "post-hook", "pre_hook": "pre-hook"}
for field_name in field_map:
if field_name in dct:
dct[field_map[field_name]] = dct.pop(field_name)
return dct

# this is still used by jsonschema validation
@classmethod
def field_mapping(cls):
Expand All @@ -554,6 +524,9 @@ def validate(cls, data):
raise ValidationError("A seed must have a materialized value of 'seed'")


SEVERITY_PATTERN = r"^([Ww][Aa][Rr][Nn]|[Ee][Rr][Rr][Oo][Rr])$"


@dataclass
class TestConfig(NodeAndTestConfig):
__test__ = False
Expand All @@ -564,7 +537,8 @@ class TestConfig(NodeAndTestConfig):
metadata=CompareBehavior.Exclude.meta(),
)
materialized: str = "test"
severity: Severity = Severity("ERROR")
# Annotated is used by mashumaro for jsonschema generation
severity: Annotated[Severity, Pattern(SEVERITY_PATTERN)] = Severity("ERROR")
store_failures: Optional[bool] = None
where: Optional[str] = None
limit: Optional[int] = None
Expand Down
36 changes: 18 additions & 18 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import hashlib

from mashumaro.types import SerializableType
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator, Literal

from dbt.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin

Expand Down Expand Up @@ -555,18 +555,18 @@ def depends_on_macros(self):

@dataclass
class AnalysisNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
resource_type: Literal[NodeType.Analysis]


@dataclass
class HookNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
resource_type: Literal[NodeType.Operation]
index: Optional[int] = None


@dataclass
class ModelNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
resource_type: Literal[NodeType.Model]
access: AccessType = AccessType.Protected
constraints: List[ModelLevelConstraint] = field(default_factory=list)
version: Optional[NodeVersion] = None
Expand Down Expand Up @@ -775,12 +775,12 @@ def same_contract(self, old, adapter_type=None) -> bool:
# TODO: rm?
@dataclass
class RPCNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
resource_type: Literal[NodeType.RPCCall]


@dataclass
class SqlNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.SqlOperation]})
resource_type: Literal[NodeType.SqlOperation]


# ====================================
Expand All @@ -790,7 +790,7 @@ class SqlNode(CompiledNode):

@dataclass
class SeedNode(ParsedNode): # No SQLDefaults!
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
resource_type: Literal[NodeType.Seed]
config: SeedConfig = field(default_factory=SeedConfig)
# seeds need the root_path because the contents are not loaded initially
# and we need the root_path to load the seed later
Expand Down Expand Up @@ -916,7 +916,7 @@ def is_relational(self):

@dataclass
class SingularTestNode(TestShouldStoreFailures, CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
resource_type: Literal[NodeType.Test]
# Was not able to make mypy happy and keep the code working. We need to
# refactor the various configs.
config: TestConfig = field(default_factory=TestConfig) # type: ignore
Expand Down Expand Up @@ -952,7 +952,7 @@ class HasTestMetadata(dbtClassMixin):

@dataclass
class GenericTestNode(TestShouldStoreFailures, CompiledNode, HasTestMetadata):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
resource_type: Literal[NodeType.Test]
column_name: Optional[str] = None
file_key_name: Optional[str] = None
# Was not able to make mypy happy and keep the code working. We need to
Expand Down Expand Up @@ -985,13 +985,13 @@ class IntermediateSnapshotNode(CompiledNode):
# uses a regular node config, which the snapshot parser will then convert
# into a full ParsedSnapshotNode after rendering. Note: it currently does
# not work to set snapshot config in schema files because of the validation.
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
resource_type: Literal[NodeType.Snapshot]
config: EmptySnapshotConfig = field(default_factory=EmptySnapshotConfig)


@dataclass
class SnapshotNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
resource_type: Literal[NodeType.Snapshot]
config: SnapshotConfig
defer_relation: Optional[DeferRelation] = None

Expand All @@ -1004,7 +1004,7 @@ class SnapshotNode(CompiledNode):
@dataclass
class Macro(BaseNode):
macro_sql: str
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
resource_type: Literal[NodeType.Macro]
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
description: str = ""
meta: Dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def depends_on_macros(self):
@dataclass
class Documentation(BaseNode):
block_contents: str
resource_type: NodeType = field(metadata={"restrict": [NodeType.Documentation]})
resource_type: Literal[NodeType.Documentation]

@property
def search_name(self):
Expand Down Expand Up @@ -1065,7 +1065,7 @@ class UnpatchedSourceDefinition(BaseNode):
source: UnparsedSourceDefinition
table: UnparsedSourceTableDefinition
fqn: List[str]
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
resource_type: Literal[NodeType.Source]
patch_path: Optional[str] = None

def get_full_source_name(self):
Expand Down Expand Up @@ -1110,7 +1110,7 @@ class ParsedSourceMandatory(GraphNode, HasRelationMetadata):
source_description: str
loader: str
identifier: str
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
resource_type: Literal[NodeType.Source]


@dataclass
Expand Down Expand Up @@ -1237,7 +1237,7 @@ def search_name(self):
class Exposure(GraphNode):
type: ExposureType
owner: Owner
resource_type: NodeType = field(metadata={"restrict": [NodeType.Exposure]})
resource_type: Literal[NodeType.Exposure]
description: str = ""
label: Optional[str] = None
maturity: Optional[MaturityType] = None
Expand Down Expand Up @@ -1386,7 +1386,7 @@ class Metric(GraphNode):
type_params: MetricTypeParams
filter: Optional[WhereFilter] = None
metadata: Optional[SourceFileMetadata] = None
resource_type: NodeType = field(metadata={"restrict": [NodeType.Metric]})
resource_type: Literal[NodeType.Metric]
meta: Dict[str, Any] = field(default_factory=dict)
tags: List[str] = field(default_factory=list)
config: MetricConfig = field(default_factory=MetricConfig)
Expand Down Expand Up @@ -1469,7 +1469,7 @@ def same_contents(self, old: Optional["Metric"]) -> bool:
class Group(BaseNode):
name: str
owner: Owner
resource_type: NodeType = field(metadata={"restrict": [NodeType.Group]})
resource_type: Literal[NodeType.Group]


# ====================================
Expand Down
Loading