Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregator for LiteHRNet #2876

Merged
merged 17 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/otx/algo/segmentation/heads/custom_fcn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from mmseg.registry import MODELS
from torch import Tensor, nn

from otx.algo.utils import IterativeAggregator

if TYPE_CHECKING:
from mmseg.utils import SampleList

Expand Down Expand Up @@ -96,3 +98,70 @@ class CustomFCNHead(ClassIncrementalMixin, FCNHead):

Custom FCNHead supports ignored label for class incremental learning cases.
"""

def __init__(
self,
enable_aggregator: bool = False,
aggregator_min_channels: int = 0,
aggregator_merge_norm: str | None = None,
aggregator_use_concat: bool = False,
in_channels: list[int] | int | None = None,
in_index: list[int] | int | None = None,
norm_cfg: dict | None = None,
conv_cfg: dict | None = None,
input_transform: list | None = None,
*args,
**kwargs,
):
if enable_aggregator: # Lite-HRNet aggregator
if in_channels is None or isinstance(in_channels, int):
msg = "'in_channels' should be List[int]."
raise ValueError(msg)
aggregator = IterativeAggregator(
in_channels=in_channels,
min_channels=aggregator_min_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
merge_norm=aggregator_merge_norm,
use_concat=aggregator_use_concat,
)

aggregator_min_channels = aggregator_min_channels if aggregator_min_channels is not None else 0
# change arguments temporarily
in_channels = max(in_channels[0], aggregator_min_channels)
input_transform = None
if isinstance(in_index, list):
in_index = in_index[0]
else:
aggregator = None

super().__init__(
*args,
in_index=in_index,
norm_cfg=norm_cfg,
conv_cfg=conv_cfg,
input_transform=input_transform,
in_channels=in_channels,
**kwargs,
)

self.aggregator = aggregator
# re-define variables
self.in_channels = in_channels
self.input_transform = input_transform
self.in_index = in_index

if self.act_cfg:
self.convs[-1].with_activation = False
delattr(self.convs[-1], "activate")

def _transform_inputs(self, inputs: list[Tensor]) -> Tensor | list:
"""Transform inputs for decoder.

Args:
inputs (list[Tensor]): List of multi-level img features.

Returns:
Tensor: The transformed inputs
"""
return self.aggregator(inputs)[0] if self.aggregator is not None else super()._transform_inputs(inputs)
6 changes: 6 additions & 0 deletions src/otx/algo/segmentation/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def _obtain_ignored_scope(self) -> dict[str, Any]:
"/backbone/stage2/stage2.1/Add_6",
"/backbone/stage2/stage2.1/Add_7",
"/backbone/stage2/stage2.1/Add_11",
"/aggregator/Add",
"/aggregator/Add_1",
"/aggregator/Add_2",
"/backbone/stage2/stage2.1/Add",
]

Expand Down Expand Up @@ -202,6 +205,8 @@ def _obtain_ignored_scope(self) -> dict[str, Any]:
"/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_2",
"/backbone/stage1/stage1.3/Add_2",
"/backbone/stage1/stage1.3/Add_5",
"/aggregator/Add",
"/aggregator/Add_1",
]

return {
Expand Down Expand Up @@ -371,6 +376,7 @@ def _obtain_ignored_scope(self) -> dict[str, Any]:

return {
"ignored_scope": {
"patterns": ["/aggregator/*"],
"names": ignored_scope_names,
},
"preset": "performance",
Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/segmentation/mmconfigs/litehrnet_18.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ decode_head:
- 1
- 2
- 3
input_transform: resize_concat
channels: 600
input_transform: "multiple_select"
channels: 40
enable_aggregator: True
kernel_size: 1
num_convs: 1
concat_input: false
Expand Down
7 changes: 5 additions & 2 deletions src/otx/algo/segmentation/mmconfigs/litehrnet_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,14 @@ decode_head:
- 0
- 1
- 2
input_transform: resize_concat
channels: 420
input_transform: "multiple_select"
channels: 60
kernel_size: 1
num_convs: 1
concat_input: false
enable_aggregator: True
aggregator_merge_norm: None
aggregator_use_concat: False
dropout_ratio: -1
num_classes: 2
norm_cfg:
Expand Down
8 changes: 6 additions & 2 deletions src/otx/algo/segmentation/mmconfigs/litehrnet_x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ decode_head:
- 2
- 3
- 4
input_transform: resize_concat
channels: 638
input_transform: "multiple_select"
channels: 60
kernel_size: 1
num_convs: 1
concat_input: false
dropout_ratio: -1
num_classes: 2
enable_aggregator: True
aggregator_min_channels: 60
aggregator_merge_norm: None
aggregator_use_concat: False
norm_cfg:
type: BN
requires_grad: true
Expand Down
Loading
Loading