Skip to content

Commit

Permalink
[OTX] consider multi GPU when setting drop last (#1520)
Browse files Browse the repository at this point in the history
* when set drop_last as true, consider mutli gpu

* bugfix and revert TC batch size

* change num_gpus to num_worlds
  • Loading branch information
eunwoosh authored Jan 13, 2023
1 parent 4ac5842 commit b509c11
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
7 changes: 6 additions & 1 deletion otx/algorithms/classification/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional

import numpy as np
import torch.distributed as dist
from mmcv.utils import ConfigDict

from otx.algorithms.classification.configs import ClassificationConfig
Expand Down Expand Up @@ -454,7 +455,11 @@ def patch_color_conversion(pipeline):
# In train dataset, when sample size is smaller than batch size
if subset == "train" and self._data_cfg:
train_data_cfg = Stage.get_data_cfg(self._data_cfg, "train")
if len(train_data_cfg.get("otx_dataset", [])) < self._recipe_cfg.data.get("samples_per_gpu", 2):
num_worlds = dist.get_world_size() if dist.is_initialized() else 1
if (
len(train_data_cfg.get("otx_dataset", []))
< self._recipe_cfg.data.get("samples_per_gpu", 2) * num_worlds
):
cfg.drop_last = False

cfg.domain = domain
Expand Down
7 changes: 6 additions & 1 deletion otx/mpa/cls/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings

import mmcv
import torch.distributed as dist
from mmcls import __version__
from mmcls.core import DistOptimizerHook
from mmcls.datasets import build_dataloader, build_dataset
Expand Down Expand Up @@ -105,7 +106,11 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
drop_last = False
dataset_len = len(otx_dataset) if otx_dataset else 0
# if task == h-label & dataset size is bigger than batch size
if train_data_cfg.get("hierarchical_info", None) and dataset_len > cfg.data.get("samples_per_gpu", 2):
num_worlds = dist.get_world_size() if self.distributed else 1
if (
train_data_cfg.get("hierarchical_info", None)
and dataset_len > cfg.data.get("samples_per_gpu", 2) * num_worlds
):
drop_last = True
# updated to adapt list of dataset for the 'train'
data_loaders = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def test_otx_multi_gpu_train(self, template, tmp_dir_path):
"--learning_parameters.num_iters",
"2",
"--learning_parameters.batch_size",
"2",
"4",
],
}

Expand Down

0 comments on commit b509c11

Please sign in to comment.