Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: auto alpaca #778

Merged
merged 3 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.1.2"
__version__ = "1.1.3"

import sys
from typing import Any, List, Optional
Expand Down
11 changes: 0 additions & 11 deletions dataquality/dq_auto/base_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@
from datasets import Dataset, DatasetDict, load_dataset

from dataquality.exceptions import GalileoException
from dataquality.integrations.seq2seq.formatter import (
BaseFormatter,
DefaultFormatter,
get_formatter,
)
from dataquality.schemas.split import Split
from dataquality.utils.auto import (
_apply_column_mapping,
Expand All @@ -20,9 +15,6 @@
class BaseDatasetManager:
DEMO_DATASETS: List[str] = []

def __init__(self) -> None:
self.formatter: BaseFormatter = DefaultFormatter()

def _validate_dataset_dict(
self,
dd: DatasetDict,
Expand Down Expand Up @@ -148,16 +140,13 @@ def try_load_dataset_dict(
if hf_data:
if isinstance(hf_data, str):
dd = load_dataset(hf_data)
self.formatter = get_formatter(hf_data)
else:
dd = hf_data
assert isinstance(dd, DatasetDict), (
"hf_data must be a path to a huggingface DatasetDict in the hf hub or "
"a DatasetDict object. "
"If this is just a Dataset, pass it to `train_data`"
)
# Apply the datasets custom formatter on load dataset dict
dd = dd.map(self.formatter.format_sample)
return dd

return None
5 changes: 5 additions & 0 deletions dataquality/dq_auto/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pandas as pd
from datasets import Dataset, DatasetDict

from dataquality.integrations.seq2seq.formatter import BaseFormatter, DefaultFormatter


@dataclass
class BaseAutoDatasetConfig:
Expand Down Expand Up @@ -46,6 +48,9 @@ class BaseAutoDatasetConfig:
# Column names
input_col: str = "text"
target_col: str = "label"
# Dataset input / output formatter
max_train_size: Optional[int] = None
formatter: BaseFormatter = DefaultFormatter()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a more idiomatic way of initializing for a dataclass:

Suggested change
formatter: BaseFormatter = DefaultFormatter()
from dataclasses import field
...
formatter: BaseFormatter = field(default_factory=DefaultFormatter)


def __post_init__(self) -> None:
if not any([self.hf_data, self.train_path, self.train_data]):
Expand Down
17 changes: 12 additions & 5 deletions dataquality/integrations/seq2seq/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
add_val_data_if_missing,
get_meta_cols,
run_name_from_hf_dataset,
sample_dataset_dict,
)
from dataquality.utils.torch import cleanup_cuda

Expand Down Expand Up @@ -54,7 +55,7 @@ def try_load_dataset_dict_from_config(
hf_data = dataset_config.hf_data
if isinstance(hf_data, str):
dd = load_dataset(hf_data)
self.formatter = get_formatter(hf_data)
dataset_config.formatter = get_formatter(hf_data)
elif isinstance(hf_data, DatasetDict):
dd = hf_data
else:
Expand All @@ -64,8 +65,8 @@ def try_load_dataset_dict_from_config(
"If this is just a Dataset, pass it to `train_data`"
)

# Apply the datasets custom formatter on load dataset dict
dd = dd.map(self.formatter.format_sample)
dataset_config.input_col = dataset_config.formatter.input_col
dataset_config.target_col = dataset_config.formatter.target_col
return dd, dataset_config

return None, dataset_config
Expand Down Expand Up @@ -109,6 +110,9 @@ def get_dataset_dict_from_config(
if test_data is not None:
dd[Split.test] = self._convert_to_hf_dataset(test_data)

# Apply the datasets custom formatter on load dataset dict
dd = dd.map(dataset_config.formatter.format_sample)
dd = sample_dataset_dict(dd, dataset_config)
return self._validate_dataset_dict(dd, []), dataset_config

def _validate_dataset_dict(
Expand Down Expand Up @@ -139,11 +143,15 @@ def _log_dataset_dict(dd: DatasetDict, input_col: str, target_col: str) -> None:
for key in dd.keys():
ds: Dataset = dd[key]
if key in Split.get_valid_keys():
meta = get_meta_cols(ds.features, {input_col, target_col})
if input_col != "text" and "text" in ds.column_names:
ds = ds.rename_columns({"text": "_metadata_text"})
if target_col != "label" and "label" in ds.column_names:
ds = ds.rename_columns({"label": "_metadata_label"})
if input_col != "input" and "input" in ds.column_names:
ds = ds.rename_columns({"input": "_metadata_input"})
if target_col != "target" and "target" in ds.column_names:
ds = ds.rename_columns({"target": "_metadata_target"})
meta = get_meta_cols(ds.features, {input_col, target_col})
dq.log_dataset(ds, text=input_col, label=target_col, split=key, meta=meta)


Expand Down Expand Up @@ -231,7 +239,6 @@ def auto(
dq.init(TaskType.seq2seq, project_name=project_name, run_name=run_name)
input_col = dataset_config.input_col
target_col = dataset_config.target_col

# We 'watch' in get_trainer, which must happen before logging datasets
model, dataloaders = get_trainer(
dd,
Expand Down
4 changes: 3 additions & 1 deletion dataquality/integrations/seq2seq/formatter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Type
from typing import Dict, Optional, Type


@dataclass
class BaseFormatter(ABC):
name: str
input_col: str
target_col: str
max_train_size: Optional[int] = None

@abstractmethod
def format_sample(self, sample: Dict[str, str]) -> Dict[str, str]:
Expand All @@ -31,6 +32,7 @@ class AlpacaFormatter(BaseFormatter):
name: str = "tatsu-lab/alpaca"
input_col: str = "formatted_input"
target_col: str = "output"
max_train_size: int = 1000

def format_sample(self, sample: Dict[str, str]) -> Dict[str, str]:
"""Formats the alpaca dataset for seq2seq
Expand Down
34 changes: 34 additions & 0 deletions dataquality/utils/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,46 @@
import pandas as pd
from datasets import ClassLabel, Dataset, DatasetDict, load_dataset

from dataquality.dq_auto.schema import BaseAutoDatasetConfig
from dataquality.exceptions import GalileoException, GalileoWarning
from dataquality.schemas.split import Split
from dataquality.schemas.task_type import TaskType
from dataquality.utils.name import BAD_CHARS_REGEX


def sample_dataset_dict(
dd: DatasetDict, dataset_config: BaseAutoDatasetConfig
) -> DatasetDict:
"""Samples the dataset dict to the max train size

A few important notes:
- If max train size is greater than the dataset size, we don't sample
- If max train size is None we also don't sample
- We set max eval size to be 25% of max train size
- Test and inference data are not sampled
"""
max_train_sz = (
dataset_config.max_train_size or dataset_config.formatter.max_train_size
)
if not max_train_sz:
return dd

max_eval_sz = int(max_train_sz * 0.25)
for split, dataset in dd.items():
sampled_size = len(dataset)
if split == Split.train:
sampled_size = min(sampled_size, max_train_sz)
elif split == Split.validation:
sampled_size = min(sampled_size, max_eval_sz)

if len(dataset) > sampled_size:
# Slice the dataset to the max size
dataset = dataset.select(range(sampled_size))
dd[split] = dataset

return dd


def get_meta_cols(
cols: Iterable, reserved_cols: Optional[Set[str]] = None
) -> List[str]:
Expand Down