Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/trained models inference #226

Merged
merged 22 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
57e7a75
feat: allow for loading lightning model from ckpt
djaniak Mar 24, 2022
4e76235
fix: missing arg in lightning_sequence_labeling.py
djaniak Mar 31, 2022
e76c948
feat: notebook with inference example for lightning
djaniak Mar 31, 2022
7858aa4
feat: notebook with inference example for flair
djaniak Mar 31, 2022
3da9c87
feat: implement load_from_ckpt method for LightningTask
djaniak Apr 6, 2022
03943f7
fix: restore inference after rebase for lightning
djaniak Apr 7, 2022
29519d4
feat: tests for lightning inference
djaniak Apr 7, 2022
282f4ab
refactor: move lightning inference notebook to tutorials and refactor
djaniak Apr 7, 2022
b6a2771
refactor: switch to herbert for testing inference
djaniak Apr 7, 2022
a990502
feat: implement flair task from_checkpoint method
djaniak Apr 7, 2022
cb8758d
refactor: lightning inference and datamodule
djaniak Apr 8, 2022
a5eae47
tests: add flair trained model inference
djaniak Apr 8, 2022
8ef0b25
refactor: update notebooks with current code
djaniak Apr 8, 2022
2396b36
refactor: naming in Lightning modules
djaniak Apr 8, 2022
38c1438
fix: tests and notebooks after rebase
djaniak Apr 8, 2022
5bebc1a
refactor: loading model from ckpt for lightning
djaniak Apr 8, 2022
8434919
fix(tests): flair inference tests
djaniak Apr 9, 2022
cfb2cb1
fix(tests): isort
djaniak Apr 10, 2022
2c830c8
fix: inference for lightning pipelines
djaniak Apr 12, 2022
63e96b7
refactor: inference tests for flair and lightning
djaniak Apr 12, 2022
f7385da
misc: update tutorial notebook
djaniak Apr 12, 2022
2566140
refactor: pr issues
djaniak Apr 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions embeddings/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
dataloader_kwargs: Optional[Dict[str, Any]] = None,
seed: int = 441,
) -> None:
self.has_setup = False
self.dataset_name_or_path = dataset_name_or_path
self.tokenizer_name_or_path = tokenizer_name_or_path
self.target_field = target_field
Expand All @@ -74,9 +75,10 @@ def __init__(
self.load_dataset_kwargs = load_dataset_kwargs if load_dataset_kwargs else {}
self.dataloader_kwargs = dataloader_kwargs if dataloader_kwargs else {}
self.seed = seed
dataset_info = self.load_dataset()["train"].info
self.setup()
super().__init__(
dataset_info=dataset_info, dataset_version=dataset_info.version.version_str
dataset_info=self.dataset["train"].info,
dataset_version=self.dataset["train"].info.version.version_str,
)

@abc.abstractmethod
Expand All @@ -94,13 +96,15 @@ def convert_to_features(
pass

def prepare_data(self) -> None:
self.load_dataset(preparation_step=True)
AutoTokenizer.from_pretrained(self.tokenizer_name_or_path)

def setup(self, stage: Optional[str] = None) -> None:
self.dataset = self.load_dataset()
self.prepare_labels()
self.process_data()
if not self.has_setup:
self.dataset = self.load_dataset()
self.prepare_labels()
self.process_data()
self.has_setup = True
assert all(hasattr(self, attr) for attr in ["num_classes", "target_names", "dataset"])

def load_dataset(self, preparation_step: bool = False) -> DatasetDict:
dataset = embeddings_dataset.Dataset(
Expand Down
43 changes: 25 additions & 18 deletions embeddings/model/lightning_module/huggingface_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import sys
from collections import ChainMap
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, List, Optional, Type

from torchmetrics import F1, Accuracy, MetricCollection, Precision, Recall
from transformers import AutoConfig, AutoModel
Expand All @@ -15,6 +15,7 @@ def __init__(
self,
model_name_or_path: T_path,
downstream_model_type: Type["AutoModel"],
num_classes: int,
finetune_last_n_layers: int,
metrics: Optional[MetricCollection] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -24,13 +25,15 @@ def __init__(
self.save_hyperparameters({"downstream_model_type": downstream_model_type.__name__})
self.downstream_model_type = downstream_model_type
self.config_kwargs = config_kwargs if config_kwargs else {}
self.target_names: Optional[List[str]] = None
self._init_model()
self._init_metrics()

def setup(self, stage: Optional[str] = None) -> None:
if stage in ("fit", None):
self.configure_model()
self.configure_metrics()
assert self.trainer is not None
self.target_names = self.trainer.datamodule.target_names
if self.hparams.use_scheduler:
assert self.trainer is not None
train_loader = self.trainer.datamodule.train_dataloader()
gpus = getattr(self.trainer, "gpus") if getattr(self.trainer, "gpus") else 0
tb_size = self.hparams.train_batch_size * max(1, gpus)
Expand All @@ -39,11 +42,10 @@ def setup(self, stage: Optional[str] = None) -> None:
(len(train_loader.dataset) / ab_size) * float(self.trainer.max_epochs)
)

def configure_model(self) -> None:
assert self.trainer is not None
def _init_model(self) -> None:
self.config = AutoConfig.from_pretrained(
self.hparams.model_name_or_path,
num_labels=self.trainer.datamodule.num_classes,
num_labels=self.hparams.num_classes,
**self.config_kwargs,
)
self.model: AutoModel = self.downstream_model_type.from_pretrained(
Expand Down Expand Up @@ -72,24 +74,22 @@ def freeze_transformer(self, finetune_last_n_layers: int) -> None:
param.requires_grad = False

def get_default_metrics(self) -> MetricCollection:
assert self.trainer is not None
num_classes = self.trainer.datamodule.num_classes
if num_classes > 2:
if self.hparams.num_classes > 2:
metrics = MetricCollection(
[
Accuracy(num_classes=num_classes),
Precision(num_classes=num_classes, average="macro"),
Recall(num_classes=num_classes, average="macro"),
F1(num_classes=num_classes, average="macro"),
Accuracy(num_classes=self.hparams.num_classes),
Precision(num_classes=self.hparams.num_classes, average="macro"),
Recall(num_classes=self.hparams.num_classes, average="macro"),
F1(num_classes=self.hparams.num_classes, average="macro"),
]
)
else:
metrics = MetricCollection(
[
Accuracy(num_classes=num_classes),
Precision(num_classes=num_classes),
Recall(num_classes=num_classes),
F1(num_classes=num_classes),
Accuracy(num_classes=self.hparams.num_classes),
Precision(num_classes=self.hparams.num_classes),
Recall(num_classes=self.hparams.num_classes),
F1(num_classes=self.hparams.num_classes),
]
)
return metrics
Expand All @@ -100,3 +100,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
if isinstance(inputs, tuple):
inputs = dict(ChainMap(*inputs))
return self.model(**inputs)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
assert self.trainer is not None
checkpoint["target_names"] = self.trainer.datamodule.target_names

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.target_names = checkpoint["target_names"]
32 changes: 24 additions & 8 deletions embeddings/model/lightning_module/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytorch_lightning as pl
import torch
from numpy import typing as nptyping
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.nn.functional import softmax
from torch.optim import Optimizer
Expand All @@ -14,9 +15,12 @@
from transformers import get_linear_schedule_with_warmup

from embeddings.data.datamodule import HuggingFaceDataset
from embeddings.utils.loggers import get_logger

Model = TypeVar("Model")

_logger = get_logger(__name__)


class LightningModule(pl.LightningModule, abc.ABC, Generic[Model]):
def __init__(
Expand Down Expand Up @@ -68,19 +72,31 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Optional[Tuple[STEP_OUTPUT,
def predict(
self, dataloader: DataLoader[HuggingFaceDataset]
) -> Dict[str, nptyping.NDArray[Any]]:
assert self.trainer is not None
logits_predictions = self.trainer.predict(
dataloaders=dataloader, return_predictions=True, ckpt_path="best"
)
logits, predictions = zip(*logits_predictions)
logits, predictions = zip(*self._predict_with_trainer(dataloader))
probabilities = softmax(torch.cat(logits), dim=1).numpy()
predictions = torch.cat(predictions).numpy()
ground_truth = torch.cat([x["labels"] for x in dataloader]).numpy()
result = {"y_pred": predictions, "y_true": ground_truth, "y_probabilities": probabilities}
assert all(isinstance(x, np.ndarray) for x in result.values())
return result

def configure_metrics(self) -> None:
def _predict_with_trainer(self, dataloader: DataLoader[HuggingFaceDataset]) -> torch.Tensor:
assert self.trainer is not None
try:
return self.trainer.predict(
model=self, dataloaders=dataloader, return_predictions=True, ckpt_path="best"
)
except MisconfigurationException: # model loaded but not fitted
_logger.warning(
"The best model checkpoint cannot be loaded because trainer.fit has not been called. Using current weights for prediction."
)
return self.trainer.predict(
model=self,
dataloaders=dataloader,
return_predictions=True,
)

def _init_metrics(self) -> None:
if self.metrics is None:
self.metrics = self.get_default_metrics()
self.train_metrics = self.metrics.clone(prefix="train/")
Expand Down Expand Up @@ -132,13 +148,13 @@ def configure_optimizers(self) -> Tuple[List[Optimizer], List[Any]]:
)

if self.hparams.use_scheduler:
lr_schedulers = self.configure_schedulers(optimizer=optimizer)
lr_schedulers = self._get_schedulers(optimizer=optimizer)
else:
lr_schedulers = []

return [optimizer], lr_schedulers

def configure_schedulers(self, optimizer: Optimizer) -> List[Dict[str, Any]]:
def _get_schedulers(self, optimizer: Optimizer) -> List[Dict[str, Any]]:
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
Expand Down
19 changes: 19 additions & 0 deletions embeddings/model/lightning_module/sequence_labeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Optional, Tuple

import torch
from datasets import ClassLabel
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torchmetrics import MetricCollection
from transformers import AutoModelForTokenClassification
Expand All @@ -18,6 +19,7 @@ class SequenceLabelingModule(HuggingFaceLightningModule):
def __init__(
self,
model_name_or_path: T_path,
num_classes: int,
finetune_last_n_layers: int,
metrics: Optional[MetricCollection] = None,
ignore_index: int = -100,
Expand All @@ -27,12 +29,21 @@ def __init__(
super().__init__(
model_name_or_path=model_name_or_path,
downstream_model_type=self.downstream_model_type,
num_classes=num_classes,
finetune_last_n_layers=finetune_last_n_layers,
metrics=metrics,
config_kwargs=config_kwargs,
task_model_kwargs=task_model_kwargs,
)
self.ignore_index = ignore_index
self.class_label: Optional[ClassLabel] = None

def setup(self, stage: Optional[str] = None) -> None:
if stage in ("fit", None):
assert self.trainer is not None
self.class_label = self.trainer.datamodule.dataset["train"].features["labels"].feature
assert isinstance(self.class_label, ClassLabel)
super().setup(stage=stage)

def shared_step(self, **batch: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
outputs = self.forward(**batch)
Expand Down Expand Up @@ -73,3 +84,11 @@ def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
else:
_logger.warning("Missing labels for the test data")
return None

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["class_label"] = self.class_label
super().on_save_checkpoint(checkpoint=checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.class_label = checkpoint["class_label"]
super().on_load_checkpoint(checkpoint=checkpoint)
2 changes: 2 additions & 0 deletions embeddings/model/lightning_module/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class TextClassificationModule(HuggingFaceLightningModule):
def __init__(
self,
model_name_or_path: T_path,
num_classes: int,
finetune_last_n_layers: int,
metrics: Optional[MetricCollection] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -26,6 +27,7 @@ def __init__(
super().__init__(
model_name_or_path=model_name_or_path,
downstream_model_type=self.downstream_model_type,
num_classes=num_classes,
finetune_last_n_layers=finetune_last_n_layers,
metrics=metrics,
config_kwargs=config_kwargs,
Expand Down
1 change: 1 addition & 0 deletions embeddings/pipeline/lightning_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
task = TextClassificationTask(
model_name_or_path=embedding_name_or_path,
output_path=output_path,
num_classes=datamodule.num_classes,
finetune_last_n_layers=config.finetune_last_n_layers,
model_config_kwargs=config.model_config_kwargs,
task_model_kwargs=config.task_model_kwargs,
Expand Down
1 change: 1 addition & 0 deletions embeddings/pipeline/lightning_sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
task = SequenceLabelingTask(
model_name_or_path=embedding_name_or_path,
output_path=output_path,
num_classes=datamodule.num_classes,
finetune_last_n_layers=config.finetune_last_n_layers,
model_config_kwargs=config.model_config_kwargs,
task_model_kwargs=config.task_model_kwargs,
Expand Down
32 changes: 31 additions & 1 deletion embeddings/task/flair_task/flair_task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import abc
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type

import flair
from flair.data import Corpus, Dictionary, Sentence
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from numpy import typing as nptyping
from typing_extensions import Literal
Expand All @@ -24,6 +25,7 @@ def __init__(
self,
output_path: T_path = RESULTS_PATH,
task_train_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any
):
super().__init__()
self.model: Optional[flair.nn.Model] = None
Expand Down Expand Up @@ -106,3 +108,31 @@ def get_y(data: List[Sentence], y_type: str, y_dictionary: Dictionary) -> nptypi
@abc.abstractmethod
def remove_labels_from_data(data: List[Sentence], y_type: str) -> None:
pass

@classmethod
def restore_task_model(
cls,
checkpoint_path: T_path,
output_path: T_path,
flair_model: Type[flair.nn.Model],
task_train_kwargs: Optional[Dict[str, Any]],
) -> "FlairTask":
model = flair_model.load(checkpoint_path)
task_kwargs = (
{"hidden_size": model.hidden_size} if isinstance(model, SequenceTagger) else {}
)
task = cls(
output_path=output_path, task_train_kwargs=task_train_kwargs or {}, **task_kwargs
)
task.model = model
return task

@classmethod
@abc.abstractmethod
def from_checkpoint(
cls,
checkpoint_path: T_path,
output_path: T_path,
task_train_kwargs: Optional[Dict[str, Any]] = None,
) -> "FlairTask":
pass
14 changes: 14 additions & 0 deletions embeddings/task/flair_task/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,17 @@ def remove_labels_from_data(data: List[Sentence], y_type: str) -> None:
for sent in data:
for token in sent:
token.remove_labels(y_type)

@classmethod
def from_checkpoint(
cls,
checkpoint_path: T_path,
output_path: T_path,
task_train_kwargs: Optional[Dict[str, Any]] = None,
) -> "FlairTask":
return cls.restore_task_model(
checkpoint_path=checkpoint_path,
output_path=output_path,
flair_model=SequenceTagger,
task_train_kwargs=task_train_kwargs,
)
14 changes: 14 additions & 0 deletions embeddings/task/flair_task/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,17 @@ def get_y(data: List[Sentence], y_type: str, y_dictionary: Dictionary) -> nptypi
def remove_labels_from_data(data: List[Sentence], y_type: str) -> None:
for sentence in data:
sentence.remove_labels(y_type)

@classmethod
def from_checkpoint(
cls,
checkpoint_path: T_path,
output_path: T_path,
task_train_kwargs: Optional[Dict[str, Any]] = None,
) -> "FlairTask":
return cls.restore_task_model(
checkpoint_path=checkpoint_path,
output_path=output_path,
flair_model=TextClassifier,
task_train_kwargs=task_train_kwargs,
)
Loading