Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ldzhangyx committed Aug 22, 2023
1 parent 05624e5 commit e682cb6
Show file tree
Hide file tree
Showing 24 changed files with 1,717 additions and 3 deletions.
Binary file modified .DS_Store
Binary file not shown.
7 changes: 5 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from melodytalk.dependencies.laion_clap.hook import CLAP_Module
from datetime import datetime
import torch
import demucs.separate
from melodytalk.utils import CLAP_post_filter

MODEL_NAME = 'melody'
DURATION = 35
CFG_COEF = 3
SAMPLES = 5
# PROMPT = 'music loop. Passionate love song with guitar rhythms, electric piano chords, drums pattern. instrument: guitar, piano, drum.'
PROMPT = "music loop with saxophone solo. instrument: saxophone."
PROMPT = "music loop with a very strong and rhythmic drum pattern."
melody_conditioned = True

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
Expand All @@ -38,7 +39,9 @@ def generate():
if not melody_conditioned:
wav = model.generate(descriptions, progress=True) # generates 3 samples.
else:
melody, sr = torchaudio.load('/home/intern-2023-02/melodytalk/assets/20230705-155518_3.wav')
# demucs.separate.main(["-n", "htdemucs_6s", "--two-stems", 'drums', '/home/intern-2023-02/melodytalk/assets/20230705-155518_3.wav' ])
melody, sr = torchaudio.load('/home/intern-2023-02/melodytalk/separated/htdemucs_6s/20230705-155518_3/no_drums.wav')

wav = model.generate_continuation(melody[None].expand(SAMPLES, -1, -1), sr, descriptions, progress=True)
# the generated wav contains the melody input, we need to cut it
wav = wav[..., int(melody.shape[-1] / sr * model.sample_rate):]
Expand Down
81 changes: 81 additions & 0 deletions melodytalk/dependencies/lpmc/music_captioning/captioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed

from melodytalk.dependencies.lpmc.music_captioning.model.bart import BartCaptionModel
from melodytalk.dependencies.lpmc.utils.eval_utils import load_pretrained
from melodytalk.dependencies.lpmc.utils.audio_utils import load_audio, STR_CH_FIRST
from omegaconf import OmegaConf

parser = argparse.ArgumentParser(description='PyTorch MSD Training')
parser.add_argument('--gpu', default=1, type=int,
help='GPU id to use.')
parser.add_argument("--framework", default="transfer", type=str)
parser.add_argument("--caption_type", default="lp_music_caps", type=str)
parser.add_argument("--max_length", default=128, type=int)
parser.add_argument("--num_beams", default=5, type=int)
parser.add_argument("--model_type", default="last", type=str)
parser.add_argument("--audio_path", default="../../dataset/samples/orchestra.wav", type=str)

def get_audio(audio_path, duration=10, target_sr=16000):
n_samples = int(duration * target_sr)
audio, sr = load_audio(
path= audio_path,
ch_format= STR_CH_FIRST,
sample_rate= target_sr,
downmix_to_mono= True,
)
if len(audio.shape) == 2:
audio = audio.mean(0, False) # to mono
input_size = int(n_samples)
if audio.shape[-1] < input_size: # pad sequence
pad = np.zeros(input_size)
pad[: audio.shape[-1]] = audio
audio = pad
ceil = int(audio.shape[-1] // n_samples)
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
return audio

def main():
args = parser.parse_args()
captioning(args)

def captioning(args):
save_dir = f"exp/{args.framework}/{args.caption_type}/"
config = OmegaConf.load(os.path.join(save_dir, "hparams.yaml"))
model = BartCaptionModel(max_length = config.max_length)
model, save_epoch = load_pretrained(args, save_dir, model, mdp=config.multiprocessing_distributed)
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
model.eval()

audio_tensor = get_audio(audio_path = args.audio_path)
if args.gpu is not None:
audio_tensor = audio_tensor.cuda(args.gpu, non_blocking=True)

with torch.no_grad():
output = model.generate(
samples=audio_tensor,
num_beams=args.num_beams,
)
inference = {}
number_of_chunks = range(audio_tensor.shape[0])
for chunk, text in zip(number_of_chunks, output):
time = f"{chunk * 10}:00-{(chunk + 1) * 10}:00"
item = {"text":text,"time":time}
inference[chunk] = item
print(item)

if __name__ == '__main__':
main()


58 changes: 58 additions & 0 deletions melodytalk/dependencies/lpmc/music_captioning/datasets/mc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import random
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from lpmc.utils.audio_utils import load_audio, STR_CH_FIRST

class MC_Dataset(Dataset):
def __init__(self, data_path, split, caption_type, sr=16000, duration=10, audio_enc="wav"):
self.data_path = data_path
self.split = split
self.caption_type = caption_type
self.audio_enc = audio_enc
self.n_samples = int(sr * duration)
self.annotation = load_dataset("seungheondoh/LP-MusicCaps-MC")
self.get_split()

def get_split(self):
if self.split == "train":
self.fl = [i for i in self.annotation[self.split] if i['is_crawled']]
elif self.split == "test":
self.fl = [i for i in self.annotation[self.split] if i['is_crawled']]
else:
raise ValueError(f"Unexpected split name: {self.split}")

def load_audio(self, audio_path, file_type):
if file_type == ".npy":
audio = np.load(audio_path, mmap_mode='r')
else:
audio, _ = load_audio(
path=audio_path,
ch_format= STR_CH_FIRST,
sample_rate= self.sr,
downmix_to_mono= True
)
if len(audio.shape) == 2:
audio = audio.squeeze(0)
input_size = int(self.n_samples)
if audio.shape[-1] < input_size:
pad = np.zeros(input_size)
pad[:audio.shape[-1]] = audio
audio = pad
random_idx = random.randint(0, audio.shape[-1]-self.n_samples)
audio_tensor = torch.from_numpy(np.array(audio[random_idx:random_idx+self.n_samples]).astype('float32'))
return audio_tensor

def __getitem__(self, index):
item = self.fl[index]
fname = item['fname']
text = item['caption_ground_truth']
audio_path = os.path.join(self.data_path, "music_caps", 'npy', fname + ".npy")
audio_tensor = self.load_audio(audio_path, file_type=audio_path[-4:])
return fname, text, audio_tensor

def __len__(self):
return len(self.fl)
64 changes: 64 additions & 0 deletions melodytalk/dependencies/lpmc/music_captioning/datasets/msd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from datasets import load_dataset

class MSD_Balanced_Dataset(Dataset):
def __init__(self, data_path, split, caption_type, sr=16000, duration=10, audio_enc="wav"):
self.data_path = data_path
self.split = split
self.audio_enc = audio_enc
self.n_samples = int(sr * duration)
self.caption_type = caption_type
self.dataset = load_dataset("seungheondoh/LP-MusicCaps-MSD")
self.get_split()

def get_split(self):
self.tags = json.load(open(os.path.join(self.data_path, "msd", f"{self.split}_tags.json"), 'r'))
self.tag_to_track = json.load(open(os.path.join(self.data_path, "msd", f"{self.split}_tag_to_track.json"), 'r'))
self.annotation = {instance['track_id'] : instance for instance in self.dataset[self.split]}

def load_caption(self, item):
caption_pool = []
if (self.caption_type in "write") or (self.caption_type == "lp_music_caps"):
caption_pool.append(item['caption_writing'])
if (self.caption_type in "summary") or (self.caption_type == "lp_music_caps"):
caption_pool.append(item['caption_summary'])
if (self.caption_type in "creative") or (self.caption_type == "lp_music_caps"):
caption_pool.append(item['caption_paraphrase'])
if (self.caption_type in "predict") or (self.caption_type == "lp_music_caps"):
caption_pool.append(item['caption_attribute_prediction'])
# randomly select one caption from multiple captions
sampled_caption = random.choice(caption_pool)
return sampled_caption

def load_audio(self, audio_path, file_type):
audio = np.load(audio_path, mmap_mode='r')
if len(audio.shape) == 2:
audio = audio.squeeze(0)
input_size = int(self.n_samples)
if audio.shape[-1] < input_size:
pad = np.zeros(input_size)
pad[:audio.shape[-1]] = audio
audio = pad
random_idx = random.randint(0, audio.shape[-1]-self.n_samples)
audio_tensor = torch.from_numpy(np.array(audio[random_idx:random_idx+self.n_samples]).astype('float32'))
audio_tensor = torch.randn(16000*10)
return audio_tensor

def __getitem__(self, index):
tag = random.choice(self.tags) # uniform random sample tag
fname = random.choice(self.tag_to_track[tag]) # uniform random sample track
item = self.annotation[fname]
track_id = item['track_id']
gt_caption = "" # no ground turhth
text = self.load_caption(item)
audio_path = os.path.join(self.data_path,'msd','npy', item['path'].replace(".mp3", ".npy"))
audio_tensor = self.load_audio(audio_path, file_type=".npy")
return fname, gt_caption, text, audio_tensor

def __len__(self):
return 2**13
51 changes: 51 additions & 0 deletions melodytalk/dependencies/lpmc/music_captioning/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import argparse
import json
import numpy as np
from datasets import load_dataset
from lpmc.utils.metrics import bleu, meteor, rouge, bertscore, vocab_novelty, caption_novelty

def inference_parsing(dataset, args):
ground_truths = [i['caption_ground_truth'] for i in dataset]
inference = json.load(open(os.path.join(args.save_dir, args.framework, args.caption_type, 'inference_temp.json'), 'r'))
id2pred = {item['audio_id']:item['predictions'] for item in inference.values()}
predictions = [id2pred[i['fname']] for i in dataset]
return predictions, ground_truths

def main(args):
dataset = load_dataset("seungheondoh/LP-MusicCaps-MC")
train_data = [i for i in dataset['train'] if i['is_crawled']]
test_data = [i for i in dataset['test'] if i['is_crawled']]
tr_ground_truths = [i['caption_ground_truth'] for i in train_data]
predictions, ground_truths = inference_parsing(test_data, args)
length_avg = np.mean([len(cap.split()) for cap in predictions])
length_std = np.std([len(cap.split()) for cap in predictions])

vocab_size, vocab_novel_score = vocab_novelty(predictions, tr_ground_truths)
cap_novel_score = caption_novelty(predictions, tr_ground_truths)
results = {
"bleu1": bleu(predictions, ground_truths, order=1),
"bleu2": bleu(predictions, ground_truths, order=2),
"bleu3": bleu(predictions, ground_truths, order=3),
"bleu4": bleu(predictions, ground_truths, order=4),
"meteor_1.0": meteor(predictions, ground_truths), # https://github.com/huggingface/evaluate/issues/115
"rougeL": rouge(predictions, ground_truths),
"bertscore": bertscore(predictions, ground_truths),
"vocab_size": vocab_size,
"vocab_novelty": vocab_novel_score,
"caption_novelty": cap_novel_score,
"length_avg": length_avg,
"length_std": length_std
}
os.makedirs(os.path.join(args.save_dir, args.framework, args.caption_type), exist_ok=True)
with open(os.path.join(args.save_dir, args.framework, args.caption_type, f"results.json"), mode="w") as io:
json.dump(results, io, indent=4)
print(results)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", default="./exp", type=str)
parser.add_argument("--framework", default="supervised", type=str)
parser.add_argument("--caption_type", default="gt", type=str)
args = parser.parse_args()
main(args=args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
framework: bart
data_dir: ../../dataset
train_data: msd_balence
text_type: all
arch: transformer
workers: 12
epochs: 4096
warmup_epochs: 125
start_epoch: 0
batch_size: 256
world_size: 1
lr: 0.0001
min_lr: 1.0e-09
rank: 0
dist_url: tcp://localhost:12312
dist_backend: nccl
seed: null
gpu: 0
print_freq: 100
multiprocessing_distributed: false
cos: true
bart_pretrain: false
label_smoothing: 0.1
use_early_stopping: false
eval_sample: 0
max_length: 110
distributed: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"bleu1": 0.19774511999248645,
"bleu2": 0.06701659207795346,
"bleu3": 0.02165946217889739,
"bleu4": 0.007868711032317937,
"meteor_1.0": 0.12883275733962948,
"rougeL": 0.1302757267270023,
"bertscore": 0.8451152142608364,
"vocab_size": 1686,
"vocab_diversity": 0.4721233689205219,
"caption_novelty": 1.0,
"length_avg": 45.26699542092286,
"length_std": 27.99358111821529
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
framework: bart
data_dir: ../../dataset
train_data: music_caps
text_type: gt
arch: transformer
workers: 8
epochs: 100
warmup_epochs: 1
start_epoch: 0
batch_size: 64
world_size: 1
lr: 0.0001
min_lr: 1.0e-09
rank: 0
dist_url: tcp://localhost:12312
dist_backend: nccl
seed: null
gpu: 0
print_freq: 100
multiprocessing_distributed: false
cos: true
bart_pretrain: false
label_smoothing: 0.1
use_early_stopping: false
eval_sample: 64
max_length: 128
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"bleu1": 0.2850628531445219,
"bleu2": 0.13762330787082025,
"bleu3": 0.07588914964864207,
"bleu4": 0.047923589711449235,
"meteor_1.0": 0.20624392533703412,
"rougeL": 0.19215453358420784,
"bertscore": 0.870539891400695,
"vocab_size": 2240,
"vocab_diversity": 0.005357142857142857,
"caption_novelty": 0.6899661781285231,
"length_avg": 46.68193025713279,
"length_std": 16.516237946145356
}
Loading

0 comments on commit e682cb6

Please sign in to comment.