From 3b8ceff6524082bdd03f4dc31625edddb671cfd2 Mon Sep 17 00:00:00 2001 From: Yixiao Zhang Date: Fri, 14 Jul 2023 14:49:41 +0900 Subject: [PATCH] add CLAP filter --- main.py | 4 ++-- melodytalk/modules.py | 12 ++++++++++-- melodytalk/utils.py | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index f1d6174..13edd55 100644 --- a/main.py +++ b/main.py @@ -9,9 +9,9 @@ MODEL_NAME = 'melody' DURATION = 40 CFG_COEF = 3 -SAMPLES = 5 +SAMPLES = 10 # PROMPT = 'music loop. Passionate love song with guitar rhythms, electric piano chords, drums pattern. instrument: guitar, piano, drum.' -PROMPT = "rock music loop with saxophone solo. bpm: 90. instrument: saxophone, guitar, drum." +PROMPT = "rock music loop with rhythmic, quick and technic saxophone solo. bpm: 90. instrument: saxophone, guitar, drum." melody_conditioned = True os.environ["CUDA_VISIBLE_DEVICES"] = "3" diff --git a/melodytalk/modules.py b/melodytalk/modules.py index ff037dc..f30b50b 100644 --- a/melodytalk/modules.py +++ b/melodytalk/modules.py @@ -11,18 +11,26 @@ from audiocraft.data.audio import audio_write # source separation import demucs.separate +# CLAP +import laion_clap from utils import * DURATION = 15 # Initialze common models -musicgen_model = MusicGen.get_pretrained('large') -musicgen_model.set_generation_params(duration=DURATION) +# musicgen_model = MusicGen.get_pretrained('large') +# musicgen_model.set_generation_params(duration=DURATION) musicgen_model_melody = MusicGen.get_pretrained('melody') musicgen_model_melody.set_generation_params(duration=DURATION) +# for acceration +musicgen_model = musicgen_model_melody + +# Intialize CLIP post filter +CLAP_model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", tmodel="roberta", device="cuda") +CLAP_model.load_ckpt("/home/intern-2023-02/melodytalk/melodytalk/pretrained/music_audioset_epoch_15_esc_90.14.pt") @dataclass class GlobalAttributes(object): diff --git a/melodytalk/utils.py b/melodytalk/utils.py index c8dc65a..01bca35 100644 --- a/melodytalk/utils.py +++ b/melodytalk/utils.py @@ -6,6 +6,7 @@ import openai import typing as tp import madmom +import resampy import torchaudio from pydub import AudioSegment @@ -147,3 +148,43 @@ def beat_tracking_with_clip(audio_path: str, output_path = audio_path audio_clip.export(output_path, format="wav") + +@torch.no_grad() +def CLAP_post_filter(clap_model, + text_description: str, + audio_candidates: tp.List[torch.Tensor] or torch.Tensor, + audio_sr: int) \ + -> torch.Tensor and int: + """ This function is a post filter for CLAP model. It takes the text description and audio candidates as input, + and returns the most similar audio and its similarity score. + + args: + clap_model: CLAP model + text_description: the text description of the audio + audio_candidates: the audio candidates + audio_sr: the sample rate of the audio candidates + + return: + audio_embedding: the embedding of the audio candidates + similarity: the similarity score + """ + + # transform the audio_candidates to torch.Tensor with shape (N, L) + if isinstance(audio_candidates, list): + audio_candidates = torch.stack(audio_candidates) + # resample the audio_candidates to 48k which supports CLAP model + audio_candidates = resampy.resample(audio_candidates.numpy(), audio_sr, 48000, axis=-1) + audio_candidates = torch.from_numpy(audio_candidates) + # calculate thte audio embedding + audio_embedding = clap_model.get_audio_embedding_from_data(x=audio_candidates, use_tensor=True) # (N, D) + # calculate the text embedding + text_embedding = clap_model.get_text_embedding([text_description]) # (1, D) + # calculate the similarity by dot product + similarity = torch.matmul(text_embedding, audio_embedding.T) # (1, N) + # get the index of the most similar audio + index = torch.argmax(similarity) + # return + return audio_candidates[index], similarity[index] + + +