Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: s2s auto metadata #777

Merged
merged 7 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.1.1"
__version__ = "1.1.2"

import sys
from typing import Any, List, Optional
Expand Down
4 changes: 2 additions & 2 deletions dataquality/dq_auto/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from dataquality.utils.auto import (
add_class_label_to_dataset,
add_val_data_if_missing,
get_meta_cols,
run_name_from_hf_dataset,
)
from dataquality.utils.auto_trainer import do_train
from dataquality.utils.setfit import _get_meta_cols

a = Analytics(ApiClient, dq.config)
a.log_import("auto_tc")
Expand Down Expand Up @@ -105,7 +105,7 @@ def _get_labels(dd: DatasetDict, labels: Optional[List[str]] = None) -> List[str
def _log_dataset_dict(dd: DatasetDict) -> None:
for key in dd:
ds = dd[key]
meta = _get_meta_cols(ds.features)
meta = get_meta_cols(ds.features)
if key in Split.get_valid_keys():
dq.log_dataset(ds, meta=meta, split=key)
else:
Expand Down
5 changes: 3 additions & 2 deletions dataquality/integrations/seq2seq/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dataquality.schemas.task_type import TaskType
from dataquality.utils.auto import (
add_val_data_if_missing,
get_meta_cols,
run_name_from_hf_dataset,
)
from dataquality.utils.torch import cleanup_cuda
Expand Down Expand Up @@ -138,12 +139,12 @@ def _log_dataset_dict(dd: DatasetDict, input_col: str, target_col: str) -> None:
for key in dd.keys():
ds: Dataset = dd[key]
if key in Split.get_valid_keys():
meta = get_meta_cols(ds.features, {input_col, target_col})
if input_col != "text" and "text" in ds.column_names:
ds = ds.rename_columns({"text": "_metadata_text"})
if target_col != "label" and "label" in ds.column_names:
ds = ds.rename_columns({"label": "_metadata_label"})

dq.log_dataset(ds, text=input_col, label=target_col, split=key)
dq.log_dataset(ds, text=input_col, label=target_col, split=key, meta=meta)


def auto(
Expand Down
11 changes: 7 additions & 4 deletions dataquality/integrations/setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
)
from dataquality.schemas.split import Split
from dataquality.schemas.task_type import TaskType
from dataquality.utils.auto import _apply_column_mapping, run_name_from_hf_dataset
from dataquality.utils.auto import (
_apply_column_mapping,
get_meta_cols,
run_name_from_hf_dataset,
)
from dataquality.utils.patcher import PatchManager
from dataquality.utils.setfit import (
SetFitModelHook,
_get_meta_cols,
_prepare_config,
_setup_patches,
get_trainer,
Expand Down Expand Up @@ -346,7 +349,7 @@ def do_model_eval(
for split in [Split.train, Split.test, Split.val]:
if split in encoded_data:
ds = encoded_data[split]
meta_columns = _get_meta_cols(ds.column_names)
meta_columns = get_meta_cols(ds.column_names)
dq_evaluate(
encoded_data[split],
split=split,
Expand All @@ -358,7 +361,7 @@ def do_model_eval(
inf_names = [k for k in encoded_data if k not in Split.get_valid_keys()]
for inf_name in inf_names:
ds = encoded_data[inf_name]
meta_columns = _get_meta_cols(ds.column_names)
meta_columns = get_meta_cols(ds.column_names)
dq_evaluate(
ds,
split=Split.inference, # type: ignore
Expand Down
1 change: 1 addition & 0 deletions dataquality/loggers/data_logger/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _get_input_df(self) -> DataFrame:
C.split_.value: [self.split] * len(self.ids),
C.token_label_positions.value: pa.array(self.token_label_positions),
C.token_label_offsets.value: pa.array(self.token_label_offsets),
**self.meta,
}
)

Expand Down
13 changes: 12 additions & 1 deletion dataquality/utils/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import warnings
from datetime import datetime
from typing import Dict, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Set, Union

import pandas as pd
from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
Expand All @@ -14,6 +14,17 @@
from dataquality.utils.name import BAD_CHARS_REGEX


def get_meta_cols(
cols: Iterable, reserved_cols: Optional[Set[str]] = None
) -> List[str]:
"""Returns the meta columns of a dataset."""
reserved_cols = reserved_cols or set()
default_cols = {"text", "label", "id"}
default_cols = set(reserved_cols).union(default_cols)
meta_columns = [col for col in cols if col not in default_cols]
return list(meta_columns)


def load_data_from_str(data: str) -> Union[pd.DataFrame, Dataset]:
"""Loads string data from either hf or disk.

Expand Down
12 changes: 3 additions & 9 deletions dataquality/utils/setfit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -10,6 +10,7 @@
import dataquality as dq
from dataquality.schemas.split import Split
from dataquality.schemas.task_type import TaskType
from dataquality.utils.auto import get_meta_cols
from dataquality.utils.patcher import Patch, PatchManager

BATCH_LOG_SIZE = 10_000
Expand All @@ -35,13 +36,6 @@ def clear(self) -> None:
self.meta.clear()


def _get_meta_cols(cols: Iterable) -> List[str]:
"""Returns the meta columns of a dataset."""
default_cols = ["text", "label", "id"]
meta_columns = [col for col in cols if col not in default_cols]
return meta_columns


def log_preds_setfit(
model: "SetFitModel",
dataset: Dataset,
Expand Down Expand Up @@ -79,7 +73,7 @@ def log_preds_setfit(
skip_logging = logger_config.helper_data[f"setfit_skip_input_log_{split}"]
# Iterate over the dataset in batches and log the data samples
# and model outputs
meta = _get_meta_cols(dataset.column_names)
meta = get_meta_cols(dataset.column_names)
for i in range(0, len(dataset), batch_size):
batch = dataset[i : i + batch_size]
assert text_col in batch, f"column '{text_col}' must be in batch"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ test = [
"setfit",
"accelerate>=0.19.0",
"typing-inspect==0.8.0",
"typing-extensions==4.0.0",
"typing-extensions==4.0.1",
"lightning",
]
dev = [
Expand Down