Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…nsions into v2
  • Loading branch information
kprokofi committed Feb 18, 2024
2 parents 8b69f62 + d9f7e15 commit 7efa031
Show file tree
Hide file tree
Showing 99 changed files with 2,425 additions and 547 deletions.
16 changes: 10 additions & 6 deletions .github/workflows/pre_merge.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,25 @@ jobs:
fail-fast: false
matrix:
include:
- python-version: "3.10"
tox-env: "py310"
name: Integration-Test-Py${{ matrix.python-version }}
- task: "action"
- task: "classification"
- task: "detection"
- task: "instance_segmentation"
- task: "semantic_segmentation"
- task: "visual_prompting"
name: Integration-Test-${{ matrix.task }}-py310
# This is what will cancel the job concurrency
concurrency:
group: ${{ github.workflow }}-Integration-${{ github.event.pull_request.number || github.ref }}
group: ${{ github.workflow }}-Integration-${{ github.event.pull_request.number || github.ref }}-${{ matrix.task }}
cancel-in-progress: true
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
python-version: "3.10"
- name: Install tox
run: python -m pip install tox
- name: Run Integration Test
run: tox -vv -e integration-test
run: tox -vv -e integration-test-${{ matrix.task }}
6 changes: 3 additions & 3 deletions src/otx/algo/detection/heads/custom_anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def __init__(
self.centers = [(stride / 2.0, stride / 2.0) for stride in strides]

self.center_offset = 0
self.base_anchors = self.gen_base_anchors()
self.gen_base_anchors()
self.use_box_type = False

def gen_base_anchors(self) -> list[torch.Tensor]:
def gen_base_anchors(self) -> None:
"""Generate base anchor for SSD."""
multi_level_base_anchors = []
for widths, heights, centers in zip(self.widths, self.heights, self.centers):
Expand All @@ -47,7 +47,7 @@ def gen_base_anchors(self) -> list[torch.Tensor]:
center=torch.Tensor(centers),
)
multi_level_base_anchors.append(base_anchors)
return multi_level_base_anchors
self.base_anchors = multi_level_base_anchors

def gen_single_level_base_anchors(
self,
Expand Down
124 changes: 123 additions & 1 deletion src/otx/algo/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,31 @@

from __future__ import annotations

import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
from datumaro.components.annotation import Bbox

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.detection import MMDetCompatibleModel
from otx.core.utils.build import build_mm_model, modify_num_classes

if TYPE_CHECKING:
import torch
from lightning import Trainer
from mmdet.models.task_modules.prior_generators.anchor_generator import AnchorGenerator
from mmengine.registry import Registry
from omegaconf import DictConfig
from torch import device, nn

from otx.core.data.dataset.base import OTXDataset


logger = logging.getLogger()


class SSD(MMDetCompatibleModel):
"""Detecion model class for SSD."""
Expand All @@ -28,6 +39,7 @@ def __init__(self, num_classes: int, variant: Literal["mobilenetv2"]) -> None:
config = read_mmconfig(model_name=model_name)
super().__init__(num_classes=num_classes, config=config)
self.image_size = (1, 3, 864, 864)
self._register_load_state_dict_pre_hook(self._set_anchors_hook)

def _create_model(self) -> nn.Module:
from mmdet.models.data_preprocessors import (
Expand All @@ -52,6 +64,94 @@ def device(self) -> device:
self.classification_layers = self.get_classification_layers(self.config, MODELS, "model.")
return build_mm_model(self.config, MODELS, self.load_from)

def setup_callback(self, trainer: Trainer) -> None:
"""Callback for setup OTX Model.
OTXSSD requires auto anchor generating w.r.t. training dataset for better accuracy.
This callback will provide training dataset to model's anchor generator.
Args:
trainer(Trainer): Lightning trainer contains OTXLitModule and OTXDatamodule.
"""
if trainer.training:
anchor_generator = self.model.bbox_head.anchor_generator
dataset = trainer.datamodule.train_dataloader().dataset
new_anchors = self._get_new_anchors(dataset, anchor_generator)
if new_anchors is not None:
logger.warning("Anchor will be updated by Dataset's statistics")
logger.warning(f"{anchor_generator.widths} -> {new_anchors[0]}")
logger.warning(f"{anchor_generator.heights} -> {new_anchors[1]}")
anchor_generator.widths = new_anchors[0]
anchor_generator.heights = new_anchors[1]
anchor_generator.gen_base_anchors()

def _get_new_anchors(self, dataset: OTXDataset, anchor_generator: AnchorGenerator) -> tuple | None:
"""Get new anchors for SSD from OTXDataset."""
from mmdet.datasets.transforms import Resize

target_wh = None
if isinstance(dataset.transforms, list):
for transform in dataset.transforms:
if isinstance(transform, Resize):
target_wh = transform.scale
if target_wh is None:
target_wh = (864, 864)
msg = f"Cannot get target_wh from the dataset. Assign it with the default value: {target_wh}"
logger.warning(msg)
group_as = [len(width) for width in anchor_generator.widths]
wh_stats = self._get_sizes_from_dataset_entity(dataset, list(target_wh))

if len(wh_stats) < sum(group_as):
logger.warning(
f"There are not enough objects to cluster: {len(wh_stats)} were detected, while it should be "
f"at least {sum(group_as)}. Anchor box clustering was skipped.",
)
return None

return self._get_anchor_boxes(wh_stats, group_as)

@staticmethod
def _get_sizes_from_dataset_entity(dataset: OTXDataset, target_wh: list[int]) -> list[tuple[int, int]]:
"""Function to get width and height size of items in OTXDataset.
Args:
dataset(OTXDataset): OTXDataset in which to get statistics
target_wh(list[int]): target width and height of the dataset
Return
list[tuple[int, int]]: tuples with width and height of each instance
"""
wh_stats: list[tuple[int, int]] = []
for item in dataset.dm_subset:
for ann in item.annotations:
if isinstance(ann, Bbox):
x1, y1, x2, y2 = ann.points
x1 = x1 / item.media.size[1] * target_wh[0]
y1 = y1 / item.media.size[0] * target_wh[1]
x2 = x2 / item.media.size[1] * target_wh[0]
y2 = y2 / item.media.size[0] * target_wh[1]
wh_stats.append((x2 - x1, y2 - y1))
return wh_stats

@staticmethod
def _get_anchor_boxes(wh_stats: list[tuple[int, int]], group_as: list[int]) -> tuple:
"""Get new anchor box widths & heights using KMeans."""
from sklearn.cluster import KMeans

kmeans = KMeans(init="k-means++", n_clusters=sum(group_as), random_state=0).fit(wh_stats)
centers = kmeans.cluster_centers_

areas = np.sqrt(np.prod(centers, axis=1))
idx = np.argsort(areas)

widths = centers[idx, 0]
heights = centers[idx, 1]

group_as = np.cumsum(group_as[:-1])
widths, heights = np.split(widths, group_as), np.split(heights, group_as)
widths = [width.tolist() for width in widths]
heights = [height.tolist() for height in heights]
return widths, heights

@staticmethod
def get_classification_layers(
config: DictConfig,
Expand Down Expand Up @@ -95,6 +195,19 @@ def get_classification_layers(
classification_layers[prefix + key] = {"use_bg": use_bg, "num_anchors": num_anchors}
return classification_layers

def state_dict(self, *args, **kwargs) -> dict[str, Any]:
"""Return state dictionary of model entity with anchor information.
Returns:
A dictionary containing SSD state.
"""
state_dict = super().state_dict(*args, **kwargs)
anchor_generator = self.model.bbox_head.anchor_generator
anchors = {"heights": anchor_generator.heights, "widths": anchor_generator.widths}
state_dict["model.model.anchors"] = anchors
return state_dict

def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) -> None:
"""Modify input state_dict according to class name matching before weight loading."""
model2ckpt = self.map_class_names(self.model_classes, self.ckpt_classes)
Expand Down Expand Up @@ -138,6 +251,15 @@ def _export_parameters(self) -> dict[str, Any]:

return export_params

def _set_anchors_hook(self, state_dict: dict[str, Any], *args, **kwargs) -> None:
"""Pre hook for pop anchor statistics from checkpoint state_dict."""
anchors = state_dict.pop("model.model.anchors", None)
if anchors is not None:
anchor_generator = self.model.bbox_head.anchor_generator
anchor_generator.widths = anchors["widths"]
anchor_generator.heights = anchors["heights"]
anchor_generator.gen_base_anchors()

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_det_ckpt(state_dict, add_prefix)
return OTXv1Helper.load_ssd_ckpt(state_dict, add_prefix)
14 changes: 11 additions & 3 deletions src/otx/algo/hooks/recording_forward_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import torch

from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity, InstanceSegBatchPredEntityWithXAI

if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
Expand Down Expand Up @@ -409,7 +409,11 @@ def create_and_register_hook(cls, num_classes: int) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
return cls(num_classes)

def func(self, preds: list[InstanceSegBatchPredEntity], _: int = -1) -> list[np.array]:
def func(
self,
preds: list[InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI],
_: int = -1,
) -> list[np.array]:
"""Generate saliency maps from predicted masks by averaging and normalizing them per-class.
Args:
Expand All @@ -428,7 +432,11 @@ def func(self, preds: list[InstanceSegBatchPredEntity], _: int = -1) -> list[np.
return batch_saliency_maps

@classmethod
def average_and_normalize(cls, pred: InstanceSegBatchPredEntity, num_classes: int) -> np.array:
def average_and_normalize(
cls,
pred: InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI,
num_classes: int,
) -> np.array:
"""Average and normalize masks in prediction per-class.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/otx/algo/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
#
"""Custom schedulers for the OTX2.0."""

from .warmup_schedulers import WarmupReduceLROnPlateau
from .warmup_schedulers import LinearWarmupScheduler

__all__ = ["WarmupReduceLROnPlateau"]
__all__ = ["LinearWarmupScheduler"]
Loading

0 comments on commit 7efa031

Please sign in to comment.