Skip to content

Commit

Permalink
⚡ Enhance evaluation flow & Add time to evaluation (#36)
Browse files Browse the repository at this point in the history
* ♻️ Enhanche evaluation flow

Enhance and simplfy the evaluation flow

* ♻️ Codestyle fix

* 🐛 Fix dataset loading

* 🧑‍💻 Improve coverage

* 🐛 Fix tests
  • Loading branch information
GabrielePicco authored Nov 18, 2022
1 parent e195caa commit 2a84de0
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 165 deletions.
3 changes: 2 additions & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ pytest-cov>=3.0.0
setuptools~=60.0.0
flair==0.11.3
flake8>=4.0.1
coverage>=6.4.1
coverage>=6.4.1
IPython
4 changes: 2 additions & 2 deletions zshot/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from zshot.evaluation.dataset import load_medmentions # noqa: F401
from zshot.evaluation.dataset import load_ontonotes # noqa: F401
from zshot.evaluation.dataset import load_medmentions_zs # noqa: F401
from zshot.evaluation.dataset import load_ontonotes_zs # noqa: F401
4 changes: 2 additions & 2 deletions zshot/evaluation/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from zshot.evaluation.dataset.med_mentions.med_mentions import load_medmentions # noqa: F401
from zshot.evaluation.dataset.ontonotes.onto_notes import load_ontonotes # noqa: F401
from zshot.evaluation.dataset.med_mentions.med_mentions import load_medmentions_zs # noqa: F401
from zshot.evaluation.dataset.ontonotes.onto_notes import load_ontonotes_zs # noqa: F401
16 changes: 10 additions & 6 deletions zshot/evaluation/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from typing import List, Optional, Dict

import datasets
from datasets import Split, Dataset
from datasets.table import Table

from zshot.utils.data_models import Entity


class DatasetWithEntities(datasets.Dataset):
# TODO: Implement save and load methods
def __init__(self, arrow_table: datasets.table.Table,

def __init__(self, arrow_table: Table,
info: Optional[datasets.info.DatasetInfo] = None,
split: Optional[datasets.splits.NamedSplit] = None,
indices_table: Optional[datasets.table.Table] = None,
indices_table: Optional[Table] = None,
fingerprint: Optional[str] = None,
entities: List[Dict[str, str]] = None):
entities: List[Entity] = None):
super().__init__(arrow_table=arrow_table, info=info, split=split,
indices_table=indices_table, fingerprint=fingerprint)
self.entities = entities
Expand All @@ -21,9 +25,9 @@ def from_dict(
mapping: dict,
features: Optional[datasets.features.Features] = None,
info: Optional[datasets.info.DatasetInfo] = None,
split: Optional[datasets.splits.NamedSplit] = None,
split: Optional[Split] = None,
entities: List[Dict[str, str]] = None
) -> "DatasetWithEntities":
) -> Dataset:
dataset = super().from_dict(mapping=mapping, features=features, info=info, split=split)
dataset.entities = entities
return dataset
Expand Down
28 changes: 20 additions & 8 deletions zshot/evaluation/dataset/med_mentions/med_mentions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from typing import Dict, Union
import re
from typing import Union, Optional

from datasets import load_dataset, Split
from datasets import load_dataset, Split, Dataset, DatasetDict
from huggingface_hub import hf_hub_download

from zshot.evaluation.dataset.dataset import DatasetWithEntities
Expand All @@ -11,15 +12,26 @@
ENTITIES_FN = "entities.json"


def load_medmentions() -> Dict[Union[str, Split], DatasetWithEntities]:
dataset = load_dataset(REPO_ID)
entities_file = hf_hub_download(repo_id=REPO_ID, repo_type='dataset',
def load_medmentions_zs(split: Optional[Union[str, Split]] = None, **kwargs) -> Union[DatasetDict, Dataset]:
dataset = load_dataset(REPO_ID, split=split, **kwargs)
entities_file = hf_hub_download(repo_id=REPO_ID,
repo_type='dataset',
filename=ENTITIES_FN)
with open(entities_file, "r") as f:
entities = json.load(f)

for split in dataset:
entities_split = [Entity(name=k, description=v) for k, v in entities[split].items()]
dataset[split] = DatasetWithEntities(dataset[split].data, entities=entities_split)
if split:
entities_split = [Entity(name=k, description=v) for k, v in entities[get_simple_split(split)].items()]
dataset = DatasetWithEntities(dataset.data, entities=entities_split)
else:
for split in dataset:
entities_split = [Entity(name=k, description=v) for k, v in entities[split].items()]
dataset[split] = DatasetWithEntities(dataset[split].data, entities=entities_split)

return dataset


def get_simple_split(split: str) -> str:
first_not_alph = re.search(r'\W+', split)
first_not_alph_chr = first_not_alph.start() if first_not_alph else len(split)
return split[: first_not_alph_chr]
64 changes: 37 additions & 27 deletions zshot/evaluation/dataset/ontonotes/onto_notes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Union
import re
from typing import Dict, Union, Optional

from datasets import ClassLabel, load_dataset, DatasetDict, Split
from datasets import ClassLabel, load_dataset, DatasetDict, Split, Dataset

from zshot.evaluation.dataset.dataset import DatasetWithEntities
from zshot.evaluation.dataset.ontonotes.entities import ONTONOTES_ENTITIES
Expand Down Expand Up @@ -54,30 +55,39 @@ def remove_out_of_split(sentence, split):
return sentence


def load_ontonotes() -> Dict[Union[str, Split], DatasetWithEntities]:
dataset_zs = load_dataset("conll2012_ontonotesv5", "english_v12")
ontonotes_zs = DatasetDict()

for split in dataset_zs:
dataset_zs[split] = dataset_zs[split].map(lambda example, idx: {
"sentences": [remove_out_of_split(s, split) for s in example['sentences']]
}, with_indices=True)
dataset_zs[split] = dataset_zs[split].map(lambda example, idx: {
"sentences": list(filter(is_not_empty, example['sentences']))
}, with_indices=True)

tokens = []
ner_tags = []
for example in dataset_zs[split]:
tokens += [s['words'] for s in example['sentences']]
ner_tags += [[labels.int2str(ent) for ent in s['named_entities']] for s in example['sentences']]

split_entities = [ent for ent in ONTONOTES_ENTITIES
if ent.name in ['NEG'] + CLASSES_PER_SPLIT[split] and ent.name not in TRIVIAL_CLASSES]
def load_ontonotes_zs(split: Optional[Union[str, Split]] = None, **kwargs) -> Union[Dict[DatasetWithEntities,
Dataset], Dataset]:
dataset_zs = load_dataset("conll2012_ontonotesv5", "english_v12", split=split, **kwargs)
if split:
ontonotes_zs = preprocess_spit(dataset_zs, get_simple_split(split))
else:
ontonotes_zs = DatasetDict()
for split in dataset_zs:
ontonotes_zs[split] = preprocess_spit(dataset_zs[split], split)
return ontonotes_zs

ontonotes_zs[split] = DatasetWithEntities.from_dict({
'tokens': tokens,
'ner_tags': ner_tags
}, split=split, entities=split_entities)

return ontonotes_zs
def preprocess_spit(dataset, split):
dataset = dataset.map(lambda example, idx: {
"sentences": [remove_out_of_split(s, split) for s in example['sentences']]
}, with_indices=True)
dataset = dataset.map(lambda example, idx: {
"sentences": list(filter(is_not_empty, example['sentences']))
}, with_indices=True)
tokens = []
ner_tags = []
for example in dataset:
tokens += [s['words'] for s in example['sentences']]
ner_tags += [[labels.int2str(ent) for ent in s['named_entities']] for s in example['sentences']]
split_entities = [ent for ent in ONTONOTES_ENTITIES
if ent.name in ['NEG'] + CLASSES_PER_SPLIT[split] and ent.name not in TRIVIAL_CLASSES]
return DatasetWithEntities.from_dict({
'tokens': tokens,
'ner_tags': ner_tags
}, split=split, entities=split_entities)


def get_simple_split(split: str) -> str:
first_not_alph = re.search(r'\W+', split)
first_not_alph_chr = first_not_alph.start() if first_not_alph else len(split)
return split[: first_not_alph_chr]
26 changes: 21 additions & 5 deletions zshot/evaluation/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import spacy
from zshot import PipelineConfig
from zshot.evaluation import load_medmentions_zs, load_ontonotes_zs
from zshot.evaluation.metrics.seqeval.seqeval import Seqeval
from zshot.evaluation.zshot_evaluate import evaluate
from zshot.evaluation.zshot_evaluate import evaluate, prettify_evaluate_report
from zshot.linker import LinkerTARS, LinkerSMXM, LinkerRegen
from zshot.mentions_extractor import MentionsExtractorSpacy, MentionsExtractorFlair, MentionsExtractorSMXM
from zshot.mentions_extractor.utils import ExtractorType
Expand All @@ -26,7 +27,7 @@
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="ontonotes", type=str,
help="Name or path to the validation data. Comma separated")
parser.add_argument("--splits", required=False, default="train, test, validation", type=str,
parser.add_argument("--splits", required=False, default="test", type=str,
help="Splits to evaluate. Comma separated")
parser.add_argument("--mode", required=False, default="full", type=str,
help="Evaluation mode. One of: full; mentions_extractor; linker")
Expand All @@ -36,8 +37,8 @@
help="Linker to evaluate. One of: all; tars")
args = parser.parse_args()

args.splits = args.splits.split(",")
args.dataset = args.dataset.split(",")
splits = args.splits.split(",")
datasets = args.dataset.split(",")

configs = {}
if args.mentions_extractor == "all":
Expand Down Expand Up @@ -78,4 +79,19 @@
nlp = spacy.blank("en") if "spacy" not in key else spacy.load("en_core_web_sm")
nlp.add_pipe("zshot", config=config, last=True)

print(evaluate(nlp, args.dataset, splits=args.splits, metric=Seqeval()))
for dataset_name in datasets:
for split in splits:
if dataset_name.lower() == "medmentions":
dataset = load_medmentions_zs(split)
elif dataset_name.lower() == "ontonotes":
dataset = load_ontonotes_zs(split)
else:
raise ValueError(f"{dataset_name} not supported")
nlp.get_pipe("zshot").mentions = dataset.entities
nlp.get_pipe("zshot").entities = dataset.entities
field_names = ["Metric"]
field_name = f"{dataset_name} {split}"
field_names.append(field_name)

evaluation = evaluate(nlp, dataset, metric=Seqeval())
print(prettify_evaluate_report(evaluation, name=f"{dataset_name}-{split}"))
139 changes: 32 additions & 107 deletions zshot/evaluation/zshot_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,128 +1,53 @@
from typing import Optional, List, Union
from typing import Optional, Union, Dict

import spacy
from datasets import Dataset
from evaluate import EvaluationModule
from prettytable import PrettyTable
from zshot.evaluation import load_medmentions, load_ontonotes

from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator, MentionsExtractorEvaluator
from zshot.evaluation.pipeline import LinkerPipeline, MentionsExtractorPipeline


def evaluate(nlp: spacy.language.Language,
datasets: Union[str, List[str]],
splits: Optional[Union[str, List[str]]] = None,
dataset: Dataset,
metric: Optional[Union[str, EvaluationModule]] = None,
batch_size: Optional[int] = 16) -> str:
batch_size: Optional[int] = 16) -> dict:
""" Evaluate a spacy zshot model
:param nlp: Spacy Language pipeline with ZShot components
:param datasets: Dataset or list of datasets to evaluate
:param splits: Optional. Split or list of splits to evaluate. All splits available by default
:param dataset: Dataset used to evaluate
:param metric: Metrics to use in evaluation.
Options available: precision, recall, f1-score-micro, f1-score-macro. All by default
:return: Result of the evaluation. Dict with precision, recall and f1-score for each component
:return: Result of the evaluation. Dict with metrics results for each component
:param batch_size: the batch size
"""
linker_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
mentions_extractor_evaluator = MentionsExtractorEvaluator("token-classification")

if type(splits) == str:
splits = [splits]

if type(datasets) == str:
datasets = [datasets]

result = {}
field_names = ["Metric"]
for dataset_name in datasets:
if dataset_name.lower() == "medmentions":
dataset = load_medmentions()
else:
dataset = load_ontonotes()

for split in splits:
field_name = f"{dataset_name} {split}"
field_names.append(field_name)
nlp.get_pipe("zshot").mentions = dataset[split].entities
nlp.get_pipe("zshot").entities = dataset[split].entities
if nlp.get_pipe("zshot").linker:
pipe = LinkerPipeline(nlp, batch_size)
res_tmp = {
'linker': linker_evaluator.compute(pipe, dataset[split], metric=metric)
}
if field_name not in result:
result.update(
{
field_name: res_tmp
}
)
else:
result[field_name].update(res_tmp)
if nlp.get_pipe("zshot").mentions_extractor:
pipe = MentionsExtractorPipeline(nlp, batch_size)
res_tmp = {
'mentions_extractor': mentions_extractor_evaluator.compute(pipe, dataset[split],
metric=metric)
}
if field_name not in result:
result.update(
{
field_name: res_tmp
}
)
else:
result[field_name].update(res_tmp)
linker_evaluator = ZeroShotTokenClassificationEvaluator()
mentions_extractor_evaluator = MentionsExtractorEvaluator()

table = PrettyTable()
table.field_names = field_names
rows = []
results = {}
if nlp.get_pipe("zshot").linker:
linker_precisions = []
linker_recalls = []
linker_micros = []
linker_macros = []
linker_accuracies = []
for field_name in field_names:
if field_name == "Metric":
continue
linker_precisions.append("{:.2f}%".format(result[field_name]['linker']['overall_precision_macro'] * 100))
linker_recalls.append("{:.2f}%".format(result[field_name]['linker']['overall_recall_macro'] * 100))
linker_accuracies.append("{:.2f}%".format(result[field_name]['linker']['overall_accuracy'] * 100))
linker_micros.append("{:.2f}%".format(result[field_name]['linker']['overall_f1_micro'] * 100))
linker_macros.append("{:.2f}%".format(result[field_name]['linker']['overall_f1_macro'] * 100))

rows.append(["Linker Precision"] + linker_precisions)
rows.append(["Linker Recall"] + linker_recalls)
rows.append(["Linker Accuracy"] + linker_accuracies)
rows.append(["Linker F1-score micro"] + linker_micros)
rows.append(["Linker F1-score macro"] + linker_macros)

pipe = LinkerPipeline(nlp, batch_size)
results['linker'] = linker_evaluator.compute(pipe, dataset, metric=metric)
if nlp.get_pipe("zshot").mentions_extractor:
mentions_extractor_precisions = []
mentions_extractor_recalls = []
mentions_extractor_micros = []
mentions_extractor_accuracies = []
mentions_extractor_macros = []
for field_name in field_names:
if field_name == "Metric":
continue
mentions_extractor_precisions.append(
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_precision_macro'] * 100))
mentions_extractor_recalls.append(
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_recall_macro'] * 100))
mentions_extractor_accuracies.append(
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_accuracy'] * 100))
mentions_extractor_micros.append(
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_f1_micro'] * 100))
mentions_extractor_macros.append(
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_f1_macro'] * 100))

rows.append(["Mentions extractor Precision"] + mentions_extractor_precisions)
rows.append(["Mentions extractor Recall"] + mentions_extractor_recalls)
rows.append(["Mentions extractor Accuracy"] + mentions_extractor_accuracies)
rows.append(["Mentions extractor F1-score micro"] + mentions_extractor_micros)
rows.append(["Mentions extractor F1-score macro"] + mentions_extractor_macros)
pipe = MentionsExtractorPipeline(nlp, batch_size)
results['mentions_extractor'] = mentions_extractor_evaluator.compute(pipe, dataset, metric=metric)
return results

table.add_rows(rows)

return table.get_string()
def prettify_evaluate_report(evaluation: Dict, name: str = "", decimals: int = 4) -> list[PrettyTable]:
"""
Convert an evaluation report Dict to a formatted string
:param evaluation: The evaluation report dict
:param name: Reference name
:param decimals: Number of decimals to show
:return: Formatted evaluation table as string, for each component
"""
tables = []
for component in evaluation:
table = PrettyTable()
table.field_names = ["Metric", name]
for metric in evaluation[component]:
if isinstance(evaluation[component][metric], (float, int)):
table.add_row([metric, f'{evaluation[component][metric]:.{decimals}f}'])
tables.append(table)
return tables
Loading

0 comments on commit 2a84de0

Please sign in to comment.