Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Permute layer to ops #6055

Merged
merged 1 commit into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ TorchVision provides commonly used building blocks as layers:
DeformConv2d
DropBlock2d
DropBlock3d
MLP
FrozenBatchNorm2d
MLP
Permute
SqueezeExcitation
StochasticDepth

Expand Down
11 changes: 1 addition & 10 deletions torchvision/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn, Tensor
from torch.nn import functional as F

from ..ops.misc import Conv2dNormActivation
from ..ops.misc import Conv2dNormActivation, Permute
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
Expand Down Expand Up @@ -35,15 +35,6 @@ def forward(self, x: Tensor) -> Tensor:
return x


class Permute(nn.Module):
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims

def forward(self, x):
return torch.permute(x, self.dims)


class CNBlock(nn.Module):
def __init__(
self,
Expand Down
3 changes: 1 addition & 2 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import torch.nn.functional as F
from torch import nn, Tensor

from ..ops.misc import MLP
from ..ops.misc import MLP, Permute
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param
from .convnext import Permute # TODO: move Permute on ops


__all__ = [
Expand Down
3 changes: 2 additions & 1 deletion torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP, Permute
from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool
Expand Down Expand Up @@ -62,6 +62,7 @@
"Conv3dNormActivation",
"SqueezeExcitation",
"MLP",
"Permute",
"generalized_box_iou_loss",
"distance_box_iou_loss",
"complete_box_iou_loss",
Expand Down
16 changes: 15 additions & 1 deletion torchvision/ops/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
interpolate = torch.nn.functional.interpolate


# This is not in nn
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
Expand Down Expand Up @@ -297,3 +296,18 @@ def __init__(

super().__init__(*layers)
_log_api_usage_once(self)


class Permute(torch.nn.Module):
"""This module returns a view of the tensor input with its dimensions permuted.

Args:
dims (List[int]): The desired ordering of dimensions
"""

def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims

def forward(self, x: Tensor) -> Tensor:
return torch.permute(x, self.dims)