-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsalmon.py
129 lines (98 loc) · 4.52 KB
/
salmon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
from torch.utils.data import DataLoader, Dataset
import torchaudio
import json
import torch
# from tqdm import tqdm
from pathlib import Path
from baselines.inference import InferenceModelFactory
class SalmonDataset(Dataset):
def __init__(self, salmon_path, part, load_audio=True):
self.data = []
self.load_audio = load_audio
self.use_hugging_face = salmon_path is None # set Hugging Face flag if data path hasn't specified
if self.use_hugging_face:
print("Salmon path not specified, downloading dataset from Hugging Face ...")
assert load_audio is True, "load_audio must be True when using Hugging Face"
from datasets import load_dataset
salmon = load_dataset('slprl/salmon', part)
self.data = [[s['positive_audio']['array'], s['negative_audio']['array']] for s in salmon['train']]
else:
salmon_path = Path(salmon_path)
dir_path = salmon_path / part
paths = list(dir_path.glob("*.wav"))
max_sample_index = -1
for path in paths:
stem = str(path.stem)
parts = stem.split("_")
sample_index = int(parts[1])
if sample_index > max_sample_index:
max_sample_index = sample_index
self.data = [[] for _ in range(max_sample_index + 1)]
for path in paths:
stem = str(path.stem)
parts = stem.split("_")
sample_index = int(parts[1])
self.data[sample_index].append(str(path))
for sample_list in self.data:
sample_list.sort()
self.data = [lst for lst in self.data if lst]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample_files = self.data[idx]
if self.use_hugging_face:
return [torch.tensor(arr, dtype=torch.float32).unsqueeze(0) for arr in sample_files]
elif self.load_audio:
sample_audios = [torchaudio.load(sample_file) for sample_file in sample_files]
return [s[0] for s in sample_audios]
else:
return sample_files
def collate_fn(batch):
pos, neg = zip(*batch)
return list(pos), list(neg)
def main():
parser = argparse.ArgumentParser(description='Run SALMon')
parser.add_argument("-c", "--inference_model_config", type=str, required=True, help="inference model config json")
parser.add_argument("-s", "--salmon_path", default=None, type=str, help="Path to the downloaded SALMon dataset, if not specified salmon will be downloaded from Hugging Face")
parser.add_argument("-p", "--parts", type=str, nargs="+", default=["all"], help="parts")
parser.add_argument("-b", "--batch_size", type=int, default=1, help="batch size")
args = parser.parse_args()
salmon_path = args.salmon_path
config_path = args.inference_model_config
with open(config_path) as f:
inference_model_config = json.load(f)
inference_model = InferenceModelFactory.get_model(inference_model_config)
if torch.cuda.is_available():
inference_model = inference_model.to("cuda")
if args.parts[0] == "all":
args.parts = [
'bg_alignment',
'bg_all_consistency',
'bg_domain_consistency',
'gender_consistency',
'rir_consistency',
'sentiment_alignment',
'sentiment_consistency',
'speaker_consistency',
]
print(f"Calculating {len(args.parts)} parts of SALMon for {inference_model} model")
for part in args.parts:
dataset = SalmonDataset(salmon_path, part, load_audio=True)
assert len(dataset) > 0, f"no samples found for {part}"
dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn)
res_list = []
with torch.no_grad():
for sample_files in dataloader:
pos_sample, neg_sample = sample_files
pos_likelihood = inference_model.log_likelihood(pos_sample)
neg_likelihood = inference_model.log_likelihood(neg_sample)
res = torch.zeros_like(pos_likelihood)
res[pos_likelihood > neg_likelihood] = 1
res[pos_likelihood == neg_likelihood] = 0.5
res[pos_likelihood < neg_likelihood] = 0
res_list.append(res)
res_list = torch.cat(res_list)
print(f"SALMon - {part}: {res_list.float().mean().cpu():.4f}")
if __name__ == "__main__":
main()