Skip to content

Commit

Permalink
✅ Added datasets tests. Added Mentions extractor pipeline and evaluat…
Browse files Browse the repository at this point in the history
…or tests

Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com>
  • Loading branch information
marmg committed Sep 28, 2022
1 parent af73b3d commit ca3ad9a
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 9 deletions.
21 changes: 21 additions & 0 deletions zshot/tests/evaluation/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from zshot.evaluation import load_ontonotes, load_medmentions


def test_ontonotes():
dataset = load_ontonotes()
assert 'train' in dataset
assert 'test' in dataset
assert 'validation' in dataset
assert dataset['train'].num_rows == 41475
assert dataset['test'].num_rows == 426
assert dataset['validation'].num_rows == 1358


def test_medmentions():
dataset = load_medmentions()
assert 'train' in dataset
assert 'test' in dataset
assert 'validation' in dataset
assert dataset['train'].num_rows == 30923
assert dataset['test'].num_rows == 10304
assert dataset['validation'].num_rows == 10171
112 changes: 103 additions & 9 deletions zshot/tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from datasets import Dataset
from spacy.tokens import Doc

from zshot import PipelineConfig, Linker
from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator
from zshot.evaluation.pipeline import LinkerPipeline
from zshot import PipelineConfig, Linker, MentionsExtractor
from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator, MentionsExtractorEvaluator
from zshot.evaluation.pipeline import LinkerPipeline, MentionsExtractorPipeline
from zshot.utils.alignment_utils import AlignmentMode
from zshot.utils.data_models import Entity, Span

Expand Down Expand Up @@ -34,7 +34,26 @@ def predict(self, docs: Iterator[Doc], batch_size=None):
return sentences


def get_pipe(predictions: List[Tuple[str, str, float]]):
class DummyMentionsExtractor(MentionsExtractor):

def __init__(self, predictions: List[Tuple[str, str, float]]):
super().__init__()
self.predictions = predictions

def predict(self, docs: Iterator[Doc], batch_size=None):
sentences = []
for doc in docs:
preds = []
for span, label, score in self.predictions:
if span in doc.text:
preds.append(
Span(doc.text.find(span), doc.text.find(span) + len(span), label="MENTION", score=score))
sentences.append(preds)

return sentences


def get_linker_pipe(predictions: List[Tuple[str, str, float]]):
nlp = spacy.blank("en")
nlp_config = PipelineConfig(
linker=DummyLinker(predictions),
Expand All @@ -46,6 +65,18 @@ def get_pipe(predictions: List[Tuple[str, str, float]]):
return LinkerPipeline(nlp)


def get_mentions_extractor_pipe(predictions: List[Tuple[str, str, float]]):
nlp = spacy.blank("en")
nlp_config = PipelineConfig(
mentions_extractor=DummyMentionsExtractor(predictions),
entities=ENTITIES
)

nlp.add_pipe("zshot", config=nlp_config, last=True)

return MentionsExtractorPipeline(nlp)


def get_spans_predictions(span: str, label: str, sentence: str):
return [{'start': sentence.find(span),
'end': sentence.find(span) + len(span),
Expand Down Expand Up @@ -82,7 +113,7 @@ def test_prediction_token_based_evaluation_all_matching(self):
dataset = get_dataset(gt, sentences)

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
metrics = custom_evaluator.compute(get_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -96,7 +127,8 @@ def test_prediction_token_based_evaluation_overlapping_spans(self):
dataset = get_dataset(gt, sentences)

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
metrics = custom_evaluator.compute(get_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]), dataset, "seqeval")
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]), dataset,
"seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -111,7 +143,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_expand(self):

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
alignment_mode=AlignmentMode.expand)
pipe = get_pipe([('New Yo', 'FAC', 1)])
pipe = get_linker_pipe([('New Yo', 'FAC', 1)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -127,7 +159,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_contract(self):

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
alignment_mode=AlignmentMode.contract)
pipe = get_pipe([('New York i', 'FAC', 1)])
pipe = get_linker_pipe([('New York i', 'FAC', 1)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -143,7 +175,69 @@ def test_prediction_token_based_evaluation_partial_and_overlapping_spans(self):

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
alignment_mode=AlignmentMode.contract)
pipe = get_pipe([('New York i', 'FAC', 1), ('w York', 'LOC', 0.7)])
pipe = get_linker_pipe([('New York i', 'FAC', 1), ('w York', 'LOC', 0.7)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_f1"]) == 1.0
assert float(metrics["overall_accuracy"]) == 1.0


class TestMentionsExtractorEvaluator:

def test_prepare_data(self):
sentences = ['New York is beautiful']
gt = [['B-FAC', 'I-FAC', 'O', 'O']]
processed_gt = [['B-MENTION', 'I-MENTION', 'O', 'O']]

dataset = get_dataset(gt, sentences)

custom_evaluator = MentionsExtractorEvaluator("token-classification")

preds = custom_evaluator.prepare_data(dataset,
input_column="tokens", label_column="ner_tags",
join_by=" ")
assert preds[0]['references'] == processed_gt

def test_prediction_token_based_evaluation_all_matching(self):
sentences = ['New York is beautiful']
gt = [['B-FAC', 'I-FAC', 'O', 'O']]

dataset = get_dataset(gt, sentences)

custom_evaluator = MentionsExtractorEvaluator("token-classification")
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_f1"]) == 1.0
assert float(metrics["overall_accuracy"]) == 1.0

def test_prediction_token_based_evaluation_overlapping_spans(self):
sentences = ['New York is beautiful']
gt = [['B-FAC', 'I-FAC', 'O', 'O']]

dataset = get_dataset(gt, sentences)

custom_evaluator = MentionsExtractorEvaluator("token-classification")
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]),
dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_f1"]) == 1.0
assert float(metrics["overall_accuracy"]) == 1.0

def test_prediction_token_based_evaluation_partial_match_spans_expand(self):
sentences = ['New York is beautiful']
gt = [['B-FAC', 'I-FAC', 'O', 'O']]

dataset = get_dataset(gt, sentences)

custom_evaluator = MentionsExtractorEvaluator("token-classification",
alignment_mode=AlignmentMode.expand)
pipe = get_mentions_extractor_pipe([('New Yo', 'FAC', 1)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
Expand Down

0 comments on commit ca3ad9a

Please sign in to comment.