diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 97b7153b..6f4b4c3f 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -44,12 +44,11 @@ def _read_yaml(filename): def _validate_choices(cls): for key in cls.__dataclass_fields__.keys(): atr = cls.__dataclass_fields__[key] - if atr.init: - if "choices" in atr.metadata.keys(): - if getattr(cls, key) not in atr.metadata.get("choices"): - raise ValueError( - f"{getattr(cls, key)} is not a valid choice for {key}. Please choose from on of the following: {atr.metadata['choices']}" - ) + if atr.init and "choices" in atr.metadata.keys(): + if getattr(cls, key) not in atr.metadata.get("choices"): + raise ValueError( + f"{getattr(cls, key)} is not a valid choice for {key}. Please choose from on of the following: {atr.metadata['choices']}" + ) @dataclass diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index faf2cd28..15868f2a 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -47,8 +47,7 @@ def safe_merge_config(config: DictConfig, inferred_config: DictConfig) -> DictCo The merged configuration. """ # using base config values if exist - if "embedding_dims" in config.keys() and config.embedding_dims is not None: - inferred_config.embedding_dims = config.embedding_dims + inferred_config.embedding_dims = config.get("embedding_dims") or inferred_config.embedding_dims merged_config = OmegaConf.merge(OmegaConf.to_container(config), OmegaConf.to_container(inferred_config)) return merged_config diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index e00489de..ec7a94b0 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -27,8 +27,7 @@ def safe_merge_config(config: DictConfig, inferred_config: DictConfig) -> DictCo The merged configuration. """ # using base config values if exist - if "embedding_dims" in config.keys() and config.embedding_dims is not None: - inferred_config.embedding_dims = config.embedding_dims + inferred_config.embedding_dims = config.get("embedding_dims") or inferred_config.embedding_dims merged_config = OmegaConf.merge(OmegaConf.to_container(config), OmegaConf.to_container(inferred_config)) return merged_config diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 39c05754..e1ec7100 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -314,10 +314,7 @@ def _load_weights(cls, model, path: Union[str, Path]) -> None: None """ ckpt = pl_load(path, map_location=lambda storage, loc: storage) - if "state_dict" in ckpt.keys(): - model.load_state_dict(ckpt["state_dict"]) - else: - model.load_state_dict(ckpt) + model.load_state_dict(ckpt.get("state_dict") or ckpt) @classmethod def load_model(cls, dir: str, map_location=None, strict=True):