Skip to content

Commit

Permalink
Merge pull request #80 from adgianv/expand_evaluator_class
Browse files Browse the repository at this point in the history
Addressing issue #65 - Expanded evaluator class: extract results as a Dataframe
  • Loading branch information
davidsbatista authored Sep 12, 2024
2 parents 97da28e + 98e3950 commit df0e695
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 6 deletions.
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"


[project]
name = "nervaluate"
version = "0.2.0"
Expand All @@ -15,11 +14,14 @@ readme = "README.md"
requires-python = ">=3.8"
keywords = ["named-entity-recognition", "ner", "evaluation-metrics", "partial-match-scoring", "nlp"]
license = {text = "MIT License"}
classifiers=[
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent"
]

[project.dependencies]
pandas = "==2.0.1"

[project.urls]
"Homepage" = "https://github.com/MantisAI/nervaluate"
"Bug Tracker" = "https://github.com/MantisAI/nervaluate/issues"
"Bug Tracker" = "https://github.com/MantisAI/nervaluate/issues"
3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ gitchangelog
mypy==1.3.0
pre-commit==3.3.1
pylint==2.17.4
pytest==7.3.1
pytest==7.3.1
pandas==2.0.1
49 changes: 48 additions & 1 deletion src/nervaluate/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from copy import deepcopy
from typing import List, Dict, Union, Tuple, Optional
import pandas as pd
from typing import List, Dict, Union, Tuple, Optional, Any
from collections import defaultdict

from .utils import conll_to_spans, find_overlap, list_to_spans

Expand Down Expand Up @@ -118,6 +120,51 @@ def evaluate(self) -> Tuple[Dict, Dict, Dict, Dict]:
)

return self.results, self.evaluation_agg_entities_type, self.evaluation_indices, self.evaluation_agg_indices

# Helper method to flatten a nested dictionary
def _flatten_dict(self, d: Dict[str, Any], parent_key: str = '', sep: str = '.') -> Dict[str, Any]:
"""
Flattens a nested dictionary.
Args:
d (dict): The dictionary to flatten.
parent_key (str): The base key string to prepend to each dictionary key.
sep (str): The separator to use when combining keys.
Returns:
dict: A flattened dictionary.
"""
items: List[Tuple[str, Any]] = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(self._flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)

# Modified results_to_dataframe method using the helper method
def results_to_dataframe(self) -> Any:
if not self.results:
raise ValueError("self.results should be defined.")

if not isinstance(self.results, dict) or not all(isinstance(v, dict) for v in self.results.values()):
raise ValueError("self.results must be a dictionary of dictionaries.")

# Flatten the nested results dictionary, including the 'entities' sub-dictionaries
flattened_results: Dict[str, Dict[str, Any]] = {}
for outer_key, inner_dict in self.results.items():
flattened_inner_dict = self._flatten_dict(inner_dict)
for inner_key, value in flattened_inner_dict.items():
if inner_key not in flattened_results:
flattened_results[inner_key] = {}
flattened_results[inner_key][outer_key] = value

# Convert the flattened results to a pandas DataFrame
try:
return pd.DataFrame(flattened_results)
except Exception as e:
raise RuntimeError("Error converting flattened results to DataFrame") from e


# flake8: noqa: C901
Expand Down
112 changes: 111 additions & 1 deletion tests/test_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,116 @@
# pylint: disable=C0302
# pylint: disable=too-many-lines
import pandas as pd
from nervaluate import Evaluator

def test_results_to_dataframe():
"""
Test the results_to_dataframe method.
"""
# Setup
evaluator = Evaluator(
true=[['B-LOC', 'I-LOC', 'O'], ['B-PER', 'O', 'O']],
pred=[['B-LOC', 'I-LOC', 'O'], ['B-PER', 'I-PER', 'O']],
tags=['LOC', 'PER']
)

# Mock results data for the purpose of this test
evaluator.results = {
'strict': {
'correct': 10,
'incorrect': 5,
'partial': 3,
'missed': 2,
'spurious': 4,
'precision': 0.625,
'recall': 0.6667,
'f1': 0.6452,
'entities': {
'LOC': {'correct': 4, 'incorrect': 1, 'partial': 0, 'missed': 1, 'spurious': 2},
'PER': {'correct': 3, 'incorrect': 2, 'partial': 1, 'missed': 0, 'spurious': 1},
'ORG': {'correct': 3, 'incorrect': 2, 'partial': 2, 'missed': 1, 'spurious': 1}
}
},
'ent_type': {
'correct': 8,
'incorrect': 4,
'partial': 1,
'missed': 3,
'spurious': 3,
'precision': 0.5714,
'recall': 0.6154,
'f1': 0.5926,
'entities': {
'LOC': {'correct': 3, 'incorrect': 2, 'partial': 1, 'missed': 1, 'spurious': 1},
'PER': {'correct': 2, 'incorrect': 1, 'partial': 0, 'missed': 2, 'spurious': 0},
'ORG': {'correct': 3, 'incorrect': 1, 'partial': 0, 'missed': 0, 'spurious': 2}
}
},
'partial': {
'correct': 7,
'incorrect': 3,
'partial': 4,
'missed': 1,
'spurious': 5,
'precision': 0.5385,
'recall': 0.6364,
'f1': 0.5833,
'entities': {
'LOC': {'correct': 2, 'incorrect': 1, 'partial': 1, 'missed': 1, 'spurious': 2},
'PER': {'correct': 3, 'incorrect': 1, 'partial': 1, 'missed': 0, 'spurious': 1},
'ORG': {'correct': 2, 'incorrect': 1, 'partial': 2, 'missed': 0, 'spurious': 2}
}
},
'exact': {
'correct': 9,
'incorrect': 6,
'partial': 2,
'missed': 2,
'spurious': 2,
'precision': 0.6,
'recall': 0.6429,
'f1': 0.6207,
'entities': {
'LOC': {'correct': 4, 'incorrect': 1, 'partial': 0, 'missed': 1, 'spurious': 1},
'PER': {'correct': 3, 'incorrect': 3, 'partial': 0, 'missed': 0, 'spurious': 0},
'ORG': {'correct': 2, 'incorrect': 2, 'partial': 2, 'missed': 1, 'spurious': 1}
}
}
}

# Expected DataFrame
expected_data = {
'correct': {'strict': 10, 'ent_type': 8, 'partial': 7, 'exact': 9},
'incorrect': {'strict': 5, 'ent_type': 4, 'partial': 3, 'exact': 6},
'partial': {'strict': 3, 'ent_type': 1, 'partial': 4, 'exact': 2},
'missed': {'strict': 2, 'ent_type': 3, 'partial': 1, 'exact': 2},
'spurious': {'strict': 4, 'ent_type': 3, 'partial': 5, 'exact': 2},
'precision': {'strict': 0.625, 'ent_type': 0.5714, 'partial': 0.5385, 'exact': 0.6},
'recall': {'strict': 0.6667, 'ent_type': 0.6154, 'partial': 0.6364, 'exact': 0.6429},
'f1': {'strict': 0.6452, 'ent_type': 0.5926, 'partial': 0.5833, 'exact': 0.6207},
'entities.LOC.correct': {'strict': 4, 'ent_type': 3, 'partial': 2, 'exact': 4},
'entities.LOC.incorrect': {'strict': 1, 'ent_type': 2, 'partial': 1, 'exact': 1},
'entities.LOC.partial': {'strict': 0, 'ent_type': 1, 'partial': 1, 'exact': 0},
'entities.LOC.missed': {'strict': 1, 'ent_type': 1, 'partial': 1, 'exact': 1},
'entities.LOC.spurious': {'strict': 2, 'ent_type': 1, 'partial': 2, 'exact': 1},
'entities.PER.correct': {'strict': 3, 'ent_type': 2, 'partial': 3, 'exact': 3},
'entities.PER.incorrect': {'strict': 2, 'ent_type': 1, 'partial': 1, 'exact': 3},
'entities.PER.partial': {'strict': 1, 'ent_type': 0, 'partial': 1, 'exact': 0},
'entities.PER.missed': {'strict': 0, 'ent_type': 2, 'partial': 0, 'exact': 0},
'entities.PER.spurious': {'strict': 1, 'ent_type': 0, 'partial': 1, 'exact': 0},
'entities.ORG.correct': {'strict': 3, 'ent_type': 3, 'partial': 2, 'exact': 2},
'entities.ORG.incorrect': {'strict': 2, 'ent_type': 1, 'partial': 1, 'exact': 2},
'entities.ORG.partial': {'strict': 2, 'ent_type': 0, 'partial': 2, 'exact': 2},
'entities.ORG.missed': {'strict': 1, 'ent_type': 0, 'partial': 0, 'exact': 1},
'entities.ORG.spurious': {'strict': 1, 'ent_type': 2, 'partial': 2, 'exact': 1}
}

expected_df = pd.DataFrame(expected_data)

# Execute
result_df = evaluator.results_to_dataframe()

# Assert
pd.testing.assert_frame_equal(result_df, expected_df)

def test_evaluator_simple_case():
true = [
Expand Down

0 comments on commit df0e695

Please sign in to comment.