From 18209118b1062676e7b121f318503291072d845e Mon Sep 17 00:00:00 2001 From: franz101 Date: Thu, 20 Jul 2023 10:03:22 -0700 Subject: [PATCH] Fixed transformers move to accelerate (#717) Changed the ids to a queue --- dataquality/__init__.py | 2 +- dataquality/integrations/torch.py | 17 +- .../torch_semantic_segmentation.py | 37 ++- .../integrations/transformers_trainer.py | 9 +- dataquality/utils/torch.py | 242 ++++++++++-------- dataquality/utils/transformers.py | 7 +- pyproject.toml | 2 +- tasks.py | 2 +- tests/conftest.py | 11 +- .../hf/test_text_classification_hf.py | 19 +- 10 files changed, 208 insertions(+), 140 deletions(-) diff --git a/dataquality/__init__.py b/dataquality/__init__.py index ed65ad371..19a4b7cf3 100644 --- a/dataquality/__init__.py +++ b/dataquality/__init__.py @@ -31,7 +31,7 @@ """ -__version__ = "0.9.7" +__version__ = "0.9.8" import sys from typing import Any, List, Optional diff --git a/dataquality/integrations/torch.py b/dataquality/integrations/torch.py index 2b5e7652e..7130d6d37 100644 --- a/dataquality/integrations/torch.py +++ b/dataquality/integrations/torch.py @@ -150,21 +150,22 @@ def _on_step_end(self) -> None: extracted in the hooks and we need to log them in the on_step_end method. """ + model_outputs_store = self.torch_helper_data.model_outputs_store # Workaround for multiprocessing - if model_outputs_store.get("ids") is None and len( + if model_outputs_store.ids is None and len( self.torch_helper_data.dl_next_idx_ids ): - model_outputs_store["ids"] = self.torch_helper_data.dl_next_idx_ids.pop(0) + model_outputs_store.ids = self.torch_helper_data.dl_next_idx_ids.pop(0) # Log only if embedding exists - assert model_outputs_store.get("embs") is not None, GalileoException( + assert model_outputs_store.embs is not None, GalileoException( "Embedding passed to the logger can not be logged" ) - assert model_outputs_store.get("logits") is not None, GalileoException( + assert model_outputs_store.logits is not None, GalileoException( "Logits passed to the logger can not be logged" ) - assert model_outputs_store.get("ids") is not None, GalileoException( + assert model_outputs_store.ids is not None, GalileoException( "id column missing in dataset (needed to map rows to the indices/ids)" ) # Convert the indices to ids @@ -173,10 +174,10 @@ def _on_step_end(self) -> None: "Current split must be set before logging" ) cur_split = cur_split.lower() # type: ignore - model_outputs_store["ids"] = map_indices_to_ids( - self.logger_config.idx_to_id_map[cur_split], model_outputs_store["ids"] + model_outputs_store.ids = map_indices_to_ids( + self.logger_config.idx_to_id_map[cur_split], model_outputs_store.ids ) - dq.log_model_outputs(**model_outputs_store) + dq.log_model_outputs(**model_outputs_store.to_dict()) model_outputs_store.clear() diff --git a/dataquality/integrations/torch_semantic_segmentation.py b/dataquality/integrations/torch_semantic_segmentation.py index afe200cdf..d8842cc2a 100644 --- a/dataquality/integrations/torch_semantic_segmentation.py +++ b/dataquality/integrations/torch_semantic_segmentation.py @@ -169,8 +169,12 @@ def _dq_logit_hook( logits = model_output["out"] else: logits = model_output + if not isinstance(logits, Tensor): + raise ValueError( + "Logits are not a tensor. Please ensure the logits are a tensor." + ) model_outputs_store = self.torch_helper_data.model_outputs_store - model_outputs_store["logits"] = logits + model_outputs_store.logits = logits def _dq_classifier_hook_with_step_end( self, @@ -205,7 +209,7 @@ def _dq_input_hook( """ # model input comes as a tuple of length 1 - self.torch_helper_data.model_input = model_input[0].detach().cpu().numpy() + self.torch_helper_data.model_input = model_input[0].detach().cpu() def get_image_ids_and_image_paths( self, split: str, logging_data: Dict[str, Any] @@ -359,11 +363,21 @@ def get_argmax_probs( Tuple[torch.Tensor, torch.Tensor]: argmax and logits tensors """ # resize the logits to the input size based on hooks - preds = self.torch_helper_data.model_outputs_store["logits"] + preds = self.torch_helper_data.model_outputs_store.logits + if preds is None: + raise ValueError( + "Logits are missing in dataquality," + " have connected to the right model layer?" + ) + elif not isinstance(preds, Tensor): + raise ValueError( + f"Logits are not a tensor. Please ensure the logits are a tensor. \ + Got {type(preds)}" + ) if preds.dtype == torch.float16: preds = preds.to(torch.float32) input_shape = self.torch_helper_data.model_input.shape[-2:] - preds = F.interpolate(preds, size=input_shape, mode="bilinear") + preds = Tensor(F.interpolate(preds, size=input_shape, mode="bilinear")) # checks whether the model is (n, classes, w, h), or (n, w, h, classes) # takes the max in case of binary classification @@ -394,9 +408,18 @@ def _on_step_end(self) -> None: # if we have not inferred the number of classes from the model architecture # takes the max of the logits shape and 2 in case of binary classification - self.number_classes = max( - self.torch_helper_data.model_outputs_store["logits"].shape[1], 2 - ) + logits = self.torch_helper_data.model_outputs_store.logits + if logits is None: + raise ValueError( + "Logits are missing in dataquality," + " have connected to the right model layer?" + ) + elif not isinstance(logits, Tensor): + raise ValueError( + f"Logits are not a tensor. Please ensure the logits are a tensor. \ + Got {type(logits)}" + ) + self.number_classes = max(logits.shape[1], 2) if not self.init_lm_labels_flag: self._init_lm_labels() self.init_lm_labels_flag = True diff --git a/dataquality/integrations/transformers_trainer.py b/dataquality/integrations/transformers_trainer.py index a66969e3c..2b2522b93 100644 --- a/dataquality/integrations/transformers_trainer.py +++ b/dataquality/integrations/transformers_trainer.py @@ -80,20 +80,21 @@ def __init__( def _do_log(self) -> None: """Log the model outputs (called by the hook)""" # Log only if embedding exists - assert self.model_outputs_store.get("embs") is not None, GalileoException( + self.model_outputs_store.ids = self.model_outputs_store.ids_queue.pop(0) + assert self.model_outputs_store.embs is not None, GalileoException( "Embedding passed to the logger can not be logged" ) - assert self.model_outputs_store.get("logits") is not None, GalileoException( + assert self.model_outputs_store.logits is not None, GalileoException( "Logits passed to the logger can not be logged" ) - assert self.model_outputs_store.get("ids") is not None, GalileoException( + assert self.model_outputs_store.ids is not None, GalileoException( "Did you map IDs to your dataset before watching the model? You can run:\n" "`ds= dataset.map(lambda x, idx: {'id': idx}, with_indices=True)`\n" "id (index) column is needed in the dataset for logging" ) # 🔭🌕 Galileo logging - dq.log_model_outputs(**self.model_outputs_store) + dq.log_model_outputs(**self.model_outputs_store.to_dict()) self.model_outputs_store.clear() def validate( diff --git a/dataquality/utils/torch.py b/dataquality/utils/torch.py index 22372a003..6d1b410b7 100644 --- a/dataquality/utils/torch.py +++ b/dataquality/utils/torch.py @@ -8,6 +8,7 @@ from warnings import warn import numpy as np # noqa: F401 +import torch from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader @@ -26,13 +27,137 @@ from dataquality.utils.patcher import Borg, Patch, PatchManager +class ModelHookManager(Borg): + """ + Manages hooks for models. Has the ability to find the layer automatically. + Otherwise the layer or the layer name needs to be provided. + """ + + # Stores all hooks to remove them from the model later. + def __init__(self) -> None: + """Class to manage patches""" + super().__init__() + if not hasattr(self, "initialized"): + self.hooks: List[RemovableHandle] = [] + + def get_embedding_layer_auto(self, model: Module) -> Module: + """ + Use a scoring algorithm to find the embedding layer automatically + given a model. The higher the score the more likely it is the embedding layer. + """ + name, layer = next(model.named_children()) + print(f'Selected layer for the last hidden state embedding "{name}"') + return layer + + def get_layer_by_name(self, model: Module, name: str) -> Module: + """ + Iterate over each layer and stop once the the layer name matches + :param model: Model + :parm name: string + """ + queue: Queue = Queue() + start = model.named_children() + queue.put(start) + layer_names = [] + layer_names_str: str = "" + while not queue.empty(): + named_children = queue.get() + for layer_name, layer_model in named_children: + layer_names.append(layer_name) + layer_names_str = ", ".join(layer_names) + if layer_name == name: + print( + f'Found layer "{layer_name}" in ' + f'model layers: "{layer_names_str}"' + ) + return layer_model + layer_model._get_name() + queue.put(layer_model.named_children()) + raise GalileoException( + f"Layer could not be found in layers: {layer_names_str}. " + "make sure to check capitalization or pass layer directly." + ) + + def attach_hooks_to_model( + self, + model: Module, + hook_fn: Callable, + model_layer: Optional[Layer] = None, + ) -> RemovableHandle: + """Attach hook and save it in our hook list""" + if model_layer is None: + selected_layer = self.get_embedding_layer_auto(model) + elif isinstance(model_layer, str): + selected_layer = self.get_layer_by_name(model, model_layer) + else: + selected_layer = model_layer + return self.attach_hook(selected_layer, hook_fn) + + def attach_classifier_hook( + self, + model: Module, + classifier_hook: Callable, + model_layer: Optional[Layer] = None, + ) -> RemovableHandle: + """Attach hook and save it in our hook list""" + if model_layer is None: + try: + selected_layer = self.get_layer_by_name(model, "classifier") + except GalileoException: + selected_layer = self.get_layer_by_name(model, "fc") + elif isinstance(model_layer, str): + selected_layer = self.get_layer_by_name(model, model_layer) + else: + selected_layer = model_layer + + return self.attach_hook(selected_layer, classifier_hook) + + def attach_hook(self, selected_layer: Module, hook: Callable) -> RemovableHandle: + """Register a hook and save it in our hook list""" + self.initialized = True + h = selected_layer.register_forward_hook(hook) + self.hooks.append(h) + return h + + def detach_hooks(self) -> None: + """Remove all hooks from the model""" + for h in self.hooks: + h.remove() + self.hooks = [] + + +@dataclass +class ModelOutputsStore: + embs: Optional[Tensor] = None + logits: Optional[Union[Tensor, Tuple[Tuple]]] = None + ids_queue: List[List[int]] = field(default_factory=list) + ids: Optional[List[int]] = None + + def clear(self) -> None: + """Resets the arrays of the class.""" + self.embs = None + self.logits = None + self.ids = None + + def clear_all(self) -> None: + """Resets the arrays of the class.""" + self.embs = None + self.logits = None + self.ids = None + self.ids_queue.clear() + + def to_dict(self) -> Dict[str, Any]: + """Returns the class as a dictionary.""" + return {"embs": self.embs, "logits": self.logits, "ids": self.ids} + + @dataclass class TorchHelper: model: Optional[Any] = None - hook_manager: Optional[Any] = None - model_outputs_store: Dict[str, Any] = field(default_factory=dict) + hook_manager: Optional[ModelHookManager] = None + model_outputs_store: ModelOutputsStore = field(default_factory=ModelOutputsStore) dl_next_idx_ids: List[Any] = field(default_factory=list) - model_input: Any = np.empty(0) + model_input: Tensor = torch.empty(0) batch: Dict[str, Any] = field(default_factory=dict) ids: List[Any] = field(default_factory=list) patches: List[Dict] = field(default_factory=list) @@ -42,7 +167,7 @@ def clear(self) -> None: self.dl_next_idx_ids.clear() self.model_outputs_store.clear() self.ids.clear() - self.model_input = np.empty(0) + self.model_input = torch.empty(0) self.batch.clear() @@ -140,7 +265,7 @@ def _dq_embedding_hook( # for NER tasks output_detached = output_detached[:, 1:, :] - self.torch_helper_data.model_outputs_store["embs"] = output_detached + self.torch_helper_data.model_outputs_store.embs = output_detached def _dq_logit_hook( self, @@ -182,7 +307,7 @@ def _dq_logit_hook( # through this dimension for NER tasks logits = logits[:, 1:, :] - self.torch_helper_data.model_outputs_store["logits"] = logits + self.torch_helper_data.model_outputs_store.logits = logits def _classifier_hook( self, @@ -267,105 +392,6 @@ def convert_fancy_idx_str_to_slice( return eval("np.s_[{}]".format(clean_str)) -class ModelHookManager(Borg): - """ - Manages hooks for models. Has the ability to find the layer automatically. - Otherwise the layer or the layer name needs to be provided. - """ - - # Stores all hooks to remove them from the model later. - def __init__(self) -> None: - """Class to manage patches""" - super().__init__() - if not hasattr(self, "initialized"): - self.hooks: List[RemovableHandle] = [] - - def get_embedding_layer_auto(self, model: Module) -> Module: - """ - Use a scoring algorithm to find the embedding layer automatically - given a model. The higher the score the more likely it is the embedding layer. - """ - name, layer = next(model.named_children()) - print(f'Selected layer for the last hidden state embedding "{name}"') - return layer - - def get_layer_by_name(self, model: Module, name: str) -> Module: - """ - Iterate over each layer and stop once the the layer name matches - :param model: Model - :parm name: string - """ - queue: Queue = Queue() - start = model.named_children() - queue.put(start) - layer_names = [] - layer_names_str: str = "" - while not queue.empty(): - named_children = queue.get() - for layer_name, layer_model in named_children: - layer_names.append(layer_name) - layer_names_str = ", ".join(layer_names) - if layer_name == name: - print( - f'Found layer "{layer_name}" in ' - f'model layers: "{layer_names_str}"' - ) - return layer_model - layer_model._get_name() - queue.put(layer_model.named_children()) - raise GalileoException( - f"Layer could not be found in layers: {layer_names_str}. " - "make sure to check capitalization or pass layer directly." - ) - - def attach_hooks_to_model( - self, - model: Module, - hook_fn: Callable, - model_layer: Optional[Layer] = None, - ) -> RemovableHandle: - """Attach hook and save it in our hook list""" - if model_layer is None: - selected_layer = self.get_embedding_layer_auto(model) - elif isinstance(model_layer, str): - selected_layer = self.get_layer_by_name(model, model_layer) - else: - selected_layer = model_layer - return self.attach_hook(selected_layer, hook_fn) - - def attach_classifier_hook( - self, - model: Module, - classifier_hook: Callable, - model_layer: Optional[Layer] = None, - ) -> RemovableHandle: - """Attach hook and save it in our hook list""" - if model_layer is None: - try: - selected_layer = self.get_layer_by_name(model, "classifier") - except GalileoException: - selected_layer = self.get_layer_by_name(model, "fc") - elif isinstance(model_layer, str): - selected_layer = self.get_layer_by_name(model, model_layer) - else: - selected_layer = model_layer - - return self.attach_hook(selected_layer, classifier_hook) - - def attach_hook(self, selected_layer: Module, hook: Callable) -> RemovableHandle: - """Register a hook and save it in our hook list""" - self.initialized = True - h = selected_layer.register_forward_hook(hook) - self.hooks.append(h) - return h - - def detach_hooks(self) -> None: - """Remove all hooks from the model""" - for h in self.hooks: - h.remove() - self.hooks = [] - - def unpatch(patches: List[Dict[str, Any]] = []) -> None: """ Unpatch all patched classes and instances @@ -497,7 +523,7 @@ class PatchSingleDataloaderIterator(Patch): def __init__( self, dataloader_cls: DataLoader, - store: Dict[str, List[int]], + store: ModelOutputsStore, fn_name: str = "_get_iterator", ): """Initializes the class with a collate function, @@ -537,7 +563,7 @@ class PatchSingleDataloaderNextIndex(Patch): def __init__( self, dataloader_cls: DataLoader, - store: Dict[str, List[int]], + store: ModelOutputsStore, fn_name: str = "_get_iterator", ): """Initializes the class with a collate function, @@ -566,7 +592,7 @@ def _unpatch(self) -> None: def __call__(self, *args: Any, **kwargs: Any) -> List[dict]: indices = self._original_fn(*args, **kwargs) if indices: - self.store["ids"] = indices + self.store.ids = indices return indices diff --git a/dataquality/utils/transformers.py b/dataquality/utils/transformers.py index daee850fd..94c60859c 100644 --- a/dataquality/utils/transformers.py +++ b/dataquality/utils/transformers.py @@ -1,10 +1,11 @@ import inspect -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, List, Union from torch.nn import Module from transformers import Trainer from dataquality.utils.patcher import Patch +from dataquality.utils.torch import ModelOutputsStore class RemoveIdCollatePatch(Patch): @@ -15,7 +16,7 @@ def __init__( self, trainer_cls: Trainer, keep_cols: List[str], - store: Dict[str, List[int]], + store: ModelOutputsStore, fn_name: str = "data_collator", ): """Initializes the class with a collate function, @@ -59,7 +60,7 @@ def __call__(self, rows: List[dict]) -> List[dict]: elif len(self.keep_cols) == 0: clean_row[key] = value clean_rows.append(clean_row) - self.store["ids"] = ids + self.store.ids_queue.append(ids) return self._original_collate_fn(clean_rows) diff --git a/pyproject.toml b/pyproject.toml index 322949cdc..53507ffb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "cachetools>=4.2.4", "importlib-metadata<6.0.1", "datasets>=2.6", - "transformers>=4.17.0,<4.31.0", + "transformers>=4.17.0", "seqeval", "sentence-transformers>=2.2", "Pillow", diff --git a/tasks.py b/tasks.py index 9b3a5fbe1..750acff30 100644 --- a/tasks.py +++ b/tasks.py @@ -231,7 +231,7 @@ def update_version_number(ctx: Context, part: Optional[BumpType] = None) -> None with open(VERSION_FILE, "w") as f: for line in lines: if line.startswith("__version__"): - f.write(f'__version__ = "v{new_version}"\n') + f.write(f'__version__ = "{new_version}"\n') else: f.write(line) print(f"New version: {new_version}") diff --git a/tests/conftest.py b/tests/conftest.py index dd4377238..ad0b43bd8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,21 +1,28 @@ import os import shutil +import warnings from typing import Any, Callable, Dict, Generator, List, Optional, Union from uuid import UUID import pytest import requests +import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from vaex.dataframe import DataFrame import dataquality from dataquality import AggregateFunction, Condition, ConditionFilter, Operator, config from dataquality.clients import objectstore +from dataquality.exceptions import GalileoWarning from dataquality.loggers import BaseGalileoLogger from dataquality.schemas.task_type import TaskType from dataquality.utils.dq_logger import DQ_LOG_FILE_HOME from tests.test_utils.mock_request import MockResponse +try: + torch.set_default_device("cpu") +except AttributeError: + warnings.warn("Torch default device not set to CPU", GalileoWarning) DEFAULT_API_URL = "http://localhost:8088" UUID_STR = "399057bc-b276-4027-a5cf-48893ac45388" TEST_STORE_DIR = "TEST_STORE" @@ -32,8 +39,8 @@ tokenizer.save_pretrained(LOCAL_MODEL_PATH) try: - model = AutoModelForSequenceClassification.from_pretrained( - LOCAL_MODEL_PATH, device_map="cpu" + model = AutoModelForSequenceClassification.from_pretrained(LOCAL_MODEL_PATH).to( + "cpu" ) except Exception: model = AutoModelForSequenceClassification.from_pretrained(HF_TEST_BERT_PATH).to( diff --git a/tests/integrations/hf/test_text_classification_hf.py b/tests/integrations/hf/test_text_classification_hf.py index 8928aac57..672f41175 100644 --- a/tests/integrations/hf/test_text_classification_hf.py +++ b/tests/integrations/hf/test_text_classification_hf.py @@ -33,7 +33,10 @@ def preprocess_function(examples, tokenizer): return tokenizer( - examples["text"], padding="max_length", max_length=201, truncation=True + examples["text"], + padding="max_length", + max_length=201, + truncation=True, ) @@ -77,6 +80,7 @@ def compute_metrics(eval_pred): per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=1, + use_mps_device=False, weight_decay=0.01, load_best_model_at_end=True, metric_for_best_model=metric_name, @@ -96,7 +100,7 @@ def test_end_to_end_without_callback( """Base case: Training on a dataset""" trainer = Trainer( - model, + model.cpu(), args_default, train_dataset=encoded_train_dataset, eval_dataset=encoded_test_dataset, @@ -134,7 +138,7 @@ def test_hf_watch_e2e( dq.log_dataset(test_dataset, split="test") trainer = Trainer( - model, + model.cpu(), args_default, train_dataset=encoded_train_dataset, eval_dataset=encoded_test_dataset, @@ -186,15 +190,20 @@ def test_remove_unused_columns( per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=1, + use_mps_device=False, weight_decay=0.01, load_best_model_at_end=True, metric_for_best_model=metric_name, push_to_hub=False, + dataloader_drop_last=True, remove_unused_columns=False, + dataloader_num_workers=0, + dataloader_pin_memory=True, ) trainer = Trainer( - model, + model.cpu(), t_args, + # dataloader_drop_last=True, train_dataset=encoded_train_dataset.with_format( "torch", columns=["id", "attention_mask", "input_ids", "label"] ), @@ -223,7 +232,7 @@ def test_training_run( """Base case: Tests watch function to pass""" trainer = Trainer( - model, + model.cpu(), args_default, train_dataset=encoded_train_dataset_repeat, eval_dataset=encoded_test_dataset_repeat,