Skip to content

Commit

Permalink
new feature: addnewtrack (test pass)
Browse files Browse the repository at this point in the history
  • Loading branch information
ldzhangyx committed Jul 25, 2023
1 parent fd8950f commit df07290
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 59 deletions.
18 changes: 13 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
import os

import torchaudio
from melodytalk.dependencies.audiocraft import MusicGen
from melodytalk.dependencies.audiocraft.models import MusicGen
from melodytalk.dependencies.audiocraft.data.audio import audio_write
from melodytalk.dependencies.laion_clap.hook import CLAP_Module
from datetime import datetime
import torch
from melodytalk.utils import CLAP_post_filter

MODEL_NAME = 'melody'
DURATION = 5
DURATION = 35
CFG_COEF = 3
SAMPLES = 5
# PROMPT = 'music loop. Passionate love song with guitar rhythms, electric piano chords, drums pattern. instrument: guitar, piano, drum.'
PROMPT = "Pop dance music loop with catchy melodies, tropical percussion, and upbeat rhythms, perfect for the beach. "
PROMPT = "music loop with saxophone solo. instrument: saxophone."
melody_conditioned = True

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

model = MusicGen.get_pretrained(MODEL_NAME, device='cuda')

DURATION_1 = min(DURATION, 30)
DURATION_2 = max(DURATION - 30, 0)
DURATION_1 = min(DURATION, 40)
DURATION_2 = max(DURATION - 40, 0)
OVERLAP = 8

model.set_generation_params(duration=DURATION_1,
cfg_coef=CFG_COEF,) # generate 8 seconds.

# CLAP_model = CLAP_Module(enable_fusion=False, amodel="HTSAT-base", device="cuda")
# CLAP_model.load_ckpt("/home/intern-2023-02/melodytalk/melodytalk/pretrained/music_audioset_epoch_15_esc_90.14.pt")
# wav = model.generate_unconditional(4) # generates 4 unconditional audio samples
descriptions = [PROMPT] * SAMPLES
# 'A slow and heartbreaking love song at tempo of 60',
Expand All @@ -35,6 +40,9 @@ def generate():
else:
melody, sr = torchaudio.load('/home/intern-2023-02/melodytalk/assets/20230705-155518_3.wav')
wav = model.generate_continuation(melody[None].expand(SAMPLES, -1, -1), sr, descriptions, progress=True)
# the generated wav contains the melody input, we need to cut it
wav = wav[..., int(melody.shape[-1] / sr * model.sample_rate):]
# best_wav, _ = CLAP_post_filter(CLAP_model, PROMPT, wav, model.sample_rate)
if DURATION_2 > 0:
wav_ = wav[:, :, -OVERLAP * model.sample_rate:]
model.set_generation_params(duration=(OVERLAP + DURATION_2))
Expand Down
2 changes: 1 addition & 1 deletion melodytalk/dependencies/audiocraft/models/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class MusicGen:
lm (LMModel): Language model over discrete representations.
"""
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
max_duration: float = 30):
max_duration: float = 40):
self.name = name
self.compression_model = compression_model
self.lm = lm
Expand Down
20 changes: 12 additions & 8 deletions melodytalk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from modules import *


MELODYTALK_PREFIX = """MelodyTalk is designed to be able to assist with a wide range of text and music related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. MelodyTalk is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
MelodyTalk is able to process and understand large amounts of text and music. As a language model, MelodyTalk can not directly read music, but it has a list of tools to finish different music tasks. Each music will have a file name formed as "music/xxx.wav", and MelodyTalk can invoke different tools to indirectly understand music. When talking about music, MelodyTalk is very strict to the file name and will never fabricate nonexistent files.
Expand Down Expand Up @@ -114,8 +113,12 @@

class ConversationBot(object):
def __init__(self):
load_dict = {"Text2Music":"cuda:0", "ExtractTrack":"cuda:0", "Text2MusicWithMelody":"cuda:0", "Text2MusicWithDrum":"cuda:0"}
template_dict = None #{ "Text2MusicwithChord": "cuda:0"} # "Accompaniment": "cuda:0",
load_dict = {"Text2Music": "cuda:0",
"ExtractTrack": "cuda:0",
"Text2MusicWithMelody": "cuda:0",
"Text2MusicWithDrum": "cuda:0",
"AddNewTrack": "cuda:0"}
template_dict = None # { "Text2MusicwithChord": "cuda:0"} # "Accompaniment": "cuda:0",

print(f"Initializing MelodyTalk, load_dict={load_dict}, template_dict={template_dict}")

Expand Down Expand Up @@ -178,7 +181,7 @@ def run_text(self, text, state):
state = state + [(text, res['output'])]
if len(res['intermediate_steps']) > 0:
audio_filename = res['intermediate_steps'][-1][1]
state = state + [(None,(audio_filename,))]
state = state + [(None, (audio_filename,))]
# print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
# f"Current Memory: {self.agent.memory.buffer}")
return state, state
Expand Down Expand Up @@ -213,12 +216,13 @@ def run_recording(self, file_path, state, txt, lang):
def clear_input_audio(self):
return gr.Audio.update(value=None)


if __name__ == '__main__':
if not os.path.exists("checkpoints"):
os.mkdir("checkpoints")
bot = ConversationBot()
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
lang = gr.Radio(choices = ['Chinese','English'], value=None, label='Language')
lang = gr.Radio(choices=['Chinese', 'English'], value=None, label='Language')
chatbot = gr.Chatbot(elem_id="chatbot", label="MelodyTalk")
state = gr.State([])
with gr.Row(visible=False) as input_raws:
Expand All @@ -228,11 +232,11 @@ def clear_input_audio(self):
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("Clear")
with gr.Column(scale=0.15, min_width=0):
btn = gr.UploadButton("Upload",file_types=["audio"])
btn = gr.UploadButton("Upload", file_types=["audio"])

with gr.Row(visible=False) as record_raws:
with gr.Column(scale=0.7):
rec_audio = gr.Audio(source='microphone', type='filepath', interactive=True)
rec_audio = gr.Audio(source='microphone', type='filepath', interactive=True, show_label=False)
with gr.Column(scale=0.15, min_width=0):
rec_clear = gr.Button("Re-recording")
with gr.Column(scale=0.15, min_width=0):
Expand All @@ -251,4 +255,4 @@ def clear_input_audio(self):
clear.click(lambda: [], None, state)
clear.click(bot.clear_input_audio, None, rec_audio)
demo.launch(server_name="0.0.0.0", server_port=7860,
ssl_certfile="cert.pem", ssl_keyfile="key.pem", ssl_verify=False)
ssl_certfile="cert.pem", ssl_keyfile="key.pem", ssl_verify=False)
50 changes: 26 additions & 24 deletions melodytalk/modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from shutil import copyfile
from dataclasses import dataclass

import torch

# text2music
from melodytalk.dependencies.audiocraft.models import MusicGen
from melodytalk.dependencies.audiocraft.data.audio import audio_write
Expand All @@ -18,8 +20,10 @@
# musicgen_model = MusicGen.get_pretrained('large')
# musicgen_model.set_generation_params(duration=DURATION)

musicgen_model_melody = MusicGen.get_pretrained('melody')
musicgen_model_melody.set_generation_params(duration=DURATION)
musicgen_model = MusicGen.get_pretrained('melody')
musicgen_model.set_generation_params(duration=DURATION)

# musicgen_model = torch.compile(musicgen_model)

# Intialize CLIP post filter
CLAP_model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", device="cuda")
Expand Down Expand Up @@ -94,7 +98,7 @@ class Text2MusicWithMelody(object):
def __init__(self, device):
print("Initializing Text2MusicWithMelody")
self.device = device
self.model = musicgen_model_melody
self.model = musicgen_model

@prompts(
name="Generate music from user input text with given melody condition",
Expand Down Expand Up @@ -123,11 +127,11 @@ class Text2MusicWithDrum(object):
def __init__(self, device):
print("Initializing Text2MusicWithDrum")
self.device = device
self.model = musicgen_model_melody
self.model = musicgen_model

@prompts(
name="Generate music from user input text based on the drum track provided.",
description="useful if you want to generate music from a user input text and a previous given drum track."
name="Generate music from user input text based on the drum audio file provided.",
description="useful if you want to generate music from a user input text and a previous given drum audio file."
"Do not use it when no previous music file (generated of uploaded) in the history."
"like: generate a pop song based on the provided drum pattern above."
"The input to this tool should be a comma separated string of two, "
Expand All @@ -140,18 +144,16 @@ def inference(self, inputs):
print(f"Generating music from text with drum condition, Input text: {text}, Drum: {music_filename}.")
updated_music_filename = get_new_audio_name(music_filename, func_name="with_drum")
drum, sr = torchaudio.load(music_filename)
self.model.set_generation_params(duration=30)
self.model.set_generation_params(duration=35)
wav = self.model.generate_continuation(prompt=drum[None].expand(GENERATION_CANDIDATE, -1, -1), prompt_sr=sr,
descriptions=[text] * GENERATION_CANDIDATE, progress=False)
self.model.set_generation_params(duration=DURATION)
# cut drum prompt
wav = wav[:, drum.shape[1]:drum.shape[1] + DURATION * sr]
# TODO: split tracks by beats

wav = wav[..., int(drum.shape[-1] / sr * self.model.sample_rate):]
splitted_audios = split_audio_tensor_by_downbeats(wav.cpu(), self.model.sample_rate, True)
# select the best one by CLAP scores
best_wav, _ = CLAP_post_filter(CLAP_model, text, wav, self.model.sample_rate)
audio_write(updated_music_filename[:-4],
best_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
print(f"CLAP post filter for {len(splitted_audios)} candidates.")
best_wav, _ = CLAP_post_filter(CLAP_model, text, splitted_audios.cuda(), self.model.sample_rate)
print(f"\nProcessed Text2MusicWithDrum, Output Music: {updated_music_filename}.")
return updated_music_filename

Expand All @@ -160,7 +162,7 @@ class AddNewTrack(object):
def __init__(self, device):
print("Initializing AddNewTrack")
self.device = device
self.model = musicgen_model_melody
self.model = musicgen_model

@prompts(
name="Add a new track to the given music loop",
Expand All @@ -172,23 +174,23 @@ def __init__(self, device):

def inference(self, inputs):
music_filename, text = inputs.split(",")[0].strip(), inputs.split(",")[1].strip()
text = description_to_attributes(text)
text = addtrack_demand_to_description(text)
print(f"Adding a new track, Input text: {text}, Previous track: {music_filename}.")
updated_music_filename = get_new_audio_name(music_filename, func_name="add_track")
updated_music_filename = get_new_audio_name(music_filename, func_name="addtrack")
p_track, sr = torchaudio.load(music_filename)
self.model.set_generation_params(duration=30)
wav = self.model.generate_continuation(prompt=p_track[None].expand(GENERATION_CANDIDATE, -1, -1), prompt_sr=sr,
self.model.set_generation_params(duration=35)
wav = self.model.generate_continuation(prompt=p_track[None].expand(GENERATION_CANDIDATE, -1, -1), prompt_sample_rate=sr,
descriptions=[text] * GENERATION_CANDIDATE, progress=False)
self.model.set_generation_params(duration=DURATION)
# cut drum prompt
wav = wav[:, p_track.shape[1]:p_track.shape[1] + DURATION * sr]
# TODO: split tracks by beats

wav = wav[..., int(p_track.shape[-1] / sr * self.model.sample_rate):]
splitted_audios = split_audio_tensor_by_downbeats(wav.cpu(), self.model.sample_rate, True)
# select the best one by CLAP scores
best_wav, _ = CLAP_post_filter(CLAP_model, text, wav, self.model.sample_rate)
print(f"CLAP post filter for {len(splitted_audios)} candidates.")
best_wav, _ = CLAP_post_filter(CLAP_model, text, splitted_audios.cuda(), self.model.sample_rate)
audio_write(updated_music_filename[:-4],
best_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
print(f"\nProcessed Text2MusicWithDrum, Output Music: {updated_music_filename}.")
print(f"\nProcessed AddNewTrack, Output Music: {updated_music_filename}.")
return updated_music_filename


Expand Down Expand Up @@ -217,7 +219,7 @@ def inference(self, inputs):
if mode == "extract":
instrument_mode = instrument
elif mode == "remove":
instrument_mode = f"no_{instrument}"
instrument_mode = f"no{instrument}"
else:
raise ValueError("mode must be `extract` or `remove`.")

Expand Down
Loading

0 comments on commit df07290

Please sign in to comment.