Skip to content

Commit

Permalink
siglip2 weights on hub, fix forward_intermediates when no prefix toke…
Browse files Browse the repository at this point in the history
…ns (& return prefix selected)
  • Loading branch information
rwightman committed Feb 21, 2025
1 parent f63a11c commit a667d3d
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,11 +769,14 @@ def forward_intermediates(
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
else:
prefix_tokens = None

if reshape:
# reshape to BCHW output format
H, W = self.patch_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))

Expand Down Expand Up @@ -1889,17 +1892,17 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),

'vit_base_patch32_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_224.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_256.webli': _cfg(
Expand All @@ -1911,49 +1914,49 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_384.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_base_patch16_siglip_512.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_large_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_256.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_384.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_so400m_patch14_siglip_224.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_378.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 378, 378),
num_classes=0),
'vit_so400m_patch14_siglip_378.webli': _cfg(
Expand All @@ -1965,42 +1968,42 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_256.webli_i18n': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_384.v2_webli': _cfg(
#hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch16_siglip_512.v2_webli': _cfg(
#hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_giantopt_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_giantopt_patch16_siglip_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),

'vit_base_patch32_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_224.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_gap_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_256.webli': _cfg(
Expand All @@ -2012,43 +2015,43 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_gap_384.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_gap_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_base_patch16_siglip_gap_512.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_large_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_gap_256.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_gap_384.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_gap_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
hf_hub_id='timm/',
Expand All @@ -2071,7 +2074,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
# custom_load='hf',
# num_classes=0),
'vit_so400m_patch14_siglip_gap_378.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 378, 378),
num_classes=0),
'vit_so400m_patch14_siglip_gap_378.webli': _cfg(
Expand Down Expand Up @@ -2147,27 +2150,27 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
# input_size=(3, 896, 896), crop_pct=1.0,
# num_classes=0),
'vit_so400m_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch16_siglip_gap_512.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_giantopt_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_giantopt_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),

Expand Down

0 comments on commit a667d3d

Please sign in to comment.