-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
1,717 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
melodytalk/dependencies/lpmc/music_captioning/captioning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import argparse | ||
import os | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.parallel | ||
import torch.backends.cudnn as cudnn | ||
import torch.distributed as dist | ||
import torch.optim | ||
import torch.multiprocessing as mp | ||
import torch.utils.data | ||
import torch.utils.data.distributed | ||
|
||
from melodytalk.dependencies.lpmc.music_captioning.model.bart import BartCaptionModel | ||
from melodytalk.dependencies.lpmc.utils.eval_utils import load_pretrained | ||
from melodytalk.dependencies.lpmc.utils.audio_utils import load_audio, STR_CH_FIRST | ||
from omegaconf import OmegaConf | ||
|
||
parser = argparse.ArgumentParser(description='PyTorch MSD Training') | ||
parser.add_argument('--gpu', default=1, type=int, | ||
help='GPU id to use.') | ||
parser.add_argument("--framework", default="transfer", type=str) | ||
parser.add_argument("--caption_type", default="lp_music_caps", type=str) | ||
parser.add_argument("--max_length", default=128, type=int) | ||
parser.add_argument("--num_beams", default=5, type=int) | ||
parser.add_argument("--model_type", default="last", type=str) | ||
parser.add_argument("--audio_path", default="../../dataset/samples/orchestra.wav", type=str) | ||
|
||
def get_audio(audio_path, duration=10, target_sr=16000): | ||
n_samples = int(duration * target_sr) | ||
audio, sr = load_audio( | ||
path= audio_path, | ||
ch_format= STR_CH_FIRST, | ||
sample_rate= target_sr, | ||
downmix_to_mono= True, | ||
) | ||
if len(audio.shape) == 2: | ||
audio = audio.mean(0, False) # to mono | ||
input_size = int(n_samples) | ||
if audio.shape[-1] < input_size: # pad sequence | ||
pad = np.zeros(input_size) | ||
pad[: audio.shape[-1]] = audio | ||
audio = pad | ||
ceil = int(audio.shape[-1] // n_samples) | ||
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32')) | ||
return audio | ||
|
||
def main(): | ||
args = parser.parse_args() | ||
captioning(args) | ||
|
||
def captioning(args): | ||
save_dir = f"exp/{args.framework}/{args.caption_type}/" | ||
config = OmegaConf.load(os.path.join(save_dir, "hparams.yaml")) | ||
model = BartCaptionModel(max_length = config.max_length) | ||
model, save_epoch = load_pretrained(args, save_dir, model, mdp=config.multiprocessing_distributed) | ||
torch.cuda.set_device(args.gpu) | ||
model = model.cuda(args.gpu) | ||
model.eval() | ||
|
||
audio_tensor = get_audio(audio_path = args.audio_path) | ||
if args.gpu is not None: | ||
audio_tensor = audio_tensor.cuda(args.gpu, non_blocking=True) | ||
|
||
with torch.no_grad(): | ||
output = model.generate( | ||
samples=audio_tensor, | ||
num_beams=args.num_beams, | ||
) | ||
inference = {} | ||
number_of_chunks = range(audio_tensor.shape[0]) | ||
for chunk, text in zip(number_of_chunks, output): | ||
time = f"{chunk * 10}:00-{(chunk + 1) * 10}:00" | ||
item = {"text":text,"time":time} | ||
inference[chunk] = item | ||
print(item) | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
|
58 changes: 58 additions & 0 deletions
58
melodytalk/dependencies/lpmc/music_captioning/datasets/mc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
import random | ||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from datasets import load_dataset | ||
from torch.utils.data import Dataset | ||
from lpmc.utils.audio_utils import load_audio, STR_CH_FIRST | ||
|
||
class MC_Dataset(Dataset): | ||
def __init__(self, data_path, split, caption_type, sr=16000, duration=10, audio_enc="wav"): | ||
self.data_path = data_path | ||
self.split = split | ||
self.caption_type = caption_type | ||
self.audio_enc = audio_enc | ||
self.n_samples = int(sr * duration) | ||
self.annotation = load_dataset("seungheondoh/LP-MusicCaps-MC") | ||
self.get_split() | ||
|
||
def get_split(self): | ||
if self.split == "train": | ||
self.fl = [i for i in self.annotation[self.split] if i['is_crawled']] | ||
elif self.split == "test": | ||
self.fl = [i for i in self.annotation[self.split] if i['is_crawled']] | ||
else: | ||
raise ValueError(f"Unexpected split name: {self.split}") | ||
|
||
def load_audio(self, audio_path, file_type): | ||
if file_type == ".npy": | ||
audio = np.load(audio_path, mmap_mode='r') | ||
else: | ||
audio, _ = load_audio( | ||
path=audio_path, | ||
ch_format= STR_CH_FIRST, | ||
sample_rate= self.sr, | ||
downmix_to_mono= True | ||
) | ||
if len(audio.shape) == 2: | ||
audio = audio.squeeze(0) | ||
input_size = int(self.n_samples) | ||
if audio.shape[-1] < input_size: | ||
pad = np.zeros(input_size) | ||
pad[:audio.shape[-1]] = audio | ||
audio = pad | ||
random_idx = random.randint(0, audio.shape[-1]-self.n_samples) | ||
audio_tensor = torch.from_numpy(np.array(audio[random_idx:random_idx+self.n_samples]).astype('float32')) | ||
return audio_tensor | ||
|
||
def __getitem__(self, index): | ||
item = self.fl[index] | ||
fname = item['fname'] | ||
text = item['caption_ground_truth'] | ||
audio_path = os.path.join(self.data_path, "music_caps", 'npy', fname + ".npy") | ||
audio_tensor = self.load_audio(audio_path, file_type=audio_path[-4:]) | ||
return fname, text, audio_tensor | ||
|
||
def __len__(self): | ||
return len(self.fl) |
64 changes: 64 additions & 0 deletions
64
melodytalk/dependencies/lpmc/music_captioning/datasets/msd.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os | ||
import json | ||
import random | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
from datasets import load_dataset | ||
|
||
class MSD_Balanced_Dataset(Dataset): | ||
def __init__(self, data_path, split, caption_type, sr=16000, duration=10, audio_enc="wav"): | ||
self.data_path = data_path | ||
self.split = split | ||
self.audio_enc = audio_enc | ||
self.n_samples = int(sr * duration) | ||
self.caption_type = caption_type | ||
self.dataset = load_dataset("seungheondoh/LP-MusicCaps-MSD") | ||
self.get_split() | ||
|
||
def get_split(self): | ||
self.tags = json.load(open(os.path.join(self.data_path, "msd", f"{self.split}_tags.json"), 'r')) | ||
self.tag_to_track = json.load(open(os.path.join(self.data_path, "msd", f"{self.split}_tag_to_track.json"), 'r')) | ||
self.annotation = {instance['track_id'] : instance for instance in self.dataset[self.split]} | ||
|
||
def load_caption(self, item): | ||
caption_pool = [] | ||
if (self.caption_type in "write") or (self.caption_type == "lp_music_caps"): | ||
caption_pool.append(item['caption_writing']) | ||
if (self.caption_type in "summary") or (self.caption_type == "lp_music_caps"): | ||
caption_pool.append(item['caption_summary']) | ||
if (self.caption_type in "creative") or (self.caption_type == "lp_music_caps"): | ||
caption_pool.append(item['caption_paraphrase']) | ||
if (self.caption_type in "predict") or (self.caption_type == "lp_music_caps"): | ||
caption_pool.append(item['caption_attribute_prediction']) | ||
# randomly select one caption from multiple captions | ||
sampled_caption = random.choice(caption_pool) | ||
return sampled_caption | ||
|
||
def load_audio(self, audio_path, file_type): | ||
audio = np.load(audio_path, mmap_mode='r') | ||
if len(audio.shape) == 2: | ||
audio = audio.squeeze(0) | ||
input_size = int(self.n_samples) | ||
if audio.shape[-1] < input_size: | ||
pad = np.zeros(input_size) | ||
pad[:audio.shape[-1]] = audio | ||
audio = pad | ||
random_idx = random.randint(0, audio.shape[-1]-self.n_samples) | ||
audio_tensor = torch.from_numpy(np.array(audio[random_idx:random_idx+self.n_samples]).astype('float32')) | ||
audio_tensor = torch.randn(16000*10) | ||
return audio_tensor | ||
|
||
def __getitem__(self, index): | ||
tag = random.choice(self.tags) # uniform random sample tag | ||
fname = random.choice(self.tag_to_track[tag]) # uniform random sample track | ||
item = self.annotation[fname] | ||
track_id = item['track_id'] | ||
gt_caption = "" # no ground turhth | ||
text = self.load_caption(item) | ||
audio_path = os.path.join(self.data_path,'msd','npy', item['path'].replace(".mp3", ".npy")) | ||
audio_tensor = self.load_audio(audio_path, file_type=".npy") | ||
return fname, gt_caption, text, audio_tensor | ||
|
||
def __len__(self): | ||
return 2**13 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
import argparse | ||
import json | ||
import numpy as np | ||
from datasets import load_dataset | ||
from lpmc.utils.metrics import bleu, meteor, rouge, bertscore, vocab_novelty, caption_novelty | ||
|
||
def inference_parsing(dataset, args): | ||
ground_truths = [i['caption_ground_truth'] for i in dataset] | ||
inference = json.load(open(os.path.join(args.save_dir, args.framework, args.caption_type, 'inference_temp.json'), 'r')) | ||
id2pred = {item['audio_id']:item['predictions'] for item in inference.values()} | ||
predictions = [id2pred[i['fname']] for i in dataset] | ||
return predictions, ground_truths | ||
|
||
def main(args): | ||
dataset = load_dataset("seungheondoh/LP-MusicCaps-MC") | ||
train_data = [i for i in dataset['train'] if i['is_crawled']] | ||
test_data = [i for i in dataset['test'] if i['is_crawled']] | ||
tr_ground_truths = [i['caption_ground_truth'] for i in train_data] | ||
predictions, ground_truths = inference_parsing(test_data, args) | ||
length_avg = np.mean([len(cap.split()) for cap in predictions]) | ||
length_std = np.std([len(cap.split()) for cap in predictions]) | ||
|
||
vocab_size, vocab_novel_score = vocab_novelty(predictions, tr_ground_truths) | ||
cap_novel_score = caption_novelty(predictions, tr_ground_truths) | ||
results = { | ||
"bleu1": bleu(predictions, ground_truths, order=1), | ||
"bleu2": bleu(predictions, ground_truths, order=2), | ||
"bleu3": bleu(predictions, ground_truths, order=3), | ||
"bleu4": bleu(predictions, ground_truths, order=4), | ||
"meteor_1.0": meteor(predictions, ground_truths), # https://github.com/huggingface/evaluate/issues/115 | ||
"rougeL": rouge(predictions, ground_truths), | ||
"bertscore": bertscore(predictions, ground_truths), | ||
"vocab_size": vocab_size, | ||
"vocab_novelty": vocab_novel_score, | ||
"caption_novelty": cap_novel_score, | ||
"length_avg": length_avg, | ||
"length_std": length_std | ||
} | ||
os.makedirs(os.path.join(args.save_dir, args.framework, args.caption_type), exist_ok=True) | ||
with open(os.path.join(args.save_dir, args.framework, args.caption_type, f"results.json"), mode="w") as io: | ||
json.dump(results, io, indent=4) | ||
print(results) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--save_dir", default="./exp", type=str) | ||
parser.add_argument("--framework", default="supervised", type=str) | ||
parser.add_argument("--caption_type", default="gt", type=str) | ||
args = parser.parse_args() | ||
main(args=args) |
27 changes: 27 additions & 0 deletions
27
melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/hparams.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
framework: bart | ||
data_dir: ../../dataset | ||
train_data: msd_balence | ||
text_type: all | ||
arch: transformer | ||
workers: 12 | ||
epochs: 4096 | ||
warmup_epochs: 125 | ||
start_epoch: 0 | ||
batch_size: 256 | ||
world_size: 1 | ||
lr: 0.0001 | ||
min_lr: 1.0e-09 | ||
rank: 0 | ||
dist_url: tcp://localhost:12312 | ||
dist_backend: nccl | ||
seed: null | ||
gpu: 0 | ||
print_freq: 100 | ||
multiprocessing_distributed: false | ||
cos: true | ||
bart_pretrain: false | ||
label_smoothing: 0.1 | ||
use_early_stopping: false | ||
eval_sample: 0 | ||
max_length: 110 | ||
distributed: false |
14 changes: 14 additions & 0 deletions
14
melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/results.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"bleu1": 0.19774511999248645, | ||
"bleu2": 0.06701659207795346, | ||
"bleu3": 0.02165946217889739, | ||
"bleu4": 0.007868711032317937, | ||
"meteor_1.0": 0.12883275733962948, | ||
"rougeL": 0.1302757267270023, | ||
"bertscore": 0.8451152142608364, | ||
"vocab_size": 1686, | ||
"vocab_diversity": 0.4721233689205219, | ||
"caption_novelty": 1.0, | ||
"length_avg": 45.26699542092286, | ||
"length_std": 27.99358111821529 | ||
} |
26 changes: 26 additions & 0 deletions
26
melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/hparams.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
framework: bart | ||
data_dir: ../../dataset | ||
train_data: music_caps | ||
text_type: gt | ||
arch: transformer | ||
workers: 8 | ||
epochs: 100 | ||
warmup_epochs: 1 | ||
start_epoch: 0 | ||
batch_size: 64 | ||
world_size: 1 | ||
lr: 0.0001 | ||
min_lr: 1.0e-09 | ||
rank: 0 | ||
dist_url: tcp://localhost:12312 | ||
dist_backend: nccl | ||
seed: null | ||
gpu: 0 | ||
print_freq: 100 | ||
multiprocessing_distributed: false | ||
cos: true | ||
bart_pretrain: false | ||
label_smoothing: 0.1 | ||
use_early_stopping: false | ||
eval_sample: 64 | ||
max_length: 128 |
14 changes: 14 additions & 0 deletions
14
melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/results.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"bleu1": 0.2850628531445219, | ||
"bleu2": 0.13762330787082025, | ||
"bleu3": 0.07588914964864207, | ||
"bleu4": 0.047923589711449235, | ||
"meteor_1.0": 0.20624392533703412, | ||
"rougeL": 0.19215453358420784, | ||
"bertscore": 0.870539891400695, | ||
"vocab_size": 2240, | ||
"vocab_diversity": 0.005357142857142857, | ||
"caption_novelty": 0.6899661781285231, | ||
"length_avg": 46.68193025713279, | ||
"length_std": 16.516237946145356 | ||
} |
Oops, something went wrong.