Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference API of AMR #40

Merged
merged 20 commits into from
Oct 22, 2024
Merged
Binary file added api_example/1a-ODBWMUAE.wav
Binary file not shown.
65 changes: 23 additions & 42 deletions lighthouse/feature_extractor/audio_encoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import librosa
import numpy as np

from lighthouse.feature_extractor.base_encoder import BaseEncoder
from lighthouse.feature_extractor.audio_encoders.pann import Cnn14
from lighthouse.feature_extractor.audio_encoders.pann import PANN, PANNConfig
from lighthouse.feature_extractor.audio_encoders.clap_a import CLAPAudio, CLAPAudioConfig

from typing import Optional, Tuple

Expand All @@ -25,14 +25,6 @@


class AudioEncoder(BaseEncoder):
SAMPLE_RATE: int = 32000
WINDOW_SIZE: int = 1024
HOP_SIZE: int = 320
MEL_BINS: int = 64
FMIN: int = 50
FMAX: int = 14000
CLASSES_NUM: int = 527

def __init__(
self,
feature_name: str,
Expand All @@ -41,42 +33,31 @@ def __init__(
self._feature_name = feature_name
self._device = device
self._pann_path = pann_path
self._audio_encoders = self._select_audio_encoders()

self._model = Cnn14(sample_rate=self.SAMPLE_RATE, window_size=self.WINDOW_SIZE,
hop_size=self.HOP_SIZE, mel_bins=self.MEL_BINS,
fmin=self.FMIN, fmax=self.FMAX, classes_num=self.CLASSES_NUM)

if pann_path is not None:
checkpoint = torch.load(pann_path, map_location=device)
self._model.load_state_dict(checkpoint['model'])
self._model.eval()
else:
raise TypeError('pann_path should not be None when using AudioEncoder.')

def _move_data_to_device(
self,
x: np.ndarray) -> torch.Tensor:
if 'float' in str(x.dtype):
return torch.Tensor(x).to(self._device)
elif 'int' in str(x.dtype):
return torch.LongTensor(x).to(self._device)
else:
raise ValueError('The input x cannot be cast into float or int.')
def _select_audio_encoders(self):
audio_encoders = {
'pann': [PANN],
'clap': [CLAPAudio]
}

config_dict = {
'pann': [PANNConfig(dict(model_path=self._pann_path))],
'clap': [CLAPAudioConfig()],
}

audio_encoders = [encoder(self._device, cfg)
for encoder, cfg in zip(audio_encoders[self._feature_name], config_dict[self._feature_name])]
return audio_encoders

@torch.no_grad()
def encode(
self,
video_path: str,
feature_time: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
(audio, _) = librosa.core.load(video_path, sr=self.SAMPLE_RATE, mono=True)
time = audio.shape[-1] / self.SAMPLE_RATE
batches = int(time // feature_time)
clip_sr = round(self.SAMPLE_RATE * feature_time)
assert clip_sr >= 9920, 'clip_sr = round(sampling_rate * feature_time) should be larger than 9920.'
audio = audio[:batches * clip_sr]
audio_clips = np.reshape(audio, [batches, clip_sr])
audio_clip_tensor = self._move_data_to_device(audio_clips)
output_dict = self._model(audio_clip_tensor, None)
audio_mask = torch.ones(1, len(output_dict['embedding'])).to(self._device)
return output_dict['embedding'].unsqueeze(0), audio_mask
audio, sr = librosa.core.load(video_path, sr=None, mono=True)

outputs = [encoder(audio, sr) for encoder in self._audio_encoders]
audio_features = torch.cat([o[0] for o in outputs])
audio_masks = torch.cat([o[1] for o in outputs])
return audio_features, audio_masks
65 changes: 65 additions & 0 deletions lighthouse/feature_extractor/audio_encoders/clap_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Optional

import numpy as np
import torch
import torchaudio.transforms as T
from msclap import CLAP
from torch.nn import functional as F


class CLAPAudioConfig:
def __init__(self, cfg: Optional[dict] = None):
self.sample_rate: int = 44100
self.window_sec: float = 1.0
self.version: str = '2023'
self.feature_time: float = 1.0

if cfg is not None:
self.update(cfg)

def update(self, cfg: dict):
self.__dict__.update(cfg)


class CLAPAudio(torch.nn.Module):
def __init__(self, device: str, cfg: CLAPAudioConfig):
super(CLAPAudio, self).__init__()
use_cuda = True if device == 'cuda' else False
self.clap = CLAP(use_cuda=use_cuda, version=cfg.version)
self.sample_rate = cfg.sample_rate
self.window_sec = cfg.window_sec
self.feature_time = cfg.feature_time
self._device = device

def _preprocess(self, audio: np.ndarray, sr: int) -> torch.Tensor:
audio_tensor = self._move_data_to_device(audio)
audio_tensor = T.Resample(sr, self.sample_rate)(audio_tensor) # original implementation in msclap

win_length = int(round(self.window_sec * self.sample_rate))
hop_length = int(round(self.feature_time * self.sample_rate))

# Truncate audio to fit the feature_time
# Note that this implementation is different from PANNs
half_win = win_length // 2
audio_tensor = F.pad(audio_tensor, (half_win, half_win), mode="constant", value=0)

audio_clip = audio_tensor.unfold(0, win_length, hop_length)

return audio_clip

def _move_data_to_device(
self,
x: np.ndarray) -> torch.Tensor:
if 'float' in str(x.dtype):
return torch.Tensor(x).to(self._device)
elif 'int' in str(x.dtype):
return torch.LongTensor(x).to(self._device)
else:
raise ValueError('The input x cannot be cast into float or int.')

def forward(self, audio: np.ndarray, sr: int):
audio_clip = self._preprocess(audio, sr)
output_dict = self.clap.clap.audio_encoder.base(audio_clip)
audio_mask = torch.ones(1, len(output_dict['embedding'])).to(self._device)
x = output_dict['embedding'].unsqueeze(0)
return x, audio_mask
84 changes: 81 additions & 3 deletions lighthouse/feature_extractor/audio_encoders/pann.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,92 @@
https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
"""

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import numpy as np

from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation



class PANNConfig:
def __init__(self, cfg: Optional[dict] = None):
self.sample_rate: int = 32000
self.window_size: int = 1024
self.hop_size: int = 320
self.mel_bins: int = 64
self.fmin: int = 50
self.fmax: int = 14000
self.classes_num: int = 527
self.model_path: Optional[str] = None
self.feature_time: float = 2.0

if cfg is not None:
self.update(cfg)

def update(self, cfg: dict):
self.__dict__.update(cfg)


class PANN(torch.nn.Module):
def __init__(self, device: str, cfg: PANNConfig):
super(PANN, self).__init__()
self._device: str = device
self.sample_rate: int = cfg.sample_rate
self.feature_time: float = cfg.feature_time
self._model = Cnn14(
sample_rate=cfg.sample_rate,
window_size=cfg.window_size,
hop_size=cfg.hop_size,
mel_bins=cfg.mel_bins,
fmin=cfg.fmin,
fmax=cfg.fmax,
classes_num=cfg.classes_num,
)

if cfg.model_path is not None:
checkpoint = torch.load(cfg.model_path, map_location=device)
self._model.load_state_dict(checkpoint['model'])
self._model.eval()
self._model.to(device)
else:
raise TypeError('pann_path should not be None when using AudioEncoder.')

def _preprocess(self, audio: np.ndarray, sr: int) -> torch.Tensor:
audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)

time = audio.shape[-1] / self.sample_rate
batches = int(time // self.feature_time)
clip_sr = round(self.sample_rate * self.feature_time)
assert clip_sr >= 9920, 'clip_sr = round(sampling_rate * feature_time) should be larger than 9920.'
audio = audio[:batches * clip_sr] # Truncate audio to fit the clip_sr
audio_clip = audio.reshape([batches, clip_sr])

audio_clip_tensor = self._move_data_to_device(audio_clip)
return audio_clip_tensor

def _move_data_to_device(
self,
x: np.ndarray) -> torch.Tensor:
if 'float' in str(x.dtype):
return torch.Tensor(x).to(self._device)
elif 'int' in str(x.dtype):
return torch.LongTensor(x).to(self._device)
else:
raise ValueError('The input x cannot be cast into float or int.')

def forward(self, audio: np.ndarray, sr: int):
audio_clip = self._preprocess(audio, sr)
output_dict = self._model(audio_clip, None) # audio_clip: (batch_size, clip_samples)
audio_mask = torch.ones(1, len(output_dict['embedding'])).to(self._device)
x = output_dict['embedding'].unsqueeze(0)
return x, audio_mask



def do_mixup(x, mixup_lambda):
out = x[0::2].transpose(0, -1) * mixup_lambda[0::2] + \
Expand All @@ -36,7 +115,6 @@ def do_mixup(x, mixup_lambda):
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)

if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
Expand Down Expand Up @@ -182,4 +260,4 @@ def forward(self, input, mixup_lambda=None):

output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}

return output_dict
return output_dict
5 changes: 4 additions & 1 deletion lighthouse/feature_extractor/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lighthouse.feature_extractor.base_encoder import BaseEncoder
from lighthouse.feature_extractor.text_encoders.clip_t import CLIPText
from lighthouse.feature_extractor.text_encoders.glove import GloVe
from lighthouse.feature_extractor.text_encoders.clap_t import CLAPText

from typing import Tuple

Expand Down Expand Up @@ -37,13 +38,15 @@ def _select_text_encoders(self):
'clip': [CLIPText],
'clip_slowfast': [CLIPText],
'clip_slowfast_pann': [CLIPText],
'clap': [CLAPText],
}

model_path_dict = {
'resnet_glove': ['glove.6B.300d'],
'clip': ['ViT-B/32'],
'clip_slowfast': ['ViT-B/32'],
'clip_slowfast_pann': ['ViT-B/32'],
'clap': ['2023'],
}

text_encoders = [encoder(self._device, model_path)
Expand All @@ -56,4 +59,4 @@ def encode(
outputs = [encoder(query) for encoder in self._text_encoders]
text_features = torch.cat([o[0] for o in outputs])
text_masks = torch.cat([o[1] for o in outputs])
return text_features, text_masks
return text_features, text_masks
43 changes: 43 additions & 0 deletions lighthouse/feature_extractor/text_encoders/clap_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Tuple

import torch
from msclap import CLAP

"""
Copyright $today.year LY Corporation

LY Corporation licenses this file to you under the Apache License,
version 2.0 (the "License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at:

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations
under the License.
"""


class CLAPText:
def __init__(
self,
device: str,
model_path: str,
) -> None:
self._model_path: str = model_path
self._device: str = device
use_cuda = True if self._device == 'cuda' else False
self._clap_extractor = CLAP(use_cuda=use_cuda, version=model_path)
self._preprocessor = self._clap_extractor.preprocess_text
self._text_encoder = self._clap_extractor.clap.caption_encoder

def __call__(self, query: str) -> Tuple[torch.Tensor, torch.Tensor]:
preprocessed = self._preprocessor([query])
mask = preprocessed['attention_mask']

out = self._text_encoder.base(**preprocessed)
x = out[0] # out[1] is pooled output

return x, mask
Loading
Loading