diff --git a/python/stringalign/evaluation.py b/python/stringalign/evaluation.py index 0aa4dc3..19f8f66 100644 --- a/python/stringalign/evaluation.py +++ b/python/stringalign/evaluation.py @@ -1,4 +1,4 @@ -from collections import Counter, deque +from collections import Counter, defaultdict, deque from collections.abc import Generator, Mapping from dataclasses import dataclass from itertools import chain @@ -210,6 +210,24 @@ def alignment_operators(self) -> Counter[AlignmentOperation]: def confusion_matrix(self) -> StringConfusionMatrix: return sum((le.confusion_matrix for le in self.line_errors), start=StringConfusionMatrix.get_empty()) + @property + def line_error_raw_lookup(self) -> dict[AlignmentOperation, frozenset[LineError]]: + out = defaultdict(set) + for line_error in self.line_errors: + for alignment_op in line_error.raw_alignment: + out[alignment_op].add(line_error) + + return {k: frozenset(v) for k, v in out.items()} + + @property + def line_error_aggregated_lookup(self) -> dict[AlignmentOperation, frozenset[LineError]]: + out = defaultdict(set) + for line_error in self.line_errors: + for alignment_op in line_error.alignment: + out[alignment_op].add(line_error) + + return {k: frozenset(v) for k, v in out.items()} + @classmethod def from_strings( cls, diff --git a/tests/evaluation/TranscriptionEvaluator/test_line_error_aggregated_lookup.py b/tests/evaluation/TranscriptionEvaluator/test_line_error_aggregated_lookup.py new file mode 100644 index 0000000..9a93f44 --- /dev/null +++ b/tests/evaluation/TranscriptionEvaluator/test_line_error_aggregated_lookup.py @@ -0,0 +1,12 @@ +from stringalign.align import Replace +from stringalign.evaluation import TranscriptionEvaluator + + +def test_raw_lookup(): + evaluator = TranscriptionEvaluator.from_strings( + references=["abc", "def", "aaa"], + predictions=["bbc", "deg", "abb"], + ) + assert evaluator.line_error_aggregated_lookup[Replace("b", "a")] == frozenset({evaluator.line_errors[0]}) + assert evaluator.line_error_aggregated_lookup[Replace("g", "f")] == frozenset({evaluator.line_errors[1]}) + assert evaluator.line_error_aggregated_lookup[Replace("bb", "aa")] == frozenset({evaluator.line_errors[2]}) diff --git a/tests/evaluation/TranscriptionEvaluator/test_line_error_raw_lookup.py b/tests/evaluation/TranscriptionEvaluator/test_line_error_raw_lookup.py new file mode 100644 index 0000000..b5c9424 --- /dev/null +++ b/tests/evaluation/TranscriptionEvaluator/test_line_error_raw_lookup.py @@ -0,0 +1,17 @@ +from stringalign.align import Replace +from stringalign.evaluation import TranscriptionEvaluator + + +def test_raw_lookup(): + evaluator = TranscriptionEvaluator.from_strings( + references=["abc", "def", "aaa"], + predictions=["bbc", "deg", "abb"], + ) + assert evaluator.line_error_raw_lookup[Replace("b", "a")] == frozenset( + {evaluator.line_errors[0], evaluator.line_errors[2]} + ) + assert evaluator.line_error_raw_lookup[Replace("g", "f")] == frozenset( + { + evaluator.line_errors[1], + } + )