Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ldzhangyx committed Aug 25, 2023
1 parent e682cb6 commit 18329ba
Show file tree
Hide file tree
Showing 10 changed files with 719 additions and 53 deletions.
Empty file.
Empty file.
7 changes: 5 additions & 2 deletions melodytalk/dependencies/lpmc/music_captioning/captioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def get_audio(audio_path, duration=10, target_sr=16000):
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
return audio

def main():
def main(audio_path=None):
args = parser.parse_args()
captioning(args)
if audio_path is not None:
args.audio_path = audio_path
return captioning(args)

def captioning(args):
save_dir = f"exp/{args.framework}/{args.caption_type}/"
Expand All @@ -74,6 +76,7 @@ def captioning(args):
item = {"text":text,"time":time}
inference[chunk] = item
print(item)
return inference

if __name__ == '__main__':
main()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from lpmc.music_captioning.model.modules import AudioEncoder
from melodytalk.dependencies.lpmc.music_captioning.model.modules import AudioEncoder
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig

class BartCaptionModel(nn.Module):
Expand Down
Empty file.
96 changes: 96 additions & 0 deletions melodytalk/dependencies/transplayer/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import librosa
import resampy

def transform(filepath):
audio, sr = librosa.load(filepath)
if sr != 16000:
audio = resampy.resample(audio, sr, 16000)
cqt_representation = lr.cqt(audio, sr=sr, hop_length=256)

cqt_magnitude = np.abs(cqt_representation)


import os
import argparse
import torch
import numpy as np
from math import ceil
from model import Generator

device = 'cuda:0'


def pad_seq(x, base=32):
len_out = int(base * ceil(float(x.shape[0]) / base))
len_pad = len_out - x.shape[0]
assert len_pad >= 0
return np.pad(x, ((0, len_pad), (0, 0)), 'constant'), len_pad


def inference(input_file_path,
output_file_path,
org='piano', trg='piano',
cp_path=None):
G = Generator(dim_neck=32,
dim_emb=4,
dim_pre=512,
freq=32).eval().to(device)
if os.path.exists(cp_path):
save_info = torch.load(cp_path)
G.load_state_dict(save_info["model"])

# one-hot
ins_list = ['harp', 'trumpet', 'epiano', 'viola', 'piano', 'guitar', 'organ', 'flute']
ins_org = org
ins_trg = trg
emb_org = ins_list.index(ins_org)
emb_trg = ins_list.index(ins_trg)
# emb_org = [i == ins_org for i in ins_list]
# emb_trg = [i == ins_trg for i in ins_list]
emb_org = torch.unsqueeze(torch.tensor(emb_org), dim=0).to(device)
emb_trg = torch.unsqueeze(torch.tensor(emb_trg), dim=0).to(device)

x_org = np.log(np.load(config.feature_path).T)[:config.feature_len]
# x_org = np.load(config.spectrogram_path).T
x_org, len_pad = pad_seq(x_org)
x_org = torch.from_numpy(x_org[np.newaxis, :, :]).to(device)

with torch.no_grad():
_, x_identic_psnt, _ = G(x_org, emb_org, emb_org)
if len_pad == 0:
x_trg = x_identic_psnt[0, 0, :, :].cpu().numpy()
else:
x_trg = x_identic_psnt[0, 0, :-len_pad, :].cpu().numpy()

np.save(os.path.basename(config.feature_path)[:-4] + "_" + ins_org + "_" + ins_org + ".npy", x_trg.T)
print("result saved.")

with torch.no_grad():
_, x_identic_psnt, _ = G(x_org, emb_org, emb_trg)
if len_pad == 0:
x_trg = x_identic_psnt[0, 0, :, :].cpu().numpy()
else:
x_trg = x_identic_psnt[0, 0, :-len_pad, :].cpu().numpy()

np.save(os.path.basename(config.feature_path)[:-4] + "_" + ins_org + "_" + ins_trg + ".npy", x_trg.T)
print("result saved.")


if __name__ == '__main__':
parser = argparse.ArgumentParser()

# Model configuration.
parser.add_argument('--lambda_cd', type=float, default=0, help='weight for hidden code loss')
# Training configuration.
parser.add_argument('--feature_path', type=str, default='../../data_syn/cropped/piano_all_00.wav_cqt.npy')
parser.add_argument('--feature_len', type=int, default=2400)
# parser.add_argument('--num_iters', type=int, default=1000000, help='number of total iterations')
# parser.add_argument('--len_crop', type=int, default=128, help='dataloader output sequence length')

# Miscellaneous.
parser.add_argument('--cp_path', type=str,
default="../../autovc_cp/weights_log_cqt_down32_neck32_onehot4_withcross")

config = parser.parse_args()
print(config)
inference(config)
Loading

0 comments on commit 18329ba

Please sign in to comment.