Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ldzhangyx committed Jun 28, 2023
1 parent 45d4156 commit 56adabd
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 31 deletions.
33 changes: 13 additions & 20 deletions melodytalk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,25 @@


class ConversationBot(object):
def __init__(self, load_dict):
print(f"Initializing MelodyTalk, load_dict={load_dict}")
if 'Text2Music' not in load_dict:
raise ValueError("You have to load Text2Music as a basic function for MelodyTalk.")
def __init__(self):
load_dict = {"Text2Music":"cuda:0", "ExtractTrack":"cuda:0", "Text2MusicWithMelody":"cuda:0", "SimpleTracksMixing":"cuda:0"}
template_dict = {"Accompaniment": "cuda:0"}

print(f"Initializing MelodyTalk, load_dict={load_dict}, template_dict={template_dict}")

self.models = {}
# Load Basic Foundation Models
for class_name, device in load_dict.items():
self.models[class_name] = globals()[class_name](device=device)

# Load Template Foundation Models
for class_name, module in globals().items():
if getattr(module, 'template_model', False):
template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if
k != 'self'}
loaded_names = set([type(e).__name__ for e in self.models.values()])
if template_required_names.issubset(loaded_names):
self.models[class_name] = globals()[class_name](
**{name: self.models[name] for name in template_required_names})
for class_name, device in template_dict.items():
template_required_names = {k for k in inspect.signature(globals()[class_name].__init__).parameters.keys() if
k != 'self'}
loaded_names = set([type(e).__name__ for e in self.models.values()])
if template_required_names.issubset(loaded_names):
self.models[class_name] = globals()[class_name](
**{name: self.models[name] for name in template_required_names})

print(f"All the Available Functions: {self.models}")

Expand Down Expand Up @@ -225,14 +225,7 @@ def clear_input_audio(self):
if __name__ == '__main__':
if not os.path.exists("checkpoints"):
os.mkdir("checkpoints")
parser = argparse.ArgumentParser()
parser.add_argument('--load', type=str, default="Text2Music_cuda:0, "
"ExtractTrack_cuda:0, "
"Text2MusicWithMelody_cuda:0,"
"SimpleTracksMixing_cuda:0")
args = parser.parse_args()
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
bot = ConversationBot(load_dict=load_dict)
bot = ConversationBot()
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
lang = gr.Radio(choices = ['Chinese','English'], value=None, label='Language')
chatbot = gr.Chatbot(elem_id="chatbot", label="MelodyTalk")
Expand Down
3 changes: 1 addition & 2 deletions melodytalk/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ def inference(self, inputs):

class Accompaniment(object):
template_model = True
def __init__(self, device, Text2MusicWithMelody, ExtractTrack, SimpleTracksMixing):
def __init__(self, Text2MusicWithMelody, ExtractTrack, SimpleTracksMixing):
print("Initializing Accompaniment")
self.device = device
self.Text2MusicWithMelody = Text2MusicWithMelody
self.ExtractTrack = ExtractTrack
self.SimpleTracksMixing = SimpleTracksMixing
Expand Down
15 changes: 6 additions & 9 deletions melodytalk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ def description_to_attributes(description: str) -> str:
:return:
"""

openai_prompt = f"""Please catch the bpm and key attributes from the original description text. If the description text does not mention it, do not add it. Here are two examples:
openai_prompt = f"""Please format the bpm and key attributes from the original description text and keep the rest unchanged.
If the description text does not mention it, do not add it. Here are two examples:
Q: Generate a love pop song in C major of 120 bpm.
A: Generate a love pop song. bpm: 120. key: Cmaj.
Q: Generate a love pop song in a minor.
A: Generate a love pop song. key: Amin.
Q: love pop song in a minor, creating a romantic atmosphere.
A: love pop song, creating a romantic atmosphere. key: Amin.
Q: {description}.
A:
"""
A: """

response = openai.Completion.create(
model="text-davinci-003",
Expand All @@ -81,7 +81,6 @@ def description_to_attributes(description: str) -> str:
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=["\n"]
)

return response.choices[0].text
Expand All @@ -101,8 +100,7 @@ def chord_generation(description: str, chord_num: int = 4) -> tp.List:
A: Dm - Bb - F - C
Q: {description}. {chord_num} chords.
A:
"""
A: """

response = openai.Completion.create(
model="text-davinci-003",
Expand All @@ -112,7 +110,6 @@ def chord_generation(description: str, chord_num: int = 4) -> tp.List:
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=["\n"]
)

chord_list = [i.strip() for i in response.choices[0].text.split('-')]
Expand Down

0 comments on commit 56adabd

Please sign in to comment.