Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more information to outputs (add instance indices for error types) #72

Merged
merged 12 commits into from
Mar 2, 2024
Merged
99 changes: 94 additions & 5 deletions src/nervaluate/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from .utils import conll_to_spans, find_overlap, list_to_spans

logger = logging.getLogger(__name__)


class Evaluator: # pylint: disable=too-many-instance-attributes, too-few-public-methods
def __init__(
Expand Down Expand Up @@ -49,6 +51,23 @@ def __init__(

self.loader = loader

self.eval_indices = {
"correct_indices": [],
"incorrect_indices": [],
"partial_indices": [],
"missed_indices": [],
"spurious_indices": [],
}

# Create dicts to hold indices for correct/spurious/missing/etc examples
self.evaluation_indices = {
"strict": deepcopy(self.eval_indices),
"ent_type": deepcopy(self.eval_indices),
"partial": deepcopy(self.eval_indices),
"exact": deepcopy(self.eval_indices),
}
self.evaluation_agg_indices = {e: deepcopy(self.evaluation_indices) for e in tags}

def evaluate(self) -> Tuple[Dict, Dict]:
logging.debug("Imported %s predictions for %s true examples", len(self.pred), len(self.true))

Expand All @@ -60,16 +79,22 @@ def evaluate(self) -> Tuple[Dict, Dict]:
if len(self.true) != len(self.pred):
raise ValueError("Number of predicted documents does not equal true")

for true_ents, pred_ents in zip(self.true, self.pred):
for index, (true_ents, pred_ents) in enumerate(zip(self.true, self.pred)):
# Compute results for one message
tmp_results, tmp_agg_results = compute_metrics(true_ents, pred_ents, self.tags)
tmp_results, tmp_agg_results, tmp_results_indices, tmp_agg_results_indices = compute_metrics(true_ents, pred_ents, self.tags, index)

# import ipdb; ipdb.set_trace()
jackboyla marked this conversation as resolved.
Show resolved Hide resolved

# Cycle through each result and accumulate
# TODO: Combine these loops below:
for eval_schema in self.results:
for metric in self.results[eval_schema]:
self.results[eval_schema][metric] += tmp_results[eval_schema][metric]

# Accumulate indices for each error type
for error_type in self.evaluation_indices[eval_schema]:
self.evaluation_indices[eval_schema][error_type] += tmp_results_indices[eval_schema][error_type]

# Calculate global precision and recall
self.results = compute_precision_recall_wrapper(self.results)

Expand All @@ -81,17 +106,21 @@ def evaluate(self) -> Tuple[Dict, Dict]:
eval_schema
][metric]

# Accumulate indices for each error type per entity type
for error_type in self.evaluation_agg_indices[label][eval_schema]:
self.evaluation_agg_indices[label][eval_schema][error_type] += tmp_agg_results_indices[label][eval_schema][error_type]

# Calculate precision recall at the individual entity level
self.evaluation_agg_entities_type[label] = compute_precision_recall_wrapper(
self.evaluation_agg_entities_type[label]
)

return self.results, self.evaluation_agg_entities_type
return self.results, self.evaluation_agg_entities_type, self.evaluation_indices, self.evaluation_agg_indices


# flake8: noqa: C901
def compute_metrics( # type: ignore
true_named_entities, pred_named_entities, tags: List[str]
true_named_entities, pred_named_entities, tags: List[str], index: int = 0
): # pylint: disable=too-many-locals, too-many-branches, too-many-statements
"""
Compute metrics on the collected true and predicted named entities
jackboyla marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -128,6 +157,23 @@ def compute_metrics( # type: ignore
# results by entity type
evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags}

eval_indices = {
jackboyla marked this conversation as resolved.
Show resolved Hide resolved
"correct_indices": [],
"incorrect_indices": [],
"partial_indices": [],
"missed_indices": [],
"spurious_indices": [],
}

# Create dicts to hold indices for correct/spurious/missing/etc examples
evaluation_indices = {
jackboyla marked this conversation as resolved.
Show resolved Hide resolved
"strict": deepcopy(eval_indices),
"ent_type": deepcopy(eval_indices),
"partial": deepcopy(eval_indices),
"exact": deepcopy(eval_indices),
}
evaluation_agg_indices = {e: deepcopy(evaluation_indices) for e in tags}
jackboyla marked this conversation as resolved.
Show resolved Hide resolved

# keep track of entities that overlapped
true_which_overlapped_with_pred = []

Expand Down Expand Up @@ -163,12 +209,21 @@ def compute_metrics( # type: ignore
evaluation["ent_type"]["correct"] += 1
evaluation["exact"]["correct"] += 1
evaluation["partial"]["correct"] += 1
evaluation_indices["strict"]["correct_indices"].append(index)
evaluation_indices["ent_type"]["correct_indices"].append(index)
evaluation_indices["exact"]["correct_indices"].append(index)
evaluation_indices["partial"]["correct_indices"].append(index)


# for the agg. by label results
evaluation_agg_entities_type[pred["label"]]["strict"]["correct"] += 1
evaluation_agg_entities_type[pred["label"]]["ent_type"]["correct"] += 1
evaluation_agg_entities_type[pred["label"]]["exact"]["correct"] += 1
evaluation_agg_entities_type[pred["label"]]["partial"]["correct"] += 1
evaluation_agg_indices[pred["label"]]["strict"]["correct_indices"].append(index)
evaluation_agg_indices[pred["label"]]["ent_type"]["correct_indices"].append(index)
evaluation_agg_indices[pred["label"]]["exact"]["correct_indices"].append(index)
evaluation_agg_indices[pred["label"]]["partial"]["correct_indices"].append(index)

else:
# check for overlaps with any of the true entities
Expand All @@ -185,13 +240,17 @@ def compute_metrics( # type: ignore
if true["start"] == pred["start"] and pred["end"] == true["end"] and true["label"] != pred["label"]:
# overall results
evaluation["strict"]["incorrect"] += 1
evaluation_indices["strict"]["incorrect_indices"].append(index)
evaluation["ent_type"]["incorrect"] += 1
evaluation_indices["ent_type"]["incorrect_indices"].append(index)
evaluation["partial"]["correct"] += 1
evaluation["exact"]["correct"] += 1

# aggregated by entity type results
evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1
evaluation_agg_indices[true["label"]]["strict"]["incorrect_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1
evaluation_agg_indices[true["label"]]["ent_type"]["incorrect_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["partial"]["correct"] += 1
evaluation_agg_entities_type[true["label"]]["exact"]["correct"] += 1

Expand All @@ -210,15 +269,21 @@ def compute_metrics( # type: ignore
if pred["label"] == true["label"]:
# overall results
evaluation["strict"]["incorrect"] += 1
evaluation_indices["strict"]["incorrect_indices"].append(index)
evaluation["ent_type"]["correct"] += 1
evaluation["partial"]["partial"] += 1
evaluation_indices["partial"]["partial_indices"].append(index)
evaluation["exact"]["incorrect"] += 1
evaluation_indices["exact"]["incorrect_indices"].append(index)

# aggregated by entity type results
evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1
evaluation_agg_indices[true["label"]]["strict"]["incorrect_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["ent_type"]["correct"] += 1
evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1
evaluation_agg_indices[true["label"]]["partial"]["partial_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1
evaluation_agg_indices[true["label"]]["exact"]["incorrect_indices"].append(index)

found_overlap = True

Expand All @@ -228,17 +293,25 @@ def compute_metrics( # type: ignore

# overall results
evaluation["strict"]["incorrect"] += 1
evaluation_indices["strict"]["incorrect_indices"].append(index)
evaluation["ent_type"]["incorrect"] += 1
evaluation_indices["ent_type"]["incorrect_indices"].append(index)
evaluation["partial"]["partial"] += 1
evaluation_indices["partial"]["partial_indices"].append(index)
evaluation["exact"]["incorrect"] += 1
evaluation_indices["exact"]["incorrect_indices"].append(index)

# aggregated by entity type results
# Results against the true entity

evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1
evaluation_agg_indices[true["label"]]["strict"]["incorrect_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1
evaluation_agg_indices[true["label"]]["partial"]["partial_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1
evaluation_agg_indices[true["label"]]["ent_type"]["incorrect_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1
evaluation_agg_indices[true["label"]]["exact"]["incorrect_indices"].append(index)

# Results against the predicted entity
# evaluation_agg_entities_type[pred['label']]['strict']['spurious'] += 1
Expand All @@ -248,9 +321,13 @@ def compute_metrics( # type: ignore
if not found_overlap:
# Overall results
evaluation["strict"]["spurious"] += 1
evaluation_indices["strict"]["spurious_indices"].append(index)
evaluation["ent_type"]["spurious"] += 1
evaluation_indices["ent_type"]["spurious_indices"].append(index)
evaluation["partial"]["spurious"] += 1
evaluation_indices["partial"]["spurious_indices"].append(index)
evaluation["exact"]["spurious"] += 1
evaluation_indices["exact"]["spurious_indices"].append(index)

# Aggregated by entity type results

Expand All @@ -270,9 +347,13 @@ def compute_metrics( # type: ignore

for true in spurious_tags:
evaluation_agg_entities_type[true]["strict"]["spurious"] += 1
evaluation_agg_indices[true]["strict"]["spurious_indices"].append(index)
evaluation_agg_entities_type[true]["ent_type"]["spurious"] += 1
evaluation_agg_indices[true]["ent_type"]["spurious_indices"].append(index)
evaluation_agg_entities_type[true]["partial"]["spurious"] += 1
evaluation_agg_indices[true]["partial"]["spurious_indices"].append(index)
evaluation_agg_entities_type[true]["exact"]["spurious"] += 1
evaluation_agg_indices[true]["exact"]["spurious_indices"].append(index)

# Scenario III: Entity was missed entirely.
for true in true_named_entities:
Expand All @@ -281,15 +362,23 @@ def compute_metrics( # type: ignore

# overall results
evaluation["strict"]["missed"] += 1
evaluation_indices["strict"]["missed_indices"].append(index)
evaluation["ent_type"]["missed"] += 1
evaluation_indices["ent_type"]["missed_indices"].append(index)
evaluation["partial"]["missed"] += 1
evaluation_indices["partial"]["missed_indices"].append(index)
evaluation["exact"]["missed"] += 1
evaluation_indices["exact"]["missed_indices"].append(index)

# for the agg. by label
evaluation_agg_entities_type[true["label"]]["strict"]["missed"] += 1
evaluation_agg_indices[true["label"]]["strict"]["missed_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["ent_type"]["missed"] += 1
evaluation_agg_indices[true["label"]]["ent_type"]["missed_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["partial"]["missed"] += 1
evaluation_agg_indices[true["label"]]["partial"]["missed_indices"].append(index)
evaluation_agg_entities_type[true["label"]]["exact"]["missed"] += 1
evaluation_agg_indices[true["label"]]["exact"]["missed_indices"].append(index)

# Compute 'possible', 'actual' according to SemEval-2013 Task 9.1 on the
# overall results, and use these to calculate precision and recall.
Expand All @@ -305,7 +394,7 @@ def compute_metrics( # type: ignore
for eval_type in entity_level:
evaluation_agg_entities_type[entity_type][eval_type] = compute_actual_possible(entity_level[eval_type])

return evaluation, evaluation_agg_entities_type
return evaluation, evaluation_agg_entities_type, evaluation_indices, evaluation_agg_indices


def compute_actual_possible(results: Dict) -> Dict:
Expand Down