Skip to content

Commit

Permalink
Merge pull request #4 from yngvem/update-confusion-matrix
Browse files Browse the repository at this point in the history
Add extra utilities for the confusion matrix
  • Loading branch information
MarieRoald authored Oct 13, 2024
2 parents e6aa545 + eb37fcf commit 1da84ee
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 9 deletions.
55 changes: 46 additions & 9 deletions python/stringalign/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@
from numbers import Number
from typing import Self, cast

from stringalign.align import AlignmentOperation, Keep, Replace
from stringalign.align import AlignmentOperation, Keep, Replace, aggregate_alignment, align_strings
from stringalign.tokenize import Tokenizer


@dataclass
def sort_by_values(d: dict[str, float], reverse=False) -> dict[str, float]:
return dict(sorted(d.items(), key=lambda x: x[1], reverse=reverse))


@dataclass(eq=True)
class StringConfusionMatrix:
true_positives: Counter[str]
false_positives: Counter[str] # Added characters
false_negatives: Counter[str] # Removed/missed characters
edit_counts: Counter[AlignmentOperation] # Count of each operation type
# There is no true negatives when we compare strings.
# Either, a character is in the string or it is not.

Expand All @@ -28,12 +34,15 @@ def from_strings_and_alignment(
true_positives: Counter[str] = Counter()
false_positives: Counter[str] = Counter()
false_negatives: Counter[str] = Counter()
edit_counts: Counter[AlignmentOperation] = Counter()
for op in alignment:
if isinstance(op, Keep):
true_positives[next(ref_iter)] += 1
next(pred_iter)
continue

edit_counts[op] += 1

op = cast(Replace, op.generalize())
for char in op.substring:
false_positives[char] += 1
Expand All @@ -46,6 +55,25 @@ def from_strings_and_alignment(
true_positives=true_positives,
false_positives=false_positives,
false_negatives=false_negatives,
edit_counts=edit_counts,
)

@classmethod
def from_strings(
cls, reference: str, predicted: str, tokenizer: Tokenizer | None = None, aggregate: bool = False
) -> Self:
alignment = align_strings(reference, predicted, tokenizer=tokenizer)
if aggregate:
alignment = list(aggregate_alignment(alignment))
return cls.from_strings_and_alignment(reference, predicted, alignment)

@classmethod
def get_empty(cls) -> Self:
return cls(
true_positives=Counter(),
false_positives=Counter(),
false_negatives=Counter(),
edit_counts=Counter(),
)

def compute_true_positive_rate(self, aggregate_over: str | None = None) -> dict[str, float] | float:
Expand All @@ -57,7 +85,7 @@ def compute_true_positive_rate(self, aggregate_over: str | None = None) -> dict[
return tp / (tp + fn)

char_count = self.true_positives + self.false_negatives
return {key: self.true_positives[key] / char_count[key] for key in char_count}
return sort_by_values({key: self.true_positives[key] / char_count[key] for key in char_count}, reverse=True)

compute_recall = compute_true_positive_rate
compute_sensitivity = compute_true_positive_rate
Expand All @@ -74,7 +102,9 @@ def compute_positive_predictive_value(self, aggregate_over: str | None = None) -
return tp / (tp + fp)

predicted_positive = self.true_positives + self.false_positives
return {key: self.true_positives[key] / predicted_positive[key] for key in self.true_positives}
return sort_by_values(
{key: self.true_positives[key] / predicted_positive[key] for key in self.true_positives}, reverse=True
)

compute_precision = compute_positive_predictive_value

Expand All @@ -90,7 +120,9 @@ def compute_false_discovery_rate(self, aggregate_over: str | None = None) -> dic
return fp / (tp + fp)

predicted_positive = self.true_positives + self.false_positives
return {key: self.false_positives[key] / predicted_positive[key] for key in self.false_positives}
return sort_by_values(
{key: self.false_positives[key] / predicted_positive[key] for key in self.false_positives}, reverse=True
)

def compute_f1_score(self, aggregate_over: str | None = None) -> dict[str, float] | float:
"""The harmonic mean of the true positive rate and positive predictive value."""
Expand All @@ -104,10 +136,14 @@ def compute_f1_score(self, aggregate_over: str | None = None) -> dict[str, float
assert isinstance(tpr, dict) and isinstance(ppv, dict)
all_chars = set(self.true_positives) | set(self.false_positives) | set(self.false_negatives)
tpr, ppv = defaultdict(int, tpr), defaultdict(int, ppv)
return {
c: (tpr[c] * ppv[c]) / (0.5 * (tpr[c] + ppv[c] or 1)) # or 1 avoids division by 0, the value is 0 anyways
for c in all_chars
}
return sort_by_values(
{
c: (tpr[c] * ppv[c])
/ (0.5 * (tpr[c] + ppv[c] or 1)) # or 1 avoids division by 0, the value is 0 anyways
for c in all_chars
},
reverse=True,
)

compute_dice = compute_f1_score

Expand All @@ -118,6 +154,7 @@ def __add__(self, other: Self) -> Self:
true_positives=self.true_positives + other.true_positives,
false_positives=self.false_positives + other.false_positives,
false_negatives=self.false_negatives + other.false_negatives,
edit_counts=self.edit_counts + other.edit_counts,
)

__radd__ = __add__
26 changes: 26 additions & 0 deletions tests/statistics/ConfusionMatrix/test_add.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import Counter

from stringalign.align import Delete, Insert, Keep, Replace
from stringalign.statistics import StringConfusionMatrix


Expand All @@ -8,14 +9,39 @@ def test__add__() -> None:
true_positives=Counter({"a": 3, "b": 2, "c": 1}),
false_positives=Counter({"a": 1, "b": 1, "d": 1}),
false_negatives=Counter({"a": 1, "c": 1, "e": 1}),
edit_counts=Counter(
{
Keep("a"): 3,
Keep("b"): 2,
Keep("c"): 1,
Delete("a"): 1,
Delete("b"): 1,
Insert("a"): 1,
Insert("c"): 1,
Replace("d", "e"): 1,
}
),
)
cm2 = StringConfusionMatrix(
true_positives=Counter({"a": 3, "b": 2, "d": 1}),
false_positives=Counter({"a": 1, "b": 1, "f": 1}),
false_negatives=Counter({"a": 1, "c": 1, "g": 1}),
edit_counts=Counter(
{
Keep("a"): 3,
Keep("b"): 2,
Keep("d"): 1,
Delete("a"): 1,
Delete("b"): 1,
Insert("a"): 1,
Insert("c"): 1,
Replace("f", "g"): 1,
}
),
)

cm3 = cm1 + cm2
assert cm3.true_positives == cm1.true_positives + cm2.true_positives
assert cm3.false_positives == cm1.false_positives + cm2.false_positives
assert cm3.false_negatives == cm1.false_negatives + cm2.false_negatives
assert cm3.edit_counts == cm1.edit_counts + cm2.edit_counts
1 change: 1 addition & 0 deletions tests/statistics/ConfusionMatrix/test_compute_f1_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def sample_confusion_matrix() -> StringConfusionMatrix:
true_positives=Counter({"a": 3, "b": 2, "c": 1}),
false_positives=Counter({"a": 1, "b": 1, "d": 1}),
false_negatives=Counter({"a": 1, "c": 1, "e": 1}),
edit_counts=Counter(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def sample_confusion_matrix() -> StringConfusionMatrix:
true_positives=Counter({"a": 3, "b": 2, "c": 1}),
false_positives=Counter({"a": 1, "b": 1, "d": 1}),
false_negatives=Counter({"a": 1, "c": 1, "e": 1}),
edit_counts=Counter(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def sample_confusion_matrix() -> StringConfusionMatrix:
true_positives=Counter({"a": 3, "b": 2, "c": 1}),
false_positives=Counter({"a": 1, "b": 1, "d": 1}),
false_negatives=Counter({"a": 1, "c": 1, "e": 1}),
edit_counts=Counter(),
)


Expand Down
24 changes: 24 additions & 0 deletions tests/statistics/ConfusionMatrix/test_from_strings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from collections import Counter

from stringalign.align import Delete, Insert, Keep, Replace, align_strings
from stringalign.statistics import StringConfusionMatrix


def test_from_strings() -> None:
reference = "abcbaa"
predicted = "acdeai"
alignment = align_strings(reference=reference, predicted=predicted)

result1 = StringConfusionMatrix.from_strings_and_alignment(reference, predicted, alignment)
result2 = StringConfusionMatrix.from_strings(reference, predicted)

assert result1 == result2


def test_from_strings_empty() -> None:
result = StringConfusionMatrix.from_strings("", "")

assert result.true_positives == Counter()
assert result.false_positives == Counter()
assert result.false_negatives == Counter()
assert result.edit_counts == Counter()
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ def test_from_strings_and_alignment() -> None:
assert result.true_positives == Counter({"a": 2, "c": 1})
assert result.false_positives == Counter({"d": 1, "e": 1, "i": 1})
assert result.false_negatives == Counter({"b": 2, "a": 1})
assert result.edit_counts == Counter(
{
Insert("b"): 1,
Replace("d", "b"): 1,
Replace("e", "a"): 1,
Delete("i"): 1,
}
)


def test_from_strings_and_alignment_empty() -> None:
Expand All @@ -22,3 +30,4 @@ def test_from_strings_and_alignment_empty() -> None:
assert result.true_positives == Counter()
assert result.false_positives == Counter()
assert result.false_negatives == Counter()
assert result.edit_counts == Counter()
13 changes: 13 additions & 0 deletions tests/statistics/ConfusionMatrix/test_get_empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from collections import Counter

from stringalign.align import Delete, Insert, Keep, Replace, align_strings
from stringalign.statistics import StringConfusionMatrix


def test_from_strings_and_alignment_empty() -> None:
result = StringConfusionMatrix.get_empty()

assert result.true_positives == Counter()
assert result.false_positives == Counter()
assert result.false_negatives == Counter()
assert result.edit_counts == Counter()

0 comments on commit 1da84ee

Please sign in to comment.