From a0780a8cf36cb4bd924c0307134ccffbd6d8f240 Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Mon, 20 Nov 2023 13:30:58 +0900 Subject: [PATCH 1/2] Make `max_num_detections` configurable (#2647) * Make max_num_detections configurable * Fix RCNN case with integration test * Apply max_num_detections to train_cfg, too --------- Signed-off-by: Songki Choi --- CHANGELOG.md | 1 + .../common/configs/training_base.py | 25 +++++++++---------- .../detection/adapters/mmdet/configurer.py | 22 ++++++++++++++-- .../detection/adapters/mmdet/task.py | 20 ++++++--------- .../detection/adapters/openvino/task.py | 15 ++--------- .../detection/configs/base/configuration.py | 15 ++--------- .../configs/detection/configuration.yaml | 19 ++++++++++++++ .../instance_segmentation/configuration.yaml | 19 ++++++++++++++ .../convnext_maskrcnn/model.py | 4 +-- .../resnet50_maskrcnn/model.py | 17 +++---------- .../rotated_detection/configuration.yaml | 19 ++++++++++++++ .../resnet50_maskrcnn/model.py | 2 +- src/otx/algorithms/detection/task.py | 24 ++++++++---------- .../cli/detection/test_detection.py | 10 +++++++- .../cli/detection/test_tiling_detection.py | 4 ++- .../test_instance_segmentation.py | 10 +++++++- .../test_tiling_instseg.py | 4 ++- .../adapters/mmdet/test_configurer.py | 6 +++-- 18 files changed, 145 insertions(+), 91 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9116e8b6adb..0746c1f194b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ All notable changes to this project will be documented in this file. - Update ModelAPI configuration() - Add Anomaly modelAPI changes () - Update Image numpy access () +- Make max_num_detections configurable () ### Bug fixes diff --git a/src/otx/algorithms/common/configs/training_base.py b/src/otx/algorithms/common/configs/training_base.py index 40f9b7a7529..6690b6e1c3e 100644 --- a/src/otx/algorithms/common/configs/training_base.py +++ b/src/otx/algorithms/common/configs/training_base.py @@ -1,18 +1,7 @@ """Base Configuration of OTX Common Algorithms.""" -# Copyright (C) 2022 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. +# Copyright (C) 2022-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 from sys import maxsize @@ -227,6 +216,16 @@ class BasePostprocessing(ParameterGroup): affects_outcome_of=ModelLifecycle.INFERENCE, ) + max_num_detections = configurable_integer( + header="Maximum number of detection per image", + description="Extra detection outputs will be discared in non-maximum suppression process. " + "Defaults to 0, which means per-model default value.", + default_value=0, + min_value=0, + max_value=10000, + affects_outcome_of=ModelLifecycle.INFERENCE, + ) + use_ellipse_shapes = configurable_boolean( default_value=False, header="Use ellipse shapes", diff --git a/src/otx/algorithms/detection/adapters/mmdet/configurer.py b/src/otx/algorithms/detection/adapters/mmdet/configurer.py index 0e4d63b1b99..3f1b624954e 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/configurer.py +++ b/src/otx/algorithms/detection/adapters/mmdet/configurer.py @@ -64,13 +64,14 @@ def configure( ir_options=None, data_classes=None, model_classes=None, + max_num_detections=0, ): """Create MMCV-consumable config from given inputs.""" logger.info(f"configure!: training={training}") self.configure_base(cfg, data_cfg, data_classes, model_classes) self.configure_device(cfg, training) - self.configure_model(cfg, ir_options) + self.configure_model(cfg, ir_options, max_num_detections) self.configure_ckpt(cfg, model_ckpt) self.configure_data(cfg, training, data_cfg) self.configure_regularization(cfg, training) @@ -113,7 +114,7 @@ def configure_base(self, cfg, data_cfg, data_classes, model_classes): new_classes = np.setdiff1d(data_classes, model_classes).tolist() train_data_cfg["new_classes"] = new_classes - def configure_model(self, cfg, ir_options): # noqa: C901 + def configure_model(self, cfg, ir_options, max_num_detections=0): # noqa: C901 """Patch config's model. Change model type to super type @@ -149,6 +150,23 @@ def is_mmov_model(key, value): {"model_path": ir_model_path, "weight_path": ir_weight_path, "init_weight": ir_weight_init}, ) + # Test config + if max_num_detections > 0: + logger.info(f"Model max_num_detections: {max_num_detections}") + test_cfg = cfg.model.test_cfg + test_cfg.max_per_img = max_num_detections + test_cfg.nms_pre = max_num_detections * 10 + # Special cases for 2-stage detectors (e.g. MaskRCNN) + if hasattr(test_cfg, "rpn"): + test_cfg.rpn.nms_pre = max_num_detections * 20 + test_cfg.rpn.max_per_img = max_num_detections * 10 + if hasattr(test_cfg, "rcnn"): + test_cfg.rcnn.max_per_img = max_num_detections + train_cfg = cfg.model.train_cfg + if hasattr(train_cfg, "rpn_proposal"): + train_cfg.rpn_proposal.nms_pre = max_num_detections * 20 + train_cfg.rpn_proposal.max_per_img = max_num_detections * 10 + def configure_data(self, cfg, training, data_cfg): # noqa: C901 """Patch cfg.data. diff --git a/src/otx/algorithms/detection/adapters/mmdet/task.py b/src/otx/algorithms/detection/adapters/mmdet/task.py index 592a63e4eb8..e1c229630b2 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/task.py +++ b/src/otx/algorithms/detection/adapters/mmdet/task.py @@ -1,18 +1,7 @@ """Task of OTX Detection using mmdetection training backend.""" # Copyright (C) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 import glob import io @@ -206,6 +195,7 @@ def configure(self, training=True, subset="train", ir_options=None, train_datase ir_options, data_classes, model_classes, + self.max_num_detections, ) if should_cluster_anchors(self._recipe_cfg): if train_dataset is not None: @@ -513,6 +503,12 @@ def _export_model( assert len(self._precision) == 1 export_options["precision"] = str(self._precision[0]) export_options["type"] = str(export_format) + if self.max_num_detections > 0: + logger.info(f"Export max_num_detections: {self.max_num_detections}") + post_proc_cfg = export_options["deploy_cfg"]["codebase_config"]["post_processing"] + post_proc_cfg["max_output_boxes_per_class"] = self.max_num_detections + post_proc_cfg["keep_top_k"] = self.max_num_detections + post_proc_cfg["pre_top_k"] = self.max_num_detections * 10 export_options["deploy_cfg"]["dump_features"] = dump_features if dump_features: diff --git a/src/otx/algorithms/detection/adapters/openvino/task.py b/src/otx/algorithms/detection/adapters/openvino/task.py index 08b5423eef2..1a8376e453e 100644 --- a/src/otx/algorithms/detection/adapters/openvino/task.py +++ b/src/otx/algorithms/detection/adapters/openvino/task.py @@ -1,18 +1,7 @@ """Openvino Task of Detection.""" -# Copyright (C) 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. +# Copyright (C) 2021-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 import copy import io diff --git a/src/otx/algorithms/detection/configs/base/configuration.py b/src/otx/algorithms/detection/configs/base/configuration.py index 0e258f7ed8c..c2481a1633b 100644 --- a/src/otx/algorithms/detection/configs/base/configuration.py +++ b/src/otx/algorithms/detection/configs/base/configuration.py @@ -1,18 +1,7 @@ """Configuration file of OTX Detection.""" -# Copyright (C) 2022 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. +# Copyright (C) 2022-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 from attr import attrs diff --git a/src/otx/algorithms/detection/configs/detection/configuration.yaml b/src/otx/algorithms/detection/configs/detection/configuration.yaml index 2784be12db9..9172db3ff4e 100644 --- a/src/otx/algorithms/detection/configs/detection/configuration.yaml +++ b/src/otx/algorithms/detection/configs/detection/configuration.yaml @@ -258,6 +258,25 @@ postprocessing: value: 0.01 visible_in_ui: true warning: null + max_num_detections: + affects_outcome_of: INFERENCE + default_value: 0 + description: + Extra detection outputs will be discared in non-maximum suppression process. + Defaults to 0, which means per-model default values. + editable: true + header: Maximum number of detections per image + max_value: 10000 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + value: 0 + visible_in_ui: true + warning: null use_ellipse_shapes: affects_outcome_of: INFERENCE default_value: false diff --git a/src/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml b/src/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml index 34c63e88ae0..9b906969a51 100644 --- a/src/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml +++ b/src/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml @@ -258,6 +258,25 @@ postprocessing: value: 0.01 visible_in_ui: true warning: null + max_num_detections: + affects_outcome_of: INFERENCE + default_value: 0 + description: + Extra detection outputs will be discared in non-maximum suppression process. + Defaults to 0, which means per-model default values. + editable: true + header: Maximum number of detections per image + max_value: 10000 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + value: 0 + visible_in_ui: true + warning: null use_ellipse_shapes: affects_outcome_of: INFERENCE default_value: false diff --git a/src/otx/algorithms/detection/configs/instance_segmentation/convnext_maskrcnn/model.py b/src/otx/algorithms/detection/configs/instance_segmentation/convnext_maskrcnn/model.py index e35799ed7e0..84d89dd682f 100644 --- a/src/otx/algorithms/detection/configs/instance_segmentation/convnext_maskrcnn/model.py +++ b/src/otx/algorithms/detection/configs/instance_segmentation/convnext_maskrcnn/model.py @@ -115,9 +115,7 @@ nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0, ), - rcnn=dict( - score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5, max_num=100), max_per_img=100, mask_thr_binary=0.5 - ), + rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5), ), ) diff --git a/src/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/model.py b/src/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/model.py index 6832028e425..22a3bdf4b9e 100644 --- a/src/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/model.py +++ b/src/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/model.py @@ -1,18 +1,7 @@ """Model configuration of Resnet50-MaskRCNN model for Instance-Seg Task.""" -# Copyright (C) 2022 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. +# Copyright (C) 2022-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # pylint: disable=invalid-name @@ -149,7 +138,7 @@ ), rcnn=dict( score_thr=0.05, - nms=dict(type="nms", iou_threshold=0.5, max_num=100), + nms=dict(type="nms", iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5, ), diff --git a/src/otx/algorithms/detection/configs/rotated_detection/configuration.yaml b/src/otx/algorithms/detection/configs/rotated_detection/configuration.yaml index aa1f41e9ec7..3b5059e0155 100644 --- a/src/otx/algorithms/detection/configs/rotated_detection/configuration.yaml +++ b/src/otx/algorithms/detection/configs/rotated_detection/configuration.yaml @@ -277,6 +277,25 @@ postprocessing: warning: null type: PARAMETER_GROUP visible_in_ui: true + max_num_detections: + affects_outcome_of: INFERENCE + default_value: 0 + description: + Extra detection outputs will be discared in non-maximum suppression process. + Defaults to 0, which means per-model default values. + editable: true + header: Maximum number of detections per image + max_value: 10000 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + value: 0 + visible_in_ui: true + warning: null algo_backend: description: parameters for algo backend header: Algo backend parameters diff --git a/src/otx/algorithms/detection/configs/rotated_detection/resnet50_maskrcnn/model.py b/src/otx/algorithms/detection/configs/rotated_detection/resnet50_maskrcnn/model.py index eee17a545c7..c406c97bd00 100644 --- a/src/otx/algorithms/detection/configs/rotated_detection/resnet50_maskrcnn/model.py +++ b/src/otx/algorithms/detection/configs/rotated_detection/resnet50_maskrcnn/model.py @@ -139,7 +139,7 @@ ), rcnn=dict( score_thr=0.05, - nms=dict(type="nms", iou_threshold=0.5, max_num=100), + nms=dict(type="nms", iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5, ), diff --git a/src/otx/algorithms/detection/task.py b/src/otx/algorithms/detection/task.py index ee94af104c2..92fed6fc3c5 100644 --- a/src/otx/algorithms/detection/task.py +++ b/src/otx/algorithms/detection/task.py @@ -1,18 +1,7 @@ """Task of OTX Detection.""" # Copyright (C) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 import io import os @@ -83,11 +72,13 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] ) self._anchors: Dict[str, int] = {} + self.confidence_threshold = 0.0 + self.max_num_detections = 0 if hasattr(self._hyperparams, "postprocessing"): if hasattr(self._hyperparams.postprocessing, "confidence_threshold"): self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold - else: - self.confidence_threshold = 0.0 + if hasattr(self._hyperparams.postprocessing, "max_num_detections"): + self.max_num_detections = self._hyperparams.postprocessing.max_num_detections if task_environment.model is not None: self._load_model() @@ -112,6 +103,11 @@ def _load_postprocessing(self, model_data): hparams.use_ellipse_shapes = loaded_postprocessing["use_ellipse_shapes"]["value"] else: hparams.use_ellipse_shapes = False + if "max_num_detections" in loaded_postprocessing: + trained_max_num_detections = loaded_postprocessing["max_num_detections"]["value"] + # Prefer new hparam value set by user (>0) intentionally than trained value + if self.max_num_detections == 0: + self.max_num_detections = trained_max_num_detections def _load_tiling_parameters(self, model_data): """Load tiling parameters from PyTorch model. diff --git a/tests/integration/cli/detection/test_detection.py b/tests/integration/cli/detection/test_detection.py index b9ba80011ba..a18625bcf43 100644 --- a/tests/integration/cli/detection/test_detection.py +++ b/tests/integration/cli/detection/test_detection.py @@ -36,7 +36,15 @@ "--val-data-roots": "tests/assets/car_tree_bug", "--test-data-roots": "tests/assets/car_tree_bug", "--input": "tests/assets/car_tree_bug/images/train", - "train_params": ["params", "--learning_parameters.num_iters", "1", "--learning_parameters.batch_size", "4"], + "train_params": [ + "params", + "--learning_parameters.num_iters", + "1", + "--learning_parameters.batch_size", + "4", + "--postprocessing.max_num_detections", + "200", + ], } args_semisl = { diff --git a/tests/integration/cli/detection/test_tiling_detection.py b/tests/integration/cli/detection/test_tiling_detection.py index a37ab5c1b45..02ba75a8a9c 100644 --- a/tests/integration/cli/detection/test_tiling_detection.py +++ b/tests/integration/cli/detection/test_tiling_detection.py @@ -1,5 +1,5 @@ """Tests for MPA Class-Incremental Learning for object detection with OTX CLI""" -# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2022-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # import os @@ -36,6 +36,8 @@ "1", "--tiling_parameters.enable_adaptive_params", "1", + "--postprocessing.max_num_detections", + "200", ], } diff --git a/tests/integration/cli/instance_segmentation/test_instance_segmentation.py b/tests/integration/cli/instance_segmentation/test_instance_segmentation.py index c2ba3b9c3df..93133e09dd1 100644 --- a/tests/integration/cli/instance_segmentation/test_instance_segmentation.py +++ b/tests/integration/cli/instance_segmentation/test_instance_segmentation.py @@ -32,7 +32,15 @@ "--val-data-roots": "tests/assets/car_tree_bug", "--test-data-roots": "tests/assets/car_tree_bug", "--input": "tests/assets/car_tree_bug/images/train", - "train_params": ["params", "--learning_parameters.num_iters", "1", "--learning_parameters.batch_size", "2"], + "train_params": [ + "params", + "--learning_parameters.num_iters", + "1", + "--learning_parameters.batch_size", + "2", + "--postprocessing.max_num_detections", + "200", + ], } # Training params for resume, num_iters*2 diff --git a/tests/integration/cli/instance_segmentation/test_tiling_instseg.py b/tests/integration/cli/instance_segmentation/test_tiling_instseg.py index 5b8d528e3c0..d39b4be372e 100644 --- a/tests/integration/cli/instance_segmentation/test_tiling_instseg.py +++ b/tests/integration/cli/instance_segmentation/test_tiling_instseg.py @@ -1,5 +1,5 @@ """Tests for MPA Class-Incremental Learning for instance segmentation with OTX CLI""" -# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2022-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # import copy @@ -41,6 +41,8 @@ "1", "--tiling_parameters.enable_adaptive_params", "1", + "--postprocessing.max_num_detections", + "200", ], } diff --git a/tests/unit/algorithms/detection/adapters/mmdet/test_configurer.py b/tests/unit/algorithms/detection/adapters/mmdet/test_configurer.py index 38dada73c0a..a582a266111 100644 --- a/tests/unit/algorithms/detection/adapters/mmdet/test_configurer.py +++ b/tests/unit/algorithms/detection/adapters/mmdet/test_configurer.py @@ -43,10 +43,12 @@ def test_configure(self, mocker): model_cfg = copy.deepcopy(self.model_cfg) data_cfg = copy.deepcopy(self.data_cfg) - returned_value = self.configurer.configure(model_cfg, self.det_dataset, "", data_cfg, True) + returned_value = self.configurer.configure( + model_cfg, self.det_dataset, "", data_cfg, True, max_num_detections=100 + ) mock_cfg_base.assert_called_once_with(model_cfg, data_cfg, None, None) mock_cfg_device.assert_called_once_with(model_cfg, True) - mock_cfg_model.assert_called_once_with(model_cfg, None) + mock_cfg_model.assert_called_once_with(model_cfg, None, 100) mock_cfg_ckpt.assert_called_once_with(model_cfg, "") mock_cfg_regularization.assert_called_once_with(model_cfg, True) mock_cfg_task.assert_called_once_with(model_cfg, self.det_dataset, True) From aceebdaa6dff910264a18aa301a52f4e4483f9ec Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Tue, 21 Nov 2023 09:57:20 +0900 Subject: [PATCH 2/2] Fix CPU training issue on non-CUDA system (#2655) Fix bug that auto adaptive batch size raises an error if CUDA isn't available (#2410) --------- Co-authored-by: Sungman Cho Co-authored-by: Eunwoo Shin --- .../common/adapters/mmcv/utils/automatic_bs.py | 5 +++++ .../common/adapters/mmcv/utils/test_automatic_bs.py | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py b/src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py index c01cf8abc76..9b95d58195f 100644 --- a/src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py +++ b/src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py @@ -8,6 +8,7 @@ from typing import Callable, Dict, List import numpy as np +from torch.cuda import is_available as cuda_available from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo from otx.algorithms.common.utils.logger import get_logger @@ -53,6 +54,10 @@ def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool = not_increase (bool) : Whether adapting batch size to larger value than default value or not. """ + if not cuda_available(): + logger.warning("Skip Auto-adaptive batch size: CUDA should be available, but it isn't.") + return + def train_func_single_iter(batch_size): copied_cfg = deepcopy(cfg) _set_batch_size(copied_cfg, batch_size) diff --git a/tests/unit/algorithms/common/adapters/mmcv/utils/test_automatic_bs.py b/tests/unit/algorithms/common/adapters/mmcv/utils/test_automatic_bs.py index f9e1eb7d231..d590b74faf2 100644 --- a/tests/unit/algorithms/common/adapters/mmcv/utils/test_automatic_bs.py +++ b/tests/unit/algorithms/common/adapters/mmcv/utils/test_automatic_bs.py @@ -109,6 +109,19 @@ def test_adapt_batch_size( assert len(mock_train_func.call_args_list[0].kwargs["cfg"].custom_hooks) == 1 +def test_adapt_batch_size_no_gpu(mocker, common_cfg, mock_dataset): + # prepare + mock_train_func = mocker.MagicMock() + mock_config = set_mock_cfg_not_action(common_cfg) + mocker.patch.object(automatic_bs, "cuda_available", return_value=False) + + # execute + adapt_batch_size(mock_train_func, mock_config, mock_dataset, False, True) + + # check train function ins't called. + mock_train_func.assert_not_called() + + class TestSubDataset: @pytest.fixture(autouse=True) def set_up(self, mocker):