Skip to content

Commit

Permalink
[fbsync] Adding Scale Jitter transform for detection (#5435)
Browse files Browse the repository at this point in the history
Summary:
* Adding Scale Jitter in references.

* Update documentation.

* Address review comments.

Reviewed By: jdsgomes

Differential Revision: D34475308

fbshipit-source-id: dcdb00685d4de39b7315ff7d4b9cfb2411218e5c
  • Loading branch information
prabhat00155 authored and facebook-github-bot committed Feb 25, 2022
1 parent d677be7 commit 793c3db
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torchvision
from torch import nn, Tensor
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
from torchvision.transforms import transforms as T, InterpolationMode


def _flip_coco_person_keypoints(kps, width):
Expand Down Expand Up @@ -282,3 +282,52 @@ def forward(
image = F.to_pil_image(image)

return image, target


class ScaleJitter(nn.Module):
"""Randomly resizes the image and its bounding boxes within the specified scale range.
The class implements the Scale Jitter augmentation as described in the paper
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
Args:
target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
range a <= scale <= b.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
"""

def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation

def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)

r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
new_width = int(self.target_size[1] * r)
new_height = int(self.target_size[0] * r)

image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)

if target is not None:
target["boxes"] *= r
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
)

return image, target

0 comments on commit 793c3db

Please sign in to comment.