Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring ConvModule by removing norm_cfg #3816

Merged
merged 27 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d193ffc
Update `build_norm_layer` to use `norm_callable`
sungchul2 Aug 8, 2024
92c2cf6
WIP
sungchul2 Aug 8, 2024
3cf69d9
Replace `norm_cfg` with `norm_callable`
sungchul2 Aug 8, 2024
d51c1df
Update `activation_callable` docstring
sungchul2 Aug 8, 2024
d26dc3b
Update `CHANGELOG`
sungchul2 Aug 8, 2024
e53ac57
Enable using pre-assigned nn.Module
sungchul2 Aug 13, 2024
49da71e
Update to use pre-assigned norm layer in `ConvModule`
sungchul2 Aug 13, 2024
8e267dd
Fix
sungchul2 Aug 13, 2024
f22b380
Enable `partial(build_norm_layer, ...)`
sungchul2 Aug 13, 2024
90b6ff5
Fix unit test
sungchul2 Aug 13, 2024
1872278
Fix typo
sungchul2 Aug 14, 2024
78719bc
Merge branch 'develop' into remove-norm_cfg
sungchul2 Aug 14, 2024
8af87ad
Fix unit test
sungchul2 Aug 14, 2024
ea2a195
Fix unit test
sungchul2 Aug 14, 2024
ce26ab7
Restore `build_activation_layer` and update `ConvModule` to use preas…
sungchul2 Aug 14, 2024
45a4cb1
Update to use `build_activation_layer`
sungchul2 Aug 14, 2024
4dea934
Fix
sungchul2 Aug 14, 2024
bc92646
Enable to get nn.Module
sungchul2 Aug 14, 2024
2958de4
Remove `callable` in arg name
sungchul2 Aug 14, 2024
4d95822
Fix unit test
sungchul2 Aug 14, 2024
5605a90
Merge branch 'develop' into remove-norm_cfg
sungchul2 Aug 14, 2024
c937cba
Fix rtdetr18
sungchul2 Aug 14, 2024
fb3a39f
Fix
sungchul2 Aug 14, 2024
f9b9b95
precommit
sungchul2 Aug 14, 2024
66fea80
Merge branch 'develop' into remove-norm_cfg
sungchul2 Aug 16, 2024
42bfd9d
Fix
sungchul2 Aug 16, 2024
20f8979
Merge branch 'develop' into remove-norm_cfg
sungchul2 Aug 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/3759>)
- Enable to use polygon and bitmap mask as prompt inputs for zero-shot learning
(<https://github.com/openvinotoolkit/training_extensions/pull/3769>)
- Refactoring `ConvModule` by removing `conv_cfg` and `act_cfg`
(<https://github.com/openvinotoolkit/training_extensions/pull/3783>, <https://github.com/openvinotoolkit/training_extensions/pull/3809>)
- Refactoring `ConvModule` by removing `conv_cfg`, `norm_cfg`, and `act_cfg`
(<https://github.com/openvinotoolkit/training_extensions/pull/3783>, <https://github.com/openvinotoolkit/training_extensions/pull/3816>, <https://github.com/openvinotoolkit/training_extensions/pull/3809>)

### Bug fixes

Expand Down
89 changes: 45 additions & 44 deletions src/otx/algo/action_classification/backbones/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm

from otx.algo.modules.activation import Swish
from otx.algo.modules.activation import Swish, build_activation_layer
from otx.algo.modules.conv_module import Conv3dModule
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.algo.utils.weight_init import constant_init, kaiming_init

Expand Down Expand Up @@ -72,10 +73,10 @@ class BlockX3D(nn.Module):
unit. If set as None, it means not using SE unit. Default: None.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
norm_cfg (dict): Config for norm layers. required keys are ``type``,
Default: ``dict(type='BN3d')``.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to `nn.ReLU`.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
Expand All @@ -89,8 +90,8 @@ def __init__(
downsample: nn.Module | None = None,
se_ratio: float | None = None,
use_swish: bool = True,
norm_cfg: dict | None = None,
activation_callable: Callable[..., nn.Module] | None = nn.ReLU,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | None = nn.ReLU,
with_cp: bool = False,
):
super().__init__()
Expand All @@ -102,8 +103,8 @@ def __init__(
self.downsample = downsample
self.se_ratio = se_ratio
self.use_swish = use_swish
self.norm_cfg = norm_cfg
self.activation_callable = activation_callable
self.normalization = normalization
self.activation = activation
self.with_cp = with_cp

self.conv1 = Conv3dModule(
Expand All @@ -113,8 +114,8 @@ def __init__(
stride=1,
padding=0,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=self.activation_callable,
normalization=build_norm_layer(normalization, num_features=planes),
activation=build_activation_layer(activation),
)
# Here we use the channel-wise conv
self.conv2 = Conv3dModule(
Expand All @@ -125,8 +126,8 @@ def __init__(
padding=1,
groups=planes,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=None,
normalization=build_norm_layer(normalization, num_features=planes),
activation=None,
)

self.swish = Swish()
Expand All @@ -138,14 +139,14 @@ def __init__(
stride=1,
padding=0,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=None,
normalization=build_norm_layer(normalization, num_features=outplanes),
activation=None,
)

if self.se_ratio is not None:
self.se_module = SEModule(planes, self.se_ratio)

self.relu = self.activation_callable() if self.activation_callable else nn.ReLU(inplace=True)
self.relu = self.activation() if self.activation else nn.ReLU(inplace=True)

def forward(self, x: Tensor) -> Tensor:
"""Defines the computation performed at every call."""
Expand Down Expand Up @@ -195,11 +196,10 @@ class X3DBackbone(nn.Module):
unit. If set as None, it means not using SE unit. Default: 1 / 16.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
norm_cfg (dict): Config for norm layers. required keys are ``type`` and
``requires_grad``.
Default: ``dict(type='BN3d', requires_grad=True)``.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to `nn.ReLU`.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var). Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
Expand All @@ -223,8 +223,8 @@ def __init__(
se_style: str = "half",
se_ratio: float = 1 / 16,
use_swish: bool = True,
norm_cfg: dict | None = None,
activation_callable: Callable[..., nn.Module] | None = nn.ReLU,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | None = nn.ReLU,
norm_eval: bool = False,
with_cp: bool = False,
zero_init_residual: bool = True,
Expand Down Expand Up @@ -266,8 +266,8 @@ def __init__(
raise ValueError(msg)
self.use_swish = use_swish

self.norm_cfg = norm_cfg
self.activation_callable = activation_callable
self.normalization = normalization
self.activation = activation
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
Expand All @@ -293,8 +293,8 @@ def __init__(
se_style=self.se_style,
se_ratio=self.se_ratio,
use_swish=self.use_swish,
norm_cfg=self.norm_cfg,
activation_callable=self.activation_callable,
normalization=self.normalization,
activation=self.activation,
with_cp=with_cp,
**kwargs,
)
Expand All @@ -311,8 +311,8 @@ def __init__(
stride=1,
padding=0,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=self.activation_callable,
normalization=build_norm_layer(self.normalization, num_features=int(self.feat_dim * self.gamma_b)),
activation=build_activation_layer(self.activation),
)
self.feat_dim = int(self.feat_dim * self.gamma_b)

Expand Down Expand Up @@ -349,8 +349,8 @@ def make_res_layer(
se_style: str = "half",
se_ratio: float | None = None,
use_swish: bool = True,
norm_cfg: dict | None = None,
activation_callable: Callable[..., nn.Module] | None = nn.ReLU,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | None = nn.ReLU,
with_cp: bool = False,
**kwargs,
) -> nn.Module:
Expand All @@ -375,9 +375,10 @@ def make_res_layer(
Default: None.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
norm_cfg (dict | None): Config for norm layers. Default: None.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to `nn.ReLU`.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
with_cp (bool | None): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
Expand All @@ -394,8 +395,8 @@ def make_res_layer(
stride=(1, spatial_stride, spatial_stride),
padding=0,
bias=False,
norm_cfg=norm_cfg,
activation_callable=None,
normalization=build_norm_layer(normalization, num_features=inplanes),
activation=None,
)

use_se = [False] * blocks
Expand All @@ -416,8 +417,8 @@ def make_res_layer(
downsample=downsample,
se_ratio=se_ratio if use_se[0] else None,
use_swish=use_swish,
norm_cfg=norm_cfg,
activation_callable=activation_callable,
normalization=normalization,
activation=activation,
with_cp=with_cp,
**kwargs,
),
Expand All @@ -432,8 +433,8 @@ def make_res_layer(
spatial_stride=1,
se_ratio=se_ratio if use_se[i] else None,
use_swish=use_swish,
norm_cfg=norm_cfg,
activation_callable=activation_callable,
normalization=normalization,
activation=activation,
with_cp=with_cp,
**kwargs,
),
Expand All @@ -450,8 +451,8 @@ def _make_stem_layer(self) -> None:
stride=(1, 2, 2),
padding=(0, 1, 1),
bias=False,
norm_cfg=None,
activation_callable=None,
normalization=None,
activation=None,
)
self.conv1_t = Conv3dModule(
self.base_channels,
Expand All @@ -461,8 +462,8 @@ def _make_stem_layer(self) -> None:
padding=(2, 0, 0),
groups=self.base_channels,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=self.activation_callable,
normalization=build_norm_layer(self.normalization, num_features=self.base_channels),
activation=build_activation_layer(self.activation),
)

def _freeze_stages(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/action_classification/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from otx.algo.action_classification.backbones.x3d import X3DBackbone
from otx.algo.action_classification.heads.x3d_head import X3DHead
from otx.algo.action_classification.recognizers.recognizer import BaseRecognizer
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
Expand Down Expand Up @@ -67,8 +68,8 @@ def _build_model(self, num_classes: int) -> nn.Module:
gamma_b=2.25,
gamma_d=2.2,
gamma_w=1,
norm_cfg={"type": "BN3d", "requires_grad": True},
activation_callable=partial(nn.ReLU, inplace=True),
normalization=partial(build_norm_layer, nn.BatchNorm3d, requires_grad=True),
activation=partial(nn.ReLU, inplace=True),
),
cls_head=X3DHead(
num_classes=num_classes,
Expand Down
Loading
Loading