Skip to content

Commit

Permalink
Improve YOLOv9 performance (#3953)
Browse files Browse the repository at this point in the history
* Fix image channel

* Fix loss weights

* Update recipe

* Update min_confidence and min_iou

* Update to use fixed cfg for export

* Update recipe

* Skip batch without ground truth

* Update recipes

* Add comment

* Fix unit test

* Add unit test

* Update export parameter

* Fix unit test
  • Loading branch information
sungchul2 authored Sep 26, 2024
1 parent 88bb571 commit 28f6529
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 51 deletions.
8 changes: 4 additions & 4 deletions src/otx/algo/detection/heads/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ class YOLOHeadModule(BaseDenseHead):
csp_args (dict[str, Any], optional): Arguments for CSP blocks. Defaults to None.
aux_cfg (dict[str, Any], optional): Configuration for auxiliary head. Defaults to None.
with_nms (bool, optional): Whether to use NMS. Defaults to True.
min_confidence (float, optional): Minimum confidence for NMS. Defaults to 0.05.
min_iou (float, optional): Minimum IoU for NMS. Defaults to 0.9.
min_confidence (float, optional): Minimum confidence for NMS. Defaults to 0.1.
min_iou (float, optional): Minimum IoU for NMS. Defaults to 0.65.
"""

def __init__(
Expand All @@ -311,8 +311,8 @@ def __init__(
csp_args: dict[str, Any] | None = None,
aux_cfg: dict[str, Any] | None = None,
with_nms: bool = True,
min_confidence: float = 0.05,
min_iou: float = 0.9,
min_confidence: float = 0.1,
min_iou: float = 0.65,
) -> None:
if len(csp_channels) - 1 != len(concat_sources):
msg = (
Expand Down
12 changes: 8 additions & 4 deletions src/otx/algo/detection/losses/yolov9_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ def __init__(
loss_dfl: nn.Module | None = None,
loss_iou: nn.Module | None = None,
reg_max: int = 16,
cls_rate: float = 1.5,
dfl_rate: float = 7.5,
iou_rate: float = 0.5,
cls_rate: float = 0.5,
dfl_rate: float = 1.5,
iou_rate: float = 7.5,
aux_rate: float = 0.25,
) -> None:
super().__init__()
Expand All @@ -394,7 +394,7 @@ def forward(
main_preds: tuple[Tensor, Tensor, Tensor],
targets: Tensor,
aux_preds: tuple[Tensor, Tensor, Tensor] | None = None,
) -> dict[str, Tensor]:
) -> dict[str, Tensor] | None:
"""Forward pass of the YOLOv9 criterion module.
Args:
Expand All @@ -405,6 +405,10 @@ def forward(
Returns:
dict[str, Tensor]: The loss dictionary.
"""
if targets.shape[1] == 0:
# TODO (sungchul): should this step be done here?
return None

main_preds = self.vec2box(main_preds)
main_iou, main_dfl, main_cls = self._forward(main_preds, targets)

Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/detection/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _customize_inputs(

def _customize_outputs(
self,
outputs: list[torch.Tensor] | dict,
outputs: list[torch.Tensor] | dict, # type: ignore[override]
inputs: DetBatchDataEntity,
) -> DetBatchPredEntity | OTXBatchLossEntity:
if self.training:
Expand Down
11 changes: 10 additions & 1 deletion src/otx/algo/detection/yolov9.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from otx.core.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.detection import OTXDetectionModel
from otx.core.types.export import TaskLevelExportParameters

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand Down Expand Up @@ -116,7 +117,7 @@ def _exporter(self) -> OTXModelExporter:
std=self.std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=True,
swap_rgb=False,
via_onnx=True,
onnx_export_configuration={
"input_names": ["image"],
Expand All @@ -135,6 +136,14 @@ def _exporter(self) -> OTXModelExporter:
output_names=None, # TODO (someone): support XAI
)

@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
return super()._export_parameters.wrap(
confidence_threshold=self.model.bbox_head.min_confidence,
iou_threshold=self.model.bbox_head.min_iou,
)

def to(self, *args, **kwargs) -> Self:
"""Sync device of the model and its components."""
ret = super().to(*args, **kwargs)
Expand Down
6 changes: 5 additions & 1 deletion src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,13 @@ def __init__(
# so that it can retrieve it from the checkpoint
self.save_hyperparameters(logger=False, ignore=["optimizer", "scheduler", "metric"])

def training_step(self, batch: T_OTXBatchDataEntity, batch_idx: int) -> Tensor:
def training_step(self, batch: T_OTXBatchDataEntity, batch_idx: int) -> Tensor | None:
"""Step for model training."""
train_loss = self.forward(inputs=batch)
if train_loss is None:
# to skip current iteration
# TODO (sungchul): check this in distributed training
return None if self.trainer.world_size == 1 else torch.tensor(0.0, device=self.device)

if isinstance(train_loss, Tensor):
self.log(
Expand Down
9 changes: 6 additions & 3 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,15 @@ def _customize_inputs(

return inputs

def _customize_outputs(
def _customize_outputs( # type: ignore[override]
self,
outputs: list[InstanceData] | dict,
outputs: list[InstanceData] | dict | None,
inputs: DetBatchDataEntity,
) -> DetBatchPredEntity | OTXBatchLossEntity:
) -> DetBatchPredEntity | OTXBatchLossEntity | None:
if self.training:
if outputs is None:
return outputs

if not isinstance(outputs, dict):
raise TypeError(outputs)

Expand Down
39 changes: 27 additions & 12 deletions src/otx/recipe/detection/yolov9_c.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ model:
optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.001
lr: 0.0001
momentum: 0.937
weight_decay: 0.0005
nesterov: true
Expand All @@ -16,13 +16,13 @@ model:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 3
warmup_interval: epoch
main_scheduler_callable:
class_path: torch.optim.lr_scheduler.LinearLR
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
total_iters: 200
start_factor: 1
end_factor: 0.01
mode: max
factor: 0.1
patience: 4
monitor: val/map_50

engine:
task: DETECTION
Expand All @@ -42,23 +42,38 @@ overrides:
input_size:
- 640
- 640
image_color_channel: BGR
train_subset:
batch_size: 16
batch_size: 10
transforms:
- class_path: otx.core.data.transform_libs.torchvision.CachedMosaic
init_args:
random_pop: false
max_cached_images: 20
img_scale: $(input_size) # (H, W)
- class_path: otx.core.data.transform_libs.torchvision.RandomCrop
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
init_args:
crop_size: $(input_size) * 0.5
scaling_ratio_range:
- 0.1
- 2.0
border: $(input_size) * -0.5
- class_path: otx.core.data.transform_libs.torchvision.CachedMixUp
init_args:
img_scale: $(input_size) # (H, W)
ratio_range:
- 1.0
- 1.0
prob: 0.5
random_pop: false
max_cached_images: 10
- class_path: otx.core.data.transform_libs.torchvision.YOLOXHSVRandomAug
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: true
transform_bbox: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: true
Expand All @@ -75,7 +90,7 @@ overrides:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
batch_size: 16
batch_size: 10
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand All @@ -95,7 +110,7 @@ overrides:
std: [255.0, 255.0, 255.0]

test_subset:
batch_size: 16
batch_size: 10
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand Down
29 changes: 22 additions & 7 deletions src/otx/recipe/detection/yolov9_m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ model:
optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.001
lr: 0.0001
momentum: 0.937
weight_decay: 0.0005
nesterov: true
Expand Down Expand Up @@ -42,23 +42,38 @@ overrides:
input_size:
- 640
- 640
image_color_channel: BGR
train_subset:
batch_size: 16
batch_size: 12
transforms:
- class_path: otx.core.data.transform_libs.torchvision.CachedMosaic
init_args:
random_pop: false
max_cached_images: 20
img_scale: $(input_size) # (H, W)
- class_path: otx.core.data.transform_libs.torchvision.RandomCrop
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
init_args:
crop_size: $(input_size) * 0.5
scaling_ratio_range:
- 0.1
- 2.0
border: $(input_size) * -0.5
- class_path: otx.core.data.transform_libs.torchvision.CachedMixUp
init_args:
img_scale: $(input_size) # (H, W)
ratio_range:
- 1.0
- 1.0
prob: 0.5
random_pop: false
max_cached_images: 10
- class_path: otx.core.data.transform_libs.torchvision.YOLOXHSVRandomAug
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: true
transform_bbox: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: true
Expand All @@ -75,7 +90,7 @@ overrides:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
batch_size: 16
batch_size: 12
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand All @@ -95,7 +110,7 @@ overrides:
std: [255.0, 255.0, 255.0]

test_subset:
batch_size: 16
batch_size: 12
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand Down
34 changes: 18 additions & 16 deletions src/otx/recipe/detection/yolov9_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,21 @@ model:
label_info: 80

optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.Adam
init_args:
lr: 0.01
momentum: 0.937
weight_decay: 0.0005
nesterov: true
lr: 0.001

scheduler:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 3
warmup_interval: epoch
main_scheduler_callable:
class_path: torch.optim.lr_scheduler.LinearLR
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
total_iters: 200
start_factor: 1
end_factor: 0.01
mode: max
factor: 0.1
patience: 4
monitor: val/map_50

engine:
task: DETECTION
Expand All @@ -42,23 +39,28 @@ overrides:
input_size:
- 640
- 640
image_color_channel: BGR
train_subset:
batch_size: 16
batch_size: 14
transforms:
- class_path: otx.core.data.transform_libs.torchvision.CachedMosaic
init_args:
random_pop: false
max_cached_images: 20
img_scale: $(input_size) # (H, W)
- class_path: otx.core.data.transform_libs.torchvision.RandomCrop
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
init_args:
crop_size: $(input_size) * 0.5
scaling_ratio_range:
- 0.1
- 2.0
border: $(input_size) * -0.5
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: true
transform_bbox: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: true
Expand All @@ -75,7 +77,7 @@ overrides:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
batch_size: 16
batch_size: 14
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand All @@ -95,7 +97,7 @@ overrides:
std: [255.0, 255.0, 255.0]

test_subset:
batch_size: 16
batch_size: 14
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/algo/detection/losses/test_yolov9_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_forward(self, mocker, criterion: YOLOv9Criterion) -> None:
return_value=(torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)),
)
main_preds = torch.tensor(0.0)
targets = torch.tensor(0.0)
targets = torch.zeros(1, 1, 4)
aux_preds = torch.tensor(0.0)

loss_dict = criterion.forward(main_preds, targets, aux_preds)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/algo/detection/test_yolov9.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_exporter(self) -> None:
otx_yolov9_s = YOLOv9(model_name="yolov9_s", label_info=3)
otx_yolov9_s_exporter = otx_yolov9_s._exporter
assert isinstance(otx_yolov9_s_exporter, OTXNativeModelExporter)
assert otx_yolov9_s_exporter.swap_rgb is True
assert otx_yolov9_s_exporter.swap_rgb is False

def test_to(self) -> None:
model = YOLOv9(model_name="yolov9_s", label_info=3)
Expand Down
Loading

0 comments on commit 28f6529

Please sign in to comment.