Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding multi-layer perceptron in ops #6053

Merged
merged 9 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ TorchVision provides commonly used building blocks as layers:
DeformConv2d
DropBlock2d
DropBlock3d
MLP
FrozenBatchNorm2d
SqueezeExcitation
StochasticDepth
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param
from .convnext import Permute
from .vision_transformer import MLPBlock
from .vision_transformer import MLPBlock # TODO: remove this, patch the weights and use MLP instead


__all__ = [
Expand Down
53 changes: 40 additions & 13 deletions torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn as nn

from ..ops.misc import Conv2dNormActivation
from ..ops.misc import Conv2dNormActivation, MLP
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
Expand Down Expand Up @@ -37,21 +37,48 @@ class ConvStemConfig(NamedTuple):
activation_layer: Callable[..., nn.Module] = nn.ReLU


class MLPBlock(nn.Sequential):
class MLPBlock(MLP):
"""Transformer MLP block."""

def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__()
self.linear_1 = nn.Linear(in_dim, mlp_dim)
self.act = nn.GELU()
self.dropout_1 = nn.Dropout(dropout)
self.linear_2 = nn.Linear(mlp_dim, in_dim)
self.dropout_2 = nn.Dropout(dropout)

nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight)
nn.init.normal_(self.linear_1.bias, std=1e-6)
nn.init.normal_(self.linear_2.bias, std=1e-6)
super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)

def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)

if version is None or version < 2:
datumbox marked this conversation as resolved.
Show resolved Hide resolved
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
for i in range(2):
for type in ["weight", "bias"]:
old_key = f"{prefix}linear_{i+1}.{type}"
new_key = f"{prefix}{3*i}.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)

super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)


class EncoderBlock(nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP
from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool
Expand Down Expand Up @@ -61,6 +61,7 @@
"Conv2dNormActivation",
"Conv3dNormActivation",
"SqueezeExcitation",
"MLP",
"generalized_box_iou_loss",
"distance_box_iou_loss",
"complete_box_iou_loss",
Expand Down
46 changes: 44 additions & 2 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class Conv2dNormActivation(ConvNormActivation):
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
Expand Down Expand Up @@ -179,7 +179,7 @@ class Conv3dNormActivation(ConvNormActivation):
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
Expand Down Expand Up @@ -253,3 +253,45 @@ def _scale(self, input: Tensor) -> Tensor:
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input


class MLP(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module.

Args:
in_channels (int): Number of channels of the input
hidden_channels (List[int]): List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Annotations in docstring make me sad :'(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know how you feel above this. :( I think all the callables all over TorchVision are added like that to provide info on what they are supposed to return.

inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool): Whether to use bias in the linear layer. Default ``True``
dropout (float): The probability for the dropout layer. Default: 0.0
"""

def __init__(
self,
in_channels: int,
hidden_channels: List[int],
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
inplace: Optional[bool] = True,
bias: bool = True,
dropout: float = 0.0,
datumbox marked this conversation as resolved.
Show resolved Hide resolved
):
params = {} if inplace is None else {"inplace": inplace}

layers = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
if norm_layer is not None:
layers.append(norm_layer(hidden_dim))
layers.append(activation_layer(**params))
layers.append(torch.nn.Dropout(dropout, **params))
in_dim = hidden_dim

layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
layers.append(torch.nn.Dropout(dropout, **params))
Copy link

@thomasbbrunner thomasbbrunner Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not very clear for me why there's a Dropout layer after the last layer. I saw that it was present in the previous MLPBlock class, but no other implementation of MLP with dropout (that I could find) has a dropout layer on the output. Including the one in the multimodal package.

Maybe this was something specific for the usecase of MLPBlock? If so, this should not be in this class.

Copy link
Contributor Author

@datumbox datumbox Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right there are various implementations of MLP, some of which don't have at all dropout, some have at the middle but not at the end or some have everywhere. If you check the references, you will see that all patterns exist. Our implementation is like that because it replaces MLP layers used in existing models like ViT and Swin. We also try to support more complex variations with more than 2 linear layers. Your observation is correct though that if one wanted to avoid having dropout at the end, the current implementation wouldn't let them. Since that variant is also valid, perhaps it's worth making this update in a non-BC way with a new boolean that controls the appearance of Dropout at the end or not. WDYT?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your suggestion sounds very nice. What would be the default value be for the boolean? I guess that setting it to True (with dropout) would cause no breaking changes. At the same time, I would say that not having a dropout in the last layer is more common (default) configuration? Also, I'd be intested in working on this, whichever option is chosen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that setting it to True (with dropout) would cause no breaking changes.

Yes, you are right we will need to maintain BC. Note that using True is the "default" setup on TorchVision at the moment as literally all existing models require dropout everywhere.

I'd be intested in working on this

Sounds great, let me recommend the following. Could you start an issue, summarizing what you said here and providing a few references of the usage of MLP with a middle dropout but without the final one? Providing a few examples from real-world vision architectures will help build a stronger case. Once we clarify the details on the issue, we can discuss a potential PR. 😃

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! I am a bit short on time at the moment, but will have more time in the upcoming weeks. Nevertheless, I'm interested in this and will be working on it!


super().__init__(*layers)
_log_api_usage_once(self)