diff --git a/docs/source/ops.rst b/docs/source/ops.rst index d045334ce3c..472c45fbab4 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -87,6 +87,7 @@ TorchVision provides commonly used building blocks as layers: DeformConv2d DropBlock2d DropBlock3d + MLP FrozenBatchNorm2d SqueezeExcitation StochasticDepth diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 6e001c1d2dd..148bfa1c4a2 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -5,14 +5,14 @@ import torch.nn.functional as F from torch import nn, Tensor +from ..ops.misc import MLP from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param -from .convnext import Permute -from .vision_transformer import MLPBlock +from .convnext import Permute # TODO: move Permute on ops __all__ = [ @@ -263,7 +263,13 @@ def __init__( ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) + self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + + for m in self.mlp.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) def forward(self, x: Tensor): x = x + self.stochastic_depth(self.attn(self.norm1(x))) @@ -412,7 +418,7 @@ def _swin_transformer( class Swin_T_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_t-4c37bd06.pth", + url="https://download.pytorch.org/models/swin_t-704ceda3.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC ), @@ -435,7 +441,7 @@ class Swin_T_Weights(WeightsEnum): class Swin_S_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_s-30134662.pth", + url="https://download.pytorch.org/models/swin_s-5e29d889.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC ), @@ -458,7 +464,7 @@ class Swin_S_Weights(WeightsEnum): class Swin_B_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_b-1f1feb5c.pth", + url="https://download.pytorch.org/models/swin_b-68c6b09e.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC ), diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index dad2804e626..063d51749b4 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from ..ops.misc import Conv2dNormActivation +from ..ops.misc import Conv2dNormActivation, MLP from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights @@ -37,21 +37,48 @@ class ConvStemConfig(NamedTuple): activation_layer: Callable[..., nn.Module] = nn.ReLU -class MLPBlock(nn.Sequential): +class MLPBlock(MLP): """Transformer MLP block.""" def __init__(self, in_dim: int, mlp_dim: int, dropout: float): - super().__init__() - self.linear_1 = nn.Linear(in_dim, mlp_dim) - self.act = nn.GELU() - self.dropout_1 = nn.Dropout(dropout) - self.linear_2 = nn.Linear(mlp_dim, in_dim) - self.dropout_2 = nn.Dropout(dropout) - - nn.init.xavier_uniform_(self.linear_1.weight) - nn.init.xavier_uniform_(self.linear_2.weight) - nn.init.normal_(self.linear_1.bias, std=1e-6) - nn.init.normal_(self.linear_2.bias, std=1e-6) + super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053 + for i in range(2): + for type in ["weight", "bias"]: + old_key = f"{prefix}linear_{i+1}.{type}" + new_key = f"{prefix}{3*i}.{type}" + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) class EncoderBlock(nn.Module): diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index d3f27ef1657..333e9246401 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -19,7 +19,7 @@ from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss -from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation +from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool @@ -61,6 +61,7 @@ "Conv2dNormActivation", "Conv3dNormActivation", "SqueezeExcitation", + "MLP", "generalized_box_iou_loss", "distance_box_iou_loss", "complete_box_iou_loss", diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index a4635099215..2e4816c9f22 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -129,7 +129,7 @@ class Conv2dNormActivation(ConvNormActivation): padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` - activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. @@ -179,7 +179,7 @@ class Conv3dNormActivation(ConvNormActivation): padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d`` - activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. @@ -253,3 +253,47 @@ def _scale(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor: scale = self._scale(input) return scale * input + + +class MLP(torch.nn.Sequential): + """This block implements the multi-layer perceptron (MLP) module. + + Args: + in_channels (int): Number of channels of the input + hidden_channels (List[int]): List of the hidden channel dimensions + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None`` + activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool): Whether to use bias in the linear layer. Default ``True`` + dropout (float): The probability for the dropout layer. Default: 0.0 + """ + + def __init__( + self, + in_channels: int, + hidden_channels: List[int], + norm_layer: Optional[Callable[..., torch.nn.Module]] = None, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + inplace: Optional[bool] = True, + bias: bool = True, + dropout: float = 0.0, + ): + # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: + # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py + params = {} if inplace is None else {"inplace": inplace} + + layers = [] + in_dim = in_channels + for hidden_dim in hidden_channels[:-1]: + layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) + if norm_layer is not None: + layers.append(norm_layer(hidden_dim)) + layers.append(activation_layer(**params)) + layers.append(torch.nn.Dropout(dropout, **params)) + in_dim = hidden_dim + + layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) + layers.append(torch.nn.Dropout(dropout, **params)) + + super().__init__(*layers) + _log_api_usage_once(self)