diff --git a/otx/algorithms/classification/tasks/inference.py b/otx/algorithms/classification/tasks/inference.py index 4c618105650..4f46afcb8a7 100644 --- a/otx/algorithms/classification/tasks/inference.py +++ b/otx/algorithms/classification/tasks/inference.py @@ -8,6 +8,7 @@ from typing import Optional import numpy as np +import torch.distributed as dist from mmcv.utils import ConfigDict from otx.algorithms.classification.configs import ClassificationConfig @@ -454,7 +455,11 @@ def patch_color_conversion(pipeline): # In train dataset, when sample size is smaller than batch size if subset == "train" and self._data_cfg: train_data_cfg = Stage.get_data_cfg(self._data_cfg, "train") - if len(train_data_cfg.get("otx_dataset", [])) < self._recipe_cfg.data.get("samples_per_gpu", 2): + num_worlds = dist.get_world_size() if dist.is_initialized() else 1 + if ( + len(train_data_cfg.get("otx_dataset", [])) + < self._recipe_cfg.data.get("samples_per_gpu", 2) * num_worlds + ): cfg.drop_last = False cfg.domain = domain diff --git a/otx/mpa/cls/trainer.py b/otx/mpa/cls/trainer.py index cb6a0845516..c42f125b9ec 100644 --- a/otx/mpa/cls/trainer.py +++ b/otx/mpa/cls/trainer.py @@ -8,6 +8,7 @@ import warnings import mmcv +import torch.distributed as dist from mmcls import __version__ from mmcls.core import DistOptimizerHook from mmcls.datasets import build_dataloader, build_dataset @@ -105,7 +106,11 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs): drop_last = False dataset_len = len(otx_dataset) if otx_dataset else 0 # if task == h-label & dataset size is bigger than batch size - if train_data_cfg.get("hierarchical_info", None) and dataset_len > cfg.data.get("samples_per_gpu", 2): + num_worlds = dist.get_world_size() if self.distributed else 1 + if ( + train_data_cfg.get("hierarchical_info", None) + and dataset_len > cfg.data.get("samples_per_gpu", 2) * num_worlds + ): drop_last = True # updated to adapt list of dataset for the 'train' data_loaders = [ diff --git a/tests/integration/cli/classification/test_classification.py b/tests/integration/cli/classification/test_classification.py index 6a62e3a4bad..b8c1dca1e53 100644 --- a/tests/integration/cli/classification/test_classification.py +++ b/tests/integration/cli/classification/test_classification.py @@ -464,7 +464,7 @@ def test_otx_multi_gpu_train(self, template, tmp_dir_path): "--learning_parameters.num_iters", "2", "--learning_parameters.batch_size", - "2", + "4", ], }