From c7693a5cee8ef8ad8f8c89629d8e6acc2c0da8e5 Mon Sep 17 00:00:00 2001 From: Yixiao Zhang Date: Mon, 26 Jun 2023 17:40:00 +0900 Subject: [PATCH] New feature: add demucs; chain thoughts pass --- melodytalk/main.py | 6 +++--- melodytalk/modules.py | 49 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/melodytalk/main.py b/melodytalk/main.py index fcaf70d..60de21e 100644 --- a/melodytalk/main.py +++ b/melodytalk/main.py @@ -24,7 +24,7 @@ from audiocraft.data.audio import audio_write from utils import prompts, seed_everything, cut_dialogue_history, get_new_audio_name -from modules import Text2Music +from modules import * MELODYTALK_PREFIX = """MelodyTalk is designed to be able to assist with a wide range of text and music related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. MelodyTalk is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. @@ -43,7 +43,7 @@ MelodyTalk has access to the following tools:""" -MELODYTALK_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format: +MELODYTALK_FORMAT_INSTRUCTIONS = """To use a tool, you MUST use the following format: ``` Thought: Do I need to use a tool? Yes @@ -226,7 +226,7 @@ def clear_input_audio(self): if not os.path.exists("checkpoints"): os.mkdir("checkpoints") parser = argparse.ArgumentParser() - parser.add_argument('--load', type=str, default="Text2Music_cuda:0") + parser.add_argument('--load', type=str, default="Text2Music_cuda:0, ExtractTrack_cuda:0") args = parser.parse_args() load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')} bot = ConversationBot(load_dict=load_dict) diff --git a/melodytalk/modules.py b/melodytalk/modules.py index 644d4d6..ac90f98 100644 --- a/melodytalk/modules.py +++ b/melodytalk/modules.py @@ -1,11 +1,16 @@ import os import uuid import torch +from shutil import copyfile +# text2music from audiocraft.models import MusicGen from audiocraft.data.audio import audio_write -from utils import prompts +# source separation +import demucs.separate + +from utils import prompts, get_new_audio_name class Text2Music(object): def __init__(self, device): @@ -31,4 +36,44 @@ def inference(self, text): audio_write(music_filename[:-4], wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True) print(f"\nProcessed Text2Music, Input Text: {text}, Output Music: {music_filename}.") - return music_filename \ No newline at end of file + return music_filename + + +class ExtractTrack(object): + def __init__(self, device): + print("Initializing ExtractTrack") + self.device = device + self.params_list = [ + "-n", "htdemucs_6s", # model selection + "--two-stems", None, # track name + None # original filename + ] + + @prompts( + name="Extract one track from a music file", + description="useful if you want to separate a track (must be one of `vocals`, `drums`, `bass`, `guitar`, `piano` or `other`) from a music file." + "Like: separate vocals from a music file, or extract drums from a music file." + "The input to this tool should be a comma separated string of two, " + "representing the music_filename and the specific track name." + ) + + def inference(self, inputs): + music_filename, instrument = inputs.split(",")[0].strip(), inputs.split(",")[1].strip() + print(f"Extracting {instrument} track from {music_filename}.") + updated_music_filename = get_new_audio_name(music_filename, func_name=f"{instrument}") + + # fill params + self.params_list[-2] = instrument + self.params_list[-1] = music_filename + # run + demucs.separate.main(self.params_list) + # rename + copyfile( + os.path.join("separated", "htdemucs_6s", music_filename[:-4].split("/")[-1],f"{instrument}.wav"), + updated_music_filename + ) + # delete the folder + # os.system(f"rm -rf {os.path.join('separated', 'htdemucs_6s')}") + + print(f"Processed Source Separation, Input Music: {music_filename}, Output Instrument: {instrument}.") + return updated_music_filename