-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
transfer laion_clap to local dependency in order to fix load_ckpt error
- Loading branch information
Showing
98 changed files
with
10,835 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
import sys | ||
dir_path = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append(dir_path) | ||
from .hook import CLAP_Module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .factory import list_models, create_model, create_model_and_transforms, add_model_config | ||
from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics | ||
from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model | ||
from .openai import load_openai_model, list_openai_models | ||
from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ | ||
get_pretrained_url, download_pretrained | ||
from .tokenizer import SimpleTokenizer, tokenize | ||
from .transform import image_transform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from transformers import BertTokenizer, BertModel | ||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | ||
model = BertModel.from_pretrained("bert-base-uncased") | ||
text = "Replace me by any text you'd like." | ||
|
||
def bert_embeddings(text): | ||
# text = "Replace me by any text you'd like." | ||
encoded_input = tokenizer(text, return_tensors='pt') | ||
output = model(**encoded_input) | ||
return output | ||
|
||
from transformers import RobertaTokenizer, RobertaModel | ||
|
||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | ||
model = RobertaModel.from_pretrained('roberta-base') | ||
text = "Replace me by any text you'd like." | ||
def Roberta_embeddings(text): | ||
# text = "Replace me by any text you'd like." | ||
encoded_input = tokenizer(text, return_tensors='pt') | ||
output = model(**encoded_input) | ||
return output | ||
|
||
from transformers import BartTokenizer, BartModel | ||
|
||
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') | ||
model = BartModel.from_pretrained('facebook/bart-base') | ||
text = "Replace me by any text you'd like." | ||
def bart_embeddings(text): | ||
# text = "Replace me by any text you'd like." | ||
encoded_input = tokenizer(text, return_tensors='pt') | ||
output = model(**encoded_input) | ||
return output |
Binary file added
BIN
+1.29 MB
melodytalk/dependencies/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz
Binary file not shown.
257 changes: 257 additions & 0 deletions
257
melodytalk/dependencies/laion_clap/clap_module/factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
import json | ||
import logging | ||
import os | ||
import pathlib | ||
import re | ||
from copy import deepcopy | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from .model import CLAP, convert_weights_to_fp16 | ||
from .openai import load_openai_model | ||
from .pretrained import get_pretrained_url, download_pretrained | ||
from .transform import image_transform | ||
|
||
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] | ||
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs | ||
|
||
|
||
def _natural_key(string_): | ||
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] | ||
|
||
|
||
def _rescan_model_configs(): | ||
global _MODEL_CONFIGS | ||
|
||
config_ext = (".json",) | ||
config_files = [] | ||
for config_path in _MODEL_CONFIG_PATHS: | ||
if config_path.is_file() and config_path.suffix in config_ext: | ||
config_files.append(config_path) | ||
elif config_path.is_dir(): | ||
for ext in config_ext: | ||
config_files.extend(config_path.glob(f"*{ext}")) | ||
|
||
for cf in config_files: | ||
with open(cf, "r") as f: | ||
model_cfg = json.load(f) | ||
if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): | ||
_MODEL_CONFIGS[cf.stem] = model_cfg | ||
|
||
_MODEL_CONFIGS = { | ||
k: v | ||
for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) | ||
} | ||
|
||
|
||
_rescan_model_configs() # initial populate of model config registry | ||
|
||
|
||
def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): | ||
checkpoint = torch.load(checkpoint_path, map_location=map_location) | ||
if isinstance(checkpoint, dict) and "state_dict" in checkpoint: | ||
state_dict = checkpoint["state_dict"] | ||
else: | ||
state_dict = checkpoint | ||
if skip_params: | ||
if next(iter(state_dict.items()))[0].startswith("module"): | ||
state_dict = {k[7:]: v for k, v in state_dict.items()} | ||
# for k in state_dict: | ||
# if k.startswith('transformer'): | ||
# v = state_dict.pop(k) | ||
# state_dict['text_branch.' + k[12:]] = v | ||
return state_dict | ||
|
||
|
||
def create_model( | ||
amodel_name: str, | ||
tmodel_name: str, | ||
pretrained: str = "", | ||
precision: str = "fp32", | ||
device: torch.device = torch.device("cpu"), | ||
jit: bool = False, | ||
force_quick_gelu: bool = False, | ||
openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), | ||
skip_params=True, | ||
pretrained_audio: str = "", | ||
pretrained_text: str = "", | ||
enable_fusion: bool = False, | ||
fusion_type: str = 'None' | ||
# pretrained_image: bool = False, | ||
): | ||
amodel_name = amodel_name.replace( | ||
"/", "-" | ||
) # for callers using old naming with / in ViT names | ||
pretrained_orig = pretrained | ||
pretrained = pretrained.lower() | ||
if pretrained == "openai": | ||
if amodel_name in _MODEL_CONFIGS: | ||
logging.info(f"Loading {amodel_name} model config.") | ||
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) | ||
else: | ||
logging.error( | ||
f"Model config for {amodel_name} not found; available models {list_models()}." | ||
) | ||
raise RuntimeError(f"Model config for {amodel_name} not found.") | ||
|
||
logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") | ||
# Hard Code in model name | ||
model_cfg["text_cfg"]["model_type"] = tmodel_name | ||
model = load_openai_model( | ||
"ViT-B-16", | ||
model_cfg, | ||
device=device, | ||
jit=jit, | ||
cache_dir=openai_model_cache_dir, | ||
enable_fusion=enable_fusion, | ||
fusion_type=fusion_type | ||
) | ||
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 | ||
if precision == "amp" or precision == "fp32": | ||
model = model.float() | ||
else: | ||
if amodel_name in _MODEL_CONFIGS: | ||
logging.info(f"Loading {amodel_name} model config.") | ||
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) | ||
else: | ||
logging.error( | ||
f"Model config for {amodel_name} not found; available models {list_models()}." | ||
) | ||
raise RuntimeError(f"Model config for {amodel_name} not found.") | ||
|
||
if force_quick_gelu: | ||
# override for use of QuickGELU on non-OpenAI transformer models | ||
model_cfg["quick_gelu"] = True | ||
|
||
# if pretrained_image: | ||
# if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): | ||
# # pretrained weight loading for timm models set via vision_cfg | ||
# model_cfg['vision_cfg']['timm_model_pretrained'] = True | ||
# else: | ||
# assert False, 'pretrained image towers currently only supported for timm models' | ||
model_cfg["text_cfg"]["model_type"] = tmodel_name | ||
model_cfg["enable_fusion"] = enable_fusion | ||
model_cfg["fusion_type"] = fusion_type | ||
model = CLAP(**model_cfg) | ||
|
||
if pretrained: | ||
checkpoint_path = "" | ||
url = get_pretrained_url(amodel_name, pretrained) | ||
if url: | ||
checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) | ||
elif os.path.exists(pretrained_orig): | ||
checkpoint_path = pretrained_orig | ||
if checkpoint_path: | ||
logging.info(f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained}).") | ||
ckpt = load_state_dict(checkpoint_path, skip_params=True) | ||
model.load_state_dict(ckpt) | ||
param_names = [n for n, p in model.named_parameters()] | ||
for n in param_names: | ||
print(n, "\t", "Loaded" if n in ckpt else "Unloaded") | ||
else: | ||
logging.warning( | ||
f"Pretrained weights ({pretrained}) not found for model {amodel_name}." | ||
) | ||
raise RuntimeError( | ||
f"Pretrained weights ({pretrained}) not found for model {amodel_name}." | ||
) | ||
|
||
if pretrained_audio: | ||
if amodel_name.startswith('PANN'): | ||
if 'Cnn14_mAP' in pretrained_audio: # official checkpoint | ||
audio_ckpt = torch.load(pretrained_audio, map_location='cpu') | ||
audio_ckpt = audio_ckpt['model'] | ||
keys = list(audio_ckpt.keys()) | ||
for key in keys: | ||
if 'spectrogram_extractor' not in key and 'logmel_extractor' not in key: | ||
v = audio_ckpt.pop(key) | ||
audio_ckpt['audio_branch.' + key] = v | ||
elif os.path.basename(pretrained_audio).startswith('PANN'): # checkpoint trained via HTSAT codebase | ||
audio_ckpt = torch.load(pretrained_audio, map_location='cpu') | ||
audio_ckpt = audio_ckpt['state_dict'] | ||
keys = list(audio_ckpt.keys()) | ||
for key in keys: | ||
if key.startswith('sed_model'): | ||
v = audio_ckpt.pop(key) | ||
audio_ckpt['audio_branch.' + key[10:]] = v | ||
elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase | ||
audio_ckpt = torch.load(pretrained_audio, map_location='cpu') | ||
else: | ||
raise ValueError('Unknown audio checkpoint') | ||
elif amodel_name.startswith('HTSAT'): | ||
if 'HTSAT_AudioSet_Saved' in pretrained_audio: # official checkpoint | ||
audio_ckpt = torch.load(pretrained_audio, map_location='cpu') | ||
audio_ckpt = audio_ckpt['state_dict'] | ||
keys = list(audio_ckpt.keys()) | ||
for key in keys: | ||
if key.startswith('sed_model') and ('spectrogram_extractor' not in key | ||
and 'logmel_extractor' not in key): | ||
v = audio_ckpt.pop(key) | ||
audio_ckpt['audio_branch.' + key[10:]] = v | ||
elif os.path.basename(pretrained_audio).startswith('HTSAT'): # checkpoint trained via HTSAT codebase | ||
audio_ckpt = torch.load(pretrained_audio, map_location='cpu') | ||
audio_ckpt = audio_ckpt['state_dict'] | ||
keys = list(audio_ckpt.keys()) | ||
for key in keys: | ||
if key.startswith('sed_model'): | ||
v = audio_ckpt.pop(key) | ||
audio_ckpt['audio_branch.' + key[10:]] = v | ||
elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase | ||
audio_ckpt = torch.load(pretrained_audio, map_location='cpu') | ||
else: | ||
raise ValueError('Unknown audio checkpoint') | ||
else: | ||
raise f'this audio encoder pretrained checkpoint is not support' | ||
|
||
model.load_state_dict(audio_ckpt, strict=False) | ||
logging.info(f"Loading pretrained {amodel_name} weights ({pretrained_audio}).") | ||
param_names = [n for n, p in model.named_parameters()] | ||
for n in param_names: | ||
print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") | ||
|
||
model.to(device=device) | ||
if precision == "fp16": | ||
assert device.type != "cpu" | ||
convert_weights_to_fp16(model) | ||
|
||
if jit: | ||
model = torch.jit.script(model) | ||
|
||
return model, model_cfg | ||
|
||
|
||
def create_model_and_transforms( | ||
model_name: str, | ||
pretrained: str = "", | ||
precision: str = "fp32", | ||
device: torch.device = torch.device("cpu"), | ||
jit: bool = False, | ||
force_quick_gelu: bool = False, | ||
# pretrained_image: bool = False, | ||
): | ||
model = create_model( | ||
model_name, | ||
pretrained, | ||
precision, | ||
device, | ||
jit, | ||
force_quick_gelu=force_quick_gelu, | ||
# pretrained_image=pretrained_image | ||
) | ||
preprocess_train = image_transform(model.visual.image_size, is_train=True) | ||
preprocess_val = image_transform(model.visual.image_size, is_train=False) | ||
return model, preprocess_train, preprocess_val | ||
|
||
|
||
def list_models(): | ||
"""enumerate available model architectures based on config files""" | ||
return list(_MODEL_CONFIGS.keys()) | ||
|
||
|
||
def add_model_config(path): | ||
"""add model config path or file and update registry""" | ||
if not isinstance(path, Path): | ||
path = Path(path) | ||
_MODEL_CONFIG_PATHS.append(path) | ||
_rescan_model_configs() |
Oops, something went wrong.