Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/relation extractor evaluation #39

Merged
merged 14 commits into from
Dec 7, 2022
43 changes: 17 additions & 26 deletions zshot/evaluation/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,27 @@
from typing import List, Optional, Dict
from typing import List

import datasets
from datasets import Split, Dataset
from datasets import Dataset
from datasets.table import Table

from zshot.utils.data_models import Entity
from zshot.utils.data_models import Entity, Relation


class DatasetWithEntities(datasets.Dataset):
class DatasetWithRelations(Dataset):

def __init__(self, arrow_table: Table,
info: Optional[datasets.info.DatasetInfo] = None,
split: Optional[datasets.splits.NamedSplit] = None,
indices_table: Optional[Table] = None,
fingerprint: Optional[str] = None,
entities: List[Entity] = None):
super().__init__(arrow_table=arrow_table, info=info, split=split,
indices_table=indices_table, fingerprint=fingerprint)
self.entities = entities
def __init__(self, arrow_table: Table, relations: List[Relation] = None, **kwargs):
super().__init__(arrow_table=arrow_table, **kwargs)
self.relations = relations

def __repr__(self):
return f"Dataset({{\n features: {list(self.features.keys())},\n num_rows: {self.num_rows}," \
f"\n entities: {[ent.name for ent in self.relations if self.relations is not None]}\n}})"

@classmethod
def from_dict(
cls,
mapping: dict,
features: Optional[datasets.features.Features] = None,
info: Optional[datasets.info.DatasetInfo] = None,
split: Optional[Split] = None,
entities: List[Dict[str, str]] = None
) -> Dataset:
dataset = super().from_dict(mapping=mapping, features=features, info=info, split=split)
dataset.entities = entities
return dataset

class DatasetWithEntities(Dataset):

def __init__(self, arrow_table: Table, entities: List[Entity] = None, **kwargs):
super().__init__(arrow_table=arrow_table, **kwargs)
self.entities = entities

def __repr__(self):
return f"Dataset({{\n features: {list(self.features.keys())},\n num_rows: {self.num_rows}," \
Expand Down
62 changes: 20 additions & 42 deletions zshot/evaluation/dataset/fewrel/fewrel.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,13 @@
from datasets import load_dataset
from tqdm import tqdm

from zshot import Linker
from zshot.utils.data_models import Span

from typing import Optional, Union, Dict

class DummyLinkerEnd2End(Linker):
@property
def is_end2end(self) -> bool:
return True
from datasets import load_dataset, Split, Dataset
from tqdm import tqdm

def predict(self, data):
return [
[
Span(
item["start"],
item["end"],
item["label"],
)
for item in doc_ents
]
for doc_ents in enumerate(data)
]
from zshot.evaluation.dataset.dataset import DatasetWithRelations
from zshot.utils.data_models import Relation


def get_entity_data(e, tokenized_sentence):
# import pdb
# pdb.set_trace()
d = {"start": None, "end": None, "label": e["type"]}
token_indices = e["indices"][0]
s = ""
Expand All @@ -39,23 +20,18 @@ def get_entity_data(e, tokenized_sentence):
if idx == token_indices[-1]:
d["end"] = curr_idx
d["sentence"] = s.strip()
# pdb.set_trace()
return d


def get_few_rel_data(split_name="val_wiki", limit=-1):
wiki_val = load_dataset("few_rel", split=split_name)
relations_descriptions = wiki_val["names"][:limit]
tokenized_sentences = wiki_val["tokens"][:limit]
def get_few_rel_data(split_name: Optional[Union[str, Split]] = "val_wiki") -> Union[Dict[DatasetWithRelations,
Dataset], Dataset]:
dataset = load_dataset("few_rel", split=split_name)
relations_descriptions = dataset["names"]
tokenized_sentences = dataset["tokens"]
sentences = [" ".join(tokens) for tokens in tokenized_sentences]
if limit != -1:
sentences = sentences[:limit]

gt = [item[0] for item in relations_descriptions]
# label_mapping = {l: idx for idx, l in enumerate(list(set(gt)))}
# gt = [label_mapping.get(item) for item in gt]
heads = wiki_val["head"]
tails = wiki_val["tail"]
heads = dataset["head"]
tails = dataset["tail"]
entities_data = []
for idx in tqdm(range(len(tokenized_sentences))):
e1 = heads[idx]
Expand All @@ -66,9 +42,11 @@ def get_few_rel_data(split_name="val_wiki", limit=-1):
get_entity_data(e2, tokenized_sentences[idx]),
]
)

return entities_data, sentences, relations_descriptions, gt


if __name__ == "__main__":
get_few_rel_data()
relations = [Relation(name=name, description=desc) for name, desc in set([(i, j) for i, j in relations_descriptions])]
dataset = Dataset.from_dict({
"sentences": sentences,
"sentence_entities": entities_data,
"labels": gt,
})
dataset.relations = relations
return dataset
10 changes: 6 additions & 4 deletions zshot/evaluation/dataset/ontonotes/onto_notes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def remove_out_of_split(sentence, split):

def load_ontonotes_zs(split: Optional[Union[str, Split]] = None, **kwargs) -> Union[Dict[DatasetWithEntities,
Dataset], Dataset]:
dataset_zs = load_dataset("conll2012_ontonotesv5", "english_v12", split=split, **kwargs)
dataset_zs = load_dataset("conll2012_ontonotesv5", "english_v12", split=split, ignore_verifications=True, **kwargs)
if split:
ontonotes_zs = preprocess_spit(dataset_zs, get_simple_split(split))
else:
Expand All @@ -67,7 +67,7 @@ def load_ontonotes_zs(split: Optional[Union[str, Split]] = None, **kwargs) -> Un
return ontonotes_zs


def preprocess_spit(dataset, split):
def preprocess_spit(dataset, split) -> DatasetWithEntities:
dataset = dataset.map(lambda example, idx: {
"sentences": [remove_out_of_split(s, split) for s in example['sentences']]
}, with_indices=True)
Expand All @@ -81,10 +81,12 @@ def preprocess_spit(dataset, split):
ner_tags += [[labels.int2str(ent) for ent in s['named_entities']] for s in example['sentences']]
split_entities = [ent for ent in ONTONOTES_ENTITIES
if ent.name in ['NEG'] + CLASSES_PER_SPLIT[split] and ent.name not in TRIVIAL_CLASSES]
return DatasetWithEntities.from_dict({
dataset = Dataset.from_dict({
'tokens': tokens,
'ner_tags': ner_tags
}, split=split, entities=split_entities)
}, split=split)
dataset.entities = split_entities
return dataset


def get_simple_split(split: str) -> str:
Expand Down
7 changes: 7 additions & 0 deletions zshot/evaluation/metrics/rel_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,11 @@ def _compute(
acc = accuracy_score(references, predictions, normalize=False)
scores["overall_accuracy"] = acc

lab = sorted(list(set(references)))
p, r, f1, supp = precision_recall_fscore_support(
references, predictions, average=None, labels=lab
)
for idx, lab in enumerate(lab):
scores[lab] = {'precision': p[idx], 'recall': r[idx],
'f1': f1[idx], 'number': supp[idx]}
return scores
2 changes: 1 addition & 1 deletion zshot/evaluation/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(self, *args, **kwargs):
class RelationExtractorPipeline:
def __init__(self, nlp, batch_size=100):
self.nlp = nlp
self.task = "text-classification"
self.task = "relation-extraction"
self.batch_size = batch_size

def __call__(self, *args, **kwargs):
Expand Down
6 changes: 4 additions & 2 deletions zshot/relation_extractor/relation_extractor_zsrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def predict(self, docs: Iterator[Doc], batch_size=None) -> List[List[RelationSpa
for i, e1 in enumerate(doc._.spans):
for j, e2 in enumerate(doc._.spans):
if (
i == j or (e1, e2) in items_to_process or (e2, e1) in items_to_process
i == j or (e1, e2) in items_to_process or (
e2, e1) in items_to_process
):
continue
else:
Expand All @@ -49,7 +50,8 @@ def predict(self, docs: Iterator[Doc], batch_size=None) -> List[List[RelationSpa
p = relations_probs[pred_class_idx]
if p >= self.thr:
relations_doc.append(
RelationSpan(start=e1, end=e2, score=p, relation=self.relations[pred_class_idx])
RelationSpan(
start=e1, end=e2, score=p, relation=self.relations[pred_class_idx])
)
relations_pred.append(relations_doc)
return relations_pred
20 changes: 0 additions & 20 deletions zshot/relation_extractor/zsrc/zero_shot_rel_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def get_device():


device = get_device()
# torch.use_deterministic_algorithms(True)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
Expand Down Expand Up @@ -93,17 +92,6 @@ def predict(model, items_to_process, relation_description, batch_size=4):
return all_preds, all_probs


def softmax(x):
return np.exp(x) / sum(np.exp(x))


# def download_file_to_path(source_url, dest_path):
# dest_dir = os.path.dirname(dest_path)
# if not os.path.exists(dest_dir):
# os.makedirs(dest_dir)
# urllib.request.urlretrieve(source_url, dest_path)


def load_model():
model = ZSBert()
if not os.path.isfile(MODEL_PATH):
Expand Down Expand Up @@ -185,11 +173,3 @@ def extract_entity(sequence_output, e_mask):
outputs = (loss,) + outputs

return outputs


random.seed(seed)
device = get_device()


if __name__ == '__main__':
load_model()
5 changes: 5 additions & 0 deletions zshot/tests/evaluation/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def test_ontonotes_zs_split():
assert dataset.num_rows == 426


def test_ontonotes_zs_sub_split():
dataset = load_ontonotes_zs(split='test[0:10]')
assert dataset.num_rows > 0


def test_medmentions_zs():
dataset = load_medmentions_zs()
assert 'train' in dataset
Expand Down
61 changes: 20 additions & 41 deletions zshot/tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# import pdb
from typing import Iterator, List, Tuple

import spacy
from datasets import Dataset
from spacy.tokens import Doc

from zshot import Linker, MentionsExtractor, PipelineConfig
from zshot.evaluation.dataset.dataset import DatasetWithRelations
from zshot.evaluation.dataset.fewrel.fewrel import get_few_rel_data
from zshot.evaluation.evaluator import (
MentionsExtractorEvaluator,
RelationExtractorEvaluator,
ZeroShotTokenClassificationEvaluator,
)
from zshot.evaluation.metrics.rel_eval import RelEval
from zshot.evaluation.pipeline import (
LinkerPipeline,
MentionsExtractorPipeline,
Expand All @@ -20,8 +21,6 @@
from zshot.relation_extractor.relation_extractor_zsrc import RelationsExtractorZSRC
from zshot.utils.alignment_utils import AlignmentMode
from zshot.utils.data_models import Entity, Span
from zshot.utils.data_models.relation import Relation
from zshot.evaluation.metrics.rel_eval import RelEval

ENTITIES = [
Entity(name="FAC", description="A facility"),
Expand Down Expand Up @@ -59,16 +58,18 @@ def is_end2end(self) -> bool:
return True

def __init__(self, predictions):
# this dummy linker works correctly ONLY if no shuffling is done by spacy when batching documents
super().__init__()
self.predictions = predictions
self.curr_idx = 0

def predict(self, docs, batch_size=100):
rval = []
for data in self.predictions:
# pdb.set_trace()
for _ in docs:
rval.append(
[Span(item["start"], item["end"], item["label"]) for item in data]
[Span(item["start"], item["end"], item["label"]) for item in self.predictions[self.curr_idx]]
)
self.curr_idx += 1
return rval


Expand Down Expand Up @@ -116,13 +117,14 @@ def get_mentions_extractor_pipe(predictions: List[Tuple[str, str, float]]):
return MentionsExtractorPipeline(nlp)


def get_relation_extraction_pipeline(predictions, relations):
def get_relation_extraction_pipeline(dataset: DatasetWithRelations):

nlp = spacy.blank("en")
nlp_config = PipelineConfig(
relations_extractor=RelationsExtractorZSRC(thr=0.0),
linker=DummyLinkerEnd2EndForEval(predictions),
relations=relations,
) # [Relation(name="part_of", description="is an instance of something or part of it"), Relation(name="is_in", description="located in, based in"),],)
linker=DummyLinkerEnd2EndForEval(dataset["sentence_entities"]),
relations=dataset.relations,
)
nlp.add_pipe("zshot", config=nlp_config, last=True)
return RelationExtractorPipeline(nlp)

Expand Down Expand Up @@ -317,41 +319,18 @@ def test_prediction_token_based_evaluation_partial_match_spans_expand(self):


class TestZeroShotTextClassificationEvaluation:
def get_dataset(self, gt: List[str], sentence: List[str]):
data_dict = {
"sentences": sentence,
"labels": gt,
}
dataset = Dataset.from_dict(data_dict)
# dataset.entities = ENTITIES
return dataset

def test_relation_classification_prediction(self):
(
entities_data,
sentences,
relations_descriptions,
gt,
) = get_few_rel_data(split_name="val_wiki", limit=5)

# pdb.set_trace()
custom_evaluator = RelationExtractorEvaluator(task="text-classification")
# pdb.set_trace()
pipe = get_relation_extraction_pipeline(
entities_data,
[
Relation(name=name, description=desc)
for name, desc in set([(i, j) for i, j in relations_descriptions])
],
)
# pdb.set_trace()
custom_evaluator.compute(
dataset = get_few_rel_data(split_name="val_wiki[0:5]")

custom_evaluator = RelationExtractorEvaluator()
pipe = get_relation_extraction_pipeline(dataset)
results = custom_evaluator.compute(
pipe,
self.get_dataset(gt, sentences),
dataset,
input_column="sentences",
label_column="labels",
metric=RelEval(),
)
# print("metrics: {}".format(metrics))
# pdb.set_trace()
assert True
assert len(dataset) == 5
assert results is not None
Loading