Skip to content

Commit

Permalink
Fixed transformers move to accelerate (#717)
Browse files Browse the repository at this point in the history
Changed the ids to a queue
  • Loading branch information
franz101 authored Jul 20, 2023
1 parent a726fe8 commit 1820911
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 140 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "0.9.7"
__version__ = "0.9.8"

import sys
from typing import Any, List, Optional
Expand Down
17 changes: 9 additions & 8 deletions dataquality/integrations/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,22 @@ def _on_step_end(self) -> None:
extracted in the hooks and we need to log them in the on_step_end
method.
"""

model_outputs_store = self.torch_helper_data.model_outputs_store
# Workaround for multiprocessing
if model_outputs_store.get("ids") is None and len(
if model_outputs_store.ids is None and len(
self.torch_helper_data.dl_next_idx_ids
):
model_outputs_store["ids"] = self.torch_helper_data.dl_next_idx_ids.pop(0)
model_outputs_store.ids = self.torch_helper_data.dl_next_idx_ids.pop(0)

# Log only if embedding exists
assert model_outputs_store.get("embs") is not None, GalileoException(
assert model_outputs_store.embs is not None, GalileoException(
"Embedding passed to the logger can not be logged"
)
assert model_outputs_store.get("logits") is not None, GalileoException(
assert model_outputs_store.logits is not None, GalileoException(
"Logits passed to the logger can not be logged"
)
assert model_outputs_store.get("ids") is not None, GalileoException(
assert model_outputs_store.ids is not None, GalileoException(
"id column missing in dataset (needed to map rows to the indices/ids)"
)
# Convert the indices to ids
Expand All @@ -173,10 +174,10 @@ def _on_step_end(self) -> None:
"Current split must be set before logging"
)
cur_split = cur_split.lower() # type: ignore
model_outputs_store["ids"] = map_indices_to_ids(
self.logger_config.idx_to_id_map[cur_split], model_outputs_store["ids"]
model_outputs_store.ids = map_indices_to_ids(
self.logger_config.idx_to_id_map[cur_split], model_outputs_store.ids
)
dq.log_model_outputs(**model_outputs_store)
dq.log_model_outputs(**model_outputs_store.to_dict())
model_outputs_store.clear()


Expand Down
37 changes: 30 additions & 7 deletions dataquality/integrations/torch_semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,12 @@ def _dq_logit_hook(
logits = model_output["out"]
else:
logits = model_output
if not isinstance(logits, Tensor):
raise ValueError(
"Logits are not a tensor. Please ensure the logits are a tensor."
)
model_outputs_store = self.torch_helper_data.model_outputs_store
model_outputs_store["logits"] = logits
model_outputs_store.logits = logits

def _dq_classifier_hook_with_step_end(
self,
Expand Down Expand Up @@ -205,7 +209,7 @@ def _dq_input_hook(
"""
# model input comes as a tuple of length 1
self.torch_helper_data.model_input = model_input[0].detach().cpu().numpy()
self.torch_helper_data.model_input = model_input[0].detach().cpu()

def get_image_ids_and_image_paths(
self, split: str, logging_data: Dict[str, Any]
Expand Down Expand Up @@ -359,11 +363,21 @@ def get_argmax_probs(
Tuple[torch.Tensor, torch.Tensor]: argmax and logits tensors
"""
# resize the logits to the input size based on hooks
preds = self.torch_helper_data.model_outputs_store["logits"]
preds = self.torch_helper_data.model_outputs_store.logits
if preds is None:
raise ValueError(
"Logits are missing in dataquality,"
" have connected to the right model layer?"
)
elif not isinstance(preds, Tensor):
raise ValueError(
f"Logits are not a tensor. Please ensure the logits are a tensor. \
Got {type(preds)}"
)
if preds.dtype == torch.float16:
preds = preds.to(torch.float32)
input_shape = self.torch_helper_data.model_input.shape[-2:]
preds = F.interpolate(preds, size=input_shape, mode="bilinear")
preds = Tensor(F.interpolate(preds, size=input_shape, mode="bilinear"))

# checks whether the model is (n, classes, w, h), or (n, w, h, classes)
# takes the max in case of binary classification
Expand Down Expand Up @@ -394,9 +408,18 @@ def _on_step_end(self) -> None:
# if we have not inferred the number of classes from the model architecture

# takes the max of the logits shape and 2 in case of binary classification
self.number_classes = max(
self.torch_helper_data.model_outputs_store["logits"].shape[1], 2
)
logits = self.torch_helper_data.model_outputs_store.logits
if logits is None:
raise ValueError(
"Logits are missing in dataquality,"
" have connected to the right model layer?"
)
elif not isinstance(logits, Tensor):
raise ValueError(
f"Logits are not a tensor. Please ensure the logits are a tensor. \
Got {type(logits)}"
)
self.number_classes = max(logits.shape[1], 2)
if not self.init_lm_labels_flag:
self._init_lm_labels()
self.init_lm_labels_flag = True
Expand Down
9 changes: 5 additions & 4 deletions dataquality/integrations/transformers_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,21 @@ def __init__(
def _do_log(self) -> None:
"""Log the model outputs (called by the hook)"""
# Log only if embedding exists
assert self.model_outputs_store.get("embs") is not None, GalileoException(
self.model_outputs_store.ids = self.model_outputs_store.ids_queue.pop(0)
assert self.model_outputs_store.embs is not None, GalileoException(
"Embedding passed to the logger can not be logged"
)
assert self.model_outputs_store.get("logits") is not None, GalileoException(
assert self.model_outputs_store.logits is not None, GalileoException(
"Logits passed to the logger can not be logged"
)
assert self.model_outputs_store.get("ids") is not None, GalileoException(
assert self.model_outputs_store.ids is not None, GalileoException(
"Did you map IDs to your dataset before watching the model? You can run:\n"
"`ds= dataset.map(lambda x, idx: {'id': idx}, with_indices=True)`\n"
"id (index) column is needed in the dataset for logging"
)

# 🔭🌕 Galileo logging
dq.log_model_outputs(**self.model_outputs_store)
dq.log_model_outputs(**self.model_outputs_store.to_dict())
self.model_outputs_store.clear()

def validate(
Expand Down
Loading

0 comments on commit 1820911

Please sign in to comment.