From 7b3b11b63fcd7e7fa2dbfa1659a7271b2a036de1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 May 2024 15:44:37 -0700 Subject: [PATCH] Support loading of paligemma weights into GAP variants of SigLIP ViT. Minor tweak to npz loading for packed transformer weights. --- timm/models/_builder.py | 10 ++- timm/models/_hub.py | 7 ++ timm/models/vision_transformer.py | 122 +++++++++++++++++++++++++----- 3 files changed, 116 insertions(+), 23 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index f248fbd310..7741cf94bf 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -10,7 +10,8 @@ from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet from timm.models._features_fx import FeatureGraphNet from timm.models._helpers import load_state_dict -from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf +from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\ + load_custom_from_hf from timm.models._manipulate import adapt_input_conv from timm.models._pretrained import PretrainedCfg from timm.models._prune import adapt_model_from_file @@ -185,7 +186,12 @@ def load_pretrained( elif load_from == 'hf-hub': _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') if isinstance(pretrained_loc, (list, tuple)): - state_dict = load_state_dict_from_hf(*pretrained_loc) + custom_load = pretrained_cfg.get('custom_load', False) + if isinstance(custom_load, str) and custom_load == 'hf': + load_custom_from_hf(*pretrained_loc, model) + return + else: + state_dict = load_state_dict_from_hf(*pretrained_loc) else: state_dict = load_state_dict_from_hf(pretrained_loc) else: diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 55ab04bfa1..a36321bf94 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -190,6 +190,13 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): return torch.load(cached_file, map_location='cpu') +def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module): + assert has_hf_hub(True) + hf_model_id, hf_revision = hf_split(model_id) + cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) + return model.load_pretrained(cached_file) + + def save_config_for_hf( model, config_path: str, diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 4f2a623f51..ace8e532e4 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -845,7 +845,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = """ import numpy as np - def _n2p(w, t=True): + def _n2p(w, t=True, idx=None): + if idx is not None: + w = w[idx] if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: @@ -955,21 +957,28 @@ def _n2p(w, t=True): mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) for i, block in enumerate(model.blocks.children()): - block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w: + block_prefix = f'{prefix}Transformer/encoderblock/' + idx = i + else: + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + idx = None mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' - block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx)) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx)) block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) + _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx)) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx)) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx)) for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) - getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) + getattr(block.mlp, f'fc{r + 1}').weight.copy_( + _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx)) + getattr(block.mlp, f'fc{r + 1}').bias.copy_( + _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx)) def _convert_openai_clip( @@ -1769,6 +1778,44 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 384, 384), num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg( + hf_hub_id='google/paligemma-3b-mix-224-jax', + hf_hub_filename='paligemma-3b-mix-224.npz', + custom_load='hf', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg( + hf_hub_id='google/paligemma-3b-pt-224-jax', + hf_hub_filename='paligemma-3b-pt-224.npz', + custom_load='hf', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_384.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 384, 384), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg( + hf_hub_id='google/paligemma-3b-mix-448-jax', + hf_hub_filename='paligemma-3b-mix-448.npz', + custom_load='hf', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg( + hf_hub_id='google/paligemma-3b-pt-448-jax', + hf_hub_filename='paligemma-3b-pt-448.npz', + custom_load='hf', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg( + hf_hub_id='google/paligemma-3b-pt-896-jax', + hf_hub_filename='paligemma-3b-pt-896.npz', + custom_load='hf', + input_size=(3, 896, 896), crop_pct=1.0, + num_classes=0), + 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', @@ -2756,15 +2803,48 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT return model -# @register_model -# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer: -# model_args = dict( -# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, -# no_embed_class=True, reg_tokens=4, -# ) -# model = _create_vision_transformer( -# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs)) -# return model +@register_model +def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs)) + return model @register_model