Skip to content

Commit

Permalink
Use updated Mashumaro code
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Mar 3, 2021
1 parent 344a144 commit 31c88f9
Show file tree
Hide file tree
Showing 44 changed files with 159 additions and 143 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Contributors:
### Under the hood
- Bump werkzeug upper bound dependency to `<v2.0` ([#3011](https://github.com/fishtown-analytics/dbt/pull/3011))
- Performance fixes for many different things ([#2862](https://github.com/fishtown-analytics/dbt/issues/2862), [#3034](https://github.com/fishtown-analytics/dbt/pull/3034))
- Update code to use Mashumaro 2.0 ([#3138](https://github.com/fishtown-analytics/dbt/pull/3138))

Contributors:
- [@Bl3f](https://github.com/Bl3f) ([#3011](https://github.com/fishtown-analytics/dbt/pull/3011))
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _get_field_named(cls, field_name):
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.to_dict() == other.to_dict()
return self.to_dict(omit_none=True) == other.to_dict(omit_none=True)

@classmethod
def get_default_quote_policy(cls) -> Policy:
Expand Down Expand Up @@ -185,10 +185,10 @@ def quoted(self, identifier):
def create_from_source(
cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any
) -> Self:
source_quoting = source.quoting.to_dict()
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop('column', None)
quote_policy = deep_merge(
cls.get_default_quote_policy().to_dict(),
cls.get_default_quote_policy().to_dict(omit_none=True),
source_quoting,
kwargs.get('quote_policy', {}),
)
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def write_graph(self, outfile: str, manifest: Manifest):
"""
out_graph = self.graph.copy()
for node_id in self.graph.nodes():
data = manifest.expect(node_id).to_dict()
data = manifest.expect(node_id).to_dict(omit_none=True)
out_graph.add_node(node_id, **data)
nx.write_gpickle(out_graph, outfile)

Expand Down Expand Up @@ -339,7 +339,7 @@ def _recursively_prepend_ctes(
model.compiled_sql = injected_sql
model.extra_ctes_injected = True
model.extra_ctes = prepended_ctes
model.validate(model.to_dict())
model.validate(model.to_dict(omit_none=True))

manifest.update_node(model)

Expand Down Expand Up @@ -388,7 +388,7 @@ def _compile_node(

logger.debug("Compiling {}".format(node.unique_id))

data = node.to_dict()
data = node.to_dict(omit_none=True)
data.update({
'compiled': False,
'compiled_sql': None,
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/config/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def to_profile_info(
'credentials': self.credentials,
}
if serialize_credentials:
result['config'] = self.config.to_dict()
result['credentials'] = self.credentials.to_dict()
result['config'] = self.config.to_dict(omit_none=True)
result['credentials'] = self.credentials.to_dict(omit_none=True)
return result

def to_target_dict(self) -> Dict[str, Any]:
Expand All @@ -125,7 +125,7 @@ def to_target_dict(self) -> Dict[str, Any]:
'name': self.target_name,
'target_name': self.target_name,
'profile_name': self.profile_name,
'config': self.config.to_dict(),
'config': self.config.to_dict(omit_none=True),
})
return target

Expand All @@ -138,7 +138,7 @@ def __eq__(self, other: object) -> bool:
def validate(self):
try:
if self.credentials:
dct = self.credentials.to_dict()
dct = self.credentials.to_dict(omit_none=True)
self.credentials.validate(dct)
dct = self.to_profile_info(serialize_credentials=True)
ProfileConfig.validate(dct)
Expand Down
7 changes: 4 additions & 3 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def create_project(self, rendered: RenderComponents) -> 'Project':
# break many things
quoting: Dict[str, Any] = {}
if cfg.quoting is not None:
quoting = cfg.quoting.to_dict()
quoting = cfg.quoting.to_dict(omit_none=True)

models: Dict[str, Any]
seeds: Dict[str, Any]
Expand Down Expand Up @@ -578,10 +578,11 @@ def to_project_config(self, with_packages=False):
'config-version': self.config_version,
})
if self.query_comment:
result['query-comment'] = self.query_comment.to_dict()
result['query-comment'] = \
self.query_comment.to_dict(omit_none=True)

if with_packages:
result.update(self.packages.to_dict())
result.update(self.packages.to_dict(omit_none=True))

return result

Expand Down
4 changes: 2 additions & 2 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def from_parts(
get_relation_class_by_name(profile.credentials.type)
.get_default_quote_policy()
.replace_dict(_project_quoting_dict(project, profile))
).to_dict()
).to_dict(omit_none=True)

cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))

Expand Down Expand Up @@ -391,7 +391,7 @@ def __getattribute__(self, name):
f"'UnsetConfig' object has no attribute {name}"
)

def __post_serialize__(self, dct, options=None):
def __post_serialize__(self, dct):
return {}


Expand Down
1 change: 1 addition & 0 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,5 @@ def flags(self) -> Any:

def generate_base_context(cli_vars: Dict[str, Any]) -> Dict[str, Any]:
ctx = BaseContext(cli_vars)
# This is not a Mashumaro to_dict call
return ctx.to_dict()
2 changes: 1 addition & 1 deletion core/dbt/context/context_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def calculate_node_config_dict(
base=base,
)
finalized = config.finalize_and_validate()
return finalized.to_dict()
return finalized.to_dict(omit_none=True)


class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
Expand Down
1 change: 1 addition & 0 deletions core/dbt/context/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,5 @@ def generate_runtime_docs(
current_project: str,
) -> Dict[str, Any]:
ctx = DocsRuntimeContext(config, target, manifest, current_project)
# This is not a Mashumaro to_dict call
return ctx.to_dict()
6 changes: 3 additions & 3 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ def graph(self) -> Dict[str, Any]:

@contextproperty('model')
def ctx_model(self) -> Dict[str, Any]:
return self.model.to_dict()
return self.model.to_dict(omit_none=True)

@contextproperty
def pre_hooks(self) -> Optional[List[Dict[str, Any]]]:
Expand Down Expand Up @@ -1231,15 +1231,15 @@ def pre_hooks(self) -> List[Dict[str, Any]]:
if isinstance(self.model, ParsedSourceDefinition):
return []
return [
h.to_dict() for h in self.model.config.pre_hook
h.to_dict(omit_none=True) for h in self.model.config.pre_hook
]

@contextproperty
def post_hooks(self) -> List[Dict[str, Any]]:
if isinstance(self.model, ParsedSourceDefinition):
return []
return [
h.to_dict() for h in self.model.config.post_hook
h.to_dict(omit_none=True) for h in self.model.config.post_hook
]

@contextproperty
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def connection_info(
) -> Iterable[Tuple[str, Any]]:
"""Return an ordered iterator of key/value pairs for pretty-printing.
"""
as_dict = self.to_dict(options={'keep_none': True})
as_dict = self.to_dict(omit_none=False)
connection_keys = set(self._connection_keys())
aliases: List[str] = []
if with_aliases:
Expand All @@ -148,8 +148,8 @@ def _connection_keys(self) -> Tuple[str, ...]:
raise NotImplementedError

@classmethod
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
data = cls.translate_aliases(data)
return data

Expand All @@ -159,7 +159,7 @@ def translate_aliases(
) -> Dict[str, Any]:
return translate_aliases(kwargs, cls._ALIASES, recurse)

def __post_serialize__(self, dct, options=None):
def __post_serialize__(self, dct):
# no super() -- do we need it?
if self._ALIASES:
dct.update({
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource:
raise ValueError('invalid resource_type: {}'
.format(compiled.resource_type))

return cls.from_dict(compiled.to_dict())
return cls.from_dict(compiled.to_dict(omit_none=True))


NonSourceCompiledNode = Union[
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def build_edges(nodes: List[ManifestNode]):


def _deepcopy(value):
return value.from_dict(value.to_dict())
return value.from_dict(value.to_dict(omit_none=True))


class Locality(enum.IntEnum):
Expand Down Expand Up @@ -564,11 +564,11 @@ def build_flat_graph(self):
"""
self.flat_graph = {
'nodes': {
k: v.to_dict(options={'keep_none': True})
k: v.to_dict(omit_none=False)
for k, v in self.nodes.items()
},
'sources': {
k: v.to_dict(options={'keep_none': True})
k: v.to_dict(omit_none=False)
for k, v in self.sources.items()
}
}
Expand Down Expand Up @@ -755,7 +755,7 @@ def writable_manifest(self):

# When 'to_dict' is called on the Manifest, it substitues a
# WritableManifest
def __pre_serialize__(self, options=None):
def __pre_serialize__(self):
return self.writable_manifest()

def write(self, path):
Expand Down
16 changes: 8 additions & 8 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def update_from(
"""
# sadly, this is a circular import
from dbt.adapters.factory import get_config_class_by_name
dct = self.to_dict(options={'keep_none': True})
dct = self.to_dict(omit_none=False)

adapter_config_cls = get_config_class_by_name(adapter_type)

Expand All @@ -326,12 +326,12 @@ def update_from(
return self.from_dict(dct)

def finalize_and_validate(self: T) -> T:
dct = self.to_dict(options={'keep_none': True})
dct = self.to_dict(omit_none=False)
self.validate(dct)
return self.from_dict(dct)

def replace(self, **kwargs):
dct = self.to_dict()
dct = self.to_dict(omit_none=True)

mapping = self.field_mapping()
for key, value in kwargs.items():
Expand Down Expand Up @@ -396,8 +396,8 @@ class NodeConfig(BaseConfig):
full_refresh: Optional[bool] = None

@classmethod
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
field_map = {'post-hook': 'post_hook', 'pre-hook': 'pre_hook'}
# create a new dict because otherwise it gets overwritten in
# tests
Expand All @@ -414,8 +414,8 @@ def __pre_deserialize__(cls, data, options=None):
data[new_name] = data.pop(field_name)
return data

def __post_serialize__(self, dct, options=None):
dct = super().__post_serialize__(dct, options=options)
def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
field_map = {'post_hook': 'post-hook', 'pre_hook': 'pre-hook'}
for field_name in field_map:
if field_name in dct:
Expand Down Expand Up @@ -480,7 +480,7 @@ def validate(cls, data):
# formerly supported with GenericSnapshotConfig

def finalize_and_validate(self):
data = self.to_dict()
data = self.to_dict(omit_none=True)
self.validate(data)
return self.from_dict(data)

Expand Down
8 changes: 4 additions & 4 deletions core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class HasRelationMetadata(dbtClassMixin, Replaceable):
# because it messes up the subclasses and default parameters
# so hack it here
@classmethod
def __pre_deserialize__(cls, data, options=None):
data = super().__pre_deserialize__(data, options=options)
def __pre_deserialize__(cls, data):
data = super().__pre_deserialize__(data)
if 'database' not in data:
data['database'] = None
return data
Expand Down Expand Up @@ -141,7 +141,7 @@ def patch(self, patch: 'ParsedNodePatch'):
# Maybe there should be validation or restrictions
# elsewhere?
assert isinstance(self, dbtClassMixin)
dct = self.to_dict(options={'keep_none': True})
dct = self.to_dict(omit_none=False)
self.validate(dct)

def get_materialization(self):
Expand Down Expand Up @@ -454,7 +454,7 @@ def patch(self, patch: ParsedMacroPatch):
if flags.STRICT_MODE:
# What does this actually validate?
assert isinstance(self, dbtClassMixin)
dct = self.to_dict(options={'keep_none': True})
dct = self.to_dict(omit_none=False)
self.validate(dct)

def same_contents(self, other: Optional['ParsedMacro']) -> bool:
Expand Down
18 changes: 6 additions & 12 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,9 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
external: Optional[ExternalTable] = None
tags: List[str] = field(default_factory=list)

def __post_serialize__(self, dct, options=None):
def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
keep_none = False
if options and 'keep_none' in options and options['keep_none']:
keep_none = True
if not keep_none and self.freshness is None:
if 'freshness' not in dct and self.freshness is None:
dct['freshness'] = None
return dct

Expand All @@ -261,12 +258,9 @@ class UnparsedSourceDefinition(dbtClassMixin, Replaceable):
def yaml_key(self) -> 'str':
return 'sources'

def __post_serialize__(self, dct, options=None):
def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
keep_none = False
if options and 'keep_none' in options and options['keep_none']:
keep_none = True
if not keep_none and self.freshness is None:
if 'freshnewss' not in dct and self.freshness is None:
dct['freshness'] = None
return dct

Expand All @@ -290,7 +284,7 @@ class SourceTablePatch(dbtClassMixin):
columns: Optional[Sequence[UnparsedColumn]] = None

def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict()
dct = self.to_dict(omit_none=True)
remove_keys = ('name')
for key in remove_keys:
if key in dct:
Expand Down Expand Up @@ -327,7 +321,7 @@ class SourcePatch(dbtClassMixin, Replaceable):
tags: Optional[List[str]] = None

def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict()
dct = self.to_dict(omit_none=True)
remove_keys = ('name', 'overrides', 'tables', 'path')
for key in remove_keys:
if key in dct:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __len__(self):
return len(fields(self.__class__))

def incorporate(self, **kwargs):
value = self.to_dict()
value = self.to_dict(omit_none=True)
value = deep_merge(value, kwargs)
return self.from_dict(value)

Expand Down
Loading

0 comments on commit 31c88f9

Please sign in to comment.