-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
⚡ Enhance evaluation flow & Add time to evaluation (#36)
* ♻️ Enhanche evaluation flow Enhance and simplfy the evaluation flow * ♻️ Codestyle fix * 🐛 Fix dataset loading * 🧑💻 Improve coverage * 🐛 Fix tests
- Loading branch information
1 parent
e195caa
commit 2a84de0
Showing
11 changed files
with
163 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ pytest-cov>=3.0.0 | |
setuptools~=60.0.0 | ||
flair==0.11.3 | ||
flake8>=4.0.1 | ||
coverage>=6.4.1 | ||
coverage>=6.4.1 | ||
IPython |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from zshot.evaluation.dataset import load_medmentions # noqa: F401 | ||
from zshot.evaluation.dataset import load_ontonotes # noqa: F401 | ||
from zshot.evaluation.dataset import load_medmentions_zs # noqa: F401 | ||
from zshot.evaluation.dataset import load_ontonotes_zs # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from zshot.evaluation.dataset.med_mentions.med_mentions import load_medmentions # noqa: F401 | ||
from zshot.evaluation.dataset.ontonotes.onto_notes import load_ontonotes # noqa: F401 | ||
from zshot.evaluation.dataset.med_mentions.med_mentions import load_medmentions_zs # noqa: F401 | ||
from zshot.evaluation.dataset.ontonotes.onto_notes import load_ontonotes_zs # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,128 +1,53 @@ | ||
from typing import Optional, List, Union | ||
from typing import Optional, Union, Dict | ||
|
||
import spacy | ||
from datasets import Dataset | ||
from evaluate import EvaluationModule | ||
from prettytable import PrettyTable | ||
from zshot.evaluation import load_medmentions, load_ontonotes | ||
|
||
from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator, MentionsExtractorEvaluator | ||
from zshot.evaluation.pipeline import LinkerPipeline, MentionsExtractorPipeline | ||
|
||
|
||
def evaluate(nlp: spacy.language.Language, | ||
datasets: Union[str, List[str]], | ||
splits: Optional[Union[str, List[str]]] = None, | ||
dataset: Dataset, | ||
metric: Optional[Union[str, EvaluationModule]] = None, | ||
batch_size: Optional[int] = 16) -> str: | ||
batch_size: Optional[int] = 16) -> dict: | ||
""" Evaluate a spacy zshot model | ||
:param nlp: Spacy Language pipeline with ZShot components | ||
:param datasets: Dataset or list of datasets to evaluate | ||
:param splits: Optional. Split or list of splits to evaluate. All splits available by default | ||
:param dataset: Dataset used to evaluate | ||
:param metric: Metrics to use in evaluation. | ||
Options available: precision, recall, f1-score-micro, f1-score-macro. All by default | ||
:return: Result of the evaluation. Dict with precision, recall and f1-score for each component | ||
:return: Result of the evaluation. Dict with metrics results for each component | ||
:param batch_size: the batch size | ||
""" | ||
linker_evaluator = ZeroShotTokenClassificationEvaluator("token-classification") | ||
mentions_extractor_evaluator = MentionsExtractorEvaluator("token-classification") | ||
|
||
if type(splits) == str: | ||
splits = [splits] | ||
|
||
if type(datasets) == str: | ||
datasets = [datasets] | ||
|
||
result = {} | ||
field_names = ["Metric"] | ||
for dataset_name in datasets: | ||
if dataset_name.lower() == "medmentions": | ||
dataset = load_medmentions() | ||
else: | ||
dataset = load_ontonotes() | ||
|
||
for split in splits: | ||
field_name = f"{dataset_name} {split}" | ||
field_names.append(field_name) | ||
nlp.get_pipe("zshot").mentions = dataset[split].entities | ||
nlp.get_pipe("zshot").entities = dataset[split].entities | ||
if nlp.get_pipe("zshot").linker: | ||
pipe = LinkerPipeline(nlp, batch_size) | ||
res_tmp = { | ||
'linker': linker_evaluator.compute(pipe, dataset[split], metric=metric) | ||
} | ||
if field_name not in result: | ||
result.update( | ||
{ | ||
field_name: res_tmp | ||
} | ||
) | ||
else: | ||
result[field_name].update(res_tmp) | ||
if nlp.get_pipe("zshot").mentions_extractor: | ||
pipe = MentionsExtractorPipeline(nlp, batch_size) | ||
res_tmp = { | ||
'mentions_extractor': mentions_extractor_evaluator.compute(pipe, dataset[split], | ||
metric=metric) | ||
} | ||
if field_name not in result: | ||
result.update( | ||
{ | ||
field_name: res_tmp | ||
} | ||
) | ||
else: | ||
result[field_name].update(res_tmp) | ||
linker_evaluator = ZeroShotTokenClassificationEvaluator() | ||
mentions_extractor_evaluator = MentionsExtractorEvaluator() | ||
|
||
table = PrettyTable() | ||
table.field_names = field_names | ||
rows = [] | ||
results = {} | ||
if nlp.get_pipe("zshot").linker: | ||
linker_precisions = [] | ||
linker_recalls = [] | ||
linker_micros = [] | ||
linker_macros = [] | ||
linker_accuracies = [] | ||
for field_name in field_names: | ||
if field_name == "Metric": | ||
continue | ||
linker_precisions.append("{:.2f}%".format(result[field_name]['linker']['overall_precision_macro'] * 100)) | ||
linker_recalls.append("{:.2f}%".format(result[field_name]['linker']['overall_recall_macro'] * 100)) | ||
linker_accuracies.append("{:.2f}%".format(result[field_name]['linker']['overall_accuracy'] * 100)) | ||
linker_micros.append("{:.2f}%".format(result[field_name]['linker']['overall_f1_micro'] * 100)) | ||
linker_macros.append("{:.2f}%".format(result[field_name]['linker']['overall_f1_macro'] * 100)) | ||
|
||
rows.append(["Linker Precision"] + linker_precisions) | ||
rows.append(["Linker Recall"] + linker_recalls) | ||
rows.append(["Linker Accuracy"] + linker_accuracies) | ||
rows.append(["Linker F1-score micro"] + linker_micros) | ||
rows.append(["Linker F1-score macro"] + linker_macros) | ||
|
||
pipe = LinkerPipeline(nlp, batch_size) | ||
results['linker'] = linker_evaluator.compute(pipe, dataset, metric=metric) | ||
if nlp.get_pipe("zshot").mentions_extractor: | ||
mentions_extractor_precisions = [] | ||
mentions_extractor_recalls = [] | ||
mentions_extractor_micros = [] | ||
mentions_extractor_accuracies = [] | ||
mentions_extractor_macros = [] | ||
for field_name in field_names: | ||
if field_name == "Metric": | ||
continue | ||
mentions_extractor_precisions.append( | ||
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_precision_macro'] * 100)) | ||
mentions_extractor_recalls.append( | ||
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_recall_macro'] * 100)) | ||
mentions_extractor_accuracies.append( | ||
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_accuracy'] * 100)) | ||
mentions_extractor_micros.append( | ||
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_f1_micro'] * 100)) | ||
mentions_extractor_macros.append( | ||
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_f1_macro'] * 100)) | ||
|
||
rows.append(["Mentions extractor Precision"] + mentions_extractor_precisions) | ||
rows.append(["Mentions extractor Recall"] + mentions_extractor_recalls) | ||
rows.append(["Mentions extractor Accuracy"] + mentions_extractor_accuracies) | ||
rows.append(["Mentions extractor F1-score micro"] + mentions_extractor_micros) | ||
rows.append(["Mentions extractor F1-score macro"] + mentions_extractor_macros) | ||
pipe = MentionsExtractorPipeline(nlp, batch_size) | ||
results['mentions_extractor'] = mentions_extractor_evaluator.compute(pipe, dataset, metric=metric) | ||
return results | ||
|
||
table.add_rows(rows) | ||
|
||
return table.get_string() | ||
def prettify_evaluate_report(evaluation: Dict, name: str = "", decimals: int = 4) -> list[PrettyTable]: | ||
""" | ||
Convert an evaluation report Dict to a formatted string | ||
:param evaluation: The evaluation report dict | ||
:param name: Reference name | ||
:param decimals: Number of decimals to show | ||
:return: Formatted evaluation table as string, for each component | ||
""" | ||
tables = [] | ||
for component in evaluation: | ||
table = PrettyTable() | ||
table.field_names = ["Metric", name] | ||
for metric in evaluation[component]: | ||
if isinstance(evaluation[component][metric], (float, int)): | ||
table.add_row([metric, f'{evaluation[component][metric]:.{decimals}f}']) | ||
tables.append(table) | ||
return tables |
Oops, something went wrong.