-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLMTEB_retrieval.py
48 lines (40 loc) · 1.63 KB
/
LMTEB_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from mteb import MTEB
from sentence_transformers import SentenceTransformer
from mteb import AbsTaskRetrieval
from datasets import load_dataset, DatasetDict
from collections import defaultdict
def load_retrieval_data(hf_hub_name, eval_splits):
eval_split = eval_splits[0]
dataset = load_dataset(hf_hub_name)
qrels = load_dataset(hf_hub_name + '-qrels')[eval_split]
corpus = {e['id']: {'text': e['text']} for e in dataset['corpus']}
queries = {e['id']: e['text'] for e in dataset['queries']}
relevant_docs = defaultdict(dict)
for e in qrels:
relevant_docs[e['qid']][e['pid']] = e['score']
corpus = DatasetDict({eval_split:corpus})
queries = DatasetDict({eval_split:queries})
relevant_docs = DatasetDict({eval_split:relevant_docs})
return corpus, queries, relevant_docs
class LongDocRetrieval(AbsTaskRetrieval):
@property
def description(self):
# TODO implement hf name
return {
'name':'LongDocRetrieval',
'hf_hub_name':'xxxx',
'eval_splits':['dev'],
'type':'retrieval',
'category':'s2p',
'main_score':'ndcg_at_10'
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True
if __name__=='__main__':
model = SentenceTransformer("xxx")
evaluation = MTEB(tasks=[LongDocRetrieval()])
evaluation.run(model)