Skip to content

Commit

Permalink
Optimize data preprocessing time and enhance overall performance in s…
Browse files Browse the repository at this point in the history
…emantic segmentation (#2020)
  • Loading branch information
supersoob authored Apr 19, 2023
1 parent d2c1dfe commit d2acd51
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ All notable changes to this project will be documented in this file.
- Extend OTX explain CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1941>)
- Segmentation task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1977>)
- Action task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1993>)
- Optimize data preprocessing time and enhance overall performance in semantic segmentation (<https://github.com/openvinotoolkit/training_extensions/pull/2020>)

### Bug fixes

Expand Down
16 changes: 10 additions & 6 deletions otx/algorithms/segmentation/adapters/mmseg/configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def configure_task(
"""Patch config to support training algorithm."""
if "task_adapt" in cfg:
logger.info(f"task config!!!!: training={training}")
cfg["task_adapt"].get("op", "REPLACE")

# Task classes
self.configure_classes(cfg)
Expand Down Expand Up @@ -421,6 +420,8 @@ def configure_samples_per_gpu(self, cfg: Config, subset: str) -> None:
# batch size of 1 is a runtime error for training batch normalization layer
if subset in ("train", "unlabeled") and dataset_len % samples_per_gpu == 1:
dataloader_cfg["drop_last"] = True
else:
dataloader_cfg["drop_last"] = False

cfg.data[f"{subset}_dataloader"] = dataloader_cfg

Expand Down Expand Up @@ -501,18 +502,21 @@ def configure_task(self, cfg: ConfigDict, training: bool) -> None:
"""Patch config to support incremental learning."""
super().configure_task(cfg, training)

new_classes: List[str] = np.setdiff1d(self.model_classes, self.org_model_classes).tolist()

# Check if new classes are added
has_new_class: bool = len(new_classes) > 0
# TODO: Revisit this part when removing bg label -> it should be 1 because of 'background' label
if len(set(self.org_model_classes) & set(self.model_classes)) == 1 or set(self.org_model_classes) == set(
self.model_classes
):
is_cls_incr = False
else:
is_cls_incr = True

# Update TaskAdaptHook (use incremental sampler)
task_adapt_hook = ConfigDict(
type="TaskAdaptHook",
src_classes=self.org_model_classes,
dst_classes=self.model_classes,
model_type=cfg.model.type,
sampler_flag=has_new_class,
sampler_flag=is_cls_incr,
efficient_mode=cfg["task_adapt"].get("efficient_mode", False),
)
update_or_add_custom_hook(cfg, task_adapt_hook)
Expand Down
1 change: 1 addition & 0 deletions otx/algorithms/segmentation/adapters/mmseg/nncf/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def build_nncf_segmentor( # noqa: C901 # pylint: disable=too-many-locals,too-m
device: Union[str, torch.device] = "cpu",
cfg_options: Optional[Union[Config, ConfigDict]] = None,
distributed=False,
**kwargs
):
"""A function to build NNCF wrapped mmcls model."""

Expand Down
2 changes: 1 addition & 1 deletion otx/algorithms/segmentation/adapters/mmseg/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _train_model(
)

# Model
model = self.build_model(cfg, fp16=cfg.get("fp16", False))
model = self.build_model(cfg, fp16=cfg.get("fp16", False), is_training=self._is_training)
model.train()
model.CLASSES = target_classes

Expand Down
8 changes: 6 additions & 2 deletions otx/algorithms/segmentation/adapters/mmseg/utils/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def build_segmentor(
device: Union[str, torch.device] = "cpu",
cfg_options: Optional[Union[Config, ConfigDict]] = None,
from_scratch: bool = False,
is_training: bool = False,
) -> torch.nn.Module:
"""A builder function for mmseg model.
Expand All @@ -58,9 +59,12 @@ def build_segmentor(
model = origin_build_segmentor(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg)
model = model.to(device)

checkpoint = checkpoint if checkpoint else config.pop("load_from", None)
checkpoint = checkpoint if checkpoint else config.get("load_from", None)
config.load_from = checkpoint

if checkpoint is not None and not from_scratch:
load_checkpoint(model, checkpoint, map_location=device)
config.load_from = checkpoint
if is_training is True:
config.load_from = None # To prevent the repeated ckpt loading in mmseg.apis.train_segmentor

return model
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,6 @@
dict(type="Resize", img_scale=__img_scale, ratio_range=(0.5, 2.0), keep_ratio=False),
dict(type="RandomCrop", crop_size=__crop_size, cat_max_ratio=0.75),
dict(type="RandomFlip", prob=0.5, direction="horizontal"),
dict(
type="MaskCompose",
prob=0.5,
lambda_limits=(4, 16),
keep_original=False,
transforms=[
dict(type="PhotoMetricDistortion"),
],
),
dict(type="Normalize", **__img_norm_cfg),
dict(type="Pad", size=__crop_size, pad_val=0, seg_pad_val=255),
dict(type="RandomRotate", prob=0.5, degree=30, pad_val=0, seg_pad_val=255),
Expand Down
9 changes: 0 additions & 9 deletions otx/recipes/stages/_base_/data/pipelines/incr_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@
dict(type="Resize", img_scale=__img_scale, ratio_range=(0.5, 2.0), keep_ratio=False),
dict(type="RandomCrop", crop_size=__crop_size, cat_max_ratio=0.75),
dict(type="RandomFlip", prob=0.5, direction="horizontal"),
dict(
type="MaskCompose",
prob=0.5,
lambda_limits=(4, 16),
keep_original=False,
transforms=[
dict(type="PhotoMetricDistortion"),
],
),
dict(type="Normalize", **__img_norm_cfg),
dict(type="Pad", size=__crop_size, pad_val=0, seg_pad_val=255),
dict(type="RandomRotate", prob=0.5, degree=30, pad_val=0, seg_pad_val=255),
Expand Down

0 comments on commit d2acd51

Please sign in to comment.