-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding multi-layer perceptron in ops #6053
Changes from 7 commits
f6ba25f
d39df6d
eb30cb8
1dfc312
39921e8
743457d
9356107
007ecf3
48d178d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,7 +129,7 @@ class Conv2dNormActivation(ConvNormActivation): | |
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` | ||
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 | ||
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` | ||
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` | ||
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` | ||
dilation (int): Spacing between kernel elements. Default: 1 | ||
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` | ||
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. | ||
|
@@ -179,7 +179,7 @@ class Conv3dNormActivation(ConvNormActivation): | |
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` | ||
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 | ||
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d`` | ||
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` | ||
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` | ||
dilation (int): Spacing between kernel elements. Default: 1 | ||
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` | ||
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. | ||
|
@@ -253,3 +253,47 @@ def _scale(self, input: Tensor) -> Tensor: | |
def forward(self, input: Tensor) -> Tensor: | ||
scale = self._scale(input) | ||
return scale * input | ||
|
||
|
||
class MLP(torch.nn.Sequential): | ||
"""This block implements the multi-layer perceptron (MLP) module. | ||
|
||
Args: | ||
in_channels (int): Number of channels of the input | ||
hidden_channels (List[int]): List of the hidden channel dimensions | ||
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None`` | ||
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Annotations in docstring make me sad :'( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know how you feel above this. :( I think all the callables all over TorchVision are added like that to provide info on what they are supposed to return. |
||
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` | ||
bias (bool): Whether to use bias in the linear layer. Default ``True`` | ||
dropout (float): The probability for the dropout layer. Default: 0.0 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channels: int, | ||
hidden_channels: List[int], | ||
norm_layer: Optional[Callable[..., torch.nn.Module]] = None, | ||
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, | ||
inplace: Optional[bool] = True, | ||
bias: bool = True, | ||
dropout: float = 0.0, | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: | ||
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py | ||
params = {} if inplace is None else {"inplace": inplace} | ||
|
||
layers = [] | ||
in_dim = in_channels | ||
for hidden_dim in hidden_channels[:-1]: | ||
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) | ||
if norm_layer is not None: | ||
layers.append(norm_layer(hidden_dim)) | ||
layers.append(activation_layer(**params)) | ||
layers.append(torch.nn.Dropout(dropout, **params)) | ||
in_dim = hidden_dim | ||
|
||
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) | ||
layers.append(torch.nn.Dropout(dropout, **params)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not very clear for me why there's a Dropout layer after the last layer. I saw that it was present in the previous MLPBlock class, but no other implementation of MLP with dropout (that I could find) has a dropout layer on the output. Including the one in the multimodal package. Maybe this was something specific for the usecase of MLPBlock? If so, this should not be in this class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right there are various implementations of MLP, some of which don't have at all dropout, some have at the middle but not at the end or some have everywhere. If you check the references, you will see that all patterns exist. Our implementation is like that because it replaces MLP layers used in existing models like ViT and Swin. We also try to support more complex variations with more than 2 linear layers. Your observation is correct though that if one wanted to avoid having dropout at the end, the current implementation wouldn't let them. Since that variant is also valid, perhaps it's worth making this update in a non-BC way with a new boolean that controls the appearance of Dropout at the end or not. WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think your suggestion sounds very nice. What would be the default value be for the boolean? I guess that setting it to True (with dropout) would cause no breaking changes. At the same time, I would say that not having a dropout in the last layer is more common (default) configuration? Also, I'd be intested in working on this, whichever option is chosen. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, you are right we will need to maintain BC. Note that using True is the "default" setup on TorchVision at the moment as literally all existing models require dropout everywhere.
Sounds great, let me recommend the following. Could you start an issue, summarizing what you said here and providing a few references of the usage of MLP with a middle dropout but without the final one? Providing a few examples from real-world vision architectures will help build a stronger case. Once we clarify the details on the issue, we can discuss a potential PR. 😃 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok! I am a bit short on time at the moment, but will have more time in the upcoming weeks. Nevertheless, I'm interested in this and will be working on it! |
||
|
||
super().__init__(*layers) | ||
_log_api_usage_once(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a straight move from convnext to ops (no weight patching needed) but to avoid doing everything on a single PR I plan to do it on a follow up.