Skip to content

Commit

Permalink
feat: implement load_from_ckpt method for LightningTask
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Apr 6, 2022
1 parent 05246f6 commit c1dbe65
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 35 deletions.
4 changes: 3 additions & 1 deletion embeddings/model/lightning_module/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def predict(
self, dataloader: DataLoader[HuggingFaceDataset]
) -> Dict[str, nptyping.NDArray[Any]]:
assert self.trainer is not None
predictions = self.trainer.predict(dataloaders=dataloader, return_predictions=True)
predictions = self.trainer.predict(
model=self, dataloaders=dataloader, return_predictions=True
)
predictions = torch.cat(predictions).numpy()
assert isinstance(predictions, np.ndarray)
ground_truth = torch.cat([x["labels"] for x in dataloader]).numpy()
Expand Down
40 changes: 38 additions & 2 deletions embeddings/task/lightning_task/lightning_task.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import abc
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

import pytorch_lightning as pl
import torch
from numpy import typing as nptyping
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader
from transformers import AutoModel

from embeddings.data.datamodule import HuggingFaceDataModule
from embeddings.data.dataset import LightingDataModuleSubset
from embeddings.data.io import T_path
from embeddings.model.lightning_module.huggingface_module import HuggingFaceLightningModule
from embeddings.model.lightning_module.lightning_module import LightningModule
from embeddings.task.task import Task
from embeddings.utils.lightning_callbacks.best_epoch_callback import BestEpochCallback
from embeddings.utils.loggers import get_logger
Expand Down Expand Up @@ -98,6 +100,40 @@ def fit_predict(
def build_task_model(self) -> None:
pass

@classmethod
def restore_task_model(
cls,
checkpoint_path: T_path,
output_path: T_path,
lightning_module: Type[LightningModule[AutoModel]],
task_train_kwargs: Optional[Dict[str, Any]],
early_stopping_kwargs: Optional[Dict[str, Any]],
) -> "LightningTask":
model = lightning_module.load_from_checkpoint(str(checkpoint_path))
trainer = pl.Trainer(default_root_dir=str(output_path), **task_train_kwargs or {})
init_kwargs = {
"model_name_or_path": model.hparams.model_name_or_path,
"output_path": output_path,
"num_classes": model.hparams.num_classes,
"finetune_last_n_layers": model.hparams.finetune_last_n_layers,
"model_config_kwargs": model.hparams.config_kwargs,
"task_model_kwargs": model.hparams.task_model_kwargs,
"task_train_kwargs": task_train_kwargs or {},
"early_stopping_kwargs": early_stopping_kwargs or {},
}
task = cls(**init_kwargs)
task.model = model
task.trainer = trainer
model.trainer = trainer
return task

@classmethod
@abc.abstractmethod
def restore_task_model(self, checkpoint_path: str) -> None:
def from_checkpoint(
cls,
checkpoint_path: T_path,
output_path: T_path,
task_train_kwargs: Optional[Dict[str, Any]],
early_stopping_kwargs: Optional[Dict[str, Any]],
) -> "LightningTask":
pass
25 changes: 17 additions & 8 deletions embeddings/task/lightning_task/sequence_labeling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import numpy as np
from numpy import typing as nptyping
Expand All @@ -19,17 +19,13 @@ def __init__(
task_model_kwargs: Dict[str, Any],
task_train_kwargs: Dict[str, Any],
early_stopping_kwargs: Dict[str, Any],
train_batch_size: int = 32,
eval_batch_size: int = 32,
finetune_last_n_layers: int = -1,
) -> None:
super().__init__(output_path, task_train_kwargs, early_stopping_kwargs)
self.model_name_or_path = model_name_or_path
self.num_classes = num_classes
self.model_config_kwargs = model_config_kwargs
self.task_model_kwargs = task_model_kwargs
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.finetune_last_n_layers = finetune_last_n_layers

def build_task_model(self) -> None:
Expand All @@ -41,9 +37,6 @@ def build_task_model(self) -> None:
task_model_kwargs=self.task_model_kwargs,
)

def restore_task_model(self, checkpoint_path: str) -> None:
self.model = SequenceLabelingModule.load_from_checkpoint(checkpoint_path)

def predict(self, dataloader: DataLoader[Any]) -> Dict[str, nptyping.NDArray[Any]]:
assert self.model is not None
results = self.model.predict(dataloader=dataloader)
Expand All @@ -63,3 +56,19 @@ def _map_filter_data(
getattr(self.trainer, "datamodule").id2str(x.item())
for x in data[ground_truth_data != self.model.ignore_index]
]

@classmethod
def from_checkpoint(
cls,
checkpoint_path: T_path,
output_path: T_path,
task_train_kwargs: Optional[Dict[str, Any]],
early_stopping_kwargs: Optional[Dict[str, Any]],
) -> "LightningTask":
return cls.restore_task_model(
checkpoint_path=checkpoint_path,
output_path=output_path,
task_train_kwargs=task_train_kwargs,
early_stopping_kwargs=early_stopping_kwargs,
lightning_module=SequenceLabelingModule,
)
22 changes: 18 additions & 4 deletions embeddings/task/lightning_task/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

from numpy import typing as nptyping
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -26,6 +26,7 @@ def __init__(
self.model_config_kwargs = model_config_kwargs
self.task_model_kwargs = task_model_kwargs
self.finetune_last_n_layers = finetune_last_n_layers
self.task_train_kwargs = task_train_kwargs

def build_task_model(self) -> None:
self.model = TextClassificationModule(
Expand All @@ -36,9 +37,22 @@ def build_task_model(self) -> None:
task_model_kwargs=self.task_model_kwargs,
)

def restore_task_model(self, checkpoint_path: str) -> None:
self.model = TextClassificationModule.load_from_checkpoint(checkpoint_path)

def predict(self, dataloader: DataLoader[Any]) -> Dict[str, nptyping.NDArray[Any]]:
assert self.model is not None
return self.model.predict(dataloader=dataloader)

@classmethod
def from_checkpoint(
cls,
checkpoint_path: T_path,
output_path: T_path,
task_train_kwargs: Optional[Dict[str, Any]],
early_stopping_kwargs: Optional[Dict[str, Any]],
) -> "LightningTask":
return cls.restore_task_model(
checkpoint_path=checkpoint_path,
output_path=output_path,
task_train_kwargs=task_train_kwargs,
early_stopping_kwargs=early_stopping_kwargs,
lightning_module=TextClassificationModule,
)
58 changes: 38 additions & 20 deletions notebooks/validate_lightning_models_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"import pytorch_lightning as pl\n",
"import torch\n",
"from embeddings.defaults import RESULTS_PATH\n",
"from embeddings.task.lightning_task.text_classification import TextClassificationTask\n",
"from embeddings.model.lightning_module.text_classification import (\n",
" TextClassificationModule,\n",
")\n",
Expand All @@ -52,12 +53,13 @@
},
"outputs": [],
"source": [
"embedding_name_or_path = \"allegro/herbert-base-cased\"\n",
"embedding_name_or_path = \"hf-internal-testing/tiny-albert\"\n",
"dataset_name = \"clarin-pl/polemo2-official\"\n",
"input_columns_name = \"text\"\n",
"target_column_name = \"target\"\n",
"path = TemporaryDirectory()\n",
"output_path = path.name\n",
"# path = TemporaryDirectory()\n",
"# output_path = path.name\n",
"output_path = \".\"\n",
"\n",
"pipeline = LightningClassificationPipeline(\n",
" embedding_name_or_path=embedding_name_or_path,\n",
Expand All @@ -72,12 +74,7 @@
" \"test_domains\": [\"hotels\", \"medicine\"],\n",
" \"text_cfg\": \"text\",\n",
" },\n",
" datamodule_kwargs={\n",
" \"max_seq_length\": 64,\n",
" \"downsample_train\": 0.005,\n",
" \"downsample_val\": 0.01,\n",
" \"downsample_test\": 0.01,\n",
" },\n",
" datamodule_kwargs={\"max_seq_length\": 64,},\n",
" task_train_kwargs={\n",
" \"max_epochs\": 1,\n",
" \"devices\": \"auto\",\n",
Expand Down Expand Up @@ -106,11 +103,26 @@
"ckpt_path = (\n",
" Path(output_path)\n",
" / \"lightning_logs\"\n",
" / \"version_0\"\n",
" / \"version_1\"\n",
" / \"checkpoints\"\n",
" / \"epoch=0-step=0.ckpt\"\n",
" / \"epoch=0-step=180.ckpt\"\n",
")\n",
"ckpt_path"
"ckpt_path.resolve()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2785fcbc-1c95-4d23-807f-a14569992354",
"metadata": {},
"outputs": [],
"source": [
"task_from_ckpt = TextClassificationTask.from_checkpoint(\n",
" checkpoint_path=ckpt_path,\n",
" output_path=output_path,\n",
" task_train_kwargs={},\n",
" early_stopping_kwargs={},\n",
")"
]
},
{
Expand Down Expand Up @@ -143,7 +155,7 @@
"outputs": [],
"source": [
"model_state_dict = pipeline.model.task.model.model.state_dict()\n",
"model_from_ckpt_state_dict = model_from_ckpt.model.state_dict()"
"model_from_ckpt_state_dict = task_from_ckpt.model.model.state_dict()"
]
},
{
Expand Down Expand Up @@ -185,6 +197,12 @@
"pipeline.model.task.trainer.save_checkpoint(\"example.ckpt\")\n",
"new_model = TextClassificationModule.load_from_checkpoint(\n",
" checkpoint_path=\"example.ckpt\"\n",
")\n",
"new_task_from_ckpt = TextClassificationTask.from_checkpoint(\n",
" checkpoint_path=ckpt_path,\n",
" output_path=output_path,\n",
" task_train_kwargs={},\n",
" early_stopping_kwargs={},\n",
")"
]
},
Expand All @@ -199,31 +217,31 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f3afa250-2937-4aad-bb3c-172a68639892",
"id": "4ad7b9b0-823a-4c8e-aac5-61a333558ed1",
"metadata": {},
"outputs": [],
"source": [
"trainer = pl.Trainer()\n",
"test_dataloader = pipeline.datamodule.test_dataloader()\n",
"predictions = trainer.predict(model_from_ckpt, dataloaders=test_dataloader)"
"preds = task_from_ckpt.predict(test_dataloader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09f45c8b-791b-43b4-9826-f798d48b9d97",
"id": "f3afa250-2937-4aad-bb3c-172a68639892",
"metadata": {},
"outputs": [],
"source": [
"predictions"
"trainer = pl.Trainer()\n",
"preds_other = trainer.predict(model_from_ckpt, dataloaders=test_dataloader)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:clarinpl-embeddings]",
"display_name": "Python [conda env:embeddings]",
"language": "python",
"name": "conda-env-clarinpl-embeddings-py"
"name": "conda-env-embeddings-py"
},
"language_info": {
"codemirror_mode": {
Expand Down

0 comments on commit c1dbe65

Please sign in to comment.