Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mmcls.VisionTransformer backbone support #1908

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion otx/algorithms/classification/configs/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ learning_parameters:
stable. A larger batch size has higher memory requirements.
editable: true
header: Batch size
max_value: 512
max_value: 2048
harimkang marked this conversation as resolved.
Show resolved Hide resolved
min_value: 1
type: INTEGER
ui_rules:
Expand Down
2 changes: 1 addition & 1 deletion otx/algorithms/common/configs/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class BaseLearningParameters(ParameterGroup):
batch_size = configurable_integer(
default_value=5,
min_value=1,
max_value=512,
max_value=2048,
header="Batch size",
description="The number of training samples seen in each iteration of training. Increasing thisvalue "
"improves training time and may make the training more stable. A larger batch size has higher "
Expand Down
8 changes: 4 additions & 4 deletions otx/cli/builder/supported_backbone/mmcls.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"options": {
"arch": ["tiny", "small", "base"]
},
"available": []
"available": ["CLASSIFICATION"]
},
"mmcls.ConvMixer": {
"required": ["arch"],
Expand Down Expand Up @@ -287,7 +287,7 @@
"mmcls.T2T_ViT": {
"required": [],
"options": {},
"available": []
"available": ["CLASSIFICATION"]
},
"mmcls.TIMMBackbone": {
"required": ["model_name"],
Expand All @@ -299,7 +299,7 @@
"options": {
"arch": ["base", "small"]
},
"available": []
"available": ["CLASSIFICATION"]
},
"mmcls.PCPVT": {
"required": ["arch"],
Expand Down Expand Up @@ -341,7 +341,7 @@
"deit-base"
]
},
"available": []
"available": ["CLASSIFICATION"]
}
}
}
7 changes: 7 additions & 0 deletions otx/mpa/cls/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

logger = get_logger()

TRANSFORMER_BACKBONES = ["VisionTransformer", "T2T_ViT", "TNT", "Conformer"]


class ClsStage(Stage):
MODEL_BUILDER = build_classifier
Expand Down Expand Up @@ -89,6 +91,11 @@ def configure_in_channel(cfg):
output = layer(torch.rand([1] + list(input_shape)))
if isinstance(output, (tuple, list)):
output = output[-1]

if layer.__class__.__name__ in TRANSFORMER_BACKBONES and isinstance(output, (tuple, list)):
# mmcls.VisionTransformer outputs Tuple[List[...]] and the last index of List is the final logit.
_, output = output

in_channels = output.shape[1]
if cfg.model.get("neck") is not None:
if cfg.model.neck.get("in_channels") is not None:
Expand Down