diff --git a/melodytalk/dependencies/vampnet/__init__.py b/melodytalk/dependencies/vampnet/__init__.py new file mode 100644 index 0000000..2a9dd07 --- /dev/null +++ b/melodytalk/dependencies/vampnet/__init__.py @@ -0,0 +1,6 @@ + +from . import modules +from . import scheduler +from .interface import Interface + +__version__ = "0.0.1" diff --git a/melodytalk/dependencies/vampnet/beats.py b/melodytalk/dependencies/vampnet/beats.py new file mode 100644 index 0000000..2b03a4e --- /dev/null +++ b/melodytalk/dependencies/vampnet/beats.py @@ -0,0 +1,250 @@ +import json +import logging +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import List +from typing import Tuple +from typing import Union + +import librosa +import torch +import numpy as np +from audiotools import AudioSignal + + +logging.basicConfig(level=logging.INFO) + +################### +# beat sync utils # +################### + +AGGREGATOR_REGISTRY = { + "mean": np.mean, + "median": np.median, + "max": np.max, + "min": np.min, +} + + +def list_aggregators() -> list: + return list(AGGREGATOR_REGISTRY.keys()) + + +@dataclass +class TimeSegment: + start: float + end: float + + @property + def duration(self): + return self.end - self.start + + def __str__(self) -> str: + return f"{self.start} - {self.end}" + + def find_overlapping_segment( + self, segments: List["TimeSegment"] + ) -> Union["TimeSegment", None]: + """Find the first segment that overlaps with this segment, or None if no segment overlaps""" + for s in segments: + if s.start <= self.start and s.end >= self.end: + return s + return None + + +def mkdir(path: Union[Path, str]) -> Path: + p = Path(path) + p.mkdir(parents=True, exist_ok=True) + return p + + + +################### +# beat data # +################### +@dataclass +class BeatSegment(TimeSegment): + downbeat: bool = False # if there's a downbeat on the start_time + + +class Beats: + def __init__(self, beat_times, downbeat_times): + if isinstance(beat_times, np.ndarray): + beat_times = beat_times.tolist() + if isinstance(downbeat_times, np.ndarray): + downbeat_times = downbeat_times.tolist() + self._beat_times = beat_times + self._downbeat_times = downbeat_times + self._use_downbeats = False + + def use_downbeats(self, use_downbeats: bool = True): + """use downbeats instead of beats when calling beat_times""" + self._use_downbeats = use_downbeats + + def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]: + """ + segments a song into time segments corresponding to beats. + the first segment starts at 0 and ends at the first beat time. + the last segment starts at the last beat time and ends at the end of the song. + """ + beat_times = self._beat_times.copy() + downbeat_times = self._downbeat_times + beat_times.insert(0, 0) + beat_times.append(signal.signal_duration) + + downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[ + 1 + ] + is_downbeat = [ + True if i in downbeat_ids else False for i in range(len(beat_times)) + ] + segments = [ + BeatSegment(start_time, end_time, downbeat) + for start_time, end_time, downbeat in zip( + beat_times[:-1], beat_times[1:], is_downbeat + ) + ] + return segments + + def get_beats(self) -> np.ndarray: + """returns an array of beat times, in seconds + if downbeats is True, returns an array of downbeat times, in seconds + """ + return np.array( + self._downbeat_times if self._use_downbeats else self._beat_times + ) + + @property + def beat_times(self) -> np.ndarray: + """return beat times""" + return np.array(self._beat_times) + + @property + def downbeat_times(self) -> np.ndarray: + """return downbeat times""" + return np.array(self._downbeat_times) + + def beat_times_to_feature_frames( + self, signal: AudioSignal, features: np.ndarray + ) -> np.ndarray: + """convert beat times to frames, given an array of time-varying features""" + beat_times = self.get_beats() + beat_frames = ( + beat_times * signal.sample_rate / signal.signal_length * features.shape[-1] + ).astype(np.int64) + return beat_frames + + def sync_features( + self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median" + ) -> np.ndarray: + """sync features to beats""" + if aggregate not in AGGREGATOR_REGISTRY: + raise ValueError(f"unknown aggregation method {aggregate}") + + return librosa.util.sync( + features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate] + ) + + def to_json(self) -> dict: + """return beats and downbeats as json""" + return { + "beats": self._beat_times, + "downbeats": self._downbeat_times, + "use_downbeats": self._use_downbeats, + } + + @classmethod + def from_dict(cls, data: dict): + """load beats and downbeats from json""" + inst = cls(data["beats"], data["downbeats"]) + inst.use_downbeats(data["use_downbeats"]) + return inst + + def save(self, output_dir: Path): + """save beats and downbeats to json""" + mkdir(output_dir) + with open(output_dir / "beats.json", "w") as f: + json.dump(self.to_json(), f) + + @classmethod + def load(cls, input_dir: Path): + """load beats and downbeats from json""" + beats_file = Path(input_dir) / "beats.json" + with open(beats_file, "r") as f: + data = json.load(f) + return cls.from_dict(data) + + +################### +# beat tracking # +################### + + +class BeatTracker: + def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]: + """extract beats from an audio signal""" + raise NotImplementedError + + def __call__(self, signal: AudioSignal) -> Beats: + """extract beats from an audio signal + NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio, + it is discarded. This is to avoid empty bins with no beat synced features in the first beat. + Args: + signal (AudioSignal): signal to beat track + Returns: + Tuple[np.ndarray, np.ndarray]: beats and downbeats + """ + beats, downbeats = self.extract_beats(signal) + return Beats(beats, downbeats) + + +class WaveBeat(BeatTracker): + def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"): + from wavebeat.dstcn import dsTCNModel + + model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device)) + model.eval() + + self.device = device + self.model = model + + def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]: + """returns beat and downbeat times, in seconds""" + # extract beats + beats, downbeats = self.model.predict_beats_from_array( + audio=signal.audio_data.squeeze(0), + sr=signal.sample_rate, + use_gpu=self.device != "cpu", + ) + + return beats, downbeats + + +class MadmomBeats(BeatTracker): + def __init__(self): + raise NotImplementedError + + def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]: + """returns beat and downbeat times, in seconds""" + pass + + +BEAT_TRACKER_REGISTRY = { + "wavebeat": WaveBeat, + "madmom": MadmomBeats, +} + + +def list_beat_trackers() -> list: + return list(BEAT_TRACKER_REGISTRY.keys()) + + +def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker: + if beat_tracker not in BEAT_TRACKER_REGISTRY: + raise ValueError( + f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}" + ) + + return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs) \ No newline at end of file diff --git a/melodytalk/dependencies/vampnet/interface.py b/melodytalk/dependencies/vampnet/interface.py new file mode 100644 index 0000000..0826c27 --- /dev/null +++ b/melodytalk/dependencies/vampnet/interface.py @@ -0,0 +1,423 @@ +import os +from pathlib import Path +import math + +import torch +import numpy as np +from audiotools import AudioSignal +import tqdm + +from .modules.transformer import VampNet +from .beats import WaveBeat +from .mask import * + +# from dac.model.dac import DAC +from lac.model.lac import LAC as DAC + + +def signal_concat( + audio_signals: list, +): + audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1) + + return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate) + + +def _load_model( + ckpt: str, + lora_ckpt: str = None, + device: str = "cpu", + chunk_size_s: int = 10, +): + # we need to set strict to False if the model has lora weights to add later + model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False) + + # load lora weights if needed + if lora_ckpt is not None: + if not Path(lora_ckpt).exists(): + should_cont = input( + f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) " + ) + if should_cont != "y": + raise Exception("aborting") + else: + model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False) + + model.to(device) + model.eval() + model.chunk_size_s = chunk_size_s + return model + + + +class Interface(torch.nn.Module): + def __init__( + self, + coarse_ckpt: str = None, + coarse_lora_ckpt: str = None, + coarse2fine_ckpt: str = None, + coarse2fine_lora_ckpt: str = None, + codec_ckpt: str = None, + wavebeat_ckpt: str = None, + device: str = "cpu", + coarse_chunk_size_s: int = 10, + coarse2fine_chunk_size_s: int = 3, + ): + super().__init__() + assert codec_ckpt is not None, "must provide a codec checkpoint" + self.codec = DAC.load(Path(codec_ckpt)) + self.codec.eval() + self.codec.to(device) + + assert coarse_ckpt is not None, "must provide a coarse checkpoint" + self.coarse = _load_model( + ckpt=coarse_ckpt, + lora_ckpt=coarse_lora_ckpt, + device=device, + chunk_size_s=coarse_chunk_size_s, + ) + + # check if we have a coarse2fine ckpt + if coarse2fine_ckpt is not None: + self.c2f = _load_model( + ckpt=coarse2fine_ckpt, + lora_ckpt=coarse2fine_lora_ckpt, + device=device, + chunk_size_s=coarse2fine_chunk_size_s, + ) + else: + self.c2f = None + + if wavebeat_ckpt is not None: + print(f"loading wavebeat from {wavebeat_ckpt}") + self.beat_tracker = WaveBeat(wavebeat_ckpt) + self.beat_tracker.model.to(device) + else: + self.beat_tracker = None + + self.device = device + + def lora_load( + self, + coarse_ckpt: str = None, + c2f_ckpt: str = None, + full_ckpts: bool = False, + ): + if full_ckpts: + if coarse_ckpt is not None: + self.coarse = _load_model( + ckpt=coarse_ckpt, + device=self.device, + chunk_size_s=self.coarse.chunk_size_s, + ) + if c2f_ckpt is not None: + self.c2f = _load_model( + ckpt=c2f_ckpt, + device=self.device, + chunk_size_s=self.c2f.chunk_size_s, + ) + else: + if coarse_ckpt is not None: + self.coarse.to("cpu") + state_dict = torch.load(coarse_ckpt, map_location="cpu") + + self.coarse.load_state_dict(state_dict, strict=False) + self.coarse.to(self.device) + if c2f_ckpt is not None: + self.c2f.to("cpu") + state_dict = torch.load(c2f_ckpt, map_location="cpu") + + self.c2f.load_state_dict(state_dict, strict=False) + self.c2f.to(self.device) + + + def s2t(self, seconds: float): + """seconds to tokens""" + if isinstance(seconds, np.ndarray): + return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length) + else: + return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length) + + def s2t2s(self, seconds: float): + """seconds to tokens to seconds""" + return self.t2s(self.s2t(seconds)) + + def t2s(self, tokens: int): + """tokens to seconds""" + return tokens * self.codec.hop_length / self.codec.sample_rate + + def to(self, device): + self.device = device + self.coarse.to(device) + self.codec.to(device) + + if self.c2f is not None: + self.c2f.to(device) + + if self.beat_tracker is not None: + self.beat_tracker.model.to(device) + return self + + def to_signal(self, z: torch.Tensor): + return self.coarse.to_signal(z, self.codec) + + def preprocess(self, signal: AudioSignal): + signal = ( + signal.clone() + .resample(self.codec.sample_rate) + .to_mono() + .normalize(-24) + .ensure_max_of_audio(1.0) + ) + return signal + + @torch.inference_mode() + def encode(self, signal: AudioSignal): + signal = self.preprocess(signal).to(self.device) + z = self.codec.encode(signal.samples, signal.sample_rate)["codes"] + return z + + def snap_to_beats( + self, + signal: AudioSignal + ): + assert hasattr(self, "beat_tracker"), "No beat tracker loaded" + beats, downbeats = self.beat_tracker.extract_beats(signal) + + # trim the signa around the first beat time + samples_begin = int(beats[0] * signal.sample_rate ) + samples_end = int(beats[-1] * signal.sample_rate) + print(beats[0]) + signal = signal.clone().trim(samples_begin, signal.length - samples_end) + + return signal + + def make_beat_mask(self, + signal: AudioSignal, + before_beat_s: float = 0.0, + after_beat_s: float = 0.02, + mask_downbeats: bool = True, + mask_upbeats: bool = True, + downbeat_downsample_factor: int = None, + beat_downsample_factor: int = None, + dropout: float = 0.0, + invert: bool = True, + ): + """make a beat synced mask. that is, make a mask that + places 1s at and around the beat, and 0s everywhere else. + """ + assert self.beat_tracker is not None, "No beat tracker loaded" + + # get the beat times + beats, downbeats = self.beat_tracker.extract_beats(signal) + + # get the beat indices in z + beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats) + + # remove downbeats from beats + beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))] + beats_z = beats_z.tolist() + downbeats_z = downbeats_z.tolist() + + # make the mask + seq_len = self.s2t(signal.duration) + mask = torch.zeros(seq_len, device=self.device) + + mask_b4 = self.s2t(before_beat_s) + mask_after = self.s2t(after_beat_s) + + if beat_downsample_factor is not None: + if beat_downsample_factor < 1: + raise ValueError("mask_beat_downsample_factor must be >= 1 or None") + else: + beat_downsample_factor = 1 + + if downbeat_downsample_factor is not None: + if downbeat_downsample_factor < 1: + raise ValueError("mask_beat_downsample_factor must be >= 1 or None") + else: + downbeat_downsample_factor = 1 + + beats_z = beats_z[::beat_downsample_factor] + downbeats_z = downbeats_z[::downbeat_downsample_factor] + print(f"beats_z: {len(beats_z)}") + print(f"downbeats_z: {len(downbeats_z)}") + + if mask_upbeats: + for beat_idx in beats_z: + _slice = int(beat_idx - mask_b4), int(beat_idx + mask_after) + num_steps = mask[_slice[0]:_slice[1]].shape[0] + _m = torch.ones(num_steps, device=self.device) + _m_mask = torch.bernoulli(_m * (1 - dropout)) + _m = _m * _m_mask.long() + + mask[_slice[0]:_slice[1]] = _m + + if mask_downbeats: + for downbeat_idx in downbeats_z: + _slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after) + num_steps = mask[_slice[0]:_slice[1]].shape[0] + _m = torch.ones(num_steps, device=self.device) + _m_mask = torch.bernoulli(_m * (1 - dropout)) + _m = _m * _m_mask.long() + + mask[_slice[0]:_slice[1]] = _m + + mask = mask.clamp(0, 1) + if invert: + mask = 1 - mask + + mask = mask[None, None, :].bool().long() + if self.c2f is not None: + mask = mask.repeat(1, self.c2f.n_codebooks, 1) + else: + mask = mask.repeat(1, self.coarse.n_codebooks, 1) + return mask + + def coarse_to_fine( + self, + z: torch.Tensor, + mask: torch.Tensor = None, + **kwargs + ): + assert self.c2f is not None, "No coarse2fine model loaded" + length = z.shape[-1] + chunk_len = self.s2t(self.c2f.chunk_size_s) + n_chunks = math.ceil(z.shape[-1] / chunk_len) + + # zero pad to chunk_len + if length % chunk_len != 0: + pad_len = chunk_len - (length % chunk_len) + z = torch.nn.functional.pad(z, (0, pad_len)) + mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None + + n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1] + if n_codebooks_to_append > 0: + z = torch.cat([ + z, + torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device) + ], dim=1) + + # set the mask to 0 for all conditioning codebooks + if mask is not None: + mask = mask.clone() + mask[:, :self.c2f.n_conditioning_codebooks, :] = 0 + + fine_z = [] + for i in range(n_chunks): + chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len] + mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None + + chunk = self.c2f.generate( + codec=self.codec, + time_steps=chunk_len, + start_tokens=chunk, + return_signal=False, + mask=mask_chunk, + **kwargs + ) + fine_z.append(chunk) + + fine_z = torch.cat(fine_z, dim=-1) + return fine_z[:, :, :length].clone() + + def coarse_vamp( + self, + z, + mask, + return_mask=False, + gen_fn=None, + **kwargs + ): + # coarse z + cz = z[:, : self.coarse.n_codebooks, :].clone() + assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}" + + mask = mask[:, : self.coarse.n_codebooks, :] + + cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token) + cz_masked = cz_masked[:, : self.coarse.n_codebooks, :] + + gen_fn = gen_fn or self.coarse.generate + c_vamp = gen_fn( + codec=self.codec, + time_steps=cz.shape[-1], + start_tokens=cz, + mask=mask, + return_signal=False, + **kwargs + ) + + # add the fine codes back in + c_vamp = torch.cat( + [c_vamp, z[:, self.coarse.n_codebooks :, :]], + dim=1 + ) + + if return_mask: + return c_vamp, cz_masked + + return c_vamp + + +if __name__ == "__main__": + import audiotools as at + import logging + logger = logging.getLogger() + logger.setLevel(logging.INFO) + torch.set_printoptions(threshold=10000) + at.util.seed(42) + + interface = Interface( + coarse_ckpt="./models/vampnet/coarse.pth", + coarse2fine_ckpt="./models/vampnet/c2f.pth", + codec_ckpt="./models/vampnet/codec.pth", + device="cuda", + wavebeat_ckpt="./models/wavebeat.pth" + ) + + + sig = at.AudioSignal('assets/example.wav') + + z = interface.encode(sig) + breakpoint() + + # mask = linear_random(z, 1.0) + # mask = mask_and( + # mask, periodic_mask( + # z, + # 32, + # 1, + # random_roll=True + # ) + # ) + + # mask = interface.make_beat_mask( + # sig, 0.0, 0.075 + # ) + # mask = dropout(mask, 0.0) + # mask = codebook_unmask(mask, 0) + + mask = inpaint(z, n_prefix=100, n_suffix=100) + + zv, mask_z = interface.coarse_vamp( + z, + mask=mask, + sampling_steps=36, + temperature=8.0, + return_mask=True, + gen_fn=interface.coarse.generate + ) + + + use_coarse2fine = True + if use_coarse2fine: + zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask) + breakpoint() + + mask = interface.to_signal(mask_z).cpu() + + sig = interface.to_signal(zv).cpu() + print("done") + + \ No newline at end of file diff --git a/melodytalk/dependencies/vampnet/mask.py b/melodytalk/dependencies/vampnet/mask.py new file mode 100644 index 0000000..1302fd1 --- /dev/null +++ b/melodytalk/dependencies/vampnet/mask.py @@ -0,0 +1,219 @@ +from typing import Optional + +import torch +from audiotools import AudioSignal + +from .util import scalar_to_batch_tensor + +def _gamma(r): + return (r * torch.pi / 2).cos().clamp(1e-10, 1.0) + +def _invgamma(y): + if not torch.is_tensor(y): + y = torch.tensor(y)[None] + return 2 * y.acos() / torch.pi + +def full_mask(x: torch.Tensor): + assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" + return torch.ones_like(x).long() + +def empty_mask(x: torch.Tensor): + assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" + return torch.zeros_like(x).long() + +def apply_mask( + x: torch.Tensor, + mask: torch.Tensor, + mask_token: int + ): + assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}" + assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}" + assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}" + assert ~torch.any(mask > 1), "mask must be binary" + assert ~torch.any(mask < 0), "mask must be binary" + + fill_x = torch.full_like(x, mask_token) + x = x * (1 - mask) + fill_x * mask + + return x, mask + +def random( + x: torch.Tensor, + r: torch.Tensor +): + assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" + if not isinstance(r, torch.Tensor): + r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device) + + r = _gamma(r)[:, None, None] + probs = torch.ones_like(x) * r + + mask = torch.bernoulli(probs) + mask = mask.round().long() + + return mask + +def linear_random( + x: torch.Tensor, + r: torch.Tensor, +): + assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" + if not isinstance(r, torch.Tensor): + r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float() + + probs = torch.ones_like(x).to(x.device).float() + # expand to batch and codebook dims + probs = probs.expand(x.shape[0], x.shape[1], -1) + probs = probs * r + + mask = torch.bernoulli(probs) + mask = mask.round().long() + + return mask + +def inpaint(x: torch.Tensor, + n_prefix, + n_suffix, +): + assert n_prefix is not None + assert n_suffix is not None + + mask = full_mask(x) + + # if we have a prefix or suffix, set their mask prob to 0 + if n_prefix > 0: + if not isinstance(n_prefix, torch.Tensor): + n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device) + for i, n in enumerate(n_prefix): + if n > 0: + mask[i, :, :n] = 0.0 + if n_suffix > 0: + if not isinstance(n_suffix, torch.Tensor): + n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device) + for i, n in enumerate(n_suffix): + if n > 0: + mask[i, :, -n:] = 0.0 + + + return mask + +def periodic_mask(x: torch.Tensor, + period: int, width: int = 1, + random_roll=False, + ): + mask = full_mask(x) + if period == 0: + return mask + + if not isinstance(period, torch.Tensor): + period = scalar_to_batch_tensor(period, x.shape[0]) + for i, factor in enumerate(period): + if factor == 0: + continue + for j in range(mask.shape[-1]): + if j % factor == 0: + # figure out how wide the mask should be + j_start = max(0, j - width // 2 ) + j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1 + # flip a coin for each position in the mask + j_mask = torch.bernoulli(torch.ones(j_end - j_start)) + assert torch.all(j_mask == 1) + j_fill = torch.ones_like(j_mask) * (1 - j_mask) + assert torch.all(j_fill == 0) + # fill + mask[i, :, j_start:j_end] = j_fill + if random_roll: + # add a random offset to the mask + offset = torch.randint(0, period[0], (1,)) + mask = torch.roll(mask, offset.item(), dims=-1) + + return mask + +def codebook_unmask( + mask: torch.Tensor, + n_conditioning_codebooks: int +): + if n_conditioning_codebooks == None: + return mask + # if we have any conditioning codebooks, set their mask to 0 + mask = mask.clone() + mask[:, :n_conditioning_codebooks, :] = 0 + return mask + +def mask_and( + mask1: torch.Tensor, + mask2: torch.Tensor +): + assert mask1.shape == mask2.shape, "masks must be same shape" + return torch.min(mask1, mask2) + +def dropout( + mask: torch.Tensor, + p: float, +): + assert 0 <= p <= 1, "p must be between 0 and 1" + assert mask.max() <= 1, "mask must be binary" + assert mask.min() >= 0, "mask must be binary" + mask = (~mask.bool()).float() + mask = torch.bernoulli(mask * (1 - p)) + mask = ~mask.round().bool() + return mask.long() + +def mask_or( + mask1: torch.Tensor, + mask2: torch.Tensor +): + assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}" + assert mask1.max() <= 1, "mask1 must be binary" + assert mask2.max() <= 1, "mask2 must be binary" + assert mask1.min() >= 0, "mask1 must be binary" + assert mask2.min() >= 0, "mask2 must be binary" + return (mask1 + mask2).clamp(0, 1) + +def time_stretch_mask( + x: torch.Tensor, + stretch_factor: int, +): + assert stretch_factor >= 1, "stretch factor must be >= 1" + c_seq_len = x.shape[-1] + x = x.repeat_interleave(stretch_factor, dim=-1) + + # trim cz to the original length + x = x[:, :, :c_seq_len] + + mask = periodic_mask(x, stretch_factor, width=1) + return mask + +def onset_mask( + sig: AudioSignal, + z: torch.Tensor, + interface, + width: int = 1 +): + import librosa + + onset_indices = librosa.onset.onset_detect( + y=sig.clone().to_mono().samples.cpu().numpy()[0, 0], + sr=sig.sample_rate, + hop_length=interface.codec.hop_length, + backtrack=True, + ) + + # create a mask, set onset + mask = torch.ones_like(z) + n_timesteps = z.shape[-1] + + for onset_index in onset_indices: + onset_index = min(onset_index, n_timesteps - 1) + onset_index = max(onset_index, 0) + mask[:, :, onset_index - width:onset_index + width] = 0.0 + + print(mask) + + return mask + + + +if __name__ == "__main__": + torch.set_printoptions(threshold=10000) + diff --git a/melodytalk/dependencies/vampnet/modules/__init__.py b/melodytalk/dependencies/vampnet/modules/__init__.py new file mode 100644 index 0000000..3f4c8c2 --- /dev/null +++ b/melodytalk/dependencies/vampnet/modules/__init__.py @@ -0,0 +1,6 @@ +import audiotools + +audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"] +audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"] + +from .transformer import VampNet \ No newline at end of file diff --git a/melodytalk/dependencies/vampnet/modules/activations.py b/melodytalk/dependencies/vampnet/modules/activations.py new file mode 100644 index 0000000..c013c63 --- /dev/null +++ b/melodytalk/dependencies/vampnet/modules/activations.py @@ -0,0 +1,55 @@ +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class NewGELU(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo + (identical to OpenAI GPT). Also see the Gaussian Error Linear Units + paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, x): + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) + ) + ) + ) + +class GatedGELU(nn.Module): + def __init__(self): + super().__init__() + self.gelu = NewGELU() + + def forward(self, x, dim: int = -1): + p1, p2 = x.chunk(2, dim=dim) + return p1 * self.gelu(p2) + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(channels)) + + def forward(self, x): + return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2) + +def get_activation(name: str = "relu"): + if name == "relu": + return nn.ReLU + elif name == "gelu": + return NewGELU + elif name == "geglu": + return GatedGELU + elif name == "snake": + return Snake1d + else: + raise ValueError(f"Unrecognized activation {name}") \ No newline at end of file diff --git a/melodytalk/dependencies/vampnet/modules/layers.py b/melodytalk/dependencies/vampnet/modules/layers.py new file mode 100644 index 0000000..0c7df97 --- /dev/null +++ b/melodytalk/dependencies/vampnet/modules/layers.py @@ -0,0 +1,164 @@ +import time +from typing import Optional +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +def num_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def recurse_children(module, fn): + for child in module.children(): + if isinstance(child, nn.ModuleList): + for c in child: + yield recurse_children(c, fn) + if isinstance(child, nn.ModuleDict): + for c in child.values(): + yield recurse_children(c, fn) + + yield recurse_children(child, fn) + yield fn(child) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class SequentialWithFiLM(nn.Module): + """ + handy wrapper for nn.Sequential that allows FiLM layers to be + inserted in between other layers. + """ + + def __init__(self, *layers): + super().__init__() + self.layers = nn.ModuleList(layers) + + @staticmethod + def has_film(module): + mod_has_film = any( + [res for res in recurse_children(module, lambda c: isinstance(c, FiLM))] + ) + return mod_has_film + + def forward(self, x, cond): + for layer in self.layers: + if self.has_film(layer): + x = layer(x, cond) + else: + x = layer(x) + return x + + +class FiLM(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if input_dim > 0: + self.beta = nn.Linear(input_dim, output_dim) + self.gamma = nn.Linear(input_dim, output_dim) + + def forward(self, x, r): + if self.input_dim == 0: + return x + else: + beta, gamma = self.beta(r), self.gamma(r) + beta, gamma = ( + beta.view(x.size(0), self.output_dim, 1), + gamma.view(x.size(0), self.output_dim, 1), + ) + x = x * (gamma + 1) + beta + return x + + +class CodebookEmbedding(nn.Module): + def __init__( + self, + vocab_size: int, + latent_dim: int, + n_codebooks: int, + emb_dim: int, + special_tokens: Optional[Tuple[str]] = None, + ): + super().__init__() + self.n_codebooks = n_codebooks + self.emb_dim = emb_dim + self.latent_dim = latent_dim + self.vocab_size = vocab_size + + if special_tokens is not None: + for tkn in special_tokens: + self.special = nn.ParameterDict( + { + tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim)) + for tkn in special_tokens + } + ) + self.special_idxs = { + tkn: i + vocab_size for i, tkn in enumerate(special_tokens) + } + + self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1) + + def from_codes(self, codes: torch.Tensor, codec): + """ + get a sequence of continuous embeddings from a sequence of discrete codes. + unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens + necessary for the language model, like . + """ + n_codebooks = codes.shape[1] + latent = [] + for i in range(n_codebooks): + c = codes[:, i, :] + + lookup_table = codec.quantizer.quantizers[i].codebook.weight + if hasattr(self, "special"): + special_lookup = torch.cat( + [self.special[tkn][i : i + 1] for tkn in self.special], dim=0 + ) + lookup_table = torch.cat([lookup_table, special_lookup], dim=0) + + l = F.embedding(c, lookup_table).transpose(1, 2) + latent.append(l) + + latent = torch.cat(latent, dim=1) + return latent + + def forward(self, latents: torch.Tensor): + """ + project a sequence of latents to a sequence of embeddings + """ + x = self.out_proj(latents) + return x + diff --git a/melodytalk/dependencies/vampnet/modules/transformer.py b/melodytalk/dependencies/vampnet/modules/transformer.py new file mode 100644 index 0000000..afe0c7d --- /dev/null +++ b/melodytalk/dependencies/vampnet/modules/transformer.py @@ -0,0 +1,942 @@ +import math +import logging +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import loralib as lora +import audiotools as at + +from .activations import get_activation +from .layers import CodebookEmbedding +from .layers import FiLM +from .layers import SequentialWithFiLM +from .layers import WNConv1d +from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten +from ..mask import _gamma + +LORA_R = 8 + +# def log(t, eps=1e-20): +# return torch.log(t + eps) + + +def gumbel_noise_like(t): + noise = torch.zeros_like(t).uniform_(1e-20, 1) + return -torch.log(-torch.log(noise)) + + +def gumbel_sample(t, temperature=1.0, dim=-1): + return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim) + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.var_eps = eps + + def forward(self, x): + """Returns root mean square normalized version of input `x` + # T5 uses a layer_norm which only scales and doesn't shift, which is also known + # as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467 + # thus varience is calculated w/o mean and there is no bias + Parameters + ---------- + x : Tensor[B x T x D] + Returns + ------- + Tensor[B x T x D] + """ + var = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var + self.var_eps) + + return self.weight * x + + +class FeedForward(nn.Module): + def __init__( + self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu" + ): + super().__init__() + factor = 2 if activation == "geglu" else 1 + self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R) + self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R) + self.drop = nn.Dropout(dropout) + self.act = get_activation(activation)() + + def forward(self, x): + """Computes position-wise feed-forward layer + Parameters + ---------- + x : Tensor[B x T x D] + Returns + ------- + Tensor[B x T x D] + """ + x = self.w_1(x) + x = self.act(x) + x = self.drop(x) + x = self.w_2(x) + return x + + +class MultiHeadRelativeAttention(nn.Module): + def __init__( + self, + n_head: int = 8, + d_model: int = 512, + dropout: float = 0.1, + bidirectional: bool = True, + has_relative_attention_bias: bool = True, + attention_num_buckets: int = 32, + attention_max_distance: int = 128, + ): + super().__init__() + d_head = d_model // n_head + self.n_head = n_head + self.d_head = d_head + self.bidirectional = bidirectional + self.has_relative_attention_bias = has_relative_attention_bias + self.attention_num_buckets = attention_num_buckets + self.attention_max_distance = attention_max_distance + + # Create linear query, key, value projections + self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R) + self.w_ks = nn.Linear(d_model, d_model, bias=False) + self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R) + + # Create linear final output projection + self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R) + + # Dropout for attention output weights + self.dropout = nn.Dropout(dropout) + + # Create relative positional embeddings (if turned on) + if has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head) + + def _relative_position_bucket(self, relative_position): + """Converts unbounded relative position into bounded set of buckets + with half "exact" buckets (1 position = 1 bucket) and half "log-spaced" + buckets + Parameters + ---------- + relative_position : Tensor[T_q x T_kv] + Relative positions between queries and key_value items + Returns + ------- + Tensor[T_q x T_kv] + Input relative positions converted into buckets + """ + relative_buckets = 0 + num_buckets = self.attention_num_buckets + max_distance = self.attention_max_distance + + # Convert relative position for (-inf, inf) to [0, inf] + # Negative relative positions correspond to past + # Positive relative positions correspond to future + if self.bidirectional: + # use half buckets for each side (past / future) + num_buckets //= 2 + + # Shift the position positions by `num_buckets` to wrap around + # negative positions + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + # If not bidirectional, ignore positive positions and wrap + # negative positions to positive + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + + # Allocate half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in + # positions up to `max_distance` + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + + # Clip the max relative position to `num_buckets - 1` + relative_postion_if_large = torch.min( + relative_postion_if_large, + torch.full_like(relative_postion_if_large, num_buckets - 1), + ) + + # Choose relative buckets based on small or large positions + relative_buckets += torch.where( + is_small, relative_position, relative_postion_if_large + ) + + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Computes a position bias scalar for each index in query_length x key_length + Parameters + ---------- + query_length : int + key_length : int + Returns + ------- + Tensor[heads x 1 x T_q x T_kv] + Position bias to be applied on attention logits + """ + + query_position = torch.arange(query_length, dtype=torch.long)[:, None] + key_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = key_position - query_position + + # Convert relative position to buckets + relative_position_bucket = self._relative_position_bucket(relative_position) + relative_position_bucket = relative_position_bucket.to( + self.relative_attention_bias.weight.device + ) + + # Index attention bias values + values = self.relative_attention_bias(relative_position_bucket) + values = rearrange(values, "q k h -> h 1 q k") + + return values + + def forward(self, q, k, v, mask=None, position_bias=None): + """Computes attention over (keys, values) for every timestep in query + Parameters + ---------- + q : Tensor[B x T_q x d_model] + Query vectors + k : Tensor[B x T_kv x d_model] + Key vectors to compute attention over + v : Tensor[B x T_kv x d_model] + Value vectors corresponding to the keys + mask : Tensor[B x T_q x T_kv], optional + position_bias: Tensor[head x 1 x T_q x T_kv] + Returns + ------- + Tensor[B x T_q x d_model] + Outputs after attending (key, value) using queries + """ + # Compute query, key, value projections + q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head) + k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head) + v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head) + + # Compute attention matrix + attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1]) + + # Add relative position bias to attention scores + if position_bias is None: + if self.has_relative_attention_bias: + position_bias = self.compute_bias(q.size(-2), k.size(-2)) + else: + position_bias = torch.zeros_like(attn) + attn += position_bias + + # Apply mask to attention scores to prevent looking up invalid locations + if mask is not None: + attn = attn.masked_fill(mask[None] == 0, -1e9) + + # Normalize attention scores and add dropout + attn = torch.softmax(attn, dim=3) + attn = self.dropout(attn) + + # Compute attended outputs (product of attention matrix and values) + output = torch.einsum("hblt,hbtv->hblv", [attn, v]) + output = rearrange(output, "head b l v -> b l (head v)") + output = self.fc(output) + + return output, position_bias + + +class TransformerLayer(nn.Module): + def __init__( + self, + d_model: int = 512, + d_cond: int = 64, + n_heads: int = 8, + bidirectional: bool = True, + is_decoder: bool = False, + has_relative_attention_bias: bool = False, + flash_attn: bool = False, + dropout: float = 0.1, + ): + super().__init__() + # Store args + self.is_decoder = is_decoder + + # Create self-attention layer + self.norm_1 = RMSNorm(d_model) + self.film_1 = FiLM(d_cond, d_model) + self.flash_attn = flash_attn + + if flash_attn: + from flash_attn.flash_attention import FlashMHA + self.self_attn = FlashMHA( + embed_dim=d_model, + num_heads=n_heads, + attention_dropout=dropout, + causal=False, + ) + else: + self.self_attn = MultiHeadRelativeAttention( + n_heads, d_model, dropout, bidirectional, has_relative_attention_bias + ) + + # (Optional) Create cross-attention layer + if is_decoder: + self.norm_2 = RMSNorm(d_model) + self.film_2 = FiLM(d_cond, d_model) + self.cross_attn = MultiHeadRelativeAttention( + n_heads, + d_model, + dropout, + bidirectional=True, + has_relative_attention_bias=False, + ) + + # Create last feed-forward layer + self.norm_3 = RMSNorm(d_model) + self.film_3 = FiLM(d_cond, d_model) + self.feed_forward = FeedForward(d_model=d_model, dropout=dropout) + + # Create dropout + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x, + x_mask, + cond, + src=None, + src_mask=None, + position_bias=None, + encoder_decoder_position_bias=None, + ): + """Computes one transformer layer consisting of self attention, (op) cross attention + and feedforward layer + Parameters + ---------- + x : Tensor[B x T_q x D] + x_mask : Tensor[B x T_q] + src : Tensor[B x T_kv x D], optional + src_mask : Tensor[B x T_kv x D], optional + position_bias : Tensor[heads x B x T_q x T_q], optional + Relative position bias for self attention layer + encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional + Relative position bias for cross attention layer + Returns + ------- + Tensor[B x T_q x D] + """ + y = self.norm_1(x) + y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1) + if self.flash_attn: + with torch.autocast(y.device.type, dtype=torch.bfloat16): + y = self.self_attn(y)[0] + else: + y, position_bias = self.self_attn(y, y, y, x_mask, position_bias) + x = x + self.dropout(y) + + if self.is_decoder: + y = self.norm_2(x) + y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1) + y, encoder_decoder_position_bias = self.cross_attn( + y, src, src, src_mask, encoder_decoder_position_bias + ) + x = x + self.dropout(y) + + y = self.norm_3(x) + y = self.film_3( + y.permute( + 0, + 2, + 1, + ), + cond, + ).permute(0, 2, 1) + y = self.feed_forward(y) + x = x + self.dropout(y) + + return x, position_bias, encoder_decoder_position_bias + +def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0): + x = np.linspace(0, 1, n_steps) + a = (0.5 - min_temp) / (max_temp - min_temp) + + x = (x * 12) - 6 + x0 = np.log((1 / a - 1) + 1e-5) / k + y = (1 / (1 + np.exp(- k *(x-x0))))[::-1] + + return y + +class TransformerStack(nn.Module): + def __init__( + self, + d_model: int = 512, + d_cond: int = 64, + n_heads: int = 8, + n_layers: int = 8, + last_layer: bool = True, + bidirectional: bool = True, + flash_attn: bool = False, + is_decoder: bool = False, + dropout: float = 0.1, + ): + super().__init__() + # Store args + self.bidirectional = bidirectional + self.is_decoder = is_decoder + + # Create transformer layers + # In T5, relative attention bias is shared by all layers in the stack + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model, + d_cond, + n_heads, + bidirectional, + is_decoder, + has_relative_attention_bias=True if (i == 0) else False, + flash_attn=flash_attn, + dropout=dropout, + ) + for i in range(n_layers) + ] + ) + + # Perform last normalization + self.norm = RMSNorm(d_model) if last_layer else None + + def subsequent_mask(self, size): + return torch.ones(1, size, size).tril().bool() + + def forward(self, x, x_mask, cond=None, src=None, src_mask=None): + """Computes a full transformer stack + Parameters + ---------- + x : Tensor[B x T_q x D] + x_mask : Tensor[B x T_q] + src : Tensor[B x T_kv x D], optional + src_mask : Tensor[B x T_kv], optional + Returns + ------- + Tensor[B x T_q x D] + """ + + # Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking + if self.is_decoder: + src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2) + + # Convert `x_mask` to (B x T_q x T_q) shape for self attention masking + x_mask = x_mask.unsqueeze(-2) + if not self.bidirectional: + x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device) + + # Initialize position biases + position_bias = None + encoder_decoder_position_bias = None + + # Compute transformer layers + for layer in self.layers: + x, position_bias, encoder_decoder_position_bias = layer( + x=x, + x_mask=x_mask, + cond=cond, + src=src, + src_mask=src_mask, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + ) + + return self.norm(x) if self.norm is not None else x + + +class VampNet(at.ml.BaseModel): + def __init__( + self, + n_heads: int = 20, + n_layers: int = 16, + r_cond_dim: int = 64, + n_codebooks: int = 9, + n_conditioning_codebooks: int = 0, + latent_dim: int = 8, + embedding_dim: int = 1280, + vocab_size: int = 1024, + flash_attn: bool = True, + noise_mode: str = "mask", + dropout: float = 0.1 + ): + super().__init__() + self.n_heads = n_heads + self.n_layers = n_layers + self.r_cond_dim = r_cond_dim + self.n_codebooks = n_codebooks + self.n_conditioning_codebooks = n_conditioning_codebooks + self.embedding_dim = embedding_dim + self.vocab_size = vocab_size + self.latent_dim = latent_dim + self.flash_attn = flash_attn + self.noise_mode = noise_mode + + assert self.noise_mode == "mask", "deprecated" + + self.embedding = CodebookEmbedding( + latent_dim=latent_dim, + n_codebooks=n_codebooks, + vocab_size=vocab_size, + emb_dim=embedding_dim, + special_tokens=["MASK"], + ) + self.mask_token = self.embedding.special_idxs["MASK"] + + self.transformer = TransformerStack( + d_model=embedding_dim, + d_cond=r_cond_dim, + n_heads=n_heads, + n_layers=n_layers, + last_layer=True, + bidirectional=True, + flash_attn=flash_attn, + is_decoder=False, + dropout=dropout, + ) + + # Add final conv layer + self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks + self.classifier = SequentialWithFiLM( + WNConv1d( + embedding_dim, + vocab_size * self.n_predict_codebooks, + kernel_size=1, + padding="same", + # groups=self.n_predict_codebooks, + ), + ) + + def forward(self, x, cond): + x = self.embedding(x) + x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1) + + cond = self.r_embed(cond) + + x = rearrange(x, "b d n -> b n d") + out = self.transformer(x=x, x_mask=x_mask, cond=cond) + out = rearrange(out, "b n d -> b d n") + + out = self.classifier(out, cond) + + out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks) + + return out + + def r_embed(self, r, max_positions=10000): + if self.r_cond_dim > 0: + dtype = r.dtype + + r = _gamma(r) * max_positions + half_dim = self.r_cond_dim // 2 + + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + + if self.r_cond_dim % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + + return emb.to(dtype) + else: + return r + + @torch.no_grad() + def to_signal(self, z, codec): + """ + convert a sequence of latents to a signal. + """ + assert z.ndim == 3 + + signal = at.AudioSignal( + codec.decode( + codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0] + )["audio"], + codec.sample_rate, + ) + + # find where the mask token is and replace it with silence in the audio + for tstep in range(z.shape[-1]): + if torch.any(z[:, :, tstep] == self.mask_token): + sample_idx_0 = tstep * codec.hop_length + sample_idx_1 = sample_idx_0 + codec.hop_length + signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0 + + return signal + + + @torch.no_grad() + def generate( + self, + codec, + time_steps: int = 300, + sampling_steps: int = 36, + start_tokens: Optional[torch.Tensor] = None, + sampling_temperature: float = 1.0, + mask: Optional[torch.Tensor] = None, + mask_temperature: float = 10.5, + typical_filtering=False, + typical_mass=0.2, + typical_min_tokens=1, + top_p=None, + return_signal=True, + seed: int = None, + sample_cutoff: float = 0.5 + ): + if seed is not None: + at.util.seed(seed) + logging.debug(f"beginning generation with {sampling_steps} steps") + + + + ##################### + # resolve initial z # + ##################### + z = start_tokens + + if z is None: + z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to( + self.device + ) + + logging.debug(f"created z with shape {z.shape}") + + + ################# + # resolve mask # + ################# + + if mask is None: + mask = torch.ones_like(z).to(self.device).int() + mask[:, : self.n_conditioning_codebooks, :] = 0.0 + if mask.ndim == 2: + mask = mask[:, None, :].repeat(1, z.shape[1], 1) + # init_mask = mask.clone() + + logging.debug(f"created mask with shape {mask.shape}") + + + ########### + # set up # + ########## + # apply the mask to z + z_masked = z.masked_fill(mask.bool(), self.mask_token) + # logging.debug(f"z_masked: {z_masked}") + + # how many mask tokens to begin with? + num_mask_tokens_at_start = (z_masked == self.mask_token).sum() + logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}") + + # how many codebooks are we inferring vs conditioning on? + n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks + logging.debug(f"n infer codebooks: {n_infer_codebooks}") + + ################# + # begin sampling # + ################# + t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature) + + for i in range(sampling_steps): + logging.debug(f"step {i} of {sampling_steps}") + + # our current schedule step + r = scalar_to_batch_tensor( + (i + 1) / sampling_steps, + z.shape[0] + ).to(z.device) + logging.debug(f"r: {r}") + + # get latents + latents = self.embedding.from_codes(z_masked, codec) + logging.debug(f"computed latents with shape: {latents.shape}") + + + # infer from latents + # NOTE: this collapses the codebook dimension into the sequence dimension + logits = self.forward(latents, r) # b, prob, seq + logits = logits.permute(0, 2, 1) # b, seq, prob + b = logits.shape[0] + + logging.debug(f"permuted logits with shape: {logits.shape}") + + sampled_z, selected_probs = sample_from_logits( + logits, sample=( + (i / sampling_steps) <= sample_cutoff + ), + temperature=t_sched[i], + typical_filtering=typical_filtering, typical_mass=typical_mass, + typical_min_tokens=typical_min_tokens, + top_k=None, top_p=top_p, return_probs=True, + ) + + logging.debug(f"sampled z with shape: {sampled_z.shape}") + + # flatten z_masked and mask, so we can deal with the sampling logic + # we'll unflatten them at the end of the loop for the next forward pass + # remove conditioning codebooks, we'll add them back at the end + z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :]) + + mask = (z_masked == self.mask_token).int() + + # update the mask, remove conditioning codebooks from the mask + logging.debug(f"updated mask with shape: {mask.shape}") + # add z back into sampled z where the mask was false + sampled_z = torch.where( + mask.bool(), sampled_z, z_masked + ) + logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}") + + # ignore any tokens that weren't masked + selected_probs = torch.where( + mask.bool(), selected_probs, torch.inf + ) + + # get the num tokens to mask, according to the schedule + num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long() + logging.debug(f"num to mask: {num_to_mask}") + + if i != (sampling_steps - 1): + num_to_mask = torch.maximum( + torch.tensor(1), + torch.minimum( + mask.sum(dim=-1, keepdim=True) - 1, + num_to_mask + ) + ) + + + # get our new mask + mask = mask_by_random_topk( + num_to_mask, selected_probs, mask_temperature * (1-r) + ) + + # update the mask + z_masked = torch.where( + mask.bool(), self.mask_token, sampled_z + ) + logging.debug(f"updated z_masked with shape: {z_masked.shape}") + + z_masked = codebook_unflatten(z_masked, n_infer_codebooks) + mask = codebook_unflatten(mask, n_infer_codebooks) + logging.debug(f"unflattened z_masked with shape: {z_masked.shape}") + + # add conditioning codebooks back to z_masked + z_masked = torch.cat( + (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1 + ) + logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}") + + + # add conditioning codebooks back to sampled_z + sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks) + sampled_z = torch.cat( + (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1 + ) + + logging.debug(f"finished sampling") + + if return_signal: + return self.to_signal(sampled_z, codec) + else: + return sampled_z + +def sample_from_logits( + logits, + sample: bool = True, + temperature: float = 1.0, + top_k: int = None, + top_p: float = None, + typical_filtering: bool = False, + typical_mass: float = 0.2, + typical_min_tokens: int = 1, + return_probs: bool = False + ): + """Convenience function to sample from a categorial distribution with input as + unnormalized logits. + + Parameters + ---------- + logits : Tensor[..., vocab_size] + config: SamplingConfig + The set of hyperparameters to be used for sampling + sample : bool, optional + Whether to perform multinomial sampling, by default True + temperature : float, optional + Scaling parameter when multinomial samping, by default 1.0 + top_k : int, optional + Restricts sampling to only `top_k` values acc. to probability, + by default None + top_p : float, optional + Restricts sampling to only those values with cumulative + probability = `top_p`, by default None + + Returns + ------- + Tensor[...] + Sampled tokens + """ + shp = logits.shape[:-1] + + if typical_filtering: + typical_filter(logits, + typical_mass=typical_mass, + typical_min_tokens=typical_min_tokens + ) + + # Apply top_k sampling + if top_k is not None: + v, _ = logits.topk(top_k) + logits[logits < v[..., [-1]]] = -float("inf") + + # Apply top_p (nucleus) sampling + if top_p is not None and top_p < 1.0: + v, sorted_indices = logits.sort(descending=True) + cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + # Right shift indices_to_remove to keep 1st token over threshold + sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[ + ..., :-1 + ] + + # Compute indices_to_remove in unsorted array + indices_to_remove = sorted_indices_to_remove.scatter( + -1, sorted_indices, sorted_indices_to_remove + ) + + logits[indices_to_remove] = -float("inf") + + # Perform multinomial sampling after normalizing logits + probs = ( + F.softmax(logits / temperature, dim=-1) + if temperature > 0 + else logits.softmax(dim=-1) + ) + token = ( + probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp) + if sample + else logits.argmax(-1) + ) + + if return_probs: + token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1) + return token, token_probs + else: + return token + + + +def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0): + """ + Args: + num_to_mask (int): number of tokens to mask + probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq) + temperature (float, optional): temperature. Defaults to 1.0. + """ + logging.debug(f"masking by random topk") + logging.debug(f"num to mask: {num_to_mask}") + logging.debug(f"probs shape: {probs.shape}") + logging.debug(f"temperature: {temperature}") + logging.debug("") + + confidence = torch.log(probs) + temperature * gumbel_noise_like(probs) + logging.debug(f"confidence shape: {confidence.shape}") + + sorted_confidence, sorted_idx = confidence.sort(dim=-1) + logging.debug(f"sorted confidence shape: {sorted_confidence.shape}") + logging.debug(f"sorted idx shape: {sorted_idx.shape}") + + # get the cut off threshold, given the mask length + cut_off = torch.take_along_dim( + sorted_confidence, num_to_mask, axis=-1 + ) + logging.debug(f"cut off shape: {cut_off.shape}") + + # mask out the tokens + mask = confidence < cut_off + logging.debug(f"mask shape: {mask.shape}") + + return mask + +def typical_filter( + logits, + typical_mass: float = 0.95, + typical_min_tokens: int = 1,): + nb, nt, _ = logits.shape + x_flat = rearrange(logits, "b t l -> (b t ) l") + x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1) + x_flat_norm_p = torch.exp(x_flat_norm) + entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True) + + c_flat_shifted = torch.abs((-x_flat_norm) - entropy) + c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False) + x_flat_cumsum = ( + x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1) + ) + + last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1) + sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather( + 1, last_ind.view(-1, 1) + ) + if typical_min_tokens > 1: + sorted_indices_to_remove[..., :typical_min_tokens] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + 1, x_flat_indices, sorted_indices_to_remove + ) + x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf")) + logits = rearrange(x_flat, "(b t) l -> b t l", t=nt) + return logits + + +if __name__ == "__main__": + # import argbind + from .layers import num_params + + VampNet = argbind.bind(VampNet) + + @argbind.bind(without_prefix=True) + def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0): + seq_len = int(32000 / 512 * seq_len_s) + + model = VampNet().to(device) + + z = torch.randint( + 0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len) + ).to(device) + + r = torch.zeros(batch_size).to(device) + + z_mask_latent = torch.rand( + batch_size, model.latent_dim * model.n_codebooks, seq_len + ).to(device) + z_hat = model(z_mask_latent, r) + + pred = z_hat.argmax(dim=1) + pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks) + + print(f"model has {num_params(model)/1e6:<.3f}M parameters") + print(f"prediction has shape {pred.shape}") + breakpoint() + + args = argbind.parse_args() + with argbind.scope(args): + try_model() + + diff --git a/melodytalk/dependencies/vampnet/scheduler.py b/melodytalk/dependencies/vampnet/scheduler.py new file mode 100644 index 0000000..a57108c --- /dev/null +++ b/melodytalk/dependencies/vampnet/scheduler.py @@ -0,0 +1,47 @@ +import copy +from typing import List + +import torch + +class NoamScheduler: + """OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf + Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + d_model: int = 512, + factor: float = 1.0, + warmup: int = 4000, + ): + # Store hparams + self.warmup = warmup + self.factor = factor + self.d_model = d_model + + # Initialize variables `lr` and `steps` + self.lr = None + self.steps = 0 + + # Store the optimizer + self.optimizer = optimizer + + def state_dict(self): + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + def step(self): + self.steps += 1 + self.lr = self.factor * ( + self.d_model ** (-0.5) + * min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5)) + ) + + for p in self.optimizer.param_groups: + p["lr"] = self.lr + diff --git a/melodytalk/dependencies/vampnet/util.py b/melodytalk/dependencies/vampnet/util.py new file mode 100644 index 0000000..8fbf8fb --- /dev/null +++ b/melodytalk/dependencies/vampnet/util.py @@ -0,0 +1,46 @@ +import tqdm + +import torch +from einops import rearrange + +def scalar_to_batch_tensor(x, batch_size): + return torch.tensor(x).repeat(batch_size) + + +def parallelize( + fn, + *iterables, + parallel: str = "thread_map", + **kwargs + ): + if parallel == "thread_map": + from tqdm.contrib.concurrent import thread_map + return thread_map( + fn, + *iterables, + **kwargs + ) + elif parallel == "process_map": + from tqdm.contrib.concurrent import process_map + return process_map( + fn, + *iterables, + **kwargs + ) + elif parallel == "single": + return [fn(x) for x in tqdm.tqdm(*iterables)] + else: + raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}") + +def codebook_flatten(tokens: torch.Tensor): + """ + flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time) + """ + return rearrange(tokens, "b c t -> b (t c)") + +def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None): + """ + unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time) + """ + tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c) + return tokens