Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] torchvision.transforms revamp #5754

Open
6 of 7 tasks
pmeier opened this issue Apr 6, 2022 · 0 comments
Open
6 of 7 tasks

[RFC] torchvision.transforms revamp #5754

pmeier opened this issue Apr 6, 2022 · 0 comments

Comments

@pmeier
Copy link
Collaborator

pmeier commented Apr 6, 2022

🚀 The feature

Note: To track the progress of the project check out this board.

The current transforms in the torchvision.transforms name space are limited to images. This makes it hard to use them for tasks that require the transform to be applied not only to the input image, but also to the target. For example, in object detection, resizing or cropping the input image also affects the bounding boxes.

This projects aims to resolve this by providing transforms that can handle the full sample possibly including images, bounding boxes, segmentation masks, and so on without the need for user interference. The implementation of this project happens in the torchvision.prototype.transforms namespace. For example:

from torchvision.prototype import features, transforms

# this will be supplied by a dataset from torchvision.prototype.datasets
image = features.EncodedImage.from_path("test/assets/fakedata/logos/rgb_pytorch.png")
label = features.Label(0)

transform = transforms.Compose(
    transforms.DecodeImage(),
    transforms.RandomHorizontalFlip(p=1.0),
    transforms.Resize((100, 300)),
)

transformed_image1 = transform(image)
transformed_image2, transformed_label = transform(image, label)

# whether we call the transform with label or not has no effect on the image transform
assert transformed_image1.eq(transformed_image2).all()
# the transform is a no-op for labels
assert transformed_label.eq(label).all()

before_image after_image

# this will be supplied by a dataset from torchvision.prototype.datasets
bounding_box = features.BoundingBox(
    [60, 30, 15, 15], format=features.BoundingBoxFormat.CXCYWH, image_size=(100, 100)
)

transformed_image, transformed_bounding_box = transform(image, bounding_box)

before_image_and_bounding_box after_image_and_bounding_box

# this will be supplied by a dataset from torchvision.prototype.datasets
segmentation_mask = torch.zeros((100, 100), dtype=torch.bool)
segmentation_mask[24:36, 55:66] = True
segmentation_mask = features.SegmentationMask(segmentation_mask)

transformed_image, transformed_segmentation_mask = transform(image, segmentation_mask)

before_image_and_segmentation_mask after_image_and_segmentation_mask

Classification

Detection

Segmentation

Other

cc @vfdev-5 @datumbox

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants