Skip to content

Commit

Permalink
Chore/speedup setfit (#636)
Browse files Browse the repository at this point in the history
@franz101

---------

Co-authored-by: Franz Gusto <franz@rungalileo.io>
  • Loading branch information
Ben Epstein and franz101 authored May 31, 2023
1 parent 7683b39 commit d0050d8
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 69 deletions.
124 changes: 55 additions & 69 deletions dataquality/integrations/setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dataquality as dq
from dataquality.schemas.split import Split
from dataquality.utils.patcher import Cleanup, Patch, PatchManager, RefManager
from dataquality.utils.setfit import log_preds_setfit

if TYPE_CHECKING:
from datasets import Dataset
Expand Down Expand Up @@ -157,6 +158,7 @@ def __init__(
labels: List[str] = [],
finish: bool = True,
wait: bool = False,
batch_size: Optional[int] = None,
) -> None:
"""Patch to SetFit trainer to run dataquality after training.
:param setfit_trainer: SetFit trainer
Expand All @@ -174,6 +176,7 @@ def __init__(
self.wait = wait
self.project_name = project_name
self.run_name = run_name
self.batch_size = batch_size

def _patch(self) -> "Patch":
"""Patch SetFit trainer by replacing train function with self."""
Expand All @@ -185,13 +188,16 @@ def _patch(self) -> "Patch":
setattr(self.trainer, self.function_name, self)
return self

def __call__(self, *args: Any, **kwds: Any) -> Any:
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Call train function and run dataquality after training."""
batch_size = kwds.get("batch_size", self.trainer.batch_size)
batch_size = kwargs.get("batch_size", self.trainer.batch_size)
if batch_size is not None and len(args) > 0:
batch_size = args[1]
# If batch_size is set in watch function, override the batch_size
if self.batch_size is not None:
batch_size = self.batch_size

res = self.old_fn(*args, **kwds)
res = self.old_fn(*args, **kwargs)
model = self.trainer.model
dq_hook = SetFitModelHook(model)
dq_store = dq_hook.store
Expand All @@ -210,12 +216,12 @@ def __call__(self, *args: Any, **kwds: Any) -> Any:
init_kwargs["run_name"] = self.run_name

dq.init("text_classification", **init_kwargs)
labels: Any = self.labels
labels: List = self.labels
if not labels:
labels = dq.get_data_logger().logger_config.labels
if not labels:
labels = getattr(train_dataset.features.get("label", {}), "names", None)
assert labels, "Labels must be set (watch(trainer, labels=[...]))"
labels = getattr(train_dataset.features.get("label", {}), "names", [])
assert len(labels), "Labels must be set (watch(trainer, labels=[...]))"
dq.set_labels_for_run(labels)
datasets = [train_dataset]
if eval_dataset is not None:
Expand All @@ -224,6 +230,7 @@ def __call__(self, *args: Any, **kwds: Any) -> Any:
eval_dataset, self.trainer.column_mapping
)
datasets.append(eval_dataset)

for split in [Split.training, Split.validation]:
if split == Split.training:
dataset = train_dataset
Expand All @@ -233,24 +240,15 @@ def __call__(self, *args: Any, **kwds: Any) -> Any:
continue
if "id" not in dataset.features:
dataset = dataset.map(lambda x, idx: {"id": idx}, with_indices=True)
for i in range(0, len(dataset), batch_size):
batch = dataset[i : i + batch_size]
model.predict_proba(batch["text"])
# 🔭🌕 Galileo logging
dq.log_data_samples(
texts=batch["text"],
ids=batch["id"],
labels=[labels[label_id] for label_id in batch["label"]],
split=split,
)
# 🔭🌕 Galileo logging
dq.log_model_outputs(
ids=batch["id"],
probs=dq_store["output"],
embs=dq_store["input_args"][0],
split=split,
epoch=0,
)

log_preds_setfit(
model=model,
dataset=dataset,
dq_store=dq_store,
batch_size=batch_size,
split=split,
)

if self.finish:
dq.finish(wait=self.wait)

Expand All @@ -273,27 +271,40 @@ def unwatch(setfit_obj: Optional[Union["SetFitModel", "SetFitTrainer"]]) -> None

def watch(
setfit: Union["SetFitModel", "SetFitTrainer"],
labels: List[str] = [],
labels: Optional[List[str]] = None,
project_name: str = "",
run_name: str = "",
finish: bool = True,
wait: bool = False,
batch_size: Optional[int] = None,
) -> Optional[Callable]:
"""Watch SetFit model by replacing predict_proba function with SetFitModelHook.
:param model: SetFit model"""
"""Watch a SetFit model or trainer and extract model outputs for dataquality.
Returns a function that can be used to evaluate the model on a dataset.
:param setfit: SetFit model or trainer
:param labels: list of labels
:param project_name: name of project
:param run_name: name of run
:param finish: whether to run dq.finish after evaluation
:param wait: whether to wait for dq.finish
:param batch_size: batch size for evaluation
:return: dq_evaluate function
"""
from setfit import SetFitTrainer

labels = labels or []
model = setfit

setfitmanager = PatchManager()

if setfit.__class__.__name__ == "SetFitTrainer":
model = setfit.model
if isinstance(setfit, SetFitTrainer):
patched_trainer = _PatchSetFitTrainer(
setfit,
labels=labels,
finish=finish,
wait=wait,
run_name=run_name,
project_name=project_name,
batch_size=batch_size,
)
setfitmanager.add_patch(patched_trainer)
return None
Expand All @@ -307,7 +318,6 @@ def evaluate(model: "SetFitModel") -> Callable:
:return: SetFitModelHook object"""
dq_hook = SetFitModelHook(model)
dq_store = dq_hook.store
labels = dq.get_data_logger().logger_config.labels

helper_data = dq.get_data_logger().logger_config.helper_data

Expand All @@ -319,11 +329,7 @@ def dq_evaluate(
dataset: "Dataset",
split: Split,
inference_name: Optional[str] = None,
column_mapping: Optional[Dict] = {
"text": "text",
"id": "id",
"label": "label",
},
column_mapping: Optional[Dict] = None,
batch_size: int = 64,
) -> torch.Tensor:
"""Evaluate SetFit model and log input and output to Galileo.
Expand All @@ -333,42 +339,22 @@ def dq_evaluate(
:param column_mapping: mapping of column names (if different from default)
:return: output of SetFitModel.predict_proba function"""

text_col = "text"
id_col = "id"
label_col = "label"
column_mapping = column_mapping or dict(
text="text",
id="id",
label="label",
)

if column_mapping is not None:
dataset = _apply_column_mapping(dataset, column_mapping)
preds: List[torch.Tensor] = []

for i in range(0, len(dataset), batch_size):
batch = dataset[i : i + batch_size]

assert text_col in batch, f"column '{text_col}' must be in batch"
assert id_col in batch, f"column '{id_col}' text must be in batch"

pred = model.predict_proba(batch[text_col])
preds.append(pred)
# 🔭🌕 Galileo logging
log_args = dict(texts=batch["text"], ids=batch[id_col], split=split)
inference_dict: Dict[str, str] = {}
if inference_name is not None:
log_args["inference_name"] = inference_name
inference_dict["inference_name"] = inference_name
else:
assert label_col in batch, f"column '{label_col}' must be in batch"
log_args["labels"] = [labels[label] for label in batch[label_col]]

dq.log_data_samples(**log_args)
# 🔭🌕 Galileo logging
dq.log_model_outputs(
ids=batch[id_col],
probs=dq_store["output"],
embs=dq_store["input_args"][0],
split=split,
epoch=0,
**inference_dict, # type: ignore
)

return torch.concat(preds)
return log_preds_setfit(
model=model,
dataset=dataset,
dq_store=dq_store,
batch_size=batch_size,
split=split,
inference_name=inference_name,
)

return dq_evaluate
114 changes: 114 additions & 0 deletions dataquality/utils/setfit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Dict, List, Optional

import torch
from torch import Tensor

import dataquality as dq
from dataquality.schemas.split import Split

BATCH_LOG_SIZE = 10_000

if TYPE_CHECKING:
from datasets import Dataset
from setfit import SetFitModel


@dataclass
class DataSampleLogArgs:
texts: List[str]
ids: List[int]
split: Split
inference_name: Optional[str]
labels: List

def __init__(
self,
split: Split,
inference_name: Optional[str] = None,
) -> None:
"""DataSampleLogArgs is a helper class for logging data samples to Galileo.
:param split: The split of the data samples (for example "training")
:param inference_name: The name of the inference (for example "inference_run_1")
"""
self.texts = []
self.ids = []
self.labels = []
self.split = split
self.inference_name = inference_name

def clear(self) -> None:
"""Resets the arrays of the class."""
self.texts.clear()
self.ids.clear()
self.labels.clear()


def log_preds_setfit(
model: "SetFitModel",
dataset: "Dataset",
split: Split,
dq_store: Dict,
batch_size: int,
inference_name: Optional[str] = None,
return_preds: bool = False,
) -> Tensor:
"""Logs the data samples and model outputs for a SetFit model.
:param model: The SetFit model
:param dataset: The dataset in the form of a HuggingFace Dataset
:param split: The split of the data samples (for example "training")
:param dq_store: The dataquality store
:param batch_size: The batch size
:param inference_name: The name of the inference (for example "inference_run_1")
:param return_preds: Whether to return the predictions
:return: The predictions
"""
text_col = "text"
id_col = "id"
label_col = "label"
preds: List[Tensor] = []
log_args: DataSampleLogArgs = DataSampleLogArgs(split=split)
inference_dict: Dict[str, str] = {}
if inference_name is not None:
log_args.inference_name = inference_name
inference_dict["inference_name"] = inference_name

labels = dq.get_data_logger().logger_config.labels

# Iterate over the dataset in batches and log the data samples
# and model outputs
for i in range(0, len(dataset), batch_size):
batch = dataset[i : i + batch_size]
assert text_col in batch, f"column '{text_col}' must be in batch"
assert id_col in batch, f"column '{id_col}' text must be in batch"

if inference_name is None:
assert label_col in batch, f"column '{label_col}' must be in batch"
log_args.labels += [labels[label] for label in batch[label_col]]

pred = model.predict_proba(batch[text_col])
if return_preds:
preds.append(pred)
# 🔭🌕 Galileo logging
log_args.texts += batch[text_col]
log_args.ids += batch[id_col]

if len(log_args.texts) >= BATCH_LOG_SIZE:
dq.log_data_samples(**asdict(log_args))
log_args.clear()

# 🔭🌕 Galileo logging
dq.log_model_outputs(
ids=batch[id_col],
probs=dq_store["output"],
embs=dq_store["input_args"][0],
split=split,
epoch=0,
**inference_dict, # type: ignore
)
# Log any leftovers
if log_args:
dq.log_data_samples(**asdict(log_args))
if not return_preds:
return torch.tensor([])
return torch.concat(preds)

0 comments on commit d0050d8

Please sign in to comment.