-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Add zero shot relations extraction (#27)
* ✨ Add zero shot relations extraction * 🎨 Code rafactor * 🐛 Fix tests * 🐛 Fix tests * 🐛 Fix tests * ✨ Fix displacy inverse rel * 🎨 Fix codestyle * 🎨 Increase coverage * 🎨 Fix code style * 🎨 Increase coverage * 🎨 Fix code style * 🐛 Fix test * performance improvements * perf improvement Signed-off-by: Gabriele Picco <piccogabriele@gmail.com> Co-authored-by: Alberto Purpura <alp@ibm.com>
- Loading branch information
1 parent
d80d1f6
commit bea91c7
Showing
38 changed files
with
1,847 additions
and
184 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from zshot.zshot import MentionsExtractor, Linker, Zshot, PipelineConfig # noqa: F401 | ||
from zshot.utils.displacy import displacy # noqa: F401 | ||
from zshot.zshot import MentionsExtractor, Linker, Zshot, PipelineConfig, RelationsExtractor # noqa: F401 | ||
from zshot.utils.displacy.displacy import displacy # noqa: F401 | ||
|
||
__version__ = '0.0.4' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from datasets import load_dataset | ||
from tqdm import tqdm | ||
|
||
from zshot import Linker | ||
from zshot.utils.data_models import Span | ||
|
||
|
||
class DummyLinkerEnd2End(Linker): | ||
@property | ||
def is_end2end(self) -> bool: | ||
return True | ||
|
||
def predict(self, data): | ||
return [ | ||
[ | ||
Span( | ||
item["start"], | ||
item["end"], | ||
item["label"], | ||
) | ||
for item in doc_ents | ||
] | ||
for doc_ents in enumerate(data) | ||
] | ||
|
||
|
||
def get_entity_data(e, tokenized_sentence): | ||
# import pdb | ||
# pdb.set_trace() | ||
d = {"start": None, "end": None, "label": e["type"]} | ||
token_indices = e["indices"][0] | ||
s = "" | ||
curr_idx = 0 | ||
for idx, token in enumerate(tokenized_sentence): | ||
if idx == token_indices[0]: | ||
d["start"] = curr_idx | ||
s += token + " " | ||
curr_idx = len(s.strip()) | ||
if idx == token_indices[-1]: | ||
d["end"] = curr_idx | ||
d["sentence"] = s.strip() | ||
# pdb.set_trace() | ||
return d | ||
|
||
|
||
def get_few_rel_data(split_name="val_wiki", limit=-1): | ||
wiki_val = load_dataset("few_rel", split=split_name) | ||
relations_descriptions = wiki_val["names"][:limit] | ||
tokenized_sentences = wiki_val["tokens"][:limit] | ||
sentences = [" ".join(tokens) for tokens in tokenized_sentences] | ||
if limit != -1: | ||
sentences = sentences[:limit] | ||
|
||
gt = [item[0] for item in relations_descriptions] | ||
# label_mapping = {l: idx for idx, l in enumerate(list(set(gt)))} | ||
# gt = [label_mapping.get(item) for item in gt] | ||
heads = wiki_val["head"] | ||
tails = wiki_val["tail"] | ||
entities_data = [] | ||
for idx in tqdm(range(len(tokenized_sentences))): | ||
e1 = heads[idx] | ||
e2 = tails[idx] | ||
entities_data.append( | ||
[ | ||
get_entity_data(e1, tokenized_sentences[idx]), | ||
get_entity_data(e2, tokenized_sentences[idx]), | ||
] | ||
) | ||
|
||
return entities_data, sentences, relations_descriptions, gt | ||
|
||
|
||
if __name__ == "__main__": | ||
get_few_rel_data() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from zshot.utils.data_models import Entity | ||
|
||
MEDMENTIONS_TYPE_INV = { | ||
'T058': "Health_Care_Activity", "T062": "Research_Activity", "T037": "Injury_or_Poisoning", | ||
"T038": "Biologic_Function", "T005": "Virus", "T007": "Bacterium", "T204": "Eukaryote", | ||
"T017": "Anotomical_Structure", "T074": "Medical_Device", "T031": "Body_Substance", "T103": "Chemical", | ||
"T168": "Food", "T201": "Clinical_Attribute", "T033": "Finding", "T082": "Spatial_Concept", | ||
"T022": "Body_System", "T091": "Biomedical_Occupation_or_Discipline", "T092": "Organization", | ||
"T097": "Professional_or_Occupational_Group", "T098": "Population_Group", "T170": "Intellectual_Product", | ||
"NEG": "NEG" | ||
} | ||
|
||
MEDMENTIONS_SPLITS = { | ||
"train": ['Biologic_Function', 'Chemical', 'Health_Care_Activity', 'Anotomical_Structure', "Finding", | ||
"Spatial_Concept", "Intellectual_Product", "Research_Activity", 'Medical_Device', 'Eukaryote', | ||
'Population_Group'], | ||
"validation": ['Biomedical_Occupation_or_Discipline', 'Virus', 'Clinical_Attribute', 'Injury_or_Poisoning', | ||
'Organization'], | ||
"test": ['Body_System', 'Food', 'Body_Substance', 'Bacterium', 'Professional_or_Occupational_Group']} | ||
|
||
MEDMENTIONS_ENTITIES = [ | ||
Entity(name="T058", description="Healthcare activity"), | ||
Entity(name="T062", description="Research activity"), | ||
Entity(name="T037", description="Injury or Poisoning"), | ||
Entity(name="T038", description="Biologic Function"), | ||
Entity(name="T005", description="Virus"), | ||
Entity(name="T007", description="Bacterium"), | ||
Entity(name="T204", description="Eukaryote"), | ||
Entity(name="T017", description="Anotomical structure"), | ||
Entity(name="T074", description="Medical device"), | ||
Entity(name="T031", description="Body substance"), | ||
Entity(name="T103", description="Chemical"), | ||
Entity(name="T168", description="Food"), | ||
Entity(name="T201", description="Clinical Attribute"), | ||
Entity(name="T033", description="Finding"), | ||
Entity(name="T082", description="Spatial Concept"), | ||
Entity(name="T022", description="Body System"), | ||
Entity(name="T091", description="Biomedical occupation or discipline"), | ||
Entity(name="T092", description="Organization"), | ||
Entity(name="T097", description="Professional or occupational group"), | ||
Entity(name="T098", description="Population group"), | ||
Entity(name="T170", description="Intellectual product") | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import os | ||
|
||
import spacy | ||
from tqdm import tqdm | ||
|
||
from zshot.evaluation.dataset.med_mentions.entities import MEDMENTIONS_TYPE_INV | ||
|
||
|
||
class Token(object): | ||
def __init__(self, word, label, label_id): | ||
self.word = word | ||
self.label = label | ||
self.label_id = label_id | ||
|
||
|
||
def convert_to_iob(id_, text, entities, nlp, end_index, start_index): | ||
count = 0 | ||
sentences = [] | ||
current_sent_pos = 0 | ||
doc = nlp(text) | ||
for sent in doc.sents: | ||
sentence = [] | ||
for tok in sent: | ||
word = tok.text | ||
ent = entities[0][1] if entities else "O" | ||
if entities and tok.idx + current_sent_pos == start_index[0]: | ||
count += 1 | ||
label = 'B-' + MEDMENTIONS_TYPE_INV[ent] | ||
label_id = 'B-' + ent | ||
token = Token(word=word, label=label, label_id=label_id) | ||
elif entities and start_index[0] < tok.idx + current_sent_pos <= end_index[0]: | ||
label = 'I-' + MEDMENTIONS_TYPE_INV[ent] | ||
label_id = 'I-' + ent | ||
token = Token(word=word, label=label, label_id=label_id) | ||
else: | ||
token = Token(word=word, label='O', label_id="O") | ||
if entities and len(word) + tok.idx + current_sent_pos >= end_index[0]: | ||
start_index.pop(0) | ||
end_index.pop(0) | ||
entities.pop(0) | ||
sentence.append(token) | ||
current_sent_pos += len(sent) + 1 | ||
sentences.append((id_, sentence)) | ||
return sentences, count | ||
|
||
|
||
def preprocess_medmentions(input_path): | ||
data_path = os.path.join(input_path, 'corpus_pubtator.txt') | ||
train_id_path = os.path.join(input_path, 'corpus_pubtator_pmids_train.txt') | ||
dev_id_path = os.path.join(input_path, 'corpus_pubtator_pmids_dev.txt') | ||
test_id_path = os.path.join(input_path, 'corpus_pubtator_pmids_test.txt') | ||
|
||
nlp = spacy.load("en_core_web_sm") | ||
sentences = [] | ||
with open(data_path, 'r') as f: | ||
data = f.readlines() | ||
ids = [] | ||
titles = [] | ||
abstracts = [] | ||
entities_title = [] | ||
entities_abstract = [] | ||
ends_title = [] | ||
ends_abstract = [] | ||
starts_title = [] | ||
starts_abstract = [] | ||
id_tmp = None | ||
starts_title_tmp = [] | ||
starts_abstract_tmp = [] | ||
ends_title_tmp = [] | ||
ends_abstract_tmp = [] | ||
entities_title_tmp = [] | ||
entities_abstract_tmp = [] | ||
for line in data: | ||
if '|t|' in line: | ||
id_tmp = line.split("|t|")[0] | ||
ids.append(id_tmp) | ||
titles.append(line.split("|t|")[1]) | ||
elif '|a|' in line: | ||
abstracts.append(line.split('|a|')[1]) | ||
elif id_tmp in line: | ||
_, start_idx, end_idx, ent, label, _ = line.split('\t') | ||
if len(ends_title_tmp) > 0 and int(end_idx) <= ends_title_tmp[-1] or \ | ||
len(ends_abstract_tmp) > 0 and int(end_idx) - len(titles[-1]) <= ends_abstract_tmp[-1]: | ||
continue | ||
if int(end_idx) < len(titles[-1]): | ||
starts_title_tmp.append(int(start_idx)) | ||
ends_title_tmp.append(int(end_idx)) | ||
entities_title_tmp.append((ent, label)) | ||
else: | ||
starts_abstract_tmp.append(int(start_idx) - len(titles[-1])) | ||
ends_abstract_tmp.append(int(end_idx) - len(titles[-1])) | ||
entities_abstract_tmp.append((ent, label)) | ||
else: | ||
starts_title.append(starts_title_tmp) | ||
ends_title.append(ends_title_tmp) | ||
entities_title.append(entities_title_tmp) | ||
starts_abstract.append(starts_abstract_tmp) | ||
ends_abstract.append(ends_abstract_tmp) | ||
entities_abstract.append(entities_abstract_tmp) | ||
starts_title_tmp = [] | ||
starts_abstract_tmp = [] | ||
ends_title_tmp = [] | ||
ends_abstract_tmp = [] | ||
entities_title_tmp = [] | ||
entities_abstract_tmp = [] | ||
id_tmp = None | ||
|
||
count = 0 | ||
for id_, title, entities_title_tmp, start_title, end_title, abstract, entities_abstract_tmp, \ | ||
start_abstract, end_abstract in tqdm(zip(ids, titles, entities_title, starts_title, ends_title, | ||
abstracts, entities_abstract, starts_abstract, ends_abstract)): | ||
sentences_title, count_t = convert_to_iob(id_, title, entities_title_tmp, nlp, | ||
end_title, start_title) | ||
sentences_abstract, count_a = convert_to_iob(id_, abstract, entities_abstract_tmp, nlp, | ||
end_abstract, start_abstract) | ||
count += count_a + count_t | ||
sentences += sentences_title | ||
sentences += sentences_abstract | ||
|
||
with open(train_id_path, "r") as f: | ||
train_ids = [line.strip() for line in f.readlines()] | ||
train_sentences = [sent for id_, sent in sentences if id_ in train_ids] | ||
with open(dev_id_path, "r") as f: | ||
dev_ids = [line.strip() for line in f.readlines()] | ||
dev_sentences = [sent for id_, sent in sentences if id_ in dev_ids] | ||
with open(test_id_path, "r") as f: | ||
test_ids = [line.strip() for line in f.readlines()] | ||
test_sentences = [sent for id_, sent in sentences if id_ in test_ids] | ||
|
||
return train_sentences, dev_sentences, test_sentences |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import evaluate | ||
from sklearn.metrics import precision_recall_fscore_support | ||
from sklearn.metrics import accuracy_score | ||
import datasets | ||
|
||
_KWARGS_DESCRIPTION = """ | ||
Produces labelling scores along with its sufficient statistics | ||
from a source against one or more references. | ||
Args: | ||
predictions: List of List of predicted labels (Estimated targets as returned by a tagger) | ||
references: List of List of reference labels (Ground truth (correct) target values) | ||
""" | ||
|
||
|
||
class RelEval(evaluate.Metric): | ||
def _info(self): | ||
return evaluate.MetricInfo( | ||
description="RelEval is a framework for relation extraction methods evaluation.", | ||
inputs_description=_KWARGS_DESCRIPTION, | ||
citation="alp@ibm.com", | ||
features=datasets.Features( | ||
{ | ||
"predictions": datasets.Value("string", id="label"), | ||
"references": datasets.Value("string", id="label"), | ||
} | ||
), | ||
) | ||
|
||
def _compute( | ||
self, | ||
predictions, | ||
references, | ||
): | ||
scores = {} | ||
p, r, f1, _ = precision_recall_fscore_support( | ||
references, predictions, average="micro" | ||
) | ||
scores["overall_precision_micro"] = p | ||
scores["overall_recall_micro"] = r | ||
scores["overall_f1_micro"] = f1 | ||
|
||
p, r, f1, _ = precision_recall_fscore_support( | ||
references, predictions, average="macro" | ||
) | ||
scores["overall_precision_macro"] = p | ||
scores["overall_recall_macro"] = r | ||
scores["overall_f1_macro"] = f1 | ||
|
||
acc = accuracy_score(references, predictions, normalize=False) | ||
scores["overall_accuracy"] = acc | ||
|
||
return scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.