Skip to content

Commit

Permalink
new module: remix
Browse files Browse the repository at this point in the history
  • Loading branch information
ldzhangyx committed Jun 26, 2023
1 parent 65bdced commit b558d2f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
2 changes: 1 addition & 1 deletion melodytalk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def clear_input_audio(self):
if not os.path.exists("checkpoints"):
os.mkdir("checkpoints")
parser = argparse.ArgumentParser()
parser.add_argument('--load', type=str, default="Text2Music_cuda:0, ExtractTrack_cuda:0")
parser.add_argument('--load', type=str, default="Text2Music_cuda:0, ExtractTrack_cuda:0, Text2MusicWithMelody_cuda:0")
args = parser.parse_args()
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
bot = ConversationBot(load_dict=load_dict)
Expand Down
37 changes: 35 additions & 2 deletions melodytalk/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@
import uuid
import torch
from shutil import copyfile
import torchaudio

# text2music
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write

# source separation
import demucs.separate

from utils import prompts, get_new_audio_name


# Initialze common models
musicgen_model = MusicGen.get_pretrained('melody')

class Text2Music(object):
def __init__(self, device):
print("Initializing Text2Music")
self.device = device
self.model = MusicGen.get_pretrained('melody')
self.model = musicgen_model

# Set generation params
self.model.set_generation_params(duration=8)
Expand All @@ -38,6 +42,35 @@ def inference(self, text):
print(f"\nProcessed Text2Music, Input Text: {text}, Output Music: {music_filename}.")
return music_filename

class Text2MusicWithMelody(object):
def __init__(self, device):
print("Initializing Text2MusicWithMelody")
self.device = device
self.model = musicgen_model

# Set generation params
self.model.set_generation_params(duration=8)

@prompts(
name="Generate music from user input text with melody condition",
description="useful if you want to generate, style transfer or remix music from a user input text with a given melody condition."
"like: remix the given melody with text description, or doing style transfer as text described with the given melody."
"The input to this tool should be a comma separated string of two, "
"representing the music_filename and the text description."
)

def inference(self, inputs):
music_filename, text = inputs.split(",")[0].strip(), inputs.split(",")[1].strip()
print(f"Generating music from text with melody condition, Input Text: {text}, Melody: {music_filename}.")
updated_music_filename = get_new_audio_name(music_filename, func_name="remix")
melody, sr = torchaudio.load(music_filename)
wav = self.model.generate_with_chroma([text], melody[None].expand(1, -1, -1), sr, progress=False)
wav = wav[0] # batch size is 1
audio_write(updated_music_filename[:-4],
wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
print(f"\nProcessed Text2MusicWithMelody, Output Music: {updated_music_filename}.")
return updated_music_filename


class ExtractTrack(object):
def __init__(self, device):
Expand Down

0 comments on commit b558d2f

Please sign in to comment.