Skip to content

Commit

Permalink
Change IR scale factor to 1.0 if tile size is too big (#2337)
Browse files Browse the repository at this point in the history
* adaptive pooling 6 -> 1

* set IR scale factor to 1.0

* replace adaptive avg pool with gap
  • Loading branch information
eugene123tw authored Jul 17, 2023
1 parent ddf5b41 commit 2d8fd86
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/otx/algorithms/common/configs/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ class BaseTilingParameters(ParameterGroup):
"to crash or result in out-of-memory errors. It is recommended to "
"adjust the scale factor value carefully based on the available "
"hardware resources and the needs of the application.",
default_value=2.0,
default_value=1.0,
min_value=1.0,
max_value=4.0,
affects_outcome_of=ModelLifecycle.NONE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
from mmcls.models.necks.gap import GlobalAveragePooling
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16
from mmdet.models.builder import DETECTORS
Expand All @@ -30,10 +31,15 @@ def __init__(self):
nn.MaxPool2d(kernel_size=3, stride=2),
ConvModule(192, 256, 3, padding=1, act_cfg=dict(type="ReLU")),
nn.MaxPool2d(kernel_size=3, stride=2),
ConvModule(256, 256, 3, padding=1, act_cfg=dict(type="ReLU")),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = torch.nn.AdaptiveAvgPool2d((6, 6))
# NOTE: Original Adaptive Avg Pooling is replaced with Global Avg Pooling
# due to ONNX tracing issues: https://github.com/openvinotoolkit/training_extensions/pull/2337

self.gap = GlobalAveragePooling()
self.classifier = torch.nn.Sequential(
torch.nn.Linear(256 * 6 * 6, 256),
torch.nn.Linear(256, 256),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(256, 256),
torch.nn.ReLU(inplace=True),
Expand All @@ -54,8 +60,7 @@ def forward(self, img: torch.Tensor) -> torch.Tensor:
torch.Tensor: logits
"""
x = self.features(img)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.gap(x)
y = self.classifier(x)
return y

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ tiling_parameters:
header: OpenVINO IR Scale Factor
description: The purpose of the scale parameter is to optimize the performance and efficiency of tiling in OpenVINO IR during inference. By controlling the increase in tile size and input size, the scale parameter allows for more efficient parallelization of the workload and improve the overall performance and efficiency of the inference process on OpenVINO.
affects_outcome_of: TRAINING
default_value: 2.0
default_value: 1.0
min_value: 1.0
max_value: 4.0
type: FLOAT
Expand All @@ -602,7 +602,7 @@ tiling_parameters:
operator: AND
rules: []
type: UI_RULES
value: 2.0
value: 1.0
visible_in_ui: true
warning: null

Expand Down
1 change: 0 additions & 1 deletion src/otx/algorithms/detection/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ def format_list_to_str(value_lists: list):
return f"[{str_value[:-2]}]"


# TODO [Eugene] please add unit test for this function
def adaptive_tile_params(
tiling_parameters: DetectionConfig.BaseTilingParameters, dataset: DatasetEntity, object_tile_ratio=0.03, rule="avg"
):
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/algorithms/detection/tiling/test_tiling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DEFAULT_ISEG_TEMPLATE_DIR,
init_environment,
)
from otx.algorithms.detection.utils.data import adaptive_tile_params


@DETECTORS.register_module(force=True)
Expand Down Expand Up @@ -428,3 +429,23 @@ def test_max_annotation(self, max_annotation=200):
assert len(data["gt_bboxes"].data[0][0]) <= max_annotation
assert len(data["gt_labels"].data[0][0]) <= max_annotation
assert len(data["gt_masks"].data[0][0]) <= max_annotation

@e2e_pytest_unit
def test_adaptive_tile_parameters(self):
model_template = parse_model_template(os.path.join(DEFAULT_ISEG_TEMPLATE_DIR, "template.yaml"))
hp = create(model_template.hyper_parameters.data)

default_tile_size = hp.tiling_parameters.tile_size
default_tile_overlap = hp.tiling_parameters.tile_overlap
default_tile_max_number = hp.tiling_parameters.tile_max_number

adaptive_tile_params(hp.tiling_parameters, self.otx_dataset)

# check tile size is changed
assert hp.tiling_parameters.tile_size != default_tile_size

# check tile overlap is changed
assert hp.tiling_parameters.tile_overlap != default_tile_overlap

# check max output prediction size is changed
assert hp.tiling_parameters.tile_max_number != default_tile_max_number

0 comments on commit 2d8fd86

Please sign in to comment.