Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to using mashumaro jsonschema with acceptable performance #8437

Merged
merged 12 commits into from
Aug 30, 2023
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230718-145428.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Switch from hologram to mashumaro jsonschema
time: 2023-07-18T14:54:28.41453-04:00
custom:
Author: gshank
Issue: "8426"
14 changes: 13 additions & 1 deletion core/dbt/context/context_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,21 @@ def initial_result(self, resource_type: NodeType, base: bool) -> C:

def _update_from_config(self, result: C, partial: Dict[str, Any], validate: bool = False) -> C:
translated = self._active_project.credentials.translate_aliases(partial)
return result.update_from(
translated = self.translate_hook_names(translated)
updated = result.update_from(
translated, self._active_project.credentials.type, validate=validate
)
return updated

def translate_hook_names(self, project_dict):
# This is a kind of kludge because the fix for #6411 specifically allowed misspelling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not the intended input format, should we raise a warning here indicating that? I wouldn't cause anything to fail, but providing some direction would make it easier for us to deprecate the incorrect spelling in the future (likely one less thing for folks to change for 2.0).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That ticket specifically allowed the "incorrect" spellings, so it's now a feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no :lolsob: emoji, why is there no :lolsob: emoji when I need one so badly.

That being said, we don't intend on ever migrating folks off of the "incorrect" spelling either?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd have to ask product and Doug :). If you want to open a ticket, go ahead. Not in scope for this one though...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The misspelling here we mean is, we'll accept either kebab case or snake case for these two configs, in the several places they could be potentially defined:

  • post-hook or post_hook
  • pre-hook or pre_hook

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The misspelling here we mean is, we'll accept either kebab case or snake case for these two configs

Agreed, I'm asking if we ever want to back out of that ditch, or support that for the foreseeable future.

# the hook field names in dbt_project.yml, which only ever worked because we didn't
# run validate on the dbt_project configs.
if "pre_hook" in project_dict:
project_dict["pre-hook"] = project_dict.pop("pre_hook")
if "post_hook" in project_dict:
project_dict["post-hook"] = project_dict.pop("post_hook")
return project_dict

def calculate_node_config_dict(
self,
Expand Down
17 changes: 7 additions & 10 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,21 @@
from dbt.events.functions import fire_event
from dbt.events.types import NewConnectionOpening
from dbt.events.contextvars import get_node_info
from typing_extensions import Protocol
from typing_extensions import Protocol, Annotated
from dbt.dataclass_schema import (
dbtClassMixin,
StrEnum,
ExtensibleDbtClassMixin,
HyphenatedDbtClassMixin,
ValidatedStringMixin,
register_pattern,
)
from dbt.contracts.util import Replaceable
from mashumaro.jsonschema.annotations import Pattern


class Identifier(ValidatedStringMixin):
ValidationRegex = r"^[A-Za-z_][A-Za-z0-9_]+$"


# we need register_pattern for jsonschema validation
register_pattern(Identifier, r"^[A-Za-z_][A-Za-z0-9_]+$")


@dataclass
class AdapterResponse(dbtClassMixin):
_message: str
Expand All @@ -55,7 +50,8 @@ class ConnectionState(StrEnum):

@dataclass(init=False)
class Connection(ExtensibleDbtClassMixin, Replaceable):
type: Identifier
# Annotated is used by mashumaro for jsonschema generation
type: Annotated[Identifier, Pattern(r"^[A-Za-z_][A-Za-z0-9_]+$")]
name: Optional[str] = None
state: ConnectionState = ConnectionState.INIT
transaction_open: bool = False
Expand Down Expand Up @@ -161,6 +157,7 @@ def _connection_keys(self) -> Tuple[str, ...]:
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
# Need to fixup dbname => database, pass => password
data = cls.translate_aliases(data)
return data

Expand Down Expand Up @@ -220,10 +217,10 @@ def to_target_dict(self):


@dataclass
class QueryComment(HyphenatedDbtClassMixin):
class QueryComment(dbtClassMixin):
comment: str = DEFAULT_QUERY_COMMENT
append: bool = False
job_label: bool = False
job_label: bool = field(default=False, metadata={"alias": "job-label"})


class AdapterRequiredConfig(HasCredentials, Protocol):
Expand Down
44 changes: 9 additions & 35 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from enum import Enum
from itertools import chain
from typing import Any, List, Optional, Dict, Union, Type, TypeVar, Callable
from typing_extensions import Annotated

from dbt.dataclass_schema import (
dbtClassMixin,
ValidationError,
register_pattern,
StrEnum,
)
from dbt.contracts.graph.unparsed import AdditionalPropertiesAllowed, Docs
Expand All @@ -15,6 +15,7 @@
from dbt.exceptions import DbtInternalError, CompilationError
from dbt import hooks
from dbt.node_types import NodeType
from mashumaro.jsonschema.annotations import Pattern


M = TypeVar("M", bound="Metadata")
Expand Down Expand Up @@ -188,9 +189,6 @@ class Severity(str):
pass


register_pattern(Severity, insensitive_patterns("warn", "error"))


class OnConfigurationChangeOption(StrEnum):
Apply = "apply"
Continue = "continue"
Expand Down Expand Up @@ -376,15 +374,6 @@ def finalize_and_validate(self: T) -> T:
self.validate(dct)
return self.from_dict(dct)

def replace(self, **kwargs):
gshank marked this conversation as resolved.
Show resolved Hide resolved
dct = self.to_dict(omit_none=True)

mapping = self.field_mapping()
for key, value in kwargs.items():
new_key = mapping.get(key, key)
dct[new_key] = value
return self.from_dict(dct)


@dataclass
class SemanticModelConfig(BaseConfig):
Expand Down Expand Up @@ -447,11 +436,11 @@ class NodeConfig(NodeAndTestConfig):
persist_docs: Dict[str, Any] = field(default_factory=dict)
post_hook: List[Hook] = field(
default_factory=list,
metadata=MergeBehavior.Append.meta(),
metadata={"merge": MergeBehavior.Append, "alias": "post-hook"},
)
pre_hook: List[Hook] = field(
default_factory=list,
metadata=MergeBehavior.Append.meta(),
metadata={"merge": MergeBehavior.Append, "alias": "pre-hook"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are aliases needed for Append now? Why does packages not need it on line 466?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "alias" is for handling the dashes in the names properly. Most of the other field definitions use that kind of hacky metadata=MergeBehavior.DictKeyAppend.meta() thing, which doesn't allow setting additional metadata.

)
quoting: Dict[str, Any] = field(
default_factory=dict,
Expand Down Expand Up @@ -511,30 +500,11 @@ def __post_init__(self):
@classmethod
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
field_map = {"post-hook": "post_hook", "pre-hook": "pre_hook"}
# create a new dict because otherwise it gets overwritten in
# tests
new_dict = {}
for key in data:
new_dict[key] = data[key]
data = new_dict
for key in hooks.ModelHookType:
if key in data:
data[key] = [hooks.get_hook_dict(h) for h in data[key]]
for field_name in field_map:
if field_name in data:
new_name = field_map[field_name]
data[new_name] = data.pop(field_name)
return data

def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
field_map = {"post_hook": "post-hook", "pre_hook": "pre-hook"}
for field_name in field_map:
if field_name in dct:
dct[field_map[field_name]] = dct.pop(field_name)
return dct

# this is still used by jsonschema validation
@classmethod
def field_mapping(cls):
Expand All @@ -554,6 +524,9 @@ def validate(cls, data):
raise ValidationError("A seed must have a materialized value of 'seed'")


SEVERITY_PATTERN = r"^([Ww][Aa][Rr][Nn]|[Ee][Rr][Rr][Oo][Rr])$"


@dataclass
class TestConfig(NodeAndTestConfig):
__test__ = False
Expand All @@ -564,7 +537,8 @@ class TestConfig(NodeAndTestConfig):
metadata=CompareBehavior.Exclude.meta(),
)
materialized: str = "test"
severity: Severity = Severity("ERROR")
# Annotated is used by mashumaro for jsonschema generation
severity: Annotated[Severity, Pattern(SEVERITY_PATTERN)] = Severity("ERROR")
store_failures: Optional[bool] = None
where: Optional[str] = None
limit: Optional[int] = None
Expand Down
36 changes: 18 additions & 18 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import hashlib

from mashumaro.types import SerializableType
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator, Literal

from dbt.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin

Expand Down Expand Up @@ -556,18 +556,18 @@ def depends_on_macros(self):

@dataclass
class AnalysisNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Analysis]})
resource_type: Literal[NodeType.Analysis]


@dataclass
class HookNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Operation]})
resource_type: Literal[NodeType.Operation]
index: Optional[int] = None


@dataclass
class ModelNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Model]})
resource_type: Literal[NodeType.Model]
access: AccessType = AccessType.Protected
constraints: List[ModelLevelConstraint] = field(default_factory=list)
version: Optional[NodeVersion] = None
Expand Down Expand Up @@ -854,12 +854,12 @@ def same_contract(self, old, adapter_type=None) -> bool:
# TODO: rm?
@dataclass
class RPCNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.RPCCall]})
resource_type: Literal[NodeType.RPCCall]


@dataclass
class SqlNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.SqlOperation]})
resource_type: Literal[NodeType.SqlOperation]


# ====================================
Expand All @@ -869,7 +869,7 @@ class SqlNode(CompiledNode):

@dataclass
class SeedNode(ParsedNode): # No SQLDefaults!
resource_type: NodeType = field(metadata={"restrict": [NodeType.Seed]})
resource_type: Literal[NodeType.Seed]
config: SeedConfig = field(default_factory=SeedConfig)
# seeds need the root_path because the contents are not loaded initially
# and we need the root_path to load the seed later
Expand Down Expand Up @@ -995,7 +995,7 @@ def is_relational(self):

@dataclass
class SingularTestNode(TestShouldStoreFailures, CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
resource_type: Literal[NodeType.Test]
# Was not able to make mypy happy and keep the code working. We need to
# refactor the various configs.
config: TestConfig = field(default_factory=TestConfig) # type: ignore
Expand Down Expand Up @@ -1031,7 +1031,7 @@ class HasTestMetadata(dbtClassMixin):

@dataclass
class GenericTestNode(TestShouldStoreFailures, CompiledNode, HasTestMetadata):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
resource_type: Literal[NodeType.Test]
column_name: Optional[str] = None
file_key_name: Optional[str] = None
# Was not able to make mypy happy and keep the code working. We need to
Expand Down Expand Up @@ -1064,13 +1064,13 @@ class IntermediateSnapshotNode(CompiledNode):
# uses a regular node config, which the snapshot parser will then convert
# into a full ParsedSnapshotNode after rendering. Note: it currently does
# not work to set snapshot config in schema files because of the validation.
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
resource_type: Literal[NodeType.Snapshot]
config: EmptySnapshotConfig = field(default_factory=EmptySnapshotConfig)


@dataclass
class SnapshotNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]})
resource_type: Literal[NodeType.Snapshot]
config: SnapshotConfig
defer_relation: Optional[DeferRelation] = None

Expand All @@ -1083,7 +1083,7 @@ class SnapshotNode(CompiledNode):
@dataclass
class Macro(BaseNode):
macro_sql: str
resource_type: NodeType = field(metadata={"restrict": [NodeType.Macro]})
resource_type: Literal[NodeType.Macro]
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
description: str = ""
meta: Dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -1113,7 +1113,7 @@ def depends_on_macros(self):
@dataclass
class Documentation(BaseNode):
block_contents: str
resource_type: NodeType = field(metadata={"restrict": [NodeType.Documentation]})
resource_type: Literal[NodeType.Documentation]

@property
def search_name(self):
Expand Down Expand Up @@ -1144,7 +1144,7 @@ class UnpatchedSourceDefinition(BaseNode):
source: UnparsedSourceDefinition
table: UnparsedSourceTableDefinition
fqn: List[str]
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
resource_type: Literal[NodeType.Source]
patch_path: Optional[str] = None

def get_full_source_name(self):
Expand Down Expand Up @@ -1189,7 +1189,7 @@ class ParsedSourceMandatory(GraphNode, HasRelationMetadata):
source_description: str
loader: str
identifier: str
resource_type: NodeType = field(metadata={"restrict": [NodeType.Source]})
resource_type: Literal[NodeType.Source]


@dataclass
Expand Down Expand Up @@ -1316,7 +1316,7 @@ def search_name(self):
class Exposure(GraphNode):
type: ExposureType
owner: Owner
resource_type: NodeType = field(metadata={"restrict": [NodeType.Exposure]})
resource_type: Literal[NodeType.Exposure]
description: str = ""
label: Optional[str] = None
maturity: Optional[MaturityType] = None
Expand Down Expand Up @@ -1465,7 +1465,7 @@ class Metric(GraphNode):
type_params: MetricTypeParams
filter: Optional[WhereFilter] = None
metadata: Optional[SourceFileMetadata] = None
resource_type: NodeType = field(metadata={"restrict": [NodeType.Metric]})
resource_type: Literal[NodeType.Metric]
meta: Dict[str, Any] = field(default_factory=dict)
tags: List[str] = field(default_factory=list)
config: MetricConfig = field(default_factory=MetricConfig)
Expand Down Expand Up @@ -1548,7 +1548,7 @@ def same_contents(self, old: Optional["Metric"]) -> bool:
class Group(BaseNode):
name: str
owner: Owner
resource_type: NodeType = field(metadata={"restrict": [NodeType.Group]})
resource_type: Literal[NodeType.Group]


# ====================================
Expand Down
Loading