Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring the SOTA(DINOv2) model to the classification task #2708

Merged
merged 20 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ exclude = [

# it will be cleaned up later
"src/otx/core/engine/utils/*",
"src/otx/algo/classification/model/backbones/*",
"src/otx/algo/classification/backbones/*",
]

# Same as Black.
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for OTX classification."""
"""Module for OTX classification models."""

from . import model
from . import backbones
from .otx_dino_v2 import DINOv2RegisterClassifier

__all__ = ["model"]
__all__ = ["backbones", "DINOv2RegisterClassifier"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved

"""EfficientNetV2 model.

Original papers:
Expand Down
8 changes: 0 additions & 8 deletions src/otx/algo/classification/model/__init__.py

This file was deleted.

138 changes: 138 additions & 0 deletions src/otx/algo/classification/otx_dino_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""DINO-V2 model for the OTX classification."""

from __future__ import annotations

from collections import OrderedDict
from typing import TYPE_CHECKING, Any

import torch
from torch import nn

from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity
from otx.core.model.entity.classification import OTXClassificationModel

if TYPE_CHECKING:
from omegaconf import DictConfig

class DINOv2RegisterClassifier(OTXClassificationModel):
"""DINO-v2 Classification Model."""
def __init__(self, config: DictConfig) -> None:
self.config = config
super().__init__() # create the model

self.loss = nn.CrossEntropyLoss()

# NOTE,
# We've decided to use MMpretrain's pipeline for this model
# It's hard to use ClsDataPreprocessor since the model is not related to MMpretrain
# That's the reason why I implemented the below preprocess things
self.data_preprocess_cfg = self.config.data_preprocess
self.register_buffer(
'mean', torch.tensor(self.data_preprocess_cfg.mean).view(-1, 1, 1), False,
)
self.register_buffer(
'std', torch.tensor(self.data_preprocess_cfg.std).view(-1, 1, 1), False,
)

def _create_model(self) -> nn.Module:
"""Create the model."""
self.backbone = torch.hub.load(
repo_or_dir="facebookresearch/dinov2",
model=self.config.backbone.name,
)
if self.config.backbone.frozen:
self._freeze_backbone(self.backbone)

self.head = nn.Linear(
self.config.head.in_channels,
self.config.head.num_classes,
)

return nn.Sequential(
OrderedDict([
("backbone", self.backbone),
("head", self.head),
]),
)
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved

def _freeze_backbone(self, backbone: nn.Module) -> None:
"""Freeze the backbone."""
for _, v in backbone.named_parameters():
v.requires_grad = False

def _preprocess_img(self, imgs: torch.Tensor) -> torch.Tensor:
"""Control normalize and BGR/RGB conversion."""
# BGR -> RGB
if self.data_preprocess_cfg.to_rgb and imgs.size(1) == 3:
imgs = imgs.flip(1)

# Normalization
imgs = imgs.float()
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved
return (imgs - self.mean) / self.std

def forward(
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved
self,
inputs: MulticlassClsBatchDataEntity,
) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity:
"""Forward function.

The output of the forward function should be the loss during training
and MulticlassBatchPredEntity during validation.
"""
customized_inputs = self._customize_inputs(inputs)
feats = self.model(customized_inputs["x"])
if self.training:
outputs = self.loss(feats, customized_inputs["labels"])
else:
pred_scores = nn.functional.softmax(feats, dim=1)
max_pred_elements, max_pred_idxs = torch.max(pred_scores, dim=1)
pred_scores = max_pred_elements.cpu().detach().numpy()
pred_labels = max_pred_idxs.cpu().detach().numpy()
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved

outputs = {
"pred_scores": pred_scores,
"pred_labels": pred_labels,
}
return self._customize_outputs(outputs, inputs)

def _customize_inputs(self, entity: MulticlassClsBatchDataEntity) -> dict[str, Any]:
"""Customize the inputs for the model."""
inputs: dict[str, Any] = {}
inputs["x"] = self._preprocess_img(
torch.stack([torch.as_tensor(image) for image in entity.images]),
)
inputs["labels"] = torch.cat([torch.as_tensor(label) for label in entity.labels])
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved
return inputs

def _customize_outputs(
self,
outputs: Any, # noqa: ANN401
inputs: MulticlassClsBatchDataEntity,
) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity:
"""Customize the outputs for the model."""
if self.training:
if not isinstance(outputs, torch.Tensor):
raise TypeError(outputs)

losses = OTXBatchLossEntity()
losses["loss"] = outputs
return losses

batch_size = outputs["pred_labels"].shape[0]
scores = []
labels = []
for b in range(batch_size):
scores.append(outputs["pred_scores"][b])
labels.append(outputs["pred_labels"][b])
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved

return MulticlassClsBatchPredEntity(
batch_size=batch_size,
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
labels=labels,
)
12 changes: 12 additions & 0 deletions src/otx/config/model/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
optimizer:
_target_: torch.optim.Adam
harimkang marked this conversation as resolved.
Show resolved Hide resolved
_partial_: true
lr: 1e-3
weight_decay: 0.0

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10
16 changes: 3 additions & 13 deletions src/otx/config/model/mmdet.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
_target_: otx.core.model.module.detection.OTXDetectionLitModule

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 1e-3
weight_decay: 0.0
defaults:
- default

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10
_target_: otx.core.model.module.detection.OTXDetectionLitModule

otx_model:
_target_: otx.core.model.entity.detection.MMDetCompatibleModel
Expand Down
16 changes: 3 additions & 13 deletions src/otx/config/model/mmpretrain.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
_target_: otx.core.model.module.classification.OTXClassificationLitModule

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 1e-3
weight_decay: 0.0
defaults:
- default

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10
_target_: otx.core.model.module.classification.OTXClassificationLitModule

otx_model:
_target_: otx.core.model.entity.classification.MMPretrainCompatibleModel
Expand Down
11 changes: 11 additions & 0 deletions src/otx/config/model/torch_classification.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- default

_target_: otx.core.model.module.classification.OTXClassificationLitModule

otx_model:
_target_: otx.core.model.entity.classification.OTXClassificationModel
config: ???

# compile model for faster training with pytorch 2.0
torch_compile: false
68 changes: 68 additions & 0 deletions src/otx/recipe/classification/otx_dino_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# @package _global_
defaults:
- override /base: classification
- override /data: mmpretrain
- override /model: torch_classification
data:
train_subset:
batch_size: 64
transforms:
- type: LoadImageFromFile
to_float32: true
- backend: cv2
scale: 224
type: RandomResizedCrop
- type: PackInputs
val_subset:
batch_size: 64
transforms:
- type: LoadImageFromFile
to_float32: true
- backend: cv2
edge: short
scale: 256
type: ResizeEdge
- crop_size: 224
type: CenterCrop
- type: PackInputs
test_subset:
batch_size: 64
transforms:
- type: LoadImageFromFile
to_float32: true
- backend: cv2
edge: short
scale: 256
type: ResizeEdge
- crop_size: 224
type: CenterCrop
- type: PackInputs
model:
otx_model:
_target_: otx.algo.classification.otx_dino_v2.DINOv2RegisterClassifier
config:
backbone:
name: dinov2_vits14_reg
frozen: true
head:
in_channels: 384
num_classes: 1000
data_preprocess:
mean:
- 123.675
- 116.28
- 103.53
std:
- 58.395
- 57.12
- 57.375
to_rgb: True
optimizer:
_target_: torch.optim.SGD
_partial_: true
lr: 0.007
momentum: 0.9
weight_decay: 0.0001
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved
scheduler:
factor: 0.5
patience: 1
2 changes: 2 additions & 0 deletions tests/unit/algo/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
51 changes: 51 additions & 0 deletions tests/unit/algo/classification/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import annotations

from unittest.mock import MagicMock

import pytest
import torch
from omegaconf import DictConfig
from torchvision import tv_tensors

from src.otx.core.data.entity.base import ImageInfo
from src.otx.core.data.entity.classification import MulticlassClsBatchDataEntity


@pytest.fixture()
def fxt_multiclass_cls_batch_data_entity() -> MulticlassClsBatchDataEntity:
batch_size = 2
random_tensor = torch.randn((batch_size, 3, 224, 224))
tv_tensor = tv_tensors.Image(data=random_tensor)
img_infos = [ImageInfo(
img_idx=i,
img_shape=(224, 224),
ori_shape=(224, 224),
pad_shape=(0, 0),
scale_factor=(1.0, 1.0),
) for i in range(batch_size)]
return MulticlassClsBatchDataEntity(
batch_size=2,
images=tv_tensor,
imgs_info=img_infos,
labels=[torch.tensor([0]), torch.tensor([1])],
)

@pytest.fixture()
def fxt_config_mock() -> DictConfig:
config_mock = MagicMock(spec=DictConfig)
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved
config_mock.backbone = MagicMock(spec=DictConfig)
config_mock.backbone.name = "dinov2_vits14_reg"
config_mock.backbone.frozen = False

config_mock.head = MagicMock(spec=DictConfig)
config_mock.head.in_channels = 384
config_mock.head.num_classes = 2

config_mock.data_preprocess = MagicMock(spec=DictConfig)
config_mock.data_preprocess.to_rgb = True
config_mock.data_preprocess.mean = [1, 1, 1]
config_mock.data_preprocess.std = [1, 1, 1]
return config_mock
Loading
Loading