Skip to content

Commit

Permalink
Adding multi-layer perceptron in ops (#6053)
Browse files Browse the repository at this point in the history
* Adding an MLP block.

* Adding documentation

* Update typos.

* Fix inplace for Dropout.

* Apply recommendations from code review.

* Making changes on pre-trained models.

* Fix linter
  • Loading branch information
datumbox authored May 19, 2022
1 parent e65372e commit 77cad12
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 22 deletions.
1 change: 1 addition & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ TorchVision provides commonly used building blocks as layers:
DeformConv2d
DropBlock2d
DropBlock3d
MLP
FrozenBatchNorm2d
SqueezeExcitation
StochasticDepth
Expand Down
18 changes: 12 additions & 6 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import torch.nn.functional as F
from torch import nn, Tensor

from ..ops.misc import MLP
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param
from .convnext import Permute
from .vision_transformer import MLPBlock
from .convnext import Permute # TODO: move Permute on ops


__all__ = [
Expand Down Expand Up @@ -263,7 +263,13 @@ def __init__(
)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout)
self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

for m in self.mlp.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)

def forward(self, x: Tensor):
x = x + self.stochastic_depth(self.attn(self.norm1(x)))
Expand Down Expand Up @@ -412,7 +418,7 @@ def _swin_transformer(

class Swin_T_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_t-4c37bd06.pth",
url="https://download.pytorch.org/models/swin_t-704ceda3.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
),
Expand All @@ -435,7 +441,7 @@ class Swin_T_Weights(WeightsEnum):

class Swin_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_s-30134662.pth",
url="https://download.pytorch.org/models/swin_s-5e29d889.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC
),
Expand All @@ -458,7 +464,7 @@ class Swin_S_Weights(WeightsEnum):

class Swin_B_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_b-1f1feb5c.pth",
url="https://download.pytorch.org/models/swin_b-68c6b09e.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
),
Expand Down
53 changes: 40 additions & 13 deletions torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn as nn

from ..ops.misc import Conv2dNormActivation
from ..ops.misc import Conv2dNormActivation, MLP
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
Expand Down Expand Up @@ -37,21 +37,48 @@ class ConvStemConfig(NamedTuple):
activation_layer: Callable[..., nn.Module] = nn.ReLU


class MLPBlock(nn.Sequential):
class MLPBlock(MLP):
"""Transformer MLP block."""

def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__()
self.linear_1 = nn.Linear(in_dim, mlp_dim)
self.act = nn.GELU()
self.dropout_1 = nn.Dropout(dropout)
self.linear_2 = nn.Linear(mlp_dim, in_dim)
self.dropout_2 = nn.Dropout(dropout)

nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight)
nn.init.normal_(self.linear_1.bias, std=1e-6)
nn.init.normal_(self.linear_2.bias, std=1e-6)
super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)

def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)

if version is None or version < 2:
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
for i in range(2):
for type in ["weight", "bias"]:
old_key = f"{prefix}linear_{i+1}.{type}"
new_key = f"{prefix}{3*i}.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)

super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)


class EncoderBlock(nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP
from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool
Expand Down Expand Up @@ -61,6 +61,7 @@
"Conv2dNormActivation",
"Conv3dNormActivation",
"SqueezeExcitation",
"MLP",
"generalized_box_iou_loss",
"distance_box_iou_loss",
"complete_box_iou_loss",
Expand Down
48 changes: 46 additions & 2 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class Conv2dNormActivation(ConvNormActivation):
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
Expand Down Expand Up @@ -179,7 +179,7 @@ class Conv3dNormActivation(ConvNormActivation):
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
Expand Down Expand Up @@ -253,3 +253,47 @@ def _scale(self, input: Tensor) -> Tensor:
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input


class MLP(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module.
Args:
in_channels (int): Number of channels of the input
hidden_channels (List[int]): List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool): Whether to use bias in the linear layer. Default ``True``
dropout (float): The probability for the dropout layer. Default: 0.0
"""

def __init__(
self,
in_channels: int,
hidden_channels: List[int],
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
inplace: Optional[bool] = True,
bias: bool = True,
dropout: float = 0.0,
):
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
params = {} if inplace is None else {"inplace": inplace}

layers = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
if norm_layer is not None:
layers.append(norm_layer(hidden_dim))
layers.append(activation_layer(**params))
layers.append(torch.nn.Dropout(dropout, **params))
in_dim = hidden_dim

layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
layers.append(torch.nn.Dropout(dropout, **params))

super().__init__(*layers)
_log_api_usage_once(self)

0 comments on commit 77cad12

Please sign in to comment.