Skip to content

Commit

Permalink
add CLAP filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ldzhangyx committed Jul 14, 2023
1 parent d863fa3 commit 3b8ceff
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
MODEL_NAME = 'melody'
DURATION = 40
CFG_COEF = 3
SAMPLES = 5
SAMPLES = 10
# PROMPT = 'music loop. Passionate love song with guitar rhythms, electric piano chords, drums pattern. instrument: guitar, piano, drum.'
PROMPT = "rock music loop with saxophone solo. bpm: 90. instrument: saxophone, guitar, drum."
PROMPT = "rock music loop with rhythmic, quick and technic saxophone solo. bpm: 90. instrument: saxophone, guitar, drum."
melody_conditioned = True

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
Expand Down
12 changes: 10 additions & 2 deletions melodytalk/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,26 @@
from audiocraft.data.audio import audio_write
# source separation
import demucs.separate
# CLAP
import laion_clap

from utils import *

DURATION = 15

# Initialze common models
musicgen_model = MusicGen.get_pretrained('large')
musicgen_model.set_generation_params(duration=DURATION)
# musicgen_model = MusicGen.get_pretrained('large')
# musicgen_model.set_generation_params(duration=DURATION)

musicgen_model_melody = MusicGen.get_pretrained('melody')
musicgen_model_melody.set_generation_params(duration=DURATION)

# for acceration
musicgen_model = musicgen_model_melody

# Intialize CLIP post filter
CLAP_model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", tmodel="roberta", device="cuda")
CLAP_model.load_ckpt("/home/intern-2023-02/melodytalk/melodytalk/pretrained/music_audioset_epoch_15_esc_90.14.pt")

@dataclass
class GlobalAttributes(object):
Expand Down
41 changes: 41 additions & 0 deletions melodytalk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import openai
import typing as tp
import madmom
import resampy
import torchaudio
from pydub import AudioSegment

Expand Down Expand Up @@ -147,3 +148,43 @@ def beat_tracking_with_clip(audio_path: str,
output_path = audio_path
audio_clip.export(output_path, format="wav")


@torch.no_grad()
def CLAP_post_filter(clap_model,
text_description: str,
audio_candidates: tp.List[torch.Tensor] or torch.Tensor,
audio_sr: int) \
-> torch.Tensor and int:
""" This function is a post filter for CLAP model. It takes the text description and audio candidates as input,
and returns the most similar audio and its similarity score.
args:
clap_model: CLAP model
text_description: the text description of the audio
audio_candidates: the audio candidates
audio_sr: the sample rate of the audio candidates
return:
audio_embedding: the embedding of the audio candidates
similarity: the similarity score
"""

# transform the audio_candidates to torch.Tensor with shape (N, L)
if isinstance(audio_candidates, list):
audio_candidates = torch.stack(audio_candidates)
# resample the audio_candidates to 48k which supports CLAP model
audio_candidates = resampy.resample(audio_candidates.numpy(), audio_sr, 48000, axis=-1)
audio_candidates = torch.from_numpy(audio_candidates)
# calculate thte audio embedding
audio_embedding = clap_model.get_audio_embedding_from_data(x=audio_candidates, use_tensor=True) # (N, D)
# calculate the text embedding
text_embedding = clap_model.get_text_embedding([text_description]) # (1, D)
# calculate the similarity by dot product
similarity = torch.matmul(text_embedding, audio_embedding.T) # (1, N)
# get the index of the most similar audio
index = torch.argmax(similarity)
# return
return audio_candidates[index], similarity[index]



0 comments on commit 3b8ceff

Please sign in to comment.