Skip to content

Commit

Permalink
update system args
Browse files Browse the repository at this point in the history
  • Loading branch information
ldzhangyx committed Jul 14, 2023
1 parent c4090f5 commit b6bac89
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
18 changes: 9 additions & 9 deletions melodytalk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@

class ConversationBot(object):
def __init__(self):
load_dict = {"Text2Music":"cuda:0", "ExtractTrack":"cuda:0", "Text2MusicWithMelody":"cuda:0", "SimpleTracksMixing":"cuda:0"}
template_dict = { "Text2MusicwithChord": "cuda:0"} # "Accompaniment": "cuda:0",
load_dict = {"Text2Music":"cuda:0", "ExtractTrack":"cuda:0", "Text2MusicWithMelody":"cuda:0", "Text2MusicWithDrum":"cuda:0"}
template_dict = None #{ "Text2MusicwithChord": "cuda:0"} # "Accompaniment": "cuda:0",

print(f"Initializing MelodyTalk, load_dict={load_dict}, template_dict={template_dict}")

Expand All @@ -138,13 +138,13 @@ def __init__(self):
self.models[class_name] = globals()[class_name](device=device)

# Load Template Foundation Models
for class_name, device in template_dict.items():
template_required_names = {k for k in inspect.signature(globals()[class_name].__init__).parameters.keys() if
k != 'self'}
loaded_names = set([type(e).__name__ for e in self.models.values()])
if template_required_names.issubset(loaded_names):
self.models[class_name] = globals()[class_name](
**{name: self.models[name] for name in template_required_names})
# for class_name, device in template_dict.items():
# template_required_names = {k for k in inspect.signature(globals()[class_name].__init__).parameters.keys() if
# k != 'self'}
# loaded_names = set([type(e).__name__ for e in self.models.values()])
# if template_required_names.issubset(loaded_names):
# self.models[class_name] = globals()[class_name](
# **{name: self.models[name] for name in template_required_names})

print(f"All the Available Functions: {self.models}")

Expand Down
42 changes: 38 additions & 4 deletions melodytalk/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from utils import *

DURATION = 15
DURATION = 6
GENERATION_CANDIDATE = 5

# Initialze common models
Expand Down Expand Up @@ -107,9 +107,10 @@ def __init__(self, device):
self.model = musicgen_model_melody

@prompts(
name="Generate music from user input text with given drum pattern",
description="useful if you want to generate music from a user input text with a given drum pattern."
"like: generate of pop song, and following the drum pattern above."
name="Generate music from user input text based on the drum audio file provided.",
description="useful if you want to generate music from a user input text and a previous given drum pattern."
"Do not use it when no previous music file (generated of uploaded) in the history."
"like: generate a pop song based on the provided drum pattern above."
"The input to this tool should be a comma separated string of two, "
"representing the music_filename and the text description."
)
Expand All @@ -136,7 +137,40 @@ def inference(self, inputs):
return updated_music_filename


class AddNewTrack(object):
def __init__(self, device):
print("Initializing AddNewTrack")
self.device = device
self.model = musicgen_model_melody

@prompts(
name="Add a new track to the given music loop",
description="useful if you want to add a new track (usually add a new instrument) to the given music."
"like: add a saxophone to the given music, or add piano arrangement to the given music."
"The input to this tool should be a comma separated string of two, "
"representing the music_filename and the text description."
)

def inference(self, inputs):
music_filename, text = inputs.split(",")[0].strip(), inputs.split(",")[1].strip()
text = description_to_attributes(text)
print(f"Generating music from text with drum condition, Input Text: {text}, Drum: {music_filename}.")
updated_music_filename = get_new_audio_name(music_filename, func_name="with_drum")
drum, sr = torchaudio.load(music_filename)
self.model.set_generation_params(duration=30)
wav = self.model.generate_continuation(prompt=drum[None].expand(GENERATION_CANDIDATE, -1, -1), prompt_sr=sr,
descriptions=[text] * GENERATION_CANDIDATE, progress=False)
self.model.set_generation_params(duration=DURATION)
# cut drum prompt
wav = wav[:, drum.shape[1]:drum.shape[1] + DURATION * sr]
# TODO: split tracks by beats

# select the best one by CLAP scores
best_wav, _ = CLAP_post_filter(CLAP_model, text, wav, self.model.sample_rate)
audio_write(updated_music_filename[:-4],
best_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
print(f"\nProcessed Text2MusicWithDrum, Output Music: {updated_music_filename}.")
return updated_music_filename


class ExtractTrack(object):
Expand Down

0 comments on commit b6bac89

Please sign in to comment.