Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ldzhangyx committed Jul 13, 2023
1 parent 5d92919 commit d863fa3
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,6 @@ cython_debug/

output/
assets/
melodytalk/music/

.DS_Store
54 changes: 35 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,50 @@
import os

import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
from melodytalk.audiocraft.models import MusicGen
from melodytalk.audiocraft.data.audio import audio_write
from datetime import datetime
import torch

MODEL_NAME = 'melody'
DURATION = 8
DURATION = 40
CFG_COEF = 3
SAMPLES = 5
PROMPT = 'love pop song with violin, piano arrangement, creating a romantic atmosphere'
# PROMPT = 'music loop. Passionate love song with guitar rhythms, electric piano chords, drums pattern. instrument: guitar, piano, drum.'
PROMPT = "rock music loop with saxophone solo. bpm: 90. instrument: saxophone, guitar, drum."
melody_conditioned = True

melody_conditioned = False
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

model = MusicGen.get_pretrained(MODEL_NAME)
model = MusicGen.get_pretrained(MODEL_NAME, device='cuda')

DURATION_1 = min(DURATION, 30)
DURATION_2 = max(DURATION - 30, 0)
OVERLAP = 8

model.set_generation_params(duration=DURATION,
cfg_coef=CFG_COEF) # generate 8 seconds.
model.set_generation_params(duration=DURATION_1,
cfg_coef=CFG_COEF,) # generate 8 seconds.
# wav = model.generate_unconditional(4) # generates 4 unconditional audio samples
descriptions = [PROMPT] * SAMPLES
# 'A slow and heartbreaking love song at tempo of 60',
# 'A slow and heartbreaking love song with cello instrument']

if not melody_conditioned:
wav = model.generate(descriptions, progress=True) # generates 3 samples.
else:
melody, sr = torchaudio.load('/home/intern-2023-02/melodytalk/assets/1625.wav')
wav = model.generate_with_chroma(descriptions, melody[None].expand(SAMPLES, -1, -1), sr, progress=True)

for idx, one_wav in enumerate(wav):
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
audio_write(f'output/{current_time}_{idx}',
one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
def generate():
if not melody_conditioned:
wav = model.generate(descriptions, progress=True) # generates 3 samples.
else:
melody, sr = torchaudio.load('/home/intern-2023-02/melodytalk/assets/20230705-155518_3.wav')
wav = model.generate_continuation(melody[None].expand(SAMPLES, -1, -1), sr, descriptions, progress=True)
if DURATION_2 > 0:
wav_ = wav[:, :, -OVERLAP * model.sample_rate:]
model.set_generation_params(duration=(OVERLAP + DURATION_2))
wav_2 = model.generate_continuation(wav_, model.sample_rate, descriptions, progress=True)[..., OVERLAP * model.sample_rate:]
wav = torch.cat([wav, wav_2], dim=-1)

for idx, one_wav in enumerate(wav):
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
audio_write(f'output/{current_time}_{idx}',
one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)

generate()
7 changes: 5 additions & 2 deletions melodytalk/audiocraft/models/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyTy

def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
progress: bool = False) -> torch.Tensor:
progress: bool = False, high_pass_filter: bool = True) -> torch.Tensor:
"""Generate samples conditioned on audio prompts.
Args:
Expand All @@ -205,6 +205,7 @@ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
high_pass_filter (bool, optional): Whether to apply a high-pass filter to the prompt. Defaults to True.
"""
if prompt.dim() == 2:
prompt = prompt[None]
Expand All @@ -213,7 +214,7 @@ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
if descriptions is None:
descriptions = [None] * len(prompt)
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt, high_pass_filter=high_pass_filter)
assert prompt_tokens is not None
return self._generate_tokens(attributes, prompt_tokens, progress)

Expand All @@ -223,6 +224,7 @@ def _prepare_tokens_and_attributes(
descriptions: tp.Sequence[tp.Optional[str]],
prompt: tp.Optional[torch.Tensor],
melody_wavs: tp.Optional[MelodyList] = None,
high_pass_filter: bool = True
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
"""Prepare model inputs.
Expand All @@ -231,6 +233,7 @@ def _prepare_tokens_and_attributes(
prompt (torch.Tensor): A batch of waveforms used for continuation.
melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms
used as melody conditioning. Defaults to None.
high_pass_filter (bool, optional): Whether to apply a high-pass filter to the prompt. Defaults to True.
"""
attributes = [
ConditioningAttributes(text={'description': description})
Expand Down
1 change: 0 additions & 1 deletion melodytalk/audiocraft/modules/conditioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,6 @@ def _get_wav_embedding(self, wav):
if wav.shape[-1] == 1:
return self.chroma(wav)
stems = self._get_filtered_wav(wav)
stems = wav
chroma = self.chroma(stems)

if self.match_len_on_eval:
Expand Down
4 changes: 2 additions & 2 deletions melodytalk/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class GlobalAttributes(object):
# metadata
key: str = None
bpm: int = None
genre: str = None
mood: str = None
# genre: str = None
# mood: str = None
instrument: str = None
# text description cache
description: str = None
Expand Down
22 changes: 22 additions & 0 deletions melodytalk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import openai
import typing as tp
import madmom
import torchaudio
from pydub import AudioSegment

openai.api_key = os.getenv("OPENAI_API_KEY")

Expand Down Expand Up @@ -125,3 +127,23 @@ def chord_generation(description: str) -> tp.List:
chord_list = [i.strip() for i in response.choices[0].text.split(' - ')]

return chord_list

def beat_tracking_with_clip(audio_path: str,
output_path: str=None,
offset: int=0,
bar: int=4,
beat_per_bar: int=4,):
proc = madmom.features.beats.DBNDownBeatTrackingProcessor(beats_per_bar=beat_per_bar, fps=100)
beats = proc(audio_path)
# we cut the audio to only bar * beat_per_bar beats, and shift the first beat to offset
first_beat_time = beats[0][0]
last_beat_time = beats[bar * beat_per_bar][0] # the beginning of the next bar
begin_time_with_offset = first_beat_time + offset
end_time = last_beat_time + offset
# cut the audio clip
audio = AudioSegment.from_wav(audio_path)
audio_clip = audio[begin_time_with_offset * 1000: end_time * 1000]
if output_path is None:
output_path = audio_path
audio_clip.export(output_path, format="wav")

0 comments on commit d863fa3

Please sign in to comment.