Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressing issue #65 - Expanded evaluator class: extract results as a Dataframe #80

Merged
merged 11 commits into from
Sep 12, 2024
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ classifiers=[

[project.urls]
"Homepage" = "https://github.com/MantisAI/nervaluate"
"Bug Tracker" = "https://github.com/MantisAI/nervaluate/issues"
"Bug Tracker" = "https://github.com/MantisAI/nervaluate/issues"
47 changes: 47 additions & 0 deletions src/nervaluate/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from copy import deepcopy
from typing import List, Dict, Union, Tuple, Optional
import pandas as pd
from collections import defaultdict

from .utils import conll_to_spans, find_overlap, list_to_spans

Expand Down Expand Up @@ -118,6 +120,51 @@ def evaluate(self) -> Tuple[Dict, Dict, Dict, Dict]:
)

return self.results, self.evaluation_agg_entities_type, self.evaluation_indices, self.evaluation_agg_indices

# Helper method to flatten a nested dictionary
def _flatten_dict(self, d, parent_key='', sep='.'):
"""
Flattens a nested dictionary.

Args:
d (dict): The dictionary to flatten.
parent_key (str): The base key string to prepend to each dictionary key.
sep (str): The separator to use when combining keys.

Returns:
dict: A flattened dictionary.
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(self._flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)

# Modified results_to_dataframe method using the helper method
def results_to_dataframe(self) -> pd.DataFrame:
if not self.results:
raise ValueError("self.results should be defined.")

if not isinstance(self.results, dict) or not all(isinstance(v, dict) for v in self.results.values()):
raise ValueError("self.results must be a dictionary of dictionaries.")

# Flatten the nested results dictionary, including the 'entities' sub-dictionaries
flattened_results = {}
for outer_key, inner_dict in self.results.items():
flattened_inner_dict = self._flatten_dict(inner_dict)
for inner_key, value in flattened_inner_dict.items():
if inner_key not in flattened_results:
flattened_results[inner_key] = {}
flattened_results[inner_key][outer_key] = value

# Convert the flattened results to a pandas DataFrame
try:
return pd.DataFrame(flattened_results)
except Exception as e:
raise RuntimeError("Error converting flattened results to DataFrame") from e


# flake8: noqa: C901
Expand Down
110 changes: 110 additions & 0 deletions tests/test_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,116 @@
# pylint: disable=C0302
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
from nervaluate import Evaluator
import pandas as pd

def test_results_to_dataframe():
"""
Test the results_to_dataframe method.
"""
# Setup
evaluator = Evaluator(
true=[['B-LOC', 'I-LOC', 'O'], ['B-PER', 'O', 'O']],
pred=[['B-LOC', 'I-LOC', 'O'], ['B-PER', 'I-PER', 'O']],
tags=['LOC', 'PER']
)

# Mock results data for the purpose of this test
evaluator.results = {
'strict': {
'correct': 10,
'incorrect': 5,
'partial': 3,
'missed': 2,
'spurious': 4,
'precision': 0.625,
'recall': 0.6667,
'f1': 0.6452,
'entities': {
'LOC': {'correct': 4, 'incorrect': 1, 'partial': 0, 'missed': 1, 'spurious': 2},
'PER': {'correct': 3, 'incorrect': 2, 'partial': 1, 'missed': 0, 'spurious': 1},
'ORG': {'correct': 3, 'incorrect': 2, 'partial': 2, 'missed': 1, 'spurious': 1}
}
},
'ent_type': {
'correct': 8,
'incorrect': 4,
'partial': 1,
'missed': 3,
'spurious': 3,
'precision': 0.5714,
'recall': 0.6154,
'f1': 0.5926,
'entities': {
'LOC': {'correct': 3, 'incorrect': 2, 'partial': 1, 'missed': 1, 'spurious': 1},
'PER': {'correct': 2, 'incorrect': 1, 'partial': 0, 'missed': 2, 'spurious': 0},
'ORG': {'correct': 3, 'incorrect': 1, 'partial': 0, 'missed': 0, 'spurious': 2}
}
},
'partial': {
'correct': 7,
'incorrect': 3,
'partial': 4,
'missed': 1,
'spurious': 5,
'precision': 0.5385,
'recall': 0.6364,
'f1': 0.5833,
'entities': {
'LOC': {'correct': 2, 'incorrect': 1, 'partial': 1, 'missed': 1, 'spurious': 2},
'PER': {'correct': 3, 'incorrect': 1, 'partial': 1, 'missed': 0, 'spurious': 1},
'ORG': {'correct': 2, 'incorrect': 1, 'partial': 2, 'missed': 0, 'spurious': 2}
}
},
'exact': {
'correct': 9,
'incorrect': 6,
'partial': 2,
'missed': 2,
'spurious': 2,
'precision': 0.6,
'recall': 0.6429,
'f1': 0.6207,
'entities': {
'LOC': {'correct': 4, 'incorrect': 1, 'partial': 0, 'missed': 1, 'spurious': 1},
'PER': {'correct': 3, 'incorrect': 3, 'partial': 0, 'missed': 0, 'spurious': 0},
'ORG': {'correct': 2, 'incorrect': 2, 'partial': 2, 'missed': 1, 'spurious': 1}
}
}
}

# Expected DataFrame
expected_data = {
'correct': {'strict': 10, 'ent_type': 8, 'partial': 7, 'exact': 9},
'incorrect': {'strict': 5, 'ent_type': 4, 'partial': 3, 'exact': 6},
'partial': {'strict': 3, 'ent_type': 1, 'partial': 4, 'exact': 2},
'missed': {'strict': 2, 'ent_type': 3, 'partial': 1, 'exact': 2},
'spurious': {'strict': 4, 'ent_type': 3, 'partial': 5, 'exact': 2},
'precision': {'strict': 0.625, 'ent_type': 0.5714, 'partial': 0.5385, 'exact': 0.6},
'recall': {'strict': 0.6667, 'ent_type': 0.6154, 'partial': 0.6364, 'exact': 0.6429},
'f1': {'strict': 0.6452, 'ent_type': 0.5926, 'partial': 0.5833, 'exact': 0.6207},
'entities.LOC.correct': {'strict': 4, 'ent_type': 3, 'partial': 2, 'exact': 4},
'entities.LOC.incorrect': {'strict': 1, 'ent_type': 2, 'partial': 1, 'exact': 1},
'entities.LOC.partial': {'strict': 0, 'ent_type': 1, 'partial': 1, 'exact': 0},
'entities.LOC.missed': {'strict': 1, 'ent_type': 1, 'partial': 1, 'exact': 1},
'entities.LOC.spurious': {'strict': 2, 'ent_type': 1, 'partial': 2, 'exact': 1},
'entities.PER.correct': {'strict': 3, 'ent_type': 2, 'partial': 3, 'exact': 3},
'entities.PER.incorrect': {'strict': 2, 'ent_type': 1, 'partial': 1, 'exact': 3},
'entities.PER.partial': {'strict': 1, 'ent_type': 0, 'partial': 1, 'exact': 0},
'entities.PER.missed': {'strict': 0, 'ent_type': 2, 'partial': 0, 'exact': 0},
'entities.PER.spurious': {'strict': 1, 'ent_type': 0, 'partial': 1, 'exact': 0},
'entities.ORG.correct': {'strict': 3, 'ent_type': 3, 'partial': 2, 'exact': 2},
'entities.ORG.incorrect': {'strict': 2, 'ent_type': 1, 'partial': 1, 'exact': 2},
'entities.ORG.partial': {'strict': 2, 'ent_type': 0, 'partial': 2, 'exact': 2},
'entities.ORG.missed': {'strict': 1, 'ent_type': 0, 'partial': 0, 'exact': 1},
'entities.ORG.spurious': {'strict': 1, 'ent_type': 2, 'partial': 2, 'exact': 1}
}

expected_df = pd.DataFrame(expected_data)

# Execute
result_df = evaluator.results_to_dataframe()

# Assert
pd.testing.assert_frame_equal(result_df, expected_df)

def test_evaluator_simple_case():
true = [
Expand Down
Loading