diff --git a/melodytalk/audiocraft/__init__.py b/melodytalk/audiocraft/__init__.py new file mode 100644 index 0000000..6b8594f --- /dev/null +++ b/melodytalk/audiocraft/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from . import data, modules, models + +__version__ = '0.0.2a2' diff --git a/melodytalk/audiocraft/data/__init__.py b/melodytalk/audiocraft/data/__init__.py new file mode 100644 index 0000000..708a3dc --- /dev/null +++ b/melodytalk/audiocraft/data/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from . import audio, audio_dataset diff --git a/melodytalk/audiocraft/data/audio.py b/melodytalk/audiocraft/data/audio.py new file mode 100644 index 0000000..2048df6 --- /dev/null +++ b/melodytalk/audiocraft/data/audio.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Audio IO methods are defined in this module (info, read, write), +We rely on av library for faster read when possible, otherwise on torchaudio. +""" + +from dataclasses import dataclass +from pathlib import Path +import logging +import typing as tp + +import numpy as np +import soundfile +import torch +from torch.nn import functional as F +import torchaudio as ta + +import av + +from .audio_utils import f32_pcm, i16_pcm, normalize_audio + + +_av_initialized = False + + +def _init_av(): + global _av_initialized + if _av_initialized: + return + logger = logging.getLogger('libav.mp3') + logger.setLevel(logging.ERROR) + _av_initialized = True + + +@dataclass(frozen=True) +class AudioFileInfo: + sample_rate: int + duration: float + channels: int + + +def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: + _init_av() + with av.open(str(filepath)) as af: + stream = af.streams.audio[0] + sample_rate = stream.codec_context.sample_rate + duration = float(stream.duration * stream.time_base) + channels = stream.channels + return AudioFileInfo(sample_rate, duration, channels) + + +def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: + info = soundfile.info(filepath) + return AudioFileInfo(info.samplerate, info.duration, info.channels) + + +def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: + # torchaudio no longer returns useful duration informations for some formats like mp3s. + filepath = Path(filepath) + if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info + # ffmpeg has some weird issue with flac. + return _soundfile_info(filepath) + else: + return _av_info(filepath) + + +def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]: + """FFMPEG-based audio file reading using PyAV bindings. + Soundfile cannot read mp3 and av_read is more efficient than torchaudio. + + Args: + filepath (str or Path): Path to audio file to read. + seek_time (float): Time at which to start reading in the file. + duration (float): Duration to read from the file. If set to -1, the whole file is read. + Returns: + Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate + """ + _init_av() + with av.open(str(filepath)) as af: + stream = af.streams.audio[0] + sr = stream.codec_context.sample_rate + num_frames = int(sr * duration) if duration >= 0 else -1 + frame_offset = int(sr * seek_time) + # we need a small negative offset otherwise we get some edge artifact + # from the mp3 decoder. + af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream) + frames = [] + length = 0 + for frame in af.decode(streams=stream.index): + current_offset = int(frame.rate * frame.pts * frame.time_base) + strip = max(0, frame_offset - current_offset) + buf = torch.from_numpy(frame.to_ndarray()) + if buf.shape[0] != stream.channels: + buf = buf.view(-1, stream.channels).t() + buf = buf[:, strip:] + frames.append(buf) + length += buf.shape[1] + if num_frames > 0 and length >= num_frames: + break + assert frames + # If the above assert fails, it is likely because we seeked past the end of file point, + # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp. + # This will need proper debugging, in due time. + wav = torch.cat(frames, dim=1) + assert wav.shape[0] == stream.channels + if num_frames > 0: + wav = wav[:, :num_frames] + return f32_pcm(wav), sr + + +def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., + duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]: + """Read audio by picking the most appropriate backend tool based on the audio format. + + Args: + filepath (str or Path): Path to audio file to read. + seek_time (float): Time at which to start reading in the file. + duration (float): Duration to read from the file. If set to -1, the whole file is read. + pad (bool): Pad output audio if not reaching expected duration. + Returns: + Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate. + """ + fp = Path(filepath) + if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg + # There is some bug with ffmpeg and reading flac + info = _soundfile_info(filepath) + frames = -1 if duration <= 0 else int(duration * info.sample_rate) + frame_offset = int(seek_time * info.sample_rate) + wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32) + assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}" + wav = torch.from_numpy(wav).t().contiguous() + if len(wav.shape) == 1: + wav = torch.unsqueeze(wav, 0) + elif ( + fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() + and duration <= 0 and seek_time == 0 + ): + # Torchaudio is faster if we load an entire file at once. + wav, sr = ta.load(fp) + else: + wav, sr = _av_read(filepath, seek_time, duration) + if pad and duration > 0: + expected_frames = int(duration * sr) + wav = F.pad(wav, (0, expected_frames - wav.shape[-1])) + return wav, sr + + +def audio_write(stem_name: tp.Union[str, Path], + wav: torch.Tensor, sample_rate: int, + format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, + strategy: str = 'peak', peak_clip_headroom_db: float = 1, + rms_headroom_db: float = 18, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, + log_clipping: bool = True, make_parent_dir: bool = True, + add_suffix: bool = True) -> Path: + """Convenience function for saving audio to disk. Returns the filename the audio was written to. + + Args: + stem_name (str or Path): Filename without extension which will be added automatically. + format (str): Either "wav" or "mp3". + mp3_rate (int): kbps when using mp3s. + normalize (bool): if `True` (default), normalizes according to the prescribed + strategy (see after). If `False`, the strategy is only used in case clipping + would happen. + strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', + i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square + with extra headroom to avoid clipping. 'clip' just clips. + peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. + rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger + than the `peak_clip` one to avoid further clipping. + loudness_headroom_db (float): Target loudness for loudness normalization. + loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'. + when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still + occurs despite strategy (only for 'rms'). + make_parent_dir (bool): Make parent directory if it doesn't exist. + Returns: + Path: Path of the saved audio. + """ + assert wav.dtype.is_floating_point, "wav is not floating point" + if wav.dim() == 1: + wav = wav[None] + elif wav.dim() > 2: + raise ValueError("Input wav should be at most 2 dimension.") + assert wav.isfinite().all() + wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, + rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping, + sample_rate=sample_rate, stem_name=str(stem_name)) + kwargs: dict = {} + if format == 'mp3': + suffix = '.mp3' + kwargs.update({"compression": mp3_rate}) + elif format == 'wav': + wav = i16_pcm(wav) + suffix = '.wav' + kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16}) + else: + raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.") + if not add_suffix: + suffix = '' + path = Path(str(stem_name) + suffix) + if make_parent_dir: + path.parent.mkdir(exist_ok=True, parents=True) + try: + ta.save(path, wav, sample_rate, **kwargs) + except Exception: + if path.exists(): + # we do not want to leave half written files around. + path.unlink() + raise + return path diff --git a/melodytalk/audiocraft/data/audio_dataset.py b/melodytalk/audiocraft/data/audio_dataset.py new file mode 100644 index 0000000..cf21422 --- /dev/null +++ b/melodytalk/audiocraft/data/audio_dataset.py @@ -0,0 +1,525 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import copy +from concurrent.futures import ThreadPoolExecutor, Future +from dataclasses import dataclass, fields +from contextlib import ExitStack +import gzip +import json +import logging +import os +from pathlib import Path +import random +import sys +import typing as tp + +import torch +import torch.nn.functional as F + +from .audio import audio_read, audio_info +from .audio_utils import convert_audio +from .zip import PathInZip + +try: + import dora +except ImportError: + dora = None # type: ignore + + +@dataclass(order=True) +class BaseInfo: + + @classmethod + def _dict2fields(cls, dictionary: dict): + return { + field.name: dictionary[field.name] + for field in fields(cls) if field.name in dictionary + } + + @classmethod + def from_dict(cls, dictionary: dict): + _dictionary = cls._dict2fields(dictionary) + return cls(**_dictionary) + + def to_dict(self): + return { + field.name: self.__getattribute__(field.name) + for field in fields(self) + } + + +@dataclass(order=True) +class AudioMeta(BaseInfo): + path: str + duration: float + sample_rate: int + amplitude: tp.Optional[float] = None + weight: tp.Optional[float] = None + # info_path is used to load additional information about the audio file that is stored in zip files. + info_path: tp.Optional[PathInZip] = None + + @classmethod + def from_dict(cls, dictionary: dict): + base = cls._dict2fields(dictionary) + if 'info_path' in base and base['info_path'] is not None: + base['info_path'] = PathInZip(base['info_path']) + return cls(**base) + + def to_dict(self): + d = super().to_dict() + if d['info_path'] is not None: + d['info_path'] = str(d['info_path']) + return d + + +@dataclass(order=True) +class SegmentInfo(BaseInfo): + meta: AudioMeta + seek_time: float + n_frames: int # actual number of frames without padding + total_frames: int # total number of frames, padding included + sample_rate: int # actual sample rate + + +DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'] + +logger = logging.getLogger(__name__) + + +def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta: + """AudioMeta from a path to an audio file. + + Args: + file_path (str): Resolved path of valid audio file. + minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). + Returns: + AudioMeta: Audio file path and its metadata. + """ + info = audio_info(file_path) + amplitude: tp.Optional[float] = None + if not minimal: + wav, sr = audio_read(file_path) + amplitude = wav.abs().max().item() + return AudioMeta(file_path, info.duration, info.sample_rate, amplitude) + + +def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: + """If Dora is available as a dependency, try to resolve potential relative paths + in list of AudioMeta. This method is expected to be used when loading meta from file. + + Args: + m (AudioMeta): Audio meta to resolve. + fast (bool): If True, uses a really fast check for determining if a file is already absolute or not. + Only valid on Linux/Mac. + Returns: + AudioMeta: Audio meta with resolved path. + """ + def is_abs(m): + if fast: + return str(m)[0] == '/' + else: + os.path.isabs(str(m)) + + if not dora: + return m + + if not is_abs(m.path): + m.path = dora.git_save.to_absolute_path(m.path) + if m.info_path is not None and not is_abs(m.info_path.zip_path): + m.info_path.zip_path = dora.git_save.to_absolute_path(m.path) + return m + + +def find_audio_files(path: tp.Union[Path, str], + exts: tp.List[str] = DEFAULT_EXTS, + resolve: bool = True, + minimal: bool = True, + progress: bool = False, + workers: int = 0) -> tp.List[AudioMeta]: + """Build a list of AudioMeta from a given path, + collecting relevant audio files and fetching meta info. + + Args: + path (str or Path): Path to folder containing audio files. + exts (list of str): List of file extensions to consider for audio files. + minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). + progress (bool): Whether to log progress on audio files collection. + workers (int): number of parallel workers, if 0, use only the current thread. + Returns: + List[AudioMeta]: List of audio file path and its metadata. + """ + audio_files = [] + futures: tp.List[Future] = [] + pool: tp.Optional[ThreadPoolExecutor] = None + with ExitStack() as stack: + if workers > 0: + pool = ThreadPoolExecutor(workers) + stack.enter_context(pool) + + if progress: + print("Finding audio files...") + for root, folders, files in os.walk(path, followlinks=True): + for file in files: + full_path = Path(root) / file + if full_path.suffix.lower() in exts: + audio_files.append(full_path) + if pool is not None: + futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal)) + if progress: + print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr) + + if progress: + print("Getting audio metadata...") + meta: tp.List[AudioMeta] = [] + for idx, file_path in enumerate(audio_files): + try: + if pool is None: + m = _get_audio_meta(str(file_path), minimal) + else: + m = futures[idx].result() + if resolve: + m = _resolve_audio_meta(m) + except Exception as err: + print("Error with", str(file_path), err, file=sys.stderr) + continue + meta.append(m) + if progress: + print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr) + meta.sort() + return meta + + +def load_audio_meta(path: tp.Union[str, Path], + resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]: + """Load list of AudioMeta from an optionally compressed json file. + + Args: + path (str or Path): Path to JSON file. + resolve (bool): Whether to resolve the path from AudioMeta (default=True). + fast (bool): activates some tricks to make things faster. + Returns: + List[AudioMeta]: List of audio file path and its total duration. + """ + open_fn = gzip.open if str(path).lower().endswith('.gz') else open + with open_fn(path, 'rb') as fp: # type: ignore + lines = fp.readlines() + meta = [] + for line in lines: + d = json.loads(line) + m = AudioMeta.from_dict(d) + if resolve: + m = _resolve_audio_meta(m, fast=fast) + meta.append(m) + return meta + + +def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]): + """Save the audio metadata to the file pointer as json. + + Args: + path (str or Path): Path to JSON file. + metadata (list of BaseAudioMeta): List of audio meta to save. + """ + Path(path).parent.mkdir(exist_ok=True, parents=True) + open_fn = gzip.open if str(path).lower().endswith('.gz') else open + with open_fn(path, 'wb') as fp: # type: ignore + for m in meta: + json_str = json.dumps(m.to_dict()) + '\n' + json_bytes = json_str.encode('utf-8') + fp.write(json_bytes) + + +class AudioDataset: + """Base audio dataset. + + The dataset takes a list of AudioMeta and create a dataset composed of segments of audio + and potentially additional information, by creating random segments from the list of audio + files referenced in the metadata and applying minimal data pre-processing such as resampling, + mixing of channels, padding, etc. + + If no segment_duration value is provided, the AudioDataset will return the full wav for each + audio file. Otherwise, it will randomly sample audio files and create a segment of the specified + duration, applying padding if required. + + By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True + allows to return a tuple containing the torch Tensor and additional metadata on the segment and the + original audio meta. + + Args: + meta (tp.List[AudioMeta]): List of audio files metadata. + segment_duration (float): Optional segment duration of audio to load. + If not specified, the dataset will load the full audio segment from the file. + shuffle (bool): Set to `True` to have the data reshuffled at every epoch. + sample_rate (int): Target sample rate of the loaded audio samples. + channels (int): Target number of channels of the loaded audio samples. + sample_on_duration (bool): Set to `True` to sample segments with probability + dependent on audio file duration. This is only used if `segment_duration` is provided. + sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of + `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product + of the file duration and file weight. This is only used if `segment_duration` is provided. + min_segment_ratio (float): Minimum segment ratio to use when the audio file + is shorter than the desired segment. + max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset. + return_info (bool): Whether to return the wav only or return wav along with segment info and metadata. + min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided + audio shorter than this will be filtered out. + max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided + audio longer than this will be filtered out. + """ + def __init__(self, + meta: tp.List[AudioMeta], + segment_duration: tp.Optional[float] = None, + shuffle: bool = True, + num_samples: int = 10_000, + sample_rate: int = 48_000, + channels: int = 2, + pad: bool = True, + sample_on_duration: bool = True, + sample_on_weight: bool = True, + min_segment_ratio: float = 0.5, + max_read_retry: int = 10, + return_info: bool = False, + min_audio_duration: tp.Optional[float] = None, + max_audio_duration: tp.Optional[float] = None + ): + assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.' + assert segment_duration is None or segment_duration > 0 + assert segment_duration is None or min_segment_ratio >= 0 + logging.debug(f'sample_on_duration: {sample_on_duration}') + logging.debug(f'sample_on_weight: {sample_on_weight}') + logging.debug(f'pad: {pad}') + logging.debug(f'min_segment_ratio: {min_segment_ratio}') + + self.segment_duration = segment_duration + self.min_segment_ratio = min_segment_ratio + self.max_audio_duration = max_audio_duration + self.min_audio_duration = min_audio_duration + if self.min_audio_duration is not None and self.max_audio_duration is not None: + assert self.min_audio_duration <= self.max_audio_duration + self.meta: tp.List[AudioMeta] = self._filter_duration(meta) + assert len(self.meta) # Fail fast if all data has been filtered. + self.total_duration = sum(d.duration for d in self.meta) + + if segment_duration is None: + num_samples = len(self.meta) + self.num_samples = num_samples + self.shuffle = shuffle + self.sample_rate = sample_rate + self.channels = channels + self.pad = pad + self.sample_on_weight = sample_on_weight + self.sample_on_duration = sample_on_duration + self.sampling_probabilities = self._get_sampling_probabilities() + self.max_read_retry = max_read_retry + self.return_info = return_info + + def __len__(self): + return self.num_samples + + def _get_sampling_probabilities(self, normalized: bool = True): + """Return the sampling probabilities for each file inside `self.meta`. + """ + scores: tp.List[float] = [] + for file_meta in self.meta: + score = 1. + if self.sample_on_weight and file_meta.weight is not None: + score *= file_meta.weight + if self.sample_on_duration: + score *= file_meta.duration + scores.append(score) + probabilities = torch.tensor(scores) + if normalized: + probabilities /= probabilities.sum() + return probabilities + + def sample_file(self, rng: torch.Generator) -> AudioMeta: + """Sample a given file from `self.meta`. Can be overriden in subclasses. + This is only called if `segment_duration` is not None. + + You must use the provided random number generator `rng` for reproducibility. + """ + if not self.sample_on_weight and not self.sample_on_duration: + file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item()) + else: + file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item()) + + return self.meta[file_index] + + def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]: + if self.segment_duration is None: + file_meta = self.meta[index] + out, sr = audio_read(file_meta.path) + out = convert_audio(out, sr, self.sample_rate, self.channels) + n_frames = out.shape[-1] + segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames, + sample_rate=self.sample_rate) + else: + rng = torch.Generator() + if self.shuffle: + # We use index, plus extra randomness + rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) + else: + # We only use index + rng.manual_seed(index) + + for retry in range(self.max_read_retry): + file_meta = self.sample_file(rng) + # We add some variance in the file position even if audio file is smaller than segment + # without ending up with empty segments + max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio) + seek_time = torch.rand(1, generator=rng).item() * max_seek + try: + out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False) + out = convert_audio(out, sr, self.sample_rate, self.channels) + n_frames = out.shape[-1] + target_frames = int(self.segment_duration * self.sample_rate) + if self.pad: + out = F.pad(out, (0, target_frames - n_frames)) + segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames, + sample_rate=self.sample_rate) + except Exception as exc: + logger.warning("Error opening file %s: %r", file_meta.path, exc) + if retry == self.max_read_retry - 1: + raise + else: + break + + if self.return_info: + # Returns the wav and additional information on the wave segment + return out, segment_info + else: + return out + + def collater(self, samples): + """The collater function has to be provided to the dataloader + if AudioDataset has return_info=True in order to properly collate + the samples of a batch. + """ + if self.segment_duration is None and len(samples) > 1: + assert self.pad, "Must allow padding when batching examples of different durations." + + # In this case the audio reaching the collater is of variable length as segment_duration=None. + to_pad = self.segment_duration is None and self.pad + if to_pad: + max_len = max([wav.shape[-1] for wav, _ in samples]) + + def _pad_wav(wav): + return F.pad(wav, (0, max_len - wav.shape[-1])) + + if self.return_info: + if len(samples) > 0: + assert len(samples[0]) == 2 + assert isinstance(samples[0][0], torch.Tensor) + assert isinstance(samples[0][1], SegmentInfo) + + wavs = [wav for wav, _ in samples] + segment_infos = [copy.deepcopy(info) for _, info in samples] + + if to_pad: + # Each wav could be of a different duration as they are not segmented. + for i in range(len(samples)): + # Determines the total legth of the signal with padding, so we update here as we pad. + segment_infos[i].total_frames = max_len + wavs[i] = _pad_wav(wavs[i]) + + wav = torch.stack(wavs) + return wav, segment_infos + else: + assert isinstance(samples[0], torch.Tensor) + if to_pad: + samples = [_pad_wav(s) for s in samples] + return torch.stack(samples) + + def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: + """Filters out audio files with short durations. + Removes from meta files that have durations that will not allow to samples examples from them. + """ + orig_len = len(meta) + + # Filter data that is too short. + if self.min_audio_duration is not None: + meta = [m for m in meta if m.duration >= self.min_audio_duration] + + # Filter data that is too long. + if self.max_audio_duration is not None: + meta = [m for m in meta if m.duration <= self.max_audio_duration] + + filtered_len = len(meta) + removed_percentage = 100*(1-float(filtered_len)/orig_len) + msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage + if removed_percentage < 10: + logging.debug(msg) + else: + logging.warning(msg) + return meta + + @classmethod + def from_meta(cls, root: tp.Union[str, Path], **kwargs): + """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file. + + Args: + root (str or Path): Path to root folder containing audio files. + kwargs: Additional keyword arguments for the AudioDataset. + """ + root = Path(root) + if root.is_dir(): + if (root / 'data.jsonl').exists(): + root = root / 'data.jsonl' + elif (root / 'data.jsonl.gz').exists(): + root = root / 'data.jsonl.gz' + else: + raise ValueError("Don't know where to read metadata from in the dir. " + "Expecting either a data.jsonl or data.jsonl.gz file but none found.") + meta = load_audio_meta(root) + return cls(meta, **kwargs) + + @classmethod + def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True, + exts: tp.List[str] = DEFAULT_EXTS, **kwargs): + """Instantiate AudioDataset from a path containing (possibly nested) audio files. + + Args: + root (str or Path): Path to root folder containing audio files. + minimal_meta (bool): Whether to only load minimal metadata or not. + exts (list of str): Extensions for audio files. + kwargs: Additional keyword arguments for the AudioDataset. + """ + root = Path(root) + if root.is_file(): + meta = load_audio_meta(root, resolve=True) + else: + meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True) + return cls(meta, **kwargs) + + +def main(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + parser = argparse.ArgumentParser( + prog='audio_dataset', + description='Generate .jsonl files by scanning a folder.') + parser.add_argument('root', help='Root folder with all the audio files') + parser.add_argument('output_meta_file', + help='Output file to store the metadata, ') + parser.add_argument('--complete', + action='store_false', dest='minimal', default=True, + help='Retrieve all metadata, even the one that are expansive ' + 'to compute (e.g. normalization).') + parser.add_argument('--resolve', + action='store_true', default=False, + help='Resolve the paths to be absolute and with no symlinks.') + parser.add_argument('--workers', + default=10, type=int, + help='Number of workers.') + args = parser.parse_args() + meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True, + resolve=args.resolve, minimal=args.minimal, workers=args.workers) + save_audio_meta(args.output_meta_file, meta) + + +if __name__ == '__main__': + main() diff --git a/melodytalk/audiocraft/data/audio_utils.py b/melodytalk/audiocraft/data/audio_utils.py new file mode 100644 index 0000000..76d4bc2 --- /dev/null +++ b/melodytalk/audiocraft/data/audio_utils.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import typing as tp + +import julius +import torch +import torchaudio + + +def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor: + """Convert audio to the given number of channels. + + Args: + wav (torch.Tensor): Audio wave of shape [B, C, T]. + channels (int): Expected number of channels as output. + Returns: + torch.Tensor: Downmixed or unchanged audio wave [B, C, T]. + """ + *shape, src_channels, length = wav.shape + if src_channels == channels: + pass + elif channels == 1: + # Case 1: + # The caller asked 1-channel audio, and the stream has multiple + # channels, downmix all channels. + wav = wav.mean(dim=-2, keepdim=True) + elif src_channels == 1: + # Case 2: + # The caller asked for multiple channels, but the input file has + # a single channel, replicate the audio over all channels. + wav = wav.expand(*shape, channels, length) + elif src_channels >= channels: + # Case 3: + # The caller asked for multiple channels, and the input file has + # more channels than requested. In that case return the first channels. + wav = wav[..., :channels, :] + else: + # Case 4: What is a reasonable choice here? + raise ValueError('The audio file has less channels than requested but is not mono.') + return wav + + +def convert_audio(wav: torch.Tensor, from_rate: float, + to_rate: float, to_channels: int) -> torch.Tensor: + """Convert audio to new sample rate and number of audio channels. + """ + wav = julius.resample_frac(wav, int(from_rate), int(to_rate)) + wav = convert_audio_channels(wav, to_channels) + return wav + + +def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, energy_floor: float = 2e-3): + """Normalize an input signal to a user loudness in dB LKFS. + Audio loudness is defined according to the ITU-R BS.1770-4 recommendation. + + Args: + wav (torch.Tensor): Input multichannel audio data. + sample_rate (int): Sample rate. + loudness_headroom_db (float): Target loudness of the output in dB LUFS. + loudness_compressor (bool): Uses tanh for soft clipping. + energy_floor (float): anything below that RMS level will not be rescaled. + Returns: + output (torch.Tensor): Loudness normalized output data. + """ + energy = wav.pow(2).mean().sqrt().item() + if energy < energy_floor: + return wav + transform = torchaudio.transforms.Loudness(sample_rate) + input_loudness_db = transform(wav).item() + # calculate the gain needed to scale to the desired loudness level + delta_loudness = -loudness_headroom_db - input_loudness_db + gain = 10.0 ** (delta_loudness / 20.0) + output = gain * wav + if loudness_compressor: + output = torch.tanh(output) + assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt()) + return output + + +def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None: + """Utility function to clip the audio with logging if specified.""" + max_scale = wav.abs().max() + if log_clipping and max_scale > 1: + clamp_prob = (wav.abs() > 1).float().mean().item() + print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):", + clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr) + wav.clamp_(-1, 1) + + +def normalize_audio(wav: torch.Tensor, normalize: bool = True, + strategy: str = 'peak', peak_clip_headroom_db: float = 1, + rms_headroom_db: float = 18, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, log_clipping: bool = False, + sample_rate: tp.Optional[int] = None, + stem_name: tp.Optional[str] = None) -> torch.Tensor: + """Normalize the audio according to the prescribed strategy (see after). + + Args: + wav (torch.Tensor): Audio data. + normalize (bool): if `True` (default), normalizes according to the prescribed + strategy (see after). If `False`, the strategy is only used in case clipping + would happen. + strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', + i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square + with extra headroom to avoid clipping. 'clip' just clips. + peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. + rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger + than the `peak_clip` one to avoid further clipping. + loudness_headroom_db (float): Target loudness for loudness normalization. + loudness_compressor (bool): If True, uses tanh based soft clipping. + log_clipping (bool): If True, basic logging on stderr when clipping still + occurs despite strategy (only for 'rms'). + sample_rate (int): Sample rate for the audio data (required for loudness). + stem_name (Optional[str]): Stem name for clipping logging. + Returns: + torch.Tensor: Normalized audio. + """ + scale_peak = 10 ** (-peak_clip_headroom_db / 20) + scale_rms = 10 ** (-rms_headroom_db / 20) + if strategy == 'peak': + rescaling = (scale_peak / wav.abs().max()) + if normalize or rescaling < 1: + wav = wav * rescaling + elif strategy == 'clip': + wav = wav.clamp(-scale_peak, scale_peak) + elif strategy == 'rms': + mono = wav.mean(dim=0) + rescaling = scale_rms / mono.pow(2).mean().sqrt() + if normalize or rescaling < 1: + wav = wav * rescaling + _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) + elif strategy == 'loudness': + assert sample_rate is not None, "Loudness normalization requires sample rate." + wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor) + _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) + else: + assert wav.abs().max() < 1 + assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'" + return wav + + +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. + """ + if wav.dtype.is_floating_point: + return wav + else: + assert wav.dtype == torch.int16 + return wav.float() / 2**15 + + +def i16_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to int 16 bits PCM format. + + ..Warning:: There exist many formula for doing this convertion. None are perfect + due to the asymetry of the int16 range. One either have possible clipping, DC offset, + or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom, + it is possible that `i16_pcm(f32_pcm)) != Identity`. + """ + if wav.dtype.is_floating_point: + assert wav.abs().max() <= 1 + candidate = (wav * 2 ** 15).round() + if candidate.max() >= 2 ** 15: # clipping would occur + candidate = (wav * (2 ** 15 - 1)).round() + return candidate.short() + else: + assert wav.dtype == torch.int16 + return wav diff --git a/melodytalk/audiocraft/data/zip.py b/melodytalk/audiocraft/data/zip.py new file mode 100644 index 0000000..1f11542 --- /dev/null +++ b/melodytalk/audiocraft/data/zip.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing +import zipfile + +from dataclasses import dataclass +from functools import lru_cache +from typing_extensions import Literal + + +DEFAULT_SIZE = 32 +MODE = Literal['r', 'w', 'x', 'a'] + + +@dataclass(order=True) +class PathInZip: + """Class for holding a path of file within a zip file. + + Args: + path: The convention is : + Let's assume there is a zip file /some/location/foo.zip + and inside of it is a json file located at /data/file1.json, + Then we expect path = "/some/location/foo.zip:/data/file1.json" + """ + + INFO_PATH_SEP = ':' + zip_path: str + file_path: str + + def __init__(self, path: str) -> None: + split_path = path.split(self.INFO_PATH_SEP) + assert len(split_path) == 2 + self.zip_path, self.file_path = split_path + + @classmethod + def from_paths(cls, zip_path: str, file_path: str): + return cls(zip_path + cls.INFO_PATH_SEP + file_path) + + def __str__(self) -> str: + return self.zip_path + self.INFO_PATH_SEP + self.file_path + + +def _open_zip(path: str, mode: MODE = 'r'): + return zipfile.ZipFile(path, mode) + + +_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip) + + +def set_zip_cache_size(max_size: int): + """Sets the maximal LRU caching for zip file opening. + + Args: + max_size: the maximal LRU cache. + """ + global _cached_open_zip + _cached_open_zip = lru_cache(max_size)(_open_zip) + + +def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: + """Opens a file stored inside a zip and returns a file-like object. + + Args: + path_in_zip: A PathInZip object representing the file to return a file-like object of. + mode: The mode in which to open the file with. + Returns: + A file-like object for PathInZip. + """ + zf = _cached_open_zip(path_in_zip.zip_path) + return zf.open(path_in_zip.file_path) diff --git a/melodytalk/audiocraft/models/__init__.py b/melodytalk/audiocraft/models/__init__.py new file mode 100644 index 0000000..92c7a48 --- /dev/null +++ b/melodytalk/audiocraft/models/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .musicgen import MusicGen +from .lm import LMModel +from .encodec import CompressionModel, EncodecModel diff --git a/melodytalk/audiocraft/models/builders.py b/melodytalk/audiocraft/models/builders.py new file mode 100644 index 0000000..77ee5f9 --- /dev/null +++ b/melodytalk/audiocraft/models/builders.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +All the functions to build the relevant models and modules +from the Hydra config. +""" + +import typing as tp +import warnings + +import audiocraft +import omegaconf +import torch + +from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa +from .lm import LMModel +from ..modules.codebooks_patterns import ( + CodebooksPatternProvider, + DelayedPatternProvider, + ParallelPatternProvider, + UnrolledPatternProvider, + VALLEPattern, + MusicLMPattern, +) +from ..modules.conditioners import ( + BaseConditioner, + ConditioningProvider, + LUTConditioner, + T5Conditioner, + ConditionFuser, + ChromaStemConditioner, +) +from .. import quantization as qt +from ..utils.utils import dict_from_config + + +def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer: + klass = { + 'no_quant': qt.DummyQuantizer, + 'rvq': qt.ResidualVectorQuantizer + }[quantizer] + kwargs = dict_from_config(getattr(cfg, quantizer)) + if quantizer != 'no_quant': + kwargs['dimension'] = dimension + return klass(**kwargs) + + +def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): + if encoder_name == 'seanet': + kwargs = dict_from_config(getattr(cfg, 'seanet')) + encoder_override_kwargs = kwargs.pop('encoder') + decoder_override_kwargs = kwargs.pop('decoder') + encoder_kwargs = {**kwargs, **encoder_override_kwargs} + decoder_kwargs = {**kwargs, **decoder_override_kwargs} + encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs) + decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs) + return encoder, decoder + else: + raise KeyError(f'Unexpected compression model {cfg.compression_model}') + + +def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: + """Instantiate a compression model. + """ + if cfg.compression_model == 'encodec': + kwargs = dict_from_config(getattr(cfg, 'encodec')) + encoder_name = kwargs.pop('autoencoder') + quantizer_name = kwargs.pop('quantizer') + encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) + quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) + frame_rate = kwargs['sample_rate'] // encoder.hop_length + renormalize = kwargs.pop('renormalize', None) + renorm = kwargs.pop('renorm') + if renormalize is None: + renormalize = renorm is not None + warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.") + return EncodecModel(encoder, decoder, quantizer, + frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) + else: + raise KeyError(f'Unexpected compression model {cfg.compression_model}') + + +def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: + """Instantiate a transformer LM. + """ + if cfg.lm_model == 'transformer_lm': + kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) + n_q = kwargs['n_q'] + q_modeling = kwargs.pop('q_modeling', None) + codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') + attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) + cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) + cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"] + fuser = get_condition_fuser(cfg) + condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) + if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically + kwargs['cross_attention'] = True + if codebooks_pattern_cfg.modeling is None: + assert q_modeling is not None, \ + 'LM model should either have a codebook pattern defined or transformer_lm.q_modeling' + codebooks_pattern_cfg = omegaconf.OmegaConf.create( + {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} + ) + pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) + return LMModel( + pattern_provider=pattern_provider, + condition_provider=condition_provider, + fuser=fuser, + cfg_dropout=cfg_prob, + cfg_coef=cfg_coef, + attribute_dropout=attribute_dropout, + dtype=getattr(torch, cfg.dtype), + device=cfg.device, + **kwargs + ).to(cfg.device) + else: + raise KeyError(f'Unexpected LM model {cfg.lm_model}') + + +def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: + """Instantiate a conditioning model. + """ + device = cfg.device + duration = cfg.dataset.segment_duration + cfg = getattr(cfg, "conditioners") + cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg + conditioners: tp.Dict[str, BaseConditioner] = {} + with omegaconf.open_dict(cfg): + condition_provider_args = cfg.pop('args', {}) + for cond, cond_cfg in cfg.items(): + model_type = cond_cfg["model"] + model_args = cond_cfg[model_type] + if model_type == "t5": + conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) + elif model_type == "lut": + conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args) + elif model_type == "chroma_stem": + model_args.pop('cache_path', None) + conditioners[str(cond)] = ChromaStemConditioner( + output_dim=output_dim, + duration=duration, + device=device, + **model_args + ) + else: + raise ValueError(f"unrecognized conditioning model: {model_type}") + conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) + return conditioner + + +def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: + """Instantiate a condition fuser object. + """ + fuser_cfg = getattr(cfg, "fuser") + fuser_methods = ["sum", "cross", "prepend", "input_interpolate"] + fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} + kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} + fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) + return fuser + + +def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: + """Instantiate a codebooks pattern provider object. + """ + pattern_providers = { + 'parallel': ParallelPatternProvider, + 'delay': DelayedPatternProvider, + 'unroll': UnrolledPatternProvider, + 'valle': VALLEPattern, + 'musiclm': MusicLMPattern, + } + name = cfg.modeling + kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} + klass = pattern_providers[name] + return klass(n_q, **kwargs) + + +def get_debug_compression_model(device='cpu'): + """Instantiate a debug compression model to be used for unit tests. + """ + seanet_kwargs = { + 'n_filters': 4, + 'n_residual_layers': 1, + 'dimension': 32, + 'ratios': [10, 8, 16] # 25 Hz at 32kHz + } + encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs) + decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs) + quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4) + init_x = torch.randn(8, 32, 128) + quantizer(init_x, 1) # initialize kmeans etc. + compression_model = EncodecModel( + encoder, decoder, quantizer, + frame_rate=25, sample_rate=32000, channels=1).to(device) + return compression_model.eval() + + +def get_debug_lm_model(device='cpu'): + """Instantiate a debug LM to be used for unit tests. + """ + pattern = DelayedPatternProvider(n_q=4) + dim = 16 + providers = { + 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"), + } + condition_provider = ConditioningProvider(providers) + fuser = ConditionFuser( + {'cross': ['description'], 'prepend': [], + 'sum': [], 'input_interpolate': []}) + lm = LMModel( + pattern, condition_provider, fuser, + n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2, + cross_attention=True, causal=True) + return lm.to(device).eval() diff --git a/melodytalk/audiocraft/models/encodec.py b/melodytalk/audiocraft/models/encodec.py new file mode 100644 index 0000000..69621a6 --- /dev/null +++ b/melodytalk/audiocraft/models/encodec.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +import typing as tp + +from einops import rearrange +import torch +from torch import nn + +from .. import quantization as qt + + +class CompressionModel(ABC, nn.Module): + + @abstractmethod + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + ... + + @abstractmethod + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """See `EncodecModel.encode`""" + ... + + @abstractmethod + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + """See `EncodecModel.decode`""" + ... + + @property + @abstractmethod + def channels(self) -> int: + ... + + @property + @abstractmethod + def frame_rate(self) -> int: + ... + + @property + @abstractmethod + def sample_rate(self) -> int: + ... + + @property + @abstractmethod + def cardinality(self) -> int: + ... + + @property + @abstractmethod + def num_codebooks(self) -> int: + ... + + @property + @abstractmethod + def total_codebooks(self) -> int: + ... + + @abstractmethod + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + ... + + +class EncodecModel(CompressionModel): + """Encodec model operating on the raw waveform. + + Args: + encoder (nn.Module): Encoder network. + decoder (nn.Module): Decoder network. + quantizer (qt.BaseQuantizer): Quantizer network. + frame_rate (int): Frame rate for the latent representation. + sample_rate (int): Audio sample rate. + channels (int): Number of audio channels. + causal (bool): Whether to use a causal version of the model. + renormalize (bool): Whether to renormalize the audio before running the model. + """ + # we need assignement to override the property in the abstract class, + # I couldn't find a better way... + frame_rate: int = 0 + sample_rate: int = 0 + channels: int = 0 + + def __init__(self, + encoder: nn.Module, + decoder: nn.Module, + quantizer: qt.BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.quantizer = quantizer + self.frame_rate = frame_rate + self.sample_rate = sample_rate + self.channels = channels + self.renormalize = renormalize + self.causal = causal + if self.causal: + # we force disabling here to avoid handling linear overlap of segments + # as supported in original EnCodec codebase. + assert not self.renormalize, 'Causal model does not support renormalize' + + @property + def total_codebooks(self): + """Total number of quantizer codebooks available. + """ + return self.quantizer.total_codebooks + + @property + def num_codebooks(self): + """Active number of codebooks used by the quantizer. + """ + return self.quantizer.num_codebooks + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + self.quantizer.set_num_codebooks(n) + + @property + def cardinality(self): + """Cardinality of each codebook. + """ + return self.quantizer.bins + + def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + scale: tp.Optional[torch.Tensor] + if self.renormalize: + mono = x.mean(dim=1, keepdim=True) + volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() + scale = 1e-8 + volume + x = x / scale + scale = scale.view(-1, 1) + else: + scale = None + return x, scale + + def postprocess(self, + x: torch.Tensor, + scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + if scale is not None: + assert self.renormalize + x = x * scale.view(-1, 1, 1) + return x + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + assert x.dim() == 3 + length = x.shape[-1] + x, scale = self.preprocess(x) + + emb = self.encoder(x) + q_res = self.quantizer(emb, self.frame_rate) + out = self.decoder(q_res.x) + + # remove extra padding added by the encoder and decoder + assert out.shape[-1] >= length, (out.shape[-1], length) + out = out[..., :length] + + q_res.x = self.postprocess(out, scale) + + return q_res + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """Encode the given input tensor to quantized representation along with scale parameter. + + Args: + x (torch.Tensor): Float tensor of shape [B, C, T] + + Returns: + codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of: + codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. + scale a float tensor containing the scale for audio renormalizealization. + """ + assert x.dim() == 3 + x, scale = self.preprocess(x) + emb = self.encoder(x) + codes = self.quantizer.encode(emb) + return codes, scale + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + """Decode the given codes to a reconstructed representation, using the scale to perform + audio denormalization if needed. + + Args: + codes (torch.Tensor): Int tensor of shape [B, K, T] + scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value. + + Returns: + out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. + """ + emb = self.quantizer.decode(codes) + out = self.decoder(emb) + out = self.postprocess(out, scale) + # out contains extra padding added by the encoder and decoder + return out + + +class FlattenedCompressionModel(CompressionModel): + """Wraps a CompressionModel and flatten its codebooks, e.g. + instead of returning [B, K, T], return [B, S, T * (K // S)] with + S the number of codebooks per step, and `K // S` the number of 'virtual steps' + for each real time step. + + Args: + model (CompressionModel): compression model to wrap. + codebooks_per_step (int): number of codebooks to keep per step, + this must divide the number of codebooks provided by the wrapped model. + extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1, + if each codebook has a cardinality N, then the first codebook will + use the range [0, N - 1], and the second [N, 2 N - 1] etc. + On decoding, this can lead to potentially invalid sequences. + Any invalid entry will be silently remapped to the proper range + with a modulo. + """ + def __init__(self, model: CompressionModel, codebooks_per_step: int = 1, + extend_cardinality: bool = True): + super().__init__() + self.model = model + self.codebooks_per_step = codebooks_per_step + self.extend_cardinality = extend_cardinality + + @property + def total_codebooks(self): + return self.model.total_codebooks + + @property + def num_codebooks(self): + """Active number of codebooks used by the quantizer. + + ..Warning:: this reports the number of codebooks after the flattening + of the codebooks! + """ + assert self.model.num_codebooks % self.codebooks_per_step == 0 + return self.codebooks_per_step + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + + ..Warning:: this sets the number of codebooks **before** the flattening + of the codebooks. + """ + assert n % self.codebooks_per_step == 0 + self.model.set_num_codebooks(n) + + @property + def num_virtual_steps(self) -> int: + """Return the number of virtual steps, e.g. one real step + will be split into that many steps. + """ + return self.model.num_codebooks // self.codebooks_per_step + + @property + def frame_rate(self) -> int: + return self.model.frame_rate * self.num_virtual_steps + + @property + def sample_rate(self) -> int: + return self.model.sample_rate + + @property + def channels(self) -> int: + return self.model.channels + + @property + def cardinality(self): + """Cardinality of each codebook. + """ + if self.extend_cardinality: + return self.model.cardinality * self.num_virtual_steps + else: + return self.model.cardinality + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + raise NotImplementedError("Not supported, use encode and decode.") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + indices, scales = self.model.encode(x) + B, K, T = indices.shape + indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step) + if self.extend_cardinality: + for virtual_step in range(1, self.num_virtual_steps): + indices[..., virtual_step] += self.model.cardinality * virtual_step + indices = rearrange(indices, 'b k t v -> b k (t v)') + return (indices, scales) + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + B, K, T = codes.shape + assert T % self.num_virtual_steps == 0 + codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps) + # We silently ignore potential errors from the LM when + # using extend_cardinality. + codes = codes % self.model.cardinality + return self.model.decode(codes, scale) diff --git a/melodytalk/audiocraft/models/lm.py b/melodytalk/audiocraft/models/lm.py new file mode 100644 index 0000000..c8aad8f --- /dev/null +++ b/melodytalk/audiocraft/models/lm.py @@ -0,0 +1,527 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from functools import partial +import logging +import math +import typing as tp + +import torch +from torch import nn + +from ..utils import utils +from ..modules.streaming import StreamingModule, State +from ..modules.transformer import StreamingTransformer, create_norm_fn +from ..modules.conditioners import ( + ConditionFuser, + ClassifierFreeGuidanceDropout, + AttributeDropout, + ConditioningProvider, + ConditioningAttributes, + ConditionType, +) +from ..modules.codebooks_patterns import CodebooksPatternProvider +from ..modules.activations import get_activation_fn + + +logger = logging.getLogger(__name__) +ConditionTensors = tp.Dict[str, ConditionType] +CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]] + + +def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): + """LM layer initialization. + Inspired from xlformers: https://github.com/fairinternal/xlformers + + Args: + method (str): Method name for init function. Valid options are: + 'gaussian', 'uniform'. + input_dim (int): Input dimension of the initialized module. + init_depth (Optional[int]): Optional init depth value used to rescale + the standard deviation if defined. + """ + # Compute std + std = 1 / math.sqrt(input_dim) + # Rescale with depth + if init_depth is not None: + std = std / math.sqrt(2 * init_depth) + + if method == 'gaussian': + return partial( + torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std + ) + elif method == 'uniform': + bound = math.sqrt(3) * std # ensure the standard deviation is `std` + return partial(torch.nn.init.uniform_, a=-bound, b=bound) + else: + raise ValueError("Unsupported layer initialization method") + + +def init_layer(m: nn.Module, + method: str, + init_depth: tp.Optional[int] = None, + zero_bias_init: bool = False): + """Wrapper around ``get_init_fn`` for proper initialization of LM modules. + + Args: + m (nn.Module): Module to initialize. + method (str): Method name for the init function. + init_depth (Optional[int]): Optional init depth value used to rescale + the standard deviation if defined. + zero_bias_init (bool): Whether to initialize the bias to 0 or not. + """ + if isinstance(m, nn.Linear): + init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) + if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: + weight = m.weight.float() + init_fn(weight) + m.weight.data[:] = weight.half() + else: + init_fn(m.weight) + if zero_bias_init and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Embedding): + init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) + if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: + weight = m.weight.float() + init_fn(weight) + m.weight.data[:] = weight.half() + else: + init_fn(m.weight) + + +class ScaledEmbedding(nn.Embedding): + """Boost learning rate for embeddings (with `scale`). + """ + def __init__(self, *args, lr=None, **kwargs): + super().__init__(*args, **kwargs) + self.lr = lr + + def make_optim_group(self): + group = {"params": list(self.parameters())} + if self.lr is not None: + group["lr"] = self.lr + return group + + +@dataclass +class LMOutput: + # The logits are already re-aligned with the input codes + # hence no extra shift is required, e.g. when computing CE + logits: torch.Tensor # [B, K, T, card] + mask: torch.Tensor # [B, K, T] + + +class LMModel(StreamingModule): + """Transformer-based language model on multiple streams of codes. + + Args: + pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. + condition_provider (MusicConditioningProvider): Conditioning provider from metadata. + fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. + n_q (int): Number of parallel streams to model. + card (int): Cardinality, vocabulary size. + dim (int): Dimension of the transformer encoder. + num_heads (int): Number of heads for the transformer encoder. + hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. + norm (str): Normalization method. + norm_first (bool): Use pre-norm instead of post-norm. + emb_lr (Optional[float]): Embedding-specific learning rate. + bias_proj (bool): Use bias for output projections. + weight_init (Optional[str]): Method for weight initialization. + depthwise_init (Optional[str]): Method for depthwise weight initialization. + zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. + cfg_dropout (float): Classifier-free guidance dropout. + cfg_coef (float): Classifier-free guidance coefficient. + attribute_dropout (dict): Attribute dropout probabilities. + two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. + **kwargs: Additional parameters for the transformer encoder. + """ + def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider, + fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8, + hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False, + emb_lr: tp.Optional[float] = None, bias_proj: bool = True, + weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None, + zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, + attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False, + **kwargs): + super().__init__() + self.cfg_coef = cfg_coef + self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout) + self.att_dropout = AttributeDropout(p=attribute_dropout) + self.condition_provider = condition_provider + self.fuser = fuser + self.card = card + embed_dim = self.card + 1 + self.n_q = n_q + self.dim = dim + self.pattern_provider = pattern_provider + self.two_step_cfg = two_step_cfg + self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)]) + if 'activation' in kwargs: + kwargs['activation'] = get_activation_fn(kwargs['activation']) + self.transformer = StreamingTransformer( + d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim), + norm=norm, norm_first=norm_first, **kwargs) + self.out_norm: tp.Optional[nn.Module] = None + if norm_first: + self.out_norm = create_norm_fn(norm, dim) + self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)]) + self._init_weights(weight_init, depthwise_init, zero_bias_init) + self._fsdp: tp.Optional[nn.Module] + self.__dict__['_fsdp'] = None + + def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): + """Initialization of the transformer module weights. + + Args: + weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options. + depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid: + 'current' where the depth corresponds to the current layer index or 'global' where the total number + of layer is used as depth. If not set, no depthwise initialization strategy is used. + zero_bias_init (bool): Whether to initalize bias to zero or not. + """ + assert depthwise_init is None or depthwise_init in ['current', 'global'] + assert depthwise_init is None or weight_init is not None, \ + "If 'depthwise_init' is defined, a 'weight_init' method should be provided." + assert not zero_bias_init or weight_init is not None, \ + "If 'zero_bias_init', a 'weight_init' method should be provided" + + if weight_init is None: + return + + for emb_layer in self.emb: + init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) + + for layer_idx, tr_layer in enumerate(self.transformer.layers): + depth = None + if depthwise_init == 'current': + depth = layer_idx + 1 + elif depthwise_init == 'global': + depth = len(self.transformer.layers) + init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init) + tr_layer.apply(init_fn) + + for linear in self.linears: + init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) + + @property + def special_token_id(self) -> int: + return self.card + + @property + def num_codebooks(self) -> int: + return self.n_q + + def forward(self, sequence: torch.Tensor, + conditions: tp.List[ConditioningAttributes], + condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor: + """Apply language model on sequence and conditions. + Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and + S the sequence steps, return the logits with shape [B, card, K, S]. + + Args: + indices (torch.Tensor): indices of the codes to model. + conditions (list[ConditioningAttributes]): conditionings to use when modeling + the given codes. Note that when evaluating multiple time with the same conditioning + you should pre-compute those and pass them as `condition_tensors`. + condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning + tensors, see `conditions`. + Returns: + torch.Tensor: Logits. + """ + B, K, S = sequence.shape + assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks' + input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)]) + if condition_tensors is None: + assert not self._is_streaming, "Conditions tensors should be precomputed when streaming." + # apply dropout modules + conditions = self.cfg_dropout(conditions) + conditions = self.att_dropout(conditions) + tokenized = self.condition_provider.tokenize(conditions) + # encode conditions and fuse, both have a streaming cache to not recompute when generating. + condition_tensors = self.condition_provider(tokenized) + else: + assert not conditions, "Shouldn't pass both conditions and condition_tensors." + + input_, cross_attention_input = self.fuser(input_, condition_tensors) + + out = self.transformer(input_, cross_attention_src=cross_attention_input) + if self.out_norm: + out = self.out_norm(out) + logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card] + + # remove the prefix from the model outputs + if len(self.fuser.fuse2cond['prepend']) > 0: + logits = logits[:, :, -S:] + + return logits # [B, K, S, card] + + def compute_predictions( + self, codes: torch.Tensor, + conditions: tp.List[ConditioningAttributes], + condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput: + """Given an input tensor of codes [B, K, T] and list of conditions, runs the model + forward using the specified codes interleaving pattern. + + Args: + codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, + K the number of codebooks and T the number of timesteps. + conditions (list[ConditioningAttributes]): conditionings to use when modeling + the given codes. Note that when evaluating multiple time with the same conditioning + you should pre-compute those and pass them as `condition_tensors`. + condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning + tensors, see `conditions`. + Returns: + LMOutput: Language model outputs + logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, + i.e. the first item corresponds to logits to predict the first code, meaning that + no additional shifting of codes and logits is required. + mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. + Given the specified interleaving strategies, parts of the logits and codes should + not be considered as valid predictions because of invalid context. + """ + B, K, T = codes.shape + codes = codes.contiguous() + # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens + pattern = self.pattern_provider.get_pattern(T) + sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence( + codes, self.special_token_id, keep_only_valid_steps=True + ) + # apply model on pattern sequence + model = self if self._fsdp is None else self._fsdp + logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card] + # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card] + # and provide the corresponding mask over invalid positions of tokens + logits = logits.permute(0, 3, 1, 2) # [B, card, K, S] + # note: we use nans as special token to make it obvious if we feed unexpected logits + logits, logits_indexes, logits_mask = pattern.revert_pattern_logits( + logits, float('nan'), keep_only_valid_steps=True + ) + logits = logits.permute(0, 2, 3, 1) # [B, K, T, card] + logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T] + return LMOutput(logits, logits_mask) + + def _sample_next_token(self, + sequence: torch.Tensor, + cfg_conditions: CFGConditions, + unconditional_state: State, + use_sampling: bool = False, + temp: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + cfg_coef: tp.Optional[float] = None) -> torch.Tensor: + """Sample next token from the model given a sequence and a set of conditions. The model supports + multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). + + Args: + sequence (torch.Tensor): Current sequence of shape [B, K, S] + with K corresponding to the number of codebooks and S the number of sequence steps. + S = 1 in streaming mode, except for the first step that contains a bigger prompt. + condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used, + should be twice the batch size, being the concatenation of the conditions + null conditions. + use_sampling (bool): Whether to use a sampling strategy or not. + temp (float): Sampling temperature. + top_k (int): K for "top-k" sampling. + top_p (float): P for "top-p" sampling. + cfg_coef (float): classifier free guidance coefficient + Returns: + next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. + """ + B = sequence.shape[0] + cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef + model = self if self._fsdp is None else self._fsdp + if self.two_step_cfg and cfg_conditions != {}: + assert isinstance(cfg_conditions, tuple) + condition_tensors, null_condition_tensors = cfg_conditions + cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors) + state = self.get_streaming_state() + self.set_streaming_state(unconditional_state) + uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors) + unconditional_state.update(self.get_streaming_state()) + self.set_streaming_state(state) + logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef + else: + assert isinstance(cfg_conditions, dict) + condition_tensors = cfg_conditions + if condition_tensors: + # Preparing for CFG, predicting both conditional and unconditional logits. + sequence = torch.cat([sequence, sequence], dim=0) + all_logits = model( + sequence, + conditions=[], condition_tensors=condition_tensors) + if condition_tensors: + cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card] + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef + else: + logits = all_logits + + logits = logits.permute(0, 1, 3, 2) # [B, K, card, T] + logits = logits[..., -1] # [B x K x card] + + # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error. + if use_sampling and temp > 0.0: + probs = torch.softmax(logits / temp, dim=-1) + if top_p > 0.0: + next_token = utils.sample_top_p(probs, p=top_p) + elif top_k > 0: + next_token = utils.sample_top_k(probs, k=top_k) + else: + next_token = utils.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) + + return next_token + + @torch.no_grad() + def generate(self, + prompt: tp.Optional[torch.Tensor] = None, + conditions: tp.List[ConditioningAttributes] = [], + num_samples: tp.Optional[int] = None, + max_gen_len: int = 256, + use_sampling: bool = True, + temp: float = 1.0, + top_k: int = 250, + top_p: float = 0.0, + cfg_coef: tp.Optional[float] = None, + two_step_cfg: bool = False, + remove_prompts: bool = False, + check: bool = False, + callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor: + """Generate tokens sampling from the model given a prompt or unconditionally. Generation can + be perform in a greedy fashion or using sampling with top K and top P strategies. + + Args: + prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T]. + conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None. + num_samples (int or None): Number of samples to generate when no prompt and no conditions are given. + max_gen_len (int): Maximum generation length. + use_sampling (bool): Whether to use a sampling strategy or not. + temp (float): Sampling temperature. + top_k (int): K for "top-k" sampling. + top_p (float): P for "top-p" sampling. + remove_prompts (bool): Whether to remove prompts from generation or not. + Returns: + torch.Tensor: Generated tokens. + """ + assert not self.training, "generation shouldn't be used in training mode." + first_param = next(iter(self.parameters())) + device = first_param.device + + # Checking all input shapes are consistents. + possible_num_samples = [] + if num_samples is not None: + possible_num_samples.append(num_samples) + elif prompt is not None: + possible_num_samples.append(prompt.shape[0]) + elif conditions: + possible_num_samples.append(len(conditions)) + else: + possible_num_samples.append(1) + assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes" + num_samples = possible_num_samples[0] + + # below we create set of conditions: one conditional and one unconditional + # to do that we merge the regular condition together with the null condition + # we then do 1 forward pass instead of 2. + # the reason for that is two-fold: + # 1. it is about x2 faster than doing 2 forward passes + # 2. avoid the streaming API treating the 2 passes as part of different time steps + # We also support doing two different passes, in particular to ensure that + # the padding structure is exactly the same between train anf test. + # With a batch size of 1, this can be slower though. + cfg_conditions: CFGConditions + two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg + if conditions: + null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions) + if two_step_cfg: + cfg_conditions = ( + self.condition_provider(self.condition_provider.tokenize(conditions)), + self.condition_provider(self.condition_provider.tokenize(null_conditions)), + ) + else: + conditions = conditions + null_conditions + tokenized = self.condition_provider.tokenize(conditions) + cfg_conditions = self.condition_provider(tokenized) + else: + cfg_conditions = {} + + if prompt is None: + assert num_samples > 0 + prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) + + B, K, T = prompt.shape + start_offset = T + assert start_offset < max_gen_len + + pattern = self.pattern_provider.get_pattern(max_gen_len) + # this token is used as default value for codes that are not generated yet + unknown_token = -1 + + # we generate codes up to the max_gen_len that will be mapped to the pattern sequence + gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device) + # filling the gen_codes with the prompt if needed + gen_codes[..., :start_offset] = prompt + # create the gen_sequence with proper interleaving from the pattern: [B, K, S] + gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) + # retrieve the start_offset in the sequence: + # it is the first sequence step that contains the `start_offset` timestep + start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) + assert start_offset_sequence is not None + + with self.streaming(): + unconditional_state = self.get_streaming_state() + prev_offset = 0 + gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S] + for offset in range(start_offset_sequence, gen_sequence_len): + # get current sequence (note that the streaming API is providing the caching over previous offsets) + curr_sequence = gen_sequence[..., prev_offset:offset] + curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1) + if check: + # check coherence between mask and sequence + assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all() + # should never happen as gen_sequence is filled progressively + assert not (curr_sequence == unknown_token).any() + # sample next token from the model, next token shape is [B, K, 1] + next_token = self._sample_next_token( + curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, + cfg_coef=cfg_coef) + # ensure the tokens that should be masked are properly set to special_token_id + # as the model never output special_token_id + valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) + next_token[~valid_mask] = self.special_token_id + # ensure we don't overwrite prompt tokens, we only write over unknown tokens + # (then mask tokens should be left as is as well, which is correct) + gen_sequence[..., offset:offset+1] = torch.where( + gen_sequence[..., offset:offset+1] == unknown_token, + next_token, gen_sequence[..., offset:offset+1] + ) + prev_offset = offset + if callback is not None: + callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) + unconditional_state.clear() + + # ensure sequence has been entirely filled + assert not (gen_sequence == unknown_token).any() + # ensure gen_sequence pattern and mask are matching + # which means the gen_sequence is valid according to the pattern + assert ( + gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id) + ).all() + # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps + out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) + + # sanity checks over the returned codes and corresponding masks + assert (out_codes[..., :max_gen_len] != unknown_token).all() + assert (out_mask[..., :max_gen_len] == 1).all() + + out_start_offset = start_offset if remove_prompts else 0 + out_codes = out_codes[..., out_start_offset:max_gen_len] + + # ensure the returned codes are all valid + assert (out_codes >= 0).all() and (out_codes <= self.card).all() + return out_codes diff --git a/melodytalk/audiocraft/models/loaders.py b/melodytalk/audiocraft/models/loaders.py new file mode 100644 index 0000000..97c662c --- /dev/null +++ b/melodytalk/audiocraft/models/loaders.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility functions to load from the checkpoints. +Each checkpoint is a torch.saved dict with the following keys: +- 'xp.cfg': the hydra config as dumped during training. This should be used + to rebuild the object using the audiocraft.models.builders functions, +- 'model_best_state': a readily loadable best state for the model, including + the conditioner. The model obtained from `xp.cfg` should be compatible + with this state dict. In the case of a LM, the encodec model would not be + bundled along but instead provided separately. + +Those functions also support loading from a remote location with the Torch Hub API. +They also support overriding some parameters, in particular the device and dtype +of the returned model. +""" + +from pathlib import Path +from huggingface_hub import hf_hub_download +import typing as tp +import os + +from omegaconf import OmegaConf +import torch + +from . import builders + + +HF_MODEL_CHECKPOINTS_MAP = { + "small": "facebook/musicgen-small", + "medium": "facebook/musicgen-medium", + "large": "facebook/musicgen-large", + "melody": "facebook/musicgen-melody", +} + + +def _get_state_dict( + file_or_url_or_id: tp.Union[Path, str], + filename: tp.Optional[str] = None, + device='cpu', + cache_dir: tp.Optional[str] = None, +): + # Return the state dict either from a file or url + file_or_url_or_id = str(file_or_url_or_id) + assert isinstance(file_or_url_or_id, str) + + if os.path.isfile(file_or_url_or_id): + return torch.load(file_or_url_or_id, map_location=device) + + if os.path.isdir(file_or_url_or_id): + file = f"{file_or_url_or_id}/{filename}" + return torch.load(file, map_location=device) + + elif file_or_url_or_id.startswith('https://'): + return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) + + elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP: + assert filename is not None, "filename needs to be defined if using HF checkpoints" + + repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id] + file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir) + return torch.load(file, map_location=device) + + else: + raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.") + + +def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): + pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) + cfg = OmegaConf.create(pkg['xp.cfg']) + cfg.device = str(device) + model = builders.get_compression_model(cfg) + model.load_state_dict(pkg['best_state']) + model.eval() + return model + + +def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): + pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) + cfg = OmegaConf.create(pkg['xp.cfg']) + cfg.device = str(device) + if cfg.device == 'cpu': + cfg.dtype = 'float32' + else: + cfg.dtype = 'float16' + model = builders.get_lm_model(cfg) + model.load_state_dict(pkg['best_state']) + model.eval() + model.cfg = cfg + return model diff --git a/melodytalk/audiocraft/models/musicgen.py b/melodytalk/audiocraft/models/musicgen.py new file mode 100644 index 0000000..007dd9e --- /dev/null +++ b/melodytalk/audiocraft/models/musicgen.py @@ -0,0 +1,362 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Main model for using MusicGen. This will combine all the required components +and provide easy access to the generation API. +""" + +import os +import typing as tp + +import torch + +from .encodec import CompressionModel +from .lm import LMModel +from .builders import get_debug_compression_model, get_debug_lm_model +from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP +from ..data.audio_utils import convert_audio +from ..modules.conditioners import ConditioningAttributes, WavCondition +from ..utils.autocast import TorchAutocast + + +MelodyList = tp.List[tp.Optional[torch.Tensor]] +MelodyType = tp.Union[torch.Tensor, MelodyList] + + +class MusicGen: + """MusicGen main model with convenient generation API. + + Args: + name (str): name of the model. + compression_model (CompressionModel): Compression model + used to map audio to invertible discrete representations. + lm (LMModel): Language model over discrete representations. + """ + def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, + max_duration: float = 30): + self.name = name + self.compression_model = compression_model + self.lm = lm + self.max_duration = max_duration + self.device = next(iter(lm.parameters())).device + self.generation_params: dict = {} + self.set_generation_params(duration=15) # 15 seconds by default + self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None + if self.device.type == 'cpu': + self.autocast = TorchAutocast(enabled=False) + else: + self.autocast = TorchAutocast( + enabled=True, device_type=self.device.type, dtype=torch.float16) + + @property + def frame_rate(self) -> int: + """Roughly the number of AR steps per seconds.""" + return self.compression_model.frame_rate + + @property + def sample_rate(self) -> int: + """Sample rate of the generated audio.""" + return self.compression_model.sample_rate + + @property + def audio_channels(self) -> int: + """Audio channels of the generated audio.""" + return self.compression_model.channels + + @staticmethod + def get_pretrained(name: str = 'melody', device=None): + """Return pretrained model, we provide four models: + - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small + - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium + - melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody + - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large + """ + + if device is None: + if torch.cuda.device_count(): + device = 'cuda' + else: + device = 'cpu' + + if name == 'debug': + # used only for unit tests + compression_model = get_debug_compression_model(device) + lm = get_debug_lm_model(device) + return MusicGen(name, compression_model, lm) + + if name not in HF_MODEL_CHECKPOINTS_MAP: + if not os.path.isfile(name) and not os.path.isdir(name): + raise ValueError( + f"{name} is not a valid checkpoint name. " + f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}" + ) + + cache_dir = os.environ.get('MUSICGEN_ROOT', None) + compression_model = load_compression_model(name, device=device, cache_dir=cache_dir) + lm = load_lm_model(name, device=device, cache_dir=cache_dir) + if name == 'melody': + lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True + + return MusicGen(name, compression_model, lm) + + def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, + top_p: float = 0.0, temperature: float = 1.0, + duration: float = 30.0, cfg_coef: float = 3.0, + two_step_cfg: bool = False, extend_stride: float = 18): + """Set the generation parameters for MusicGen. + + Args: + use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. + top_k (int, optional): top_k used for sampling. Defaults to 250. + top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. + temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. + duration (float, optional): Duration of the generated waveform. Defaults to 30.0. + cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. + two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, + instead of batching together the two. This has some impact on how things + are padded but seems to have little impact in practice. + extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much + should we extend the audio each time. Larger values will mean less context is + preserved, and shorter value will require extra computations. + """ + assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." + self.extend_stride = extend_stride + self.duration = duration + self.generation_params = { + 'use_sampling': use_sampling, + 'temp': temperature, + 'top_k': top_k, + 'top_p': top_p, + 'cfg_coef': cfg_coef, + 'two_step_cfg': two_step_cfg, + } + + def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): + """Override the default progress callback.""" + self._progress_callback = progress_callback + + def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor: + """Generate samples in an unconditional manner. + + Args: + num_samples (int): Number of samples to be generated. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + descriptions: tp.List[tp.Optional[str]] = [None] * num_samples + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) + return self._generate_tokens(attributes, prompt_tokens, progress) + + def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor: + """Generate samples conditioned on text. + + Args: + descriptions (tp.List[str]): A list of strings used as text conditioning. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) + assert prompt_tokens is None + return self._generate_tokens(attributes, prompt_tokens, progress) + + def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType, + melody_sample_rate: int, progress: bool = False) -> torch.Tensor: + """Generate samples conditioned on text and melody. + + Args: + descriptions (tp.List[str]): A list of strings used as text conditioning. + melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as + melody conditioning. Should have shape [B, C, T] with B matching the description length, + C=1 or 2. It can be [C, T] if there is a single description. It can also be + a list of [C, T] tensors. + melody_sample_rate: (int): Sample rate of the melody waveforms. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + if isinstance(melody_wavs, torch.Tensor): + if melody_wavs.dim() == 2: + melody_wavs = melody_wavs[None] + if melody_wavs.dim() != 3: + raise ValueError("Melody wavs should have a shape [B, C, T].") + melody_wavs = list(melody_wavs) + else: + for melody in melody_wavs: + if melody is not None: + assert melody.dim() == 2, "One melody in the list has the wrong number of dims." + + melody_wavs = [ + convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels) + if wav is not None else None + for wav in melody_wavs] + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None, + melody_wavs=melody_wavs) + assert prompt_tokens is None + return self._generate_tokens(attributes, prompt_tokens, progress) + + def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, + descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, + progress: bool = False) -> torch.Tensor: + """Generate samples conditioned on audio prompts. + + Args: + prompt (torch.Tensor): A batch of waveforms used for continuation. + Prompt should be [B, C, T], or [C, T] if only one sample is generated. + prompt_sample_rate (int): Sampling rate of the given audio waveforms. + descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + if prompt.dim() == 2: + prompt = prompt[None] + if prompt.dim() != 3: + raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") + prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) + if descriptions is None: + descriptions = [None] * len(prompt) + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) + assert prompt_tokens is not None + return self._generate_tokens(attributes, prompt_tokens, progress) + + @torch.no_grad() + def _prepare_tokens_and_attributes( + self, + descriptions: tp.Sequence[tp.Optional[str]], + prompt: tp.Optional[torch.Tensor], + melody_wavs: tp.Optional[MelodyList] = None, + ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: + """Prepare model inputs. + + Args: + descriptions (tp.List[str]): A list of strings used as text conditioning. + prompt (torch.Tensor): A batch of waveforms used for continuation. + melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms + used as melody conditioning. Defaults to None. + """ + attributes = [ + ConditioningAttributes(text={'description': description}) + for description in descriptions] + + if melody_wavs is None: + for attr in attributes: + attr.wav['self_wav'] = WavCondition( + torch.zeros((1, 1), device=self.device), + torch.tensor([0], device=self.device), + path='null_wav') # type: ignore + else: + if self.name != "melody": + raise RuntimeError("This model doesn't support melody conditioning. " + "Use the `melody` model.") + assert len(melody_wavs) == len(descriptions), \ + f"number of melody wavs must match number of descriptions! " \ + f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}" + for attr, melody in zip(attributes, melody_wavs): + if melody is None: + attr.wav['self_wav'] = WavCondition( + torch.zeros((1, 1), device=self.device), + torch.tensor([0], device=self.device), + path='null_wav') # type: ignore + else: + attr.wav['self_wav'] = WavCondition( + melody.to(device=self.device), + torch.tensor([melody.shape[-1]], device=self.device)) + + if prompt is not None: + if descriptions is not None: + assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" + prompt = prompt.to(self.device) + prompt_tokens, scale = self.compression_model.encode(prompt) + assert scale is None + else: + prompt_tokens = None + return attributes, prompt_tokens + + def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], + prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: + """Generate discrete audio tokens given audio prompt and/or conditions. + + Args: + attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody). + prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + Returns: + torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. + """ + total_gen_len = int(self.duration * self.frame_rate) + max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) + current_gen_offset: int = 0 + + def _progress_callback(generated_tokens: int, tokens_to_generate: int): + generated_tokens += current_gen_offset + if self._progress_callback is not None: + # Note that total_gen_len might be quite wrong depending on the + # codebook pattern used, but with delay it is almost accurate. + self._progress_callback(generated_tokens, total_gen_len) + else: + print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r') + + if prompt_tokens is not None: + assert max_prompt_len >= prompt_tokens.shape[-1], \ + "Prompt is longer than audio to generate" + + callback = None + if progress: + callback = _progress_callback + + if self.duration <= self.max_duration: + # generate by sampling from LM, simple case. + with self.autocast: + gen_tokens = self.lm.generate( + prompt_tokens, attributes, + callback=callback, max_gen_len=total_gen_len, **self.generation_params) + + else: + # now this gets a bit messier, we need to handle prompts, + # melody conditioning etc. + ref_wavs = [attr.wav['self_wav'] for attr in attributes] + all_tokens = [] + if prompt_tokens is None: + prompt_length = 0 + else: + all_tokens.append(prompt_tokens) + prompt_length = prompt_tokens.shape[-1] + + stride_tokens = int(self.frame_rate * self.extend_stride) + + while current_gen_offset + prompt_length < total_gen_len: + time_offset = current_gen_offset / self.frame_rate + chunk_duration = min(self.duration - time_offset, self.max_duration) + max_gen_len = int(chunk_duration * self.frame_rate) + for attr, ref_wav in zip(attributes, ref_wavs): + wav_length = ref_wav.length.item() + if wav_length == 0: + continue + # We will extend the wav periodically if it not long enough. + # we have to do it here rather than in conditioners.py as otherwise + # we wouldn't have the full wav. + initial_position = int(time_offset * self.sample_rate) + wav_target_length = int(self.max_duration * self.sample_rate) + print(initial_position / self.sample_rate, wav_target_length / self.sample_rate) + positions = torch.arange(initial_position, + initial_position + wav_target_length, device=self.device) + attr.wav['self_wav'] = WavCondition( + ref_wav[0][:, positions % wav_length], + torch.full_like(ref_wav[1], wav_target_length)) + with self.autocast: + gen_tokens = self.lm.generate( + prompt_tokens, attributes, + callback=callback, max_gen_len=max_gen_len, **self.generation_params) + if prompt_tokens is None: + all_tokens.append(gen_tokens) + else: + all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) + prompt_tokens = gen_tokens[:, :, stride_tokens:] + prompt_length = prompt_tokens.shape[-1] + current_gen_offset += stride_tokens + + gen_tokens = torch.cat(all_tokens, dim=-1) + + # generate audio + assert gen_tokens.dim() == 3 + with torch.no_grad(): + gen_audio = self.compression_model.decode(gen_tokens, None) + return gen_audio diff --git a/melodytalk/audiocraft/modules/__init__.py b/melodytalk/audiocraft/modules/__init__.py new file mode 100644 index 0000000..81ba30f --- /dev/null +++ b/melodytalk/audiocraft/modules/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .conv import ( + NormConv1d, + NormConv2d, + NormConvTranspose1d, + NormConvTranspose2d, + StreamableConv1d, + StreamableConvTranspose1d, + pad_for_conv1d, + pad1d, + unpad1d, +) +from .lstm import StreamableLSTM +from .seanet import SEANetEncoder, SEANetDecoder diff --git a/melodytalk/audiocraft/modules/activations.py b/melodytalk/audiocraft/modules/activations.py new file mode 100644 index 0000000..8bd6f29 --- /dev/null +++ b/melodytalk/audiocraft/modules/activations.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch import Tensor +from typing import Union, Callable + + +class CustomGLU(nn.Module): + """Custom Gated Linear Unit activation. + Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half + of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation + function (i.e. sigmoid, swish, etc.). + + Args: + activation (nn.Module): The custom activation to apply in the Gated Linear Unit + dim (int): the dimension on which to split the input. Default: -1 + + Shape: + - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional + dimensions + - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + + Examples:: + >>> m = CustomGLU(nn.Sigmoid()) + >>> input = torch.randn(4, 2) + >>> output = m(input) + """ + def __init__(self, activation: nn.Module, dim: int = -1): + super(CustomGLU, self).__init__() + self.dim = dim + self.activation = activation + + def forward(self, x: Tensor): + assert x.shape[self.dim] % 2 == 0 # M = N / 2 + a, b = torch.chunk(x, 2, dim=self.dim) + return a * self.activation(b) + + +class SwiGLU(CustomGLU): + """SiLU Gated Linear Unit activation. + Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is + the first half of the input matrices, :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + """ + def __init__(self, dim: int = -1): + super(SwiGLU, self).__init__(nn.SiLU(), dim) + + +class GeGLU(CustomGLU): + """GeLU Gated Linear Unit activation. + Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is + the first half of the input matrices, :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + """ + def __init__(self, dim: int = -1): + super(GeGLU, self).__init__(nn.GELU(), dim) + + +class ReGLU(CustomGLU): + """ReLU Gated Linear Unit activation. + Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is + the first half of the input matrices, :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + """ + def __init__(self, dim: int = -1): + super(ReGLU, self).__init__(nn.ReLU(), dim) + + +def get_activation_fn( + activation: Union[str, Callable[[Tensor], Tensor]] +) -> Union[str, Callable[[Tensor], Tensor]]: + """Helper function to map an activation string to the activation class. + If the supplied activation is not a string that is recognized, the activation is passed back. + + Args: + activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check + """ + if isinstance(activation, str): + if activation == "reglu": + return ReGLU() + elif activation == "geglu": + return GeGLU() + elif activation == "swiglu": + return SwiGLU() + return activation diff --git a/melodytalk/audiocraft/modules/codebooks_patterns.py b/melodytalk/audiocraft/modules/codebooks_patterns.py new file mode 100644 index 0000000..c5b35cb --- /dev/null +++ b/melodytalk/audiocraft/modules/codebooks_patterns.py @@ -0,0 +1,539 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import namedtuple +from dataclasses import dataclass +from functools import lru_cache +import logging +import typing as tp + +from abc import ABC, abstractmethod +import torch + +LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) +PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates +logger = logging.getLogger(__name__) + + +@dataclass +class Pattern: + """Base implementation of a pattern over a sequence with multiple codebooks. + + The codebook pattern consists in a layout, defining for each sequence step + the list of coordinates of each codebook timestep in the resulting interleaved sequence. + The first item of the pattern is always an empty list in order to properly insert a special token + to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern + and ``timesteps`` the number of timesteps corresponding to the original sequence. + + The pattern provides convenient methods to build and revert interleaved sequences from it: + ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] + to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size, + K being the number of codebooks, T the number of original timesteps and S the number of sequence steps + for the output sequence. The unfilled positions are replaced with a special token and the built sequence + is returned along with a mask indicating valid tokens. + ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment + of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask + to fill and specify invalid positions if needed. + See the dedicated methods for more details. + """ + # Pattern layout, for each sequence step, we have a list of coordinates + # corresponding to the original codebook timestep and position. + # The first list is always an empty list in order to properly insert + # a special token to start with. + layout: PatternLayout + timesteps: int + n_q: int + + def __post_init__(self): + assert len(self.layout) > 0 + assert self.layout[0] == [] + self._validate_layout() + self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) + self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) + logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout)) + + def _validate_layout(self): + """Runs checks on the layout to ensure a valid pattern is defined. + A pattern is considered invalid if: + - Multiple timesteps for a same codebook are defined in the same sequence step + - The timesteps for a given codebook are not in ascending order as we advance in the sequence + (this would mean that we have future timesteps before past timesteps). + """ + q_timesteps = {q: 0 for q in range(self.n_q)} + for s, seq_coords in enumerate(self.layout): + if len(seq_coords) > 0: + qs = set() + for coord in seq_coords: + qs.add(coord.q) + last_q_timestep = q_timesteps[coord.q] + assert coord.t >= last_q_timestep, \ + f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" + q_timesteps[coord.q] = coord.t + # each sequence step contains at max 1 coordinate per codebook + assert len(qs) == len(seq_coords), \ + f"Multiple entries for a same codebook are found at step {s}" + + @property + def num_sequence_steps(self): + return len(self.layout) - 1 + + @property + def max_delay(self): + max_t_in_seq_coords = 0 + for seq_coords in self.layout[1:]: + for coords in seq_coords: + max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1) + return max_t_in_seq_coords - self.timesteps + + @property + def valid_layout(self): + valid_step = len(self.layout) - self.max_delay + return self.layout[:valid_step] + + def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): + """Get codebook coordinates in the layout that corresponds to the specified timestep t + and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step + and the actual codebook coordinates. + """ + assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" + if q is not None: + assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" + coords = [] + for s, seq_codes in enumerate(self.layout): + for code in seq_codes: + if code.t == t and (q is None or code.q == q): + coords.append((s, code)) + return coords + + def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]: + return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)] + + def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]: + steps_with_timesteps = self.get_steps_with_timestep(t, q) + return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None + + def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool, + device: tp.Union[torch.device, str] = 'cpu'): + """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. + + Args: + timesteps (int): Maximum number of timesteps steps to consider. + keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. + device (Union[torch.device, str]): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. + """ + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" + # use the proper layout based on whether we limit ourselves to valid steps only or not, + # note that using the valid_layout will result in a truncated sequence up to the valid steps + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy() + mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + # the last value is n_q * timesteps as we have flattened z and append special token as the last token + # which will correspond to the index: n_q * timesteps + indexes[:] = n_q * timesteps + # iterate over the pattern and fill scattered indexes and mask + for s, sequence_coords in enumerate(ref_layout): + for coords in sequence_coords: + if coords.t < timesteps: + indexes[coords.q, s] = coords.t + coords.q * timesteps + mask[coords.q, s] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Build sequence corresponding to the pattern from the input tensor z. + The sequence is built using up to sequence_steps if specified, and non-pattern + coordinates are filled with the special token. + + Args: + z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T]. + special_token (int): Special token used to fill non-pattern coordinates in the new sequence. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S + corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S]. + """ + B, K, T = z.shape + indexes, mask = self._build_pattern_sequence_scatter_indexes( + T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device) + ) + z = z.view(B, -1) + # we append the special token as the last index of our flattened z tensor + z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1) + values = z[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int, + keep_only_valid_steps: bool = False, + is_model_output: bool = False, + device: tp.Union[torch.device, str] = 'cpu'): + """Builds scatter indexes required to retrieve the original multi-codebook sequence + from interleaving pattern. + + Args: + sequence_steps (int): Sequence steps. + n_q (int): Number of codebooks. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. + device (Union[torch.device, str]): Device for created tensors. + Returns: + torch.Tensor: Indexes for reconstructing the output, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # TODO(jade): Do we want to further truncate to only valid timesteps here as well? + timesteps = self.timesteps + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert sequence_steps <= len(ref_layout), \ + f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" + + # ensure we take the appropriate indexes to keep the model output from the first special token as well + if is_model_output: + ref_layout = ref_layout[1:] + + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy() + mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + indexes[:] = n_q * sequence_steps + for s, sequence_codes in enumerate(ref_layout): + if s < sequence_steps: + for code in sequence_codes: + if code.t < timesteps: + indexes[code.q, code.t] = s + code.q * sequence_steps + mask[code.q, code.t] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. + The sequence is reverted using up to timesteps if specified, and non-pattern coordinates + are filled with the special token. + + Args: + s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S]. + special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T + corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + B, K, S = s.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device) + ) + s = s.view(B, -1) + # we append the special token as the last index of our flattened z tensor + s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1) + values = s[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False): + """Revert model logits obtained on a sequence built from the pattern + back to a tensor matching the original sequence. + + This method is similar to ``revert_pattern_sequence`` with the following specificities: + 1. It is designed to work with the extra cardinality dimension + 2. We return the logits for the first sequence item that matches the special_token and + which matching target in the original sequence is the first item of the sequence, + while we skip the last logits as there is no matching target + """ + B, card, K, S = logits.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=True, device=logits.device + ) + logits = logits.reshape(B, card, -1) + # we append the special token as the last index of our flattened z tensor + logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S] + values = logits[:, :, indexes.view(-1)] + values = values.view(B, card, K, indexes.shape[-1]) + return values, indexes, mask + + +class CodebooksPatternProvider(ABC): + """Abstraction around providing pattern for interleaving codebooks. + + The CodebooksPatternProvider abstraction allows to implement various strategies to + define interleaving pattern of sequences composed of multiple codebooks. For a given + number of codebooks `n_q`, the pattern provider can generate a specified pattern + corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern + can be used to construct a new sequence from the original codes respecting the specified + pattern. The pattern is defined as a list of list of code coordinates, code coordinate + being a tuple with the original timestep and codebook to build the new sequence. + Note that all patterns must start with an empty list that is then used to insert a first + sequence step of special tokens in the newly generated sequence. + + Args: + n_q (int): number of codebooks. + cached (bool): if True, patterns for a given length are cached. In general + that should be true for efficiency reason to avoid synchronization points. + """ + def __init__(self, n_q: int, cached: bool = True): + assert n_q > 0 + self.n_q = n_q + self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore + + @abstractmethod + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern with specific interleaving between codebooks. + + Args: + timesteps (int): Total numer of timesteps. + """ + raise NotImplementedError() + + +class DelayedPatternProvider(CodebooksPatternProvider): + """Provider for delayed pattern across delayed codebooks. + Codebooks are delayed in the sequence and sequence steps will contain codebooks + from different timesteps. + + Example: + Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + The resulting sequence obtained from the returned pattern is: + [[S, 1, 2, 3, 4], + [S, S, 1, 2, 3], + [S, S, S, 1, 2]] + (with S being a special token) + + Args: + n_q (int): Number of codebooks. + delays (Optional[List[int]]): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + flatten_first (int): Flatten the first N timesteps. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, + flatten_first: int = 0, empty_initial: int = 0): + super().__init__(n_q) + if delays is None: + delays = list(range(n_q)) + self.delays = delays + self.flatten_first = flatten_first + self.empty_initial = empty_initial + assert len(self.delays) == self.n_q + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + max_delay = max(self.delays) + if self.empty_initial: + out += [[] for _ in range(self.empty_initial)] + if self.flatten_first: + for t in range(min(timesteps, self.flatten_first)): + for q in range(self.n_q): + out.append([LayoutCoord(t, q)]) + for t in range(self.flatten_first, timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= self.flatten_first: + v.append(LayoutCoord(t_for_q, q)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class ParallelPatternProvider(DelayedPatternProvider): + """Provider for parallel pattern across codebooks. + This pattern provider is a special case of the delayed pattern with actually no delay, + hence delays=repeat(0, n_q). + + Args: + n_q (int): Number of codebooks. + """ + def __init__(self, n_q: int): + super().__init__(n_q, [0] * n_q) + + +class UnrolledPatternProvider(CodebooksPatternProvider): + """Provider for unrolling codebooks pattern. + This pattern provider enables to represent the codebook flattened completely or only to some extend + while also specifying a given delay between the flattened codebooks representation, allowing to + unroll the codebooks in the sequence. + + Example: + 1. Flattening of the codebooks. + By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), + taking n_q = 3 and timesteps = 4: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, 1, S, S, 2, S, S, 3, S, S, 4], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step + for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example + taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks + allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the + same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] + and delays = [0, 3, 3]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, S, 1, S, 2, S, 3, S, 4], + [S, S, S, 1, S, 2, S, 3, S, 4], + [1, 2, 3, S, 4, S, 5, S, 6, S]] + + Args: + n_q (int): Number of codebooks. + flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined, + the codebooks will be flattened to 1 codebook per step, meaning that the sequence will + have n_q extra steps for each timestep. + delays (Optional[List[int]]): Delay for each of the codebooks. If not defined, + no delay is added and therefore will default to [0] * ``n_q``. + Note that two codebooks that will be flattened to the same inner step + should have the same delay, otherwise the pattern is considered as invalid. + """ + FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay']) + + def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, + delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if flattening is None: + flattening = list(range(n_q)) + if delays is None: + delays = [0] * n_q + assert len(flattening) == n_q + assert len(delays) == n_q + assert sorted(flattening) == flattening + assert sorted(delays) == delays + self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) + self.max_delay = max(delays) + + def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]): + """Build a flattened codebooks representation as a dictionary of inner step + and the actual codebook indices corresponding to the flattened codebook. For convenience, we + also store the delay associated to the flattened codebook to avoid maintaining an extra mapping. + """ + flattened_codebooks: dict = {} + for q, (inner_step, delay) in enumerate(zip(flattening, delays)): + if inner_step not in flattened_codebooks: + flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) + else: + flat_codebook = flattened_codebooks[inner_step] + assert flat_codebook.delay == delay, ( + "Delay and flattening between codebooks is inconsistent: ", + "two codebooks flattened to the same position should have the same delay." + ) + flat_codebook.codebooks.append(q) + flattened_codebooks[inner_step] = flat_codebook + return flattened_codebooks + + @property + def _num_inner_steps(self): + """Number of inner steps to unroll between timesteps in order to flatten the codebooks. + """ + return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1 + + def num_virtual_steps(self, timesteps: int) -> int: + return timesteps * self._num_inner_steps + 1 + + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern for delay across codebooks. + + Args: + timesteps (int): Total numer of timesteps. + """ + # the PatternLayout is built as a tuple of sequence position and list of coordinates + # so that it can be reordered properly given the required delay between codebooks of given timesteps + indexed_out: list = [(-1, [])] + max_timesteps = timesteps + self.max_delay + for t in range(max_timesteps): + # for each timestep, we unroll the flattened codebooks, + # emitting the sequence step with the corresponding delay + for step in range(self._num_inner_steps): + if step in self._flattened_codebooks: + # we have codebooks at this virtual step to emit + step_codebooks = self._flattened_codebooks[step] + t_for_q = t + step_codebooks.delay + coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks] + if t_for_q < max_timesteps and t < max_timesteps: + indexed_out.append((t_for_q, coords)) + else: + # there is no codebook in this virtual step so we emit an empty list + indexed_out.append((t, [])) + out = [coords for _, coords in sorted(indexed_out)] + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class VALLEPattern(CodebooksPatternProvider): + """Almost VALL-E style pattern. We futher allow some delays for the + codebooks other than the first one. + + Args: + n_q (int): Number of codebooks. + delays (Optional[List[int]]): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if delays is None: + delays = [0] * (n_q - 1) + self.delays = delays + assert len(self.delays) == self.n_q - 1 + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for t in range(timesteps): + out.append([LayoutCoord(t, 0)]) + max_delay = max(self.delays) + for t in range(timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= 0: + v.append(LayoutCoord(t_for_q, q + 1)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class MusicLMPattern(CodebooksPatternProvider): + """Almost MusicLM style pattern. This is equivalent to full flattening + but in a different order. + + Args: + n_q (int): Number of codebooks. + group_by (int): Number of codebooks to group together. + """ + def __init__(self, n_q: int, group_by: int = 2): + super().__init__(n_q) + self.group_by = group_by + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for offset in range(0, self.n_q, self.group_by): + for t in range(timesteps): + for q in range(offset, offset + self.group_by): + out.append([LayoutCoord(t, q)]) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) diff --git a/melodytalk/audiocraft/modules/conditioners.py b/melodytalk/audiocraft/modules/conditioners.py new file mode 100644 index 0000000..e57fdc9 --- /dev/null +++ b/melodytalk/audiocraft/modules/conditioners.py @@ -0,0 +1,991 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass, field +from itertools import chain +import logging +import math +import random +import re +import typing as tp +import warnings + +from einops import rearrange +from num2words import num2words +import spacy +from transformers import T5EncoderModel, T5Tokenizer # type: ignore +import torchaudio +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from .streaming import StreamingModule +from .transformer import create_sin_embedding +from ..data.audio_dataset import SegmentInfo +from ..utils.autocast import TorchAutocast +from ..utils.utils import hash_trick, length_to_mask, collate + + +logger = logging.getLogger(__name__) +TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist) +ConditionType = tp.Tuple[Tensor, Tensor] # condition, mask + + +class WavCondition(tp.NamedTuple): + wav: Tensor + length: Tensor + path: tp.List[tp.Optional[str]] = [] + + +def nullify_condition(condition: ConditionType, dim: int = 1): + """This function transforms an input condition to a null condition. + The way it is done by converting it to a single zero vector similarly + to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. + + Args: + condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor]) + dim (int): the dimension that will be truncated (should be the time dimension) + WARNING!: dim should not be the batch dimension! + Returns: + ConditionType: a tuple of null condition and mask + """ + assert dim != 0, "dim cannot be the batch dimension!" + assert type(condition) == tuple and \ + type(condition[0]) == Tensor and \ + type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!" + cond, mask = condition + B = cond.shape[0] + last_dim = cond.dim() - 1 + out = cond.transpose(dim, last_dim) + out = 0. * out[..., :1] + out = out.transpose(dim, last_dim) + mask = torch.zeros((B, 1), device=out.device).int() + assert cond.dim() == out.dim() + return out, mask + + +def nullify_wav(wav: Tensor) -> WavCondition: + """Create a nullified WavCondition from a wav tensor with appropriate shape. + + Args: + wav (Tensor): tensor of shape [B, T] + Returns: + WavCondition: wav condition with nullified wav. + """ + null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1) + return WavCondition( + wav=null_wav, + length=torch.tensor([0] * wav.shape[0], device=wav.device), + path=['null_wav'] * wav.shape[0] + ) + + +@dataclass +class ConditioningAttributes: + text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) + wav: tp.Dict[str, WavCondition] = field(default_factory=dict) + + def __getitem__(self, item): + return getattr(self, item) + + @property + def text_attributes(self): + return self.text.keys() + + @property + def wav_attributes(self): + return self.wav.keys() + + @property + def attributes(self): + return {"text": self.text_attributes, "wav": self.wav_attributes} + + def to_flat_dict(self): + return { + **{f"text.{k}": v for k, v in self.text.items()}, + **{f"wav.{k}": v for k, v in self.wav.items()}, + } + + @classmethod + def from_flat_dict(cls, x): + out = cls() + for k, v in x.items(): + kind, att = k.split(".") + out[kind][att] = v + return out + + +class SegmentWithAttributes(SegmentInfo): + """Base class for all dataclasses that are used for conditioning. + All child classes should implement `to_condition_attributes` that converts + the existing attributes to a dataclass of type ConditioningAttributes. + """ + def to_condition_attributes(self) -> ConditioningAttributes: + raise NotImplementedError() + + +class Tokenizer: + """Base class for all tokenizers + (in case we want to introduce more advances tokenizers in the future). + """ + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]: + raise NotImplementedError() + + +class WhiteSpaceTokenizer(Tokenizer): + """This tokenizer should be used for natural language descriptions. + For example: + ["he didn't, know he's going home.", 'shorter sentence'] => + [[78, 62, 31, 4, 78, 25, 19, 34], + [59, 77, 0, 0, 0, 0, 0, 0]] + """ + PUNCTUATIONS = "?:!.,;" + + def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm", + lemma: bool = True, stopwords: bool = True) -> None: + self.n_bins = n_bins + self.pad_idx = pad_idx + self.lemma = lemma + self.stopwords = stopwords + try: + self.nlp = spacy.load(language) + except IOError: + spacy.cli.download(language) # type: ignore + self.nlp = spacy.load(language) + + @tp.no_type_check + def __call__( + self, + texts: tp.List[tp.Optional[str]], + return_text: bool = False + ) -> tp.Tuple[Tensor, Tensor]: + """Take a list of strings and convert them to a tensor of indices. + + Args: + texts (tp.List[str]): List of strings. + return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False. + Returns: + tp.Tuple[Tensor, Tensor]: + - Indices of words in the LUT. + - And a mask indicating where the padding tokens are + """ + output, lengths = [], [] + texts = deepcopy(texts) + for i, text in enumerate(texts): + # if current sample doesn't have a certain attribute, replace with pad token + if text is None: + output.append(Tensor([self.pad_idx])) + lengths.append(0) + continue + + # convert numbers to words + text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore + # normalize text + text = self.nlp(text) # type: ignore + # remove stopwords + if self.stopwords: + text = [w for w in text if not w.is_stop] # type: ignore + # remove punctuations + text = [w for w in text if w.text not in self.PUNCTUATIONS] # type: ignore + # lemmatize if needed + text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore + + texts[i] = " ".join(text) + lengths.append(len(text)) + # convert to tensor + tokens = Tensor([hash_trick(w, self.n_bins) for w in text]) + output.append(tokens) + + mask = length_to_mask(torch.IntTensor(lengths)).int() + padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t() + if return_text: + return padded_output, mask, texts # type: ignore + return padded_output, mask + + +class NoopTokenizer(Tokenizer): + """This tokenizer should be used for global conditioners such as: artist, genre, key, etc. + The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split + strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will + split it to ["Jeff", "Buckley"] and return an index per word. + + For example: + ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101] + ["Metal", "Rock", "Classical"] => [0, 223, 51] + """ + def __init__(self, n_bins: int, pad_idx: int = 0): + self.n_bins = n_bins + self.pad_idx = pad_idx + + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]: + output, lengths = [], [] + for text in texts: + # if current sample doesn't have a certain attribute, replace with pad token + if text is None: + output.append(self.pad_idx) + lengths.append(0) + else: + output.append(hash_trick(text, self.n_bins)) + lengths.append(1) + + tokens = torch.LongTensor(output).unsqueeze(1) + mask = length_to_mask(torch.IntTensor(lengths)).int() + return tokens, mask + + +class BaseConditioner(nn.Module): + """Base model for all conditioner modules. We allow the output dim to be different + than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large; + 2) make all condition dims consistent. + + Args: + dim (int): Hidden dim of the model (text-encoder/LUT). + output_dim (int): Output dim of the conditioner. + """ + def __init__(self, dim, output_dim): + super().__init__() + self.dim = dim + self.output_dim = output_dim + self.output_proj = nn.Linear(dim, output_dim) + + def tokenize(self, *args, **kwargs) -> tp.Any: + """Should be any part of the processing that will lead to a synchronization + point, e.g. BPE tokenization with transfer to the GPU. + + The returned value will be saved and return later when calling forward(). + """ + raise NotImplementedError() + + def forward(self, inputs: tp.Any) -> ConditionType: + """Gets input that should be used as conditioning (e.g, genre, description or a waveform). + Outputs a ConditionType, after the input data was embedded as a dense vector. + + Returns: + ConditionType: + - A tensor of size [B, T, D] where B is the batch size, T is the length of the + output embedding and D is the dimension of the embedding. + - And a mask indicating where the padding tokens. + """ + raise NotImplementedError() + + +class TextConditioner(BaseConditioner): + ... + + +class LUTConditioner(TextConditioner): + """Lookup table TextConditioner. + + Args: + n_bins (int): Number of bins. + dim (int): Hidden dim of the model (text-encoder/LUT). + output_dim (int): Output dim of the conditioner. + tokenizer (str): Name of the tokenizer. + pad_idx (int, optional): Index for padding token. Defaults to 0. + """ + def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0): + super().__init__(dim, output_dim) + self.embed = nn.Embedding(n_bins, dim) + self.tokenizer: Tokenizer + if tokenizer == "whitespace": + self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx) + elif tokenizer == "noop": + self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx) + else: + raise ValueError(f"unrecognized tokenizer `{tokenizer}`.") + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + device = self.embed.weight.device + tokens, mask = self.tokenizer(x) + tokens, mask = tokens.to(device), mask.to(device) + return tokens, mask + + def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType: + tokens, mask = inputs + embeds = self.embed(tokens) + embeds = self.output_proj(embeds) + embeds = (embeds * mask.unsqueeze(-1)) + return embeds, mask + + +class T5Conditioner(TextConditioner): + """T5-based TextConditioner. + + Args: + name (str): Name of the T5 model. + output_dim (int): Output dim of the conditioner. + finetune (bool): Whether to fine-tune T5 at train time. + device (str): Device for T5 Conditioner. + autocast_dtype (tp.Optional[str], optional): Autocast dtype. + word_dropout (float, optional): Word dropout probability. + normalize_text (bool, optional): Whether to apply text normalization. + """ + MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl"] + MODELS_DIMS = { + "t5-small": 512, + "t5-base": 768, + "t5-large": 1024, + "t5-3b": 1024, + "t5-11b": 1024, + "google/flan-t5-small": 512, + "google/flan-t5-base": 768, + "google/flan-t5-large": 1024, + "google/flan-t5-3b": 1024, + "google/flan-t5-11b": 1024, + } + + def __init__(self, name: str, output_dim: int, finetune: bool, device: str, + autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., + normalize_text: bool = False): + assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})" + super().__init__(self.MODELS_DIMS[name], output_dim) + self.device = device + self.name = name + self.finetune = finetune + self.word_dropout = word_dropout + + if autocast_dtype is None or self.device == 'cpu': + self.autocast = TorchAutocast(enabled=False) + if self.device != 'cpu': + logger.warning("T5 has no autocast, this might lead to NaN") + else: + dtype = getattr(torch, autocast_dtype) + assert isinstance(dtype, torch.dtype) + logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}") + self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) + # Let's disable logging temporarily because T5 will vomit some errors otherwise. + # thanks https://gist.github.com/simon-weber/7853144 + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.t5_tokenizer = T5Tokenizer.from_pretrained(name) + t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune) + finally: + logging.disable(previous_level) + if finetune: + self.t5 = t5 + else: + # this makes sure that the t5 models is not part + # of the saved checkpoint + self.__dict__["t5"] = t5.to(device) + + self.normalize_text = normalize_text + if normalize_text: + self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True) + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: + # if current sample doesn't have a certain attribute, replace with empty string + entries: tp.List[str] = [xi if xi is not None else "" for xi in x] + if self.normalize_text: + _, _, entries = self.text_normalizer(entries, return_text=True) + if self.word_dropout > 0. and self.training: + new_entries = [] + for entry in entries: + words = [word for word in entry.split(" ") if random.random() >= self.word_dropout] + new_entries.append(" ".join(words)) + entries = new_entries + + empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) + + inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device) + mask = inputs["attention_mask"] + mask[empty_idx, :] = 0 # zero-out index where the input is non-existant + return inputs + + def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: + mask = inputs["attention_mask"] + with torch.set_grad_enabled(self.finetune), self.autocast: + embeds = self.t5(**inputs).last_hidden_state + embeds = self.output_proj(embeds.to(self.output_proj.weight)) + embeds = (embeds * mask.unsqueeze(-1)) + return embeds, mask + + +class WaveformConditioner(BaseConditioner): + """Base class for all conditioners that take a waveform as input. + Classes that inherit must implement `_get_wav_embedding` that outputs + a continuous tensor, and `_downsampling_factor` that returns the down-sampling + factor of the embedding model. + + Args: + dim (int): The internal representation dimension. + output_dim (int): Output dimension. + device (tp.Union[torch.device, str]): Device. + """ + def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): + super().__init__(dim, output_dim) + self.device = device + + def tokenize(self, wav_length: WavCondition) -> WavCondition: + wav, length, path = wav_length + assert length is not None + return WavCondition(wav.to(self.device), length.to(self.device), path) + + def _get_wav_embedding(self, wav: Tensor) -> Tensor: + """Gets as input a wav and returns a dense vector of conditions.""" + raise NotImplementedError() + + def _downsampling_factor(self): + """Returns the downsampling factor of the embedding model.""" + raise NotImplementedError() + + def forward(self, inputs: WavCondition) -> ConditionType: + """ + Args: + input (WavCondition): Tuple of (waveform, lengths). + Returns: + ConditionType: Dense vector representing the conditioning along with its' mask. + """ + wav, lengths, path = inputs + with torch.no_grad(): + embeds = self._get_wav_embedding(wav) + embeds = embeds.to(self.output_proj.weight) + embeds = self.output_proj(embeds) + + if lengths is not None: + lengths = lengths / self._downsampling_factor() + mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore + else: + mask = torch.ones_like(embeds) + embeds = (embeds * mask.unsqueeze(2).to(self.device)) + + return embeds, mask + + +class ChromaStemConditioner(WaveformConditioner): + """Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by + the insight the drums and bass often dominate the chroma, leading to the chroma not containing the + information about melody. + + Args: + output_dim (int): Output dimension for the conditioner. + sample_rate (int): Sample rate for the chroma extractor. + n_chroma (int): Number of chroma for the chroma extractor. + radix2_exp (int): Radix2 exponent for the chroma extractor. + duration (float): Duration used during training. This is later used for correct padding + in case we are using chroma as prefix. + match_len_on_eval (bool, optional): If True then all chromas are padded to the training + duration. Defaults to False. + eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as + conditions during eval (for cases where we don't want to leak test conditions like MusicCaps). + Defaults to None. + n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0. + device (tp.Union[torch.device, str], optional): Device for the conditioner. + **kwargs: Additional parameters for the chroma extractor. + """ + def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int, + duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None, + n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs): + from demucs import pretrained + super().__init__(dim=n_chroma, output_dim=output_dim, device=device) + self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32) + self.sample_rate = sample_rate + self.match_len_on_eval = match_len_on_eval + self.duration = duration + self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device) + self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3} + self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device) + self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp, + device=device, **kwargs) + self.chroma_len = self._get_chroma_len() + + def _downsampling_factor(self): + return self.chroma.winhop + + def _get_chroma_len(self): + """Get length of chroma during training""" + dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device) + dummy_chr = self.chroma(dummy_wav) + return dummy_chr.shape[1] + + @torch.no_grad() + def _get_filtered_wav(self, wav): + from demucs.apply import apply_model + from demucs.audio import convert_audio + with self.autocast: + wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels) + stems = apply_model(self.demucs, wav, device=self.device) + stems = stems[:, self.stem_idx] # extract stem + stems = stems.sum(1) # merge extracted stems + stems = stems.mean(1, keepdim=True) # mono + stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1) + return stems + + @torch.no_grad() + def _get_wav_embedding(self, wav): + # avoid 0-size tensors when we are working with null conds + if wav.shape[-1] == 1: + return self.chroma(wav) + # stems = self._get_filtered_wav(wav) + stems = wav + chroma = self.chroma(stems) + + if self.match_len_on_eval: + b, t, c = chroma.shape + if t > self.chroma_len: + chroma = chroma[:, :self.chroma_len] + logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})') + elif t < self.chroma_len: + # chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t)) + n_repeat = int(math.ceil(self.chroma_len / t)) + chroma = chroma.repeat(1, n_repeat, 1) + chroma = chroma[:, :self.chroma_len] + logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})') + return chroma + + +class ChromaExtractor(nn.Module): + """Chroma extraction class, handles chroma extraction and quantization. + + Args: + sample_rate (int): Sample rate. + n_chroma (int): Number of chroma to consider. + radix2_exp (int): Radix2 exponent. + nfft (tp.Optional[int], optional): Number of FFT. + winlen (tp.Optional[int], optional): Window length. + winhop (tp.Optional[int], optional): Window hop size. + argmax (bool, optional): Whether to use argmax. Defaults to False. + norm (float, optional): Norm for chroma normalization. Defaults to inf. + device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu. + """ + def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, + nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, + argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"): + super().__init__() + from librosa import filters + self.device = device + self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32) + self.winlen = winlen or 2 ** radix2_exp + self.nfft = nfft or self.winlen + self.winhop = winhop or (self.winlen // 4) + self.sr = sample_rate + self.n_chroma = n_chroma + self.norm = norm + self.argmax = argmax + self.window = torch.hann_window(self.winlen).to(device) + self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, + n_chroma=self.n_chroma)).to(device) + self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, + hop_length=self.winhop, power=2, center=True, + pad=0, normalized=True).to(device) + + def forward(self, wav): + with self.autocast: + T = wav.shape[-1] + # in case we are getting a wav that was dropped out (nullified) + # make sure wav length is no less that nfft + if T < self.nfft: + pad = self.nfft - T + r = 0 if pad % 2 == 0 else 1 + wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) + assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}' + spec = self.spec(wav).squeeze(1) + raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec) + norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) + norm_chroma = rearrange(norm_chroma, "b d t -> b t d") + + if self.argmax: + idx = norm_chroma.argmax(-1, keepdims=True) + norm_chroma[:] = 0 + norm_chroma.scatter_(dim=-1, index=idx, value=1) + + return norm_chroma + + +def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str): + """Utility function for nullifying an attribute inside an ConditioningAttributes object. + If the condition is of type "wav", then nullify it using "nullify_condition". + If the condition is of any other type, set its' value to None. + Works in-place. + """ + if condition_type not in ["text", "wav"]: + raise ValueError( + "dropout_condition got an unexpected condition type!" + f" expected 'wav' or 'text' but got '{condition_type}'" + ) + + if condition not in getattr(sample, condition_type): + raise ValueError( + "dropout_condition received an unexpected condition!" + f" expected wav={sample.wav.keys()} and text={sample.text.keys()}" + f"but got '{condition}' of type '{condition_type}'!" + ) + + if condition_type == "wav": + wav, length, path = sample.wav[condition] + sample.wav[condition] = nullify_wav(wav) + else: + sample.text[condition] = None + + return sample + + +class DropoutModule(nn.Module): + """Base class for all dropout modules.""" + def __init__(self, seed: int = 1234): + super().__init__() + self.rng = torch.Generator() + self.rng.manual_seed(seed) + + +class AttributeDropout(DropoutModule): + """Applies dropout with a given probability per attribute. This is different from the behavior of + ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example, + "artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout + where if "artist" is dropped "genre" must also be dropped. + + Args: + p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: + ... + "genre": 0.1, + "artist": 0.5, + "wav": 0.25, + ... + active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. + seed (int, optional): Random seed. + """ + def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): + super().__init__(seed=seed) + self.active_on_eval = active_on_eval + # construct dict that return the values from p otherwise 0 + self.p = {} + for condition_type, probs in p.items(): + self.p[condition_type] = defaultdict(lambda: 0, probs) + + def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + """ + Args: + samples (tp.List[ConditioningAttributes]): List of conditions. + Returns: + tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None. + """ + if not self.training and not self.active_on_eval: + return samples + + samples = deepcopy(samples) + + for condition_type, ps in self.p.items(): # for condition types [text, wav] + for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) + if torch.rand(1, generator=self.rng).item() < p: + for sample in samples: + dropout_condition(sample, condition_type, condition) + + return samples + + def __repr__(self): + return f"AttributeDropout({dict(self.p)})" + + +class ClassifierFreeGuidanceDropout(DropoutModule): + """Applies Classifier Free Guidance dropout, meaning all attributes + are dropped with the same probability. + + Args: + p (float): Probability to apply condition dropout during training. + seed (int): Random seed. + """ + def __init__(self, p: float, seed: int = 1234): + super().__init__(seed=seed) + self.p = p + + def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + """ + Args: + samples (tp.List[ConditioningAttributes]): List of conditions. + Returns: + tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None. + """ + if not self.training: + return samples + + # decide on which attributes to drop in a batched fashion + drop = torch.rand(1, generator=self.rng).item() < self.p + if not drop: + return samples + + # nullify conditions of all attributes + samples = deepcopy(samples) + + for condition_type in ["wav", "text"]: + for sample in samples: + for condition in sample.attributes[condition_type]: + dropout_condition(sample, condition_type, condition) + + return samples + + def __repr__(self): + return f"ClassifierFreeGuidanceDropout(p={self.p})" + + +class ConditioningProvider(nn.Module): + """Main class to provide conditions given all the supported conditioners. + + Args: + conditioners (dict): Dictionary of conditioners. + merge_text_conditions_p (float, optional): Probability to merge all text sources + into a single text condition. Defaults to 0. + drop_desc_p (float, optional): Probability to drop the original description + when merging all text sources into a single text condition. Defaults to 0. + device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types. + """ + def __init__( + self, + conditioners: tp.Dict[str, BaseConditioner], + merge_text_conditions_p: float = 0, + drop_desc_p: float = 0, + device: tp.Union[torch.device, str] = "cpu", + ): + super().__init__() + self.device = device + self.merge_text_conditions_p = merge_text_conditions_p + self.drop_desc_p = drop_desc_p + self.conditioners = nn.ModuleDict(conditioners) + + @property + def text_conditions(self): + return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] + + @property + def wav_conditions(self): + return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)] + + @property + def has_wav_condition(self): + return len(self.wav_conditions) > 0 + + def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: + """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. + This should be called before starting any real GPU work to avoid synchronization points. + This will return a dict matching conditioner names to their arbitrary tokenized representations. + + Args: + inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing + text and wav conditions. + """ + assert all([type(x) == ConditioningAttributes for x in inputs]), \ + "got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \ + f" but types were {set([type(x) for x in inputs])}" + + output = {} + text = self._collate_text(inputs) + wavs = self._collate_wavs(inputs) + + assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \ + f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}" + + for attribute, batch in chain(text.items(), wavs.items()): + output[attribute] = self.conditioners[attribute].tokenize(batch) + return output + + def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: + """Compute pairs of `(embedding, mask)` using the configured conditioners + and the tokenized representations. The output is for example: + + { + "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), + "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), + ... + } + + Args: + tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. + """ + output = {} + for attribute, inputs in tokenized.items(): + condition, mask = self.conditioners[attribute](inputs) + output[attribute] = (condition, mask) + return output + + def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: + """Given a list of ConditioningAttributes objects, compile a dictionary where the keys + are the attributes and the values are the aggregated input per attribute. + For example: + Input: + [ + ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), + ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...), + ] + Output: + { + "genre": ["Rock", "Hip-hop"], + "description": ["A rock song with a guitar solo", "A hip-hop verse"] + } + """ + batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) + + def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0): + def is_valid(k, v): + k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument'] + v_valid = v is not None and isinstance(v, (int, float, str, list)) + return k_valid and v_valid + + def process_value(v): + if isinstance(v, (int, float, str)): + return v + if isinstance(v, list): + return ", ".join(v) + else: + RuntimeError(f"unknown type for text value! ({type(v), v})") + + desc = cond.text['description'] + meta_data = "" + if random.uniform(0, 1) < merge_text_conditions_p: + meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)] + random.shuffle(meta_pairs) + meta_data = ". ".join(meta_pairs) + desc = desc if not random.uniform(0, 1) < drop_desc_p else None + + if desc is None: + desc = meta_data if len(meta_data) > 1 else None + else: + desc = desc.rstrip('.') + ". " + meta_data + cond.text['description'] = desc.strip() if desc else None + + if self.training and self.merge_text_conditions_p: + for sample in samples: + _merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p) + + texts = [x.text for x in samples] + for text in texts: + for condition in self.text_conditions: + batch_per_attribute[condition].append(text[condition]) + + return batch_per_attribute + + def _collate_wavs(self, samples: tp.List[ConditioningAttributes]): + """Generate a dict where the keys are attributes by which we fetch similar wavs, + and the values are Tensors of wavs according to said attribtues. + + *Note*: by the time the samples reach this function, each sample should have some waveform + inside the "wav" attribute. It should be either: + 1. A real waveform + 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset) + 3. A null waveform due to it being dropped in a dropout module (nullified by dropout) + + Args: + samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples. + Returns: + dict: A dicionary mapping an attribute name to wavs. + """ + wavs = defaultdict(list) + lens = defaultdict(list) + paths = defaultdict(list) + out = {} + + for sample in samples: + for attribute in self.wav_conditions: + wav, length, path = sample.wav[attribute] + wavs[attribute].append(wav.flatten()) + lens[attribute].append(length) + paths[attribute].append(path) + + # stack all wavs to a single tensor + for attribute in self.wav_conditions: + stacked_wav, _ = collate(wavs[attribute], dim=0) + out[attribute] = WavCondition(stacked_wav.unsqueeze(1), + torch.cat(lens['self_wav']), paths[attribute]) # type: ignore + + return out + + +class ConditionFuser(StreamingModule): + """Condition fuser handles the logic to combine the different conditions + to the actual model input. + + Args: + fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse + each condition. For example: + { + "prepend": ["description"], + "sum": ["genre", "bpm"], + "cross": ["description"], + } + cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. + cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. + """ + FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"] + + def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, + cross_attention_pos_emb_scale: float = 1.0): + super().__init__() + assert all( + [k in self.FUSING_METHODS for k in fuse2cond.keys()] + ), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}" + self.cross_attention_pos_emb = cross_attention_pos_emb + self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale + self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond + self.cond2fuse: tp.Dict[str, str] = {} + for fuse_method, conditions in fuse2cond.items(): + for condition in conditions: + self.cond2fuse[condition] = fuse_method + + def forward( + self, + input: Tensor, + conditions: tp.Dict[str, ConditionType] + ) -> tp.Tuple[Tensor, tp.Optional[Tensor]]: + """Fuse the conditions to the provided model input. + + Args: + input (Tensor): Transformer input. + conditions (tp.Dict[str, ConditionType]): Dict of conditions. + Returns: + tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input + after the conditions have been fused. The second output tensor is the tensor + used for cross-attention or None if no cross attention inputs exist. + """ + B, T, _ = input.shape + + if 'offsets' in self._streaming_state: + first_step = False + offsets = self._streaming_state['offsets'] + else: + first_step = True + offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device) + + assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ + f"given conditions contain unknown attributes for fuser, " \ + f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" + cross_attention_output = None + for cond_type, (cond, cond_mask) in conditions.items(): + op = self.cond2fuse[cond_type] + if op == "sum": + input += cond + elif op == "input_interpolate": + cond = rearrange(cond, "b t d -> b d t") + cond = F.interpolate(cond, size=input.shape[1]) + input += rearrange(cond, "b d t -> b t d") + elif op == "prepend": + if first_step: + input = torch.cat([cond, input], dim=1) + elif op == "cross": + if cross_attention_output is not None: + cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) + else: + cross_attention_output = cond + else: + raise ValueError(f"unknown op ({op})") + + if self.cross_attention_pos_emb and cross_attention_output is not None: + positions = torch.arange( + cross_attention_output.shape[1], + device=cross_attention_output.device + ).view(1, -1, 1) + pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1]) + cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb + + if self._is_streaming: + self._streaming_state['offsets'] = offsets + T + + return input, cross_attention_output diff --git a/melodytalk/audiocraft/modules/conv.py b/melodytalk/audiocraft/modules/conv.py new file mode 100644 index 0000000..972938a --- /dev/null +++ b/melodytalk/audiocraft/modules/conv.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs): + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`. + """ + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d! + """ + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class StreamableConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect'): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1' + f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + return self.conv(x) + + +class StreamableConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/melodytalk/audiocraft/modules/lstm.py b/melodytalk/audiocraft/modules/lstm.py new file mode 100644 index 0000000..c086617 --- /dev/null +++ b/melodytalk/audiocraft/modules/lstm.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn + + +class StreamableLSTM(nn.Module): + """LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y diff --git a/melodytalk/audiocraft/modules/rope.py b/melodytalk/audiocraft/modules/rope.py new file mode 100644 index 0000000..4b8c70b --- /dev/null +++ b/melodytalk/audiocraft/modules/rope.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +from torch import nn +import torch + + +class XPos(nn.Module): + """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1). + This applies an exponential decay to the RoPE rotation matrix. + + Args: + dim (int): Embedding dimension. + smoothing (float): Smoothing factor applied to the decay rates. + base_scale (int): Base decay rate, given in terms of scaling time. + device (torch.device or None): Device on which to initialize the module. + dtype (torch.dtype): dtype to use to generate the embedding. + """ + def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512, + device=None, dtype: torch.dtype = torch.float32): + super().__init__() + assert dim % 2 == 0 + assert dtype in [torch.float64, torch.float32] + self.dtype = dtype + self.base_scale = base_scale + + half_dim = dim // 2 + adim = torch.arange(half_dim, device=device, dtype=dtype) + decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing) + self.register_buffer("decay_rates", decay_rates) + self.decay: tp.Optional[torch.Tensor] = None + + def get_decay(self, start: int, end: int): + """Create complex decay tensor, cache values for fast computation. + """ + if self.decay is None or end > self.decay.shape[0]: + assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker. + idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype) + power = idx / self.base_scale + scale = self.decay_rates ** power.unsqueeze(-1) + self.decay = torch.polar(scale, torch.zeros_like(scale)) + return self.decay[start:end] # [T, C/2] + + +class RotaryEmbedding(nn.Module): + """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864). + + Args: + dim (int): Embedding dimension (twice the number of frequencies). + max_period (float): Maximum period of the rotation frequencies. + xpos (bool): Use xPos, applies an exponential decay to rotation matrix. + scale (float): Scale of positional embedding, set to 0 to deactivate. + device (torch.device or None): Device on which to initialize the module. + dtype (torch.dtype): dtype to use to generate the embedding. + """ + def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False, + scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32): + super().__init__() + assert dim % 2 == 0 + self.scale = scale + assert dtype in [torch.float64, torch.float32] + self.dtype = dtype + + adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)] + frequencies = 1.0 / (max_period ** (adim / dim)) + self.register_buffer("frequencies", frequencies) + self.rotation: tp.Optional[torch.Tensor] = None + + self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None + + def get_rotation(self, start: int, end: int): + """Create complex rotation tensor, cache values for fast computation. + """ + if self.rotation is None or end > self.rotation.shape[0]: + assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker. + idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype) + angles = torch.outer(idx, self.frequencies) + self.rotation = torch.polar(torch.ones_like(angles), angles) + return self.rotation[start:end] + + def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False): + """Apply rope rotation to query or key tensor. + """ + T = x.shape[1] + rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2) + + if self.xpos: + decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2) + else: + decay = 1.0 + + if invert_decay: + decay = decay ** -1 + + x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2)) + scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale) + x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2) + + return x_out.type_as(x) + + def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0): + """ Apply rope rotation to both query and key tensors. + Supports streaming mode, in which query and key are not expected to have the same shape. + In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but + query will be [C] (typically C == 1). + + Args: + query (torch.Tensor): Query to rotate. + key (torch.Tensor): Key to rotate. + start (int): Start index of the sequence for time offset. + """ + query_timesteps = query.shape[1] + key_timesteps = key.shape[1] + streaming_offset = key_timesteps - query_timesteps + + query_out = self.rotate(query, start + streaming_offset) + key_out = self.rotate(key, start, invert_decay=True) + + return query_out, key_out diff --git a/melodytalk/audiocraft/modules/seanet.py b/melodytalk/audiocraft/modules/seanet.py new file mode 100644 index 0000000..3e5998e --- /dev/null +++ b/melodytalk/audiocraft/modules/seanet.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import numpy as np +import torch.nn as nn + +from .conv import StreamableConv1d, StreamableConvTranspose1d +from .lstm import StreamableLSTM + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + + Args: + dim (int): Dimension of the input/output. + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection. + """ + def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], + activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, + pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): + super().__init__() + assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order. We use the decoder order as some models may only employ the decoder. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. + For the encoder, it corresponds to the N first blocks. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, + disable_norm_outer_blocks: int = 0): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks + self.disable_norm_outer_blocks = disable_norm_outer_blocks + assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ + "Number of blocks for which to disable norm is invalid." \ + "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + + act = getattr(nn, activation) + mult = 1 + model: tp.List[nn.Module] = [ + StreamableConv1d(channels, mult * n_filters, kernel_size, + norm='none' if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + norm=block_norm, norm_params=norm_params, + activation=activation, activation_params=activation_params, + causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + # Add downsampling layers + model += [ + act(**activation_params), + StreamableConv1d(mult * n_filters, mult * n_filters * 2, + kernel_size=ratio * 2, stride=ratio, + norm=block_norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + mult *= 2 + + if lstm: + model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] + + model += [ + act(**activation_params), + StreamableConv1d(mult * n_filters, dimension, last_kernel_size, + norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, + norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple. + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. + For the decoder, it corresponds to the N last blocks. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, + norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, + disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks + self.disable_norm_outer_blocks = disable_norm_outer_blocks + assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ + "Number of blocks for which to disable norm is invalid." \ + "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + + act = getattr(nn, activation) + mult = int(2 ** len(self.ratios)) + model: tp.List[nn.Module] = [ + StreamableConv1d(dimension, mult * n_filters, kernel_size, + norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, + norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + ] + + if lstm: + model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm + # Add upsampling layers + model += [ + act(**activation_params), + StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2, + kernel_size=ratio * 2, stride=ratio, + norm=block_norm, norm_kwargs=norm_params, + causal=causal, trim_right_ratio=trim_right_ratio), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + activation=activation, activation_params=activation_params, + norm=block_norm, norm_params=norm_params, causal=causal, + pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params), + StreamableConv1d(n_filters, channels, last_kernel_size, + norm='none' if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [ + final_act(**final_activation_params) + ] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y diff --git a/melodytalk/audiocraft/modules/streaming.py b/melodytalk/audiocraft/modules/streaming.py new file mode 100644 index 0000000..fdbdf5e --- /dev/null +++ b/melodytalk/audiocraft/modules/streaming.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Streaming module API that should be implemented by all Streaming components, +""" + +from contextlib import contextmanager +import typing as tp +from torch import nn +import torch + + +State = tp.Dict[str, torch.Tensor] + + +class StreamingModule(nn.Module): + """Common API for streaming components. + + Each streaming component has a streaming state, which is just a dict[str, Tensor]. + By convention, the first dim of each tensor must be the batch size. + Don't use dots in the key names, as this would clash with submodules + (like in state_dict). + + If `self._is_streaming` is True, the component should use and remember + the proper state inside `self._streaming_state`. + + To set a streaming component in streaming state, use + + with module.streaming(): + ... + + This will automatically reset the streaming state when exiting the context manager. + This also automatically propagates to all streaming children module. + + Some module might also implement the `StreamingModule.flush` method, although + this one is trickier, as all parents module must be StreamingModule and implement + it as well for it to work properly. See `StreamingSequential` after. + """ + def __init__(self) -> None: + super().__init__() + self._streaming_state: State = {} + self._is_streaming = False + + def _apply_named_streaming(self, fn: tp.Any): + for name, module in self.named_modules(): + if isinstance(module, StreamingModule): + fn(name, module) + + def _set_streaming(self, streaming: bool): + def _set_streaming(name, module): + module._is_streaming = streaming + self._apply_named_streaming(_set_streaming) + + @contextmanager + def streaming(self): + """Context manager to enter streaming mode. Reset streaming state on exit. + """ + self._set_streaming(True) + try: + yield + finally: + self._set_streaming(False) + self.reset_streaming() + + def reset_streaming(self): + """Reset the streaming state. + """ + def _reset(name: str, module: StreamingModule): + module._streaming_state.clear() + + self._apply_named_streaming(_reset) + + def get_streaming_state(self) -> State: + """Return the streaming state, including that of sub-modules. + """ + state: State = {} + + def _add(name: str, module: StreamingModule): + if name: + name += "." + for key, value in module._streaming_state.items(): + state[name + key] = value + + self._apply_named_streaming(_add) + return state + + def set_streaming_state(self, state: State): + """Set the streaming state, including that of sub-modules. + """ + state = dict(state) + + def _set(name: str, module: StreamingModule): + if name: + name += "." + module._streaming_state.clear() + for key, value in list(state.items()): + # complexity is not ideal here, but probably fine. + if key.startswith(name): + local_key = key[len(name):] + if '.' not in local_key: + module._streaming_state[local_key] = value + del state[key] + + self._apply_named_streaming(_set) + assert len(state) == 0, list(state.keys()) + + def flush(self, x: tp.Optional[torch.Tensor] = None): + """Flush any remaining outputs that were waiting for completion. + Typically, for convolutions, this will add the final padding + and process the last buffer. + + This should take an optional argument `x`, which will be provided + if a module before this one in the streaming pipeline has already + spitted out a flushed out buffer. + """ + if x is None: + return None + else: + return self(x) + + +class StreamingSequential(StreamingModule, nn.Sequential): + """A streaming compatible alternative of `nn.Sequential`. + """ + def flush(self, x: tp.Optional[torch.Tensor] = None): + for module in self: + if isinstance(module, StreamingModule): + x = module.flush(x) + elif x is not None: + x = module(x) + return x diff --git a/melodytalk/audiocraft/modules/transformer.py b/melodytalk/audiocraft/modules/transformer.py new file mode 100644 index 0000000..e69cca8 --- /dev/null +++ b/melodytalk/audiocraft/modules/transformer.py @@ -0,0 +1,747 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Transformer model, with streaming support, xformer attention support +and easy causal attention with a potentially finite receptive field. + +See `StreamingTransformer` for more information. + +Unlike regular PyTorch Transformer, we make the hard choice that batches are first. +""" + +import typing as tp + +from einops import rearrange +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint as torch_checkpoint +from xformers import ops + +from .rope import RotaryEmbedding +from .streaming import StreamingModule + +_efficient_attention_backend: str = 'torch' + + +def set_efficient_attention_backend(backend: str = 'torch'): + # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster). + global _efficient_attention_backend + assert _efficient_attention_backend in ['xformers', 'torch'] + _efficient_attention_backend = backend + + +def _get_attention_time_dimension() -> int: + if _efficient_attention_backend == 'torch': + return 2 + else: + return 1 + + +def _is_profiled() -> bool: + # Return true if we are currently running with a xformers profiler activated. + try: + from xformers.profiler import profiler + except ImportError: + return False + return profiler._Profiler._CURRENT_PROFILER is not None + + +def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: + """Create normalization module for transformer encoder layer. + + Args: + norm_type (str): Normalization method. + dim (int): Dimension of the normalized layer. + **kwargs (dict): Additional parameters for normalization layer. + Returns: + nn.Module: Normalization module. + """ + if norm_type == 'layer_norm': + return nn.LayerNorm(dim, eps=1e-5, **kwargs) + else: + raise ValueError(f"Unknown norm type: {norm_type}") + + +def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, + dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Create sinusoidal positional embedding, with shape `[B, T, C]`. + + Args: + positions (torch.Tensor): LongTensor of positions. + dim (int): Dimension of the embedding. + max_period (float): Maximum period of the cosine/sine functions. + dtype (torch.dtype or str): dtype to use to generate the embedding. + Returns: + torch.Tensor: Sinusoidal positional embedding. + """ + # We aim for BTC format + assert dim % 2 == 0 + half_dim = dim // 2 + positions = positions.to(dtype) + adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) + max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point + phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) + return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) + + +def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers""" + if n_rep == 1: + return x + if _efficient_attention_backend == 'torch': + bs, n_kv_heads, slen, head_dim = x.shape + return ( + x[:, :, None, :, :] + .expand(bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(bs, n_kv_heads * n_rep, slen, head_dim) + ) + else: + bs, slen, n_kv_heads, head_dim = x.shape + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonaly the residual outputs close to 0, with a learnt scale. + + Args: + channels (int): Number of channels. + init (float): Initial scale. + channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. + device (torch.device or None): Device on which to initialize the module. + dtype (torch.dtype or None): dtype to use to initialize the module. + """ + def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True, + device=None, dtype=None): + super().__init__() + self.channel_last = channel_last + self.scale = nn.Parameter( + torch.full((channels,), init, + requires_grad=True, device=device, dtype=dtype)) + + def forward(self, x: torch.Tensor): + if self.channel_last: + return self.scale * x + else: + return self.scale[:, None] * x + + +class StreamingMultiheadAttention(StreamingModule): + """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation. + + Args: + embed_dim (int): Dimension to project to. + num_heads (int): Number of heads. + dropout (float): Dropout level. + bias (bool): Use bias in projections. + causal (bool): Causal mask applied automatically. + past_context (int or None): Receptive field for the causal mask, infinite if None. + custom (bool): Use custom MHA implementation, for testing / benchmarking. + memory_efficient (bool): Use xformers based memory efficient attention. + attention_as_float32 (bool): Perform the attention as float32 + (especially important with memory_efficient as autocast won't do this automatically). + rope (`RotaryEmbedding` or None): Rope embedding to use. + cross_attention: Should be true when used as a cross attention. + All keys and values must be available at once, streaming is only for the queries. + Cannot be used with `causal` or `rope` (as it wouldn't make sens to + intepret the time steps in the keys relative to those in the queries). + safe_streaming (bool): Bug fix, will go away with xformers update. + qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product. + kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). + This will lead to faster decoding time on A100 or other GPUs with tensorcore. + device (torch.device or None): Sevice on which to initialize. + dtype (torch.dtype or None): dtype to use. + """ + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, + causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, + memory_efficient: bool = False, attention_as_float32: bool = False, + rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False, + safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1, + device=None, dtype=None): + super().__init__() + factory_kwargs = {'device': device, 'dtype': dtype} + if past_context is not None: + assert causal + + self.embed_dim = embed_dim + self.causal = causal + self.past_context = past_context + self.memory_efficient = memory_efficient + self.attention_as_float32 = attention_as_float32 + self.rope = rope + self.cross_attention = cross_attention + self.safe_streaming = safe_streaming + self.num_heads = num_heads + self.dropout = dropout + self.kv_repeat = kv_repeat + if cross_attention: + assert not causal, "Causal cannot work with cross attention." + assert rope is None, "Rope cannot work with cross attention." + + if memory_efficient: + _verify_xformers_memory_efficient_compat() + + self.custom = _is_custom(custom, memory_efficient) + if self.custom: + out_dim = embed_dim + assert num_heads % kv_repeat == 0 + assert not cross_attention or kv_repeat == 1 + num_kv = num_heads // kv_repeat + kv_dim = (embed_dim // num_heads) * num_kv + out_dim += 2 * kv_dim + in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) + # We try to follow the default PyTorch MHA convention, to easily compare results. + self.in_proj_weight = in_proj.weight + self.in_proj_bias = in_proj.bias + if bias: + self.in_proj_bias.data.zero_() # Following Pytorch convention + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + if bias: + self.out_proj.bias.data.zero_() + else: + assert not qk_layer_norm + assert kv_repeat == 1 + self.mha = nn.MultiheadAttention( + embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, + **factory_kwargs) + self.qk_layer_norm = qk_layer_norm + if qk_layer_norm: + assert self.custom + assert kv_repeat == 1 + ln_dim = embed_dim + self.q_layer_norm = nn.LayerNorm(ln_dim) + self.k_layer_norm = nn.LayerNorm(ln_dim) + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + if not self.custom: + # Support compat with regular MHA + keys = [n for n, _ in self.mha.named_parameters()] + for key in keys: + if prefix + key in state_dict: + state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype): + # Return a causal mask, accounting for potentially stored past keys/values + # We actually return a bias for the attention score, as this has the same + # convention both in the builtin MHA in Pytorch, and Xformers functions. + time_dim = _get_attention_time_dimension() + if self.memory_efficient: + from xformers.ops import LowerTriangularMask + if current_steps == 1: + # If we only have one step, then we do not need a mask. + return None + elif 'past_keys' in self._streaming_state: + raise RuntimeError('Not supported at the moment') + else: + # Then we can safely use a lower triangular mask + return LowerTriangularMask() + if self._streaming_state: + past_keys = self._streaming_state['past_keys'] + past_steps = past_keys.shape[time_dim] + else: + past_steps = 0 + + queries_pos = torch.arange( + past_steps, current_steps + past_steps, device=device).view(-1, 1) + keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1) + delta = queries_pos - keys_pos + valid = delta >= 0 + if self.past_context is not None: + valid &= (delta <= self.past_context) + return torch.where( + valid, + torch.zeros([], device=device, dtype=dtype), + torch.full([], float('-inf'), device=device, dtype=dtype)) + + def _complete_kv(self, k, v): + time_dim = _get_attention_time_dimension() + if self.cross_attention: + # With cross attention we assume all keys and values + # are already available, and streaming is with respect + # to the queries only. + return k, v + # Complete the key/value pair using the streaming state. + if self._streaming_state: + pk = self._streaming_state['past_keys'] + nk = torch.cat([pk, k], dim=time_dim) + if v is k: + nv = nk + else: + pv = self._streaming_state['past_values'] + nv = torch.cat([pv, v], dim=time_dim) + else: + nk = k + nv = v + + assert nk.shape[time_dim] == nv.shape[time_dim] + offset = 0 + if self.past_context is not None: + offset = max(0, nk.shape[time_dim] - self.past_context) + if self._is_streaming: + self._streaming_state['past_keys'] = nk[:, offset:] + if v is not k: + self._streaming_state['past_values'] = nv[:, offset:] + if 'offset' in self._streaming_state: + self._streaming_state['offset'] += offset + else: + self._streaming_state['offset'] = torch.tensor(0) + return nk, nv + + def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): + # TODO: fix and verify layout. + assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.' + # Apply rope embeddings to query and key tensors. + assert self.rope is not None + if 'past_keys' in self._streaming_state: + past_keys_offset = self._streaming_state['past_keys'].shape[1] + else: + past_keys_offset = 0 + if 'offset' in self._streaming_state: + past_context_offset = int(self._streaming_state['offset'].item()) + else: + past_context_offset = 0 + streaming_offset = past_context_offset + past_keys_offset + return self.rope.rotate_qk(query, key, start=streaming_offset) + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + key_padding_mask=None, need_weights=False, attn_mask=None, + average_attn_weights=True, is_causal=False): + assert attn_mask is None + assert not is_causal, ("new param added in torch 2.0.1 not supported, " + "use the causal args in the constructor.") + + time_dim = _get_attention_time_dimension() + if time_dim == 2: + layout = "b h t d" + else: + layout = "b t h d" + dtype = query.dtype + if self._is_streaming: + assert self.causal or self.cross_attention, \ + "Streaming only available for causal or cross attention" + + if self.causal: + # At the moment we specialize only for the self-attention case. + assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value" + assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value" + attn_mask = self._get_mask(query.shape[1], query.device, query.dtype) + + if self.custom: + # custom implementation + assert need_weights is False + assert key_padding_mask is None + if self.cross_attention: + # Different queries, keys, values, we have to spit manually the weights + # before applying the linear. + dim = self.in_proj_weight.shape[0] // 3 + if self.in_proj_bias is None: + bias_q, bias_k, bias_v = None, None, None + else: + bias_q = self.in_proj_bias[:dim] + bias_k = self.in_proj_bias[dim: 2 * dim] + bias_v = self.in_proj_bias[2 * dim:] + q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) + # todo: when streaming, we could actually save k, v and check the shape actually match. + k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) + v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) + if self.qk_layer_norm is True: + q = self.q_layer_norm(q) + k = self.k_layer_norm(k) + q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] + else: + if not _is_profiled(): + # profiling breaks that propertysomehow. + assert query is key, "specialized implementation" + assert value is key, "specialized implementation" + projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) + if self.kv_repeat == 1: + if time_dim == 2: + bound_layout = "b h p t d" + else: + bound_layout = "b t p h d" + packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) + q, k, v = ops.unbind(packed, dim=2) + else: + embed_dim = self.embed_dim + per_head_dim = (embed_dim // self.num_heads) + kv_heads = self.num_heads // self.kv_repeat + q = projected[:, :, :embed_dim] + start = embed_dim + end = start + per_head_dim * kv_heads + k = projected[:, :, start: end] + v = projected[:, :, end:] + q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads) + k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads) + v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads) + + if self.qk_layer_norm is True: + assert self.kv_repeat == 1 + q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]] + q = self.q_layer_norm(q) + k = self.k_layer_norm(k) + q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]] + if self.rope: + q, k = self._apply_rope(q, k) + k, v = self._complete_kv(k, v) + if self.kv_repeat > 1: + k = expand_repeated_kv(k, self.kv_repeat) + v = expand_repeated_kv(v, self.kv_repeat) + if self.attention_as_float32: + q, k, v = [x.float() for x in [q, k, v]] + if self.memory_efficient: + p = self.dropout if self.training else 0 + if _efficient_attention_backend == 'torch': + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=attn_mask is not None, dropout_p=p) + else: + x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p) + else: + # We include the dot product as float32, for consistency + # with the other implementations that include that step + # as part of the attention. Note that when using `autocast`, + # the einsums would be done as bfloat16, but the softmax + # would be done as bfloat16, so `attention_as_float32` will + # extend a bit the range of operations done in float32, + # although this should make no difference. + q = q / q.shape[-1] ** 0.5 + key_layout = layout.replace('t', 'k') + query_layout = layout + if self._is_streaming and self.safe_streaming and q.device.type == 'cuda': + with torch.autocast(device_type=q.device.type, dtype=torch.float32): + pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) + else: + pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) + if attn_mask is not None: + pre_w = pre_w + attn_mask + w = torch.softmax(pre_w, dim=-1) + w = F.dropout(w, self.dropout, training=self.training).to(v) + # Key and value have the same format. + x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v) + x = x.to(dtype) + x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) + x = self.out_proj(x) + else: + key, value = self._complete_kv(key, value) + if self.attention_as_float32: + query, key, value = [x.float() for x in [query, key, value]] + x, _ = self.mha( + query, key, value, key_padding_mask, + need_weights, attn_mask, average_attn_weights) + x = x.to(dtype) + + return x, None + + +class StreamingTransformerLayer(nn.TransformerEncoderLayer): + """TransformerLayer with Streaming / Causal support. + This also integrates cross_attention, when passing `cross_attention=True`, + rather than having two separate classes like in PyTorch. + + Args: + d_model (int): Dimension of the data. + num_heads (int): Number of heads. + dim_feedforward (int): Intermediate dimension of FF module. + dropout (float): Dropout both for MHA and FF. + bias_ff (bool): Use bias for FF. + bias_attn (bool): Use bias for MHA. + causal (bool): Causal mask applied automatically. + past_context (int or None): Receptive field for the causal mask, infinite if None. + custom (bool): Use custom MHA implementation, for testing / benchmarking. + memory_efficient (bool): Use xformers based memory efficient attention. + attention_as_float32 (bool): Perform the attention as float32 + (especially important with memory_efficient as autocast won't do this automatically). + qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention. + qk_layer_norm_cross (bool): Same for the cross attention. + cross_attention (bool): If True, expect to get secondary input for cross-attention. + Cross attention will use the default MHA, as it typically won't require + special treatment. + layer_scale (float or None): If not None, LayerScale will be used with + the given value as initial scale. + rope (`RotaryEmbedding` or None): Rope embedding to use. + attention_dropout (float or None): If not None, separate the value of the dimension dropout + in FFN and of the attention dropout. + kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). + This will lead to faster decoding time on A100 or other GPUs with tensorcore. + device (torch.device or None): Device on which to initialize. + dtype (torch.dtype or None): dtype to use. + **kwargs: See `nn.TransformerEncoderLayer`. + """ + def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, + bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, + past_context: tp.Optional[int] = None, custom: bool = False, + memory_efficient: bool = False, attention_as_float32: bool = False, + qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False, + cross_attention: bool = False, layer_scale: tp.Optional[float] = None, + rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None, + kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs): + super().__init__(d_model, num_heads, dim_feedforward, dropout, + device=device, dtype=dtype, batch_first=True, **kwargs) + factory_kwargs = {'device': device, 'dtype': dtype} + # Redefine self_attn to our streaming multi-head attention + attn_kwargs: tp.Dict[str, tp.Any] = { + 'embed_dim': d_model, + 'num_heads': num_heads, + 'dropout': dropout if attention_dropout is None else attention_dropout, + 'bias': bias_attn, + 'custom': custom, + 'memory_efficient': memory_efficient, + 'attention_as_float32': attention_as_float32, + } + self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention( + causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm, + kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore + # Redefine feedforward layers to expose bias parameter + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) + + self.layer_scale_1: nn.Module + self.layer_scale_2: nn.Module + if layer_scale is None: + self.layer_scale_1 = nn.Identity() + self.layer_scale_2 = nn.Identity() + else: + self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) + self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) + + self.cross_attention: tp.Optional[nn.Module] = None + if cross_attention: + self.cross_attention = StreamingMultiheadAttention( + cross_attention=True, qk_layer_norm=qk_layer_norm_cross, + **attn_kwargs, **factory_kwargs) + # Norm and dropout + self.dropout_cross = nn.Dropout(dropout) + # eps value matching that used in PyTorch reference implementation. + self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs) + self.layer_scale_cross: nn.Module + if layer_scale is None: + self.layer_scale_cross = nn.Identity() + else: + self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs) + self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore + self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore + + def _cross_attention_block(self, src: torch.Tensor, + cross_attention_src: torch.Tensor) -> torch.Tensor: + assert self.cross_attention is not None + # queries are from src, keys and values from cross_attention_src. + x = self.cross_attention( + src, cross_attention_src, cross_attention_src, need_weights=False)[0] + return self.dropout_cross(x) # type: ignore + + def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore + src_key_padding_mask: tp.Optional[torch.Tensor] = None, + cross_attention_src: tp.Optional[torch.Tensor] = None): + if self.cross_attention is None: + assert cross_attention_src is None + else: + assert cross_attention_src is not None + x = src + if self.norm_first: + x = x + self.layer_scale_1( + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) + if cross_attention_src is not None: + x = x + self.layer_scale_cross( + self._cross_attention_block( + self.norm_cross(x), cross_attention_src)) + x = x + self.layer_scale_2(self._ff_block(self.norm2(x))) + else: + x = self.norm1(x + self.layer_scale_1( + self._sa_block(x, src_mask, src_key_padding_mask))) + if cross_attention_src is not None: + x = self.norm_cross( + x + self.layer_scale_cross( + self._cross_attention_block(src, cross_attention_src))) + x = self.norm2(x + self.layer_scale_2(self._ff_block(x))) + return x + + +class StreamingTransformer(StreamingModule): + """Transformer with Streaming / Causal support. + + Args: + d_model (int): Dimension of the data. + num_heads (int): Number of heads. + dim_feedforward (int): Intermediate dimension of FF module. + dropout (float): Dropout both for MHA and FF. + bias_ff (bool): Use bias for FF. + bias_attn (bool): Use bias for MHA. + causal (bool): Causal mask applied automatically. + past_context (int or None): Receptive field for the causal mask, infinite if None. + custom (bool): Use custom MHA implementation, for testing / benchmarking. + memory_efficient (bool): Use xformers based memory efficient attention. + attention_as_float32 (bool): Perform the attention as float32 + (especially important with memory_efficient as autocast won't do this automatically). + cross_attention (bool): If True, expect to get secondary input for cross-attention. + layer_scale (float or None): If not None, LayerScale will be used + with the given value as initial scale. + positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope). + max_period (float): Maximum period of the time embedding. + positional_scale (float): Scale of positional embedding, set to 0 to deactivate. + xpos (bool): Apply xpos exponential decay to positional embedding (rope only). + lr (float or None): learning rate override through the `make_optim_group` API. + weight_decay (float or None): Weight_decay override through the `make_optim_group` API. + layer_class: (subclass of `StreamingTransformerLayer): class to use + to initialize the layers, allowing further customization outside of Audiocraft. + checkpointing (str): Checkpointing strategy to reduce memory usage. + No checkpointing if set to 'none'. Per layer checkpointing using PyTorch + if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, + minimal memory usage, but maximal runtime). Finally, `xformers_default` provide + a policy for opting-out some operations of the checkpointing like + linear layers and attention, providing a middle ground between speed and memory. + device (torch.device or None): Device on which to initialize. + dtype (torch.dtype or None): dtype to use. + **kwargs: See `nn.TransformerEncoderLayer`. + """ + def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, + dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, + causal: bool = False, past_context: tp.Optional[int] = None, + custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, + cross_attention: bool = False, layer_scale: tp.Optional[float] = None, + positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1., + xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None, + layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer, + checkpointing: str = 'none', device=None, dtype=None, **kwargs): + super().__init__() + assert d_model % num_heads == 0 + + self.positional_embedding = positional_embedding + self.max_period = max_period + self.positional_scale = positional_scale + self.weight_decay = weight_decay + self.lr = lr + + assert positional_embedding in ['sin', 'rope', 'sin_rope'] + self.rope: tp.Optional[RotaryEmbedding] = None + if self.positional_embedding in ['rope', 'sin_rope']: + assert _is_custom(custom, memory_efficient) + self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period, + xpos=xpos, scale=positional_scale, device=device) + + self.checkpointing = checkpointing + + assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm'] + if self.checkpointing.startswith('xformers'): + _verify_xformers_internal_compat() + + self.layers = nn.ModuleList() + for idx in range(num_layers): + self.layers.append( + layer_class( + d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, + dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, + causal=causal, past_context=past_context, custom=custom, + memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, + cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope, + device=device, dtype=dtype, **kwargs)) + + if self.checkpointing != 'none': + for layer in self.layers: + # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the + # backward hook inside of FSDP... + layer._magma_checkpointed = True # type: ignore + assert layer.layer_drop == 0., "Need further checking" # type: ignore + + def _apply_layer(self, layer, *args, **kwargs): + method = self.checkpointing + if method == 'none': + return layer(*args, **kwargs) + elif method == 'torch': + return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs) + elif method.startswith('xformers'): + from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy + if method == 'xformers_default': + # those operations will be saved, and not recomputed. + # According to Francisco we can get smarter policies but this is a good start. + allow_list = [ + "xformers.efficient_attention_forward_cutlass.default", + "xformers_flash.flash_fwd.default", + "aten.addmm.default", + "aten.mm.default", + ] + elif method == 'xformers_mm': + # those operations will be saved, and not recomputed. + # According to Francisco we can get smarter policies but this is a good start. + allow_list = [ + "aten.addmm.default", + "aten.mm.default", + ] + else: + raise ValueError(f"xformers checkpointing xformers policy {method} is not known.") + policy_fn = _get_default_policy(allow_list) + return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs) + else: + raise ValueError(f"Checkpointing method {method} is unknown.") + + def forward(self, x: torch.Tensor, *args, **kwargs): + B, T, C = x.shape + + if 'offsets' in self._streaming_state: + offsets = self._streaming_state['offsets'] + else: + offsets = torch.zeros(B, dtype=torch.long, device=x.device) + + if self.positional_embedding in ['sin', 'sin_rope']: + positions = torch.arange(T, device=x.device).view(1, -1, 1) + positions = positions + offsets.view(-1, 1, 1) + pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) + x = x + self.positional_scale * pos_emb + + for layer in self.layers: + x = self._apply_layer(layer, x, *args, **kwargs) + + if self._is_streaming: + self._streaming_state['offsets'] = offsets + T + + return x + + def make_optim_group(self): + group = {"params": list(self.parameters())} + if self.lr is not None: + group["lr"] = self.lr + if self.weight_decay is not None: + group["weight_decay"] = self.weight_decay + return group + + +# special attention attention related function + +def _verify_xformers_memory_efficient_compat(): + try: + from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa + except ImportError: + raise ImportError( + "xformers is not installed. Please install it and try again.\n" + "To install on AWS and Azure, run \n" + "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" + "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" + "To install on FAIR Cluster, run \n" + "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" + "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") + + +def _verify_xformers_internal_compat(): + try: + from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa + except ImportError: + raise ImportError( + "Francisco's fairinternal xformers is not installed. Please install it and try again.\n" + "To install on AWS and Azure, run \n" + "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" + "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" + "To install on FAIR Cluster, run \n" + "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" + "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") + + +def _is_custom(custom: bool, memory_efficient: bool): + return custom or memory_efficient diff --git a/melodytalk/audiocraft/py.typed b/melodytalk/audiocraft/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/melodytalk/audiocraft/quantization/__init__.py b/melodytalk/audiocraft/quantization/__init__.py new file mode 100644 index 0000000..836d6eb --- /dev/null +++ b/melodytalk/audiocraft/quantization/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import ResidualVectorQuantizer +from .base import BaseQuantizer, DummyQuantizer, QuantizedResult diff --git a/melodytalk/audiocraft/quantization/base.py b/melodytalk/audiocraft/quantization/base.py new file mode 100644 index 0000000..1b16c13 --- /dev/null +++ b/melodytalk/audiocraft/quantization/base.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Base class for all quantizers. +""" + +from dataclasses import dataclass, field +import typing as tp + +import torch +from torch import nn + + +@dataclass +class QuantizedResult: + x: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class BaseQuantizer(nn.Module): + """Base class for quantizers. + """ + + def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: + """ + Given input tensor x, returns first the quantized (or approximately quantized) + representation along with quantized codes, bandwidth, and any penalty term for the loss. + Finally, this returns a dict of metrics to update logging etc. + Frame rate must be passed so that the bandwidth is properly computed. + """ + raise NotImplementedError() + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + """ + raise NotImplementedError() + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation. + """ + raise NotImplementedError() + + @property + def total_codebooks(self): + """Total number of codebooks. + """ + raise NotImplementedError() + + @property + def num_codebooks(self): + """Number of active codebooks. + """ + raise NotImplementedError() + + def set_num_codebooks(self, n: int): + """Set the number of active codebooks. + """ + raise NotImplementedError() + + +class DummyQuantizer(BaseQuantizer): + """Fake quantizer that actually does not perform any quantization. + """ + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, frame_rate: int): + q = x.unsqueeze(1) + return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + In the case of the DummyQuantizer, the codes are actually identical + to the input and resulting quantized representation as no quantization is done. + """ + return x.unsqueeze(1) + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation. + In the case of the DummyQuantizer, the codes are actually identical + to the input and resulting quantized representation as no quantization is done. + """ + return codes.squeeze(1) + + @property + def total_codebooks(self): + """Total number of codebooks. + """ + return 1 + + @property + def num_codebooks(self): + """Total number of codebooks. + """ + return self.total_codebooks + + def set_num_codebooks(self, n: int): + """Set the number of active codebooks. + """ + raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") diff --git a/melodytalk/audiocraft/quantization/core_vq.py b/melodytalk/audiocraft/quantization/core_vq.py new file mode 100644 index 0000000..e1896bb --- /dev/null +++ b/melodytalk/audiocraft/quantization/core_vq.py @@ -0,0 +1,400 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +from einops import rearrange, repeat +import flashy +import torch +from torch import nn, einsum +import torch.nn.functional as F + + +def exists(val: tp.Optional[tp.Any]) -> bool: + return val is not None + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if exists(val) else d + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +def orthgonal_loss_fn(t): + # eq (2) from https://arxiv.org/abs/2112.00384 + n = t.shape[0] + normed_codes = l2norm(t) + identity = torch.eye(n, device=t.device) + cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) + return ((cosine_sim - identity) ** 2).sum() / (n ** 2) + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.8, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + flashy.distrib.broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + flashy.distrib.broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): + channels_last (bool): Channels are the last dimension in the input tensors. + commitment_weight (float): Weight for commitment loss. + orthogonal_reg_weight (float): Orthogonal regularization weights. + orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. + orthogonal_reg_max_codes (optional int): Maximum number of codes to consider + for orthogonal regulariation. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.8, + epsilon: float = 1e-5, + kmeans_init: bool = False, + kmeans_iters: int = 10, + threshold_ema_dead_code: int = 2, + channels_last: bool = False, + commitment_weight: float = 1., + orthogonal_reg_weight: float = 0.0, + orthogonal_reg_active_codes_only: bool = False, + orthogonal_reg_max_codes: tp.Optional[int] = None, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) + self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + + self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, + kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, + decay=decay, epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code) + self.codebook_size = codebook_size + + self.channels_last = channels_last + + @property + def codebook(self): + return self._codebook.embed + + @property + def inited(self): + return self._codebook.inited + + def _preprocess(self, x): + if not self.channels_last: + x = rearrange(x, "b d n -> b n d") + return x + + def _postprocess(self, quantize): + if not self.channels_last: + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def encode(self, x): + x = self._preprocess(x) + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = self._postprocess(quantize) + return quantize + + def forward(self, x): + device = x.device + x = self._preprocess(x) + + x = self.project_in(x) + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + if self.orthogonal_reg_weight > 0: + codebook = self.codebook + + if self.orthogonal_reg_active_codes_only: + # only calculate orthogonal loss for the activated codes for this batch + unique_code_ids = torch.unique(embed_ind) + codebook = codebook[unique_code_ids] + + num_codes = codebook.shape[0] + if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: + rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] + codebook = codebook[rand_ids] + + orthogonal_reg_loss = orthgonal_loss_fn(codebook) + loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight + + quantize = self.project_out(quantize) + quantize = self._postprocess(quantize) + + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for i, layer in enumerate(self.layers[:n_q]): + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/melodytalk/audiocraft/quantization/vq.py b/melodytalk/audiocraft/quantization/vq.py new file mode 100644 index 0000000..f67c3a0 --- /dev/null +++ b/melodytalk/audiocraft/quantization/vq.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp + +import torch + +from .base import BaseQuantizer, QuantizedResult +from .core_vq import ResidualVectorQuantization + + +class ResidualVectorQuantizer(BaseQuantizer): + """Residual Vector Quantizer. + + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + q_dropout (bool): Random quantizer drop out at train time. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + orthogonal_reg_weight (float): Orthogonal regularization weights. + orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. + orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. + for orthogonal regulariation. + """ + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + q_dropout: bool = False, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 10, + threshold_ema_dead_code: int = 2, + orthogonal_reg_weight: float = 0.0, + orthogonal_reg_active_codes_only: bool = False, + orthogonal_reg_max_codes: tp.Optional[int] = None, + ): + super().__init__() + self.max_n_q = n_q + self.n_q = n_q + self.q_dropout = q_dropout + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + orthogonal_reg_weight=self.orthogonal_reg_weight, + orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, + orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, + channels_last=False + ) + + def forward(self, x: torch.Tensor, frame_rate: int): + n_q = self.n_q + if self.training and self.q_dropout: + n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) + bw_per_q = math.log2(self.bins) * frame_rate / 1000 + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + codes = codes.transpose(0, 1) + # codes is [B, K, T], with T frames, K nb of codebooks. + bw = torch.tensor(n_q * bw_per_q).to(x) + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a given input tensor with the specified frame rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + """ + n_q = self.n_q + codes = self.vq.encode(x, n_q=n_q) + codes = codes.transpose(0, 1) + # codes is [B, K, T], with T frames, K nb of codebooks. + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation. + """ + # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. + codes = codes.transpose(0, 1) + quantized = self.vq.decode(codes) + return quantized + + @property + def total_codebooks(self): + return self.max_n_q + + @property + def num_codebooks(self): + return self.n_q + + def set_num_codebooks(self, n: int): + assert n > 0 and n <= self.max_n_q + self.n_q = n diff --git a/melodytalk/audiocraft/utils/__init__.py b/melodytalk/audiocraft/utils/__init__.py new file mode 100644 index 0000000..0952fcc --- /dev/null +++ b/melodytalk/audiocraft/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/melodytalk/audiocraft/utils/autocast.py b/melodytalk/audiocraft/utils/autocast.py new file mode 100644 index 0000000..ed64484 --- /dev/null +++ b/melodytalk/audiocraft/utils/autocast.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class TorchAutocast: + """TorchAutocast utility class. + Allows you to enable and disable autocast. This is specially useful + when dealing with different architectures and clusters with different + levels of support. + + Args: + enabled (bool): Whether to enable torch.autocast or not. + args: Additional args for torch.autocast. + kwargs: Additional kwargs for torch.autocast + """ + def __init__(self, enabled: bool, *args, **kwargs): + self.autocast = torch.autocast(*args, **kwargs) if enabled else None + + def __enter__(self): + if self.autocast is None: + return + try: + self.autocast.__enter__() + except RuntimeError: + device = self.autocast.device + dtype = self.autocast.fast_dtype + raise RuntimeError( + f"There was an error autocasting with dtype={dtype} device={device}\n" + "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" + ) + + def __exit__(self, *args, **kwargs): + if self.autocast is None: + return + self.autocast.__exit__(*args, **kwargs) diff --git a/melodytalk/audiocraft/utils/export.py b/melodytalk/audiocraft/utils/export.py new file mode 100644 index 0000000..b513b52 --- /dev/null +++ b/melodytalk/audiocraft/utils/export.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility to export a training checkpoint to a lightweight release checkpoint. +""" + +from pathlib import Path +import typing as tp + +from omegaconf import OmegaConf, DictConfig +import torch + + +def _clean_lm_cfg(cfg: DictConfig): + OmegaConf.set_struct(cfg, False) + # This used to be set automatically in the LM solver, need a more robust solution + # for the future. + cfg['transformer_lm']['card'] = 2048 + cfg['transformer_lm']['n_q'] = 4 + # Experimental params no longer supported. + bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', + 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] + for name in bad_params: + del cfg['transformer_lm'][name] + OmegaConf.set_struct(cfg, True) + return cfg + + +def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): + sig = Path(checkpoint_path).parent.name + assert len(sig) == 8, "Not a valid Dora signature" + pkg = torch.load(checkpoint_path, 'cpu') + new_pkg = { + 'best_state': pkg['ema']['state']['model'], + 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + } + out_file = Path(out_folder) / f'{sig}.th' + torch.save(new_pkg, out_file) + return out_file + + +def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): + sig = Path(checkpoint_path).parent.name + assert len(sig) == 8, "Not a valid Dora signature" + pkg = torch.load(checkpoint_path, 'cpu') + new_pkg = { + 'best_state': pkg['fsdp_best_state']['model'], + 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) + } + out_file = Path(out_folder) / f'{sig}.th' + torch.save(new_pkg, out_file) + return out_file diff --git a/melodytalk/audiocraft/utils/notebook.py b/melodytalk/audiocraft/utils/notebook.py new file mode 100644 index 0000000..019b9d1 --- /dev/null +++ b/melodytalk/audiocraft/utils/notebook.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +try: + import IPython.display as ipd # type: ignore +except ImportError: + # Note in a notebook... + pass + + +import torch + + +def display_audio(samples: torch.Tensor, sample_rate: int): + """Renders an audio player for the given audio samples. + + Args: + samples (torch.Tensor): a Tensor of decoded audio samples + with shapes [B, C, T] or [C, T] + sample_rate (int): sample rate audio should be displayed with. + """ + assert samples.dim() == 2 or samples.dim() == 3 + + samples = samples.detach().cpu() + if samples.dim() == 2: + samples = samples[None, ...] + + for audio in samples: + ipd.display(ipd.Audio(audio, rate=sample_rate)) diff --git a/melodytalk/audiocraft/utils/utils.py b/melodytalk/audiocraft/utils/utils.py new file mode 100644 index 0000000..86e1448 --- /dev/null +++ b/melodytalk/audiocraft/utils/utils.py @@ -0,0 +1,234 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from concurrent.futures import ProcessPoolExecutor +from functools import wraps +import hashlib +import logging +import typing as tp + +import flashy +import flashy.distrib +import omegaconf +import torch +from torch.nn.utils.rnn import pad_sequence + + +logger = logging.getLogger(__name__) + + +def dict_from_config(cfg: omegaconf.DictConfig) -> dict: + """Convenience function to map an omegaconf configuration to a dictionary. + + Args: + cfg (omegaconf.DictConfig): Original configuration to map to dict. + Returns: + dict: Config as dictionary object. + """ + dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) + assert isinstance(dct, dict) + return dct + + +def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset: + if max_samples >= len(dataset): + return dataset + + generator = torch.Generator().manual_seed(seed) + perm = torch.randperm(len(dataset), generator=generator) + return torch.utils.data.Subset(dataset, perm[:max_samples].tolist()) + + +def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int, + num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader: + """Convenience function to load dataset into a dataloader with optional subset sampling. + + Args: + dataset: Dataset to load. + num_samples (Optional[int]): Number of samples to limit subset size. + batch_size (int): Batch size. + num_workers (int): Number of workers for data loading. + seed (int): Random seed. + """ + if num_samples is not None: + dataset = random_subset(dataset, num_samples, seed) + + dataloader = flashy.distrib.loader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + **kwargs + ) + return dataloader + + +def get_dataset_from_loader(dataloader): + dataset = dataloader.dataset + if isinstance(dataset, torch.utils.data.Subset): + return dataset.dataset + else: + return dataset + + +def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): + """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. + + Args: + input (torch.Tensor): The input tensor containing probabilities. + num_samples (int): Number of samples to draw. + replacement (bool): Whether to draw with replacement or not. + Keywords args: + generator (torch.Generator): A pseudorandom number generator for sampling. + Returns: + torch.Tensor: Last dimension contains num_samples indices + sampled from the multinomial probability distribution + located in the last dimension of tensor input. + """ + input_ = input.reshape(-1, input.shape[-1]) + output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) + output = output_.reshape(*list(input.shape[:-1]), -1) + return output + + +def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: + """Sample next token from top K values along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + k (int): The k in “top-k”. + Returns: + torch.Tensor: Sampled tokens. + """ + top_k_value, _ = torch.topk(probs, k, dim=-1) + min_value_top_k = top_k_value[..., [-1]] + probs *= (probs >= min_value_top_k).float() + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """Sample next token from top P probabilities along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + p (int): The p in “top-p”. + Returns: + torch.Tensor: Sampled tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort *= (~mask).float() + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +class DummyPoolExecutor: + """Dummy pool executor to use when we actually have only 1 worker. + (e.g. instead of ProcessPoolExecutor). + """ + class DummyResult: + def __init__(self, func, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + def result(self): + return self.func(*self.args, **self.kwargs) + + def __init__(self, workers, mp_context=None): + pass + + def submit(self, func, *args, **kwargs): + return DummyPoolExecutor.DummyResult(func, *args, **kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + return + + +def get_pool_executor(num_workers: int, mp_context=None): + return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1) + + +def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: + """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). + For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] + + Args: + lengths (torch.Tensor): tensor with lengths + max_len (int): can set the max length manually. Defaults to None. + Returns: + torch.Tensor: mask with 0s where there is pad tokens else 1s + """ + assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." + final_length = lengths.max().item() if not max_len else max_len + final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor + return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] + + +def hash_trick(word: str, vocab_size: int) -> int: + """Hash trick to pair each word with an index + + Args: + word (str): word we wish to convert to an index + vocab_size (int): size of the vocabulary + Returns: + int: index of the word in the embedding LUT + """ + hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16) + return hash % vocab_size + + +def with_rank_rng(base_seed: int = 1234): + """Decorator for a function so that the function will use a Random Number Generator + whose state depend on the GPU rank. The original RNG state is restored upon returning. + + Args: + base_seed (int): Random seed. + """ + def _decorator(fun: tp.Callable): + @wraps(fun) + def _decorated(*args, **kwargs): + state = torch.get_rng_state() + seed = base_seed ^ flashy.distrib.rank() + torch.manual_seed(seed) + logger.debug('Rank dependent seed set to %d', seed) + try: + return fun(*args, **kwargs) + finally: + torch.set_rng_state(state) + logger.debug('RNG state restored.') + return _decorated + return _decorator + + +def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Get a list of tensors and collate them to a single tensor. according to the following logic: + - `dim` specifies the time dimension which will be stacked and padded. + - The output will contain 1 new dimension (dimension index 0) which will be the size of + of the original list. + + Args: + tensors (tp.List[torch.Tensor]): List of tensors to collate. + dim (int): Dimension which will be stacked and padded. + Returns: + tp.Tuple[torch.Tensor, torch.Tensor]: + torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension + (dimension index 0) which will be the size of the original list. + torch.Tensor: Tensor containing length of original tensor sizes (without padding). + """ + tensors = [x.transpose(0, dim) for x in tensors] + lens = torch.LongTensor([len(x) for x in tensors]) + padded_tensors = pad_sequence(tensors) + padded_tensors = padded_tensors.transpose(0, 1) + padded_tensors = padded_tensors.transpose(1, dim + 1) + return padded_tensors, lens diff --git a/melodytalk/main.py b/melodytalk/main.py index b711e48..41de164 100644 --- a/melodytalk/main.py +++ b/melodytalk/main.py @@ -71,7 +71,7 @@ Since MelodyTalk is a text language model, MelodyTalk must use tools to observe music rather than imagination. The thoughts and observations are only visible for MelodyTalk. New input: {input} -Thought: Do I need to use a tool? {agent_scratchpad} Let's think step by step. +Thought: Do I need to use a tool? {agent_scratchpad} You MUST strictly follow the format. """ # removed from melodytalk_suffix: @@ -128,7 +128,7 @@ class ConversationBot(object): def __init__(self): load_dict = {"Text2Music":"cuda:0", "ExtractTrack":"cuda:0", "Text2MusicWithMelody":"cuda:0", "SimpleTracksMixing":"cuda:0"} - template_dict = {"Accompaniment": "cuda:0"} + template_dict = { "Text2MusicwithChord": "cuda:0"} # "Accompaniment": "cuda:0", print(f"Initializing MelodyTalk, load_dict={load_dict}, template_dict={template_dict}") @@ -188,8 +188,8 @@ def run_text(self, text, state): if len(res['intermediate_steps']) > 0: audio_filename = res['intermediate_steps'][-1][1] state = state + [(None,(audio_filename,))] - print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n" - f"Current Memory: {self.agent.memory.buffer}") + # print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n" + # f"Current Memory: {self.agent.memory.buffer}") return state, state def run_audio(self, file, state, txt, lang): @@ -211,8 +211,8 @@ def run_audio(self, file, state, txt, lang): self.agent.memory.chat_memory.add_user_message(Human_prompt) self.agent.memory.chat_memory.add_ai_message(AI_prompt) state = state + [((music_filename,), AI_prompt)] - print(f"\nProcessed run_audio, Input music: {music_filename}\nCurrent state: {state}\n" - f"Current Memory: {self.agent.memory.buffer}") + # print(f"\nProcessed run_audio, Input music: {music_filename}\nCurrent state: {state}\n" + # f"Current Memory: {self.agent.memory.buffer}") return state, state def run_recording(self, file_path, state, txt, lang): diff --git a/melodytalk/modules.py b/melodytalk/modules.py index 2390557..f57ce12 100644 --- a/melodytalk/modules.py +++ b/melodytalk/modules.py @@ -10,12 +10,16 @@ # source separation import demucs.separate -from utils import prompts, get_new_audio_name, description_to_attributes, cut_dialogue_history +from utils import * +DURATION = 12 # Initialze common models -musicgen_model = MusicGen.get_pretrained('melody') -musicgen_model.set_generation_params(duration=8) +musicgen_model = MusicGen.get_pretrained('large') +musicgen_model.set_generation_params(duration=DURATION) + +musicgen_model_melody = MusicGen.get_pretrained('melody') +musicgen_model_melody.set_generation_params(duration=DURATION) class Text2Music(object): def __init__(self, device): @@ -44,12 +48,13 @@ class Text2MusicWithMelody(object): def __init__(self, device): print("Initializing Text2MusicWithMelody") self.device = device - self.model = musicgen_model + self.model = musicgen_model_melody @prompts( - name="Generate music from user input text with melody or track condition", + name="Generate music from user input text with given melody or track condition", description="useful if you want to generate, style transfer or remix music from a user input text with a given melody or track condition." "Unlike Accompaniment, this tool will also re-generate the given melody." + "You shall not use it when no previous music file in the history." "like: remix the given melody with text description, or doing style transfer as text described with the given melody." "The input to this tool should be a comma separated string of two, " "representing the music_filename and the text description." @@ -150,7 +155,7 @@ def inference(self, inputs): wav_2 = wav_2[:, :wav_1.shape[-1]] # mix assert wav_1.shape == wav_2.shape # channel, length - wav = torch.add(wav_1, wav_2) + wav = torch.add(wav_1, wav_2 * 0.7) # write audio_write(updated_music_filename[:-4], wav.cpu(), sr_1, strategy="loudness", loudness_compressor=True) @@ -158,6 +163,54 @@ def inference(self, inputs): return updated_music_filename +class MusicCaptioning(object): + def __init__(self): + raise NotImplementedError + +class Text2MusicwithChord(object): + template_model = True + def __init__(self, Text2Music): + print("Initializing Text2MusicwithChord") + self.Text2Music = Text2Music + + @prompts( + name="Generate music from user input text and chord description", + description="useful only if you want to generate music from a user input text and explicitly mention a chord description." + "Like: generate a pop love song with piano and a chord progression of C - F - G - C, or generate a sad music with a jazz chord progression." + "This tool will automatically extract chord information and generate music." + "The input to this tool should be the user input text. " + ) + + def inference(self, inputs): + music_filename = os.path.join("music", f"{str(uuid.uuid4())[:8]}.wav") + + chords_list = chord_generation(inputs) + preprocessed_input = description_to_attributes(inputs) + + for i, chord in enumerate(chords_list): + text = f"{preprocessed_input} key: {chord}." + self.Text2Music.model.set_generation_params(duration=(i + 1) * (DURATION / len(chords_list))) + if i == 0: + wav = self.Text2Music.model.generate([text], progress=False) + else: + wav = self.Text2Music.model.generate_continuation(wav, + self.Text2Music.model.sample_rate, + [text], + progress=False) + if i == len(chords_list) - 1: + wav = wav[0] # batch size is 1 + audio_write(music_filename[:-4], + wav.cpu(), self.Text2Music.model.sample_rate, strategy="loudness", loudness_compressor=True) + self.Text2Music.model.set_generation_params(duration=DURATION) + print(f"\nProcessed Text2Music, Input Text: {preprocessed_input}, Output Music: {music_filename}.") + return music_filename + + + +class MusicInpainting(object): + def __init__(self): + raise NotImplementedError + class Accompaniment(object): template_model = True def __init__(self, Text2MusicWithMelody, ExtractTrack, SimpleTracksMixing): @@ -183,7 +236,7 @@ def inference(self, inputs): # separate the track updated_main_track = self.ExtractTrack.inference(f"{music_filename}, {track_name}, extract") # generate music - updated_new_music = self.Text2MusicWithMelody.inference(f"{text}, {updated_main_track}") + updated_new_music = self.Text2MusicWithMelody.inference(f"{updated_main_track}, {text}") # remove the track in accompaniment updated_accompaniment = self.ExtractTrack.inference(f"{updated_new_music}, {track_name}, remove") # mix diff --git a/melodytalk/utils.py b/melodytalk/utils.py index f2035db..8f59379 100644 --- a/melodytalk/utils.py +++ b/melodytalk/utils.py @@ -61,14 +61,18 @@ def description_to_attributes(description: str) -> str: :return: """ - openai_prompt = f"""Please format the bpm and key attributes from the original description text and keep the rest unchanged. - If the description text does not mention it, do not add it. Here are two examples: + openai_prompt = f"""Please: 1. split and paste the bpm and key attributes from the original description text; + copy and paste the instrument attribute but leave the original text; delete the chord description if necessary. keep the rest unchanged. + If the description text does not mention it, do not add it. Here are some examples: - Q: Generate a love pop song in C major of 120 bpm. - A: Generate a love pop song. bpm: 120. key: Cmaj. + Q: Generate a love pop song with piano and violin in C major of 120 bpm. + A: Generate a love pop song with piano and violin. bpm: 120. key: C major. instrument: piano, violin. - Q: love pop song in a minor, creating a romantic atmosphere. - A: love pop song, creating a romantic atmosphere. key: Amin. + Q: love pop song in A minor, creating a romantic atmosphere. + A: love pop song, creating a romantic atmosphere. key: A minor. + + Q: Generate a pop song with chord progression of C - G - Am - F, with piano and guitar. + A: Generate a pop song with piano and guitar. instrument: piano, guitar. Q: {description}. A: """ @@ -86,7 +90,7 @@ def description_to_attributes(description: str) -> str: return response.choices[0].text -def chord_generation(description: str, chord_num: int = 4) -> tp.List: +def chord_generation(description: str) -> tp.List: """ This function is a trick to generate chord sequence from the description. :param description: @@ -94,12 +98,17 @@ def chord_generation(description: str, chord_num: int = 4) -> tp.List: :return: """ - openai_prompt = f"""Please generate a chord sequence consists of according to the text description. Example: + openai_prompt = f"""Please analyse and generate a chord sequence (4 chords by default) according to the text description. Example: - Q: Generate a chord sequence for sad pop song. 4 chords. - A: Dm - Bb - F - C + # random generate chord sequence + Q: Generate a song which has a jazz chord progression, with piano and guitar. + A: D minor - Bb major - F major - C major + + # formalize chord description + Q: Generate a pop song which has a chord progression of I - IV - V - I, with piano and guitar. + A: C major - F major - G major - C major - Q: {description}. {chord_num} chords. + Q: {description} A: """ response = openai.Completion.create( @@ -112,6 +121,6 @@ def chord_generation(description: str, chord_num: int = 4) -> tp.List: presence_penalty=0.0, ) - chord_list = [i.strip() for i in response.choices[0].text.split('-')] + chord_list = [i.strip() for i in response.choices[0].text.split(' - ')] return chord_list