Skip to content

Commit

Permalink
♻️ Refactor fewrel dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
GabrielePicco committed Dec 1, 2022
1 parent b737e06 commit c078a98
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 44 deletions.
40 changes: 6 additions & 34 deletions zshot/evaluation/dataset/fewrel/fewrel.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,10 @@
from datasets import load_dataset
from tqdm import tqdm

from zshot import Linker
from zshot.utils.data_models import Span


class DummyLinkerEnd2End(Linker):
@property
def is_end2end(self) -> bool:
return True
from typing import Optional, Union

def predict(self, data):
return [
[
Span(
item["start"],
item["end"],
item["label"],
)
for item in doc_ents
]
for doc_ents in enumerate(data)
]
from datasets import load_dataset, Split
from tqdm import tqdm


def get_entity_data(e, tokenized_sentence):
# import pdb
# pdb.set_trace()
d = {"start": None, "end": None, "label": e["type"]}
token_indices = e["indices"][0]
s = ""
Expand All @@ -39,21 +17,15 @@ def get_entity_data(e, tokenized_sentence):
if idx == token_indices[-1]:
d["end"] = curr_idx
d["sentence"] = s.strip()
# pdb.set_trace()
return d


def get_few_rel_data(split_name="val_wiki", limit=-1):
def get_few_rel_data(split_name: Optional[Union[str, Split]] = "val_wiki"):
wiki_val = load_dataset("few_rel", split=split_name)
relations_descriptions = wiki_val["names"][:limit]
tokenized_sentences = wiki_val["tokens"][:limit]
relations_descriptions = wiki_val["names"]
tokenized_sentences = wiki_val["tokens"]
sentences = [" ".join(tokens) for tokens in tokenized_sentences]
if limit != -1:
sentences = sentences[:limit]

gt = [item[0] for item in relations_descriptions]
# label_mapping = {l: idx for idx, l in enumerate(list(set(gt)))}
# gt = [label_mapping.get(item) for item in gt]
heads = wiki_val["head"]
tails = wiki_val["tail"]
entities_data = []
Expand Down
14 changes: 4 additions & 10 deletions zshot/tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# import pdb
from typing import Iterator, List, Tuple

import spacy
Expand Down Expand Up @@ -323,7 +322,6 @@ def get_dataset(self, gt: List[str], sentence: List[str]):
"labels": gt,
}
dataset = Dataset.from_dict(data_dict)
# dataset.entities = ENTITIES
return dataset

def test_relation_classification_prediction(self):
Expand All @@ -332,26 +330,22 @@ def test_relation_classification_prediction(self):
sentences,
relations_descriptions,
gt,
) = get_few_rel_data(split_name="val_wiki", limit=5)
) = get_few_rel_data(split_name="val_wiki[0:5]")

# pdb.set_trace()
custom_evaluator = RelationExtractorEvaluator()
# pdb.set_trace()
pipe = get_relation_extraction_pipeline(
entities_data,
[
Relation(name=name, description=desc)
for name, desc in set([(i, j) for i, j in relations_descriptions])
],
)
# pdb.set_trace()
custom_evaluator.compute(
results = custom_evaluator.compute(
pipe,
self.get_dataset(gt, sentences),
input_column="sentences",
label_column="labels",
metric=RelEval(),
)
# print("metrics: {}".format(metrics))
# pdb.set_trace()
assert True
assert len(sentences) == 5
assert results is not None

0 comments on commit c078a98

Please sign in to comment.