From 107cd2fe3f2c8361140b205d1d5c8e2aaf8550f0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 10 May 2023 06:13:43 +0200 Subject: [PATCH] simplify keys / get (#190) --- src/pytorch_tabular/config/config.py | 11 +++++------ src/pytorch_tabular/models/base_model.py | 3 +-- src/pytorch_tabular/ssl_models/base_model.py | 3 +-- src/pytorch_tabular/tabular_model.py | 5 +---- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 97b7153b..6f4b4c3f 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -44,12 +44,11 @@ def _read_yaml(filename): def _validate_choices(cls): for key in cls.__dataclass_fields__.keys(): atr = cls.__dataclass_fields__[key] - if atr.init: - if "choices" in atr.metadata.keys(): - if getattr(cls, key) not in atr.metadata.get("choices"): - raise ValueError( - f"{getattr(cls, key)} is not a valid choice for {key}. Please choose from on of the following: {atr.metadata['choices']}" - ) + if atr.init and "choices" in atr.metadata.keys(): + if getattr(cls, key) not in atr.metadata.get("choices"): + raise ValueError( + f"{getattr(cls, key)} is not a valid choice for {key}. Please choose from on of the following: {atr.metadata['choices']}" + ) @dataclass diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index faf2cd28..15868f2a 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -47,8 +47,7 @@ def safe_merge_config(config: DictConfig, inferred_config: DictConfig) -> DictCo The merged configuration. """ # using base config values if exist - if "embedding_dims" in config.keys() and config.embedding_dims is not None: - inferred_config.embedding_dims = config.embedding_dims + inferred_config.embedding_dims = config.get("embedding_dims") or inferred_config.embedding_dims merged_config = OmegaConf.merge(OmegaConf.to_container(config), OmegaConf.to_container(inferred_config)) return merged_config diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index e00489de..ec7a94b0 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -27,8 +27,7 @@ def safe_merge_config(config: DictConfig, inferred_config: DictConfig) -> DictCo The merged configuration. """ # using base config values if exist - if "embedding_dims" in config.keys() and config.embedding_dims is not None: - inferred_config.embedding_dims = config.embedding_dims + inferred_config.embedding_dims = config.get("embedding_dims") or inferred_config.embedding_dims merged_config = OmegaConf.merge(OmegaConf.to_container(config), OmegaConf.to_container(inferred_config)) return merged_config diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 39c05754..e1ec7100 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -314,10 +314,7 @@ def _load_weights(cls, model, path: Union[str, Path]) -> None: None """ ckpt = pl_load(path, map_location=lambda storage, loc: storage) - if "state_dict" in ckpt.keys(): - model.load_state_dict(ckpt["state_dict"]) - else: - model.load_state_dict(ckpt) + model.load_state_dict(ckpt.get("state_dict") or ckpt) @classmethod def load_model(cls, dir: str, map_location=None, strict=True):