From 8b044b7ac60b2b92d4cece8926d92c4aded2569c Mon Sep 17 00:00:00 2001 From: Jihwan Eom Date: Tue, 14 Mar 2023 10:27:08 +0900 Subject: [PATCH] [FEATURE] Add MoViNet template for action classification (#1742) * Add MoViNet template for action classification Co-authored-by: Nikita Savelyev --- .../adapters/mmaction/models/__init__.py | 7 +- .../mmaction/models/backbones/__init__.py | 5 + .../mmaction/models/backbones/movinet.py | 771 ++++++++++++++++++ .../mmaction/models/heads/__init__.py | 4 +- .../mmaction/models/heads/movinet_head.py | 79 ++ .../mmaction/models/recognizers/__init__.py | 8 + .../models/recognizers/movinet_recognizer.py | 43 + .../classification/movinet/__init__.py | 15 + .../classification/movinet/data_pipeline.py | 78 ++ .../classification/movinet/deployment.py | 3 + .../configs/classification/movinet/model.py | 65 ++ .../classification/movinet/template.yaml | 63 ++ pyproject.toml | 5 + .../models/backbones/test_action_movinet.py | 203 +++++ .../test_action_register_backbone.py | 4 +- .../models/heads/test_action_movinet_head.py | 32 + .../mmaction/models/recognizers/__init__.py | 3 + .../test_action_movinet_recognizer.py | 67 ++ 18 files changed, 1449 insertions(+), 6 deletions(-) create mode 100644 otx/algorithms/action/adapters/mmaction/models/backbones/movinet.py create mode 100644 otx/algorithms/action/adapters/mmaction/models/heads/movinet_head.py create mode 100644 otx/algorithms/action/adapters/mmaction/models/recognizers/__init__.py create mode 100644 otx/algorithms/action/adapters/mmaction/models/recognizers/movinet_recognizer.py create mode 100644 otx/algorithms/action/configs/classification/movinet/__init__.py create mode 100644 otx/algorithms/action/configs/classification/movinet/data_pipeline.py create mode 100644 otx/algorithms/action/configs/classification/movinet/deployment.py create mode 100644 otx/algorithms/action/configs/classification/movinet/model.py create mode 100644 otx/algorithms/action/configs/classification/movinet/template.yaml create mode 100644 tests/unit/algorithms/action/adapters/mmaction/models/backbones/test_action_movinet.py create mode 100644 tests/unit/algorithms/action/adapters/mmaction/models/heads/test_action_movinet_head.py create mode 100644 tests/unit/algorithms/action/adapters/mmaction/models/recognizers/__init__.py create mode 100644 tests/unit/algorithms/action/adapters/mmaction/models/recognizers/test_action_movinet_recognizer.py diff --git a/otx/algorithms/action/adapters/mmaction/models/__init__.py b/otx/algorithms/action/adapters/mmaction/models/__init__.py index 919d80ace56..62a536df1e5 100644 --- a/otx/algorithms/action/adapters/mmaction/models/__init__.py +++ b/otx/algorithms/action/adapters/mmaction/models/__init__.py @@ -3,8 +3,9 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .backbones import register_action_backbones +from .backbones import OTXMoViNet, register_action_backbones from .detectors import AVAFastRCNN -from .heads import AVARoIHead +from .heads import AVARoIHead, MoViNetHead +from .recognizers import MoViNetRecognizer -__all__ = ["register_action_backbones", "AVAFastRCNN", "AVARoIHead"] +__all__ = ["register_action_backbones", "AVAFastRCNN", "OTXMoViNet", "MoViNetHead", "MoViNetRecognizer", "AVARoIHead"] diff --git a/otx/algorithms/action/adapters/mmaction/models/backbones/__init__.py b/otx/algorithms/action/adapters/mmaction/models/backbones/__init__.py index 787a75ea119..43ecc488014 100644 --- a/otx/algorithms/action/adapters/mmaction/models/backbones/__init__.py +++ b/otx/algorithms/action/adapters/mmaction/models/backbones/__init__.py @@ -6,7 +6,12 @@ from mmaction.models.backbones.x3d import X3D from mmdet.models import BACKBONES as MMDET_BACKBONES +from .movinet import OTXMoViNet + def register_action_backbones(): """Register action backbone to mmdetection backbones.""" MMDET_BACKBONES.register_module()(X3D) + + +__all__ = ["OTXMoViNet"] diff --git a/otx/algorithms/action/adapters/mmaction/models/backbones/movinet.py b/otx/algorithms/action/adapters/mmaction/models/backbones/movinet.py new file mode 100644 index 00000000000..2803d02cf56 --- /dev/null +++ b/otx/algorithms/action/adapters/mmaction/models/backbones/movinet.py @@ -0,0 +1,771 @@ +"""Code modified from: https://github.com/Atze00/MoViNet-pytorch/blob/main/movinets/models.py.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from collections import OrderedDict +from typing import Any, Callable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from einops import rearrange +from mmaction.models.builder import BACKBONES +from mmcv.utils import Config +from torch import Tensor, nn +from torch.nn.modules.utils import _pair, _triple + + +class Conv2dBNActivation(nn.Sequential): + """A base module that applies a 2D Conv-BN-Activation. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + kernel_size (Union[int, Tuple[int, int]]): Size of the convolution kernel. + padding (Union[int, Tuple[int, int]]): Size of the padding applied to the input. + stride (Union[int, Tuple[int, int]], optional): Stride of the convolution. Default: 1. + groups (int, optional): Number of groups in the convolution. Default: 1. + norm_layer (Optional[Callable[..., nn.Module]], optional): Normalization layer to use. + If None, identity is used. Default: None. + activation_layer (Optional[Callable[..., nn.Module]], optional): Activation layer to use. + If None, identity is used. Default: None. + **kwargs (Any): Additional keyword arguments passed to nn.Conv2d. + + Attributes: + kernel_size (Tuple[int, int]): Size of the convolution kernel. + stride (Tuple[int, int]): Stride of the convolution. + out_channels (int): Number of output channels. + + """ + + def __init__( + self, + in_planes: int, + out_planes: int, + *, + kernel_size: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, + ) -> None: + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + if norm_layer is None: + norm_layer = nn.Identity + if activation_layer is None: + activation_layer = nn.Identity + self.kernel_size = kernel_size + self.stride = stride + dict_layers = OrderedDict( + { + "conv2d": nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + **kwargs, + ), + "norm": norm_layer(out_planes, eps=0.001), + "act": activation_layer(), + } + ) + + self.out_channels = out_planes + super().__init__(dict_layers) + + +class Conv3DBNActivation(nn.Sequential): + """A base module that applies a 3D Conv-BN-Activation. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolution kernel. + padding (Union[int, Tuple[int, int, int]]): Size of the padding applied to the input. + stride (Union[int, Tuple[int, int, int]], optional): Stride of the convolution. Default: 1. + groups (int, optional): Number of groups in the convolution. Default: 1. + norm_layer (Optional[Callable[..., nn.Module]], optional): Normalization layer to use. + If None, identity is used. Default: None. + activation_layer (Optional[Callable[..., nn.Module]], optional): Activation layer to use. + If None, identity is used. Default: None. + **kwargs (Any): Additional keyword arguments passed to nn.Conv3d. + + Attributes: + kernel_size (Tuple[int, int, int]): Size of the convolution kernel. + stride (Tuple[int, int, int]): Stride of the convolution. + out_channels (int): Number of output channels. + + """ + + def __init__( + self, + in_planes: int, + out_planes: int, + *, + kernel_size: Union[int, Tuple[int, int, int]], + padding: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, + ) -> None: + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + if norm_layer is None: + norm_layer = nn.Identity + if activation_layer is None: + activation_layer = nn.Identity + self.kernel_size = kernel_size + self.stride = stride + + dict_layers = OrderedDict( + { + "conv3d": nn.Conv3d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + **kwargs, + ), + "norm": norm_layer(out_planes, eps=0.001), + "act": activation_layer(), + } + ) + + self.out_channels = out_planes + super().__init__(dict_layers) + + +class ConvBlock3D(nn.Module): + """A module that applies a 2+1D or 3D Conv-BN-activation sequential. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + kernel_size (Tuple[int, int, int]): Size of the convolution kernel. + tf_like (bool): Whether to use TensorFlow-like padding and convolution. + conv_type (str): Type of 3D convolution to use. Must be "2plus1d" or "3d". + padding (Tuple[int, int, int], optional): Size of the padding applied to the input. + Default: (0, 0, 0). + stride (Tuple[int, int, int], optional): Stride of the convolution. Default: (1, 1, 1). + norm_layer (Optional[Callable[..., nn.Module]], optional): Normalization layer to use. + If None, identity is used. Default: None. + activation_layer (Optional[Callable[..., nn.Module]], optional): Activation layer to use. + If None, identity is used. Default: None. + bias (bool, optional): Whether to use bias in the convolution. Default: False. + **kwargs (Any): Additional keyword arguments passed to nn.Conv2d or nn.Conv3d. + + Attributes: + conv_1 (Union[Conv2dBNActivation, Conv3DBNActivation]): Convolutional layer. + conv_2 (Optional[Conv2dBNActivation]): Convolutional layer for 2+1D convolution. + padding (Tuple[int, int, int]): Size of the padding applied to the input. + kernel_size (Tuple[int, int, int]): Size of the convolution kernel. + dim_pad (int): Padding along the temporal dimension. + stride (Tuple[int, int, int]): Stride of the convolution. + conv_type (str): Type of 3D convolution used. + tf_like (bool): Whether to use TensorFlow-like padding and convolution. + + """ + + # pylint: disable=too-many-instance-attributes, too-many-arguments + def __init__( + self, + in_planes: int, + out_planes: int, + kernel_size: Tuple[int, int, int], + tf_like: bool, + conv_type: str, + padding: Tuple[int, int, int] = (0, 0, 0), + stride: Tuple[int, int, int] = (1, 1, 1), + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = None, + bias: bool = False, + **kwargs: Any, + ) -> None: + super().__init__() + self.conv_2 = None + if tf_like: + # We need odd kernel to have even padding + # and stride == 1 to precompute padding, + if kernel_size[0] % 2 == 0: + raise ValueError("tf_like supports only odd" + " kernels for temporal dimension") + padding = ((kernel_size[0] - 1) // 2, 0, 0) + if stride[0] != 1: + raise ValueError("illegal stride value, tf like supports" + " only stride == 1 for temporal dimension") + if stride[1] > kernel_size[1] or stride[2] > kernel_size[2]: + # these values are not tested so should be avoided + raise ValueError("tf_like supports only" + " stride <= of the kernel size") + + if conv_type not in ["2plus1d", "3d"]: + raise ValueError("only 2plus2d or 3d are " + "allowed as 3d convolutions") + + if conv_type == "2plus1d": + self.conv_1 = Conv2dBNActivation( + in_planes, + out_planes, + kernel_size=(kernel_size[1], kernel_size[2]), + padding=(padding[1], padding[2]), + stride=(stride[1], stride[2]), + activation_layer=activation_layer, + norm_layer=norm_layer, + bias=bias, + **kwargs, + ) + if kernel_size[0] > 1: + self.conv_2 = Conv2dBNActivation( + in_planes, + out_planes, + kernel_size=(kernel_size[0], 1), + padding=(padding[0], 0), + stride=(stride[0], 1), + activation_layer=activation_layer, + norm_layer=norm_layer, + bias=bias, + **kwargs, + ) + elif conv_type == "3d": + self.conv_1 = Conv3DBNActivation( + in_planes, + out_planes, + kernel_size=kernel_size, + padding=padding, + activation_layer=activation_layer, + norm_layer=norm_layer, + stride=stride, + bias=bias, + **kwargs, + ) + self.padding = padding + self.kernel_size = kernel_size + self.dim_pad = self.kernel_size[0] - 1 + self.stride = stride + self.conv_type = conv_type + self.tf_like = tf_like + + def _forward(self, x: Tensor) -> Tensor: + shape_with_buffer = x.shape + if self.conv_type == "2plus1d": + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv_1(x) + if self.conv_type == "2plus1d": + x = rearrange(x, "(b t) c h w -> b c t h w", t=shape_with_buffer[2]) + if self.conv_2 is not None: + w = x.shape[-1] + x = rearrange(x, "b c t h w -> b c t (h w)") + x = self.conv_2(x) + x = rearrange(x, "b c t (h w) -> b c t h w", w=w) + return x + + def forward(self, x: Tensor) -> Tensor: + """Forward function of ConvBlock3D.""" + if self.tf_like: + x = same_padding( + x, + x.shape[-2], + x.shape[-1], + self.stride[-2], + self.stride[-1], + self.kernel_size[-2], + self.kernel_size[-1], + ) + x = self._forward(x) + return x + + +class SqueezeExcitation(nn.Module): + """Implements the Squeeze-and-Excitation (SE) block. + + Args: + input_channels (int): Number of input channels. + activation_2 (nn.Module): Activation function applied after the second convolutional block. + activation_1 (nn.Module): Activation function applied after the first convolutional block. + conv_type (str): Convolutional block type ("2plus1d" or "3d"). + squeeze_factor (int, optional): The reduction factor for the number of channels (default: 4). + bias (bool, optional): Whether to add a bias term to the convolutional blocks (default: True). + """ + + def __init__( + self, + input_channels: int, + activation_2: nn.Module, + activation_1: nn.Module, + conv_type: str, + squeeze_factor: int = 4, + bias: bool = True, + ) -> None: + super().__init__() + se_multiplier = 1 + squeeze_channels = _make_divisible(input_channels // squeeze_factor * se_multiplier, 8) + self.fc1 = ConvBlock3D( + input_channels * se_multiplier, + squeeze_channels, + kernel_size=(1, 1, 1), + padding=(0, 0, 0), + tf_like=False, + conv_type=conv_type, + bias=bias, + ) + self.activation_1 = activation_1() + self.activation_2 = activation_2() + self.fc2 = ConvBlock3D( + squeeze_channels, + input_channels, + kernel_size=(1, 1, 1), + padding=(0, 0, 0), + tf_like=False, + conv_type=conv_type, + bias=bias, + ) + + def _scale(self, x: Tensor) -> Tensor: + """Computes the scaling factor for the input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, time, height, width). + + Returns: + torch.Tensor: Scaling factor for the input tensor of shape (batch_size, channels, 1, 1, 1). + """ + scale = F.adaptive_avg_pool3d(x, 1) + scale = self.fc1(scale) + scale = self.activation_1(scale) + scale = self.fc2(scale) + return self.activation_2(scale) + + def forward(self, x: Tensor) -> Tensor: + """Forward function of SqueezeExcitation.""" + scale = self._scale(x) + return scale * x + + +def _make_divisible(value: float, divisor: int, min_value: Optional[int] = None) -> int: + if min_value is None: + min_value = divisor + new_v = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * value: + new_v += divisor + return new_v + + +def same_padding( + x: Tensor, in_height: int, in_width: int, stride_h: int, stride_w: int, filter_height: int, filter_width: int +) -> Tensor: + """Applies padding to the input tensor to ensure that the output tensor size is the same as the input tensor size. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, time, height, width). + in_height (int): Height of the input tensor. + in_width (int): Width of the input tensor. + stride_h (int): Stride in the height dimension. + stride_w (int): Stride in the width dimension. + filter_height (int): Height of the filter (kernel). + filter_width (int): Width of the filter (kernel). + + Returns: + torch.Tensor: Padded tensor of shape (batch_size, channels, time, height + pad_h, width + pad_w), where + pad_h and pad_w are the heights and widths of the top, bottom, left, and right padding applied to the tensor. + + """ + if in_height % stride_h == 0: + pad_along_height = max(filter_height - stride_h, 0) + else: + pad_along_height = max(filter_height - (in_height % stride_h), 0) + if in_width % stride_w == 0: + pad_along_width = max(filter_width - stride_w, 0) + else: + pad_along_width = max(filter_width - (in_width % stride_w), 0) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + padding_pad = (pad_left, pad_right, pad_top, pad_bottom) + return torch.nn.functional.pad(x, padding_pad) + + +class TFAvgPool3D(nn.Module): + """3D average pooling layer with padding.""" + + def __init__(self) -> None: + super().__init__() + self.avgf = nn.AvgPool3d((1, 3, 3), stride=(1, 2, 2)) + + def forward(self, x: Tensor) -> Tensor: + """Applies 3D average pooling with padding to the input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, time, height, width). + + Returns: + torch.Tensor: Pooled tensor of shape (batch_size, channels, time, height', width'), where + height' and width' are the heights and widths of the pooled tensor after padding is applied. + + """ + use_padding = x.shape[-1] % 2 != 0 + if use_padding: + padding_pad = (0, 0, 0, 0) + else: + padding_pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, padding_pad) + if use_padding: + x = torch.nn.functional.avg_pool3d( + x, (1, 3, 3), stride=(1, 2, 2), count_include_pad=False, padding=(0, 1, 1) + ) + else: + x = self.avgf(x) + x[..., -1] = x[..., -1] * 9 / 6 + x[..., -1, :] = x[..., -1, :] * 9 / 6 + return x + + +class BasicBneck(nn.Module): + """Basic bottleneck block of MoViNet network. + + Args: + cfg (Config): Configuration object containing block's hyperparameters. + tf_like (bool): A boolean indicating whether to use TensorFlow like convolution + padding or not. + conv_type (str): A string indicating the type of convolutional layer to use. + Can be "2d" or "3d". + norm_layer (Callable[..., nn.Module], optional): A callable normalization layer + to use. Defaults to None. + activation_layer (Callable[..., nn.Module], optional): A callable activation + layer to use. Defaults to None. + + Attributes: + expand (ConvBlock3D, optional): An optional expansion convolutional block. + deep (ConvBlock3D): A convolutional block with kernel size, stride, padding, + and groups as specified in the configuration object. + squeeze_excitation (SqueezeExcitation): A squeeze-and-excitation block. + project (ConvBlock3D): A projection convolutional block. + res (nn.Sequential, optional): An optional residual convolutional block. + alpha (nn.Parameter): A learnable parameter used in the ReZero operation. + + Raises: + AssertionError: If the stride in configuration is not a tuple. + + """ + + def __init__( + self, + cfg: "Config", + tf_like: bool, + conv_type: str, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + assert isinstance(cfg.stride, tuple) + self.res = None + + layers = [] + if cfg.expanded_channels != cfg.out_channels: + self.expand = ConvBlock3D( + in_planes=cfg.input_channels, + out_planes=cfg.expanded_channels, + kernel_size=(1, 1, 1), + padding=(0, 0, 0), + conv_type=conv_type, + tf_like=tf_like, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + self.deep = ConvBlock3D( + in_planes=cfg.expanded_channels, + out_planes=cfg.expanded_channels, + kernel_size=cfg.kernel_size, + padding=cfg.padding, + stride=cfg.stride, # type: ignore + groups=cfg.expanded_channels, + conv_type=conv_type, + tf_like=tf_like, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + # pylint: disable=invalid-name + self.se = SqueezeExcitation( + cfg.expanded_channels, + activation_1=activation_layer, + activation_2=(nn.Sigmoid if conv_type == "3d" else nn.Hardsigmoid), + conv_type=conv_type, + ) + self.project = ConvBlock3D( + cfg.expanded_channels, + cfg.out_channels, + kernel_size=(1, 1, 1), + padding=(0, 0, 0), + conv_type=conv_type, + tf_like=tf_like, + norm_layer=norm_layer, + activation_layer=nn.Identity, + ) + + if not (cfg.stride == (1, 1, 1) and cfg.input_channels == cfg.out_channels): + if cfg.stride != (1, 1, 1): + if tf_like: + layers.append(TFAvgPool3D()) + else: + layers.append(nn.AvgPool3d((1, 3, 3), stride=cfg.stride, padding=cfg.padding_avg)) + layers.append( + ConvBlock3D( + in_planes=cfg.input_channels, + out_planes=cfg.out_channels, + kernel_size=(1, 1, 1), + padding=(0, 0, 0), + norm_layer=norm_layer, + activation_layer=nn.Identity, + conv_type=conv_type, + tf_like=tf_like, + ) + ) + self.res = nn.Sequential(*layers) + # ReZero + self.alpha = nn.Parameter(torch.tensor(0.0), requires_grad=True) + + def forward(self, x: Tensor) -> Tensor: + """Forward function of BasicBneck.""" + if self.res is not None: + residual = self.res(x) + else: + residual = x + if hasattr(self, "expand"): + x = self.expand(x) + x = self.deep(x) + x = self.se(x) + x = self.project(x) + result = residual + self.alpha * x + return result + + +class MoViNet(nn.Module): + """MoViNet class used for video classification. + + Args: + cfg (Config): Configuration object containing network's hyperparameters. + conv_type (str, optional): A string indicating the type of convolutional layer + to use. Can be "2d" or "3d". Defaults to "3d". + tf_like (bool, optional): A boolean indicating whether to use TensorFlow like + convolution padding or not. Defaults to False. + + Attributes: + conv1 (ConvBlock3D): A convolutional block for the first layer. + blocks (nn.Sequential): A sequence of basic bottleneck blocks. + conv7 (ConvBlock3D): A convolutional block for the final layer. + + Methods: + avg(x: Tensor) -> Tensor: A static method that returns the adaptive average pool + of the input tensor. + _init_weights(module): A private method that initializes the weights of the network's + convolutional, batch normalization, and linear layers. + forward(x: Tensor) -> Tensor: The forward pass of the network. + + """ + + def __init__( + self, + cfg: "Config", + conv_type: str = "3d", + tf_like: bool = False, + ) -> None: + super().__init__() + tf_like = True + blocks_dic = OrderedDict() + + norm_layer = nn.BatchNorm3d if conv_type == "3d" else nn.BatchNorm2d + activation_layer = nn.SiLU if conv_type == "3d" else nn.Hardswish + + self.conv1 = ConvBlock3D( + in_planes=cfg.conv1.input_channels, + out_planes=cfg.conv1.out_channels, + kernel_size=cfg.conv1.kernel_size, + stride=cfg.conv1.stride, + padding=cfg.conv1.padding, + conv_type=conv_type, + tf_like=tf_like, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + for i, block in enumerate(cfg.blocks): + for j, basicblock in enumerate(block): + blocks_dic[f"b{i}_l{j}"] = BasicBneck( + basicblock, + conv_type=conv_type, + tf_like=tf_like, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + self.blocks = nn.Sequential(blocks_dic) + self.conv7 = ConvBlock3D( + in_planes=cfg.conv7.input_channels, + out_planes=cfg.conv7.out_channels, + kernel_size=cfg.conv7.kernel_size, + stride=cfg.conv7.stride, + padding=cfg.conv7.padding, + conv_type=conv_type, + tf_like=tf_like, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + + def avg(self, x: Tensor) -> Tensor: + """Returns the adaptive average pool of the input tensor. + + Args: + x (Tensor): A tensor to be averaged. + + Returns: + Tensor: A tensor with the averaged values. + + """ + return F.adaptive_avg_pool3d(x, 1) + + @staticmethod + def _init_weights(module): + if isinstance(module, nn.Conv3d): + nn.init.kaiming_normal_(module.weight, mode="fan_out") + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.BatchNorm3d, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + nn.init.normal_(module.weight, 0, 0.01) + nn.init.zeros_(module.bias) + + def forward(self, x: Tensor) -> Tensor: + """Forward function of MoViNet.""" + x = self.conv1(x) + x = self.blocks(x) + x = self.conv7(x) + x = self.avg(x) + return x + + def init_weights(self): + """Initializes the weights of network.""" + self.apply(self._init_weights) + + +@BACKBONES.register_module() +class OTXMoViNet(MoViNet): + """MoViNet wrapper class for OTX.""" + + # pylint: disable=unused-argument + def __init__(self, **kwargs): + cfg = Config() + cfg.name = "A0" + cfg.conv1 = Config() + OTXMoViNet.fill_conv(cfg.conv1, 3, 8, (1, 3, 3), (1, 2, 2), (0, 1, 1)) + + cfg.blocks = [ + [Config()], + [Config() for _ in range(3)], + [Config() for _ in range(3)], + [Config() for _ in range(4)], + [Config() for _ in range(4)], + ] + + # block 2 + OTXMoViNet.fill_se_config(cfg.blocks[0][0], 8, 8, 24, (1, 5, 5), (1, 2, 2), (0, 2, 2), (0, 1, 1)) + + # block 3 + OTXMoViNet.fill_se_config(cfg.blocks[1][0], 8, 32, 80, (3, 3, 3), (1, 2, 2), (1, 0, 0), (0, 0, 0)) + OTXMoViNet.fill_se_config(cfg.blocks[1][1], 32, 32, 80, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + OTXMoViNet.fill_se_config(cfg.blocks[1][2], 32, 32, 80, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + + # block 4 + OTXMoViNet.fill_se_config(cfg.blocks[2][0], 32, 56, 184, (5, 3, 3), (1, 2, 2), (2, 0, 0), (0, 0, 0)) + OTXMoViNet.fill_se_config(cfg.blocks[2][1], 56, 56, 112, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + OTXMoViNet.fill_se_config(cfg.blocks[2][2], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + + # block 5 + OTXMoViNet.fill_se_config(cfg.blocks[3][0], 56, 56, 184, (5, 3, 3), (1, 1, 1), (2, 1, 1), (0, 1, 1)) + OTXMoViNet.fill_se_config(cfg.blocks[3][1], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + OTXMoViNet.fill_se_config(cfg.blocks[3][2], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + OTXMoViNet.fill_se_config(cfg.blocks[3][3], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + + # block 6 + OTXMoViNet.fill_se_config(cfg.blocks[4][0], 56, 104, 384, (5, 3, 3), (1, 2, 2), (2, 1, 1), (0, 1, 1)) + OTXMoViNet.fill_se_config(cfg.blocks[4][1], 104, 104, 280, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)) + OTXMoViNet.fill_se_config(cfg.blocks[4][2], 104, 104, 280, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)) + OTXMoViNet.fill_se_config(cfg.blocks[4][3], 104, 104, 344, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)) + + cfg.conv7 = Config() + OTXMoViNet.fill_conv(cfg.conv7, 104, 480, (1, 1, 1), (1, 1, 1), (0, 0, 0)) + + cfg.dense9 = Config(dict(hidden_dim=2048)) + super().__init__(cfg) + + @staticmethod + # pylint: disable=too-many-arguments + def fill_se_config( + conf, + input_channels, + out_channels, + expanded_channels, + kernel_size, + stride, + padding, + padding_avg, + ): + """Set the values of a given Config object to SE module. + + Args: + conf (Config): The Config object to be updated. + input_channels (int): The number of input channels. + out_channels (int): The number of output channels. + expanded_channels (int): The number of channels after expansion in the basic block. + kernel_size (tuple[int]): The size of the kernel. + stride (tuple[int]): The stride of the kernel. + padding (tuple[int]): The padding of the kernel. + padding_avg (tuple[int]): The padding for the average pooling operation. + + Returns: + None. + """ + conf.expanded_channels = expanded_channels + conf.padding_avg = padding_avg + OTXMoViNet.fill_conv( + conf, + input_channels, + out_channels, + kernel_size, + stride, + padding, + ) + + @staticmethod + def fill_conv( + conf, + input_channels, + out_channels, + kernel_size, + stride, + padding, + ): + """Set the values of a given Config object to conv layer. + + Args: + conf (Config): The Config object to be updated. + input_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (tuple[int]): The size of the kernel. + stride (tuple[int]): The stride of the kernel. + padding (tuple[int]): The padding of the kernel. + + Returns: + None. + """ + conf.input_channels = input_channels + conf.out_channels = out_channels + conf.kernel_size = kernel_size + conf.stride = stride + conf.padding = padding diff --git a/otx/algorithms/action/adapters/mmaction/models/heads/__init__.py b/otx/algorithms/action/adapters/mmaction/models/heads/__init__.py index 23c67c6eab1..dacea417b3a 100644 --- a/otx/algorithms/action/adapters/mmaction/models/heads/__init__.py +++ b/otx/algorithms/action/adapters/mmaction/models/heads/__init__.py @@ -3,7 +3,7 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - +from .movinet_head import MoViNetHead from .roi_head import AVARoIHead -__all__ = ["AVARoIHead"] +__all__ = ["AVARoIHead", "MoViNetHead"] diff --git a/otx/algorithms/action/adapters/mmaction/models/heads/movinet_head.py b/otx/algorithms/action/adapters/mmaction/models/heads/movinet_head.py new file mode 100644 index 00000000000..7f3b62cd283 --- /dev/null +++ b/otx/algorithms/action/adapters/mmaction/models/heads/movinet_head.py @@ -0,0 +1,79 @@ +"""MoViNet head for otx action recognition.""" +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from mmaction.models.builder import HEADS +from mmaction.models.heads.base import BaseHead +from mmcv.cnn import normal_init +from torch import nn + +from otx.algorithms.action.adapters.mmaction.models.backbones.movinet import ConvBlock3D + + +@HEADS.register_module() +class MoViNetHead(BaseHead): + """Classification head for MoViNet. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + hidden_dim (int): Number of channels in hidden layer. + tf_like (bool): If True, uses TensorFlow-style padding. Default: False. + conv_type (str): Type of convolutional layer. Default: '3d'. + loss_cls (dict): Config for building loss. Default: dict(type='CrossEntropyLoss'). + spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. + dropout_ratio (float): Probability of dropout layer. Default: 0.5. + init_std (float): Standard deviation for initialization. Default: 0.1. + """ + + def __init__( + self, + num_classes: int, + in_channels: int, + hidden_dim: int, + loss_cls: dict, + tf_like: bool = False, + conv_type: str = "3d", + ): + super().__init__(num_classes, in_channels, loss_cls) + self.init_std = 0.1 + self.classifier = nn.Sequential( + ConvBlock3D( + in_channels, + hidden_dim, + kernel_size=(1, 1, 1), + tf_like=tf_like, + conv_type=conv_type, + bias=True, + ), + nn.SiLU(), + nn.Dropout(p=0.2, inplace=True), + ConvBlock3D( + hidden_dim, + num_classes, + kernel_size=(1, 1, 1), + tf_like=tf_like, + conv_type=conv_type, + bias=True, + ), + ) + + def init_weights(self): + """Initialize the parameters from scratch.""" + normal_init(self.classifier, std=self.init_std) + + def forward(self, x): + """Defines the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + + Returns: + torch.Tensor: The classification scores for input samples. + """ + # [N, in_channels, T, H, W] + cls_score = self.classifier(x) + cls_score = cls_score.flatten(1) + # [N, num_classes] + return cls_score diff --git a/otx/algorithms/action/adapters/mmaction/models/recognizers/__init__.py b/otx/algorithms/action/adapters/mmaction/models/recognizers/__init__.py new file mode 100644 index 00000000000..258cd9d3074 --- /dev/null +++ b/otx/algorithms/action/adapters/mmaction/models/recognizers/__init__.py @@ -0,0 +1,8 @@ +"""OTX Adapters for action recognition backbones - mmaction2.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .movinet_recognizer import MoViNetRecognizer + +__all__ = ["MoViNetRecognizer"] diff --git a/otx/algorithms/action/adapters/mmaction/models/recognizers/movinet_recognizer.py b/otx/algorithms/action/adapters/mmaction/models/recognizers/movinet_recognizer.py new file mode 100644 index 00000000000..148dcee1c3b --- /dev/null +++ b/otx/algorithms/action/adapters/mmaction/models/recognizers/movinet_recognizer.py @@ -0,0 +1,43 @@ +"""MoViNet Recognizer for OTX compatibility.""" +# pylint: disable=unused-argument +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import functools + +from mmaction.models.builder import RECOGNIZERS +from mmaction.models.recognizers.recognizer3d import Recognizer3D + + +@RECOGNIZERS.register_module() +class MoViNetRecognizer(Recognizer3D): + """MoViNet recognizer model framework for OTX compatibility.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Hooks for redirect state_dict load/save + self._register_state_dict_hook(self.state_dict_hook) + self._register_load_state_dict_pre_hook(functools.partial(self.load_state_dict_pre_hook, self)) + + @staticmethod + def state_dict_hook(module, state_dict, *args, **kwargs): + """Redirect model as output state_dict for OTX MoviNet compatibility.""" + for key in list(state_dict.keys()): + val = state_dict.pop(key) + if "cls_head" in key: + key = key.replace("cls_head.", "") + else: + key = key.replace("backbone.", "") + state_dict[key] = val + + @staticmethod + def load_state_dict_pre_hook(module, state_dict, prefix, *args, **kwargs): + """Redirect input state_dict to model for OTX model compatibility.""" + for key in list(state_dict.keys()): + val = state_dict.pop(key) + if "classifier" in key: + key = key.replace("classifier", "cls_head.classifier") + else: + key = prefix + "backbone." + key[len(prefix) :] + state_dict[key] = val diff --git a/otx/algorithms/action/configs/classification/movinet/__init__.py b/otx/algorithms/action/configs/classification/movinet/__init__.py new file mode 100644 index 00000000000..bc73b4ae2e1 --- /dev/null +++ b/otx/algorithms/action/configs/classification/movinet/__init__.py @@ -0,0 +1,15 @@ +"""Initialization of MoViNet model for Action Classification Task.""" + +# Copyright (C) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. diff --git a/otx/algorithms/action/configs/classification/movinet/data_pipeline.py b/otx/algorithms/action/configs/classification/movinet/data_pipeline.py new file mode 100644 index 00000000000..869caa9291b --- /dev/null +++ b/otx/algorithms/action/configs/classification/movinet/data_pipeline.py @@ -0,0 +1,78 @@ +"""Data Pipeline of MoViNet model for Action Classification Task.""" + +# Copyright (C) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +# pylint: disable=invalid-name +# dataset settings +seed = 2 +dataset_type = "RawframeDataset" + +img_norm_cfg = dict(mean=[0.0, 0.0, 0.0], std=[255.0, 255.0, 255.0], to_bgr=False) + +clip_len = 8 +frame_interval = 4 +train_pipeline = [ + dict(type="SampleFrames", clip_len=clip_len, frame_interval=frame_interval, num_clips=1), + dict(type="RawFrameDecode"), + dict(type="Resize", scale=(-1, 256)), + dict(type="RandomResizedCrop"), + dict(type="Resize", scale=(224, 224), keep_ratio=False), + dict(type="Flip", flip_ratio=0.5), + dict(type="Normalize", **img_norm_cfg), + dict(type="FormatShape", input_format="NCTHW"), + dict(type="Collect", keys=["imgs", "label"], meta_keys=[]), + dict(type="ToTensor", keys=["imgs", "label"]), +] + +val_pipeline = [ + dict(type="SampleFrames", clip_len=clip_len, frame_interval=frame_interval, num_clips=1, test_mode=True), + dict(type="RawFrameDecode"), + dict(type="Resize", scale=(-1, 256)), + dict(type="CenterCrop", crop_size=224), + dict(type="Normalize", **img_norm_cfg), + dict(type="FormatShape", input_format="NCTHW"), + dict(type="Collect", keys=["imgs", "label"], meta_keys=[]), + dict(type="ToTensor", keys=["imgs"]), +] +# TODO Delete label in meta key in test pipeline +test_pipeline = [ + dict(type="SampleFrames", clip_len=clip_len, frame_interval=frame_interval, num_clips=1, test_mode=True), + dict(type="RawFrameDecode"), + dict(type="Resize", scale=(-1, 256)), + dict(type="CenterCrop", crop_size=224), + dict(type="Normalize", **img_norm_cfg), + dict(type="FormatShape", input_format="NCTHW"), + dict(type="Collect", keys=["imgs"], meta_keys=[]), + dict(type="ToTensor", keys=["imgs"]), +] + +data = dict( + videos_per_gpu=10, + workers_per_gpu=0, + val_dataloader=dict(videos_per_gpu=1), + test_dataloader=dict(videos_per_gpu=1), + train=dict( + type=dataset_type, + pipeline=train_pipeline, + ), + val=dict( + type=dataset_type, + pipeline=val_pipeline, + ), + test=dict( + type=dataset_type, + pipeline=test_pipeline, + ), +) diff --git a/otx/algorithms/action/configs/classification/movinet/deployment.py b/otx/algorithms/action/configs/classification/movinet/deployment.py new file mode 100644 index 00000000000..37e92c99c2a --- /dev/null +++ b/otx/algorithms/action/configs/classification/movinet/deployment.py @@ -0,0 +1,3 @@ +"""MMDeploy config of MoViNet model for Action classification Task.""" + +_base_ = ["../base/base_classification_dynamic.py"] diff --git a/otx/algorithms/action/configs/classification/movinet/model.py b/otx/algorithms/action/configs/classification/movinet/model.py new file mode 100644 index 00000000000..a8b3e7cb89c --- /dev/null +++ b/otx/algorithms/action/configs/classification/movinet/model.py @@ -0,0 +1,65 @@ +"""Model configuration of MoViNet model for Action Classification Task.""" + +# Copyright (C) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +# pylint: disable=invalid-name + +num_classes = 400 +model = dict( + type="MoViNetRecognizer", + backbone=dict(type="OTXMoViNet"), + cls_head=dict( + type="MoViNetHead", + in_channels=480, + hidden_dim=2048, + num_classes=num_classes, + loss_cls=dict(type="CrossEntropyLoss", loss_weight=1.0), + ), + # model training and testing settings + train_cfg=None, + test_cfg=dict(average_clips="prob"), +) + + +evaluation = dict(interval=1, metrics=["top_k_accuracy", "mean_class_accuracy"], final_metric="mean_class_accuracy") + +optimizer = dict( + type="AdamW", + lr=0.003, + weight_decay=0.0001, +) + +optimizer_config = dict(grad_clip=dict(max_norm=40.0, norm_type=2)) +lr_config = dict(policy="CosineAnnealing", min_lr=0) +total_epochs = 5 + +# runtime settings +checkpoint_config = dict(interval=1) +log_config = dict( + interval=10, + hooks=[ + dict(type="TextLoggerHook", ignore_last=False), + ], +) +# runtime settings +log_level = "INFO" +workflow = [("train", 1)] + +find_unused_parameters = False +gpu_ids = range(0, 1) + +dist_params = dict(backend="nccl") +resume_from = None +load_from = "https://github.com/Atze00/MoViNet-pytorch/blob/main/weights/modelA0_statedict_v3?raw=true" diff --git a/otx/algorithms/action/configs/classification/movinet/template.yaml b/otx/algorithms/action/configs/classification/movinet/template.yaml new file mode 100644 index 00000000000..6fee18320db --- /dev/null +++ b/otx/algorithms/action/configs/classification/movinet/template.yaml @@ -0,0 +1,63 @@ +# Description. +model_template_id: Custom_Action_Classificaiton_MoViNet +name: MoViNet +task_type: ACTION_CLASSIFICATION +task_family: VISION +instantiation: "CLASS" +summary: Basic transfer learning template for MoViNet +application: ~ + +# Algo backend. +framework: OTXAction v2.9.1 + +# Task implementations. +entrypoints: + base: otx.algorithms.action.tasks.ActionTrainTask + openvino: otx.algorithms.action.tasks.ActionOpenVINOTask + +# Capabilities. +capabilities: + - compute_representations + +# Hyperparameters. +hyper_parameters: + base_path: ../configuration.yaml + parameter_overrides: + learning_parameters: + batch_size: + default_value: 8 + auto_hpo_state: POSSIBLE + learning_rate: + default_value: 0.003 + auto_hpo_state: POSSIBLE + learning_rate_warmup_iters: + default_value: 3 + num_iters: + default_value: 10 + nncf_optimization: + enable_quantization: + default_value: true + enable_pruning: + default_value: false + pruning_supported: + default_value: true + maximal_accuracy_degradation: + default_value: 1.0 + algo_backend: + train_type: + default_value: INCREMENTAL + +# Training resources. +max_nodes: 1 +training_targets: + - GPU + - CPU + +# Stats. +gigaflops: 2.71 +size: 3.1 +# # Inference options. Defined by OpenVINO capabilities, not Algo Backend or Platform. +# inference_targets: +# - CPU +# - GPU +# - VPU diff --git a/pyproject.toml b/pyproject.toml index b26f42cab26..f4ed51f28fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,11 @@ good-names = [ "ys", "p", "f", + "b", + "c", + "t", + "w", + "h", ] [tool.pylint.imports] diff --git a/tests/unit/algorithms/action/adapters/mmaction/models/backbones/test_action_movinet.py b/tests/unit/algorithms/action/adapters/mmaction/models/backbones/test_action_movinet.py new file mode 100644 index 00000000000..0138b65d198 --- /dev/null +++ b/tests/unit/algorithms/action/adapters/mmaction/models/backbones/test_action_movinet.py @@ -0,0 +1,203 @@ +"""Unit test for otx.algorithms.action.adapters.mmaction.models.backbones.movinet""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +import pytest +import torch +from torch import nn + +from otx.algorithms.action.adapters.mmaction.models.backbones.movinet import ( + BasicBneck, + Conv2dBNActivation, + ConvBlock3D, + MoViNet, + OTXMoViNet, + SqueezeExcitation, + TFAvgPool3D, +) +from tests.test_suite.e2e_test_system import e2e_pytest_unit + + +class TestConv2dBNActivation: + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.layer = Conv2dBNActivation(3, 16, kernel_size=3, padding=1) + + @e2e_pytest_unit + def test_conv2d_bn_activation_output_shape(self): + x = torch.Tensor(1, 3, 32, 32) + output = self.layer(x) + assert output.shape == (1, 16, 32, 32) + + @e2e_pytest_unit + def test_conv2d_bn_activation_attributes(self): + assert self.layer.kernel_size == (3, 3) + assert self.layer.stride == (1, 1) + assert self.layer.out_channels == 16 + + +class TestConvBlock3D: + @e2e_pytest_unit + def test_conv_block_3d_output_shape(self): + x = torch.Tensor(1, 3, 32, 32, 32) + layer = ConvBlock3D(3, 16, kernel_size=(3, 3, 3), tf_like=True, conv_type="3d") + output = layer(x) + assert output.shape == (1, 16, 32, 32, 32) + + @e2e_pytest_unit + @pytest.mark.parametrize("conv_type", ["3d", "2plus1d"]) + def test_conv_block_3d_attributes(self, conv_type): + layer = ConvBlock3D(3, 16, kernel_size=(3, 3, 3), tf_like=True, conv_type=conv_type) + assert layer.kernel_size == (3, 3, 3) + assert layer.stride == (1, 1, 1) + assert layer.dim_pad == 2 + assert layer.conv_type == conv_type + assert layer.tf_like + + +class TestSqueezeExcitation: + @pytest.fixture + def se_block(self): + return SqueezeExcitation(16, nn.ReLU, nn.Sigmoid, conv_type="2plus1d", squeeze_factor=4, bias=True) + + @e2e_pytest_unit + def test_scale_output_shape(self, se_block): + x = torch.Tensor(1, 16, 32, 32, 32) + scale = se_block._scale(x) + assert scale.shape == (1, 16, 1, 1, 1) + + @e2e_pytest_unit + def test_forward_output_shape(self, se_block): + x = torch.Tensor(1, 16, 32, 32, 32) + output = se_block(x) + assert output.shape == (1, 16, 32, 32, 32) + + @e2e_pytest_unit + def test_se_block_attributes(self, se_block): + assert se_block.fc1.kernel_size == (1, 1, 1) + assert se_block.fc2.kernel_size == (1, 1, 1) + assert se_block.fc1.conv_type == "2plus1d" + assert se_block.fc2.conv_type == "2plus1d" + + +class TestTFAvgPool3D: + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.pool = TFAvgPool3D() + + @e2e_pytest_unit + def test_tf_avg_pool_output_shape(self): + x = torch.Tensor(1, 3, 32, 32, 32) + output = self.pool(x) + assert output.shape == (1, 3, 32, 16, 16) + + @e2e_pytest_unit + def test_tf_avg_pool_output_shape_odd(self): + x = torch.Tensor(1, 3, 31, 31, 31) + output = self.pool(x) + assert output.shape == (1, 3, 31, 16, 16) + + @e2e_pytest_unit + def test_tf_avg_pool_output_shape_odd_padding(self): + x = torch.Tensor(1, 3, 30, 30, 30) + output = self.pool(x) + assert output.shape == (1, 3, 30, 15, 15) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +class TestBasicBneck: + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.config = AttrDict( + input_channels=64, + expanded_channels=64, + out_channels=64, + kernel_size=(3, 3, 3), + padding=(1, 1, 1), + stride=(1, 1, 1), + padding_avg=(1, 1, 1), + ) + + @e2e_pytest_unit + def test_basic_bneck_output_shape(self): + module = BasicBneck(self.config, tf_like=False, conv_type="3d", activation_layer=nn.ReLU) + x = torch.randn(1, self.config.input_channels, 32, 32, 32) + output = module(x) + assert output.shape == (1, self.config.out_channels, 32, 32, 32) + + +class TestMoViNet: + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.cfg = AttrDict() + self.cfg.conv1 = AttrDict( + { + "input_channels": 3, + "out_channels": 16, + "kernel_size": (3, 5, 5), + "stride": (1, 1, 1), + "padding": (1, 2, 2), + } + ) + self.cfg.blocks = [ + [ + AttrDict( + { + "input_channels": 16, + "expanded_channels": 24, + "out_channels": 24, + "kernel_size": (3, 3, 3), + "stride": (1, 1, 1), + "padding": (1, 1, 1), + } + ), + ] + ] + self.cfg.conv7 = AttrDict( + { + "input_channels": 40, + "out_channels": 256, + "kernel_size": (1, 1, 1), + "stride": (1, 1, 1), + "padding": (0, 0, 0), + } + ) + + @e2e_pytest_unit + def test_movinet_output_shape(self): + module = MoViNet(self.cfg) + x = torch.randn(1, 3, 32, 32, 32) + module.conv1 = nn.Identity() + module.blocks = nn.Identity() + module.conv7 = nn.Identity() + output = module(x) + assert output.shape == (1, 3, 1, 1, 1) + + @e2e_pytest_unit + def test_init_weights(self): + module = MoViNet(self.cfg) + module.apply(module._init_weights) + for m in module.modules(): + if isinstance(m, nn.Conv3d): + if m.bias is not None: + assert m.bias.mean().item() == pytest.approx(0, abs=1e-2) + assert m.bias.std().item() == pytest.approx(0, abs=1e-2) + elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.GroupNorm)): + assert m.bias.mean().item() == pytest.approx(0, abs=1e-2) + assert m.bias.std().item() == pytest.approx(0, abs=1e-2) + elif isinstance(m, nn.Linear): + assert m.bias.mean().item() == pytest.approx(0, abs=1e-2) + assert m.bias.std().item() == pytest.approx(0, abs=1e-2) + + @e2e_pytest_unit + def test_OTXMoViNet(self): + model = OTXMoViNet() + input_tensor = torch.randn(1, 3, 32, 224, 224) + output_tensor = model(input_tensor) + assert output_tensor.shape == torch.Size([1, 480, 1, 1, 1]) diff --git a/tests/unit/algorithms/action/adapters/mmaction/models/backbones/test_action_register_backbone.py b/tests/unit/algorithms/action/adapters/mmaction/models/backbones/test_action_register_backbone.py index 4ee0c8a3018..0bdb5aeacec 100644 --- a/tests/unit/algorithms/action/adapters/mmaction/models/backbones/test_action_register_backbone.py +++ b/tests/unit/algorithms/action/adapters/mmaction/models/backbones/test_action_register_backbone.py @@ -3,7 +3,7 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # - +from mmaction.models import BACKBONES as MMACTION_BACKBONES from mmdet.models import BACKBONES as MMDET_BACKBONES from tests.test_suite.e2e_test_system import e2e_pytest_unit @@ -17,3 +17,5 @@ def test_register_action_backbones() -> None: """ assert "X3D" in MMDET_BACKBONES + assert "X3D" in MMACTION_BACKBONES + assert "OTXMoViNet" in MMACTION_BACKBONES diff --git a/tests/unit/algorithms/action/adapters/mmaction/models/heads/test_action_movinet_head.py b/tests/unit/algorithms/action/adapters/mmaction/models/heads/test_action_movinet_head.py new file mode 100644 index 00000000000..1f2601b4999 --- /dev/null +++ b/tests/unit/algorithms/action/adapters/mmaction/models/heads/test_action_movinet_head.py @@ -0,0 +1,32 @@ +"""Unit Test for otx.algorithms.action.adapters.mmaction.heads.movinet_head.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import pytest +import torch + +from otx.algorithms.action.adapters.mmaction.models.heads.movinet_head import ( + MoViNetHead, +) +from tests.test_suite.e2e_test_system import e2e_pytest_unit + + +class TestMoViNetHead: + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.movinet_head = MoViNetHead( + num_classes=400, + in_channels=480, + hidden_dim=2048, + loss_cls=dict(type="CrossEntropyLoss", loss_weight=1.0), + ) + + @e2e_pytest_unit + def test_forward(self) -> None: + """Test forward function.""" + sample_input = torch.randn(1, 480, 1, 1, 1) + with torch.no_grad(): + out = self.movinet_head(sample_input) + assert out.shape == (1, self.movinet_head.num_classes) diff --git a/tests/unit/algorithms/action/adapters/mmaction/models/recognizers/__init__.py b/tests/unit/algorithms/action/adapters/mmaction/models/recognizers/__init__.py new file mode 100644 index 00000000000..79931efa777 --- /dev/null +++ b/tests/unit/algorithms/action/adapters/mmaction/models/recognizers/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2021-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/tests/unit/algorithms/action/adapters/mmaction/models/recognizers/test_action_movinet_recognizer.py b/tests/unit/algorithms/action/adapters/mmaction/models/recognizers/test_action_movinet_recognizer.py new file mode 100644 index 00000000000..0a487333203 --- /dev/null +++ b/tests/unit/algorithms/action/adapters/mmaction/models/recognizers/test_action_movinet_recognizer.py @@ -0,0 +1,67 @@ +"""Unit Test for otx.algorithms.action.adapters.mmaction.models.recognizers.movinet_recognizer.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +from copy import deepcopy + +import pytest +import torch +from mmaction.models.recognizers.recognizer3d import Recognizer3D + +from otx.algorithms.action.adapters.mmaction.models.recognizers.movinet_recognizer import ( + MoViNetRecognizer, +) +from tests.test_suite.e2e_test_system import e2e_pytest_unit + + +class MockOTXMoViNet: + pass + + +class MockModule: + def __init__(self): + self.backbone = MockOTXMoViNet() + self._state_dict = { + "classifier.0.conv_1.conv3d.weight": torch.rand(1, 1), + "conv1.conv_1.conv3d.weight": torch.rand(1, 1), + } + self.is_export = False + + def state_dict(self): + return self._state_dict + + +class TestMoViNetRecognizer: + @pytest.fixture(autouse=True) + def setup(self, mocker) -> None: + mocker.patch.object(Recognizer3D, "__init__", return_value=None) + MoViNetRecognizer._register_state_dict_hook = mocker.MagicMock() + MoViNetRecognizer._register_load_state_dict_pre_hook = mocker.MagicMock() + self.recognizer = MoViNetRecognizer() + self.prefix = "" + + @e2e_pytest_unit + def test_load_state_dict_pre_hook(self) -> None: + """Test load_state_dict_pre_hook function.""" + module = MockModule() + state_dict = module.state_dict() + self.recognizer.load_state_dict_pre_hook(module, state_dict, prefix=self.prefix) + + for key in state_dict: + if "classifier" in key: + assert "cls_head.classifier.0.conv_1.conv3d.weight" in state_dict + else: + assert "backbone.conv1.conv_1.conv3d.weight" in state_dict + + @e2e_pytest_unit + def test_state_dict_hook(self): + """Test state_dict_hook function.""" + module = MockModule() + state_dict = module.state_dict() + state_dict_copy = deepcopy(state_dict) + self.recognizer.load_state_dict_pre_hook(module, state_dict, prefix=self.prefix) + # backward state dict + self.recognizer.state_dict_hook(module, state_dict, prefix=self.prefix) + + assert state_dict.keys() == state_dict_copy.keys()