diff --git a/embeddings/data/datamodule.py b/embeddings/data/datamodule.py index c72bc16c..4182198a 100644 --- a/embeddings/data/datamodule.py +++ b/embeddings/data/datamodule.py @@ -60,6 +60,7 @@ def __init__( dataloader_kwargs: Optional[Dict[str, Any]] = None, seed: int = 441, ) -> None: + self.has_setup = False self.dataset_name_or_path = dataset_name_or_path self.tokenizer_name_or_path = tokenizer_name_or_path self.target_field = target_field @@ -74,9 +75,10 @@ def __init__( self.load_dataset_kwargs = load_dataset_kwargs if load_dataset_kwargs else {} self.dataloader_kwargs = dataloader_kwargs if dataloader_kwargs else {} self.seed = seed - dataset_info = self.load_dataset()["train"].info + self.setup() super().__init__( - dataset_info=dataset_info, dataset_version=dataset_info.version.version_str + dataset_info=self.dataset["train"].info, + dataset_version=self.dataset["train"].info.version.version_str, ) @abc.abstractmethod @@ -94,13 +96,15 @@ def convert_to_features( pass def prepare_data(self) -> None: - self.load_dataset(preparation_step=True) AutoTokenizer.from_pretrained(self.tokenizer_name_or_path) def setup(self, stage: Optional[str] = None) -> None: - self.dataset = self.load_dataset() - self.prepare_labels() - self.process_data() + if not self.has_setup: + self.dataset = self.load_dataset() + self.prepare_labels() + self.process_data() + self.has_setup = True + assert all(hasattr(self, attr) for attr in ["num_classes", "target_names", "dataset"]) def load_dataset(self, preparation_step: bool = False) -> DatasetDict: dataset = embeddings_dataset.Dataset( diff --git a/embeddings/model/lightning_module/huggingface_module.py b/embeddings/model/lightning_module/huggingface_module.py index a2196819..ed23b6e5 100644 --- a/embeddings/model/lightning_module/huggingface_module.py +++ b/embeddings/model/lightning_module/huggingface_module.py @@ -1,7 +1,7 @@ import abc import sys from collections import ChainMap -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, List, Optional, Type from torchmetrics import F1, Accuracy, MetricCollection, Precision, Recall from transformers import AutoConfig, AutoModel @@ -15,6 +15,7 @@ def __init__( self, model_name_or_path: T_path, downstream_model_type: Type["AutoModel"], + num_classes: int, finetune_last_n_layers: int, metrics: Optional[MetricCollection] = None, config_kwargs: Optional[Dict[str, Any]] = None, @@ -24,13 +25,15 @@ def __init__( self.save_hyperparameters({"downstream_model_type": downstream_model_type.__name__}) self.downstream_model_type = downstream_model_type self.config_kwargs = config_kwargs if config_kwargs else {} + self.target_names: Optional[List[str]] = None + self._init_model() + self._init_metrics() def setup(self, stage: Optional[str] = None) -> None: if stage in ("fit", None): - self.configure_model() - self.configure_metrics() + assert self.trainer is not None + self.target_names = self.trainer.datamodule.target_names if self.hparams.use_scheduler: - assert self.trainer is not None train_loader = self.trainer.datamodule.train_dataloader() gpus = getattr(self.trainer, "gpus") if getattr(self.trainer, "gpus") else 0 tb_size = self.hparams.train_batch_size * max(1, gpus) @@ -39,11 +42,10 @@ def setup(self, stage: Optional[str] = None) -> None: (len(train_loader.dataset) / ab_size) * float(self.trainer.max_epochs) ) - def configure_model(self) -> None: - assert self.trainer is not None + def _init_model(self) -> None: self.config = AutoConfig.from_pretrained( self.hparams.model_name_or_path, - num_labels=self.trainer.datamodule.num_classes, + num_labels=self.hparams.num_classes, **self.config_kwargs, ) self.model: AutoModel = self.downstream_model_type.from_pretrained( @@ -72,24 +74,22 @@ def freeze_transformer(self, finetune_last_n_layers: int) -> None: param.requires_grad = False def get_default_metrics(self) -> MetricCollection: - assert self.trainer is not None - num_classes = self.trainer.datamodule.num_classes - if num_classes > 2: + if self.hparams.num_classes > 2: metrics = MetricCollection( [ - Accuracy(num_classes=num_classes), - Precision(num_classes=num_classes, average="macro"), - Recall(num_classes=num_classes, average="macro"), - F1(num_classes=num_classes, average="macro"), + Accuracy(num_classes=self.hparams.num_classes), + Precision(num_classes=self.hparams.num_classes, average="macro"), + Recall(num_classes=self.hparams.num_classes, average="macro"), + F1(num_classes=self.hparams.num_classes, average="macro"), ] ) else: metrics = MetricCollection( [ - Accuracy(num_classes=num_classes), - Precision(num_classes=num_classes), - Recall(num_classes=num_classes), - F1(num_classes=num_classes), + Accuracy(num_classes=self.hparams.num_classes), + Precision(num_classes=self.hparams.num_classes), + Recall(num_classes=self.hparams.num_classes), + F1(num_classes=self.hparams.num_classes), ] ) return metrics @@ -100,3 +100,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: if isinstance(inputs, tuple): inputs = dict(ChainMap(*inputs)) return self.model(**inputs) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + assert self.trainer is not None + checkpoint["target_names"] = self.trainer.datamodule.target_names + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.target_names = checkpoint["target_names"] diff --git a/embeddings/model/lightning_module/lightning_module.py b/embeddings/model/lightning_module/lightning_module.py index 1f20bc26..9c5a9808 100644 --- a/embeddings/model/lightning_module/lightning_module.py +++ b/embeddings/model/lightning_module/lightning_module.py @@ -6,6 +6,7 @@ import pytorch_lightning as pl import torch from numpy import typing as nptyping +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.nn.functional import softmax from torch.optim import Optimizer @@ -14,9 +15,12 @@ from transformers import get_linear_schedule_with_warmup from embeddings.data.datamodule import HuggingFaceDataset +from embeddings.utils.loggers import get_logger Model = TypeVar("Model") +_logger = get_logger(__name__) + class LightningModule(pl.LightningModule, abc.ABC, Generic[Model]): def __init__( @@ -68,11 +72,7 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Optional[Tuple[STEP_OUTPUT, def predict( self, dataloader: DataLoader[HuggingFaceDataset] ) -> Dict[str, nptyping.NDArray[Any]]: - assert self.trainer is not None - logits_predictions = self.trainer.predict( - dataloaders=dataloader, return_predictions=True, ckpt_path="best" - ) - logits, predictions = zip(*logits_predictions) + logits, predictions = zip(*self._predict_with_trainer(dataloader)) probabilities = softmax(torch.cat(logits), dim=1).numpy() predictions = torch.cat(predictions).numpy() ground_truth = torch.cat([x["labels"] for x in dataloader]).numpy() @@ -80,7 +80,23 @@ def predict( assert all(isinstance(x, np.ndarray) for x in result.values()) return result - def configure_metrics(self) -> None: + def _predict_with_trainer(self, dataloader: DataLoader[HuggingFaceDataset]) -> torch.Tensor: + assert self.trainer is not None + try: + return self.trainer.predict( + model=self, dataloaders=dataloader, return_predictions=True, ckpt_path="best" + ) + except MisconfigurationException: # model loaded but not fitted + _logger.warning( + "The best model checkpoint cannot be loaded because trainer.fit has not been called. Using current weights for prediction." + ) + return self.trainer.predict( + model=self, + dataloaders=dataloader, + return_predictions=True, + ) + + def _init_metrics(self) -> None: if self.metrics is None: self.metrics = self.get_default_metrics() self.train_metrics = self.metrics.clone(prefix="train/") @@ -132,13 +148,13 @@ def configure_optimizers(self) -> Tuple[List[Optimizer], List[Any]]: ) if self.hparams.use_scheduler: - lr_schedulers = self.configure_schedulers(optimizer=optimizer) + lr_schedulers = self._get_schedulers(optimizer=optimizer) else: lr_schedulers = [] return [optimizer], lr_schedulers - def configure_schedulers(self, optimizer: Optimizer) -> List[Dict[str, Any]]: + def _get_schedulers(self, optimizer: Optimizer) -> List[Dict[str, Any]]: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.hparams.warmup_steps, diff --git a/embeddings/model/lightning_module/sequence_labeling.py b/embeddings/model/lightning_module/sequence_labeling.py index d6a6805c..0e1e25fd 100644 --- a/embeddings/model/lightning_module/sequence_labeling.py +++ b/embeddings/model/lightning_module/sequence_labeling.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional, Tuple import torch +from datasets import ClassLabel from pytorch_lightning.utilities.types import STEP_OUTPUT from torchmetrics import MetricCollection from transformers import AutoModelForTokenClassification @@ -18,6 +19,7 @@ class SequenceLabelingModule(HuggingFaceLightningModule): def __init__( self, model_name_or_path: T_path, + num_classes: int, finetune_last_n_layers: int, metrics: Optional[MetricCollection] = None, ignore_index: int = -100, @@ -27,12 +29,21 @@ def __init__( super().__init__( model_name_or_path=model_name_or_path, downstream_model_type=self.downstream_model_type, + num_classes=num_classes, finetune_last_n_layers=finetune_last_n_layers, metrics=metrics, config_kwargs=config_kwargs, task_model_kwargs=task_model_kwargs, ) self.ignore_index = ignore_index + self.class_label: Optional[ClassLabel] = None + + def setup(self, stage: Optional[str] = None) -> None: + if stage in ("fit", None): + assert self.trainer is not None + self.class_label = self.trainer.datamodule.dataset["train"].features["labels"].feature + assert isinstance(self.class_label, ClassLabel) + super().setup(stage=stage) def shared_step(self, **batch: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: outputs = self.forward(**batch) @@ -73,3 +84,11 @@ def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: else: _logger.warning("Missing labels for the test data") return None + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + checkpoint["class_label"] = self.class_label + super().on_save_checkpoint(checkpoint=checkpoint) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.class_label = checkpoint["class_label"] + super().on_load_checkpoint(checkpoint=checkpoint) diff --git a/embeddings/model/lightning_module/text_classification.py b/embeddings/model/lightning_module/text_classification.py index 0f7d8ca3..a79db619 100644 --- a/embeddings/model/lightning_module/text_classification.py +++ b/embeddings/model/lightning_module/text_classification.py @@ -18,6 +18,7 @@ class TextClassificationModule(HuggingFaceLightningModule): def __init__( self, model_name_or_path: T_path, + num_classes: int, finetune_last_n_layers: int, metrics: Optional[MetricCollection] = None, config_kwargs: Optional[Dict[str, Any]] = None, @@ -26,6 +27,7 @@ def __init__( super().__init__( model_name_or_path=model_name_or_path, downstream_model_type=self.downstream_model_type, + num_classes=num_classes, finetune_last_n_layers=finetune_last_n_layers, metrics=metrics, config_kwargs=config_kwargs, diff --git a/embeddings/pipeline/lightning_classification.py b/embeddings/pipeline/lightning_classification.py index f8afc414..e3ed9039 100644 --- a/embeddings/pipeline/lightning_classification.py +++ b/embeddings/pipeline/lightning_classification.py @@ -57,6 +57,7 @@ def __init__( task = TextClassificationTask( model_name_or_path=embedding_name_or_path, output_path=output_path, + num_classes=datamodule.num_classes, finetune_last_n_layers=config.finetune_last_n_layers, model_config_kwargs=config.model_config_kwargs, task_model_kwargs=config.task_model_kwargs, diff --git a/embeddings/pipeline/lightning_sequence_labeling.py b/embeddings/pipeline/lightning_sequence_labeling.py index 1386bd34..fd970db5 100644 --- a/embeddings/pipeline/lightning_sequence_labeling.py +++ b/embeddings/pipeline/lightning_sequence_labeling.py @@ -62,6 +62,7 @@ def __init__( task = SequenceLabelingTask( model_name_or_path=embedding_name_or_path, output_path=output_path, + num_classes=datamodule.num_classes, finetune_last_n_layers=config.finetune_last_n_layers, model_config_kwargs=config.model_config_kwargs, task_model_kwargs=config.task_model_kwargs, diff --git a/embeddings/task/flair_task/flair_task.py b/embeddings/task/flair_task/flair_task.py index 3d4bb5c9..6fc746b0 100644 --- a/embeddings/task/flair_task/flair_task.py +++ b/embeddings/task/flair_task/flair_task.py @@ -1,9 +1,10 @@ import abc from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type import flair from flair.data import Corpus, Dictionary, Sentence +from flair.models import SequenceTagger from flair.trainers import ModelTrainer from numpy import typing as nptyping from typing_extensions import Literal @@ -24,6 +25,7 @@ def __init__( self, output_path: T_path = RESULTS_PATH, task_train_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any ): super().__init__() self.model: Optional[flair.nn.Model] = None @@ -106,3 +108,31 @@ def get_y(data: List[Sentence], y_type: str, y_dictionary: Dictionary) -> nptypi @abc.abstractmethod def remove_labels_from_data(data: List[Sentence], y_type: str) -> None: pass + + @classmethod + def restore_task_model( + cls, + checkpoint_path: T_path, + output_path: T_path, + flair_model: Type[flair.nn.Model], + task_train_kwargs: Optional[Dict[str, Any]], + ) -> "FlairTask": + model = flair_model.load(checkpoint_path) + task_kwargs = ( + {"hidden_size": model.hidden_size} if isinstance(model, SequenceTagger) else {} + ) + task = cls( + output_path=output_path, task_train_kwargs=task_train_kwargs or {}, **task_kwargs + ) + task.model = model + return task + + @classmethod + @abc.abstractmethod + def from_checkpoint( + cls, + checkpoint_path: T_path, + output_path: T_path, + task_train_kwargs: Optional[Dict[str, Any]] = None, + ) -> "FlairTask": + pass diff --git a/embeddings/task/flair_task/sequence_labeling.py b/embeddings/task/flair_task/sequence_labeling.py index aef4e189..6349f9c4 100644 --- a/embeddings/task/flair_task/sequence_labeling.py +++ b/embeddings/task/flair_task/sequence_labeling.py @@ -64,3 +64,17 @@ def remove_labels_from_data(data: List[Sentence], y_type: str) -> None: for sent in data: for token in sent: token.remove_labels(y_type) + + @classmethod + def from_checkpoint( + cls, + checkpoint_path: T_path, + output_path: T_path, + task_train_kwargs: Optional[Dict[str, Any]] = None, + ) -> "FlairTask": + return cls.restore_task_model( + checkpoint_path=checkpoint_path, + output_path=output_path, + flair_model=SequenceTagger, + task_train_kwargs=task_train_kwargs, + ) diff --git a/embeddings/task/flair_task/text_classification.py b/embeddings/task/flair_task/text_classification.py index 8120e08a..7dc8a31b 100644 --- a/embeddings/task/flair_task/text_classification.py +++ b/embeddings/task/flair_task/text_classification.py @@ -57,3 +57,17 @@ def get_y(data: List[Sentence], y_type: str, y_dictionary: Dictionary) -> nptypi def remove_labels_from_data(data: List[Sentence], y_type: str) -> None: for sentence in data: sentence.remove_labels(y_type) + + @classmethod + def from_checkpoint( + cls, + checkpoint_path: T_path, + output_path: T_path, + task_train_kwargs: Optional[Dict[str, Any]] = None, + ) -> "FlairTask": + return cls.restore_task_model( + checkpoint_path=checkpoint_path, + output_path=output_path, + flair_model=TextClassifier, + task_train_kwargs=task_train_kwargs, + ) diff --git a/embeddings/task/lightning_task/lightning_task.py b/embeddings/task/lightning_task/lightning_task.py index e0e6be69..55a5d61b 100644 --- a/embeddings/task/lightning_task/lightning_task.py +++ b/embeddings/task/lightning_task/lightning_task.py @@ -1,6 +1,6 @@ import abc from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence, Type import pytorch_lightning as pl import torch @@ -8,11 +8,13 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping from torch.utils.data import DataLoader +from transformers import AutoModel from embeddings.data.datamodule import HuggingFaceDataModule from embeddings.data.dataset import LightingDataModuleSubset from embeddings.data.io import T_path from embeddings.model.lightning_module.huggingface_module import HuggingFaceLightningModule +from embeddings.model.lightning_module.lightning_module import LightningModule from embeddings.task.task import Task from embeddings.utils.lightning_callbacks.best_epoch_callback import BestEpochCallback from embeddings.utils.loggers import LightningLoggingConfig, get_logger @@ -60,7 +62,7 @@ def best_validation_score(self) -> Optional[float]: def _get_callbacks(self, dataset_subsets: Sequence[str]) -> List[Callback]: callbacks: List[Callback] = [ - ModelCheckpoint(dirpath=self.output_path.joinpath("checkpoints")) + ModelCheckpoint(dirpath=self.output_path.joinpath("checkpoints"), save_last=True) ] if "validation" in dataset_subsets: callbacks.append(BestEpochCallback()) @@ -91,7 +93,9 @@ def fit( raise e @abc.abstractmethod - def predict(self, dataloader: DataLoader[Any]) -> Dict[str, nptyping.NDArray[Any]]: + def predict( + self, dataloader: DataLoader[Any], return_names: bool = True + ) -> Dict[str, nptyping.NDArray[Any]]: pass def fit_predict( @@ -114,3 +118,42 @@ def fit_predict( @abc.abstractmethod def build_task_model(self) -> None: pass + + @classmethod + def restore_task_model( + cls, + checkpoint_path: T_path, + output_path: T_path, + lightning_module: Type[LightningModule[AutoModel]], + task_train_kwargs: Optional[Dict[str, Any]], + logging_config: Optional[LightningLoggingConfig], + ) -> "LightningTask": + model = lightning_module.load_from_checkpoint(str(checkpoint_path)) + trainer = pl.Trainer(default_root_dir=str(output_path), **task_train_kwargs or {}) + init_kwargs = { + "model_name_or_path": model.hparams.model_name_or_path, + "output_path": output_path, + "num_classes": model.hparams.num_classes, + "finetune_last_n_layers": model.hparams.finetune_last_n_layers, + "model_config_kwargs": model.hparams.config_kwargs, + "task_model_kwargs": model.hparams.task_model_kwargs, + "task_train_kwargs": task_train_kwargs or {}, + "early_stopping_kwargs": {}, + "logging_config": logging_config or LightningLoggingConfig(), + } + task = cls(**init_kwargs) + task.model = model + task.trainer = trainer + model.trainer = trainer + return task + + @classmethod + @abc.abstractmethod + def from_checkpoint( + cls, + checkpoint_path: T_path, + output_path: T_path, + task_train_kwargs: Optional[Dict[str, Any]] = None, + logging_config: Optional[LightningLoggingConfig] = None, + ) -> "LightningTask": + pass diff --git a/embeddings/task/lightning_task/sequence_labeling.py b/embeddings/task/lightning_task/sequence_labeling.py index 82f58896..9e08da99 100644 --- a/embeddings/task/lightning_task/sequence_labeling.py +++ b/embeddings/task/lightning_task/sequence_labeling.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import numpy as np from numpy import typing as nptyping @@ -15,32 +15,33 @@ def __init__( self, model_name_or_path: T_path, output_path: T_path, + num_classes: int, model_config_kwargs: Dict[str, Any], task_model_kwargs: Dict[str, Any], task_train_kwargs: Dict[str, Any], early_stopping_kwargs: Dict[str, Any], logging_config: LightningLoggingConfig, - train_batch_size: int = 32, - eval_batch_size: int = 32, finetune_last_n_layers: int = -1, ) -> None: super().__init__(output_path, task_train_kwargs, early_stopping_kwargs, logging_config) self.model_name_or_path = model_name_or_path + self.num_classes = num_classes self.model_config_kwargs = model_config_kwargs self.task_model_kwargs = task_model_kwargs - self.train_batch_size = train_batch_size - self.eval_batch_size = eval_batch_size self.finetune_last_n_layers = finetune_last_n_layers def build_task_model(self) -> None: self.model = SequenceLabelingModule( model_name_or_path=self.model_name_or_path, + num_classes=self.num_classes, finetune_last_n_layers=self.finetune_last_n_layers, config_kwargs=self.model_config_kwargs, task_model_kwargs=self.task_model_kwargs, ) - def predict(self, dataloader: DataLoader[Any]) -> Dict[str, nptyping.NDArray[Any]]: + def predict( + self, dataloader: DataLoader[Any], return_names: bool = True + ) -> Dict[str, nptyping.NDArray[Any]]: assert self.model is not None results = self.model.predict(dataloader=dataloader) predictions, ground_truth, probabilities = ( @@ -54,23 +55,35 @@ def predict(self, dataloader: DataLoader[Any]) -> Dict[str, nptyping.NDArray[Any ground_truth[i] = self._map_filter_data(gt, gt) probabilities[i] = [x for x in probs[gt != self.model.ignore_index]] - assert self.trainer is not None - assert hasattr(self.trainer, "datamodule") - names = getattr(self.trainer, "datamodule").target_names - return { + results = { "y_pred": np.array(predictions, dtype=object), "y_true": np.array(ground_truth, dtype=object), "y_probabilities": np.array(probabilities, dtype=object), - "names": np.array(names), + "names": np.array(self.model.target_names), } + return results def _map_filter_data( self, data: nptyping.NDArray[Any], ground_truth_data: nptyping.NDArray[Any] ) -> List[str]: assert self.model is not None - assert self.trainer is not None - assert hasattr(self.trainer, "datamodule") return [ - getattr(self.trainer, "datamodule").id2str(x.item()) + self.model.class_label.int2str(x.item()) for x in data[ground_truth_data != self.model.ignore_index] ] + + @classmethod + def from_checkpoint( + cls, + checkpoint_path: T_path, + output_path: T_path, + task_train_kwargs: Optional[Dict[str, Any]] = None, + logging_config: Optional[LightningLoggingConfig] = None, + ) -> "LightningTask": + return cls.restore_task_model( + checkpoint_path=checkpoint_path, + output_path=output_path, + task_train_kwargs=task_train_kwargs, + lightning_module=SequenceLabelingModule, + logging_config=logging_config, + ) diff --git a/embeddings/task/lightning_task/text_classification.py b/embeddings/task/lightning_task/text_classification.py index 7b91ade8..442e9406 100644 --- a/embeddings/task/lightning_task/text_classification.py +++ b/embeddings/task/lightning_task/text_classification.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional import numpy as np from numpy import typing as nptyping @@ -15,6 +15,7 @@ def __init__( self, model_name_or_path: T_path, output_path: T_path, + num_classes: int, model_config_kwargs: Dict[str, Any], task_model_kwargs: Dict[str, Any], task_train_kwargs: Dict[str, Any], @@ -24,22 +25,41 @@ def __init__( ) -> None: super().__init__(output_path, task_train_kwargs, early_stopping_kwargs, logging_config) self.model_name_or_path = model_name_or_path + self.num_classes = num_classes self.model_config_kwargs = model_config_kwargs self.task_model_kwargs = task_model_kwargs self.finetune_last_n_layers = finetune_last_n_layers + self.task_train_kwargs = task_train_kwargs def build_task_model(self) -> None: self.model = TextClassificationModule( model_name_or_path=self.model_name_or_path, + num_classes=self.num_classes, finetune_last_n_layers=self.finetune_last_n_layers, config_kwargs=self.model_config_kwargs, task_model_kwargs=self.task_model_kwargs, ) - def predict(self, dataloader: DataLoader[Any]) -> Dict[str, nptyping.NDArray[Any]]: + def predict( + self, dataloader: DataLoader[Any], return_names: bool = True + ) -> Dict[str, nptyping.NDArray[Any]]: assert self.model is not None results = self.model.predict(dataloader=dataloader) - assert self.trainer is not None - assert hasattr(self.trainer, "datamodule") - results["names"] = np.array(getattr(self.trainer, "datamodule").target_names) + results["names"] = np.array(self.model.target_names) return results + + @classmethod + def from_checkpoint( + cls, + checkpoint_path: T_path, + output_path: T_path, + task_train_kwargs: Optional[Dict[str, Any]] = None, + logging_config: Optional[LightningLoggingConfig] = None, + ) -> "LightningTask": + return cls.restore_task_model( + checkpoint_path=checkpoint_path, + output_path=output_path, + task_train_kwargs=task_train_kwargs, + lightning_module=TextClassificationModule, + logging_config=logging_config, + ) diff --git a/examples/evaluate_sequence_labelling.py b/examples/evaluate_sequence_labelling.py index f5484ef0..d92d11e7 100644 --- a/examples/evaluate_sequence_labelling.py +++ b/examples/evaluate_sequence_labelling.py @@ -43,7 +43,6 @@ def run( input_column_name=input_column_name, target_column_name=target_column_name, output_path=output_path, - hidden_size=hidden_size, evaluation_mode=evaluation_mode, tagging_scheme=tagging_scheme, ) diff --git a/tests/test_lightning_classification_pipeline.py b/tests/test_lightning_classification_pipeline.py index 5dc064f7..64257bbd 100644 --- a/tests/test_lightning_classification_pipeline.py +++ b/tests/test_lightning_classification_pipeline.py @@ -6,12 +6,14 @@ import numpy as np import pytest import pytorch_lightning as pl +import torch from _pytest.tmpdir import TempdirFactory from embeddings.config.lightning_config import LightningAdvancedConfig from embeddings.pipeline.hf_preprocessing_pipeline import HuggingFacePreprocessingPipeline from embeddings.pipeline.lightning_classification import LightningClassificationPipeline from embeddings.pipeline.lightning_pipeline import LightningPipeline +from embeddings.task.lightning_task.text_classification import TextClassificationTask @pytest.fixture(scope="module") @@ -21,13 +23,7 @@ def tmp_path_module(tmpdir_factory: TempdirFactory) -> Path: @pytest.fixture(scope="module") -def pipeline_kwargs() -> Dict[str, Any]: - return {"embedding_name_or_path": "allegro/herbert-base-cased"} - - -@pytest.fixture(scope="module") -def dataset_kwargs(tmp_path_module) -> Dict[str, Any]: - path = str(tmp_path_module) +def dataset_kwargs(tmp_path_module: Path) -> Dict[str, Any]: pipeline = HuggingFacePreprocessingPipeline( dataset_name="clarin-pl/polemo2-official", load_dataset_kwargs={ @@ -36,7 +32,7 @@ def dataset_kwargs(tmp_path_module) -> Dict[str, Any]: "test_domains": ["hotels", "medicine"], "text_cfg": "text", }, - persist_path=path, + persist_path=str(tmp_path_module), sample_missing_splits=None, ignore_test_subset=False, downsample_splits=(0.01, 0.01, 0.05), @@ -45,7 +41,7 @@ def dataset_kwargs(tmp_path_module) -> Dict[str, Any]: pipeline.run() return { - "dataset_name_or_path": path, + "dataset_name_or_path": tmp_path_module, "input_column_name": ["text"], "target_column_name": "target", } @@ -88,33 +84,34 @@ def config() -> LightningAdvancedConfig: def lightning_classification_pipeline( dataset_kwargs: Dict[str, Any], config: LightningAdvancedConfig, - result_path: "TemporaryDirectory[str]", -) -> Tuple[ - LightningPipeline[datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any]], - "TemporaryDirectory[str]", -]: - return ( - LightningClassificationPipeline( - embedding_name_or_path="allegro/herbert-base-cased", - output_path=result_path.name, - config=config, - devices="auto", - accelerator="cpu", - **dataset_kwargs, - ), - result_path, + tmp_path_module: Path, +) -> LightningPipeline[datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any]]: + return LightningClassificationPipeline( + embedding_name_or_path="allegro/herbert-base-cased", + output_path=tmp_path_module, + config=config, + devices="auto", + accelerator="cpu", + **dataset_kwargs, ) def test_lightning_classification_pipeline( - lightning_classification_pipeline: Tuple[ - LightningPipeline[datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any]], - "TemporaryDirectory[str]", + lightning_classification_pipeline: LightningPipeline[ + datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any] ], + tmp_path_module: Path, ) -> None: pl.seed_everything(441, workers=True) - pipeline, path = lightning_classification_pipeline + pipeline = lightning_classification_pipeline result = pipeline.run() + + assert_result_values(result) + assert_result_types(result) + assert_inference_from_checkpoint(result, pipeline, tmp_path_module) + + +def assert_result_values(result: Dict[str, Any]) -> None: np.testing.assert_almost_equal( result["accuracy"]["accuracy"], 0.3783783, decimal=pytest.decimal ) @@ -128,6 +125,8 @@ def test_lightning_classification_pipeline( result["recall__average_macro"]["recall"], 0.2333333, decimal=pytest.decimal ) + +def assert_result_types(result: Dict[str, Any]) -> None: assert "data" in result assert "y_pred" in result["data"] assert "y_true" in result["data"] @@ -141,3 +140,24 @@ def test_lightning_classification_pipeline( assert result["data"]["y_true"].dtype == np.int64 assert result["data"]["y_probabilities"].dtype == np.float32 assert isinstance(result["data"]["names"][0], str) + + +def assert_inference_from_checkpoint( + result: Dict[str, Any], + pipeline: LightningPipeline[datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any]], + tmp_path_module: Path, +) -> None: + ckpt_path = tmp_path_module / "checkpoints" / "last.ckpt" + task_from_ckpt = TextClassificationTask.from_checkpoint( + checkpoint_path=ckpt_path.resolve(), + output_path=tmp_path_module, + ) + + model_state_dict = pipeline.model.task.model.model.state_dict() + model_from_ckpt_state_dict = task_from_ckpt.model.model.state_dict() + assert model_state_dict.keys() == model_from_ckpt_state_dict.keys() + for k in model_state_dict.keys(): + assert torch.equal(model_state_dict[k], model_from_ckpt_state_dict[k]) + + predictions = task_from_ckpt.predict(pipeline.datamodule.test_dataloader()) + assert np.array_equal(result["data"]["y_probabilities"], predictions["y_probabilities"]) diff --git a/tests/test_lightning_sequence_labeling_pipeline.py b/tests/test_lightning_sequence_labeling_pipeline.py index 9804f19a..ed710f14 100644 --- a/tests/test_lightning_sequence_labeling_pipeline.py +++ b/tests/test_lightning_sequence_labeling_pipeline.py @@ -5,12 +5,14 @@ import numpy as np import pytest import pytorch_lightning as pl +import torch from _pytest.tmpdir import TempdirFactory from embeddings.config.lightning_config import LightningAdvancedConfig from embeddings.pipeline.hf_preprocessing_pipeline import HuggingFacePreprocessingPipeline from embeddings.pipeline.lightning_pipeline import LightningPipeline from embeddings.pipeline.lightning_sequence_labeling import LightningSequenceLabelingPipeline +from embeddings.task.lightning_task.sequence_labeling import SequenceLabelingTask @pytest.fixture(scope="module") @@ -20,12 +22,11 @@ def tmp_path_module(tmpdir_factory: TempdirFactory) -> Path: @pytest.fixture(scope="module") -def dataset_kwargs(tmp_path_module) -> Dict[str, Any]: - path = str(tmp_path_module) +def dataset_kwargs(tmp_path_module: Path) -> Dict[str, Any]: pipeline = HuggingFacePreprocessingPipeline( dataset_name="clarin-pl/kpwr-ner", load_dataset_kwargs=None, - persist_path=path, + persist_path=str(tmp_path_module), sample_missing_splits=None, ignore_test_subset=False, downsample_splits=(0.01, 0.01, 0.05), @@ -34,7 +35,7 @@ def dataset_kwargs(tmp_path_module) -> Dict[str, Any]: pipeline.run() return { - "dataset_name_or_path": path, + "dataset_name_or_path": tmp_path_module, "input_column_name": "tokens", "target_column_name": "ner", } @@ -79,28 +80,32 @@ def config() -> LightningAdvancedConfig: def lightning_sequence_labeling_pipeline( dataset_kwargs: Dict[str, Any], config: LightningAdvancedConfig, - tmp_path: Path, -) -> Tuple[LightningPipeline[datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any]], Path]: - return ( - LightningSequenceLabelingPipeline( - output_path=tmp_path, - embedding_name_or_path="allegro/herbert-base-cased", - config=config, - **dataset_kwargs, - ), - tmp_path, + tmp_path_module: Path, +) -> LightningPipeline[datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any]]: + return LightningSequenceLabelingPipeline( + output_path=tmp_path_module, + embedding_name_or_path="allegro/herbert-base-cased", + config=config, + **dataset_kwargs, ) def test_lightning_sequence_labeling_pipeline( - lightning_sequence_labeling_pipeline: Tuple[ - LightningPipeline[datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any]], - Path, + lightning_sequence_labeling_pipeline: LightningPipeline[ + datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any] ], + tmp_path_module: Path, ) -> None: pl.seed_everything(441) - pipeline, path = lightning_sequence_labeling_pipeline + pipeline = lightning_sequence_labeling_pipeline result = pipeline.run() + + assert_result_values(result) + assert_result_types(result) + assert_inference_from_checkpoint(result, pipeline, tmp_path_module) + + +def assert_result_values(result: Dict[str, Any]) -> None: np.testing.assert_almost_equal( result["seqeval__mode_None__scheme_None"]["overall_accuracy"], 0.0015690, @@ -120,6 +125,8 @@ def test_lightning_sequence_labeling_pipeline( decimal=pytest.decimal, ) + +def assert_result_types(result: Dict[str, Any]) -> None: assert "data" in result assert "y_pred" in result["data"] assert "y_true" in result["data"] @@ -138,3 +145,26 @@ def test_lightning_sequence_labeling_pipeline( assert isinstance(result["data"]["y_probabilities"][0][0], np.ndarray) assert isinstance(result["data"]["names"][0], str) assert isinstance(result["data"]["y_probabilities"][0][0][0], np.float32) + + +def assert_inference_from_checkpoint( + result: Dict[str, Any], + pipeline: LightningPipeline[datasets.DatasetDict, Dict[str, np.ndarray], Dict[str, Any]], + tmp_path_module: Path, +) -> None: + ckpt_path = tmp_path_module / "checkpoints" / "last.ckpt" + task_from_ckpt = SequenceLabelingTask.from_checkpoint( + checkpoint_path=ckpt_path.resolve(), + output_path=tmp_path_module, + ) + + model_state_dict = pipeline.model.task.model.model.state_dict() + model_from_ckpt_state_dict = task_from_ckpt.model.model.state_dict() + assert model_state_dict.keys() == model_from_ckpt_state_dict.keys() + for k in model_state_dict.keys(): + assert torch.equal(model_state_dict[k], model_from_ckpt_state_dict[k]) + + predictions = task_from_ckpt.predict(pipeline.datamodule.test_dataloader()) + assert np.array_equal( + result["data"]["y_probabilities"][0][0], predictions["y_probabilities"][0][0] + ) diff --git a/tests/test_sequence_labelling.py b/tests/test_sequence_labelling.py index 5233514b..adebfa4b 100644 --- a/tests/test_sequence_labelling.py +++ b/tests/test_sequence_labelling.py @@ -162,10 +162,24 @@ def test_pos_tagging_pipeline( flair.device = torch.device("cpu") pipeline, path = pos_tagging_pipeline result = pipeline.run() - path.cleanup() np.testing.assert_almost_equal(result["UnitSeqeval"]["overall_f1"], 0.1450381) + task_from_ckpt = SequenceLabeling.from_checkpoint( + checkpoint_path=(Path(path.name) / "final-model.pt"), output_path=path.name + ) + loaded_data = pipeline.data_loader.load(pipeline.dataset) + transformed_data = pipeline.transformation.transform(loaded_data) + test_data = transformed_data.test + + y_pred, loss = task_from_ckpt.predict(test_data) + y_true = task_from_ckpt.get_y(test_data, task_from_ckpt.y_type, task_from_ckpt.y_dictionary) + results_from_ckpt = pipeline.evaluator.evaluate({"y_pred": y_pred, "y_true": y_true}) + + assert np.array_equal(result["data"]["y_pred"], results_from_ckpt["data"]["y_pred"]) + + path.cleanup() + def test_ner_tagging_pipeline( ner_tagging_pipeline: Tuple[ @@ -179,10 +193,24 @@ def test_ner_tagging_pipeline( flair.device = torch.device("cpu") pipeline, path = ner_tagging_pipeline result = pipeline.run() - path.cleanup() np.testing.assert_almost_equal(result["seqeval__mode_None__scheme_None"]["overall_f1"], 0.0) + task_from_ckpt = SequenceLabeling.from_checkpoint( + checkpoint_path=(Path(path.name) / "final-model.pt"), output_path=path.name + ) + loaded_data = pipeline.data_loader.load(pipeline.dataset) + transformed_data = pipeline.transformation.transform(loaded_data) + test_data = transformed_data.test + + y_pred, loss = task_from_ckpt.predict(test_data) + y_true = task_from_ckpt.get_y(test_data, task_from_ckpt.y_type, task_from_ckpt.y_dictionary) + results_from_ckpt = pipeline.evaluator.evaluate({"y_pred": y_pred, "y_true": y_true}) + + assert np.array_equal(result["data"]["y_pred"], results_from_ckpt["data"]["y_pred"]) + + path.cleanup() + def test_pos_tagging_pipeline_local_embedding( pos_tagging_pipeline_local_embedding: Tuple[ @@ -196,9 +224,8 @@ def test_pos_tagging_pipeline_local_embedding( flair.device = torch.device("cpu") pipeline, path = pos_tagging_pipeline_local_embedding result = pipeline.run() - path.cleanup() - np.testing.assert_almost_equal(result["UnitSeqeval"]["overall_f1"], 0.1832061) + path.cleanup() def test_ner_tagging_pipeline_local_embedding( @@ -213,8 +240,7 @@ def test_ner_tagging_pipeline_local_embedding( flair.device = torch.device("cpu") pipeline, path = ner_tagging_pipeline_local_embedding result = pipeline.run() - path.cleanup() - np.testing.assert_almost_equal( result["seqeval__mode_None__scheme_None"]["overall_f1"], 0.0107816 ) + path.cleanup() diff --git a/tests/test_text_classification.py b/tests/test_text_classification.py index 433ebbed..56d895ac 100644 --- a/tests/test_text_classification.py +++ b/tests/test_text_classification.py @@ -63,12 +63,25 @@ def test_text_classification_pipeline( flair.set_seed(441) pipeline, path = text_classification_pipeline result = pipeline.run() - path.cleanup() np.testing.assert_almost_equal(result["accuracy"]["accuracy"], 0.3333333) np.testing.assert_almost_equal(result["f1__average_macro"]["f1"], 0.1666666) np.testing.assert_almost_equal(result["precision__average_macro"]["precision"], 0.1111111) np.testing.assert_almost_equal(result["recall__average_macro"]["recall"], 0.3333333) + task_from_ckpt = TextClassification.from_checkpoint( + checkpoint_path=(Path(path.name) / "final-model.pt"), output_path=path.name + ) + loaded_data = pipeline.data_loader.load(pipeline.dataset) + transformed_data = pipeline.transformation.transform(loaded_data) + test_data = transformed_data.test + + y_pred, loss = task_from_ckpt.predict(test_data) + y_true = task_from_ckpt.get_y(test_data, task_from_ckpt.y_type, task_from_ckpt.y_dictionary) + results_from_ckpt = pipeline.evaluator.evaluate({"y_pred": y_pred, "y_true": y_true}) + assert np.array_equal(result["data"]["y_pred"], results_from_ckpt["data"]["y_pred"]) + + path.cleanup() + @pytest.fixture(scope="module") def text_classification_pipeline_local_embedding( @@ -111,8 +124,22 @@ def test_text_classification_pipeline_local_embedding( flair.set_seed(441) pipeline, path = text_classification_pipeline_local_embedding result = pipeline.run() - path.cleanup() + np.testing.assert_almost_equal(result["accuracy"]["accuracy"], 0.3333333) np.testing.assert_almost_equal(result["f1__average_macro"]["f1"], 0.3333333) np.testing.assert_almost_equal(result["precision__average_macro"]["precision"], 0.3333333) np.testing.assert_almost_equal(result["recall__average_macro"]["recall"], 0.3333333) + + task_from_ckpt = TextClassification.from_checkpoint( + checkpoint_path=(Path(path.name) / "final-model.pt"), output_path=path.name + ) + loaded_data = pipeline.data_loader.load(pipeline.dataset) + transformed_data = pipeline.transformation.transform(loaded_data) + test_data = transformed_data.test + + y_pred, loss = task_from_ckpt.predict(test_data) + y_true = task_from_ckpt.get_y(test_data, task_from_ckpt.y_type, task_from_ckpt.y_dictionary) + results_from_ckpt = pipeline.evaluator.evaluate({"y_pred": y_pred, "y_true": y_true}) + assert np.array_equal(result["data"]["y_pred"], results_from_ckpt["data"]["y_pred"]) + + path.cleanup() diff --git a/tutorials/validate_flair_models_inference.ipynb b/tutorials/validate_flair_models_inference.ipynb new file mode 100644 index 00000000..22858c62 --- /dev/null +++ b/tutorials/validate_flair_models_inference.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "418b9661-aea2-4990-8e26-e7f0e167b9b2", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f23bfd51-d1f4-4321-aed9-96f51b171fe9", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")\n", + "\n", + "from embeddings.data.data_loader import HuggingFaceDataLoader\n", + "from embeddings.defaults import DATASET_PATH, RESULTS_PATH\n", + "from embeddings.embedding.auto_flair import AutoFlairWordEmbedding\n", + "from embeddings.evaluator.sequence_labeling_evaluator import SequenceLabelingEvaluator\n", + "from embeddings.model.flair_model import FlairModel\n", + "from embeddings.pipeline.standard_pipeline import StandardPipeline\n", + "from embeddings.task.flair_task.sequence_labeling import SequenceLabeling\n", + "from embeddings.transformation.flair_transformation.column_corpus_transformation import (\n", + " ColumnCorpusTransformation,\n", + ")\n", + "from embeddings.data.dataset import Dataset\n", + "\n", + "from embeddings.transformation.flair_transformation.downsample_corpus_transformation import (\n", + " DownsampleFlairCorpusTransformation,\n", + ")\n", + "from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (\n", + " SampleSplitsFlairCorpusTransformation,\n", + ")\n", + "from embeddings.utils.utils import build_output_path" + ] + }, + { + "cell_type": "markdown", + "id": "5e4c2372-8314-4868-a576-8f0988aae888", + "metadata": {}, + "source": [ + "### Run downsampled flair pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd4fb7d6-1e81-4bea-9bd3-b4a4bec87fc9", + "metadata": {}, + "outputs": [], + "source": [ + "embedding_name_or_path = \"clarin-pl/word2vec-kgr10\"\n", + "dataset_name = \"clarin-pl/kpwr-ner\"\n", + "\n", + "output_path = build_output_path(RESULTS_PATH, embedding_name_or_path, dataset_name)\n", + "\n", + "dataset = Dataset(dataset_name)\n", + "data_loader = HuggingFaceDataLoader()\n", + "transformation = (\n", + " ColumnCorpusTransformation(\"tokens\", \"ner\")\n", + " .then(SampleSplitsFlairCorpusTransformation(dev_fraction=0.1, seed=441))\n", + " .then(DownsampleFlairCorpusTransformation(downsample_train=0.005, downsample_dev=0.01, downsample_test=0.01))\n", + ")\n", + "task = SequenceLabeling(\n", + " output_path,\n", + " hidden_size=256,\n", + " task_train_kwargs={\"max_epochs\": 1, \"mini_batch_size\": 64},\n", + ")\n", + "embedding = AutoFlairWordEmbedding.from_hub(embedding_name_or_path)\n", + "model = FlairModel(embedding, task)\n", + "evaluator = SequenceLabelingEvaluator()\n", + "\n", + "pipeline = StandardPipeline(dataset, data_loader, transformation, model, evaluator)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f69e538-332d-4278-977d-7002fe2b67bd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "_ = pipeline.run()" + ] + }, + { + "cell_type": "markdown", + "id": "44613ef9-a9d4-4d5c-980c-9e0f68bc3525", + "metadata": {}, + "source": [ + "### Load model from checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ffae7c5-1734-4e4a-81bd-55170a5c14ca", + "metadata": {}, + "outputs": [], + "source": [ + "!ls $output_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e75f303-fb82-4cfd-9a81-5d42e994e606", + "metadata": {}, + "outputs": [], + "source": [ + "task_from_ckpt = SequenceLabeling.from_checkpoint(checkpoint_path=(output_path / \"final-model.pt\"), output_path=output_path)" + ] + }, + { + "cell_type": "markdown", + "id": "802762c3-8246-465b-bb9c-2336134a51bd", + "metadata": {}, + "source": [ + "### Predict for test data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c0fadf3-aa47-407d-ad79-e5633532eafa", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "loaded_data = data_loader.load(dataset)\n", + "transformed_data = transformation.transform(loaded_data)\n", + "test_data = transformed_data.test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df0d1eda-5bf0-40d7-97ed-16eab007a0f3", + "metadata": {}, + "outputs": [], + "source": [ + "y_pred, loss = task_from_ckpt.predict(test_data)\n", + "y_true = task_from_ckpt.get_y(test_data, task_from_ckpt.y_type, task_from_ckpt.y_dictionary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f9d6dde-601a-4d0c-bb48-e9177e7002c9", + "metadata": {}, + "outputs": [], + "source": [ + "evaluator.evaluate({\"y_pred\": y_pred, \"y_true\": y_true})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:clarinpl-embeddings]", + "language": "python", + "name": "conda-env-clarinpl-embeddings-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/tutorials/validate_lightning_models_inference.ipynb b/tutorials/validate_lightning_models_inference.ipynb new file mode 100644 index 00000000..ee0cab56 --- /dev/null +++ b/tutorials/validate_lightning_models_inference.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3d3ac2b5-06e8-46bc-a626-9384a35920e5", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1019b750-cebe-438b-b1ab-434d6f756864", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")\n", + "from typing import Any, Dict\n", + "\n", + "import pytorch_lightning as pl\n", + "from embeddings.config.lightning_config import LightningAdvancedConfig\n", + "from embeddings.defaults import DATASET_PATH, RESULTS_PATH\n", + "from embeddings.model.lightning_module.text_classification import (\n", + " TextClassificationModule,\n", + ")\n", + "from embeddings.pipeline.hf_preprocessing_pipeline import (\n", + " HuggingFacePreprocessingPipeline,\n", + ")\n", + "from embeddings.pipeline.lightning_classification import LightningClassificationPipeline\n", + "from embeddings.task.lightning_task.text_classification import TextClassificationTask\n", + "from embeddings.utils.utils import build_output_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d0e06e2-3c5a-420b-b065-31d5ccd6b255", + "metadata": {}, + "outputs": [], + "source": [ + "embedding_name_or_path = \"hf-internal-testing/tiny-albert\"\n", + "dataset_name = \"clarin-pl/polemo2-official\"\n", + "\n", + "dataset_path = build_output_path(DATASET_PATH, embedding_name_or_path, dataset_name)\n", + "output_path = build_output_path(RESULTS_PATH, embedding_name_or_path, dataset_name)" + ] + }, + { + "cell_type": "markdown", + "id": "b6d0098c-41ec-473a-954a-709f7fb05922", + "metadata": {}, + "source": [ + "### Preprocess and downsample data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "095d1c88-900f-4275-a879-f9efdb73265a", + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_data(path: str) -> Dict[str, Any]:\n", + " pipeline = HuggingFacePreprocessingPipeline(\n", + " dataset_name=dataset_name,\n", + " load_dataset_kwargs={\n", + " \"train_domains\": [\"hotels\", \"medicine\"],\n", + " \"dev_domains\": [\"hotels\", \"medicine\"],\n", + " \"test_domains\": [\"hotels\", \"medicine\"],\n", + " \"text_cfg\": \"text\",\n", + " },\n", + " persist_path=path,\n", + " sample_missing_splits=None,\n", + " ignore_test_subset=False,\n", + " downsample_splits=(0.01, 0.01, 0.05),\n", + " seed=441,\n", + " )\n", + " pipeline.run()\n", + "\n", + " return {\n", + " \"dataset_name_or_path\": path,\n", + " \"input_column_name\": [\"text\"],\n", + " \"target_column_name\": \"target\",\n", + " }\n", + "\n", + "\n", + "dataset_kwargs = preprocess_data(dataset_path)" + ] + }, + { + "cell_type": "markdown", + "id": "159445cd-fb59-4964-aca2-ce9c18a8cf5e", + "metadata": {}, + "source": [ + "### Train simple downsampled pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cb7ebd4-182c-4797-b5de-a7069313a901", + "metadata": {}, + "outputs": [], + "source": [ + "config = LightningAdvancedConfig(\n", + " finetune_last_n_layers=0,\n", + " task_train_kwargs={\"max_epochs\": 1, \"deterministic\": True,},\n", + " task_model_kwargs={\n", + " \"learning_rate\": 5e-4,\n", + " \"train_batch_size\": 32,\n", + " \"eval_batch_size\": 32,\n", + " \"use_scheduler\": True,\n", + " \"optimizer\": \"AdamW\",\n", + " \"adam_epsilon\": 1e-8,\n", + " \"warmup_steps\": 100,\n", + " \"weight_decay\": 0.0,\n", + " },\n", + " datamodule_kwargs={\"max_seq_length\": 64,},\n", + " early_stopping_kwargs={\"monitor\": \"val/Loss\", \"mode\": \"min\", \"patience\": 3,},\n", + " tokenizer_kwargs={},\n", + " batch_encoding_kwargs={},\n", + " dataloader_kwargs={},\n", + " model_config_kwargs={},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "148a0089-f461-4948-93fa-04f2e34ac9e0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "pipeline = LightningClassificationPipeline(\n", + " embedding_name_or_path=embedding_name_or_path,\n", + " output_path=output_path,\n", + " config=config,\n", + " devices=\"auto\",\n", + " accelerator=\"cpu\",\n", + " **dataset_kwargs\n", + ")\n", + "result = pipeline.run()" + ] + }, + { + "cell_type": "markdown", + "id": "491215dc-9960-4ad0-bc14-6d61d1fafac8", + "metadata": {}, + "source": [ + "### Load model from chechpoint automatically generated with Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee9e824c-00f1-45b0-9e32-1bd33f364f3a", + "metadata": {}, + "outputs": [], + "source": [ + "ckpt_path = output_path / \"checkpoints\" / \"last.ckpt\"\n", + "ckpt_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2785fcbc-1c95-4d23-807f-a14569992354", + "metadata": {}, + "outputs": [], + "source": [ + "task_from_ckpt = TextClassificationTask.from_checkpoint(\n", + " checkpoint_path=ckpt_path, output_path=output_path,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "13272a49-8ef5-41af-80a3-5cf3b7b677c7", + "metadata": {}, + "source": [ + "#### Alternatively we can load the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b32fd93-e43d-4c42-961e-53232bf9e02e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "model_from_ckpt = TextClassificationModule.load_from_checkpoint(str(ckpt_path))" + ] + }, + { + "cell_type": "markdown", + "id": "103e7972-c386-4c44-9b58-0385213f20f8", + "metadata": {}, + "source": [ + "The warning appears when loading the model, however, it was validated that the loaded weights are the same as the weights that are being saved. The reason for this is that when the model_state_dict keys are loaded from the cached huggingface model some of them (cls.(...)) do not match the keys from the state_dict of the model weights that are saved.\n", + "\n", + "https://github.com/CLARIN-PL/embeddings/issues/225" + ] + }, + { + "cell_type": "markdown", + "id": "88e7a6c7-449f-4d0c-9042-a5f98aebc14b", + "metadata": {}, + "source": [ + "### Use task from checkpoint for predictions" + ] + }, + { + "cell_type": "markdown", + "id": "c5eeab69-e13c-4ba4-b0ea-2473555915d9", + "metadata": {}, + "source": [ + "`return_names` needs to be set to False since it uses the `datamodule` to retrieves the names while the datamodule is not loaded to `Trainer` in the `LightningTask` since we have not fitted it yet." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ad7b9b0-823a-4c8e-aac5-61a333558ed1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "test_dataloader = pipeline.datamodule.test_dataloader()\n", + "preds = task_from_ckpt.predict(test_dataloader)\n", + "preds" + ] + }, + { + "cell_type": "markdown", + "id": "9c789d71-2368-4add-8a7b-f51571aecfbd", + "metadata": {}, + "source": [ + "Alternatively we can implicitly assign the `datamodule` to `Trainer` in `LightningTask`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9836dc5d-8ee2-46fc-b7d8-94841cc13ce5", + "metadata": {}, + "outputs": [], + "source": [ + "task_from_ckpt.trainer.datamodule = pipeline.datamodule\n", + "preds_with_names = task_from_ckpt.predict(test_dataloader, return_names=True)\n", + "preds_with_names" + ] + }, + { + "cell_type": "markdown", + "id": "29c321e2-9ecc-4b65-936b-c8e7cca1155a", + "metadata": {}, + "source": [ + "We can also use previosly loaded lightning model (`LightningModule`) outside of the task and get the predictions. To do this we also need to intitialize a `Trainer`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3afa250-2937-4aad-bb3c-172a68639892", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(default_root_dir=str(output_path))\n", + "preds_from_model = trainer.predict(model_from_ckpt, dataloaders=test_dataloader)\n", + "preds_from_model" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:embeddings]", + "language": "python", + "name": "conda-env-embeddings-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}