diff --git a/src/nervaluate/__init__.py b/src/nervaluate/__init__.py index 7e13ac8..bfa73e5 100644 --- a/src/nervaluate/__init__.py +++ b/src/nervaluate/__init__.py @@ -7,5 +7,7 @@ find_overlap, summary_report_ent, summary_report_overall, + summary_report_ents_indices, + summary_report_overall_indices, ) from .utils import collect_named_entities, conll_to_spans, list_to_spans, split_list diff --git a/src/nervaluate/evaluate.py b/src/nervaluate/evaluate.py index 0468f7a..9db3ae0 100644 --- a/src/nervaluate/evaluate.py +++ b/src/nervaluate/evaluate.py @@ -1,9 +1,11 @@ import logging from copy import deepcopy -from typing import List, Dict, Union, Tuple +from typing import List, Dict, Union, Tuple, Optional from .utils import conll_to_spans, find_overlap, list_to_spans +logger = logging.getLogger(__name__) + class Evaluator: # pylint: disable=too-many-instance-attributes, too-few-public-methods def __init__( @@ -49,7 +51,24 @@ def __init__( self.loader = loader - def evaluate(self) -> Tuple[Dict, Dict]: + self.eval_indices: Dict[str, List[int]] = { + "correct_indices": [], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + } + + # Create dicts to hold indices for correct/spurious/missing/etc examples + self.evaluation_indices = { + "strict": deepcopy(self.eval_indices), + "ent_type": deepcopy(self.eval_indices), + "partial": deepcopy(self.eval_indices), + "exact": deepcopy(self.eval_indices), + } + self.evaluation_agg_indices = {e: deepcopy(self.evaluation_indices) for e in tags} + + def evaluate(self) -> Tuple[Dict, Dict, Dict, Dict]: logging.debug("Imported %s predictions for %s true examples", len(self.pred), len(self.true)) if self.loader != "default": @@ -60,9 +79,11 @@ def evaluate(self) -> Tuple[Dict, Dict]: if len(self.true) != len(self.pred): raise ValueError("Number of predicted documents does not equal true") - for true_ents, pred_ents in zip(self.true, self.pred): + for index, (true_ents, pred_ents) in enumerate(zip(self.true, self.pred)): # Compute results for one message - tmp_results, tmp_agg_results = compute_metrics(true_ents, pred_ents, self.tags) + tmp_results, tmp_agg_results, tmp_results_indices, tmp_agg_results_indices = compute_metrics( + true_ents, pred_ents, self.tags, index + ) # Cycle through each result and accumulate # TODO: Combine these loops below: @@ -70,6 +91,10 @@ def evaluate(self) -> Tuple[Dict, Dict]: for metric in self.results[eval_schema]: self.results[eval_schema][metric] += tmp_results[eval_schema][metric] + # Accumulate indices for each error type + for error_type in self.evaluation_indices[eval_schema]: + self.evaluation_indices[eval_schema][error_type] += tmp_results_indices[eval_schema][error_type] + # Calculate global precision and recall self.results = compute_precision_recall_wrapper(self.results) @@ -81,17 +106,23 @@ def evaluate(self) -> Tuple[Dict, Dict]: eval_schema ][metric] + # Accumulate indices for each error type per entity type + for error_type in self.evaluation_agg_indices[label][eval_schema]: + self.evaluation_agg_indices[label][eval_schema][error_type] += tmp_agg_results_indices[label][ + eval_schema + ][error_type] + # Calculate precision recall at the individual entity level self.evaluation_agg_entities_type[label] = compute_precision_recall_wrapper( self.evaluation_agg_entities_type[label] ) - return self.results, self.evaluation_agg_entities_type + return self.results, self.evaluation_agg_entities_type, self.evaluation_indices, self.evaluation_agg_indices # flake8: noqa: C901 def compute_metrics( # type: ignore - true_named_entities, pred_named_entities, tags: List[str] + true_named_entities, pred_named_entities, tags: List[str], instance_index: int = 0 ): # pylint: disable=too-many-locals, too-many-branches, too-many-statements """ Compute metrics on the collected true and predicted named entities @@ -104,6 +135,9 @@ def compute_metrics( # type: ignore :tags: list of tags to be used + + :instance_index: + index of the example being evaluated. Used to record indices of correct/missing/spurious/exact/partial predictions. """ eval_metrics = { @@ -128,6 +162,23 @@ def compute_metrics( # type: ignore # results by entity type evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags} + eval_ent_indices: Dict[str, List[Tuple[int, int]]] = { + "correct_indices": [], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + } + + # Create dicts to hold indices for correct/spurious/missing/etc examples + evaluation_ent_indices = { + "strict": deepcopy(eval_ent_indices), + "ent_type": deepcopy(eval_ent_indices), + "partial": deepcopy(eval_ent_indices), + "exact": deepcopy(eval_ent_indices), + } + evaluation_agg_ent_indices = {e: deepcopy(evaluation_ent_indices) for e in tags} + # keep track of entities that overlapped true_which_overlapped_with_pred = [] @@ -149,7 +200,7 @@ def compute_metrics( # type: ignore pred_named_entities.sort(key=lambda x: x["end"]) # go through each predicted named-entity - for pred in pred_named_entities: + for within_instance_index, pred in enumerate(pred_named_entities): found_overlap = False # Check each of the potential scenarios in turn. See @@ -163,12 +214,28 @@ def compute_metrics( # type: ignore evaluation["ent_type"]["correct"] += 1 evaluation["exact"]["correct"] += 1 evaluation["partial"]["correct"] += 1 + evaluation_ent_indices["strict"]["correct_indices"].append((instance_index, within_instance_index)) + evaluation_ent_indices["ent_type"]["correct_indices"].append((instance_index, within_instance_index)) + evaluation_ent_indices["exact"]["correct_indices"].append((instance_index, within_instance_index)) + evaluation_ent_indices["partial"]["correct_indices"].append((instance_index, within_instance_index)) # for the agg. by label results evaluation_agg_entities_type[pred["label"]]["strict"]["correct"] += 1 evaluation_agg_entities_type[pred["label"]]["ent_type"]["correct"] += 1 evaluation_agg_entities_type[pred["label"]]["exact"]["correct"] += 1 evaluation_agg_entities_type[pred["label"]]["partial"]["correct"] += 1 + evaluation_agg_ent_indices[pred["label"]]["strict"]["correct_indices"].append( + (instance_index, within_instance_index) + ) + evaluation_agg_ent_indices[pred["label"]]["ent_type"]["correct_indices"].append( + (instance_index, within_instance_index) + ) + evaluation_agg_ent_indices[pred["label"]]["exact"]["correct_indices"].append( + (instance_index, within_instance_index) + ) + evaluation_agg_ent_indices[pred["label"]]["partial"]["correct_indices"].append( + (instance_index, within_instance_index) + ) else: # check for overlaps with any of the true entities @@ -185,13 +252,25 @@ def compute_metrics( # type: ignore if true["start"] == pred["start"] and pred["end"] == true["end"] and true["label"] != pred["label"]: # overall results evaluation["strict"]["incorrect"] += 1 + evaluation_ent_indices["strict"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation["ent_type"]["incorrect"] += 1 + evaluation_ent_indices["ent_type"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation["partial"]["correct"] += 1 evaluation["exact"]["correct"] += 1 # aggregated by entity type results evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1 + evaluation_agg_ent_indices[true["label"]]["strict"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1 + evaluation_agg_ent_indices[true["label"]]["ent_type"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["partial"]["correct"] += 1 evaluation_agg_entities_type[true["label"]]["exact"]["correct"] += 1 @@ -210,15 +289,33 @@ def compute_metrics( # type: ignore if pred["label"] == true["label"]: # overall results evaluation["strict"]["incorrect"] += 1 + evaluation_ent_indices["strict"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation["ent_type"]["correct"] += 1 evaluation["partial"]["partial"] += 1 + evaluation_ent_indices["partial"]["partial_indices"].append( + (instance_index, within_instance_index) + ) evaluation["exact"]["incorrect"] += 1 + evaluation_ent_indices["exact"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) # aggregated by entity type results evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1 + evaluation_agg_ent_indices[true["label"]]["strict"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["ent_type"]["correct"] += 1 evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1 + evaluation_agg_ent_indices[true["label"]]["partial"]["partial_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1 + evaluation_agg_ent_indices[true["label"]]["exact"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) found_overlap = True @@ -228,17 +325,41 @@ def compute_metrics( # type: ignore # overall results evaluation["strict"]["incorrect"] += 1 + evaluation_ent_indices["strict"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation["ent_type"]["incorrect"] += 1 + evaluation_ent_indices["ent_type"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation["partial"]["partial"] += 1 + evaluation_ent_indices["partial"]["partial_indices"].append( + (instance_index, within_instance_index) + ) evaluation["exact"]["incorrect"] += 1 + evaluation_ent_indices["exact"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) # aggregated by entity type results # Results against the true entity evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1 + evaluation_agg_ent_indices[true["label"]]["strict"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1 + evaluation_agg_ent_indices[true["label"]]["partial"]["partial_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1 + evaluation_agg_ent_indices[true["label"]]["ent_type"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1 + evaluation_agg_ent_indices[true["label"]]["exact"]["incorrect_indices"].append( + (instance_index, within_instance_index) + ) # Results against the predicted entity # evaluation_agg_entities_type[pred['label']]['strict']['spurious'] += 1 @@ -248,9 +369,13 @@ def compute_metrics( # type: ignore if not found_overlap: # Overall results evaluation["strict"]["spurious"] += 1 + evaluation_ent_indices["strict"]["spurious_indices"].append((instance_index, within_instance_index)) evaluation["ent_type"]["spurious"] += 1 + evaluation_ent_indices["ent_type"]["spurious_indices"].append((instance_index, within_instance_index)) evaluation["partial"]["spurious"] += 1 + evaluation_ent_indices["partial"]["spurious_indices"].append((instance_index, within_instance_index)) evaluation["exact"]["spurious"] += 1 + evaluation_ent_indices["exact"]["spurious_indices"].append((instance_index, within_instance_index)) # Aggregated by entity type results @@ -270,26 +395,54 @@ def compute_metrics( # type: ignore for true in spurious_tags: evaluation_agg_entities_type[true]["strict"]["spurious"] += 1 + evaluation_agg_ent_indices[true]["strict"]["spurious_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true]["ent_type"]["spurious"] += 1 + evaluation_agg_ent_indices[true]["ent_type"]["spurious_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true]["partial"]["spurious"] += 1 + evaluation_agg_ent_indices[true]["partial"]["spurious_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true]["exact"]["spurious"] += 1 + evaluation_agg_ent_indices[true]["exact"]["spurious_indices"].append( + (instance_index, within_instance_index) + ) # Scenario III: Entity was missed entirely. - for true in true_named_entities: + for within_instance_index, true in enumerate(true_named_entities): if true in true_which_overlapped_with_pred: continue # overall results evaluation["strict"]["missed"] += 1 + evaluation_ent_indices["strict"]["missed_indices"].append((instance_index, within_instance_index)) evaluation["ent_type"]["missed"] += 1 + evaluation_ent_indices["ent_type"]["missed_indices"].append((instance_index, within_instance_index)) evaluation["partial"]["missed"] += 1 + evaluation_ent_indices["partial"]["missed_indices"].append((instance_index, within_instance_index)) evaluation["exact"]["missed"] += 1 + evaluation_ent_indices["exact"]["missed_indices"].append((instance_index, within_instance_index)) # for the agg. by label evaluation_agg_entities_type[true["label"]]["strict"]["missed"] += 1 + evaluation_agg_ent_indices[true["label"]]["strict"]["missed_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["ent_type"]["missed"] += 1 + evaluation_agg_ent_indices[true["label"]]["ent_type"]["missed_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["partial"]["missed"] += 1 + evaluation_agg_ent_indices[true["label"]]["partial"]["missed_indices"].append( + (instance_index, within_instance_index) + ) evaluation_agg_entities_type[true["label"]]["exact"]["missed"] += 1 + evaluation_agg_ent_indices[true["label"]]["exact"]["missed_indices"].append( + (instance_index, within_instance_index) + ) # Compute 'possible', 'actual' according to SemEval-2013 Task 9.1 on the # overall results, and use these to calculate precision and recall. @@ -305,7 +458,7 @@ def compute_metrics( # type: ignore for eval_type in entity_level: evaluation_agg_entities_type[entity_type][eval_type] = compute_actual_possible(entity_level[eval_type]) - return evaluation, evaluation_agg_entities_type + return evaluation, evaluation_agg_entities_type, evaluation_ent_indices, evaluation_agg_ent_indices def compute_actual_possible(results: Dict) -> Dict: @@ -465,3 +618,57 @@ def summary_report_overall(results: Dict, digits: int = 2) -> str: report += row_fmt.format(*row, width=width, digits=digits) return report + + +def summary_report_ents_indices(evaluation_agg_indices: Dict, error_schema: str, preds: Optional[List] = [[]]) -> str: + """ + Usage: print(summary_report_ents_indices(evaluation_agg_indices, 'partial', preds)) + """ + report = "" + for entity_type, entity_results in evaluation_agg_indices.items(): + report += f"\nEntity Type: {entity_type}\n" + error_data = entity_results[error_schema] + report += f" Error Schema: '{error_schema}'\n" + for category, indices in error_data.items(): + category_name = category.replace("_", " ").capitalize() + report += f" ({entity_type}) {category_name}:\n" + if indices: + for instance_index, entity_index in indices: + if preds is not [[]]: + pred = preds[instance_index][entity_index] # type: ignore + prediction_info = f"Label={pred['label']}, Start={pred['start']}, End={pred['end']}" + report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" + else: + report += f" - Instance {instance_index}, Entity {entity_index}\n" + else: + report += " - None\n" + return report + + +def summary_report_overall_indices(evaluation_indices: Dict, error_schema: str, preds: Optional[List] = [[]]) -> str: + """ + Usage: print(summary_report_overall_indices(evaluation_indices, 'partial', preds)) + """ + report = "" + assert error_schema in evaluation_indices, f"Error schema '{error_schema}' not found in the results." + + error_data = evaluation_indices[error_schema] + report += f"Indices for error schema '{error_schema}':\n\n" + + for category, indices in error_data.items(): + category_name = category.replace("_", " ").capitalize() + report += f"{category_name} indices:\n" + if indices: + for instance_index, entity_index in indices: + if preds is not [[]]: + # Retrieve the corresponding prediction + pred = preds[instance_index][entity_index] # type: ignore + prediction_info = f"Label={pred['label']}, Start={pred['start']}, End={pred['end']}" + report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" + else: + report += f" - Instance {instance_index}, Entity {entity_index}\n" + else: + report += " - None\n" + report += "\n" + + return report diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 409b33a..4d60cff 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,3 +1,4 @@ +# pylint: disable=C0302 from nervaluate import Evaluator @@ -17,7 +18,7 @@ def test_evaluator_simple_case(): ], ] evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) - results, _ = evaluator.evaluate() + results, _, _, _ = evaluator.evaluate() expected = { "strict": { "correct": 3, @@ -92,7 +93,7 @@ def test_evaluator_simple_case_filtered_tags(): ], ] evaluator = Evaluator(true, pred, tags=["PER", "LOC"]) - results, _ = evaluator.evaluate() + results, _, _, _ = evaluator.evaluate() expected = { "strict": { "correct": 3, @@ -159,7 +160,7 @@ def test_evaluator_extra_classes(): [{"label": "FOO", "start": 1, "end": 3}], ] evaluator = Evaluator(true, pred, tags=["ORG", "FOO"]) - results, _ = evaluator.evaluate() + results, _, _, _ = evaluator.evaluate() expected = { "strict": { "correct": 0, @@ -226,7 +227,7 @@ def test_evaluator_no_entities_in_prediction(): [], ] evaluator = Evaluator(true, pred, tags=["PER"]) - results, _ = evaluator.evaluate() + results, _, _, _ = evaluator.evaluate() expected = { "strict": { "correct": 0, @@ -293,7 +294,7 @@ def test_evaluator_compare_results_and_results_agg(): [{"label": "PER", "start": 2, "end": 4}], ] evaluator = Evaluator(true, pred, tags=["PER"]) - results, results_agg = evaluator.evaluate() + results, results_agg, _, _ = evaluator.evaluate() expected = { "strict": { "correct": 1, @@ -426,7 +427,7 @@ def test_evaluator_compare_results_and_results_agg_1(): [{"label": "MISC", "start": 2, "end": 4}], ] evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) - results, results_agg = evaluator.evaluate() + results, results_agg, _, _ = evaluator.evaluate() expected = { "strict": { "correct": 2, @@ -612,7 +613,7 @@ def test_evaluator_with_extra_keys_in_pred(): ], ] evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) - results, _ = evaluator.evaluate() + results, _, _, _ = evaluator.evaluate() expected = { "strict": { "correct": 3, @@ -686,7 +687,7 @@ def test_evaluator_with_extra_keys_in_true(): ], ] evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) - results, _ = evaluator.evaluate() + results, _, _, _ = evaluator.evaluate() expected = { "strict": { "correct": 3, @@ -759,7 +760,7 @@ def test_issue_29(): ] ] evaluator = Evaluator(true, pred, tags=["PER"]) - results, _ = evaluator.evaluate() + results, _, _, _ = evaluator.evaluate() expected = { "strict": { "correct": 1, @@ -815,3 +816,244 @@ def test_issue_29(): assert results["ent_type"] == expected["ent_type"] assert results["partial"] == expected["partial"] assert results["exact"] == expected["exact"] + + +def test_evaluator_compare_results_indices_and_results_agg_indices(): + """Check that the label level results match the total results.""" + true = [ + [{"label": "PER", "start": 2, "end": 4}], + ] + pred = [ + [{"label": "PER", "start": 2, "end": 4}], + ] + evaluator = Evaluator(true, pred, tags=["PER"]) + _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() + expected_evaluation_indices = { + "strict": { + "correct_indices": [(0, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "ent_type": { + "correct_indices": [(0, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "partial": { + "correct_indices": [(0, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "exact": { + "correct_indices": [(0, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + } + expected_evaluation_agg_indices = { + "PER": { + "strict": { + "correct_indices": [(0, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "ent_type": { + "correct_indices": [(0, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "partial": { + "correct_indices": [(0, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "exact": { + "correct_indices": [(0, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + } + } + assert evaluation_agg_indices["PER"]["strict"] == expected_evaluation_agg_indices["PER"]["strict"] + assert evaluation_agg_indices["PER"]["ent_type"] == expected_evaluation_agg_indices["PER"]["ent_type"] + assert evaluation_agg_indices["PER"]["partial"] == expected_evaluation_agg_indices["PER"]["partial"] + assert evaluation_agg_indices["PER"]["exact"] == expected_evaluation_agg_indices["PER"]["exact"] + + assert evaluation_indices["strict"] == expected_evaluation_indices["strict"] + assert evaluation_indices["ent_type"] == expected_evaluation_indices["ent_type"] + assert evaluation_indices["partial"] == expected_evaluation_indices["partial"] + assert evaluation_indices["exact"] == expected_evaluation_indices["exact"] + + assert evaluation_indices["strict"] == expected_evaluation_agg_indices["PER"]["strict"] + assert evaluation_indices["ent_type"] == expected_evaluation_agg_indices["PER"]["ent_type"] + assert evaluation_indices["partial"] == expected_evaluation_agg_indices["PER"]["partial"] + assert evaluation_indices["exact"] == expected_evaluation_agg_indices["PER"]["exact"] + + +def test_evaluator_compare_results_indices_and_results_agg_indices_1(): + """Test case when model predicts a label not in the test data.""" + true = [ + [], + [{"label": "ORG", "start": 2, "end": 4}], + [{"label": "MISC", "start": 2, "end": 4}], + ] + pred = [ + [{"label": "PER", "start": 2, "end": 4}], + [{"label": "ORG", "start": 2, "end": 4}], + [{"label": "MISC", "start": 2, "end": 4}], + ] + evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) + _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() + + expected_evaluation_indices = { + "strict": { + "correct_indices": [(1, 0), (2, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [(0, 0)], + }, + "ent_type": { + "correct_indices": [(1, 0), (2, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [(0, 0)], + }, + "partial": { + "correct_indices": [(1, 0), (2, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [(0, 0)], + }, + "exact": { + "correct_indices": [(1, 0), (2, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [(0, 0)], + }, + } + expected_evaluation_agg_indices = { + "PER": { + "strict": { + "correct_indices": [], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [(0, 0)], + }, + "ent_type": { + "correct_indices": [], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [(0, 0)], + }, + "partial": { + "correct_indices": [], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [(0, 0)], + }, + "exact": { + "correct_indices": [], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [(0, 0)], + }, + }, + "ORG": { + "strict": { + "correct_indices": [(1, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "ent_type": { + "correct_indices": [(1, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "partial": { + "correct_indices": [(1, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "exact": { + "correct_indices": [(1, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + }, + "MISC": { + "strict": { + "correct_indices": [(2, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "ent_type": { + "correct_indices": [(2, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "partial": { + "correct_indices": [(2, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + "exact": { + "correct_indices": [(2, 0)], + "incorrect_indices": [], + "partial_indices": [], + "missed_indices": [], + "spurious_indices": [], + }, + }, + } + assert evaluation_agg_indices["ORG"]["strict"] == expected_evaluation_agg_indices["ORG"]["strict"] + assert evaluation_agg_indices["ORG"]["ent_type"] == expected_evaluation_agg_indices["ORG"]["ent_type"] + assert evaluation_agg_indices["ORG"]["partial"] == expected_evaluation_agg_indices["ORG"]["partial"] + assert evaluation_agg_indices["ORG"]["exact"] == expected_evaluation_agg_indices["ORG"]["exact"] + + assert evaluation_agg_indices["MISC"]["strict"] == expected_evaluation_agg_indices["MISC"]["strict"] + assert evaluation_agg_indices["MISC"]["ent_type"] == expected_evaluation_agg_indices["MISC"]["ent_type"] + assert evaluation_agg_indices["MISC"]["partial"] == expected_evaluation_agg_indices["MISC"]["partial"] + assert evaluation_agg_indices["MISC"]["exact"] == expected_evaluation_agg_indices["MISC"]["exact"] + + assert evaluation_indices["strict"] == expected_evaluation_indices["strict"] + assert evaluation_indices["ent_type"] == expected_evaluation_indices["ent_type"] + assert evaluation_indices["partial"] == expected_evaluation_indices["partial"] + assert evaluation_indices["exact"] == expected_evaluation_indices["exact"] diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 6408be7..80cc921 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -50,9 +50,9 @@ def test_loaders_produce_the_same_results(): evaluator_prod = Evaluator(true_prod, pred_prod, tags=["PER", "ORG", "MISC"]) - _, _ = evaluator_list.evaluate() - _, _ = evaluator_prod.evaluate() - _, _ = evaluator_conll.evaluate() + _, _, _, _ = evaluator_list.evaluate() + _, _, _, _ = evaluator_prod.evaluate() + _, _, _, _ = evaluator_conll.evaluate() assert evaluator_prod.pred == evaluator_list.pred == evaluator_conll.pred assert evaluator_prod.true == evaluator_list.true == evaluator_conll.true diff --git a/tests/test_nervaluate.py b/tests/test_nervaluate.py index c4e0c07..57aff58 100644 --- a/tests/test_nervaluate.py +++ b/tests/test_nervaluate.py @@ -23,7 +23,7 @@ def test_compute_metrics_case_1(): {"label": "LOC", "start": 208, "end": 219}, {"label": "LOC", "start": 225, "end": 243}, ] - results, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "MISC"]) + results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "MISC"]) results = compute_precision_recall_wrapper(results) expected = { "strict": { @@ -81,7 +81,7 @@ def test_compute_metrics_case_1(): def test_compute_metrics_agg_scenario_3(): true_named_entities = [{"label": "PER", "start": 59, "end": 69}] pred_named_entities = [] - _, results_agg = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) + _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) expected_agg = { "PER": { "strict": { @@ -144,7 +144,7 @@ def test_compute_metrics_agg_scenario_3(): def test_compute_metrics_agg_scenario_2(): true_named_entities = [] pred_named_entities = [{"label": "PER", "start": 59, "end": 69}] - _, results_agg = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) + _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) expected_agg = { "PER": { "strict": { @@ -207,7 +207,7 @@ def test_compute_metrics_agg_scenario_2(): def test_compute_metrics_agg_scenario_5(): true_named_entities = [{"label": "PER", "start": 59, "end": 69}] pred_named_entities = [{"label": "PER", "start": 57, "end": 69}] - _, results_agg = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) + _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) expected_agg = { "PER": { "strict": { @@ -270,7 +270,7 @@ def test_compute_metrics_agg_scenario_5(): def test_compute_metrics_agg_scenario_4(): true_named_entities = [{"label": "PER", "start": 59, "end": 69}] pred_named_entities = [{"label": "LOC", "start": 59, "end": 69}] - _, results_agg = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC"]) + _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC"]) expected_agg = { "PER": { "strict": { @@ -384,7 +384,7 @@ def test_compute_metrics_agg_scenario_4(): def test_compute_metrics_agg_scenario_1(): true_named_entities = [{"label": "PER", "start": 59, "end": 69}] pred_named_entities = [{"label": "PER", "start": 59, "end": 69}] - _, results_agg = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) + _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) expected_agg = { "PER": { "strict": { @@ -447,7 +447,7 @@ def test_compute_metrics_agg_scenario_1(): def test_compute_metrics_agg_scenario_6(): true_named_entities = [{"label": "PER", "start": 59, "end": 69}] pred_named_entities = [{"label": "LOC", "start": 54, "end": 69}] - _, results_agg = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC"]) + _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC"]) expected_agg = { "PER": { "strict": { @@ -570,7 +570,7 @@ def test_compute_metrics_extra_tags_in_prediction(): {"label": "ORG", "start": 59, "end": 69}, # Correct {"label": "MISC", "start": 71, "end": 72}, # Wrong type ] - results, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "ORG"]) + results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "ORG"]) expected = { "strict": { "correct": 1, @@ -641,7 +641,7 @@ def test_compute_metrics_extra_tags_in_true(): {"label": "ORG", "start": 71, "end": 72}, # Spurious ] - results, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "ORG"]) + results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "ORG"]) expected = { "strict": { @@ -707,7 +707,7 @@ def test_compute_metrics_no_predictions(): {"label": "MISC", "start": 71, "end": 72}, ] pred_named_entities = [] - results, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "ORG", "MISC"]) + results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "ORG", "MISC"]) expected = { "strict": { "correct": 0, @@ -831,58 +831,58 @@ def test_compute_metrics_one_pred_two_true(): {"start": 0, "end": 17, "label": "A"}, ] - results1, _ = compute_metrics(true_named_entities_1, pred_named_entities, ["A", "B"]) - results2, _ = compute_metrics(true_named_entities_2, pred_named_entities, ["A", "B"]) + results1, _, _, _ = compute_metrics(true_named_entities_1, pred_named_entities, ["A", "B"]) + results2, _, _, _ = compute_metrics(true_named_entities_2, pred_named_entities, ["A", "B"]) expected = { - 'ent_type': { - 'correct': 1, - 'incorrect': 1, - 'partial': 0, - 'missed': 0, - 'spurious': 0, - 'possible': 2, - 'actual': 2, - 'precision': 0, - 'recall': 0, - 'f1': 0 + "ent_type": { + "correct": 1, + "incorrect": 1, + "partial": 0, + "missed": 0, + "spurious": 0, + "possible": 2, + "actual": 2, + "precision": 0, + "recall": 0, + "f1": 0, + }, + "partial": { + "correct": 0, + "incorrect": 0, + "partial": 2, + "missed": 0, + "spurious": 0, + "possible": 2, + "actual": 2, + "precision": 0, + "recall": 0, + "f1": 0, }, - 'partial': { - 'correct': 0, - 'incorrect': 0, - 'partial': 2, - 'missed': 0, - 'spurious': 0, - 'possible': 2, - 'actual': 2, - 'precision': 0, - 'recall': 0, - 'f1': 0 + "strict": { + "correct": 0, + "incorrect": 2, + "partial": 0, + "missed": 0, + "spurious": 0, + "possible": 2, + "actual": 2, + "precision": 0, + "recall": 0, + "f1": 0, }, - 'strict': { - 'correct': 0, - 'incorrect': 2, - 'partial': 0, - 'missed': 0, - 'spurious': 0, - 'possible': 2, - 'actual': 2, - 'precision': 0, - 'recall': 0, - 'f1': 0 + "exact": { + "correct": 0, + "incorrect": 2, + "partial": 0, + "missed": 0, + "spurious": 0, + "possible": 2, + "actual": 2, + "precision": 0, + "recall": 0, + "f1": 0, }, - 'exact': { - 'correct': 0, - 'incorrect': 2, - 'partial': 0, - 'missed': 0, - 'spurious': 0, - 'possible': 2, - 'actual': 2, - 'precision': 0, - 'recall': 0, - 'f1': 0 - } } assert results1 == expected