diff --git a/.DS_Store b/.DS_Store index 7a990dd..88ec8f8 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/main.py b/main.py index f357e5e..3373a61 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ from melodytalk.dependencies.laion_clap.hook import CLAP_Module from datetime import datetime import torch +import demucs.separate from melodytalk.utils import CLAP_post_filter MODEL_NAME = 'melody' @@ -13,7 +14,7 @@ CFG_COEF = 3 SAMPLES = 5 # PROMPT = 'music loop. Passionate love song with guitar rhythms, electric piano chords, drums pattern. instrument: guitar, piano, drum.' -PROMPT = "music loop with saxophone solo. instrument: saxophone." +PROMPT = "music loop with a very strong and rhythmic drum pattern." melody_conditioned = True os.environ["CUDA_VISIBLE_DEVICES"] = "3" @@ -38,7 +39,9 @@ def generate(): if not melody_conditioned: wav = model.generate(descriptions, progress=True) # generates 3 samples. else: - melody, sr = torchaudio.load('/home/intern-2023-02/melodytalk/assets/20230705-155518_3.wav') + # demucs.separate.main(["-n", "htdemucs_6s", "--two-stems", 'drums', '/home/intern-2023-02/melodytalk/assets/20230705-155518_3.wav' ]) + melody, sr = torchaudio.load('/home/intern-2023-02/melodytalk/separated/htdemucs_6s/20230705-155518_3/no_drums.wav') + wav = model.generate_continuation(melody[None].expand(SAMPLES, -1, -1), sr, descriptions, progress=True) # the generated wav contains the melody input, we need to cut it wav = wav[..., int(melody.shape[-1] / sr * model.sample_rate):] diff --git a/melodytalk/dependencies/lpmc/music_captioning/captioning.py b/melodytalk/dependencies/lpmc/music_captioning/captioning.py new file mode 100644 index 0000000..6d14c5b --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/captioning.py @@ -0,0 +1,81 @@ +import argparse +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed + +from melodytalk.dependencies.lpmc.music_captioning.model.bart import BartCaptionModel +from melodytalk.dependencies.lpmc.utils.eval_utils import load_pretrained +from melodytalk.dependencies.lpmc.utils.audio_utils import load_audio, STR_CH_FIRST +from omegaconf import OmegaConf + +parser = argparse.ArgumentParser(description='PyTorch MSD Training') +parser.add_argument('--gpu', default=1, type=int, + help='GPU id to use.') +parser.add_argument("--framework", default="transfer", type=str) +parser.add_argument("--caption_type", default="lp_music_caps", type=str) +parser.add_argument("--max_length", default=128, type=int) +parser.add_argument("--num_beams", default=5, type=int) +parser.add_argument("--model_type", default="last", type=str) +parser.add_argument("--audio_path", default="../../dataset/samples/orchestra.wav", type=str) + +def get_audio(audio_path, duration=10, target_sr=16000): + n_samples = int(duration * target_sr) + audio, sr = load_audio( + path= audio_path, + ch_format= STR_CH_FIRST, + sample_rate= target_sr, + downmix_to_mono= True, + ) + if len(audio.shape) == 2: + audio = audio.mean(0, False) # to mono + input_size = int(n_samples) + if audio.shape[-1] < input_size: # pad sequence + pad = np.zeros(input_size) + pad[: audio.shape[-1]] = audio + audio = pad + ceil = int(audio.shape[-1] // n_samples) + audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32')) + return audio + +def main(): + args = parser.parse_args() + captioning(args) + +def captioning(args): + save_dir = f"exp/{args.framework}/{args.caption_type}/" + config = OmegaConf.load(os.path.join(save_dir, "hparams.yaml")) + model = BartCaptionModel(max_length = config.max_length) + model, save_epoch = load_pretrained(args, save_dir, model, mdp=config.multiprocessing_distributed) + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + model.eval() + + audio_tensor = get_audio(audio_path = args.audio_path) + if args.gpu is not None: + audio_tensor = audio_tensor.cuda(args.gpu, non_blocking=True) + + with torch.no_grad(): + output = model.generate( + samples=audio_tensor, + num_beams=args.num_beams, + ) + inference = {} + number_of_chunks = range(audio_tensor.shape[0]) + for chunk, text in zip(number_of_chunks, output): + time = f"{chunk * 10}:00-{(chunk + 1) * 10}:00" + item = {"text":text,"time":time} + inference[chunk] = item + print(item) + +if __name__ == '__main__': + main() + + \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/datasets/mc.py b/melodytalk/dependencies/lpmc/music_captioning/datasets/mc.py new file mode 100644 index 0000000..eb0375e --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/datasets/mc.py @@ -0,0 +1,58 @@ +import os +import random +import numpy as np +import pandas as pd +import torch +from datasets import load_dataset +from torch.utils.data import Dataset +from lpmc.utils.audio_utils import load_audio, STR_CH_FIRST + +class MC_Dataset(Dataset): + def __init__(self, data_path, split, caption_type, sr=16000, duration=10, audio_enc="wav"): + self.data_path = data_path + self.split = split + self.caption_type = caption_type + self.audio_enc = audio_enc + self.n_samples = int(sr * duration) + self.annotation = load_dataset("seungheondoh/LP-MusicCaps-MC") + self.get_split() + + def get_split(self): + if self.split == "train": + self.fl = [i for i in self.annotation[self.split] if i['is_crawled']] + elif self.split == "test": + self.fl = [i for i in self.annotation[self.split] if i['is_crawled']] + else: + raise ValueError(f"Unexpected split name: {self.split}") + + def load_audio(self, audio_path, file_type): + if file_type == ".npy": + audio = np.load(audio_path, mmap_mode='r') + else: + audio, _ = load_audio( + path=audio_path, + ch_format= STR_CH_FIRST, + sample_rate= self.sr, + downmix_to_mono= True + ) + if len(audio.shape) == 2: + audio = audio.squeeze(0) + input_size = int(self.n_samples) + if audio.shape[-1] < input_size: + pad = np.zeros(input_size) + pad[:audio.shape[-1]] = audio + audio = pad + random_idx = random.randint(0, audio.shape[-1]-self.n_samples) + audio_tensor = torch.from_numpy(np.array(audio[random_idx:random_idx+self.n_samples]).astype('float32')) + return audio_tensor + + def __getitem__(self, index): + item = self.fl[index] + fname = item['fname'] + text = item['caption_ground_truth'] + audio_path = os.path.join(self.data_path, "music_caps", 'npy', fname + ".npy") + audio_tensor = self.load_audio(audio_path, file_type=audio_path[-4:]) + return fname, text, audio_tensor + + def __len__(self): + return len(self.fl) \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/datasets/msd.py b/melodytalk/dependencies/lpmc/music_captioning/datasets/msd.py new file mode 100644 index 0000000..cfc03ed --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/datasets/msd.py @@ -0,0 +1,64 @@ +import os +import json +import random +import numpy as np +import torch +from torch.utils.data import Dataset +from datasets import load_dataset + +class MSD_Balanced_Dataset(Dataset): + def __init__(self, data_path, split, caption_type, sr=16000, duration=10, audio_enc="wav"): + self.data_path = data_path + self.split = split + self.audio_enc = audio_enc + self.n_samples = int(sr * duration) + self.caption_type = caption_type + self.dataset = load_dataset("seungheondoh/LP-MusicCaps-MSD") + self.get_split() + + def get_split(self): + self.tags = json.load(open(os.path.join(self.data_path, "msd", f"{self.split}_tags.json"), 'r')) + self.tag_to_track = json.load(open(os.path.join(self.data_path, "msd", f"{self.split}_tag_to_track.json"), 'r')) + self.annotation = {instance['track_id'] : instance for instance in self.dataset[self.split]} + + def load_caption(self, item): + caption_pool = [] + if (self.caption_type in "write") or (self.caption_type == "lp_music_caps"): + caption_pool.append(item['caption_writing']) + if (self.caption_type in "summary") or (self.caption_type == "lp_music_caps"): + caption_pool.append(item['caption_summary']) + if (self.caption_type in "creative") or (self.caption_type == "lp_music_caps"): + caption_pool.append(item['caption_paraphrase']) + if (self.caption_type in "predict") or (self.caption_type == "lp_music_caps"): + caption_pool.append(item['caption_attribute_prediction']) + # randomly select one caption from multiple captions + sampled_caption = random.choice(caption_pool) + return sampled_caption + + def load_audio(self, audio_path, file_type): + audio = np.load(audio_path, mmap_mode='r') + if len(audio.shape) == 2: + audio = audio.squeeze(0) + input_size = int(self.n_samples) + if audio.shape[-1] < input_size: + pad = np.zeros(input_size) + pad[:audio.shape[-1]] = audio + audio = pad + random_idx = random.randint(0, audio.shape[-1]-self.n_samples) + audio_tensor = torch.from_numpy(np.array(audio[random_idx:random_idx+self.n_samples]).astype('float32')) + audio_tensor = torch.randn(16000*10) + return audio_tensor + + def __getitem__(self, index): + tag = random.choice(self.tags) # uniform random sample tag + fname = random.choice(self.tag_to_track[tag]) # uniform random sample track + item = self.annotation[fname] + track_id = item['track_id'] + gt_caption = "" # no ground turhth + text = self.load_caption(item) + audio_path = os.path.join(self.data_path,'msd','npy', item['path'].replace(".mp3", ".npy")) + audio_tensor = self.load_audio(audio_path, file_type=".npy") + return fname, gt_caption, text, audio_tensor + + def __len__(self): + return 2**13 \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/eval.py b/melodytalk/dependencies/lpmc/music_captioning/eval.py new file mode 100644 index 0000000..e9d8362 --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/eval.py @@ -0,0 +1,51 @@ +import os +import argparse +import json +import numpy as np +from datasets import load_dataset +from lpmc.utils.metrics import bleu, meteor, rouge, bertscore, vocab_novelty, caption_novelty + +def inference_parsing(dataset, args): + ground_truths = [i['caption_ground_truth'] for i in dataset] + inference = json.load(open(os.path.join(args.save_dir, args.framework, args.caption_type, 'inference_temp.json'), 'r')) + id2pred = {item['audio_id']:item['predictions'] for item in inference.values()} + predictions = [id2pred[i['fname']] for i in dataset] + return predictions, ground_truths + +def main(args): + dataset = load_dataset("seungheondoh/LP-MusicCaps-MC") + train_data = [i for i in dataset['train'] if i['is_crawled']] + test_data = [i for i in dataset['test'] if i['is_crawled']] + tr_ground_truths = [i['caption_ground_truth'] for i in train_data] + predictions, ground_truths = inference_parsing(test_data, args) + length_avg = np.mean([len(cap.split()) for cap in predictions]) + length_std = np.std([len(cap.split()) for cap in predictions]) + + vocab_size, vocab_novel_score = vocab_novelty(predictions, tr_ground_truths) + cap_novel_score = caption_novelty(predictions, tr_ground_truths) + results = { + "bleu1": bleu(predictions, ground_truths, order=1), + "bleu2": bleu(predictions, ground_truths, order=2), + "bleu3": bleu(predictions, ground_truths, order=3), + "bleu4": bleu(predictions, ground_truths, order=4), + "meteor_1.0": meteor(predictions, ground_truths), # https://github.com/huggingface/evaluate/issues/115 + "rougeL": rouge(predictions, ground_truths), + "bertscore": bertscore(predictions, ground_truths), + "vocab_size": vocab_size, + "vocab_novelty": vocab_novel_score, + "caption_novelty": cap_novel_score, + "length_avg": length_avg, + "length_std": length_std + } + os.makedirs(os.path.join(args.save_dir, args.framework, args.caption_type), exist_ok=True) + with open(os.path.join(args.save_dir, args.framework, args.caption_type, f"results.json"), mode="w") as io: + json.dump(results, io, indent=4) + print(results) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--save_dir", default="./exp", type=str) + parser.add_argument("--framework", default="supervised", type=str) + parser.add_argument("--caption_type", default="gt", type=str) + args = parser.parse_args() + main(args=args) \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/hparams.yaml b/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/hparams.yaml new file mode 100644 index 0000000..6b791f6 --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/hparams.yaml @@ -0,0 +1,27 @@ +framework: bart +data_dir: ../../dataset +train_data: msd_balence +text_type: all +arch: transformer +workers: 12 +epochs: 4096 +warmup_epochs: 125 +start_epoch: 0 +batch_size: 256 +world_size: 1 +lr: 0.0001 +min_lr: 1.0e-09 +rank: 0 +dist_url: tcp://localhost:12312 +dist_backend: nccl +seed: null +gpu: 0 +print_freq: 100 +multiprocessing_distributed: false +cos: true +bart_pretrain: false +label_smoothing: 0.1 +use_early_stopping: false +eval_sample: 0 +max_length: 110 +distributed: false diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/results.json b/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/results.json new file mode 100644 index 0000000..9e57a16 --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/results.json @@ -0,0 +1,14 @@ +{ + "bleu1": 0.19774511999248645, + "bleu2": 0.06701659207795346, + "bleu3": 0.02165946217889739, + "bleu4": 0.007868711032317937, + "meteor_1.0": 0.12883275733962948, + "rougeL": 0.1302757267270023, + "bertscore": 0.8451152142608364, + "vocab_size": 1686, + "vocab_diversity": 0.4721233689205219, + "caption_novelty": 1.0, + "length_avg": 45.26699542092286, + "length_std": 27.99358111821529 +} \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/hparams.yaml b/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/hparams.yaml new file mode 100644 index 0000000..04829e5 --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/hparams.yaml @@ -0,0 +1,26 @@ +framework: bart +data_dir: ../../dataset +train_data: music_caps +text_type: gt +arch: transformer +workers: 8 +epochs: 100 +warmup_epochs: 1 +start_epoch: 0 +batch_size: 64 +world_size: 1 +lr: 0.0001 +min_lr: 1.0e-09 +rank: 0 +dist_url: tcp://localhost:12312 +dist_backend: nccl +seed: null +gpu: 0 +print_freq: 100 +multiprocessing_distributed: false +cos: true +bart_pretrain: false +label_smoothing: 0.1 +use_early_stopping: false +eval_sample: 64 +max_length: 128 diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/results.json b/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/results.json new file mode 100644 index 0000000..acc0e3c --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/results.json @@ -0,0 +1,14 @@ +{ + "bleu1": 0.2850628531445219, + "bleu2": 0.13762330787082025, + "bleu3": 0.07588914964864207, + "bleu4": 0.047923589711449235, + "meteor_1.0": 0.20624392533703412, + "rougeL": 0.19215453358420784, + "bertscore": 0.870539891400695, + "vocab_size": 2240, + "vocab_diversity": 0.005357142857142857, + "caption_novelty": 0.6899661781285231, + "length_avg": 46.68193025713279, + "length_std": 16.516237946145356 +} \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/hparams.yaml b/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/hparams.yaml new file mode 100644 index 0000000..728a20a --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/hparams.yaml @@ -0,0 +1,25 @@ +framework: transfer +data_dir: ../../dataset +text_type: gt +arch: transformer +workers: 8 +epochs: 100 +warmup_epochs: 20 +start_epoch: 0 +batch_size: 64 +world_size: 1 +lr: 0.0001 +min_lr: 1.0e-09 +rank: 0 +dist_url: tcp://localhost:12312 +dist_backend: nccl +seed: null +gpu: 1 +print_freq: 10 +multiprocessing_distributed: false +cos: true +bart_pretrain: false +label_smoothing: 0.1 +use_early_stopping: false +eval_sample: 64 +max_length: 128 \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/results.json b/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/results.json new file mode 100644 index 0000000..55961d8 --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/results.json @@ -0,0 +1,14 @@ +{ + "bleu1": 0.29093124850986807, + "bleu2": 0.1486508686790579, + "bleu3": 0.08928000372127677, + "bleu4": 0.06046496381102543, + "meteor_1.0": 0.22386547507696472, + "rougeL": 0.2148899248811179, + "bertscore": 0.8777867602017111, + "vocab_size": 1695, + "vocab_diversity": 0.014749262536873156, + "caption_novelty": 0.9606498194945848, + "length_avg": 42.47234941880944, + "length_std": 14.336758651069372 +} \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/infer.py b/melodytalk/dependencies/lpmc/music_captioning/infer.py new file mode 100644 index 0000000..3bddae6 --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/infer.py @@ -0,0 +1,100 @@ +import argparse +import os +import json +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed + +from lpmc.music_captioning.datasets.mc import MC_Dataset +from lpmc.music_captioning.model.bart import BartCaptionModel +from lpmc.utils.eval_utils import load_pretrained +from tqdm import tqdm +from omegaconf import OmegaConf + +parser = argparse.ArgumentParser(description='PyTorch MSD Training') +parser.add_argument('--data_dir', type=str, default="../../dataset") +parser.add_argument('--framework', type=str, default="supervised") +parser.add_argument("--caption_type", default="gt", type=str) +parser.add_argument('--arch', default='transformer') +parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', + help='number of data loading workers') +parser.add_argument('--epochs', default=100, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--warmup_epochs', default=10, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N') +parser.add_argument('--world-size', default=1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--min_lr', default=1e-9, type=float) +parser.add_argument('--seed', default=42, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=1, type=int, + help='GPU id to use.') +parser.add_argument('--print_freq', default=50, type=int) +parser.add_argument("--cos", default=True, type=bool) +parser.add_argument("--label_smoothing", default=0.1, type=float) +parser.add_argument("--max_length", default=128, type=int) +parser.add_argument("--num_beams", default=5, type=int) +parser.add_argument("--model_type", default="last", type=str) + + +def main(): + args = parser.parse_args() + main_worker(args) + +def main_worker(args): + test_dataset = MC_Dataset( + data_path = args.data_dir, + split="test", + caption_type = "gt" + ) + print(len(test_dataset)) + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True, drop_last=False) + model = BartCaptionModel( + max_length = args.max_length, + label_smoothing = args.label_smoothing, + ) + eval(args, model, test_dataset, test_loader, args.num_beams) + +def eval(args, model, test_dataset, test_loader, num_beams=5): + save_dir = f"exp/{args.framework}/{args.caption_type}/" + config = OmegaConf.load(os.path.join(save_dir, "hparams.yaml")) + model, save_epoch = load_pretrained(args, save_dir, model, mdp=config.multiprocessing_distributed) + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + model.eval() + + inference_results = {} + idx = 0 + for batch in tqdm(test_loader): + fname, text,audio_tensor = batch + if args.gpu is not None: + audio_tensor = audio_tensor.cuda(args.gpu, non_blocking=True) + with torch.no_grad(): + output = model.generate( + samples=audio_tensor, + num_beams=num_beams, + ) + for audio_id, gt, pred in zip(fname, text, output): + inference_results[idx] = { + "predictions": pred, + "true_captions": gt, + "audio_id": audio_id + } + idx += 1 + + with open(os.path.join(save_dir, f"inference.json"), mode="w") as io: + json.dump(inference_results, io, indent=4) + +if __name__ == '__main__': + main() + + \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/model/bart.py b/melodytalk/dependencies/lpmc/music_captioning/model/bart.py new file mode 100644 index 0000000..308214c --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/model/bart.py @@ -0,0 +1,153 @@ +### code reference: https://github.com/XinhaoMei/WavCaps/blob/master/captioning/models/bart_captioning.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from lpmc.music_captioning.model.modules import AudioEncoder +from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig + +class BartCaptionModel(nn.Module): + def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768): + super(BartCaptionModel, self).__init__() + # non-finetunning case + bart_config = BartConfig.from_pretrained(bart_type) + self.tokenizer = BartTokenizer.from_pretrained(bart_type) + self.bart = BartForConditionalGeneration(bart_config) + + self.n_sample = sr * duration + self.hop_length = int(0.01 * sr) # hard coding hop_size + self.n_frames = int(self.n_sample // self.hop_length) + self.num_of_stride_conv = num_of_conv - 1 + self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1 + self.audio_encoder = AudioEncoder( + n_mels = n_mels, # hard coding n_mel + n_ctx = self.n_ctx, + audio_dim = audio_dim, + text_dim = self.bart.config.hidden_size, + num_of_stride_conv = self.num_of_stride_conv + ) + + self.max_length = max_length + self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100) + + @property + def device(self): + return list(self.parameters())[0].device + + def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right.ls + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + return shifted_input_ids + + def forward_encoder(self, audio): + audio_embs = self.audio_encoder(audio) + encoder_outputs = self.bart.model.encoder( + input_ids=None, + inputs_embeds=audio_embs, + return_dict=True + )["last_hidden_state"] + return encoder_outputs, audio_embs + + def forward_decoder(self, text, encoder_outputs): + text = self.tokenizer(text, + padding='longest', + truncation=True, + max_length=self.max_length, + return_tensors="pt") + input_ids = text["input_ids"].to(self.device) + attention_mask = text["attention_mask"].to(self.device) + + decoder_targets = input_ids.masked_fill( + input_ids == self.tokenizer.pad_token_id, -100 + ) + + decoder_input_ids = self.shift_tokens_right( + decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id + ) + + decoder_outputs = self.bart( + input_ids=None, + attention_mask=None, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=attention_mask, + inputs_embeds=None, + labels=None, + encoder_outputs=(encoder_outputs,), + return_dict=True + ) + lm_logits = decoder_outputs["logits"] + loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1)) + return loss + + def forward(self, audio, text): + encoder_outputs, _ = self.forward_encoder(audio) + loss = self.forward_decoder(text, encoder_outputs) + return loss + + def generate(self, + samples, + use_nucleus_sampling=False, + num_beams=5, + max_length=128, + min_length=2, + top_p=0.9, + repetition_penalty=1.0, + ): + + # self.bart.force_bos_token_to_be_generated = True + audio_embs = self.audio_encoder(samples) + encoder_outputs = self.bart.model.encoder( + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=audio_embs, + output_attentions=None, + output_hidden_states=None, + return_dict=True) + + input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) + input_ids[:, 0] = self.bart.config.decoder_start_token_id + decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) + if use_nucleus_sampling: + outputs = self.bart.generate( + input_ids=None, + attention_mask=None, + decoder_input_ids=input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + repetition_penalty=1.1) + else: + outputs = self.bart.generate(input_ids=None, + attention_mask=None, + decoder_input_ids=input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + head_mask=None, + decoder_head_mask=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + repetition_penalty=repetition_penalty) + + captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + return captions diff --git a/melodytalk/dependencies/lpmc/music_captioning/model/modules.py b/melodytalk/dependencies/lpmc/music_captioning/model/modules.py new file mode 100644 index 0000000..788baa6 --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/model/modules.py @@ -0,0 +1,95 @@ +### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py + +import os +import torch +import torchaudio +import numpy as np +import torch.nn.functional as F +from torch import Tensor, nn +from typing import Dict, Iterable, Optional + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 1024 +N_MELS = 128 +HOP_LENGTH = int(0.01 * SAMPLE_RATE) +DURATION = 10 +N_SAMPLES = int(DURATION * SAMPLE_RATE) +N_FRAMES = N_SAMPLES // HOP_LENGTH + 1 + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + +class MelEncoder(nn.Module): + """ + time-frequency represntation + """ + def __init__(self, + sample_rate= 16000, + f_min=0, + f_max=8000, + n_fft=1024, + win_length=1024, + hop_length = int(0.01 * 16000), + n_mels = 128, + power = None, + pad= 0, + normalized= False, + center= True, + pad_mode= "reflect" + ): + super(MelEncoder, self).__init__() + self.window = torch.hann_window(win_length) + self.spec_fn = torchaudio.transforms.Spectrogram( + n_fft = n_fft, + win_length = win_length, + hop_length = hop_length, + power = power + ) + self.mel_scale = torchaudio.transforms.MelScale( + n_mels, + sample_rate, + f_min, + f_max, + n_fft // 2 + 1) + + self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() + + def forward(self, wav): + spec = self.spec_fn(wav) + power_spec = spec.real.abs().pow(2) + mel_spec = self.mel_scale(power_spec) + mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin)) + return mel_spec + +class AudioEncoder(nn.Module): + def __init__( + self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int, + ): + super().__init__() + self.mel_encoder = MelEncoder(n_mels=n_mels) + self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1) + self.conv_stack = nn.ModuleList([]) + for _ in range(num_of_stride_conv): + self.conv_stack.append( + nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1) + ) + # self.proj = nn.Linear(audio_dim, text_dim, bias=False) + self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim)) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, waveform) + single channel wavform + """ + x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx) + x = F.gelu(self.conv1(x)) + for conv in self.conv_stack: + x = F.gelu(conv(x)) + x = x.permute(0, 2, 1) + x = (x + self.positional_embedding).to(x.dtype) + return x \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/preprocessor.py b/melodytalk/dependencies/lpmc/music_captioning/preprocessor.py new file mode 100644 index 0000000..0ead2c9 --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/preprocessor.py @@ -0,0 +1,58 @@ +import os +import random +from contextlib import contextmanager + +import json +from melodytalk.dependencies.lpmc.utils.audio_utils import load_audio, STR_CH_FIRST +from sklearn.preprocessing import MultiLabelBinarizer +from tqdm import tqdm + +# hard coding hpamras +DATASET_PATH = "../../dataset/msd" +MUSIC_SAMPLE_RATE = 16000 +DURATION = 30 +DATA_LENGTH = int(MUSIC_SAMPLE_RATE * DURATION) + +@contextmanager +def poolcontext(*args, **kwargs): + pool = multiprocessing.Pool(*args, **kwargs) + yield pool + pool.terminate() + +def msd_resampler(sample): + path = sample['path'] + save_name = os.path.join(DATASET_PATH,'npy', path.replace(".mp3",".npy")) + src, _ = load_audio( + path=os.path.join(DATASET_PATH,'songs',path), + ch_format= STR_CH_FIRST, + sample_rate= MUSIC_SAMPLE_RATE, + downmix_to_mono= True) + if src.shape[-1] < DATA_LENGTH: # short case + pad = np.zeros(DATA_LENGTH) + pad[:src.shape[-1]] = src + src = pad + elif src.shape[-1] > DATA_LENGTH: # too long case + src = src[:DATA_LENGTH] + + if not os.path.exists(os.path.dirname(save_name)): + os.makedirs(os.path.dirname(save_name)) + np.save(save_name, src.astype(np.float32)) + +def build_tag_to_track(msd_dataset, split): + """ + for balanced sampler, we bulid tag_to_track graph + """ + mlb = MultiLabelBinarizer() + indexs = [i['track_id'] for i in msd_dataset[split]] + binary = mlb.fit_transform([i['tag'] for i in msd_dataset[split]]) + tags = list(mlb.classes_) + tag_to_track = {} + for idx, tag in enumerate(tqdm(tags)): + track_list = [indexs[i] for i in binary[:,idx].nonzero()[0]] + tag_to_track[tag] = track_list + + with open(os.path.join(DATASET_PATH, f"{split}_tag_to_track.json"), mode="w") as io: + json.dump(tag_to_track, io, indent=4) + + with open(os.path.join(DATASET_PATH, f"{split}_tags.json"), mode="w") as io: + json.dump(tags, io, indent=4) \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/readme.md b/melodytalk/dependencies/lpmc/music_captioning/readme.md new file mode 100644 index 0000000..441c77b --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/readme.md @@ -0,0 +1,120 @@ +# Audio-to-Caption using Cross Modal Encoder-Decoder + +We used a cross-modal encoder-decoder transformer architecture. + +1. Similar to Whisper, the encoder takes a log-mel spectrogram with six convolution layers with a filter width of 3 and the GELU activation function. With the exception of the first layer, each convolution layer has a stride of two. The output of the convolution layers is combined with the sinusoidal position encoding and then processed by the encoder transformer blocks. + +2. Following the BART architecture, our encoder and decoder both have 768 widths and 6 transformer blocks. The decoder processes tokenized text captions using transformer blocks with a multi-head attention module that includes a mask to hide future tokens for causality. The music and caption representations are fed into the cross-modal attention layer, and the head of the language model in the decoder predicts the next token autoregressively using the cross-entropy loss. + +- **Supervised Model** : [download link](https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/supervised.pth) +- **Pretrain Model** : [download link](https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/pretrain.pth) +- **Transfer Model** : [download link](https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth) + +

+ +

+ +## 0. Quick Start +```bash +# download pretrain model weight from huggingface + +wget https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/supervised.pth -O exp/supervised/gt/last.pth +wget https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth -O exp/transfer/lp_music_caps/last.pth +wget https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/pretrain.pth -O exp/pretrain/lp_music_caps/last.pth +python captioning.py --audio_path ../../dataset/samples/orchestra.wav +``` + +```bash +{ + 'text': "This is a symphonic orchestra playing a piece that's riveting, thrilling and exciting. + The peace would be suitable in a movie when something grand and impressive happens. + There are clarinets, tubas, trumpets and french horns being played. The brass instruments help create that sense of a momentous occasion.", + 'time': '0:00-10:00' +} +{ + 'text': 'This is a classical music piece from a movie soundtrack. + There is a clarinet playing the main melody while a brass section and a flute are playing the melody. + The rhythmic background is provided by the acoustic drums. The atmosphere is epic and victorious. + This piece could be used in the soundtrack of a historical drama movie during the scenes of an army marching towards the end.', +'time': '10:00-20:00' +} +``` + +## 1. Preprocessing audio with ffmpeg + +For fast training, we resample audio at 16000 sampling rate and save it as `.npy`. + +```python +def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]: + """ + Decoding, downmixing, and downsampling by librosa. + Returns a channel-first audio signal. + + Args: + path: + sample_rate: + downmix_to_mono: + + Returns: + (audio signal, sample rate) + """ + + def _decode_resample_by_ffmpeg(filename, sr): + """decode, downmix, and resample audio file""" + channel_cmd = '-ac 1 ' if downmix_to_mono else '' # downmixing option + resampling_cmd = f'-ar {str(sr)}' if sr else '' # downsampling option + cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -" + p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + return out + + src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate))) + return src.T, sr +``` + +The code using `multiprocessing`` is as follows. We also provide preprocessing code for balanced data loading. + +``` +# multiprocessing resampling & bulid tag-to-track linked list for balanced data loading. +python preprocessor.py +``` + +## 2. Train & Eval Supervised Model (Baseline) + +Download [MusicCaps audio](https://github.com/seungheondoh/music_caps_dl), if you hard to get audio please request for research purpose + +``` +# train supervised baseline model +python train.py --framework supervised --train_data mc --caption_type gt --warmup_epochs 1 --label_smoothing 0.1 --max_length 128 --batch-size 64 --epochs 100 + +# inference caption +python infer.py --framework supervised --train_data mc --caption_type gt --num_beams 5 --model_type last + +# eval +python eval.py --framework supervised --caption_type gt +``` + +## 3. Pretrain, Transfer Music Captioning Model (Proposed) + +Download MSD audio, if you hard to get audio please request for research purpose + +``` +# train pretrain model +python train.py --framework pretrain --train_data msd --caption_type lp_music_caps --warmup_epochs 125 --label_smoothing 0.1 --max_length 110 --batch-size 256 --epochs 4096 + +# train transfer model +python transfer.py --caption_type gt --warmup_epochs 1 --label_smoothing 0.1 --max_length 128 --batch-size 64 --epochs 100 + +# inference caption +python infer.py --framework transfer --caption_type lp_music_caps --num_beams 5 --model_type last + +# eval +python eval.py --framework transfer --caption_type lp_music_caps +``` + +### License +This project is under the CC-BY-NC 4.0 license. See LICENSE for details. + + +### Acknowledgement +We would like to thank the [Whisper](https://github.com/openai/whisper) for audio frontend, [WavCaps](https://github.com/XinhaoMei/WavCaps) for audio-captioning training code and [deezer-playntell](https://github.com/deezer/playntell) for contents based captioning evaluation protocal. \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/train.py b/melodytalk/dependencies/lpmc/music_captioning/train.py new file mode 100644 index 0000000..62a899d --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/train.py @@ -0,0 +1,134 @@ +import argparse +import math +import random +import shutil +import torch +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torch.utils.data.distributed + +from lpmc.music_captioning.datasets.mc import MC_Dataset +from lpmc.music_captioning.datasets.msd import MSD_Balanced_Dataset +from lpmc.music_captioning.model.bart import BartCaptionModel +from lpmc.utils.train_utils import Logger, AverageMeter, ProgressMeter, EarlyStopping, save_hparams + +parser = argparse.ArgumentParser(description='PyTorch MSD Training') +parser.add_argument('--framework', type=str, default="pretrain") +parser.add_argument('--data_dir', type=str, default="../../dataset") +parser.add_argument('--train_data', type=str, default="msd") +parser.add_argument("--caption_type", default="lp_music_caps", type=str) # lp_music_caps +parser.add_argument('--arch', default='transformer') +parser.add_argument('-j', '--workers', default=12, type=int, metavar='N', + help='number of data loading workers') +parser.add_argument('--epochs', default=100, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--warmup_epochs', default=125, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N') +parser.add_argument('--world-size', default=1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--min_lr', default=1e-9, type=float) +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=1, type=int, + help='GPU id to use.') +parser.add_argument('--print_freq', default=10, type=int) +parser.add_argument("--cos", default=True, type=bool) +parser.add_argument("--label_smoothing", default=0.1, type=float) +parser.add_argument("--max_length", default=128, type=int) +parser.add_argument("--resume", default=None, type=bool) + +def main(): + args = parser.parse_args() + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + main_worker(args) + +def main_worker(args): + if args.train_data == "msd": + train_dataset = MSD_Balanced_Dataset( + data_path = args.data_dir, + split="train", + caption_type = args.caption_type + ) + elif args.train_data == "mc": + train_dataset = MC_Dataset( + data_path = args.data_dir, + split="train", + caption_type = args.caption_type + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, drop_last=True) + + model = BartCaptionModel( + max_length = args.max_length, + label_smoothing = args.label_smoothing + ) + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + + optimizer = torch.optim.AdamW(model.parameters(), args.lr) + save_dir = f"exp/{args.framework}/{args.caption_type}/" + + logger = Logger(save_dir) + save_hparams(args, save_dir) + + for epoch in range(args.start_epoch, args.epochs): + train(train_loader, model, optimizer, epoch, logger, args) + + torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()}, f'{save_dir}/last.pth') + print("We are at epoch:", epoch) + +def train(train_loader, model, optimizer, epoch, logger, args): + train_losses = AverageMeter('Train Loss', ':.4e') + progress = ProgressMeter(len(train_loader),[train_losses],prefix="Epoch: [{}]".format(epoch)) + iters_per_epoch = len(train_loader) + model.train() + for data_iter_step, batch in enumerate(train_loader): + lr = adjust_learning_rate(optimizer, data_iter_step / iters_per_epoch + epoch, args) + fname, text, audio_tensor = batch + if args.gpu is not None: + audio_tensor = audio_tensor.cuda(args.gpu, non_blocking=True) + # compute output + loss = model(audio=audio_tensor, text=text) + train_losses.step(loss.item(), audio_tensor.size(0)) + logger.log_train_loss(loss, epoch * iters_per_epoch + data_iter_step) + logger.log_learning_rate(lr, epoch * iters_per_epoch + data_iter_step) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if data_iter_step % args.print_freq == 0: + progress.display(data_iter_step) + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ + (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr + +if __name__ == '__main__': + main() + + \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/music_captioning/transfer.py b/melodytalk/dependencies/lpmc/music_captioning/transfer.py new file mode 100644 index 0000000..6889857 --- /dev/null +++ b/melodytalk/dependencies/lpmc/music_captioning/transfer.py @@ -0,0 +1,125 @@ +import argparse +import math +import os +import random +import shutil +import torch +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +from lpmc.music_captioning.datasets.mc import MC_Dataset +from lpmc.music_captioning.model.bart import BartCaptionModel +from lpmc.utils.train_utils import Logger, AverageMeter, ProgressMeter, EarlyStopping, save_hparams +from mcb.utils.eval_utils import load_pretrained, print_model_params +from omegaconf import OmegaConf + +parser = argparse.ArgumentParser(description='PyTorch MSD Training') +parser.add_argument('--framework', type=str, default="transfer") +parser.add_argument('--data_dir', type=str, default="../../dataset") +parser.add_argument("--caption_type", default="lp_music_caps", type=str) +parser.add_argument('--arch', default='transformer') +parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', + help='number of data loading workers') +parser.add_argument('--epochs', default=100, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--warmup_epochs', default=20, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N') +parser.add_argument('--world-size', default=1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--min_lr', default=1e-9, type=float) +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=1, type=int, + help='GPU id to use.') +parser.add_argument('--print_freq', default=10, type=int) +parser.add_argument("--cos", default=True, type=bool) +parser.add_argument("--label_smoothing", default=0.1, type=float) +parser.add_argument("--max_length", default=128, type=int) + +def main(): + args = parser.parse_args() + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + main_worker(args) + +def main_worker(args): + train_dataset = MC_Dataset( + data_path = args.data_dir, + split="train", + caption_type = "gt" + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True, drop_last=False) + + model = BartCaptionModel( + max_length = args.max_length, + label_smoothing = args.label_smoothing + ) + pretrain_dir = f"exp/pretrain/{args.caption_type}/" + config = OmegaConf.load(os.path.join(pretrain_dir, "hparams.yaml")) + model, save_epoch = load_pretrained(args, pretrain_dir, model, model_types="last", mdp=config.multiprocessing_distributed) + print_model_params(model) + + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + + optimizer = torch.optim.AdamW(model.parameters(), args.lr) + save_dir = f"exp/transfer/{args.caption_type}" + logger = Logger(save_dir) + save_hparams(args, save_dir) + for epoch in range(args.start_epoch, args.epochs): + train(train_loader, model, optimizer, epoch, logger, args) + torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()}, f'{save_dir}/last.pth') + +def train(train_loader, model, optimizer, epoch, logger, args): + train_losses = AverageMeter('Train Loss', ':.4e') + progress = ProgressMeter(len(train_loader),[train_losses],prefix="Epoch: [{}]".format(epoch)) + iters_per_epoch = len(train_loader) + model.train() + for data_iter_step, batch in enumerate(train_loader): + lr = adjust_learning_rate(optimizer, data_iter_step / iters_per_epoch + epoch, args) + fname, gt_caption, text, audio_embs = batch + if args.gpu is not None: + audio_embs = audio_embs.cuda(args.gpu, non_blocking=True) + # compute output + loss = model(audio=audio_embs, text=text) + train_losses.step(loss.item(), audio_embs.size(0)) + logger.log_train_loss(loss, epoch * iters_per_epoch + data_iter_step) + logger.log_learning_rate(lr, epoch * iters_per_epoch + data_iter_step) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if data_iter_step % args.print_freq == 0: + progress.display(data_iter_step) + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ + (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr + +if __name__ == '__main__': + main() + + \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/utils/audio_utils.py b/melodytalk/dependencies/lpmc/utils/audio_utils.py new file mode 100644 index 0000000..d033238 --- /dev/null +++ b/melodytalk/dependencies/lpmc/utils/audio_utils.py @@ -0,0 +1,247 @@ +STR_CLIP_ID = 'clip_id' +STR_AUDIO_SIGNAL = 'audio_signal' +STR_TARGET_VECTOR = 'target_vector' + + +STR_CH_FIRST = 'channels_first' +STR_CH_LAST = 'channels_last' + +import io +import os +import tqdm +import logging +import subprocess +from typing import Tuple +from pathlib import Path + +# import librosa +import numpy as np +import soundfile as sf + +import itertools +from numpy.fft import irfft + +def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]: + """ + Decoding, downmixing, and downsampling by librosa. + Returns a channel-first audio signal. + + Args: + path: + sample_rate: + downmix_to_mono: + + Returns: + (audio signal, sample rate) + """ + + def _decode_resample_by_ffmpeg(filename, sr): + """decode, downmix, and resample audio file""" + channel_cmd = '-ac 1 ' if downmix_to_mono else '' # downmixing option + resampling_cmd = f'-ar {str(sr)}' if sr else '' # downsampling option + cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -" + p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + return out + + src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate))) + return src.T, sr + + +def _resample_load_librosa(path: str, sample_rate: int, downmix_to_mono: bool, **kwargs) -> Tuple[np.ndarray, int]: + """ + Decoding, downmixing, and downsampling by librosa. + Returns a channel-first audio signal. + """ + src, sr = librosa.load(path, sr=sample_rate, mono=downmix_to_mono, **kwargs) + return src, sr + + +def load_audio( + path: str or Path, + ch_format: str, + sample_rate: int = None, + downmix_to_mono: bool = False, + resample_by: str = 'ffmpeg', + **kwargs, +) -> Tuple[np.ndarray, int]: + """A wrapper of librosa.load that: + - forces the returned audio to be 2-dim, + - defaults to sr=None, and + - defaults to downmix_to_mono=False. + + The audio decoding is done by `audioread` or `soundfile` package and ultimately, often by ffmpeg. + The resampling is done by `librosa`'s child package `resampy`. + + Args: + path: audio file path + ch_format: one of 'channels_first' or 'channels_last' + sample_rate: target sampling rate. if None, use the rate of the audio file + downmix_to_mono: + resample_by (str): 'librosa' or 'ffmpeg'. it decides backend for audio decoding and resampling. + **kwargs: keyword args for librosa.load - offset, duration, dtype, res_type. + + Returns: + (audio, sr) tuple + """ + if ch_format not in (STR_CH_FIRST, STR_CH_LAST): + raise ValueError(f'ch_format is wrong here -> {ch_format}') + + if os.stat(path).st_size > 8000: + if resample_by == 'librosa': + src, sr = _resample_load_librosa(path, sample_rate, downmix_to_mono, **kwargs) + elif resample_by == 'ffmpeg': + src, sr = _resample_load_ffmpeg(path, sample_rate, downmix_to_mono) + else: + raise NotImplementedError(f'resample_by: "{resample_by}" is not supposred yet') + else: + raise ValueError('Given audio is too short!') + return src, sr + + # if src.ndim == 1: + # src = np.expand_dims(src, axis=0) + # # now always 2d and channels_first + + # if ch_format == STR_CH_FIRST: + # return src, sr + # else: + # return src.T, sr + +def ms(x): + """Mean value of signal `x` squared. + :param x: Dynamic quantity. + :returns: Mean squared of `x`. + """ + return (np.abs(x)**2.0).mean() + +def normalize(y, x=None): + """normalize power in y to a (standard normal) white noise signal. + Optionally normalize to power in signal `x`. + #The mean power of a Gaussian with :math:`\\mu=0` and :math:`\\sigma=1` is 1. + """ + if x is not None: + x = ms(x) + else: + x = 1.0 + return y * np.sqrt(x / ms(y)) + +def noise(N, color='white', state=None): + """Noise generator. + :param N: Amount of samples. + :param color: Color of noise. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + """ + try: + return _noise_generators[color](N, state) + except KeyError: + raise ValueError("Incorrect color.") + +def white(N, state=None): + """ + White noise. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + White noise has a constant power density. It's narrowband spectrum is therefore flat. + The power in white noise will increase by a factor of two for each octave band, + and therefore increases with 3 dB per octave. + """ + state = np.random.RandomState() if state is None else state + return state.randn(N) + +def pink(N, state=None): + """ + Pink noise. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + Pink noise has equal power in bands that are proportionally wide. + Power density decreases with 3 dB per octave. + """ + state = np.random.RandomState() if state is None else state + uneven = N % 2 + X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) + S = np.sqrt(np.arange(len(X)) + 1.) # +1 to avoid divide by zero + y = (irfft(X / S)).real + if uneven: + y = y[:-1] + return normalize(y) + +def blue(N, state=None): + """ + Blue noise. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + Power increases with 6 dB per octave. + Power density increases with 3 dB per octave. + """ + state = np.random.RandomState() if state is None else state + uneven = N % 2 + X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) + S = np.sqrt(np.arange(len(X))) # Filter + y = (irfft(X * S)).real + if uneven: + y = y[:-1] + return normalize(y) + +def brown(N, state=None): + """ + Violet noise. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + Power decreases with -3 dB per octave. + Power density decreases with 6 dB per octave. + """ + state = np.random.RandomState() if state is None else state + uneven = N % 2 + X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) + S = (np.arange(len(X)) + 1) # Filter + y = (irfft(X / S)).real + if uneven: + y = y[:-1] + return normalize(y) + +def violet(N, state=None): + """ + Violet noise. Power increases with 6 dB per octave. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + Power increases with +9 dB per octave. + Power density increases with +6 dB per octave. + """ + state = np.random.RandomState() if state is None else state + uneven = N % 2 + X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) + S = (np.arange(len(X))) # Filter + y = (irfft(X * S)).real + if uneven: + y = y[:-1] + return normalize(y) + +_noise_generators = { + 'white': white, + 'pink': pink, + 'blue': blue, + 'brown': brown, + 'violet': violet, +} + +def noise_generator(N=44100, color='white', state=None): + """Noise generator. + :param N: Amount of unique samples to generate. + :param color: Color of noise. + Generate `N` amount of unique samples and cycle over these samples. + """ + #yield from itertools.cycle(noise(N, color)) # Python 3.3 + for sample in itertools.cycle(noise(N, color, state)): + yield sample + +def heaviside(N): + """Heaviside. + Returns the value 0 for `x < 0`, 1 for `x > 0`, and 1/2 for `x = 0`. + """ + return 0.5 * (np.sign(N) + 1) \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/utils/eval_utils.py b/melodytalk/dependencies/lpmc/utils/eval_utils.py new file mode 100644 index 0000000..9fefc9d --- /dev/null +++ b/melodytalk/dependencies/lpmc/utils/eval_utils.py @@ -0,0 +1,30 @@ +import os +import json +import torch +import numpy as np +import pandas as pd +from sklearn import metrics + + +def print_model_params(model): + n_parameters = sum(p.numel() for p in model.parameters()) + train_n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("============") + print('number of params (M): %.2f' % (n_parameters / 1.e6)) + print('number train of params (M): %.2f' % (train_n_parameters / 1.e6)) + print("============") + +def load_pretrained(args, save_dir, model, model_types="last", mdp=False): + pretrained_object = torch.load(f'{save_dir}/{model_types}.pth', map_location='cpu') + state_dict = pretrained_object['state_dict'] + save_epoch = pretrained_object['epoch'] + if mdp: + for k in list(state_dict.keys()): + if k.startswith('module.'): + state_dict[k[len("module."):]] = state_dict[k] + del state_dict[k] + model.load_state_dict(state_dict) + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + model.eval() + return model, save_epoch \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/utils/metrics.py b/melodytalk/dependencies/lpmc/utils/metrics.py new file mode 100644 index 0000000..d3cb799 --- /dev/null +++ b/melodytalk/dependencies/lpmc/utils/metrics.py @@ -0,0 +1,144 @@ +"""Placeholder for metrics.""" +from functools import partial +import evaluate +import numpy as np +import torch +import torchmetrics.retrieval as retrieval_metrics +# CAPTIONING METRICS +def bleu(predictions, ground_truths, order): + bleu_eval = evaluate.load("bleu") + return bleu_eval.compute( + predictions=predictions, references=ground_truths, max_order=order + )["bleu"] + +def meteor(predictions, ground_truths): + # https://github.com/huggingface/evaluate/issues/115 + meteor_eval = evaluate.load("meteor") + return meteor_eval.compute(predictions=predictions, references=ground_truths)[ + "meteor" + ] + + +def rouge(predictions, ground_truths): + rouge_eval = evaluate.load("rouge") + return rouge_eval.compute(predictions=predictions, references=ground_truths)[ + "rougeL" + ] + + +def bertscore(predictions, ground_truths): + bertscore_eval = evaluate.load("bertscore") + score = bertscore_eval.compute( + predictions=predictions, references=ground_truths, lang="en" + )["f1"] + return np.mean(score) + + +def vocab_diversity(predictions, references): + train_caps_tokenized = [ + train_cap.translate(str.maketrans("", "", string.punctuation)).lower().split() + for train_cap in references + ] + gen_caps_tokenized = [ + gen_cap.translate(str.maketrans("", "", string.punctuation)).lower().split() + for gen_cap in predictions + ] + training_vocab = Vocabulary(train_caps_tokenized, min_count=2).idx2word + generated_vocab = Vocabulary(gen_caps_tokenized, min_count=1).idx2word + + return len(generated_vocab) / len(training_vocab) + + +def vocab_novelty(predictions, tr_ground_truths): + predictions_token, tr_ground_truths_token = [], [] + for gen, ref in zip(predictions, tr_ground_truths): + predictions_token.extend(gen.lower().replace(",","").replace(".","").split()) + tr_ground_truths_token.extend(ref.lower().replace(",","").replace(".","").split()) + + predictions_vocab = set(predictions_token) + new_vocab = predictions_vocab.difference(set(tr_ground_truths_token)) + + vocab_size = len(predictions_vocab) + novel_v = len(new_vocab) / vocab_size + return vocab_size, novel_v + +def caption_novelty(predictions, tr_ground_truths): + unique_pred_captions = set(predictions) + unique_train_captions = set(tr_ground_truths) + + new_caption = unique_pred_captions.difference(unique_train_captions) + novel_c = len(new_caption) / len(unique_pred_captions) + return novel_c + +def metric_1(predictions, ground_truths) -> float: + """Computes metric_1 score. + Args: + predictions: A list of predictions. + ground_truths: A list of ground truths. + Returns: + metric_1: A float number, the metric_1 score. + """ + return 0.0 + + +# RETRIEVAL METRICS +def _prepare_torchmetrics_input(scores, query2target_idx): + target = [ + [i in target_idxs for i in range(len(scores[0]))] + for query_idx, target_idxs in query2target_idx.items() + ] + indexes = torch.arange(len(scores)).unsqueeze(1).repeat((1, len(target[0]))) + return torch.as_tensor(scores), torch.as_tensor(target), indexes + + +def _call_torchmetrics( + metric: retrieval_metrics.RetrievalMetric, scores, query2target_idx, **kwargs +): + preds, target, indexes = _prepare_torchmetrics_input(scores, query2target_idx) + return metric(preds, target, indexes=indexes, **kwargs).item() + + +def recall(predicted_scores, query2target_idx, k: int) -> float: + """Compute retrieval recall score at cutoff k. + + Args: + predicted_scores: N x M similarity matrix + query2target_idx: a dictionary with + key: unique query idx + values: list of target idx + k: number of top-k results considered + Returns: + average score of recall@k + """ + recall_metric = retrieval_metrics.RetrievalRecall(k=k) + return _call_torchmetrics(recall_metric, predicted_scores, query2target_idx) + + +def mean_average_precision(predicted_scores, query2target_idx) -> float: + """Compute retrieval mean average precision (MAP) score at cutoff k. + + Args: + predicted_scores: N x M similarity matrix + query2target_idx: a dictionary with + key: unique query idx + values: list of target idx + Returns: + MAP@k score + """ + map_metric = retrieval_metrics.RetrievalMAP() + return _call_torchmetrics(map_metric, predicted_scores, query2target_idx) + + +def mean_reciprocal_rank(predicted_scores, query2target_idx) -> float: + """Compute retrieval mean reciprocal rank (MRR) score. + + Args: + predicted_scores: N x M similarity matrix + query2target_idx: a dictionary with + key: unique query idx + values: list of target idx + Returns: + MRR score + """ + mrr_metric = retrieval_metrics.RetrievalMRR() + return _call_torchmetrics(mrr_metric, predicted_scores, query2target_idx) \ No newline at end of file diff --git a/melodytalk/dependencies/lpmc/utils/train_utils.py b/melodytalk/dependencies/lpmc/utils/train_utils.py new file mode 100644 index 0000000..be5c17c --- /dev/null +++ b/melodytalk/dependencies/lpmc/utils/train_utils.py @@ -0,0 +1,120 @@ +import os +import torch +from torch.utils.tensorboard import SummaryWriter +from omegaconf import DictConfig, OmegaConf + +def save_hparams(args, save_path): + save_config = OmegaConf.create(vars(args)) + os.makedirs(save_path, exist_ok=True) + OmegaConf.save(config=save_config, f= os.path.join(save_path, "hparams.yaml")) + +class EarlyStopping(): + def __init__(self, min_max="min", tolerance=20, min_delta=1e-9): + self.tolerance = tolerance + self.min_delta = min_delta + self.min_max = min_max + self.counter = 0 + self.early_stop = False + + def min_stopping(self, valid_loss, best_valid_loss): + if (valid_loss - best_valid_loss) > self.min_delta: + self.counter +=1 + if self.counter >= self.tolerance: + self.early_stop = True + else: + self.counter = 0 + + def max_stopping(self, valid_acc, best_valid_acc): + if (best_valid_acc - valid_acc) > self.min_delta: + self.counter +=1 + if self.counter >= self.tolerance: + self.early_stop = True + else: + self.counter = 0 + + def __call__(self, valid_metric, best_metic): + if self.min_max == "min": + self.min_stopping(valid_metric, best_metic) + elif self.min_max == "max": + self.max_stopping(valid_metric, best_metic) + else: + raise ValueError(f"Unexpected split name: {self.min_max}") + +class Logger(SummaryWriter): + def __init__(self, logdir): + super(Logger, self).__init__(logdir) + + def log_train_loss(self, loss, steps): + self.add_scalar('train_loss', loss.item(), steps) + + def log_val_loss(self, loss, epochs): + self.add_scalar('valid_loss', loss.item(), epochs) + + def log_caption_matric(self, metric, epochs, name="acc"): + self.add_scalar(f'{name}', metric, epochs) + + def log_logitscale(self, logitscale, epochs): + self.add_scalar('logit_scale', logitscale.item(), epochs) + + def log_learning_rate(self, lr, epochs): + self.add_scalar('lr', lr, epochs) + + def log_learning_rate(self, lr, epochs): + self.add_scalar('lr', lr, epochs) + + def log_roc(self, roc, epochs): + self.add_scalar('roc', roc, epochs) + + def log_pr(self, pr, epochs): + self.add_scalar('pr', pr, epochs) + +class AverageMeter(object): + def __init__(self,name, fmt, init_steps=0): + self.name = name + self.fmt = fmt + self.steps = init_steps + self.reset() + + def reset(self): + self.val = 0.0 + self.sum = 0.0 + self.num = 0 + self.avg = 0.0 + + def step(self, val, num=1): + self.val = val + self.sum += num*val + self.num += num + self.steps += 1 + self.avg = self.sum/self.num + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + +def load_pretrained(pretrain_path, model): + checkpoint= torch.load(pretrain_path, map_location='cpu') + state_dict = checkpoint['state_dict'] + for k in list(state_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.1.mlp'): + state_dict[k[len("module.encoder_q.0."):]] = state_dict[k] + del state_dict[k] + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/melodytalk/modules.py b/melodytalk/modules.py index 8bc1ed3..153b2a7 100644 --- a/melodytalk/modules.py +++ b/melodytalk/modules.py @@ -324,7 +324,18 @@ def inference(self, inputs): class MusicCaptioning(object): def __init__(self): - raise NotImplementedError + print("Initializing MusicCaptioning") + + @prompts( + name="Describe the current music.", + description="useful if you want to describe a music." + "Like: describe the current music, or what is the current music sounds like." + "The input to this tool should be the music_filename. " + ) + + def inference(self, inputs): + pass + # class Text2MusicwithChord(object): # template_model = True