Skip to content

Commit

Permalink
refactor: loading model from ckpt for lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Apr 8, 2022
1 parent 38c1438 commit 5bebc1a
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 7 deletions.
4 changes: 1 addition & 3 deletions embeddings/task/lightning_task/lightning_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def restore_task_model(
output_path: T_path,
lightning_module: Type[LightningModule[AutoModel]],
task_train_kwargs: Optional[Dict[str, Any]],
early_stopping_kwargs: Optional[Dict[str, Any]],
logging_config: Optional[LightningLoggingConfig],
) -> "LightningTask":
model = lightning_module.load_from_checkpoint(str(checkpoint_path))
Expand All @@ -139,7 +138,7 @@ def restore_task_model(
"model_config_kwargs": model.hparams.config_kwargs,
"task_model_kwargs": model.hparams.task_model_kwargs,
"task_train_kwargs": task_train_kwargs or {},
"early_stopping_kwargs": early_stopping_kwargs or {},
"early_stopping_kwargs": {},
"logging_config": logging_config or LightningLoggingConfig(),
}
task = cls(**init_kwargs)
Expand All @@ -155,7 +154,6 @@ def from_checkpoint(
checkpoint_path: T_path,
output_path: T_path,
task_train_kwargs: Optional[Dict[str, Any]] = None,
early_stopping_kwargs: Optional[Dict[str, Any]] = None,
logging_config: Optional[LightningLoggingConfig] = None,
) -> "LightningTask":
pass
2 changes: 0 additions & 2 deletions embeddings/task/lightning_task/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,12 @@ def from_checkpoint(
checkpoint_path: T_path,
output_path: T_path,
task_train_kwargs: Optional[Dict[str, Any]] = None,
early_stopping_kwargs: Optional[Dict[str, Any]] = None,
logging_config: Optional[LightningLoggingConfig] = None,
) -> "LightningTask":
return cls.restore_task_model(
checkpoint_path=checkpoint_path,
output_path=output_path,
task_train_kwargs=task_train_kwargs,
early_stopping_kwargs=early_stopping_kwargs,
lightning_module=SequenceLabelingModule,
logging_config=logging_config,
)
2 changes: 0 additions & 2 deletions embeddings/task/lightning_task/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,12 @@ def from_checkpoint(
checkpoint_path: T_path,
output_path: T_path,
task_train_kwargs: Optional[Dict[str, Any]] = None,
early_stopping_kwargs: Optional[Dict[str, Any]] = None,
logging_config: Optional[LightningLoggingConfig] = None,
) -> "LightningTask":
return cls.restore_task_model(
checkpoint_path=checkpoint_path,
output_path=output_path,
task_train_kwargs=task_train_kwargs,
early_stopping_kwargs=early_stopping_kwargs,
lightning_module=TextClassificationModule,
logging_config=logging_config,
)

0 comments on commit 5bebc1a

Please sign in to comment.