-
Notifications
You must be signed in to change notification settings - Fork 307
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Allow aggregated tasks within benchmarks (#1771)
* fix: Allow aggregated tasks within benchmarks Fixes #1231 * feat: Update task filtering, fixing bug on MTEB - Updated task filtering adding exclusive_language_filter and hf_subset - fix bug in MTEB where cross-lingual splits were included - added missing language filtering to MTEB(europe, beta) and MTEB(indic, beta) The following code outlines the problems: ```py import mteb from mteb.benchmarks import MTEB_ENG_CLASSIC task = [t for t in MTEB_ENG_CLASSIC.tasks if t.metadata.name == "STS22"][0] # was eq. to: task = mteb.get_task("STS22", languages=["eng"]) task.hf_subsets # correct filtering to English datasets: # ['en', 'de-en', 'es-en', 'pl-en', 'zh-en'] # However it should be: # ['en'] # with the changes it is: task = [t for t in MTEB_ENG_CLASSIC.tasks if t.metadata.name == "STS22"][0] task.hf_subsets # ['en'] # eq. to task = mteb.get_task("STS22", hf_subsets=["en"]) # which you can also obtain using the exclusive_language_filter (though not if there was multiple english splits): task = mteb.get_task("STS22", languages=["eng"], exclusive_language_filter=True) ``` * format * remove "en-ext" from AmazonCounterfactualClassification * fixed mteb(deu) * fix: simplify in a few areas * wip * tmp * sav * Allow aggregated tasks within benchmarks Fixes #1231 * ensure correct formatting of eval_langs * ignore aggregate dataset * clean up dummy cases * add to mteb(eng, classic) * format * clean up * Allow aggregated tasks within benchmarks Fixes #1231 * added fixed from comments * fix merge * format * Updated task type * Added minor fix for dummy tasks
- Loading branch information
1 parent
e1be438
commit 8fb59a4
Showing
18 changed files
with
484 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,4 +147,5 @@ results/ | |
uv.lock | ||
|
||
# model loading tests | ||
model_names.txt | ||
model_names.txt | ||
mteb/leaderboard/__cached_results.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from datetime import datetime | ||
from typing import Any | ||
|
||
from pydantic import ConfigDict, model_validator | ||
|
||
from mteb.abstasks.AbsTask import AbsTask | ||
from mteb.abstasks.TaskMetadata import ( | ||
ANNOTATOR_TYPE, | ||
LANGUAGES, | ||
LICENSES, | ||
MODALITIES, | ||
SAMPLE_CREATION_METHOD, | ||
STR_DATE, | ||
TASK_DOMAIN, | ||
TASK_SUBTYPE, | ||
TASK_TYPE, | ||
HFSubset, | ||
TaskMetadata, | ||
) | ||
from mteb.languages import ISO_LANGUAGE_SCRIPT | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AggregateTaskMetadata(TaskMetadata): | ||
"""Metadata for an aggregation of tasks. This description only covers exceptions to the TaskMetadata. Many of the field if not filled out will be | ||
autofilled from its tasks. | ||
Attributes: | ||
name: The name of the aggregated task. | ||
description: A description of the task. Should explain the aggregation. | ||
prompt: An aggregate task does not have a prompt, thus this value is always None. | ||
dataset: The dataset for the aggregated task is specified in its tasks. The aggregate task thus only specified the revision and uses a | ||
placeholder path. | ||
tasks: A list of tasks, the majority of the metadata is described within its tasks. | ||
eval_splits: The splits of the tasks used for evaluation. | ||
""" | ||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
name: str | ||
description: str | ||
dataset: dict[str, Any] = { | ||
"path": "aggregate tasks do not have a path", # just a place holder | ||
"revision": "1", | ||
} | ||
|
||
tasks: list[AbsTask] | ||
main_score: str | ||
type: TASK_TYPE | ||
eval_splits: list[str] | ||
eval_langs: LANGUAGES = [] | ||
prompt: None = None | ||
reference: str | None = None | ||
bibtex_citation: str | None = None | ||
|
||
@property | ||
def hf_subsets_to_langscripts(self) -> dict[HFSubset, list[ISO_LANGUAGE_SCRIPT]]: | ||
"""Return a dictionary mapping huggingface subsets to languages.""" | ||
return {"default": self.eval_langs} # type: ignore | ||
|
||
@model_validator(mode="after") # type: ignore | ||
def compute_unfilled_cases(self) -> AggregateTaskMetadata: | ||
if not self.eval_langs: | ||
self.eval_langs = self.compute_eval_langs() | ||
if not self.date: | ||
self.date = self.compute_date() | ||
if not self.domains: | ||
self.domains = self.compute_domains() | ||
if not self.task_subtypes: | ||
self.task_subtypes = self.compute_task_subtypes() | ||
if not self.license: | ||
self.license = self.compute_license() | ||
if not self.annotations_creators: | ||
self.annotations_creators = self.compute_annotations_creators() | ||
if not self.dialect: | ||
self.dialect = self.compute_dialect() | ||
if not self.sample_creation: | ||
self.sample_creation = self.compute_sample_creation() | ||
if not self.modalities: | ||
self.modalities = self.compute_modalities() | ||
|
||
return self | ||
|
||
def compute_eval_langs(self) -> list[ISO_LANGUAGE_SCRIPT]: | ||
langs = set() | ||
for task in self.tasks: | ||
langs.update(set(task.metadata.bcp47_codes)) | ||
return list(langs) | ||
|
||
def compute_date(self) -> tuple[STR_DATE, STR_DATE] | None: | ||
# get min max date from tasks | ||
dates = [] | ||
for task in self.tasks: | ||
if task.metadata.date: | ||
dates.append(datetime.fromisoformat(task.metadata.date[0])) | ||
dates.append(datetime.fromisoformat(task.metadata.date[1])) | ||
|
||
if not dates: | ||
return None | ||
|
||
min_date = min(dates) | ||
max_date = max(dates) | ||
return min_date.isoformat(), max_date.isoformat() | ||
|
||
def compute_domains(self) -> list[TASK_DOMAIN] | None: | ||
domains = set() | ||
for task in self.tasks: | ||
if task.metadata.domains: | ||
domains.update(set(task.metadata.domains)) | ||
if domains: | ||
return list(domains) | ||
return None | ||
|
||
def compute_task_subtypes(self) -> list[TASK_SUBTYPE] | None: | ||
subtypes = set() | ||
for task in self.tasks: | ||
if task.metadata.task_subtypes: | ||
subtypes.update(set(task.metadata.task_subtypes)) | ||
if subtypes: | ||
return list(subtypes) | ||
return None | ||
|
||
def compute_license(self) -> LICENSES | None: | ||
licenses = set() | ||
for task in self.tasks: | ||
if task.metadata.license: | ||
licenses.add(task.metadata.license) | ||
if len(licenses) > 1: | ||
return "multiple" | ||
return None | ||
|
||
def compute_annotations_creators(self) -> ANNOTATOR_TYPE | None: | ||
creators = set() | ||
for task in self.tasks: | ||
if task.metadata.annotations_creators: | ||
creators.add(task.metadata.annotations_creators) | ||
if len(creators) > 1: | ||
logger.warning( | ||
f"Multiple annotations_creators found for tasks in {self.name}. Using None as annotations_creators." | ||
) | ||
return None | ||
|
||
def compute_dialect(self) -> list[str] | None: | ||
dialects = set() | ||
for task in self.tasks: | ||
if task.metadata.dialect: | ||
dialects.update(set(task.metadata.dialect)) | ||
if dialects: | ||
return list(dialects) | ||
return None | ||
|
||
def compute_sample_creation(self) -> SAMPLE_CREATION_METHOD | None: | ||
sample_creations = set() | ||
for task in self.tasks: | ||
if task.metadata.sample_creation: | ||
sample_creations.add(task.metadata.sample_creation) | ||
if len(sample_creations) > 1: | ||
return "multiple" | ||
return None | ||
|
||
def compute_modalities(self) -> list[MODALITIES]: | ||
modalities = set() | ||
for task in self.tasks: | ||
if task.metadata.modalities: | ||
modalities.update(set(task.metadata.modalities)) | ||
if modalities: | ||
return list(modalities) | ||
return None |
Oops, something went wrong.