-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
37 lines (31 loc) · 1.38 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from pyannote.metrics.diarization import DiarizationErrorRate, JaccardErrorRate
from pyannote.core import Annotation, Segment
class DiarizationReference():
def __init__(self, reference_rttm):
self.reference_rttm = reference_rttm
# Loads RTTMs to Annotations for metric calculations
def _load_rttm_to_annotation(self, rttm_file):
annotation = Annotation()
with open(rttm_file, 'r') as file:
for line in file:
parts = line.strip().split()
start_time = float(parts[3])
duration = float(parts[4])
end_time = start_time + duration
speaker = parts[7]
segment = Segment(start_time, end_time)
annotation[segment] = speaker
return annotation
# For a list of hypothesis RTTM files, calculate metrics for each and return
# similarly indexed array.
def calculate_metrics(self, hypotheses_rttms, type):
if type == 'der':
metric = DiarizationErrorRate()
elif type == 'jer':
metric = JaccardErrorRate()
reference = self._load_rttm_to_annotation(self.reference_rttm)
metrics = []
for i in range(len(hypotheses_rttms)):
hypothesis = self._load_rttm_to_annotation(hypotheses_rttms[i])
metrics.append(metric(reference, hypothesis))
return metrics