Skip to content

Commit

Permalink
fix: Allow aggregated tasks within benchmarks (#1771)
Browse files Browse the repository at this point in the history
* fix: Allow aggregated tasks within benchmarks

Fixes #1231

* feat: Update task filtering, fixing bug on MTEB

- Updated task filtering adding exclusive_language_filter and hf_subset
- fix bug in MTEB where cross-lingual splits were included
- added missing language filtering to MTEB(europe, beta) and MTEB(indic, beta)

The following code outlines the problems:

```py
import mteb
from mteb.benchmarks import MTEB_ENG_CLASSIC

task = [t for t in MTEB_ENG_CLASSIC.tasks if t.metadata.name == "STS22"][0]
# was eq. to:
task = mteb.get_task("STS22", languages=["eng"])
task.hf_subsets
# correct filtering to English datasets:
# ['en', 'de-en', 'es-en', 'pl-en', 'zh-en']
# However it should be:
# ['en']

# with the changes it is:
task = [t for t in MTEB_ENG_CLASSIC.tasks if t.metadata.name == "STS22"][0]
task.hf_subsets
# ['en']
# eq. to
task = mteb.get_task("STS22", hf_subsets=["en"])
# which you can also obtain using the exclusive_language_filter (though not if there was multiple english splits):
task = mteb.get_task("STS22", languages=["eng"], exclusive_language_filter=True)
```

* format

* remove "en-ext" from AmazonCounterfactualClassification

* fixed mteb(deu)

* fix: simplify in a few areas

* wip

* tmp

* sav

* Allow aggregated tasks within benchmarks
Fixes #1231

* ensure correct formatting of eval_langs

* ignore aggregate dataset

* clean up dummy cases

* add to mteb(eng, classic)

* format

* clean up

* Allow aggregated tasks within benchmarks
Fixes #1231

* added fixed from comments

* fix merge

* format

* Updated task type

* Added minor fix for dummy tasks
  • Loading branch information
KennethEnevoldsen authored Jan 29, 2025
1 parent e1be438 commit 8fb59a4
Show file tree
Hide file tree
Showing 18 changed files with 484 additions and 74 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,5 @@ results/
uv.lock

# model loading tests
model_names.txt
model_names.txt
mteb/leaderboard/__cached_results.json
3 changes: 2 additions & 1 deletion mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class AbsTask(ABC):
dataset: dict[HFSubset, DatasetDict] | None = None # type: ignore
data_loaded: bool = False
is_multilingual: bool = False
hf_subsets: list[HFSubset] | None = None
hf_subsets: list[HFSubset]

def __init__(self, seed: int = 42, **kwargs: Any):
self.save_suffix = kwargs.get("save_suffix", "")
Expand All @@ -73,6 +73,7 @@ def __init__(self, seed: int = 42, **kwargs: Any):
np.random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)
self.hf_subsets = list(self.metadata.hf_subsets_to_langscripts.keys())

def check_if_dataset_is_superseded(self):
"""Check if the dataset is superseded by a newer version"""
Expand Down
2 changes: 1 addition & 1 deletion mteb/abstasks/AbsTaskBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def evaluate(
subsets_to_run: list[HFSubset] | None = None,
*,
encode_kwargs: dict[str, Any] = {},
**kwargs,
**kwargs: Any,
) -> dict[HFSubset, ScoresDict]:
if not self.data_loaded:
self.load_data()
Expand Down
25 changes: 20 additions & 5 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@
"machine-translated and verified",
"machine-translated and localized",
"LM-generated and verified",
"multiple",
]

TASK_TYPE = Literal[
"BitextMining",
"Classification",
Expand All @@ -98,6 +98,7 @@
"Speed",
]


TASK_CATEGORY = Literal[
"s2s", # Sentence-to-sentence
"s2p", # Sentence-to-paragraph
Expand Down Expand Up @@ -169,9 +170,10 @@
"gpl-3.0",
"cdla-sharing-1.0",
"mpl-2.0",
"multiple",
]
)

MODALITIES = Literal["text"]
METRIC_NAME = str
METRIC_VALUE = Union[int, float, dict[str, Any]]

Expand Down Expand Up @@ -228,13 +230,13 @@ class TaskMetadata(BaseModel):
bibtex_citation: The BibTeX citation for the dataset. Should be an empty string if no citation is available.
"""

dataset: dict
dataset: dict[str, Any]

name: str
description: str
prompt: str | PromptDict | None = None
type: TASK_TYPE
modalities: list[Literal["text"]] = ["text"]
modalities: list[MODALITIES] = ["text"]
category: TASK_CATEGORY | None = None
reference: STR_URL | None = None

Expand Down Expand Up @@ -335,6 +337,15 @@ def _check_language_code(code):
f"Invalid script code: {script}, you can find valid ISO 15924 codes in {path_to_lang_scripts}"
)

@property
def bcp47_codes(self) -> list[ISO_LANGUAGE_SCRIPT]:
"""Return the languages and script codes of the dataset formatting in accordance with the BCP-47 standard."""
if isinstance(self.eval_langs, dict):
return sorted(
{lang for langs in self.eval_langs.values() for lang in langs}
)
return sorted(set(self.eval_langs))

@property
def languages(self) -> list[str]:
"""Return the languages of the dataset as iso639-3 codes."""
Expand Down Expand Up @@ -421,8 +432,12 @@ def n_samples(self) -> dict[str, int] | None:
for subset, subset_value in stats.items():
if subset == "hf_subset_descriptive_stats":
continue
n_samples[subset] = subset_value["num_samples"]
n_samples[subset] = subset_value["num_samples"] # type: ignore
return n_samples

def __hash__(self) -> int:
return hash(self.model_dump_json())

@property
def revision(self) -> str:
return self.dataset["revision"]
172 changes: 172 additions & 0 deletions mteb/abstasks/aggregate_task_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

import logging
from datetime import datetime
from typing import Any

from pydantic import ConfigDict, model_validator

from mteb.abstasks.AbsTask import AbsTask
from mteb.abstasks.TaskMetadata import (
ANNOTATOR_TYPE,
LANGUAGES,
LICENSES,
MODALITIES,
SAMPLE_CREATION_METHOD,
STR_DATE,
TASK_DOMAIN,
TASK_SUBTYPE,
TASK_TYPE,
HFSubset,
TaskMetadata,
)
from mteb.languages import ISO_LANGUAGE_SCRIPT

logger = logging.getLogger(__name__)


class AggregateTaskMetadata(TaskMetadata):
"""Metadata for an aggregation of tasks. This description only covers exceptions to the TaskMetadata. Many of the field if not filled out will be
autofilled from its tasks.
Attributes:
name: The name of the aggregated task.
description: A description of the task. Should explain the aggregation.
prompt: An aggregate task does not have a prompt, thus this value is always None.
dataset: The dataset for the aggregated task is specified in its tasks. The aggregate task thus only specified the revision and uses a
placeholder path.
tasks: A list of tasks, the majority of the metadata is described within its tasks.
eval_splits: The splits of the tasks used for evaluation.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

name: str
description: str
dataset: dict[str, Any] = {
"path": "aggregate tasks do not have a path", # just a place holder
"revision": "1",
}

tasks: list[AbsTask]
main_score: str
type: TASK_TYPE
eval_splits: list[str]
eval_langs: LANGUAGES = []
prompt: None = None
reference: str | None = None
bibtex_citation: str | None = None

@property
def hf_subsets_to_langscripts(self) -> dict[HFSubset, list[ISO_LANGUAGE_SCRIPT]]:
"""Return a dictionary mapping huggingface subsets to languages."""
return {"default": self.eval_langs} # type: ignore

@model_validator(mode="after") # type: ignore
def compute_unfilled_cases(self) -> AggregateTaskMetadata:
if not self.eval_langs:
self.eval_langs = self.compute_eval_langs()
if not self.date:
self.date = self.compute_date()
if not self.domains:
self.domains = self.compute_domains()
if not self.task_subtypes:
self.task_subtypes = self.compute_task_subtypes()
if not self.license:
self.license = self.compute_license()
if not self.annotations_creators:
self.annotations_creators = self.compute_annotations_creators()
if not self.dialect:
self.dialect = self.compute_dialect()
if not self.sample_creation:
self.sample_creation = self.compute_sample_creation()
if not self.modalities:
self.modalities = self.compute_modalities()

return self

def compute_eval_langs(self) -> list[ISO_LANGUAGE_SCRIPT]:
langs = set()
for task in self.tasks:
langs.update(set(task.metadata.bcp47_codes))
return list(langs)

def compute_date(self) -> tuple[STR_DATE, STR_DATE] | None:
# get min max date from tasks
dates = []
for task in self.tasks:
if task.metadata.date:
dates.append(datetime.fromisoformat(task.metadata.date[0]))
dates.append(datetime.fromisoformat(task.metadata.date[1]))

if not dates:
return None

min_date = min(dates)
max_date = max(dates)
return min_date.isoformat(), max_date.isoformat()

def compute_domains(self) -> list[TASK_DOMAIN] | None:
domains = set()
for task in self.tasks:
if task.metadata.domains:
domains.update(set(task.metadata.domains))
if domains:
return list(domains)
return None

def compute_task_subtypes(self) -> list[TASK_SUBTYPE] | None:
subtypes = set()
for task in self.tasks:
if task.metadata.task_subtypes:
subtypes.update(set(task.metadata.task_subtypes))
if subtypes:
return list(subtypes)
return None

def compute_license(self) -> LICENSES | None:
licenses = set()
for task in self.tasks:
if task.metadata.license:
licenses.add(task.metadata.license)
if len(licenses) > 1:
return "multiple"
return None

def compute_annotations_creators(self) -> ANNOTATOR_TYPE | None:
creators = set()
for task in self.tasks:
if task.metadata.annotations_creators:
creators.add(task.metadata.annotations_creators)
if len(creators) > 1:
logger.warning(
f"Multiple annotations_creators found for tasks in {self.name}. Using None as annotations_creators."
)
return None

def compute_dialect(self) -> list[str] | None:
dialects = set()
for task in self.tasks:
if task.metadata.dialect:
dialects.update(set(task.metadata.dialect))
if dialects:
return list(dialects)
return None

def compute_sample_creation(self) -> SAMPLE_CREATION_METHOD | None:
sample_creations = set()
for task in self.tasks:
if task.metadata.sample_creation:
sample_creations.add(task.metadata.sample_creation)
if len(sample_creations) > 1:
return "multiple"
return None

def compute_modalities(self) -> list[MODALITIES]:
modalities = set()
for task in self.tasks:
if task.metadata.modalities:
modalities.update(set(task.metadata.modalities))
if modalities:
return list(modalities)
return None
Loading

0 comments on commit 8fb59a4

Please sign in to comment.