From 5d9291962be4f2ec3d4ec531f9bdf341c6a6f357 Mon Sep 17 00:00:00 2001 From: Yixiao Zhang Date: Tue, 4 Jul 2023 16:52:25 +0900 Subject: [PATCH] add global attributes --- melodytalk/audiocraft/modules/conditioners.py | 2 +- melodytalk/modules.py | 20 ++++++++++++++++++- melodytalk/utils.py | 1 + 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/melodytalk/audiocraft/modules/conditioners.py b/melodytalk/audiocraft/modules/conditioners.py index e57fdc9..7db09e5 100644 --- a/melodytalk/audiocraft/modules/conditioners.py +++ b/melodytalk/audiocraft/modules/conditioners.py @@ -527,7 +527,7 @@ def _get_wav_embedding(self, wav): # avoid 0-size tensors when we are working with null conds if wav.shape[-1] == 1: return self.chroma(wav) - # stems = self._get_filtered_wav(wav) + stems = self._get_filtered_wav(wav) stems = wav chroma = self.chroma(stems) diff --git a/melodytalk/modules.py b/melodytalk/modules.py index f57ce12..49db3cd 100644 --- a/melodytalk/modules.py +++ b/melodytalk/modules.py @@ -3,6 +3,8 @@ import torch from shutil import copyfile import torchaudio +from dataclasses import dataclass +import typing as tp # text2music from audiocraft.models import MusicGen @@ -12,7 +14,7 @@ from utils import * -DURATION = 12 +DURATION = 15 # Initialze common models musicgen_model = MusicGen.get_pretrained('large') @@ -21,6 +23,22 @@ musicgen_model_melody = MusicGen.get_pretrained('melody') musicgen_model_melody.set_generation_params(duration=DURATION) + +@dataclass +class GlobalAttributes(object): + # metadata + key: str = None + bpm: int = None + genre: str = None + mood: str = None + instrument: str = None + # text description cache + description: str = None + # tracks cache + mix: torch.Tensor = None + stems: tp.Dict[str, torch.Tensor] = None + + class Text2Music(object): def __init__(self, device): print("Initializing Text2Music") diff --git a/melodytalk/utils.py b/melodytalk/utils.py index 8f59379..96e6680 100644 --- a/melodytalk/utils.py +++ b/melodytalk/utils.py @@ -5,6 +5,7 @@ import os import openai import typing as tp +import madmom openai.api_key = os.getenv("OPENAI_API_KEY")