Skip to content

Commit

Permalink
✨ Add zero shot relations extraction (#27)
Browse files Browse the repository at this point in the history
* ✨ Add zero shot relations extraction

* 🎨 Code rafactor

* 🐛 Fix tests

* 🐛 Fix tests

* 🐛 Fix tests

* ✨ Fix displacy inverse rel

* 🎨 Fix codestyle

* 🎨 Increase coverage

* 🎨 Fix code style

* 🎨 Increase coverage

* 🎨 Fix code style

* 🐛 Fix test

* performance improvements

* perf improvement

Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>
Co-authored-by: Alberto Purpura <alp@ibm.com>
  • Loading branch information
GabrielePicco and Alberto Purpura authored Nov 8, 2022
1 parent d80d1f6 commit bea91c7
Show file tree
Hide file tree
Showing 38 changed files with 1,847 additions and 184 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
"prettytable>=3.4",
"torch>=1",
"transformers>=4.20",
"datasets>=2.2.2",
"evaluate>=0.2.2",
"datasets>=2.3.0",
"evaluate>=0.3.0",
"seqeval>=1.2.2",
],
entry_points="""
Expand Down
4 changes: 2 additions & 2 deletions zshot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from zshot.zshot import MentionsExtractor, Linker, Zshot, PipelineConfig # noqa: F401
from zshot.utils.displacy import displacy # noqa: F401
from zshot.zshot import MentionsExtractor, Linker, Zshot, PipelineConfig, RelationsExtractor # noqa: F401
from zshot.utils.displacy.displacy import displacy # noqa: F401

__version__ = '0.0.4'
74 changes: 74 additions & 0 deletions zshot/evaluation/dataset/fewrel/fewrel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from datasets import load_dataset
from tqdm import tqdm

from zshot import Linker
from zshot.utils.data_models import Span


class DummyLinkerEnd2End(Linker):
@property
def is_end2end(self) -> bool:
return True

def predict(self, data):
return [
[
Span(
item["start"],
item["end"],
item["label"],
)
for item in doc_ents
]
for doc_ents in enumerate(data)
]


def get_entity_data(e, tokenized_sentence):
# import pdb
# pdb.set_trace()
d = {"start": None, "end": None, "label": e["type"]}
token_indices = e["indices"][0]
s = ""
curr_idx = 0
for idx, token in enumerate(tokenized_sentence):
if idx == token_indices[0]:
d["start"] = curr_idx
s += token + " "
curr_idx = len(s.strip())
if idx == token_indices[-1]:
d["end"] = curr_idx
d["sentence"] = s.strip()
# pdb.set_trace()
return d


def get_few_rel_data(split_name="val_wiki", limit=-1):
wiki_val = load_dataset("few_rel", split=split_name)
relations_descriptions = wiki_val["names"][:limit]
tokenized_sentences = wiki_val["tokens"][:limit]
sentences = [" ".join(tokens) for tokens in tokenized_sentences]
if limit != -1:
sentences = sentences[:limit]

gt = [item[0] for item in relations_descriptions]
# label_mapping = {l: idx for idx, l in enumerate(list(set(gt)))}
# gt = [label_mapping.get(item) for item in gt]
heads = wiki_val["head"]
tails = wiki_val["tail"]
entities_data = []
for idx in tqdm(range(len(tokenized_sentences))):
e1 = heads[idx]
e2 = tails[idx]
entities_data.append(
[
get_entity_data(e1, tokenized_sentences[idx]),
get_entity_data(e2, tokenized_sentences[idx]),
]
)

return entities_data, sentences, relations_descriptions, gt


if __name__ == "__main__":
get_few_rel_data()
43 changes: 43 additions & 0 deletions zshot/evaluation/dataset/med_mentions/entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from zshot.utils.data_models import Entity

MEDMENTIONS_TYPE_INV = {
'T058': "Health_Care_Activity", "T062": "Research_Activity", "T037": "Injury_or_Poisoning",
"T038": "Biologic_Function", "T005": "Virus", "T007": "Bacterium", "T204": "Eukaryote",
"T017": "Anotomical_Structure", "T074": "Medical_Device", "T031": "Body_Substance", "T103": "Chemical",
"T168": "Food", "T201": "Clinical_Attribute", "T033": "Finding", "T082": "Spatial_Concept",
"T022": "Body_System", "T091": "Biomedical_Occupation_or_Discipline", "T092": "Organization",
"T097": "Professional_or_Occupational_Group", "T098": "Population_Group", "T170": "Intellectual_Product",
"NEG": "NEG"
}

MEDMENTIONS_SPLITS = {
"train": ['Biologic_Function', 'Chemical', 'Health_Care_Activity', 'Anotomical_Structure', "Finding",
"Spatial_Concept", "Intellectual_Product", "Research_Activity", 'Medical_Device', 'Eukaryote',
'Population_Group'],
"validation": ['Biomedical_Occupation_or_Discipline', 'Virus', 'Clinical_Attribute', 'Injury_or_Poisoning',
'Organization'],
"test": ['Body_System', 'Food', 'Body_Substance', 'Bacterium', 'Professional_or_Occupational_Group']}

MEDMENTIONS_ENTITIES = [
Entity(name="T058", description="Healthcare activity"),
Entity(name="T062", description="Research activity"),
Entity(name="T037", description="Injury or Poisoning"),
Entity(name="T038", description="Biologic Function"),
Entity(name="T005", description="Virus"),
Entity(name="T007", description="Bacterium"),
Entity(name="T204", description="Eukaryote"),
Entity(name="T017", description="Anotomical structure"),
Entity(name="T074", description="Medical device"),
Entity(name="T031", description="Body substance"),
Entity(name="T103", description="Chemical"),
Entity(name="T168", description="Food"),
Entity(name="T201", description="Clinical Attribute"),
Entity(name="T033", description="Finding"),
Entity(name="T082", description="Spatial Concept"),
Entity(name="T022", description="Body System"),
Entity(name="T091", description="Biomedical occupation or discipline"),
Entity(name="T092", description="Organization"),
Entity(name="T097", description="Professional or occupational group"),
Entity(name="T098", description="Population group"),
Entity(name="T170", description="Intellectual product")
]
130 changes: 130 additions & 0 deletions zshot/evaluation/dataset/med_mentions/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os

import spacy
from tqdm import tqdm

from zshot.evaluation.dataset.med_mentions.entities import MEDMENTIONS_TYPE_INV


class Token(object):
def __init__(self, word, label, label_id):
self.word = word
self.label = label
self.label_id = label_id


def convert_to_iob(id_, text, entities, nlp, end_index, start_index):
count = 0
sentences = []
current_sent_pos = 0
doc = nlp(text)
for sent in doc.sents:
sentence = []
for tok in sent:
word = tok.text
ent = entities[0][1] if entities else "O"
if entities and tok.idx + current_sent_pos == start_index[0]:
count += 1
label = 'B-' + MEDMENTIONS_TYPE_INV[ent]
label_id = 'B-' + ent
token = Token(word=word, label=label, label_id=label_id)
elif entities and start_index[0] < tok.idx + current_sent_pos <= end_index[0]:
label = 'I-' + MEDMENTIONS_TYPE_INV[ent]
label_id = 'I-' + ent
token = Token(word=word, label=label, label_id=label_id)
else:
token = Token(word=word, label='O', label_id="O")
if entities and len(word) + tok.idx + current_sent_pos >= end_index[0]:
start_index.pop(0)
end_index.pop(0)
entities.pop(0)
sentence.append(token)
current_sent_pos += len(sent) + 1
sentences.append((id_, sentence))
return sentences, count


def preprocess_medmentions(input_path):
data_path = os.path.join(input_path, 'corpus_pubtator.txt')
train_id_path = os.path.join(input_path, 'corpus_pubtator_pmids_train.txt')
dev_id_path = os.path.join(input_path, 'corpus_pubtator_pmids_dev.txt')
test_id_path = os.path.join(input_path, 'corpus_pubtator_pmids_test.txt')

nlp = spacy.load("en_core_web_sm")
sentences = []
with open(data_path, 'r') as f:
data = f.readlines()
ids = []
titles = []
abstracts = []
entities_title = []
entities_abstract = []
ends_title = []
ends_abstract = []
starts_title = []
starts_abstract = []
id_tmp = None
starts_title_tmp = []
starts_abstract_tmp = []
ends_title_tmp = []
ends_abstract_tmp = []
entities_title_tmp = []
entities_abstract_tmp = []
for line in data:
if '|t|' in line:
id_tmp = line.split("|t|")[0]
ids.append(id_tmp)
titles.append(line.split("|t|")[1])
elif '|a|' in line:
abstracts.append(line.split('|a|')[1])
elif id_tmp in line:
_, start_idx, end_idx, ent, label, _ = line.split('\t')
if len(ends_title_tmp) > 0 and int(end_idx) <= ends_title_tmp[-1] or \
len(ends_abstract_tmp) > 0 and int(end_idx) - len(titles[-1]) <= ends_abstract_tmp[-1]:
continue
if int(end_idx) < len(titles[-1]):
starts_title_tmp.append(int(start_idx))
ends_title_tmp.append(int(end_idx))
entities_title_tmp.append((ent, label))
else:
starts_abstract_tmp.append(int(start_idx) - len(titles[-1]))
ends_abstract_tmp.append(int(end_idx) - len(titles[-1]))
entities_abstract_tmp.append((ent, label))
else:
starts_title.append(starts_title_tmp)
ends_title.append(ends_title_tmp)
entities_title.append(entities_title_tmp)
starts_abstract.append(starts_abstract_tmp)
ends_abstract.append(ends_abstract_tmp)
entities_abstract.append(entities_abstract_tmp)
starts_title_tmp = []
starts_abstract_tmp = []
ends_title_tmp = []
ends_abstract_tmp = []
entities_title_tmp = []
entities_abstract_tmp = []
id_tmp = None

count = 0
for id_, title, entities_title_tmp, start_title, end_title, abstract, entities_abstract_tmp, \
start_abstract, end_abstract in tqdm(zip(ids, titles, entities_title, starts_title, ends_title,
abstracts, entities_abstract, starts_abstract, ends_abstract)):
sentences_title, count_t = convert_to_iob(id_, title, entities_title_tmp, nlp,
end_title, start_title)
sentences_abstract, count_a = convert_to_iob(id_, abstract, entities_abstract_tmp, nlp,
end_abstract, start_abstract)
count += count_a + count_t
sentences += sentences_title
sentences += sentences_abstract

with open(train_id_path, "r") as f:
train_ids = [line.strip() for line in f.readlines()]
train_sentences = [sent for id_, sent in sentences if id_ in train_ids]
with open(dev_id_path, "r") as f:
dev_ids = [line.strip() for line in f.readlines()]
dev_sentences = [sent for id_, sent in sentences if id_ in dev_ids]
with open(test_id_path, "r") as f:
test_ids = [line.strip() for line in f.readlines()]
test_sentences = [sent for id_, sent in sentences if id_ in test_ids]

return train_sentences, dev_sentences, test_sentences
22 changes: 20 additions & 2 deletions zshot/evaluation/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict, List, Union

from datasets import Dataset
from evaluate import TokenClassificationEvaluator
from evaluate import (Evaluator, TokenClassificationEvaluator)

from zshot.utils.alignment_utils import filter_overlapping_spans, AlignmentMode
from zshot.utils.alignment_utils import AlignmentMode, filter_overlapping_spans
from zshot.utils.data_models import Span


Expand Down Expand Up @@ -56,3 +56,21 @@ def prepare_data(self, data: Union[str, Dataset], input_column: str, label_colum
for sent in metric_inputs['references']]

return metric_inputs, pipeline_inputs


class RelationExtractorEvaluator(Evaluator):
def __init__(self, task="relation-extraction", default_metric_name=None):
super().__init__(task, default_metric_name)

def predictions_processor(self, predictions: List[List[Dict]], sentences: List[List[str]]):
return {"predictions": predictions}

def prepare_pipeline(
self,
model_or_pipeline, # noqa: F821
tokenizer=None, # noqa: F821
feature_extractor=None, # noqa: F821
device: int = None,
):
pipe = super().prepare_pipeline(model_or_pipeline)
return pipe
52 changes: 52 additions & 0 deletions zshot/evaluation/metrics/rel_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import evaluate
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
import datasets

_KWARGS_DESCRIPTION = """
Produces labelling scores along with its sufficient statistics
from a source against one or more references.
Args:
predictions: List of List of predicted labels (Estimated targets as returned by a tagger)
references: List of List of reference labels (Ground truth (correct) target values)
"""


class RelEval(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description="RelEval is a framework for relation extraction methods evaluation.",
inputs_description=_KWARGS_DESCRIPTION,
citation="alp@ibm.com",
features=datasets.Features(
{
"predictions": datasets.Value("string", id="label"),
"references": datasets.Value("string", id="label"),
}
),
)

def _compute(
self,
predictions,
references,
):
scores = {}
p, r, f1, _ = precision_recall_fscore_support(
references, predictions, average="micro"
)
scores["overall_precision_micro"] = p
scores["overall_recall_micro"] = r
scores["overall_f1_micro"] = f1

p, r, f1, _ = precision_recall_fscore_support(
references, predictions, average="macro"
)
scores["overall_precision_macro"] = p
scores["overall_recall_macro"] = r
scores["overall_f1_macro"] = f1

acc = accuracy_score(references, predictions, normalize=False)
scores["overall_accuracy"] = acc

return scores
1 change: 0 additions & 1 deletion zshot/evaluation/metrics/seqeval/seqeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
IOE1
IOE2
IOBES
See the [README.md] file at https://github.com/chakki-works/seqeval for more information.
"""

_KWARGS_DESCRIPTION = """
Expand Down
Loading

0 comments on commit bea91c7

Please sign in to comment.