Skip to content

Commit

Permalink
Merge pull request #2178 from huggingface/pali_siglip
Browse files Browse the repository at this point in the history
Support loading of PaliGemma weights into GAP variants of SigLIP ViT.
  • Loading branch information
rwightman authored May 15, 2024
2 parents 04462f5 + 7b3b11b commit 6653747
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 23 deletions.
10 changes: 8 additions & 2 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\
load_custom_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
Expand Down Expand Up @@ -185,7 +186,12 @@ def load_pretrained(
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
custom_load = pretrained_cfg.get('custom_load', False)
if isinstance(custom_load, str) and custom_load == 'hf':
load_custom_from_hf(*pretrained_loc, model)
return
else:
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
else:
Expand Down
7 changes: 7 additions & 0 deletions timm/models/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
return torch.load(cached_file, map_location='cpu')


def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
assert has_hf_hub(True)
hf_model_id, hf_revision = hf_split(model_id)
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
return model.load_pretrained(cached_file)


def save_config_for_hf(
model,
config_path: str,
Expand Down
122 changes: 101 additions & 21 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
"""
import numpy as np

def _n2p(w, t=True):
def _n2p(w, t=True, idx=None):
if idx is not None:
w = w[idx]
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
Expand Down Expand Up @@ -955,21 +957,28 @@ def _n2p(w, t=True):

mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
block_prefix = f'{prefix}Transformer/encoderblock/'
idx = i
else:
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
idx = None
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))


def _convert_openai_clip(
Expand Down Expand Up @@ -1769,6 +1778,44 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
input_size=(3, 384, 384),
num_classes=0),

'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg(
hf_hub_id='google/paligemma-3b-mix-224-jax',
hf_hub_filename='paligemma-3b-mix-224.npz',
custom_load='hf',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg(
hf_hub_id='google/paligemma-3b-pt-224-jax',
hf_hub_filename='paligemma-3b-pt-224.npz',
custom_load='hf',
num_classes=0),
'vit_so400m_patch14_siglip_gap_384.webli': _cfg(
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 384, 384), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg(
hf_hub_id='google/paligemma-3b-mix-448-jax',
hf_hub_filename='paligemma-3b-mix-448.npz',
custom_load='hf',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg(
hf_hub_id='google/paligemma-3b-pt-448-jax',
hf_hub_filename='paligemma-3b-pt-448.npz',
custom_load='hf',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg(
hf_hub_id='google/paligemma-3b-pt-896-jax',
hf_hub_filename='paligemma-3b-pt-896.npz',
custom_load='hf',
input_size=(3, 896, 896), crop_pct=1.0,
num_classes=0),

'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
Expand Down Expand Up @@ -2756,15 +2803,48 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
return model


# @register_model
# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
# model_args = dict(
# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
# no_embed_class=True, reg_tokens=4,
# )
# model = _create_vision_transformer(
# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
# return model
@register_model
def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
Expand Down

0 comments on commit 6653747

Please sign in to comment.