Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DeiT template for OTX classification task #2093

Merged
merged 27 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a7d6a5a
Add vit
JihwanEom Apr 9, 2023
5861abf
Upload experiments folder
JihwanEom Apr 9, 2023
d5e1d87
Merge branch 'develop' of https://github.com/JihwanEom/training_exten…
JihwanEom Apr 12, 2023
5c07701
Remove experiments files
JihwanEom Apr 12, 2023
8e6f98a
Revert clip codes
JihwanEom Apr 12, 2023
5c93b5b
Revert clip codes
JihwanEom Apr 12, 2023
32d01ae
Revert clip codes
JihwanEom Apr 12, 2023
7a4f9d5
Merge branch 'develop' of https://github.com/JihwanEom/training_exten…
JihwanEom May 2, 2023
ef95379
Enable DeiT tempalte
JihwanEom May 2, 2023
d8e359a
Fix typo
JihwanEom May 2, 2023
7406d38
Fix unit test
JihwanEom May 2, 2023
7162f94
Support multi-label & h-label
JihwanEom May 2, 2023
bffc974
Temporary hack for multi-label classification
JihwanEom May 2, 2023
933bb60
Merge branch 'develop' of https://github.com/JihwanEom/training_exten…
JihwanEom May 3, 2023
b43598c
Resolve pre-commit issue
JihwanEom May 3, 2023
79a375d
Clear codes
JihwanEom May 4, 2023
a81c12b
Fix typo
JihwanEom May 4, 2023
c89b91d
Skip FQ test cases
JihwanEom May 4, 2023
0d0a6b2
Update HPO config
JihwanEom May 4, 2023
3aaa7d5
Refactor heads
JihwanEom May 4, 2023
dbd4046
Resolve bugs
JihwanEom May 4, 2023
274275c
Update deit HPs
JihwanEom May 6, 2023
aa3c659
Fix pre-commit issue
JihwanEom May 6, 2023
02e57c7
Add unit test for OTXHeadMixin & avoid protected function
JihwanEom May 6, 2023
1a72350
Merge branch 'develop' of https://github.com/JihwanEom/training_exten…
JihwanEom May 6, 2023
d80a76a
Merge branch 'develop' of https://github.com/JihwanEom/training_exten…
JihwanEom May 16, 2023
43ff369
Reflect review comments
JihwanEom May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ All notable changes to this project will be documented in this file.
- Add option to save images after inference in OTX CLI demo together with demo in exportable code (<https://github.com/openvinotoolkit/training_extensions/pull/2005>)
- Support storage cache in Apache Arrow using Datumaro for cls, det, seg tasks (<https://github.com/openvinotoolkit/training_extensions/pull/2009>)
- Add noisy label detection for multi-class classification task (<https://github.com/openvinotoolkit/training_extensions/pull/1985>, <https://github.com/openvinotoolkit/training_extensions/pull/2034>)
- Add DeiT template for classification tasks as experimental template (<https://github.com/openvinotoolkit/training_extensions/pull/2093)

### Enhancements

Expand Down
2 changes: 0 additions & 2 deletions otx/algorithms/classification/adapters/mmcls/configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,6 @@ def configure_in_channel(cfg):
if layer.__class__.__name__ in TRANSFORMER_BACKBONES and isinstance(output, (tuple, list)):
# mmcls.VisionTransformer outputs Tuple[List[...]] and the last index of List is the final logit.
_, output = output
if cfg.model.head.type != "VisionTransformerClsHead":
JihwanEom marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"{layer.__class__.__name__ } needs VisionTransformerClsHead as head")

in_channels = output.shape[1]
if cfg.model.get("neck") is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .custom_hierarchical_non_linear_cls_head import CustomHierarchicalNonLinearClsHead
from .custom_multi_label_linear_cls_head import CustomMultiLabelLinearClsHead
from .custom_multi_label_non_linear_cls_head import CustomMultiLabelNonLinearClsHead
from .custom_vision_transformer_head import CustomVisionTransformerClsHead
from .mmov_cls_head import MMOVClsHead
from .non_linear_cls_head import NonLinearClsHead
from .semisl_cls_head import SemiLinearClsHead, SemiNonLinearClsHead
Expand All @@ -41,6 +42,7 @@
"CustomMultiLabelNonLinearClsHead",
"SemiLinearMultilabelClsHead",
"SemiNonLinearMultilabelClsHead",
"CustomVisionTransformerClsHead",
"NonLinearClsHead",
"SemiLinearClsHead",
"SemiNonLinearClsHead",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from mmcv.cnn import normal_init
from torch import nn

from .mixin import OTXHeadMixin


@HEADS.register_module()
class CustomHierarchicalLinearClsHead(MultiLabelClsHead):
class CustomHierarchicalLinearClsHead(OTXHeadMixin, MultiLabelClsHead):
"""Custom Linear classification head for hierarchical classification task.

Args:
Expand Down Expand Up @@ -80,6 +82,7 @@ def forward(self, x):
def forward_train(self, cls_score, gt_label, **kwargs):
"""Forward_train fuction of CustomHierarchicalLinearClsHead class."""
img_metas = kwargs.get("img_metas", None)
cls_score = self.pre_logits(cls_score)
gt_label = gt_label.type_as(cls_score)
cls_score = self.fc(cls_score)

Expand Down Expand Up @@ -127,6 +130,7 @@ def forward_train(self, cls_score, gt_label, **kwargs):

def simple_test(self, img):
"""Test without augmentation."""
img = self.pre_logits(img)
cls_score = self.fc(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from mmcv.cnn import build_activation_layer, constant_init, normal_init
from torch import nn

from .mixin import OTXHeadMixin


@HEADS.register_module()
class CustomHierarchicalNonLinearClsHead(MultiLabelClsHead): # pylint: disable=too-many-instance-attributes
class CustomHierarchicalNonLinearClsHead(
OTXHeadMixin, MultiLabelClsHead
): # pylint: disable=too-many-instance-attributes
"""Custom NonLinear classification head for hierarchical classification task.

Args:
Expand Down Expand Up @@ -108,6 +112,7 @@ def forward(self, x):
def forward_train(self, cls_score, gt_label, **kwargs):
"""Forward_train fuction of CustomHierarchicalNonLinearClsHead class."""
img_metas = kwargs.get("img_metas", None)
cls_score = self.pre_logits(cls_score)
gt_label = gt_label.type_as(cls_score)
cls_score = self.classifier(cls_score)

Expand Down Expand Up @@ -155,6 +160,7 @@ def forward_train(self, cls_score, gt_label, **kwargs):

def simple_test(self, img):
"""Test without augmentation."""
img = self.pre_logits(img)
cls_score = self.classifier(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from mmcv.cnn import normal_init
from torch import nn

from .mixin import OTXHeadMixin


@HEADS.register_module()
class CustomMultiLabelLinearClsHead(MultiLabelClsHead):
class CustomMultiLabelLinearClsHead(OTXHeadMixin, MultiLabelClsHead):
"""Custom Linear classification head for multilabel task.

Args:
Expand Down Expand Up @@ -73,6 +75,7 @@ def forward(self, x):
def forward_train(self, cls_score, gt_label, **kwargs):
"""Forward_train fuction of CustomMultiLabelLinearClsHead."""
img_metas = kwargs.get("img_metas", False)
cls_score = self.pre_logits(cls_score)
gt_label = gt_label.type_as(cls_score)
cls_score = self.fc(cls_score) * self.scale

Expand All @@ -96,6 +99,7 @@ def forward_train(self, cls_score, gt_label, **kwargs):

def simple_test(self, img):
"""Test without augmentation."""
img = self.pre_logits(img)
cls_score = self.fc(img) * self.scale
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from torch import nn

from .custom_multi_label_linear_cls_head import AnglularLinear
from .mixin import OTXHeadMixin


@HEADS.register_module()
class CustomMultiLabelNonLinearClsHead(MultiLabelClsHead):
class CustomMultiLabelNonLinearClsHead(OTXHeadMixin, MultiLabelClsHead):
"""Non-linear classification head for multilabel task.

Args:
Expand Down Expand Up @@ -102,6 +103,7 @@ def forward(self, x):
def forward_train(self, cls_score, gt_label, **kwargs):
"""Forward_train fuction of CustomMultiLabelNonLinearClsHead."""
img_metas = kwargs.get("img_metas", False)
cls_score = self.pre_logits(cls_score)
gt_label = gt_label.type_as(cls_score)
cls_score = self.classifier(cls_score) * self.scale

Expand All @@ -125,6 +127,7 @@ def forward_train(self, cls_score, gt_label, **kwargs):

def simple_test(self, img):
"""Test without augmentation."""
img = self.pre_logits(img)
cls_score = self.classifier(img) * self.scale
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Module to define CustomVisionTransformerClsHead for classification task."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from mmcls.models.builder import HEADS
from mmcls.models.heads import VisionTransformerClsHead


@HEADS.register_module()
class CustomVisionTransformerClsHead(VisionTransformerClsHead):
"""Custom Vision Transformer classifier head which supports IBLoss loss calculation."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_type = kwargs.get("loss", dict(type="CrossEntropyLoss"))["type"]

def loss(self, cls_score, gt_label, feature=None):
"""Calculate loss for given cls_score/gt_label."""
num_samples = len(cls_score)
losses = dict()
# compute loss
if self.loss_type == "IBLoss":
loss = self.compute_loss(cls_score, gt_label, feature=feature)
else:
loss = self.compute_loss(cls_score, gt_label, avg_factor=num_samples)
if self.cal_acc:
# compute accuracy
acc = self.compute_accuracy(cls_score, gt_label)
assert len(acc) == len(self.topk)
losses["accuracy"] = {f"top-{k}": a for k, a in zip(self.topk, acc)}
losses["loss"] = loss
return losses
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Module defining Mix-in class of heads."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#


class OTXHeadMixin:
harimkang marked this conversation as resolved.
Show resolved Hide resolved
"""Mix-in class for OTX custom heads."""

@staticmethod
def pre_logits(x):
"""Preprocess logits before forward. Designed to support vision transformer output."""
if isinstance(x, list):
x = x[-1]
return x
return x
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
NonLinearClsHead,
)

from .mixin import OTXHeadMixin

class SemiClsHead:

class SemiClsHead(OTXHeadMixin):
"""Classification head for Semi-SL.

Args:
Expand Down Expand Up @@ -76,6 +78,8 @@ def forward_train(self, x, gt_label, final_layer=None): # pylint: disable=too-m
"""
label_u, mask = None, None
if isinstance(x, dict):
for key in x.keys():
x[key] = self.pre_logits(x[key])
harimkang marked this conversation as resolved.
Show resolved Hide resolved
outputs = final_layer(x["labeled"]) # Logit of Labeled Img
batch_size = len(outputs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
CustomMultiLabelNonLinearClsHead,
)

from .mixin import OTXHeadMixin


def generate_aux_mlp(aux_mlp_cfg: dict, in_channels: int):
"""Generate auxiliary MLP."""
Expand Down Expand Up @@ -96,7 +98,7 @@ def balance_losses(self, losses):
return total_loss


class SemiMultilabelClsHead:
class SemiMultilabelClsHead(OTXHeadMixin):
"""Multilabel Classification head for Semi-SL.

Args:
Expand Down Expand Up @@ -167,6 +169,8 @@ def forward_train_with_last_layers(self, x, gt_label, final_cls_layer, final_emb
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
for key in x.keys():
x[key] = self.pre_logits(x[key])
logits = final_cls_layer(x["labeled_weak"])
features_weak = torch.cat((final_emb_layer(x["labeled_weak"]), final_emb_layer(x["unlabeled_weak"])))
features_strong = torch.cat((final_emb_layer(x["labeled_strong"]), final_emb_layer(x["unlabeled_strong"])))
Expand Down
10 changes: 8 additions & 2 deletions otx/algorithms/classification/adapters/mmcls/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mmcv.runner import wrap_fp16_model
from mmcv.utils import Config, ConfigDict

from otx.algorithms import TRANSFORMER_BACKBONES
from otx.algorithms.classification.adapters.mmcls.utils.exporter import (
ClassificationExporter,
)
Expand Down Expand Up @@ -281,11 +282,16 @@ def hook(module, inp, outp):
model.register_forward_pre_hook(pre_hook)
model.register_forward_hook(hook)

if not dump_saliency_map:
model_type = cfg.model.backbone.type.split(".")[-1] # mmcls.VisionTransformer => VisionTransformer
if (
not dump_saliency_map or model_type in TRANSFORMER_BACKBONES
): # TODO: remove latter "or" condition after resolving Issue#2098
forward_explainer_hook: Union[nullcontext, BaseRecordingForwardHook] = nullcontext()
else:
forward_explainer_hook = ReciproCAMHook(feature_model)
if not dump_features:
if (
not dump_features or model_type in TRANSFORMER_BACKBONES
): # TODO: remove latter "or" condition after resolving Issue#2098
feature_vector_hook: Union[nullcontext, BaseRecordingForwardHook] = nullcontext()
else:
feature_vector_hook = FeatureVectorHook(feature_model)
Expand Down
18 changes: 18 additions & 0 deletions otx/algorithms/classification/configs/base/models/deit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Base deit config."""

# model settings
model = dict(
type="ImageClassifier",
backbone=dict(type="mmcls.VisionTransformer", arch="deit-small", img_size=224, patch_size=16),
neck=None,
head=dict(
type="CustomVisionTransformerClsHead",
num_classes=1000,
in_channels=384,
loss=dict(type="CrossEntropyLoss", loss_weight=1.0),
),
init_cfg=[
dict(type="TruncNormal", layer="Linear", std=0.02),
dict(type="Constant", layer="LayerNorm", val=1.0, bias=0.0),
],
)
15 changes: 15 additions & 0 deletions otx/algorithms/classification/configs/deit_tiny/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Initialization of deit-tiny model for Classification Task."""

# Copyright (C) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Data Pipeline of deit-tiny model for Classification Task."""

# Copyright (C) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

# pylint: disable=invalid-name
_base_ = ["../base/data/data_pipeline.py"]
11 changes: 11 additions & 0 deletions otx/algorithms/classification/configs/deit_tiny/deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""deit-tiny for multi-class MMDeploy config."""

_base_ = ["../base/deployments/base_classification_dynamic.py"]

ir_config = dict(
output_names=["logits"],
)

backend_config = dict(
model_inputs=[dict(opt_shapes=dict(input=[-1, 3, 224, 224]))],
)
15 changes: 15 additions & 0 deletions otx/algorithms/classification/configs/deit_tiny/hpo_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
metric: accuracy
search_algorithm: asha
hp_space:
learning_parameters.learning_rate:
param_type: qloguniform
range:
- 0.00001
- 0.001
- 0.00001
learning_parameters.batch_size:
param_type: qloguniform
range:
- 42
- 96
- 2
17 changes: 17 additions & 0 deletions otx/algorithms/classification/configs/deit_tiny/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""deit-tiny for multi-class config."""

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/deit.py"]
ckpt_url = "https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth"

model = dict(
type="SAMImageClassifier",
task="classification",
backbone=dict(arch="deit-tiny", init_cfg=dict(type="Pretrained", checkpoint=ckpt_url, prefix="backbone")),
)

fp16 = dict(loss_scale=512.0)

optimizer = dict(_delete_=True, type="AdamW", lr=0.01, weight_decay=0.05)
optimizer_config = dict(_delete_=True)
Loading