Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve YOLOv9 performance #3953

Merged
merged 14 commits into from
Sep 26, 2024
8 changes: 4 additions & 4 deletions src/otx/algo/detection/heads/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ class YOLOHeadModule(BaseDenseHead):
csp_args (dict[str, Any], optional): Arguments for CSP blocks. Defaults to None.
aux_cfg (dict[str, Any], optional): Configuration for auxiliary head. Defaults to None.
with_nms (bool, optional): Whether to use NMS. Defaults to True.
min_confidence (float, optional): Minimum confidence for NMS. Defaults to 0.05.
min_iou (float, optional): Minimum IoU for NMS. Defaults to 0.9.
min_confidence (float, optional): Minimum confidence for NMS. Defaults to 0.1.
min_iou (float, optional): Minimum IoU for NMS. Defaults to 0.65.
"""

def __init__(
Expand All @@ -311,8 +311,8 @@ def __init__(
csp_args: dict[str, Any] | None = None,
aux_cfg: dict[str, Any] | None = None,
with_nms: bool = True,
min_confidence: float = 0.05,
min_iou: float = 0.9,
min_confidence: float = 0.1,
min_iou: float = 0.65,
) -> None:
if len(csp_channels) - 1 != len(concat_sources):
msg = (
Expand Down
12 changes: 8 additions & 4 deletions src/otx/algo/detection/losses/yolov9_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ def __init__(
loss_dfl: nn.Module | None = None,
loss_iou: nn.Module | None = None,
reg_max: int = 16,
cls_rate: float = 1.5,
dfl_rate: float = 7.5,
iou_rate: float = 0.5,
cls_rate: float = 0.5,
dfl_rate: float = 1.5,
iou_rate: float = 7.5,
aux_rate: float = 0.25,
) -> None:
super().__init__()
Expand All @@ -394,7 +394,7 @@ def forward(
main_preds: tuple[Tensor, Tensor, Tensor],
targets: Tensor,
aux_preds: tuple[Tensor, Tensor, Tensor] | None = None,
) -> dict[str, Tensor]:
) -> dict[str, Tensor] | None:
"""Forward pass of the YOLOv9 criterion module.

Args:
Expand All @@ -405,6 +405,10 @@ def forward(
Returns:
dict[str, Tensor]: The loss dictionary.
"""
if targets.shape[1] == 0:
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
# TODO (sungchul): should this step be done here?
return None

main_preds = self.vec2box(main_preds)
main_iou, main_dfl, main_cls = self._forward(main_preds, targets)

Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/detection/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _customize_inputs(

def _customize_outputs(
self,
outputs: list[torch.Tensor] | dict,
outputs: list[torch.Tensor] | dict, # type: ignore[override]
inputs: DetBatchDataEntity,
) -> DetBatchPredEntity | OTXBatchLossEntity:
if self.training:
Expand Down
11 changes: 10 additions & 1 deletion src/otx/algo/detection/yolov9.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from otx.core.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.detection import OTXDetectionModel
from otx.core.types.export import TaskLevelExportParameters

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand Down Expand Up @@ -116,7 +117,7 @@ def _exporter(self) -> OTXModelExporter:
std=self.std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=True,
swap_rgb=False,
via_onnx=True,
onnx_export_configuration={
"input_names": ["image"],
Expand All @@ -135,6 +136,14 @@ def _exporter(self) -> OTXModelExporter:
output_names=None, # TODO (someone): support XAI
)

@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
return super()._export_parameters.wrap(
confidence_threshold=self.model.bbox_head.min_confidence,
iou_threshold=self.model.bbox_head.min_iou,
)

def to(self, *args, **kwargs) -> Self:
"""Sync device of the model and its components."""
ret = super().to(*args, **kwargs)
Expand Down
6 changes: 5 additions & 1 deletion src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,13 @@ def __init__(
# so that it can retrieve it from the checkpoint
self.save_hyperparameters(logger=False, ignore=["optimizer", "scheduler", "metric"])

def training_step(self, batch: T_OTXBatchDataEntity, batch_idx: int) -> Tensor:
def training_step(self, batch: T_OTXBatchDataEntity, batch_idx: int) -> Tensor | None:
"""Step for model training."""
train_loss = self.forward(inputs=batch)
if train_loss is None:
# to skip current iteration
# TODO (sungchul): check this in distributed training
return None if self.trainer.world_size == 1 else torch.tensor(0.0, device=self.device)

if isinstance(train_loss, Tensor):
self.log(
Expand Down
9 changes: 6 additions & 3 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,15 @@ def _customize_inputs(

return inputs

def _customize_outputs(
def _customize_outputs( # type: ignore[override]
self,
outputs: list[InstanceData] | dict,
outputs: list[InstanceData] | dict | None,
inputs: DetBatchDataEntity,
) -> DetBatchPredEntity | OTXBatchLossEntity:
) -> DetBatchPredEntity | OTXBatchLossEntity | None:
if self.training:
if outputs is None:
return outputs

if not isinstance(outputs, dict):
raise TypeError(outputs)

Expand Down
39 changes: 27 additions & 12 deletions src/otx/recipe/detection/yolov9_c.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ model:
optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.001
lr: 0.0001
momentum: 0.937
weight_decay: 0.0005
nesterov: true
Expand All @@ -16,13 +16,13 @@ model:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 3
warmup_interval: epoch
main_scheduler_callable:
class_path: torch.optim.lr_scheduler.LinearLR
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
total_iters: 200
start_factor: 1
end_factor: 0.01
mode: max
factor: 0.1
patience: 4
monitor: val/map_50

engine:
task: DETECTION
Expand All @@ -42,23 +42,38 @@ overrides:
input_size:
- 640
- 640
image_color_channel: BGR
train_subset:
batch_size: 16
batch_size: 10
transforms:
- class_path: otx.core.data.transform_libs.torchvision.CachedMosaic
init_args:
random_pop: false
max_cached_images: 20
img_scale: $(input_size) # (H, W)
- class_path: otx.core.data.transform_libs.torchvision.RandomCrop
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
init_args:
crop_size: $(input_size) * 0.5
scaling_ratio_range:
- 0.1
- 2.0
border: $(input_size) * -0.5
- class_path: otx.core.data.transform_libs.torchvision.CachedMixUp
init_args:
img_scale: $(input_size) # (H, W)
ratio_range:
- 1.0
- 1.0
prob: 0.5
random_pop: false
max_cached_images: 10
- class_path: otx.core.data.transform_libs.torchvision.YOLOXHSVRandomAug
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: true
transform_bbox: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: true
Expand All @@ -75,7 +90,7 @@ overrides:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
batch_size: 16
batch_size: 10
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand All @@ -95,7 +110,7 @@ overrides:
std: [255.0, 255.0, 255.0]

test_subset:
batch_size: 16
batch_size: 10
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand Down
29 changes: 22 additions & 7 deletions src/otx/recipe/detection/yolov9_m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ model:
optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.001
lr: 0.0001
momentum: 0.937
weight_decay: 0.0005
nesterov: true
Expand Down Expand Up @@ -42,23 +42,38 @@ overrides:
input_size:
- 640
- 640
image_color_channel: BGR
train_subset:
batch_size: 16
batch_size: 12
transforms:
- class_path: otx.core.data.transform_libs.torchvision.CachedMosaic
init_args:
random_pop: false
max_cached_images: 20
img_scale: $(input_size) # (H, W)
- class_path: otx.core.data.transform_libs.torchvision.RandomCrop
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
init_args:
crop_size: $(input_size) * 0.5
scaling_ratio_range:
- 0.1
- 2.0
border: $(input_size) * -0.5
- class_path: otx.core.data.transform_libs.torchvision.CachedMixUp
init_args:
img_scale: $(input_size) # (H, W)
ratio_range:
- 1.0
- 1.0
prob: 0.5
random_pop: false
max_cached_images: 10
- class_path: otx.core.data.transform_libs.torchvision.YOLOXHSVRandomAug
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: true
transform_bbox: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: true
Expand All @@ -75,7 +90,7 @@ overrides:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
batch_size: 16
batch_size: 12
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand All @@ -95,7 +110,7 @@ overrides:
std: [255.0, 255.0, 255.0]

test_subset:
batch_size: 16
batch_size: 12
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand Down
34 changes: 18 additions & 16 deletions src/otx/recipe/detection/yolov9_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,21 @@ model:
label_info: 80

optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.Adam
init_args:
lr: 0.01
momentum: 0.937
weight_decay: 0.0005
nesterov: true
lr: 0.001

scheduler:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 3
warmup_interval: epoch
main_scheduler_callable:
class_path: torch.optim.lr_scheduler.LinearLR
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
total_iters: 200
start_factor: 1
end_factor: 0.01
mode: max
factor: 0.1
patience: 4
monitor: val/map_50

engine:
task: DETECTION
Expand All @@ -42,23 +39,28 @@ overrides:
input_size:
- 640
- 640
image_color_channel: BGR
train_subset:
batch_size: 16
batch_size: 14
transforms:
- class_path: otx.core.data.transform_libs.torchvision.CachedMosaic
init_args:
random_pop: false
max_cached_images: 20
img_scale: $(input_size) # (H, W)
- class_path: otx.core.data.transform_libs.torchvision.RandomCrop
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
init_args:
crop_size: $(input_size) * 0.5
scaling_ratio_range:
- 0.1
- 2.0
border: $(input_size) * -0.5
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: true
transform_bbox: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: true
Expand All @@ -75,7 +77,7 @@ overrides:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
batch_size: 16
batch_size: 14
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand All @@ -95,7 +97,7 @@ overrides:
std: [255.0, 255.0, 255.0]

test_subset:
batch_size: 16
batch_size: 14
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/algo/detection/losses/test_yolov9_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_forward(self, mocker, criterion: YOLOv9Criterion) -> None:
return_value=(torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)),
)
main_preds = torch.tensor(0.0)
targets = torch.tensor(0.0)
targets = torch.zeros(1, 1, 4)
aux_preds = torch.tensor(0.0)

loss_dict = criterion.forward(main_preds, targets, aux_preds)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/algo/detection/test_yolov9.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_exporter(self) -> None:
otx_yolov9_s = YOLOv9(model_name="yolov9_s", label_info=3)
otx_yolov9_s_exporter = otx_yolov9_s._exporter
assert isinstance(otx_yolov9_s_exporter, OTXNativeModelExporter)
assert otx_yolov9_s_exporter.swap_rgb is True
assert otx_yolov9_s_exporter.swap_rgb is False

def test_to(self) -> None:
model = YOLOv9(model_name="yolov9_s", label_info=3)
Expand Down
Loading
Loading