Skip to content

Commit

Permalink
added ViT-L to vision_transformers.py + improved img_size argument fl…
Browse files Browse the repository at this point in the history
…exibility
  • Loading branch information
clemsgrs committed Nov 26, 2024
1 parent e058a80 commit 8769b17
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 7 deletions.
1 change: 1 addition & 0 deletions dino/config/knn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ label_name: 'label'

model:
arch: vit_small
input_size: 256
patch_size: 16
pretrained_weights:
checkpoint_key: 'teacher'
Expand Down
1 change: 1 addition & 0 deletions dino/config/patch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ resume_from_checkpoint: 'latest.pth'

model:
arch: 'vit_small'
input_size: 256
patch_size: 16
out_dim: 65536 # dimensionality of the DINO head output. For complex and large datasets large values (like 65k) work well
norm_last_layer: False # whether or not to weight normalize the last layer of the DINO head ; not normalizing leads to better performance but can make the training unstable.
Expand Down
2 changes: 1 addition & 1 deletion dino/config/region.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ seed: 0

model:
arch: 'hvit_xs'
region_size: 4096
input_size: 4096
patch_size: 256
out_dim: 65536 # dimensionality of the DINO head output. For complex and large datasets large values (like 65k) work well
norm_last_layer: False # whether or not to weight normalize the last layer of the DINO head ; not normalizing leads to better performance but can make the training unstable.
Expand Down
4 changes: 3 additions & 1 deletion dino/eval/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def extract_feature_pipeline(
test_df: pd.DataFrame,
features_dir: str,
arch: str,
input_size: int,
patch_size: int,
pretrained_weights: str,
checkpoint_key: str,
Expand All @@ -157,7 +158,7 @@ def extract_feature_pipeline(
)

# ============ building network ... ============
model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
model = vits.__dict__[arch](img_size=input_size, patch_size=patch_size, num_classes=0)
print(f"Model {arch} {patch_size}x{patch_size} built.")
model.cuda()
print("Loading pretrained weights...")
Expand Down Expand Up @@ -471,6 +472,7 @@ def main(cfg: DictConfig):
test_df,
cfg.features_dir,
cfg.model.arch,
cfg.model.input_size,
cfg.model.patch_size,
cfg.model.pretrained_weights,
cfg.model.checkpoint_key,
Expand Down
4 changes: 3 additions & 1 deletion dino/eval_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def extract_feature_pipeline(
test_df: pd.DataFrame,
features_dir: str,
arch: str,
input_size: int,
patch_size: int,
pretrained_weights: str,
checkpoint_key: str,
Expand All @@ -141,7 +142,7 @@ def extract_feature_pipeline(
)

# ============ building network ... ============
model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
model = vits.__dict__[arch](img_size=input_size, patch_size=patch_size, num_classes=0)
print(f"Model {arch} {patch_size}x{patch_size} built.")
model.cuda()
print("Loading pretrained weights...")
Expand Down Expand Up @@ -384,6 +385,7 @@ def main(cfg: DictConfig):
test_df,
features_dir,
cfg.model.arch,
cfg.model.input_size,
cfg.model.patch_size,
cfg.model.pretrained_weights,
cfg.model.checkpoint_key,
Expand Down
20 changes: 20 additions & 0 deletions dino/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,26 @@ def vit_base(
return model


def vit_large(
img_size: int = 256,
patch_size: int = 16,
embed_dim: int = 1024,
**kwargs,
):
model = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
embed_dim=embed_dim,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model


class HierarchicalVisionTransformer(nn.Module):
"""Hierarchical Vision Transformer"""

Expand Down
3 changes: 2 additions & 1 deletion dino/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,11 @@ def main(cfg: DictConfig):
if is_main_process():
print("Building student and teacher networks...")
student = vits.__dict__[cfg.model.arch](
img_size=cfg.model.input_size,
patch_size=cfg.model.patch_size,
drop_path_rate=cfg.model.drop_path_rate,
)
teacher = vits.__dict__[cfg.model.arch](patch_size=cfg.model.patch_size)
teacher = vits.__dict__[cfg.model.arch](img_size=cfg.model.input_size, patch_size=cfg.model.patch_size)
embed_dim = student.embed_dim

# multi-crop wrapper handles forward with inputs of different resolutions
Expand Down
6 changes: 3 additions & 3 deletions dino/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(cfg: DictConfig):
cfg.aug.global_crops_scale,
cfg.aug.local_crops_number,
cfg.aug.local_crops_scale,
cfg.model.region_size,
cfg.model.input_size,
cfg.model.patch_size,
)

Expand Down Expand Up @@ -122,12 +122,12 @@ def main(cfg: DictConfig):
if is_main_process():
print("Building student and teacher networks...")
student = vits.__dict__[cfg.model.arch](
img_size=cfg.model.region_size,
img_size=cfg.model.input_size,
patch_size=cfg.model.patch_size,
drop_path_rate=cfg.model.drop_path_rate,
)
teacher = vits.__dict__[cfg.model.arch](
img_size=cfg.model.region_size, patch_size=cfg.model.patch_size
img_size=cfg.model.input_size, patch_size=cfg.model.patch_size
)
embed_dim = student.embed_dim

Expand Down

0 comments on commit 8769b17

Please sign in to comment.