Skip to content

Commit

Permalink
remove meta modification
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Apr 17, 2023
1 parent 9bf1fd9 commit 6ded090
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 20 deletions.
4 changes: 2 additions & 2 deletions otx/algorithms/action/adapters/mmaction/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def _train_model(
validate = bool(cfg.data.get("val", None))

if auto_adapt_bs:
train_func = partial(train_model, model=deepcopy(model), distributed=False)
adapt_batch_size(train_func, cfg, meta, datasets, validate)
train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
adapt_batch_size(train_func, cfg, datasets, validate)

train_model(
model,
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/classification/adapters/mmcls/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@ def _train_model(
)

if auto_adapt_bs:
train_func = partial(train_model, model=deepcopy(model), distributed=False)
adapt_batch_size(train_func, cfg, meta, datasets, False)
train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
adapt_batch_size(train_func, cfg, datasets, False)

train_model(
model,
Expand Down
3 changes: 0 additions & 3 deletions otx/algorithms/common/adapters/mmcv/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ def stop(self) -> bool:
This method supports distributed training by broadcasting should_stop to other ranks
:return: a cancellation bool
"""
if self.meta.get("run_single_iter", False):
return True

broadcast_obj = [False]
if self.rank == 0 and self.should_stop:
broadcast_obj = [True]
Expand Down
18 changes: 9 additions & 9 deletions otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
logger = get_logger()


def adapt_batch_size(train_func: Callable, cfg, meta: Dict, datasets: List, validate: bool = False):
def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool = False):
"""Decrease batch size if default batch size isn't fit to current GPU device.
This function just setup for single iteration training to reduce time for adapting.
Expand All @@ -41,7 +41,6 @@ def adapt_batch_size(train_func: Callable, cfg, meta: Dict, datasets: List, vali
"""
def train_func_single_iter(batch_size):
copied_cfg = deepcopy(cfg)
copied_meta = deepcopy(meta)
_set_batch_size(copied_cfg, batch_size)

# setup for training a single iter to reduce time
Expand All @@ -56,7 +55,6 @@ def train_func_single_iter(batch_size):
train_func(
dataset=new_datasets,
cfg=copied_cfg,
meta=copied_meta,
validate=validate,
)

Expand All @@ -68,7 +66,7 @@ def train_func_single_iter(batch_size):
)
_set_batch_size(cfg, available_bs)
cfg.optimizer.lr *= available_bs / default_bs
logger.info(f"Adpating batch size : {default_bs} -> {available_bs}")
logger.info(f"Result of the adpated batch size : {default_bs} -> {available_bs}")


def _get_batch_size(cfg) -> int:
Expand All @@ -85,15 +83,17 @@ def _set_batch_size(cfg, batch_size: int):


class SubDataset:
"""Wrapper class for DatasetEntity of dataset. It's used to make subset during HPO.
"""Wrapper class to make dataset pretend to have specified number of images.
Args:
fullset: full dataset
config (Optional[Dict[str, Any]], optional): hyper parameter trial config
indices (Optional[List[int]]): dataset index. Defaults to None.
fullset: Original dataset.
num_samples (int): Number of images to pretend to have. It should be positive.
"""

def __init__(self, fullset, num_sampels: Optional[int] = None):
def __init__(self, fullset, num_sampels: int):
if num_sampels <= 0:
raise ValueError(f"num_sampels should be positive. But, current value is {num_sampels}.")

self.fullset = fullset
self.num_sampels = num_sampels

Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def _train_model(
validate = bool(cfg.data.get("val", None))

if auto_adapt_bs:
train_func = partial(train_detector, model=deepcopy(model), distributed=False)
adapt_batch_size(train_func, cfg, meta, datasets, False)
train_func = partial(train_detector, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
adapt_batch_size(train_func, cfg, datasets, False)

train_detector(
model,
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/segmentation/adapters/mmseg/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ def _train_model(
validate = bool(cfg.data.get("val", None))

if auto_adapt_bs:
train_func = partial(train_segmentor, model=deepcopy(model), distributed=False)
adapt_batch_size(train_func, cfg, meta, datasets, False)
train_func = partial(train_segmentor, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
adapt_batch_size(train_func, cfg, datasets, False)

train_segmentor(
model,
Expand Down

0 comments on commit 6ded090

Please sign in to comment.