From b558d2f3211fdc0c563fe5cb8ff38bcc5282789c Mon Sep 17 00:00:00 2001 From: Yixiao Zhang Date: Mon, 26 Jun 2023 20:03:11 +0900 Subject: [PATCH] new module: remix --- melodytalk/main.py | 2 +- melodytalk/modules.py | 37 +++++++++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/melodytalk/main.py b/melodytalk/main.py index e5c9eec..e110d21 100644 --- a/melodytalk/main.py +++ b/melodytalk/main.py @@ -226,7 +226,7 @@ def clear_input_audio(self): if not os.path.exists("checkpoints"): os.mkdir("checkpoints") parser = argparse.ArgumentParser() - parser.add_argument('--load', type=str, default="Text2Music_cuda:0, ExtractTrack_cuda:0") + parser.add_argument('--load', type=str, default="Text2Music_cuda:0, ExtractTrack_cuda:0, Text2MusicWithMelody_cuda:0") args = parser.parse_args() load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')} bot = ConversationBot(load_dict=load_dict) diff --git a/melodytalk/modules.py b/melodytalk/modules.py index ac90f98..a64e40f 100644 --- a/melodytalk/modules.py +++ b/melodytalk/modules.py @@ -2,21 +2,25 @@ import uuid import torch from shutil import copyfile +import torchaudio # text2music from audiocraft.models import MusicGen from audiocraft.data.audio import audio_write - # source separation import demucs.separate from utils import prompts, get_new_audio_name + +# Initialze common models +musicgen_model = MusicGen.get_pretrained('melody') + class Text2Music(object): def __init__(self, device): print("Initializing Text2Music") self.device = device - self.model = MusicGen.get_pretrained('melody') + self.model = musicgen_model # Set generation params self.model.set_generation_params(duration=8) @@ -38,6 +42,35 @@ def inference(self, text): print(f"\nProcessed Text2Music, Input Text: {text}, Output Music: {music_filename}.") return music_filename +class Text2MusicWithMelody(object): + def __init__(self, device): + print("Initializing Text2MusicWithMelody") + self.device = device + self.model = musicgen_model + + # Set generation params + self.model.set_generation_params(duration=8) + + @prompts( + name="Generate music from user input text with melody condition", + description="useful if you want to generate, style transfer or remix music from a user input text with a given melody condition." + "like: remix the given melody with text description, or doing style transfer as text described with the given melody." + "The input to this tool should be a comma separated string of two, " + "representing the music_filename and the text description." + ) + + def inference(self, inputs): + music_filename, text = inputs.split(",")[0].strip(), inputs.split(",")[1].strip() + print(f"Generating music from text with melody condition, Input Text: {text}, Melody: {music_filename}.") + updated_music_filename = get_new_audio_name(music_filename, func_name="remix") + melody, sr = torchaudio.load(music_filename) + wav = self.model.generate_with_chroma([text], melody[None].expand(1, -1, -1), sr, progress=False) + wav = wav[0] # batch size is 1 + audio_write(updated_music_filename[:-4], + wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True) + print(f"\nProcessed Text2MusicWithMelody, Output Music: {updated_music_filename}.") + return updated_music_filename + class ExtractTrack(object): def __init__(self, device):