diff --git a/melodytalk/modules.py b/melodytalk/modules.py index a64e40f..5856aa9 100644 --- a/melodytalk/modules.py +++ b/melodytalk/modules.py @@ -15,6 +15,7 @@ # Initialze common models musicgen_model = MusicGen.get_pretrained('melody') +musicgen_model.set_generation_params(duration=8) class Text2Music(object): def __init__(self, device): @@ -22,9 +23,6 @@ def __init__(self, device): self.device = device self.model = musicgen_model - # Set generation params - self.model.set_generation_params(duration=8) - @prompts( name="Generate music from user input text", description="useful if you want to generate music from a user input text and save it to a file." @@ -48,9 +46,6 @@ def __init__(self, device): self.device = device self.model = musicgen_model - # Set generation params - self.model.set_generation_params(duration=8) - @prompts( name="Generate music from user input text with melody condition", description="useful if you want to generate, style transfer or remix music from a user input text with a given melody condition." @@ -72,6 +67,9 @@ def inference(self, inputs): return updated_music_filename + + + class ExtractTrack(object): def __init__(self, device): print("Initializing ExtractTrack") @@ -83,17 +81,25 @@ def __init__(self, device): ] @prompts( - name="Extract one track from a music file", + name="Separate one track from a music file to extract (return the single track) or remove (return the mixture of the rest tracks) it.", description="useful if you want to separate a track (must be one of `vocals`, `drums`, `bass`, `guitar`, `piano` or `other`) from a music file." - "Like: separate vocals from a music file, or extract drums from a music file." - "The input to this tool should be a comma separated string of two, " - "representing the music_filename and the specific track name." + "Like: separate vocals from a music file, or remove the drum track from a music file." + "The input to this tool should be a comma separated string of three params, " + "representing the music_filename, the specific track name, and the mode (must be `extract` or `remove`)." ) def inference(self, inputs): - music_filename, instrument = inputs.split(",")[0].strip(), inputs.split(",")[1].strip() - print(f"Extracting {instrument} track from {music_filename}.") - updated_music_filename = get_new_audio_name(music_filename, func_name=f"{instrument}") + music_filename, instrument, mode = inputs.split(",")[0].strip(), inputs.split(",")[1].strip(), inputs.split(",")[2].strip() + print(f"{mode}ing {instrument} track from {music_filename}.") + + if mode == "extract": + instrument_mode = instrument + elif mode == "remove": + instrument_mode = f"no_{instrument}" + else: + raise ValueError("mode must be `extract` or `remove`.") + + updated_music_filename = get_new_audio_name(music_filename, func_name=f"{instrument_mode}") # fill params self.params_list[-2] = instrument @@ -102,11 +108,49 @@ def inference(self, inputs): demucs.separate.main(self.params_list) # rename copyfile( - os.path.join("separated", "htdemucs_6s", music_filename[:-4].split("/")[-1],f"{instrument}.wav"), + os.path.join("separated", "htdemucs_6s", music_filename[:-4].split("/")[-1],f"{instrument_mode}.wav"), updated_music_filename ) # delete the folder # os.system(f"rm -rf {os.path.join('separated', 'htdemucs_6s')}") - print(f"Processed Source Separation, Input Music: {music_filename}, Output Instrument: {instrument}.") + print(f"Processed Source Separation, Input Music: {music_filename}, Output Instrument: {instrument_mode}.") return updated_music_filename + + +class SimpleTracksMixing(object): + def __init__(self, device): + print("Initializing SimpleTracksMixing") + self.device = device + + @prompts( + name="Simply mixing two tracks from two music files.", + description="useful if you want to mix two tracks from two music files." + "Like: mix the vocals track from a music file with the drums track from another music file." + "The input to this tool should be a comma separated string of two, " + "representing the first music_filename_1 and the second music_filename_2." + ) + + def inference(self, inputs): + music_filename_1, music_filename_2 = inputs.split(",")[0].strip(), inputs.split(",")[1].strip() + print(f"Mixing two tracks from two music files, Input Music 1: {music_filename_1}, Input Music 2: {music_filename_2}.") + updated_music_filename = get_new_audio_name(music_filename_1, func_name="mixing") + # load + wav_1, sr_1 = torchaudio.load(music_filename_1) + wav_2, sr_2 = torchaudio.load(music_filename_2) + # resample + if sr_1 != sr_2: + wav_2 = torchaudio.transforms.Resample(sr_2, sr_1)(wav_2) + # pad or cut + if wav_1.shape[-1] > wav_2.shape[-1]: + wav_2 = torch.cat([wav_2, torch.zeros_like(wav_1[:, wav_2.shape[-1]:])], dim=-1) + elif wav_1.shape[-1] < wav_2.shape[-1]: + wav_2 = wav_2[:, :wav_1.shape[-1]] + # mix + assert wav_1.shape == wav_2.shape # channel, length + wav = torch.add(wav_1, wav_2) + # write + audio_write(updated_music_filename[:-4], + wav.cpu(), sr_1, strategy="loudness", loudness_compressor=True) + print(f"\nProcessed TracksMixing, Output Music: {updated_music_filename}.") + return updated_music_filename \ No newline at end of file