Skip to content

Commit

Permalink
refactor: pr issues
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Apr 13, 2022
1 parent f7385da commit 2566140
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 71 deletions.
71 changes: 7 additions & 64 deletions embeddings/model/lightning_module/sequence_labeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import torch
from datasets import ClassLabel
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torchmetrics import MetricCollection
from transformers import AutoModelForTokenClassification
Expand Down Expand Up @@ -35,18 +36,13 @@ def __init__(
task_model_kwargs=task_model_kwargs,
)
self.ignore_index = ignore_index
self._str2int: Optional[Dict[str, int]] = None
self._int2str: Optional[Dict[int, str]] = None
self.class_label: Optional[ClassLabel] = None

def setup(self, stage: Optional[str] = None) -> None:
if stage in ("fit", None):
assert self.trainer is not None
self._int2str = (
self.trainer.datamodule.dataset["train"].features["labels"].feature._int2str
)
self._str2int = (
self.trainer.datamodule.dataset["train"].features["labels"].feature._str2int
)
self.class_label = self.trainer.datamodule.dataset["train"].features["labels"].feature
assert isinstance(self.class_label, ClassLabel)
super().setup(stage=stage)

def shared_step(self, **batch: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -89,63 +85,10 @@ def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
_logger.warning("Missing labels for the test data")
return None

def str2int(self, values: Union[str, Iterable[Any]]) -> Union[int, Iterable[Any]]:
"""Conversion class name string => integer duplicated from huggingface ClassLabel."""
assert isinstance(values, str) or isinstance(
values, Iterable
), f"Values {values} should be a string or an Iterable (list, numpy array, pytorch, tensorflow tensors)"
return_list = True
if isinstance(values, str):
values = [values]
return_list = False

output = []
for value in values:
if self._str2int:
# strip key if not in dict
if value not in self._str2int:
value = str(value).strip()
output.append(self._str2int[str(value)])
else:
# No names provided, try to integerize
failed_parse = False
try:
output.append(int(value))
if not 0 <= int(value) < self.hparams.num_classes:
failed_parse = True
except ValueError:
failed_parse = True
if failed_parse:
raise ValueError(f"Invalid string class label {value}")
return output if return_list else output[0]

def int2str(self, values: Union[int, Iterable[Any]]) -> Union[str, Iterable[Any]]:
"""Conversion integer => class name string duplicated from huggingface ClassLabel."""
assert isinstance(values, int) or isinstance(
values, Iterable
), f"Values {values} should be an integer or an Iterable (list, numpy array, pytorch, tensorflow tensors)"
return_list = True
if isinstance(values, int):
values = [values]
return_list = False

for v in values:
if not 0 <= v < self.hparams.num_classes:
raise ValueError(f"Invalid integer class label {v:d}")

if self._int2str:
output = [self._int2str[int(v)] for v in values]
else:
# No names provided, return str(values)
output = [str(v) for v in values]
return output if return_list else output[0]

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["_int2str"] = self._int2str
checkpoint["_str2int"] = self._str2int
checkpoint["class_label"] = self.class_label
super().on_save_checkpoint(checkpoint=checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self._int2str = checkpoint["_int2str"]
self._str2int = checkpoint["_str2int"]
self.class_label = checkpoint["class_label"]
super().on_load_checkpoint(checkpoint=checkpoint)
6 changes: 3 additions & 3 deletions embeddings/task/lightning_task/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@ def predict(
"y_pred": np.array(predictions, dtype=object),
"y_true": np.array(ground_truth, dtype=object),
"y_probabilities": np.array(probabilities, dtype=object),
"names": np.array(self.model.target_names),
}
if return_names:
results["names"] = np.array(self.model.target_names)
return results

def _map_filter_data(
self, data: nptyping.NDArray[Any], ground_truth_data: nptyping.NDArray[Any]
) -> List[str]:
assert self.model is not None
return [
self.model.int2str(x.item()) for x in data[ground_truth_data != self.model.ignore_index]
self.model.class_label.int2str(x.item())
for x in data[ground_truth_data != self.model.ignore_index]
]

@classmethod
Expand Down
5 changes: 1 addition & 4 deletions embeddings/task/lightning_task/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ def predict(
) -> Dict[str, nptyping.NDArray[Any]]:
assert self.model is not None
results = self.model.predict(dataloader=dataloader)
if return_names:
assert self.trainer is not None
assert hasattr(self.trainer, "datamodule")
results["names"] = np.array(self.model.target_names)
results["names"] = np.array(self.model.target_names)
return results

@classmethod
Expand Down

0 comments on commit 2566140

Please sign in to comment.