diff --git a/src/otx/algo/detection/heads/yolo_head.py b/src/otx/algo/detection/heads/yolo_head.py index 429f1ccf101..042ef4898a9 100644 --- a/src/otx/algo/detection/heads/yolo_head.py +++ b/src/otx/algo/detection/heads/yolo_head.py @@ -296,8 +296,8 @@ class YOLOHeadModule(BaseDenseHead): csp_args (dict[str, Any], optional): Arguments for CSP blocks. Defaults to None. aux_cfg (dict[str, Any], optional): Configuration for auxiliary head. Defaults to None. with_nms (bool, optional): Whether to use NMS. Defaults to True. - min_confidence (float, optional): Minimum confidence for NMS. Defaults to 0.05. - min_iou (float, optional): Minimum IoU for NMS. Defaults to 0.9. + min_confidence (float, optional): Minimum confidence for NMS. Defaults to 0.1. + min_iou (float, optional): Minimum IoU for NMS. Defaults to 0.65. """ def __init__( @@ -311,8 +311,8 @@ def __init__( csp_args: dict[str, Any] | None = None, aux_cfg: dict[str, Any] | None = None, with_nms: bool = True, - min_confidence: float = 0.05, - min_iou: float = 0.9, + min_confidence: float = 0.1, + min_iou: float = 0.65, ) -> None: if len(csp_channels) - 1 != len(concat_sources): msg = ( diff --git a/src/otx/algo/detection/losses/yolov9_loss.py b/src/otx/algo/detection/losses/yolov9_loss.py index 6e13925940c..349043ff552 100644 --- a/src/otx/algo/detection/losses/yolov9_loss.py +++ b/src/otx/algo/detection/losses/yolov9_loss.py @@ -371,9 +371,9 @@ def __init__( loss_dfl: nn.Module | None = None, loss_iou: nn.Module | None = None, reg_max: int = 16, - cls_rate: float = 1.5, - dfl_rate: float = 7.5, - iou_rate: float = 0.5, + cls_rate: float = 0.5, + dfl_rate: float = 1.5, + iou_rate: float = 7.5, aux_rate: float = 0.25, ) -> None: super().__init__() @@ -394,7 +394,7 @@ def forward( main_preds: tuple[Tensor, Tensor, Tensor], targets: Tensor, aux_preds: tuple[Tensor, Tensor, Tensor] | None = None, - ) -> dict[str, Tensor]: + ) -> dict[str, Tensor] | None: """Forward pass of the YOLOv9 criterion module. Args: @@ -405,6 +405,10 @@ def forward( Returns: dict[str, Tensor]: The loss dictionary. """ + if targets.shape[1] == 0: + # TODO (sungchul): should this step be done here? + return None + main_preds = self.vec2box(main_preds) main_iou, main_dfl, main_cls = self._forward(main_preds, targets) diff --git a/src/otx/algo/detection/rtdetr.py b/src/otx/algo/detection/rtdetr.py index d398a68c208..87784dadd7a 100644 --- a/src/otx/algo/detection/rtdetr.py +++ b/src/otx/algo/detection/rtdetr.py @@ -135,7 +135,7 @@ def _customize_inputs( def _customize_outputs( self, - outputs: list[torch.Tensor] | dict, + outputs: list[torch.Tensor] | dict, # type: ignore[override] inputs: DetBatchDataEntity, ) -> DetBatchPredEntity | OTXBatchLossEntity: if self.training: diff --git a/src/otx/algo/detection/yolov9.py b/src/otx/algo/detection/yolov9.py index bc61d52888b..092782bb3e4 100644 --- a/src/otx/algo/detection/yolov9.py +++ b/src/otx/algo/detection/yolov9.py @@ -19,6 +19,7 @@ from otx.core.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.detection import OTXDetectionModel +from otx.core.types.export import TaskLevelExportParameters if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -116,7 +117,7 @@ def _exporter(self) -> OTXModelExporter: std=self.std, resize_mode="fit_to_window_letterbox", pad_value=114, - swap_rgb=True, + swap_rgb=False, via_onnx=True, onnx_export_configuration={ "input_names": ["image"], @@ -135,6 +136,14 @@ def _exporter(self) -> OTXModelExporter: output_names=None, # TODO (someone): support XAI ) + @property + def _export_parameters(self) -> TaskLevelExportParameters: + """Defines parameters required to export a particular model implementation.""" + return super()._export_parameters.wrap( + confidence_threshold=self.model.bbox_head.min_confidence, + iou_threshold=self.model.bbox_head.min_iou, + ) + def to(self, *args, **kwargs) -> Self: """Sync device of the model and its components.""" ret = super().to(*args, **kwargs) diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index 8424dd99a12..bc2604c7344 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -141,9 +141,13 @@ def __init__( # so that it can retrieve it from the checkpoint self.save_hyperparameters(logger=False, ignore=["optimizer", "scheduler", "metric"]) - def training_step(self, batch: T_OTXBatchDataEntity, batch_idx: int) -> Tensor: + def training_step(self, batch: T_OTXBatchDataEntity, batch_idx: int) -> Tensor | None: """Step for model training.""" train_loss = self.forward(inputs=batch) + if train_loss is None: + # to skip current iteration + # TODO (sungchul): check this in distributed training + return None if self.trainer.world_size == 1 else torch.tensor(0.0, device=self.device) if isinstance(train_loss, Tensor): self.log( diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index 5c038eea1bd..fb6acccf12d 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -137,12 +137,15 @@ def _customize_inputs( return inputs - def _customize_outputs( + def _customize_outputs( # type: ignore[override] self, - outputs: list[InstanceData] | dict, + outputs: list[InstanceData] | dict | None, inputs: DetBatchDataEntity, - ) -> DetBatchPredEntity | OTXBatchLossEntity: + ) -> DetBatchPredEntity | OTXBatchLossEntity | None: if self.training: + if outputs is None: + return outputs + if not isinstance(outputs, dict): raise TypeError(outputs) diff --git a/src/otx/recipe/detection/yolov9_c.yaml b/src/otx/recipe/detection/yolov9_c.yaml index c87bc957060..dfb0ae43009 100644 --- a/src/otx/recipe/detection/yolov9_c.yaml +++ b/src/otx/recipe/detection/yolov9_c.yaml @@ -7,7 +7,7 @@ model: optimizer: class_path: torch.optim.SGD init_args: - lr: 0.001 + lr: 0.0001 momentum: 0.937 weight_decay: 0.0005 nesterov: true @@ -16,13 +16,13 @@ model: class_path: otx.core.schedulers.LinearWarmupSchedulerCallable init_args: num_warmup_steps: 3 - warmup_interval: epoch main_scheduler_callable: - class_path: torch.optim.lr_scheduler.LinearLR + class_path: lightning.pytorch.cli.ReduceLROnPlateau init_args: - total_iters: 200 - start_factor: 1 - end_factor: 0.01 + mode: max + factor: 0.1 + patience: 4 + monitor: val/map_50 engine: task: DETECTION @@ -42,23 +42,38 @@ overrides: input_size: - 640 - 640 - image_color_channel: BGR train_subset: - batch_size: 16 + batch_size: 10 transforms: - class_path: otx.core.data.transform_libs.torchvision.CachedMosaic init_args: random_pop: false max_cached_images: 20 img_scale: $(input_size) # (H, W) - - class_path: otx.core.data.transform_libs.torchvision.RandomCrop + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine init_args: - crop_size: $(input_size) * 0.5 + scaling_ratio_range: + - 0.1 + - 2.0 + border: $(input_size) * -0.5 + - class_path: otx.core.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: $(input_size) # (H, W) + ratio_range: + - 1.0 + - 1.0 + prob: 0.5 + random_pop: false + max_cached_images: 10 + - class_path: otx.core.data.transform_libs.torchvision.YOLOXHSVRandomAug - class_path: otx.core.data.transform_libs.torchvision.Resize init_args: scale: $(input_size) keep_ratio: true transform_bbox: true + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 - class_path: otx.core.data.transform_libs.torchvision.Pad init_args: pad_to_square: true @@ -75,7 +90,7 @@ overrides: class_path: otx.algo.samplers.balanced_sampler.BalancedSampler val_subset: - batch_size: 16 + batch_size: 10 transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize init_args: @@ -95,7 +110,7 @@ overrides: std: [255.0, 255.0, 255.0] test_subset: - batch_size: 16 + batch_size: 10 transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize init_args: diff --git a/src/otx/recipe/detection/yolov9_m.yaml b/src/otx/recipe/detection/yolov9_m.yaml index 6cddf04c235..34daaa8f6ad 100644 --- a/src/otx/recipe/detection/yolov9_m.yaml +++ b/src/otx/recipe/detection/yolov9_m.yaml @@ -7,7 +7,7 @@ model: optimizer: class_path: torch.optim.SGD init_args: - lr: 0.001 + lr: 0.0001 momentum: 0.937 weight_decay: 0.0005 nesterov: true @@ -42,23 +42,38 @@ overrides: input_size: - 640 - 640 - image_color_channel: BGR train_subset: - batch_size: 16 + batch_size: 12 transforms: - class_path: otx.core.data.transform_libs.torchvision.CachedMosaic init_args: random_pop: false max_cached_images: 20 img_scale: $(input_size) # (H, W) - - class_path: otx.core.data.transform_libs.torchvision.RandomCrop + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine init_args: - crop_size: $(input_size) * 0.5 + scaling_ratio_range: + - 0.1 + - 2.0 + border: $(input_size) * -0.5 + - class_path: otx.core.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: $(input_size) # (H, W) + ratio_range: + - 1.0 + - 1.0 + prob: 0.5 + random_pop: false + max_cached_images: 10 + - class_path: otx.core.data.transform_libs.torchvision.YOLOXHSVRandomAug - class_path: otx.core.data.transform_libs.torchvision.Resize init_args: scale: $(input_size) keep_ratio: true transform_bbox: true + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 - class_path: otx.core.data.transform_libs.torchvision.Pad init_args: pad_to_square: true @@ -75,7 +90,7 @@ overrides: class_path: otx.algo.samplers.balanced_sampler.BalancedSampler val_subset: - batch_size: 16 + batch_size: 12 transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize init_args: @@ -95,7 +110,7 @@ overrides: std: [255.0, 255.0, 255.0] test_subset: - batch_size: 16 + batch_size: 12 transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize init_args: diff --git a/src/otx/recipe/detection/yolov9_s.yaml b/src/otx/recipe/detection/yolov9_s.yaml index 47b32c6565f..4900338e796 100644 --- a/src/otx/recipe/detection/yolov9_s.yaml +++ b/src/otx/recipe/detection/yolov9_s.yaml @@ -5,24 +5,21 @@ model: label_info: 80 optimizer: - class_path: torch.optim.SGD + class_path: torch.optim.Adam init_args: - lr: 0.01 - momentum: 0.937 - weight_decay: 0.0005 - nesterov: true + lr: 0.001 scheduler: class_path: otx.core.schedulers.LinearWarmupSchedulerCallable init_args: num_warmup_steps: 3 - warmup_interval: epoch main_scheduler_callable: - class_path: torch.optim.lr_scheduler.LinearLR + class_path: lightning.pytorch.cli.ReduceLROnPlateau init_args: - total_iters: 200 - start_factor: 1 - end_factor: 0.01 + mode: max + factor: 0.1 + patience: 4 + monitor: val/map_50 engine: task: DETECTION @@ -42,23 +39,28 @@ overrides: input_size: - 640 - 640 - image_color_channel: BGR train_subset: - batch_size: 16 + batch_size: 14 transforms: - class_path: otx.core.data.transform_libs.torchvision.CachedMosaic init_args: random_pop: false max_cached_images: 20 img_scale: $(input_size) # (H, W) - - class_path: otx.core.data.transform_libs.torchvision.RandomCrop + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine init_args: - crop_size: $(input_size) * 0.5 + scaling_ratio_range: + - 0.1 + - 2.0 + border: $(input_size) * -0.5 - class_path: otx.core.data.transform_libs.torchvision.Resize init_args: scale: $(input_size) keep_ratio: true transform_bbox: true + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 - class_path: otx.core.data.transform_libs.torchvision.Pad init_args: pad_to_square: true @@ -75,7 +77,7 @@ overrides: class_path: otx.algo.samplers.balanced_sampler.BalancedSampler val_subset: - batch_size: 16 + batch_size: 14 transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize init_args: @@ -95,7 +97,7 @@ overrides: std: [255.0, 255.0, 255.0] test_subset: - batch_size: 16 + batch_size: 14 transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize init_args: diff --git a/tests/unit/algo/detection/losses/test_yolov9_loss.py b/tests/unit/algo/detection/losses/test_yolov9_loss.py index 6267db3b284..ef228eca5b8 100644 --- a/tests/unit/algo/detection/losses/test_yolov9_loss.py +++ b/tests/unit/algo/detection/losses/test_yolov9_loss.py @@ -185,7 +185,7 @@ def test_forward(self, mocker, criterion: YOLOv9Criterion) -> None: return_value=(torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)), ) main_preds = torch.tensor(0.0) - targets = torch.tensor(0.0) + targets = torch.zeros(1, 1, 4) aux_preds = torch.tensor(0.0) loss_dict = criterion.forward(main_preds, targets, aux_preds) diff --git a/tests/unit/algo/detection/test_yolov9.py b/tests/unit/algo/detection/test_yolov9.py index 645b211cf09..de66c05a111 100644 --- a/tests/unit/algo/detection/test_yolov9.py +++ b/tests/unit/algo/detection/test_yolov9.py @@ -29,7 +29,7 @@ def test_exporter(self) -> None: otx_yolov9_s = YOLOv9(model_name="yolov9_s", label_info=3) otx_yolov9_s_exporter = otx_yolov9_s._exporter assert isinstance(otx_yolov9_s_exporter, OTXNativeModelExporter) - assert otx_yolov9_s_exporter.swap_rgb is True + assert otx_yolov9_s_exporter.swap_rgb is False def test_to(self) -> None: model = YOLOv9(model_name="yolov9_s", label_info=3) diff --git a/tests/unit/core/model/test_base.py b/tests/unit/core/model/test_base.py index 3a24908e99f..eea3821c73a 100644 --- a/tests/unit/core/model/test_base.py +++ b/tests/unit/core/model/test_base.py @@ -25,6 +25,24 @@ def test_init(self, monkeypatch): with pytest.raises(ValueError, match="Input size should be a multiple"): OTXModel(label_info=2, input_size=(1024, 1024)) + def test_training_step_none_loss(self, mocker: MockerFixture) -> None: + mock_trainer = mocker.create_autospec(spec=Trainer) + mock_trainer.world_size = 1 + with mocker.patch.object(OTXModel, "_create_model", return_value=MockNNModule(3)) and mocker.patch.object( + OTXModel, + "forward", + return_value=None, + ): + current_model = OTXModel(label_info=3) + current_model.trainer = mock_trainer + + batch = {"input": torch.randn(2, 3)} + batch_idx = 0 + + results = current_model.training_step(batch, batch_idx) + + assert results is None + def test_smart_weight_loading(self, mocker) -> None: with mocker.patch.object(OTXModel, "_create_model", return_value=MockNNModule(2)): prev_model = OTXModel(label_info=2)