From c078a98680d757863b9ce411e15e723a0b6d0364 Mon Sep 17 00:00:00 2001 From: Gabriele Picco Date: Thu, 1 Dec 2022 15:32:33 +0000 Subject: [PATCH] :recycle: Refactor fewrel dataset --- zshot/evaluation/dataset/fewrel/fewrel.py | 40 ++++------------------- zshot/tests/evaluation/test_evaluation.py | 14 +++----- 2 files changed, 10 insertions(+), 44 deletions(-) diff --git a/zshot/evaluation/dataset/fewrel/fewrel.py b/zshot/evaluation/dataset/fewrel/fewrel.py index 6e80f17..cb50a00 100644 --- a/zshot/evaluation/dataset/fewrel/fewrel.py +++ b/zshot/evaluation/dataset/fewrel/fewrel.py @@ -1,32 +1,10 @@ -from datasets import load_dataset -from tqdm import tqdm - -from zshot import Linker -from zshot.utils.data_models import Span - - -class DummyLinkerEnd2End(Linker): - @property - def is_end2end(self) -> bool: - return True +from typing import Optional, Union - def predict(self, data): - return [ - [ - Span( - item["start"], - item["end"], - item["label"], - ) - for item in doc_ents - ] - for doc_ents in enumerate(data) - ] +from datasets import load_dataset, Split +from tqdm import tqdm def get_entity_data(e, tokenized_sentence): - # import pdb - # pdb.set_trace() d = {"start": None, "end": None, "label": e["type"]} token_indices = e["indices"][0] s = "" @@ -39,21 +17,15 @@ def get_entity_data(e, tokenized_sentence): if idx == token_indices[-1]: d["end"] = curr_idx d["sentence"] = s.strip() - # pdb.set_trace() return d -def get_few_rel_data(split_name="val_wiki", limit=-1): +def get_few_rel_data(split_name: Optional[Union[str, Split]] = "val_wiki"): wiki_val = load_dataset("few_rel", split=split_name) - relations_descriptions = wiki_val["names"][:limit] - tokenized_sentences = wiki_val["tokens"][:limit] + relations_descriptions = wiki_val["names"] + tokenized_sentences = wiki_val["tokens"] sentences = [" ".join(tokens) for tokens in tokenized_sentences] - if limit != -1: - sentences = sentences[:limit] - gt = [item[0] for item in relations_descriptions] - # label_mapping = {l: idx for idx, l in enumerate(list(set(gt)))} - # gt = [label_mapping.get(item) for item in gt] heads = wiki_val["head"] tails = wiki_val["tail"] entities_data = [] diff --git a/zshot/tests/evaluation/test_evaluation.py b/zshot/tests/evaluation/test_evaluation.py index b78a641..61a5e00 100644 --- a/zshot/tests/evaluation/test_evaluation.py +++ b/zshot/tests/evaluation/test_evaluation.py @@ -1,4 +1,3 @@ -# import pdb from typing import Iterator, List, Tuple import spacy @@ -323,7 +322,6 @@ def get_dataset(self, gt: List[str], sentence: List[str]): "labels": gt, } dataset = Dataset.from_dict(data_dict) - # dataset.entities = ENTITIES return dataset def test_relation_classification_prediction(self): @@ -332,11 +330,9 @@ def test_relation_classification_prediction(self): sentences, relations_descriptions, gt, - ) = get_few_rel_data(split_name="val_wiki", limit=5) + ) = get_few_rel_data(split_name="val_wiki[0:5]") - # pdb.set_trace() custom_evaluator = RelationExtractorEvaluator() - # pdb.set_trace() pipe = get_relation_extraction_pipeline( entities_data, [ @@ -344,14 +340,12 @@ def test_relation_classification_prediction(self): for name, desc in set([(i, j) for i, j in relations_descriptions]) ], ) - # pdb.set_trace() - custom_evaluator.compute( + results = custom_evaluator.compute( pipe, self.get_dataset(gt, sentences), input_column="sentences", label_column="labels", metric=RelEval(), ) - # print("metrics: {}".format(metrics)) - # pdb.set_trace() - assert True + assert len(sentences) == 5 + assert results is not None