Skip to content

Commit

Permalink
new feature: generate with drum
Browse files Browse the repository at this point in the history
  • Loading branch information
ldzhangyx committed Jul 14, 2023
1 parent 3b8ceff commit c4090f5
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 105 deletions.
244 changes: 140 additions & 104 deletions melodytalk/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from utils import *

DURATION = 15
GENERATION_CANDIDATE = 5

# Initialze common models
# musicgen_model = MusicGen.get_pretrained('large')
Expand Down Expand Up @@ -99,6 +100,41 @@ def inference(self, inputs):
print(f"\nProcessed Text2MusicWithMelody, Output Music: {updated_music_filename}.")
return updated_music_filename

class Text2MusicWithDrum(object):
def __init__(self, device):
print("Initializing Text2MusicWithDrum")
self.device = device
self.model = musicgen_model_melody

@prompts(
name="Generate music from user input text with given drum pattern",
description="useful if you want to generate music from a user input text with a given drum pattern."
"like: generate of pop song, and following the drum pattern above."
"The input to this tool should be a comma separated string of two, "
"representing the music_filename and the text description."
)

def inference(self, inputs):
music_filename, text = inputs.split(",")[0].strip(), inputs.split(",")[1].strip()
text = description_to_attributes(text)
print(f"Generating music from text with drum condition, Input Text: {text}, Drum: {music_filename}.")
updated_music_filename = get_new_audio_name(music_filename, func_name="with_drum")
drum, sr = torchaudio.load(music_filename)
self.model.set_generation_params(duration=30)
wav = self.model.generate_continuation(prompt=drum[None].expand(GENERATION_CANDIDATE, -1, -1), prompt_sr=sr,
descriptions=[text] * GENERATION_CANDIDATE, progress=False)
self.model.set_generation_params(duration=DURATION)
# cut drum prompt
wav = wav[:, drum.shape[1]:drum.shape[1] + DURATION * sr]
# TODO: split tracks by beats

# select the best one by CLAP scores
best_wav, _ = CLAP_post_filter(CLAP_model, text, wav, self.model.sample_rate)
audio_write(updated_music_filename[:-4],
best_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
print(f"\nProcessed Text2MusicWithDrum, Output Music: {updated_music_filename}.")
return updated_music_filename




Expand Down Expand Up @@ -151,120 +187,120 @@ def inference(self, inputs):
return updated_music_filename


class SimpleTracksMixing(object):
def __init__(self, device):
print("Initializing SimpleTracksMixing")
self.device = device

@prompts(
name="Simply mixing two tracks from two music files.",
description="useful if you want to mix two tracks from two music files."
"Like: mix the vocals track from a music file with the drums track from another music file."
"The input to this tool should be a comma separated string of two, "
"representing the first music_filename_1 and the second music_filename_2."
)

def inference(self, inputs):
music_filename_1, music_filename_2 = inputs.split(",")[0].strip(), inputs.split(",")[1].strip()
print(f"Mixing two tracks from two music files, Input Music 1: {music_filename_1}, Input Music 2: {music_filename_2}.")
updated_music_filename = get_new_audio_name(music_filename_1, func_name="mixing")
# load
wav_1, sr_1 = torchaudio.load(music_filename_1)
wav_2, sr_2 = torchaudio.load(music_filename_2)
# resample
if sr_1 != sr_2:
wav_2 = torchaudio.transforms.Resample(sr_2, sr_1)(wav_2)
# pad or cut
if wav_1.shape[-1] > wav_2.shape[-1]:
wav_2 = torch.cat([wav_2, torch.zeros_like(wav_1[:, wav_2.shape[-1]:])], dim=-1)
elif wav_1.shape[-1] < wav_2.shape[-1]:
wav_2 = wav_2[:, :wav_1.shape[-1]]
# mix
assert wav_1.shape == wav_2.shape # channel, length
wav = torch.add(wav_1, wav_2 * 0.7)
# write
audio_write(updated_music_filename[:-4],
wav.cpu(), sr_1, strategy="loudness", loudness_compressor=True)
print(f"\nProcessed TracksMixing, Output Music: {updated_music_filename}.")
return updated_music_filename
# class SimpleTracksMixing(object):
# def __init__(self, device):
# print("Initializing SimpleTracksMixing")
# self.device = device
#
# @prompts(
# name="Simply mixing two tracks from two music files.",
# description="useful if you want to mix two tracks from two music files."
# "Like: mix the vocals track from a music file with the drums track from another music file."
# "The input to this tool should be a comma separated string of two, "
# "representing the first music_filename_1 and the second music_filename_2."
# )
#
# def inference(self, inputs):
# music_filename_1, music_filename_2 = inputs.split(",")[0].strip(), inputs.split(",")[1].strip()
# print(f"Mixing two tracks from two music files, Input Music 1: {music_filename_1}, Input Music 2: {music_filename_2}.")
# updated_music_filename = get_new_audio_name(music_filename_1, func_name="mixing")
# # load
# wav_1, sr_1 = torchaudio.load(music_filename_1)
# wav_2, sr_2 = torchaudio.load(music_filename_2)
# # resample
# if sr_1 != sr_2:
# wav_2 = torchaudio.transforms.Resample(sr_2, sr_1)(wav_2)
# # pad or cut
# if wav_1.shape[-1] > wav_2.shape[-1]:
# wav_2 = torch.cat([wav_2, torch.zeros_like(wav_1[:, wav_2.shape[-1]:])], dim=-1)
# elif wav_1.shape[-1] < wav_2.shape[-1]:
# wav_2 = wav_2[:, :wav_1.shape[-1]]
# # mix
# assert wav_1.shape == wav_2.shape # channel, length
# wav = torch.add(wav_1, wav_2 * 0.7)
# # write
# audio_write(updated_music_filename[:-4],
# wav.cpu(), sr_1, strategy="loudness", loudness_compressor=True)
# print(f"\nProcessed TracksMixing, Output Music: {updated_music_filename}.")
# return updated_music_filename


class MusicCaptioning(object):
def __init__(self):
raise NotImplementedError

class Text2MusicwithChord(object):
template_model = True
def __init__(self, Text2Music):
print("Initializing Text2MusicwithChord")
self.Text2Music = Text2Music

@prompts(
name="Generate music from user input text and chord description",
description="useful only if you want to generate music from a user input text and explicitly mention a chord description."
"Like: generate a pop love song with piano and a chord progression of C - F - G - C, or generate a sad music with a jazz chord progression."
"This tool will automatically extract chord information and generate music."
"The input to this tool should be the user input text. "
)

def inference(self, inputs):
music_filename = os.path.join("music", f"{str(uuid.uuid4())[:8]}.wav")

chords_list = chord_generation(inputs)
preprocessed_input = description_to_attributes(inputs)

for i, chord in enumerate(chords_list):
text = f"{preprocessed_input} key: {chord}."
self.Text2Music.model.set_generation_params(duration=(i + 1) * (DURATION / len(chords_list)))
if i == 0:
wav = self.Text2Music.model.generate([text], progress=False)
else:
wav = self.Text2Music.model.generate_continuation(wav,
self.Text2Music.model.sample_rate,
[text],
progress=False)
if i == len(chords_list) - 1:
wav = wav[0] # batch size is 1
audio_write(music_filename[:-4],
wav.cpu(), self.Text2Music.model.sample_rate, strategy="loudness", loudness_compressor=True)
self.Text2Music.model.set_generation_params(duration=DURATION)
print(f"\nProcessed Text2Music, Input Text: {preprocessed_input}, Output Music: {music_filename}.")
return music_filename
# class Text2MusicwithChord(object):
# template_model = True
# def __init__(self, Text2Music):
# print("Initializing Text2MusicwithChord")
# self.Text2Music = Text2Music
#
# @prompts(
# name="Generate music from user input text and chord description",
# description="useful only if you want to generate music from a user input text and explicitly mention a chord description."
# "Like: generate a pop love song with piano and a chord progression of C - F - G - C, or generate a sad music with a jazz chord progression."
# "This tool will automatically extract chord information and generate music."
# "The input to this tool should be the user input text. "
# )
#
# def inference(self, inputs):
# music_filename = os.path.join("music", f"{str(uuid.uuid4())[:8]}.wav")
#
# chords_list = chord_generation(inputs)
# preprocessed_input = description_to_attributes(inputs)
#
# for i, chord in enumerate(chords_list):
# text = f"{preprocessed_input} key: {chord}."
# self.Text2Music.model.set_generation_params(duration=(i + 1) * (DURATION / len(chords_list)))
# if i == 0:
# wav = self.Text2Music.model.generate([text], progress=False)
# else:
# wav = self.Text2Music.model.generate_continuation(wav,
# self.Text2Music.model.sample_rate,
# [text],
# progress=False)
# if i == len(chords_list) - 1:
# wav = wav[0] # batch size is 1
# audio_write(music_filename[:-4],
# wav.cpu(), self.Text2Music.model.sample_rate, strategy="loudness", loudness_compressor=True)
# self.Text2Music.model.set_generation_params(duration=DURATION)
# print(f"\nProcessed Text2Music, Input Text: {preprocessed_input}, Output Music: {music_filename}.")
# return music_filename



class MusicInpainting(object):
def __init__(self):
raise NotImplementedError

class Accompaniment(object):
template_model = True
def __init__(self, Text2MusicWithMelody, ExtractTrack, SimpleTracksMixing):
print("Initializing Accompaniment")
self.Text2MusicWithMelody = Text2MusicWithMelody
self.ExtractTrack = ExtractTrack
self.SimpleTracksMixing = SimpleTracksMixing

@prompts(
name="Generate accompaniment music from user input text, keeping the given melody or track",
description="useful if you want to style transfer or remix music from a user input text with a given melody."
"Unlike Text2MusicWithMelody, this tool will keep the given melody track instead of re-generate it."
"Note that the user must assign a track (it must be one of `vocals`, `drums`, `bass`, `guitar`, `piano` or `other`) to keep."
"like: keep the guitar track and remix the given music with text description, "
"or generate accompaniment as text described with the given vocal track."
"The input to this tool should be a comma separated string of three, "
"representing the music_filename, track name, and the text description."
)

def inference(self, inputs):
music_filename, track_name, text = inputs.split(",")[0].strip(), inputs.split(",")[1].strip(), inputs.split(",")[2].strip()
print(f"Generating music from text with accompaniment condition, Input Text: {text}, Previous music: {music_filename}, Track: {track_name}.")
# separate the track
updated_main_track = self.ExtractTrack.inference(f"{music_filename}, {track_name}, extract")
# generate music
updated_new_music = self.Text2MusicWithMelody.inference(f"{updated_main_track}, {text}")
# remove the track in accompaniment
updated_accompaniment = self.ExtractTrack.inference(f"{updated_new_music}, {track_name}, remove")
# mix
updated_music_filename = self.SimpleTracksMixing.inference(f"{updated_main_track}, {updated_accompaniment}")
return updated_music_filename
# class Accompaniment(object):
# template_model = True
# def __init__(self, Text2MusicWithMelody, ExtractTrack, SimpleTracksMixing):
# print("Initializing Accompaniment")
# self.Text2MusicWithMelody = Text2MusicWithMelody
# self.ExtractTrack = ExtractTrack
# self.SimpleTracksMixing = SimpleTracksMixing
#
# @prompts(
# name="Generate accompaniment music from user input text, keeping the given melody or track",
# description="useful if you want to style transfer or remix music from a user input text with a given melody."
# "Unlike Text2MusicWithMelody, this tool will keep the given melody track instead of re-generate it."
# "Note that the user must assign a track (it must be one of `vocals`, `drums`, `bass`, `guitar`, `piano` or `other`) to keep."
# "like: keep the guitar track and remix the given music with text description, "
# "or generate accompaniment as text described with the given vocal track."
# "The input to this tool should be a comma separated string of three, "
# "representing the music_filename, track name, and the text description."
# )
#
# def inference(self, inputs):
# music_filename, track_name, text = inputs.split(",")[0].strip(), inputs.split(",")[1].strip(), inputs.split(",")[2].strip()
# print(f"Generating music from text with accompaniment condition, Input Text: {text}, Previous music: {music_filename}, Track: {track_name}.")
# # separate the track
# updated_main_track = self.ExtractTrack.inference(f"{music_filename}, {track_name}, extract")
# # generate music
# updated_new_music = self.Text2MusicWithMelody.inference(f"{updated_main_track}, {text}")
# # remove the track in accompaniment
# updated_accompaniment = self.ExtractTrack.inference(f"{updated_new_music}, {track_name}, remove")
# # mix
# updated_music_filename = self.SimpleTracksMixing.inference(f"{updated_main_track}, {updated_accompaniment}")
# return updated_music_filename
5 changes: 4 additions & 1 deletion melodytalk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def beat_tracking_with_clip(audio_path: str,
audio_clip.export(output_path, format="wav")


def split_track_with_beat(input_track):
pass

@torch.no_grad()
def CLAP_post_filter(clap_model,
text_description: str,
Expand All @@ -175,7 +178,7 @@ def CLAP_post_filter(clap_model,
# resample the audio_candidates to 48k which supports CLAP model
audio_candidates = resampy.resample(audio_candidates.numpy(), audio_sr, 48000, axis=-1)
audio_candidates = torch.from_numpy(audio_candidates)
# calculate thte audio embedding
# calculate the audio embedding
audio_embedding = clap_model.get_audio_embedding_from_data(x=audio_candidates, use_tensor=True) # (N, D)
# calculate the text embedding
text_embedding = clap_model.get_text_embedding([text_description]) # (1, D)
Expand Down

0 comments on commit c4090f5

Please sign in to comment.