diff --git a/melodytalk/main.py b/melodytalk/main.py index 18d15d4..ed9fe3a 100644 --- a/melodytalk/main.py +++ b/melodytalk/main.py @@ -221,10 +221,18 @@ def clear_input_audio(self): if not os.path.exists("checkpoints"): os.mkdir("checkpoints") bot = ConversationBot() + with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo: + + gr.Markdown( + """This is a demo to our work *MelodyTalk*. + """ + ) + lang = gr.Radio(choices=['Chinese', 'English'], value=None, label='Language') chatbot = gr.Chatbot(elem_id="chatbot", label="MelodyTalk") state = gr.State([]) + with gr.Row(visible=False) as input_raws: with gr.Column(scale=0.7): txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an audio").style( diff --git a/melodytalk/modules.py b/melodytalk/modules.py index 49b9f3f..97db424 100644 --- a/melodytalk/modules.py +++ b/melodytalk/modules.py @@ -13,7 +13,7 @@ from utils import * -DURATION = 6 +DURATION = 8 GENERATION_CANDIDATE = 5 # Initialze common models @@ -66,11 +66,10 @@ def description_to_attributes_wrapper(self, description: str) -> str: + # attribute management -global attribute_table attribute_table = GlobalAttributes() - class Text2Music(object): def __init__(self, device): print("Initializing Text2Music") @@ -86,6 +85,7 @@ def __init__(self, device): def inference(self, text): music_filename = os.path.join("music", f"{str(uuid.uuid4())[:8]}.wav") + attribute_table.descriptions = text text = description_to_attributes(text) # convert text to attributes wav = self.model.generate([text], progress=False) wav = wav[0] # batch size is 1 @@ -112,6 +112,7 @@ def __init__(self, device): def inference(self, inputs): music_filename, text = inputs.split(",")[0].strip(), inputs.split(",")[1].strip() + attribute_table.descriptions = text text = description_to_attributes(text) # convert text to attributes print(f"Generating music from text with melody condition, Input Text: {text}, Melody: {music_filename}.") updated_music_filename = get_new_audio_name(music_filename, func_name="remix") @@ -174,6 +175,7 @@ def __init__(self, device): def inference(self, inputs): music_filename, text = inputs.split(",")[0].strip(), inputs.split(",")[1].strip() + attribute_table.descriptions = merge_description(attribute_table.descriptions, text) text = addtrack_demand_to_description(text) print(f"Adding a new track, Input text: {text}, Previous track: {music_filename}.") updated_music_filename = get_new_audio_name(music_filename, func_name="addtrack") @@ -187,7 +189,7 @@ def inference(self, inputs): splitted_audios = split_audio_tensor_by_downbeats(wav.cpu(), self.model.sample_rate, True) # select the best one by CLAP scores print(f"CLAP post filter for {len(splitted_audios)} candidates.") - best_wav, _ = CLAP_post_filter(CLAP_model, text, splitted_audios.cuda(), self.model.sample_rate) + best_wav, _ = CLAP_post_filter(CLAP_model, attribute_table.descriptions, splitted_audios.cuda(), self.model.sample_rate) audio_write(updated_music_filename[:-4], best_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True) print(f"\nProcessed AddNewTrack, Output Music: {updated_music_filename}.") diff --git a/melodytalk/utils.py b/melodytalk/utils.py index 21a5036..378c788 100644 --- a/melodytalk/utils.py +++ b/melodytalk/utils.py @@ -122,6 +122,32 @@ def addtrack_demand_to_description(description: str) -> str: return response.choices[0].text +def merge_description(description_1: str, description_2: str) -> str: + openai_prompt = f"""Please merge two descriptions into one. + + S1: Please generate a rock music with drum and guitar for me. + S2: Please add a saxophone track to this music. + A: rock music with drum, guitar and saxophone. + + S1: Please generate a love pop song with piano and violin for me. + S2: Please remove the piano. + A: love pop song with violin. + + S1: {description_1}. + S2: {description_2}. + A: """ + + response = openai.Completion.create( + model="text-davinci-003", + prompt=openai_prompt, + temperature=0, + max_tokens=100, + top_p=1, + frequency_penalty=0.0, + presence_penalty=0.0, + ) + + return response.choices[0].text def chord_generation(description: str) -> tp.List: """ This function is a trick to generate chord sequence from the description. @@ -216,7 +242,7 @@ def CLAP_post_filter(clap_model, # get the index of the most similar audio index = torch.argmax(similarity) best = audio_candidates[index].view(1, -1) - best = resampy.resample(best.cpu().numpy(), 48000, audio_sr, axis=-1) + best = torch.from_numpy(resampy.resample(best.cpu().numpy(), 48000, audio_sr, axis=-1)) # return return best, similarity[index]