diff --git a/.DS_Store b/.DS_Store
index 7a990dd..88ec8f8 100644
Binary files a/.DS_Store and b/.DS_Store differ
diff --git a/main.py b/main.py
index f357e5e..3373a61 100644
--- a/main.py
+++ b/main.py
@@ -6,6 +6,7 @@
from melodytalk.dependencies.laion_clap.hook import CLAP_Module
from datetime import datetime
import torch
+import demucs.separate
from melodytalk.utils import CLAP_post_filter
MODEL_NAME = 'melody'
@@ -13,7 +14,7 @@
CFG_COEF = 3
SAMPLES = 5
# PROMPT = 'music loop. Passionate love song with guitar rhythms, electric piano chords, drums pattern. instrument: guitar, piano, drum.'
-PROMPT = "music loop with saxophone solo. instrument: saxophone."
+PROMPT = "music loop with a very strong and rhythmic drum pattern."
melody_conditioned = True
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
@@ -38,7 +39,9 @@ def generate():
if not melody_conditioned:
wav = model.generate(descriptions, progress=True) # generates 3 samples.
else:
- melody, sr = torchaudio.load('/home/intern-2023-02/melodytalk/assets/20230705-155518_3.wav')
+ # demucs.separate.main(["-n", "htdemucs_6s", "--two-stems", 'drums', '/home/intern-2023-02/melodytalk/assets/20230705-155518_3.wav' ])
+ melody, sr = torchaudio.load('/home/intern-2023-02/melodytalk/separated/htdemucs_6s/20230705-155518_3/no_drums.wav')
+
wav = model.generate_continuation(melody[None].expand(SAMPLES, -1, -1), sr, descriptions, progress=True)
# the generated wav contains the melody input, we need to cut it
wav = wav[..., int(melody.shape[-1] / sr * model.sample_rate):]
diff --git a/melodytalk/dependencies/lpmc/music_captioning/captioning.py b/melodytalk/dependencies/lpmc/music_captioning/captioning.py
new file mode 100644
index 0000000..6d14c5b
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/captioning.py
@@ -0,0 +1,81 @@
+import argparse
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.optim
+import torch.multiprocessing as mp
+import torch.utils.data
+import torch.utils.data.distributed
+
+from melodytalk.dependencies.lpmc.music_captioning.model.bart import BartCaptionModel
+from melodytalk.dependencies.lpmc.utils.eval_utils import load_pretrained
+from melodytalk.dependencies.lpmc.utils.audio_utils import load_audio, STR_CH_FIRST
+from omegaconf import OmegaConf
+
+parser = argparse.ArgumentParser(description='PyTorch MSD Training')
+parser.add_argument('--gpu', default=1, type=int,
+ help='GPU id to use.')
+parser.add_argument("--framework", default="transfer", type=str)
+parser.add_argument("--caption_type", default="lp_music_caps", type=str)
+parser.add_argument("--max_length", default=128, type=int)
+parser.add_argument("--num_beams", default=5, type=int)
+parser.add_argument("--model_type", default="last", type=str)
+parser.add_argument("--audio_path", default="../../dataset/samples/orchestra.wav", type=str)
+
+def get_audio(audio_path, duration=10, target_sr=16000):
+ n_samples = int(duration * target_sr)
+ audio, sr = load_audio(
+ path= audio_path,
+ ch_format= STR_CH_FIRST,
+ sample_rate= target_sr,
+ downmix_to_mono= True,
+ )
+ if len(audio.shape) == 2:
+ audio = audio.mean(0, False) # to mono
+ input_size = int(n_samples)
+ if audio.shape[-1] < input_size: # pad sequence
+ pad = np.zeros(input_size)
+ pad[: audio.shape[-1]] = audio
+ audio = pad
+ ceil = int(audio.shape[-1] // n_samples)
+ audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
+ return audio
+
+def main():
+ args = parser.parse_args()
+ captioning(args)
+
+def captioning(args):
+ save_dir = f"exp/{args.framework}/{args.caption_type}/"
+ config = OmegaConf.load(os.path.join(save_dir, "hparams.yaml"))
+ model = BartCaptionModel(max_length = config.max_length)
+ model, save_epoch = load_pretrained(args, save_dir, model, mdp=config.multiprocessing_distributed)
+ torch.cuda.set_device(args.gpu)
+ model = model.cuda(args.gpu)
+ model.eval()
+
+ audio_tensor = get_audio(audio_path = args.audio_path)
+ if args.gpu is not None:
+ audio_tensor = audio_tensor.cuda(args.gpu, non_blocking=True)
+
+ with torch.no_grad():
+ output = model.generate(
+ samples=audio_tensor,
+ num_beams=args.num_beams,
+ )
+ inference = {}
+ number_of_chunks = range(audio_tensor.shape[0])
+ for chunk, text in zip(number_of_chunks, output):
+ time = f"{chunk * 10}:00-{(chunk + 1) * 10}:00"
+ item = {"text":text,"time":time}
+ inference[chunk] = item
+ print(item)
+
+if __name__ == '__main__':
+ main()
+
+
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/datasets/mc.py b/melodytalk/dependencies/lpmc/music_captioning/datasets/mc.py
new file mode 100644
index 0000000..eb0375e
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/datasets/mc.py
@@ -0,0 +1,58 @@
+import os
+import random
+import numpy as np
+import pandas as pd
+import torch
+from datasets import load_dataset
+from torch.utils.data import Dataset
+from lpmc.utils.audio_utils import load_audio, STR_CH_FIRST
+
+class MC_Dataset(Dataset):
+ def __init__(self, data_path, split, caption_type, sr=16000, duration=10, audio_enc="wav"):
+ self.data_path = data_path
+ self.split = split
+ self.caption_type = caption_type
+ self.audio_enc = audio_enc
+ self.n_samples = int(sr * duration)
+ self.annotation = load_dataset("seungheondoh/LP-MusicCaps-MC")
+ self.get_split()
+
+ def get_split(self):
+ if self.split == "train":
+ self.fl = [i for i in self.annotation[self.split] if i['is_crawled']]
+ elif self.split == "test":
+ self.fl = [i for i in self.annotation[self.split] if i['is_crawled']]
+ else:
+ raise ValueError(f"Unexpected split name: {self.split}")
+
+ def load_audio(self, audio_path, file_type):
+ if file_type == ".npy":
+ audio = np.load(audio_path, mmap_mode='r')
+ else:
+ audio, _ = load_audio(
+ path=audio_path,
+ ch_format= STR_CH_FIRST,
+ sample_rate= self.sr,
+ downmix_to_mono= True
+ )
+ if len(audio.shape) == 2:
+ audio = audio.squeeze(0)
+ input_size = int(self.n_samples)
+ if audio.shape[-1] < input_size:
+ pad = np.zeros(input_size)
+ pad[:audio.shape[-1]] = audio
+ audio = pad
+ random_idx = random.randint(0, audio.shape[-1]-self.n_samples)
+ audio_tensor = torch.from_numpy(np.array(audio[random_idx:random_idx+self.n_samples]).astype('float32'))
+ return audio_tensor
+
+ def __getitem__(self, index):
+ item = self.fl[index]
+ fname = item['fname']
+ text = item['caption_ground_truth']
+ audio_path = os.path.join(self.data_path, "music_caps", 'npy', fname + ".npy")
+ audio_tensor = self.load_audio(audio_path, file_type=audio_path[-4:])
+ return fname, text, audio_tensor
+
+ def __len__(self):
+ return len(self.fl)
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/datasets/msd.py b/melodytalk/dependencies/lpmc/music_captioning/datasets/msd.py
new file mode 100644
index 0000000..cfc03ed
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/datasets/msd.py
@@ -0,0 +1,64 @@
+import os
+import json
+import random
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from datasets import load_dataset
+
+class MSD_Balanced_Dataset(Dataset):
+ def __init__(self, data_path, split, caption_type, sr=16000, duration=10, audio_enc="wav"):
+ self.data_path = data_path
+ self.split = split
+ self.audio_enc = audio_enc
+ self.n_samples = int(sr * duration)
+ self.caption_type = caption_type
+ self.dataset = load_dataset("seungheondoh/LP-MusicCaps-MSD")
+ self.get_split()
+
+ def get_split(self):
+ self.tags = json.load(open(os.path.join(self.data_path, "msd", f"{self.split}_tags.json"), 'r'))
+ self.tag_to_track = json.load(open(os.path.join(self.data_path, "msd", f"{self.split}_tag_to_track.json"), 'r'))
+ self.annotation = {instance['track_id'] : instance for instance in self.dataset[self.split]}
+
+ def load_caption(self, item):
+ caption_pool = []
+ if (self.caption_type in "write") or (self.caption_type == "lp_music_caps"):
+ caption_pool.append(item['caption_writing'])
+ if (self.caption_type in "summary") or (self.caption_type == "lp_music_caps"):
+ caption_pool.append(item['caption_summary'])
+ if (self.caption_type in "creative") or (self.caption_type == "lp_music_caps"):
+ caption_pool.append(item['caption_paraphrase'])
+ if (self.caption_type in "predict") or (self.caption_type == "lp_music_caps"):
+ caption_pool.append(item['caption_attribute_prediction'])
+ # randomly select one caption from multiple captions
+ sampled_caption = random.choice(caption_pool)
+ return sampled_caption
+
+ def load_audio(self, audio_path, file_type):
+ audio = np.load(audio_path, mmap_mode='r')
+ if len(audio.shape) == 2:
+ audio = audio.squeeze(0)
+ input_size = int(self.n_samples)
+ if audio.shape[-1] < input_size:
+ pad = np.zeros(input_size)
+ pad[:audio.shape[-1]] = audio
+ audio = pad
+ random_idx = random.randint(0, audio.shape[-1]-self.n_samples)
+ audio_tensor = torch.from_numpy(np.array(audio[random_idx:random_idx+self.n_samples]).astype('float32'))
+ audio_tensor = torch.randn(16000*10)
+ return audio_tensor
+
+ def __getitem__(self, index):
+ tag = random.choice(self.tags) # uniform random sample tag
+ fname = random.choice(self.tag_to_track[tag]) # uniform random sample track
+ item = self.annotation[fname]
+ track_id = item['track_id']
+ gt_caption = "" # no ground turhth
+ text = self.load_caption(item)
+ audio_path = os.path.join(self.data_path,'msd','npy', item['path'].replace(".mp3", ".npy"))
+ audio_tensor = self.load_audio(audio_path, file_type=".npy")
+ return fname, gt_caption, text, audio_tensor
+
+ def __len__(self):
+ return 2**13
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/eval.py b/melodytalk/dependencies/lpmc/music_captioning/eval.py
new file mode 100644
index 0000000..e9d8362
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/eval.py
@@ -0,0 +1,51 @@
+import os
+import argparse
+import json
+import numpy as np
+from datasets import load_dataset
+from lpmc.utils.metrics import bleu, meteor, rouge, bertscore, vocab_novelty, caption_novelty
+
+def inference_parsing(dataset, args):
+ ground_truths = [i['caption_ground_truth'] for i in dataset]
+ inference = json.load(open(os.path.join(args.save_dir, args.framework, args.caption_type, 'inference_temp.json'), 'r'))
+ id2pred = {item['audio_id']:item['predictions'] for item in inference.values()}
+ predictions = [id2pred[i['fname']] for i in dataset]
+ return predictions, ground_truths
+
+def main(args):
+ dataset = load_dataset("seungheondoh/LP-MusicCaps-MC")
+ train_data = [i for i in dataset['train'] if i['is_crawled']]
+ test_data = [i for i in dataset['test'] if i['is_crawled']]
+ tr_ground_truths = [i['caption_ground_truth'] for i in train_data]
+ predictions, ground_truths = inference_parsing(test_data, args)
+ length_avg = np.mean([len(cap.split()) for cap in predictions])
+ length_std = np.std([len(cap.split()) for cap in predictions])
+
+ vocab_size, vocab_novel_score = vocab_novelty(predictions, tr_ground_truths)
+ cap_novel_score = caption_novelty(predictions, tr_ground_truths)
+ results = {
+ "bleu1": bleu(predictions, ground_truths, order=1),
+ "bleu2": bleu(predictions, ground_truths, order=2),
+ "bleu3": bleu(predictions, ground_truths, order=3),
+ "bleu4": bleu(predictions, ground_truths, order=4),
+ "meteor_1.0": meteor(predictions, ground_truths), # https://github.com/huggingface/evaluate/issues/115
+ "rougeL": rouge(predictions, ground_truths),
+ "bertscore": bertscore(predictions, ground_truths),
+ "vocab_size": vocab_size,
+ "vocab_novelty": vocab_novel_score,
+ "caption_novelty": cap_novel_score,
+ "length_avg": length_avg,
+ "length_std": length_std
+ }
+ os.makedirs(os.path.join(args.save_dir, args.framework, args.caption_type), exist_ok=True)
+ with open(os.path.join(args.save_dir, args.framework, args.caption_type, f"results.json"), mode="w") as io:
+ json.dump(results, io, indent=4)
+ print(results)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--save_dir", default="./exp", type=str)
+ parser.add_argument("--framework", default="supervised", type=str)
+ parser.add_argument("--caption_type", default="gt", type=str)
+ args = parser.parse_args()
+ main(args=args)
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/hparams.yaml b/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/hparams.yaml
new file mode 100644
index 0000000..6b791f6
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/hparams.yaml
@@ -0,0 +1,27 @@
+framework: bart
+data_dir: ../../dataset
+train_data: msd_balence
+text_type: all
+arch: transformer
+workers: 12
+epochs: 4096
+warmup_epochs: 125
+start_epoch: 0
+batch_size: 256
+world_size: 1
+lr: 0.0001
+min_lr: 1.0e-09
+rank: 0
+dist_url: tcp://localhost:12312
+dist_backend: nccl
+seed: null
+gpu: 0
+print_freq: 100
+multiprocessing_distributed: false
+cos: true
+bart_pretrain: false
+label_smoothing: 0.1
+use_early_stopping: false
+eval_sample: 0
+max_length: 110
+distributed: false
diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/results.json b/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/results.json
new file mode 100644
index 0000000..9e57a16
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/exp/pretrain/lp_music_caps/results.json
@@ -0,0 +1,14 @@
+{
+ "bleu1": 0.19774511999248645,
+ "bleu2": 0.06701659207795346,
+ "bleu3": 0.02165946217889739,
+ "bleu4": 0.007868711032317937,
+ "meteor_1.0": 0.12883275733962948,
+ "rougeL": 0.1302757267270023,
+ "bertscore": 0.8451152142608364,
+ "vocab_size": 1686,
+ "vocab_diversity": 0.4721233689205219,
+ "caption_novelty": 1.0,
+ "length_avg": 45.26699542092286,
+ "length_std": 27.99358111821529
+}
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/hparams.yaml b/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/hparams.yaml
new file mode 100644
index 0000000..04829e5
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/hparams.yaml
@@ -0,0 +1,26 @@
+framework: bart
+data_dir: ../../dataset
+train_data: music_caps
+text_type: gt
+arch: transformer
+workers: 8
+epochs: 100
+warmup_epochs: 1
+start_epoch: 0
+batch_size: 64
+world_size: 1
+lr: 0.0001
+min_lr: 1.0e-09
+rank: 0
+dist_url: tcp://localhost:12312
+dist_backend: nccl
+seed: null
+gpu: 0
+print_freq: 100
+multiprocessing_distributed: false
+cos: true
+bart_pretrain: false
+label_smoothing: 0.1
+use_early_stopping: false
+eval_sample: 64
+max_length: 128
diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/results.json b/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/results.json
new file mode 100644
index 0000000..acc0e3c
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/exp/supervised/gt/results.json
@@ -0,0 +1,14 @@
+{
+ "bleu1": 0.2850628531445219,
+ "bleu2": 0.13762330787082025,
+ "bleu3": 0.07588914964864207,
+ "bleu4": 0.047923589711449235,
+ "meteor_1.0": 0.20624392533703412,
+ "rougeL": 0.19215453358420784,
+ "bertscore": 0.870539891400695,
+ "vocab_size": 2240,
+ "vocab_diversity": 0.005357142857142857,
+ "caption_novelty": 0.6899661781285231,
+ "length_avg": 46.68193025713279,
+ "length_std": 16.516237946145356
+}
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/hparams.yaml b/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/hparams.yaml
new file mode 100644
index 0000000..728a20a
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/hparams.yaml
@@ -0,0 +1,25 @@
+framework: transfer
+data_dir: ../../dataset
+text_type: gt
+arch: transformer
+workers: 8
+epochs: 100
+warmup_epochs: 20
+start_epoch: 0
+batch_size: 64
+world_size: 1
+lr: 0.0001
+min_lr: 1.0e-09
+rank: 0
+dist_url: tcp://localhost:12312
+dist_backend: nccl
+seed: null
+gpu: 1
+print_freq: 10
+multiprocessing_distributed: false
+cos: true
+bart_pretrain: false
+label_smoothing: 0.1
+use_early_stopping: false
+eval_sample: 64
+max_length: 128
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/results.json b/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/results.json
new file mode 100644
index 0000000..55961d8
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/exp/transfer/lp_music_caps/results.json
@@ -0,0 +1,14 @@
+{
+ "bleu1": 0.29093124850986807,
+ "bleu2": 0.1486508686790579,
+ "bleu3": 0.08928000372127677,
+ "bleu4": 0.06046496381102543,
+ "meteor_1.0": 0.22386547507696472,
+ "rougeL": 0.2148899248811179,
+ "bertscore": 0.8777867602017111,
+ "vocab_size": 1695,
+ "vocab_diversity": 0.014749262536873156,
+ "caption_novelty": 0.9606498194945848,
+ "length_avg": 42.47234941880944,
+ "length_std": 14.336758651069372
+}
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/infer.py b/melodytalk/dependencies/lpmc/music_captioning/infer.py
new file mode 100644
index 0000000..3bddae6
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/infer.py
@@ -0,0 +1,100 @@
+import argparse
+import os
+import json
+import torch
+import torch.nn.parallel
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+
+from lpmc.music_captioning.datasets.mc import MC_Dataset
+from lpmc.music_captioning.model.bart import BartCaptionModel
+from lpmc.utils.eval_utils import load_pretrained
+from tqdm import tqdm
+from omegaconf import OmegaConf
+
+parser = argparse.ArgumentParser(description='PyTorch MSD Training')
+parser.add_argument('--data_dir', type=str, default="../../dataset")
+parser.add_argument('--framework', type=str, default="supervised")
+parser.add_argument("--caption_type", default="gt", type=str)
+parser.add_argument('--arch', default='transformer')
+parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
+ help='number of data loading workers')
+parser.add_argument('--epochs', default=100, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--warmup_epochs', default=10, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
+ help='manual epoch number (useful on restarts)')
+parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N')
+parser.add_argument('--world-size', default=1, type=int,
+ help='number of nodes for distributed training')
+parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
+ metavar='LR', help='initial learning rate', dest='lr')
+parser.add_argument('--min_lr', default=1e-9, type=float)
+parser.add_argument('--seed', default=42, type=int,
+ help='seed for initializing training. ')
+parser.add_argument('--gpu', default=1, type=int,
+ help='GPU id to use.')
+parser.add_argument('--print_freq', default=50, type=int)
+parser.add_argument("--cos", default=True, type=bool)
+parser.add_argument("--label_smoothing", default=0.1, type=float)
+parser.add_argument("--max_length", default=128, type=int)
+parser.add_argument("--num_beams", default=5, type=int)
+parser.add_argument("--model_type", default="last", type=str)
+
+
+def main():
+ args = parser.parse_args()
+ main_worker(args)
+
+def main_worker(args):
+ test_dataset = MC_Dataset(
+ data_path = args.data_dir,
+ split="test",
+ caption_type = "gt"
+ )
+ print(len(test_dataset))
+ test_loader = torch.utils.data.DataLoader(
+ test_dataset, batch_size=args.batch_size, shuffle=False,
+ num_workers=args.workers, pin_memory=True, drop_last=False)
+ model = BartCaptionModel(
+ max_length = args.max_length,
+ label_smoothing = args.label_smoothing,
+ )
+ eval(args, model, test_dataset, test_loader, args.num_beams)
+
+def eval(args, model, test_dataset, test_loader, num_beams=5):
+ save_dir = f"exp/{args.framework}/{args.caption_type}/"
+ config = OmegaConf.load(os.path.join(save_dir, "hparams.yaml"))
+ model, save_epoch = load_pretrained(args, save_dir, model, mdp=config.multiprocessing_distributed)
+ torch.cuda.set_device(args.gpu)
+ model = model.cuda(args.gpu)
+ model.eval()
+
+ inference_results = {}
+ idx = 0
+ for batch in tqdm(test_loader):
+ fname, text,audio_tensor = batch
+ if args.gpu is not None:
+ audio_tensor = audio_tensor.cuda(args.gpu, non_blocking=True)
+ with torch.no_grad():
+ output = model.generate(
+ samples=audio_tensor,
+ num_beams=num_beams,
+ )
+ for audio_id, gt, pred in zip(fname, text, output):
+ inference_results[idx] = {
+ "predictions": pred,
+ "true_captions": gt,
+ "audio_id": audio_id
+ }
+ idx += 1
+
+ with open(os.path.join(save_dir, f"inference.json"), mode="w") as io:
+ json.dump(inference_results, io, indent=4)
+
+if __name__ == '__main__':
+ main()
+
+
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/model/bart.py b/melodytalk/dependencies/lpmc/music_captioning/model/bart.py
new file mode 100644
index 0000000..308214c
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/model/bart.py
@@ -0,0 +1,153 @@
+### code reference: https://github.com/XinhaoMei/WavCaps/blob/master/captioning/models/bart_captioning.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from lpmc.music_captioning.model.modules import AudioEncoder
+from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
+
+class BartCaptionModel(nn.Module):
+ def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
+ super(BartCaptionModel, self).__init__()
+ # non-finetunning case
+ bart_config = BartConfig.from_pretrained(bart_type)
+ self.tokenizer = BartTokenizer.from_pretrained(bart_type)
+ self.bart = BartForConditionalGeneration(bart_config)
+
+ self.n_sample = sr * duration
+ self.hop_length = int(0.01 * sr) # hard coding hop_size
+ self.n_frames = int(self.n_sample // self.hop_length)
+ self.num_of_stride_conv = num_of_conv - 1
+ self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
+ self.audio_encoder = AudioEncoder(
+ n_mels = n_mels, # hard coding n_mel
+ n_ctx = self.n_ctx,
+ audio_dim = audio_dim,
+ text_dim = self.bart.config.hidden_size,
+ num_of_stride_conv = self.num_of_stride_conv
+ )
+
+ self.max_length = max_length
+ self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)
+
+ @property
+ def device(self):
+ return list(self.parameters())[0].device
+
+ def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.ls
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+ return shifted_input_ids
+
+ def forward_encoder(self, audio):
+ audio_embs = self.audio_encoder(audio)
+ encoder_outputs = self.bart.model.encoder(
+ input_ids=None,
+ inputs_embeds=audio_embs,
+ return_dict=True
+ )["last_hidden_state"]
+ return encoder_outputs, audio_embs
+
+ def forward_decoder(self, text, encoder_outputs):
+ text = self.tokenizer(text,
+ padding='longest',
+ truncation=True,
+ max_length=self.max_length,
+ return_tensors="pt")
+ input_ids = text["input_ids"].to(self.device)
+ attention_mask = text["attention_mask"].to(self.device)
+
+ decoder_targets = input_ids.masked_fill(
+ input_ids == self.tokenizer.pad_token_id, -100
+ )
+
+ decoder_input_ids = self.shift_tokens_right(
+ decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
+ )
+
+ decoder_outputs = self.bart(
+ input_ids=None,
+ attention_mask=None,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=attention_mask,
+ inputs_embeds=None,
+ labels=None,
+ encoder_outputs=(encoder_outputs,),
+ return_dict=True
+ )
+ lm_logits = decoder_outputs["logits"]
+ loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
+ return loss
+
+ def forward(self, audio, text):
+ encoder_outputs, _ = self.forward_encoder(audio)
+ loss = self.forward_decoder(text, encoder_outputs)
+ return loss
+
+ def generate(self,
+ samples,
+ use_nucleus_sampling=False,
+ num_beams=5,
+ max_length=128,
+ min_length=2,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ ):
+
+ # self.bart.force_bos_token_to_be_generated = True
+ audio_embs = self.audio_encoder(samples)
+ encoder_outputs = self.bart.model.encoder(
+ input_ids=None,
+ attention_mask=None,
+ head_mask=None,
+ inputs_embeds=audio_embs,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=True)
+
+ input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
+ input_ids[:, 0] = self.bart.config.decoder_start_token_id
+ decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
+ if use_nucleus_sampling:
+ outputs = self.bart.generate(
+ input_ids=None,
+ attention_mask=None,
+ decoder_input_ids=input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=True,
+ top_p=top_p,
+ num_return_sequences=1,
+ repetition_penalty=1.1)
+ else:
+ outputs = self.bart.generate(input_ids=None,
+ attention_mask=None,
+ decoder_input_ids=input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ head_mask=None,
+ decoder_head_mask=None,
+ inputs_embeds=None,
+ decoder_inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ max_length=max_length,
+ min_length=min_length,
+ num_beams=num_beams,
+ repetition_penalty=repetition_penalty)
+
+ captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ return captions
diff --git a/melodytalk/dependencies/lpmc/music_captioning/model/modules.py b/melodytalk/dependencies/lpmc/music_captioning/model/modules.py
new file mode 100644
index 0000000..788baa6
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/model/modules.py
@@ -0,0 +1,95 @@
+### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
+
+import os
+import torch
+import torchaudio
+import numpy as np
+import torch.nn.functional as F
+from torch import Tensor, nn
+from typing import Dict, Iterable, Optional
+
+# hard-coded audio hyperparameters
+SAMPLE_RATE = 16000
+N_FFT = 1024
+N_MELS = 128
+HOP_LENGTH = int(0.01 * SAMPLE_RATE)
+DURATION = 10
+N_SAMPLES = int(DURATION * SAMPLE_RATE)
+N_FRAMES = N_SAMPLES // HOP_LENGTH + 1
+
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding"""
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+class MelEncoder(nn.Module):
+ """
+ time-frequency represntation
+ """
+ def __init__(self,
+ sample_rate= 16000,
+ f_min=0,
+ f_max=8000,
+ n_fft=1024,
+ win_length=1024,
+ hop_length = int(0.01 * 16000),
+ n_mels = 128,
+ power = None,
+ pad= 0,
+ normalized= False,
+ center= True,
+ pad_mode= "reflect"
+ ):
+ super(MelEncoder, self).__init__()
+ self.window = torch.hann_window(win_length)
+ self.spec_fn = torchaudio.transforms.Spectrogram(
+ n_fft = n_fft,
+ win_length = win_length,
+ hop_length = hop_length,
+ power = power
+ )
+ self.mel_scale = torchaudio.transforms.MelScale(
+ n_mels,
+ sample_rate,
+ f_min,
+ f_max,
+ n_fft // 2 + 1)
+
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
+
+ def forward(self, wav):
+ spec = self.spec_fn(wav)
+ power_spec = spec.real.abs().pow(2)
+ mel_spec = self.mel_scale(power_spec)
+ mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin))
+ return mel_spec
+
+class AudioEncoder(nn.Module):
+ def __init__(
+ self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int,
+ ):
+ super().__init__()
+ self.mel_encoder = MelEncoder(n_mels=n_mels)
+ self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1)
+ self.conv_stack = nn.ModuleList([])
+ for _ in range(num_of_stride_conv):
+ self.conv_stack.append(
+ nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1)
+ )
+ # self.proj = nn.Linear(audio_dim, text_dim, bias=False)
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim))
+
+ def forward(self, x: Tensor):
+ """
+ x : torch.Tensor, shape = (batch_size, waveform)
+ single channel wavform
+ """
+ x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx)
+ x = F.gelu(self.conv1(x))
+ for conv in self.conv_stack:
+ x = F.gelu(conv(x))
+ x = x.permute(0, 2, 1)
+ x = (x + self.positional_embedding).to(x.dtype)
+ return x
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/preprocessor.py b/melodytalk/dependencies/lpmc/music_captioning/preprocessor.py
new file mode 100644
index 0000000..0ead2c9
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/preprocessor.py
@@ -0,0 +1,58 @@
+import os
+import random
+from contextlib import contextmanager
+
+import json
+from melodytalk.dependencies.lpmc.utils.audio_utils import load_audio, STR_CH_FIRST
+from sklearn.preprocessing import MultiLabelBinarizer
+from tqdm import tqdm
+
+# hard coding hpamras
+DATASET_PATH = "../../dataset/msd"
+MUSIC_SAMPLE_RATE = 16000
+DURATION = 30
+DATA_LENGTH = int(MUSIC_SAMPLE_RATE * DURATION)
+
+@contextmanager
+def poolcontext(*args, **kwargs):
+ pool = multiprocessing.Pool(*args, **kwargs)
+ yield pool
+ pool.terminate()
+
+def msd_resampler(sample):
+ path = sample['path']
+ save_name = os.path.join(DATASET_PATH,'npy', path.replace(".mp3",".npy"))
+ src, _ = load_audio(
+ path=os.path.join(DATASET_PATH,'songs',path),
+ ch_format= STR_CH_FIRST,
+ sample_rate= MUSIC_SAMPLE_RATE,
+ downmix_to_mono= True)
+ if src.shape[-1] < DATA_LENGTH: # short case
+ pad = np.zeros(DATA_LENGTH)
+ pad[:src.shape[-1]] = src
+ src = pad
+ elif src.shape[-1] > DATA_LENGTH: # too long case
+ src = src[:DATA_LENGTH]
+
+ if not os.path.exists(os.path.dirname(save_name)):
+ os.makedirs(os.path.dirname(save_name))
+ np.save(save_name, src.astype(np.float32))
+
+def build_tag_to_track(msd_dataset, split):
+ """
+ for balanced sampler, we bulid tag_to_track graph
+ """
+ mlb = MultiLabelBinarizer()
+ indexs = [i['track_id'] for i in msd_dataset[split]]
+ binary = mlb.fit_transform([i['tag'] for i in msd_dataset[split]])
+ tags = list(mlb.classes_)
+ tag_to_track = {}
+ for idx, tag in enumerate(tqdm(tags)):
+ track_list = [indexs[i] for i in binary[:,idx].nonzero()[0]]
+ tag_to_track[tag] = track_list
+
+ with open(os.path.join(DATASET_PATH, f"{split}_tag_to_track.json"), mode="w") as io:
+ json.dump(tag_to_track, io, indent=4)
+
+ with open(os.path.join(DATASET_PATH, f"{split}_tags.json"), mode="w") as io:
+ json.dump(tags, io, indent=4)
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/readme.md b/melodytalk/dependencies/lpmc/music_captioning/readme.md
new file mode 100644
index 0000000..441c77b
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/readme.md
@@ -0,0 +1,120 @@
+# Audio-to-Caption using Cross Modal Encoder-Decoder
+
+We used a cross-modal encoder-decoder transformer architecture.
+
+1. Similar to Whisper, the encoder takes a log-mel spectrogram with six convolution layers with a filter width of 3 and the GELU activation function. With the exception of the first layer, each convolution layer has a stride of two. The output of the convolution layers is combined with the sinusoidal position encoding and then processed by the encoder transformer blocks.
+
+2. Following the BART architecture, our encoder and decoder both have 768 widths and 6 transformer blocks. The decoder processes tokenized text captions using transformer blocks with a multi-head attention module that includes a mask to hide future tokens for causality. The music and caption representations are fed into the cross-modal attention layer, and the head of the language model in the decoder predicts the next token autoregressively using the cross-entropy loss.
+
+- **Supervised Model** : [download link](https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/supervised.pth)
+- **Pretrain Model** : [download link](https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/pretrain.pth)
+- **Transfer Model** : [download link](https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth)
+
+
+
+
+
+## 0. Quick Start
+```bash
+# download pretrain model weight from huggingface
+
+wget https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/supervised.pth -O exp/supervised/gt/last.pth
+wget https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth -O exp/transfer/lp_music_caps/last.pth
+wget https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/pretrain.pth -O exp/pretrain/lp_music_caps/last.pth
+python captioning.py --audio_path ../../dataset/samples/orchestra.wav
+```
+
+```bash
+{
+ 'text': "This is a symphonic orchestra playing a piece that's riveting, thrilling and exciting.
+ The peace would be suitable in a movie when something grand and impressive happens.
+ There are clarinets, tubas, trumpets and french horns being played. The brass instruments help create that sense of a momentous occasion.",
+ 'time': '0:00-10:00'
+}
+{
+ 'text': 'This is a classical music piece from a movie soundtrack.
+ There is a clarinet playing the main melody while a brass section and a flute are playing the melody.
+ The rhythmic background is provided by the acoustic drums. The atmosphere is epic and victorious.
+ This piece could be used in the soundtrack of a historical drama movie during the scenes of an army marching towards the end.',
+'time': '10:00-20:00'
+}
+```
+
+## 1. Preprocessing audio with ffmpeg
+
+For fast training, we resample audio at 16000 sampling rate and save it as `.npy`.
+
+```python
+def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]:
+ """
+ Decoding, downmixing, and downsampling by librosa.
+ Returns a channel-first audio signal.
+
+ Args:
+ path:
+ sample_rate:
+ downmix_to_mono:
+
+ Returns:
+ (audio signal, sample rate)
+ """
+
+ def _decode_resample_by_ffmpeg(filename, sr):
+ """decode, downmix, and resample audio file"""
+ channel_cmd = '-ac 1 ' if downmix_to_mono else '' # downmixing option
+ resampling_cmd = f'-ar {str(sr)}' if sr else '' # downsampling option
+ cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -"
+ p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, err = p.communicate()
+ return out
+
+ src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate)))
+ return src.T, sr
+```
+
+The code using `multiprocessing`` is as follows. We also provide preprocessing code for balanced data loading.
+
+```
+# multiprocessing resampling & bulid tag-to-track linked list for balanced data loading.
+python preprocessor.py
+```
+
+## 2. Train & Eval Supervised Model (Baseline)
+
+Download [MusicCaps audio](https://github.com/seungheondoh/music_caps_dl), if you hard to get audio please request for research purpose
+
+```
+# train supervised baseline model
+python train.py --framework supervised --train_data mc --caption_type gt --warmup_epochs 1 --label_smoothing 0.1 --max_length 128 --batch-size 64 --epochs 100
+
+# inference caption
+python infer.py --framework supervised --train_data mc --caption_type gt --num_beams 5 --model_type last
+
+# eval
+python eval.py --framework supervised --caption_type gt
+```
+
+## 3. Pretrain, Transfer Music Captioning Model (Proposed)
+
+Download MSD audio, if you hard to get audio please request for research purpose
+
+```
+# train pretrain model
+python train.py --framework pretrain --train_data msd --caption_type lp_music_caps --warmup_epochs 125 --label_smoothing 0.1 --max_length 110 --batch-size 256 --epochs 4096
+
+# train transfer model
+python transfer.py --caption_type gt --warmup_epochs 1 --label_smoothing 0.1 --max_length 128 --batch-size 64 --epochs 100
+
+# inference caption
+python infer.py --framework transfer --caption_type lp_music_caps --num_beams 5 --model_type last
+
+# eval
+python eval.py --framework transfer --caption_type lp_music_caps
+```
+
+### License
+This project is under the CC-BY-NC 4.0 license. See LICENSE for details.
+
+
+### Acknowledgement
+We would like to thank the [Whisper](https://github.com/openai/whisper) for audio frontend, [WavCaps](https://github.com/XinhaoMei/WavCaps) for audio-captioning training code and [deezer-playntell](https://github.com/deezer/playntell) for contents based captioning evaluation protocal.
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/train.py b/melodytalk/dependencies/lpmc/music_captioning/train.py
new file mode 100644
index 0000000..62a899d
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/train.py
@@ -0,0 +1,134 @@
+import argparse
+import math
+import random
+import shutil
+import torch
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+
+from lpmc.music_captioning.datasets.mc import MC_Dataset
+from lpmc.music_captioning.datasets.msd import MSD_Balanced_Dataset
+from lpmc.music_captioning.model.bart import BartCaptionModel
+from lpmc.utils.train_utils import Logger, AverageMeter, ProgressMeter, EarlyStopping, save_hparams
+
+parser = argparse.ArgumentParser(description='PyTorch MSD Training')
+parser.add_argument('--framework', type=str, default="pretrain")
+parser.add_argument('--data_dir', type=str, default="../../dataset")
+parser.add_argument('--train_data', type=str, default="msd")
+parser.add_argument("--caption_type", default="lp_music_caps", type=str) # lp_music_caps
+parser.add_argument('--arch', default='transformer')
+parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
+ help='number of data loading workers')
+parser.add_argument('--epochs', default=100, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--warmup_epochs', default=125, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
+ help='manual epoch number (useful on restarts)')
+parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N')
+parser.add_argument('--world-size', default=1, type=int,
+ help='number of nodes for distributed training')
+parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
+ metavar='LR', help='initial learning rate', dest='lr')
+parser.add_argument('--min_lr', default=1e-9, type=float)
+parser.add_argument('--seed', default=None, type=int,
+ help='seed for initializing training. ')
+parser.add_argument('--gpu', default=1, type=int,
+ help='GPU id to use.')
+parser.add_argument('--print_freq', default=10, type=int)
+parser.add_argument("--cos", default=True, type=bool)
+parser.add_argument("--label_smoothing", default=0.1, type=float)
+parser.add_argument("--max_length", default=128, type=int)
+parser.add_argument("--resume", default=None, type=bool)
+
+def main():
+ args = parser.parse_args()
+ if args.seed is not None:
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ cudnn.deterministic = True
+ main_worker(args)
+
+def main_worker(args):
+ if args.train_data == "msd":
+ train_dataset = MSD_Balanced_Dataset(
+ data_path = args.data_dir,
+ split="train",
+ caption_type = args.caption_type
+ )
+ elif args.train_data == "mc":
+ train_dataset = MC_Dataset(
+ data_path = args.data_dir,
+ split="train",
+ caption_type = args.caption_type
+ )
+
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.batch_size, shuffle=True,
+ num_workers=args.workers, pin_memory=True, drop_last=True)
+
+ model = BartCaptionModel(
+ max_length = args.max_length,
+ label_smoothing = args.label_smoothing
+ )
+ torch.cuda.set_device(args.gpu)
+ model = model.cuda(args.gpu)
+
+ optimizer = torch.optim.AdamW(model.parameters(), args.lr)
+ save_dir = f"exp/{args.framework}/{args.caption_type}/"
+
+ logger = Logger(save_dir)
+ save_hparams(args, save_dir)
+
+ for epoch in range(args.start_epoch, args.epochs):
+ train(train_loader, model, optimizer, epoch, logger, args)
+
+ torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()}, f'{save_dir}/last.pth')
+ print("We are at epoch:", epoch)
+
+def train(train_loader, model, optimizer, epoch, logger, args):
+ train_losses = AverageMeter('Train Loss', ':.4e')
+ progress = ProgressMeter(len(train_loader),[train_losses],prefix="Epoch: [{}]".format(epoch))
+ iters_per_epoch = len(train_loader)
+ model.train()
+ for data_iter_step, batch in enumerate(train_loader):
+ lr = adjust_learning_rate(optimizer, data_iter_step / iters_per_epoch + epoch, args)
+ fname, text, audio_tensor = batch
+ if args.gpu is not None:
+ audio_tensor = audio_tensor.cuda(args.gpu, non_blocking=True)
+ # compute output
+ loss = model(audio=audio_tensor, text=text)
+ train_losses.step(loss.item(), audio_tensor.size(0))
+ logger.log_train_loss(loss, epoch * iters_per_epoch + data_iter_step)
+ logger.log_learning_rate(lr, epoch * iters_per_epoch + data_iter_step)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if data_iter_step % args.print_freq == 0:
+ progress.display(data_iter_step)
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
+ torch.save(state, filename)
+ if is_best:
+ shutil.copyfile(filename, 'model_best.pth.tar')
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+ if epoch < args.warmup_epochs:
+ lr = args.lr * epoch / args.warmup_epochs
+ else:
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+ return lr
+
+if __name__ == '__main__':
+ main()
+
+
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/music_captioning/transfer.py b/melodytalk/dependencies/lpmc/music_captioning/transfer.py
new file mode 100644
index 0000000..6889857
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/music_captioning/transfer.py
@@ -0,0 +1,125 @@
+import argparse
+import math
+import os
+import random
+import shutil
+import torch
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+from lpmc.music_captioning.datasets.mc import MC_Dataset
+from lpmc.music_captioning.model.bart import BartCaptionModel
+from lpmc.utils.train_utils import Logger, AverageMeter, ProgressMeter, EarlyStopping, save_hparams
+from mcb.utils.eval_utils import load_pretrained, print_model_params
+from omegaconf import OmegaConf
+
+parser = argparse.ArgumentParser(description='PyTorch MSD Training')
+parser.add_argument('--framework', type=str, default="transfer")
+parser.add_argument('--data_dir', type=str, default="../../dataset")
+parser.add_argument("--caption_type", default="lp_music_caps", type=str)
+parser.add_argument('--arch', default='transformer')
+parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
+ help='number of data loading workers')
+parser.add_argument('--epochs', default=100, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--warmup_epochs', default=20, type=int, metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
+ help='manual epoch number (useful on restarts)')
+parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N')
+parser.add_argument('--world-size', default=1, type=int,
+ help='number of nodes for distributed training')
+parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
+ metavar='LR', help='initial learning rate', dest='lr')
+parser.add_argument('--min_lr', default=1e-9, type=float)
+parser.add_argument('--seed', default=None, type=int,
+ help='seed for initializing training. ')
+parser.add_argument('--gpu', default=1, type=int,
+ help='GPU id to use.')
+parser.add_argument('--print_freq', default=10, type=int)
+parser.add_argument("--cos", default=True, type=bool)
+parser.add_argument("--label_smoothing", default=0.1, type=float)
+parser.add_argument("--max_length", default=128, type=int)
+
+def main():
+ args = parser.parse_args()
+ if args.seed is not None:
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ cudnn.deterministic = True
+ main_worker(args)
+
+def main_worker(args):
+ train_dataset = MC_Dataset(
+ data_path = args.data_dir,
+ split="train",
+ caption_type = "gt"
+ )
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.batch_size, shuffle=False,
+ num_workers=args.workers, pin_memory=True, drop_last=False)
+
+ model = BartCaptionModel(
+ max_length = args.max_length,
+ label_smoothing = args.label_smoothing
+ )
+ pretrain_dir = f"exp/pretrain/{args.caption_type}/"
+ config = OmegaConf.load(os.path.join(pretrain_dir, "hparams.yaml"))
+ model, save_epoch = load_pretrained(args, pretrain_dir, model, model_types="last", mdp=config.multiprocessing_distributed)
+ print_model_params(model)
+
+ torch.cuda.set_device(args.gpu)
+ model = model.cuda(args.gpu)
+
+ optimizer = torch.optim.AdamW(model.parameters(), args.lr)
+ save_dir = f"exp/transfer/{args.caption_type}"
+ logger = Logger(save_dir)
+ save_hparams(args, save_dir)
+ for epoch in range(args.start_epoch, args.epochs):
+ train(train_loader, model, optimizer, epoch, logger, args)
+ torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()}, f'{save_dir}/last.pth')
+
+def train(train_loader, model, optimizer, epoch, logger, args):
+ train_losses = AverageMeter('Train Loss', ':.4e')
+ progress = ProgressMeter(len(train_loader),[train_losses],prefix="Epoch: [{}]".format(epoch))
+ iters_per_epoch = len(train_loader)
+ model.train()
+ for data_iter_step, batch in enumerate(train_loader):
+ lr = adjust_learning_rate(optimizer, data_iter_step / iters_per_epoch + epoch, args)
+ fname, gt_caption, text, audio_embs = batch
+ if args.gpu is not None:
+ audio_embs = audio_embs.cuda(args.gpu, non_blocking=True)
+ # compute output
+ loss = model(audio=audio_embs, text=text)
+ train_losses.step(loss.item(), audio_embs.size(0))
+ logger.log_train_loss(loss, epoch * iters_per_epoch + data_iter_step)
+ logger.log_learning_rate(lr, epoch * iters_per_epoch + data_iter_step)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if data_iter_step % args.print_freq == 0:
+ progress.display(data_iter_step)
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
+ torch.save(state, filename)
+ if is_best:
+ shutil.copyfile(filename, 'model_best.pth.tar')
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+ if epoch < args.warmup_epochs:
+ lr = args.lr * epoch / args.warmup_epochs
+ else:
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+ return lr
+
+if __name__ == '__main__':
+ main()
+
+
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/utils/audio_utils.py b/melodytalk/dependencies/lpmc/utils/audio_utils.py
new file mode 100644
index 0000000..d033238
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/utils/audio_utils.py
@@ -0,0 +1,247 @@
+STR_CLIP_ID = 'clip_id'
+STR_AUDIO_SIGNAL = 'audio_signal'
+STR_TARGET_VECTOR = 'target_vector'
+
+
+STR_CH_FIRST = 'channels_first'
+STR_CH_LAST = 'channels_last'
+
+import io
+import os
+import tqdm
+import logging
+import subprocess
+from typing import Tuple
+from pathlib import Path
+
+# import librosa
+import numpy as np
+import soundfile as sf
+
+import itertools
+from numpy.fft import irfft
+
+def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]:
+ """
+ Decoding, downmixing, and downsampling by librosa.
+ Returns a channel-first audio signal.
+
+ Args:
+ path:
+ sample_rate:
+ downmix_to_mono:
+
+ Returns:
+ (audio signal, sample rate)
+ """
+
+ def _decode_resample_by_ffmpeg(filename, sr):
+ """decode, downmix, and resample audio file"""
+ channel_cmd = '-ac 1 ' if downmix_to_mono else '' # downmixing option
+ resampling_cmd = f'-ar {str(sr)}' if sr else '' # downsampling option
+ cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -"
+ p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, err = p.communicate()
+ return out
+
+ src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate)))
+ return src.T, sr
+
+
+def _resample_load_librosa(path: str, sample_rate: int, downmix_to_mono: bool, **kwargs) -> Tuple[np.ndarray, int]:
+ """
+ Decoding, downmixing, and downsampling by librosa.
+ Returns a channel-first audio signal.
+ """
+ src, sr = librosa.load(path, sr=sample_rate, mono=downmix_to_mono, **kwargs)
+ return src, sr
+
+
+def load_audio(
+ path: str or Path,
+ ch_format: str,
+ sample_rate: int = None,
+ downmix_to_mono: bool = False,
+ resample_by: str = 'ffmpeg',
+ **kwargs,
+) -> Tuple[np.ndarray, int]:
+ """A wrapper of librosa.load that:
+ - forces the returned audio to be 2-dim,
+ - defaults to sr=None, and
+ - defaults to downmix_to_mono=False.
+
+ The audio decoding is done by `audioread` or `soundfile` package and ultimately, often by ffmpeg.
+ The resampling is done by `librosa`'s child package `resampy`.
+
+ Args:
+ path: audio file path
+ ch_format: one of 'channels_first' or 'channels_last'
+ sample_rate: target sampling rate. if None, use the rate of the audio file
+ downmix_to_mono:
+ resample_by (str): 'librosa' or 'ffmpeg'. it decides backend for audio decoding and resampling.
+ **kwargs: keyword args for librosa.load - offset, duration, dtype, res_type.
+
+ Returns:
+ (audio, sr) tuple
+ """
+ if ch_format not in (STR_CH_FIRST, STR_CH_LAST):
+ raise ValueError(f'ch_format is wrong here -> {ch_format}')
+
+ if os.stat(path).st_size > 8000:
+ if resample_by == 'librosa':
+ src, sr = _resample_load_librosa(path, sample_rate, downmix_to_mono, **kwargs)
+ elif resample_by == 'ffmpeg':
+ src, sr = _resample_load_ffmpeg(path, sample_rate, downmix_to_mono)
+ else:
+ raise NotImplementedError(f'resample_by: "{resample_by}" is not supposred yet')
+ else:
+ raise ValueError('Given audio is too short!')
+ return src, sr
+
+ # if src.ndim == 1:
+ # src = np.expand_dims(src, axis=0)
+ # # now always 2d and channels_first
+
+ # if ch_format == STR_CH_FIRST:
+ # return src, sr
+ # else:
+ # return src.T, sr
+
+def ms(x):
+ """Mean value of signal `x` squared.
+ :param x: Dynamic quantity.
+ :returns: Mean squared of `x`.
+ """
+ return (np.abs(x)**2.0).mean()
+
+def normalize(y, x=None):
+ """normalize power in y to a (standard normal) white noise signal.
+ Optionally normalize to power in signal `x`.
+ #The mean power of a Gaussian with :math:`\\mu=0` and :math:`\\sigma=1` is 1.
+ """
+ if x is not None:
+ x = ms(x)
+ else:
+ x = 1.0
+ return y * np.sqrt(x / ms(y))
+
+def noise(N, color='white', state=None):
+ """Noise generator.
+ :param N: Amount of samples.
+ :param color: Color of noise.
+ :param state: State of PRNG.
+ :type state: :class:`np.random.RandomState`
+ """
+ try:
+ return _noise_generators[color](N, state)
+ except KeyError:
+ raise ValueError("Incorrect color.")
+
+def white(N, state=None):
+ """
+ White noise.
+ :param N: Amount of samples.
+ :param state: State of PRNG.
+ :type state: :class:`np.random.RandomState`
+ White noise has a constant power density. It's narrowband spectrum is therefore flat.
+ The power in white noise will increase by a factor of two for each octave band,
+ and therefore increases with 3 dB per octave.
+ """
+ state = np.random.RandomState() if state is None else state
+ return state.randn(N)
+
+def pink(N, state=None):
+ """
+ Pink noise.
+ :param N: Amount of samples.
+ :param state: State of PRNG.
+ :type state: :class:`np.random.RandomState`
+ Pink noise has equal power in bands that are proportionally wide.
+ Power density decreases with 3 dB per octave.
+ """
+ state = np.random.RandomState() if state is None else state
+ uneven = N % 2
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
+ S = np.sqrt(np.arange(len(X)) + 1.) # +1 to avoid divide by zero
+ y = (irfft(X / S)).real
+ if uneven:
+ y = y[:-1]
+ return normalize(y)
+
+def blue(N, state=None):
+ """
+ Blue noise.
+ :param N: Amount of samples.
+ :param state: State of PRNG.
+ :type state: :class:`np.random.RandomState`
+ Power increases with 6 dB per octave.
+ Power density increases with 3 dB per octave.
+ """
+ state = np.random.RandomState() if state is None else state
+ uneven = N % 2
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
+ S = np.sqrt(np.arange(len(X))) # Filter
+ y = (irfft(X * S)).real
+ if uneven:
+ y = y[:-1]
+ return normalize(y)
+
+def brown(N, state=None):
+ """
+ Violet noise.
+ :param N: Amount of samples.
+ :param state: State of PRNG.
+ :type state: :class:`np.random.RandomState`
+ Power decreases with -3 dB per octave.
+ Power density decreases with 6 dB per octave.
+ """
+ state = np.random.RandomState() if state is None else state
+ uneven = N % 2
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
+ S = (np.arange(len(X)) + 1) # Filter
+ y = (irfft(X / S)).real
+ if uneven:
+ y = y[:-1]
+ return normalize(y)
+
+def violet(N, state=None):
+ """
+ Violet noise. Power increases with 6 dB per octave.
+ :param N: Amount of samples.
+ :param state: State of PRNG.
+ :type state: :class:`np.random.RandomState`
+ Power increases with +9 dB per octave.
+ Power density increases with +6 dB per octave.
+ """
+ state = np.random.RandomState() if state is None else state
+ uneven = N % 2
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
+ S = (np.arange(len(X))) # Filter
+ y = (irfft(X * S)).real
+ if uneven:
+ y = y[:-1]
+ return normalize(y)
+
+_noise_generators = {
+ 'white': white,
+ 'pink': pink,
+ 'blue': blue,
+ 'brown': brown,
+ 'violet': violet,
+}
+
+def noise_generator(N=44100, color='white', state=None):
+ """Noise generator.
+ :param N: Amount of unique samples to generate.
+ :param color: Color of noise.
+ Generate `N` amount of unique samples and cycle over these samples.
+ """
+ #yield from itertools.cycle(noise(N, color)) # Python 3.3
+ for sample in itertools.cycle(noise(N, color, state)):
+ yield sample
+
+def heaviside(N):
+ """Heaviside.
+ Returns the value 0 for `x < 0`, 1 for `x > 0`, and 1/2 for `x = 0`.
+ """
+ return 0.5 * (np.sign(N) + 1)
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/utils/eval_utils.py b/melodytalk/dependencies/lpmc/utils/eval_utils.py
new file mode 100644
index 0000000..9fefc9d
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/utils/eval_utils.py
@@ -0,0 +1,30 @@
+import os
+import json
+import torch
+import numpy as np
+import pandas as pd
+from sklearn import metrics
+
+
+def print_model_params(model):
+ n_parameters = sum(p.numel() for p in model.parameters())
+ train_n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("============")
+ print('number of params (M): %.2f' % (n_parameters / 1.e6))
+ print('number train of params (M): %.2f' % (train_n_parameters / 1.e6))
+ print("============")
+
+def load_pretrained(args, save_dir, model, model_types="last", mdp=False):
+ pretrained_object = torch.load(f'{save_dir}/{model_types}.pth', map_location='cpu')
+ state_dict = pretrained_object['state_dict']
+ save_epoch = pretrained_object['epoch']
+ if mdp:
+ for k in list(state_dict.keys()):
+ if k.startswith('module.'):
+ state_dict[k[len("module."):]] = state_dict[k]
+ del state_dict[k]
+ model.load_state_dict(state_dict)
+ torch.cuda.set_device(args.gpu)
+ model = model.cuda(args.gpu)
+ model.eval()
+ return model, save_epoch
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/utils/metrics.py b/melodytalk/dependencies/lpmc/utils/metrics.py
new file mode 100644
index 0000000..d3cb799
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/utils/metrics.py
@@ -0,0 +1,144 @@
+"""Placeholder for metrics."""
+from functools import partial
+import evaluate
+import numpy as np
+import torch
+import torchmetrics.retrieval as retrieval_metrics
+# CAPTIONING METRICS
+def bleu(predictions, ground_truths, order):
+ bleu_eval = evaluate.load("bleu")
+ return bleu_eval.compute(
+ predictions=predictions, references=ground_truths, max_order=order
+ )["bleu"]
+
+def meteor(predictions, ground_truths):
+ # https://github.com/huggingface/evaluate/issues/115
+ meteor_eval = evaluate.load("meteor")
+ return meteor_eval.compute(predictions=predictions, references=ground_truths)[
+ "meteor"
+ ]
+
+
+def rouge(predictions, ground_truths):
+ rouge_eval = evaluate.load("rouge")
+ return rouge_eval.compute(predictions=predictions, references=ground_truths)[
+ "rougeL"
+ ]
+
+
+def bertscore(predictions, ground_truths):
+ bertscore_eval = evaluate.load("bertscore")
+ score = bertscore_eval.compute(
+ predictions=predictions, references=ground_truths, lang="en"
+ )["f1"]
+ return np.mean(score)
+
+
+def vocab_diversity(predictions, references):
+ train_caps_tokenized = [
+ train_cap.translate(str.maketrans("", "", string.punctuation)).lower().split()
+ for train_cap in references
+ ]
+ gen_caps_tokenized = [
+ gen_cap.translate(str.maketrans("", "", string.punctuation)).lower().split()
+ for gen_cap in predictions
+ ]
+ training_vocab = Vocabulary(train_caps_tokenized, min_count=2).idx2word
+ generated_vocab = Vocabulary(gen_caps_tokenized, min_count=1).idx2word
+
+ return len(generated_vocab) / len(training_vocab)
+
+
+def vocab_novelty(predictions, tr_ground_truths):
+ predictions_token, tr_ground_truths_token = [], []
+ for gen, ref in zip(predictions, tr_ground_truths):
+ predictions_token.extend(gen.lower().replace(",","").replace(".","").split())
+ tr_ground_truths_token.extend(ref.lower().replace(",","").replace(".","").split())
+
+ predictions_vocab = set(predictions_token)
+ new_vocab = predictions_vocab.difference(set(tr_ground_truths_token))
+
+ vocab_size = len(predictions_vocab)
+ novel_v = len(new_vocab) / vocab_size
+ return vocab_size, novel_v
+
+def caption_novelty(predictions, tr_ground_truths):
+ unique_pred_captions = set(predictions)
+ unique_train_captions = set(tr_ground_truths)
+
+ new_caption = unique_pred_captions.difference(unique_train_captions)
+ novel_c = len(new_caption) / len(unique_pred_captions)
+ return novel_c
+
+def metric_1(predictions, ground_truths) -> float:
+ """Computes metric_1 score.
+ Args:
+ predictions: A list of predictions.
+ ground_truths: A list of ground truths.
+ Returns:
+ metric_1: A float number, the metric_1 score.
+ """
+ return 0.0
+
+
+# RETRIEVAL METRICS
+def _prepare_torchmetrics_input(scores, query2target_idx):
+ target = [
+ [i in target_idxs for i in range(len(scores[0]))]
+ for query_idx, target_idxs in query2target_idx.items()
+ ]
+ indexes = torch.arange(len(scores)).unsqueeze(1).repeat((1, len(target[0])))
+ return torch.as_tensor(scores), torch.as_tensor(target), indexes
+
+
+def _call_torchmetrics(
+ metric: retrieval_metrics.RetrievalMetric, scores, query2target_idx, **kwargs
+):
+ preds, target, indexes = _prepare_torchmetrics_input(scores, query2target_idx)
+ return metric(preds, target, indexes=indexes, **kwargs).item()
+
+
+def recall(predicted_scores, query2target_idx, k: int) -> float:
+ """Compute retrieval recall score at cutoff k.
+
+ Args:
+ predicted_scores: N x M similarity matrix
+ query2target_idx: a dictionary with
+ key: unique query idx
+ values: list of target idx
+ k: number of top-k results considered
+ Returns:
+ average score of recall@k
+ """
+ recall_metric = retrieval_metrics.RetrievalRecall(k=k)
+ return _call_torchmetrics(recall_metric, predicted_scores, query2target_idx)
+
+
+def mean_average_precision(predicted_scores, query2target_idx) -> float:
+ """Compute retrieval mean average precision (MAP) score at cutoff k.
+
+ Args:
+ predicted_scores: N x M similarity matrix
+ query2target_idx: a dictionary with
+ key: unique query idx
+ values: list of target idx
+ Returns:
+ MAP@k score
+ """
+ map_metric = retrieval_metrics.RetrievalMAP()
+ return _call_torchmetrics(map_metric, predicted_scores, query2target_idx)
+
+
+def mean_reciprocal_rank(predicted_scores, query2target_idx) -> float:
+ """Compute retrieval mean reciprocal rank (MRR) score.
+
+ Args:
+ predicted_scores: N x M similarity matrix
+ query2target_idx: a dictionary with
+ key: unique query idx
+ values: list of target idx
+ Returns:
+ MRR score
+ """
+ mrr_metric = retrieval_metrics.RetrievalMRR()
+ return _call_torchmetrics(mrr_metric, predicted_scores, query2target_idx)
\ No newline at end of file
diff --git a/melodytalk/dependencies/lpmc/utils/train_utils.py b/melodytalk/dependencies/lpmc/utils/train_utils.py
new file mode 100644
index 0000000..be5c17c
--- /dev/null
+++ b/melodytalk/dependencies/lpmc/utils/train_utils.py
@@ -0,0 +1,120 @@
+import os
+import torch
+from torch.utils.tensorboard import SummaryWriter
+from omegaconf import DictConfig, OmegaConf
+
+def save_hparams(args, save_path):
+ save_config = OmegaConf.create(vars(args))
+ os.makedirs(save_path, exist_ok=True)
+ OmegaConf.save(config=save_config, f= os.path.join(save_path, "hparams.yaml"))
+
+class EarlyStopping():
+ def __init__(self, min_max="min", tolerance=20, min_delta=1e-9):
+ self.tolerance = tolerance
+ self.min_delta = min_delta
+ self.min_max = min_max
+ self.counter = 0
+ self.early_stop = False
+
+ def min_stopping(self, valid_loss, best_valid_loss):
+ if (valid_loss - best_valid_loss) > self.min_delta:
+ self.counter +=1
+ if self.counter >= self.tolerance:
+ self.early_stop = True
+ else:
+ self.counter = 0
+
+ def max_stopping(self, valid_acc, best_valid_acc):
+ if (best_valid_acc - valid_acc) > self.min_delta:
+ self.counter +=1
+ if self.counter >= self.tolerance:
+ self.early_stop = True
+ else:
+ self.counter = 0
+
+ def __call__(self, valid_metric, best_metic):
+ if self.min_max == "min":
+ self.min_stopping(valid_metric, best_metic)
+ elif self.min_max == "max":
+ self.max_stopping(valid_metric, best_metic)
+ else:
+ raise ValueError(f"Unexpected split name: {self.min_max}")
+
+class Logger(SummaryWriter):
+ def __init__(self, logdir):
+ super(Logger, self).__init__(logdir)
+
+ def log_train_loss(self, loss, steps):
+ self.add_scalar('train_loss', loss.item(), steps)
+
+ def log_val_loss(self, loss, epochs):
+ self.add_scalar('valid_loss', loss.item(), epochs)
+
+ def log_caption_matric(self, metric, epochs, name="acc"):
+ self.add_scalar(f'{name}', metric, epochs)
+
+ def log_logitscale(self, logitscale, epochs):
+ self.add_scalar('logit_scale', logitscale.item(), epochs)
+
+ def log_learning_rate(self, lr, epochs):
+ self.add_scalar('lr', lr, epochs)
+
+ def log_learning_rate(self, lr, epochs):
+ self.add_scalar('lr', lr, epochs)
+
+ def log_roc(self, roc, epochs):
+ self.add_scalar('roc', roc, epochs)
+
+ def log_pr(self, pr, epochs):
+ self.add_scalar('pr', pr, epochs)
+
+class AverageMeter(object):
+ def __init__(self,name, fmt, init_steps=0):
+ self.name = name
+ self.fmt = fmt
+ self.steps = init_steps
+ self.reset()
+
+ def reset(self):
+ self.val = 0.0
+ self.sum = 0.0
+ self.num = 0
+ self.avg = 0.0
+
+ def step(self, val, num=1):
+ self.val = val
+ self.sum += num*val
+ self.num += num
+ self.steps += 1
+ self.avg = self.sum/self.num
+
+ def __str__(self):
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+ return fmtstr.format(**self.__dict__)
+
+class ProgressMeter(object):
+ def __init__(self, num_batches, meters, prefix=""):
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+ self.meters = meters
+ self.prefix = prefix
+
+ def display(self, batch):
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
+ entries += [str(meter) for meter in self.meters]
+ print('\t'.join(entries))
+
+ def _get_batch_fmtstr(self, num_batches):
+ num_digits = len(str(num_batches // 1))
+ fmt = '{:' + str(num_digits) + 'd}'
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+def load_pretrained(pretrain_path, model):
+ checkpoint= torch.load(pretrain_path, map_location='cpu')
+ state_dict = checkpoint['state_dict']
+ for k in list(state_dict.keys()):
+ # retain only encoder_q up to before the embedding layer
+ if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.1.mlp'):
+ state_dict[k[len("module.encoder_q.0."):]] = state_dict[k]
+ del state_dict[k]
+ model.load_state_dict(state_dict)
+ return model
\ No newline at end of file
diff --git a/melodytalk/modules.py b/melodytalk/modules.py
index 8bc1ed3..153b2a7 100644
--- a/melodytalk/modules.py
+++ b/melodytalk/modules.py
@@ -324,7 +324,18 @@ def inference(self, inputs):
class MusicCaptioning(object):
def __init__(self):
- raise NotImplementedError
+ print("Initializing MusicCaptioning")
+
+ @prompts(
+ name="Describe the current music.",
+ description="useful if you want to describe a music."
+ "Like: describe the current music, or what is the current music sounds like."
+ "The input to this tool should be the music_filename. "
+ )
+
+ def inference(self, inputs):
+ pass
+
# class Text2MusicwithChord(object):
# template_model = True