Skip to content

Commit

Permalink
Add the 256x256 in1k ft of the so150m, add an alternate so150m def
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 18, 2025
1 parent 2a84d68 commit 3677f67
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2152,15 +2152,20 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
'vit_base_patch16_reg4_gap_256.untrained': _cfg(
input_size=(3, 256, 256)),

'vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k': _cfg(
'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0),
input_size=(3, 256, 256), crop_pct=0.95),
'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
'vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
input_size=(3, 256, 256)),
'vit_so150m2_patch16_reg1_gap_256.untrained': _cfg(
input_size=(3, 256, 256), crop_pct=0.95),

'vit_intern300m_patch14_448.ogvl_dist': _cfg(
hf_hub_id='timm/',
Expand Down Expand Up @@ -3467,6 +3472,7 @@ def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionT

@register_model
def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" SO150M (shape optimized, but diff than paper def, optimized for GPU) """
model_args = dict(
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
class_token=False, reg_tokens=4, global_pool='map',
Expand All @@ -3478,6 +3484,7 @@ def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> Visio

@register_model
def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" SO150M (shape optimized, but diff than paper def, optimized for GPU) """
model_args = dict(
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
Expand All @@ -3489,6 +3496,7 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio

@register_model
def vit_so150m_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" SO150M (shape optimized, but diff than paper def, optimized for GPU) """
model_args = dict(
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
Expand All @@ -3498,6 +3506,18 @@ def vit_so150m_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> Visio
return model


@register_model
def vit_so150m2_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
model_args = dict(
patch_size=16, embed_dim=896, depth=20, num_heads=14, mlp_ratio=2.429, init_values=1e-5,
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
)
model = _create_vision_transformer(
'vit_so150m2_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
Expand Down

0 comments on commit 3677f67

Please sign in to comment.