Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add functional pad on segmentation mask #5866

Merged
merged 6 commits into from
May 2, 2022

Conversation

federicopozzi33
Copy link
Contributor

@federicopozzi33 federicopozzi33 commented Apr 22, 2022

Part of #5782

Results on synthetic images/bboxes/segm mask:

Code
import numpy as np

import torch
import torchvision
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import pad_bounding_box, pad_segmentation_mask, pad_image_tensor

size = (64, 76)
# xyxy format
in_boxes = [
    [10, 15, 25, 35],
    [50, 5, 70, 22],
    [45, 46, 56, 62],
]
labels = [1, 2, 3]

im1 = 255 * np.ones(size + (3, ), dtype=np.uint8)
mask = np.zeros(size, dtype=np.int64)
for in_box, label in zip(in_boxes, labels):
    im1[in_box[1]:in_box[3], in_box[0]:in_box[2], :] = (127, 127, 127)
    mask[in_box[1]:in_box[3], in_box[0]:in_box[2]] = label
    
t_im1 = torch.tensor(im1).permute(2, 0, 1).view(1, 3, *size)

in_boxes = features.BoundingBox(
    in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size
)
in_mask = features.SegmentationMask(torch.tensor(mask)).view(1, *size)
    
padding = [5, 5, 5, 5]

out_boxes = pad_bounding_box(in_boxes, padding, in_boxes.format)
print(out_boxes)

out_mask = pad_segmentation_mask(in_mask, padding)

t_im2 = pad_image_tensor(t_im1, padding)


import cv2
import matplotlib.pyplot as plt
%matplotlib inline


plt.figure(figsize=(14, 10))

plt.subplot(2,3,1)
plt.title("Input image + bboxes")
r1 = t_im1[0, ...].permute(1, 2, 0).contiguous().cpu().numpy()
for in_box in in_boxes:    
    r1 = cv2.rectangle(r1, (in_box[0].item(), in_box[1].item()), (in_box[2].item(), in_box[3].item()), (255, 127, 0))
plt.imshow(r1)


plt.subplot(2,3,2)
plt.title("Input segm mask")
plt.imshow(in_mask[0, :, :].cpu().numpy())


plt.subplot(2,3,3)
plt.title("Input image + bboxes + segm mask")
plt.imshow(r1, alpha=0.5)
plt.imshow(in_mask[0, :, :].cpu().numpy(), alpha=0.75)


plt.subplot(2,3,4)
plt.title("Output image + bboxes")
r2 = t_im2[0, ...].permute(1, 2, 0).contiguous().cpu().numpy()
for out_box in out_boxes:
    out_box = np.round(out_box.cpu().numpy()).astype("int32")
    r2 = cv2.rectangle(r2, (out_box[0], out_box[1]), (out_box[2], out_box[3]), (255, 127, 0), 0)
plt.imshow(r2)


plt.subplot(2,3,5)
plt.title("Output segm mask")
plt.imshow(out_mask[0, :, :].cpu().numpy())

plt.subplot(2,3,6)
plt.title("Output image + bboxes + segm mask")
plt.imshow(r2, alpha=0.5)
plt.imshow(out_mask[0, :, :].cpu().numpy(), alpha=0.75)

pad_viz

@federicopozzi33
Copy link
Contributor Author

federicopozzi33 commented Apr 22, 2022

Hi @vfdev-5,

I have some questions about this PR.

  1. Delving a little bit into the pad function implementation, it seems that typing is not correct.

def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:

In particular,padding and fill should be Union[int, List[int], Tuple[int]] and Union[int, float] respectively.

  1. Implementing the pad_segmentation_mask, I asked myself if I should expose fill and padding_mode parameters. In particular, segmentation masks shouldn't be binary? If so, should I check the fill value to be in [0, 1]?

@federicopozzi33 federicopozzi33 marked this pull request as draft April 23, 2022 09:35
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 23, 2022

Delving a little bit into the pad function implementation, it seems that typing is not correct.

@federicopozzi33 this may be due to torch jit script limitation of using Union which started being supported after the time pad was coded for tensors.

Implementing the pad_segmentation_mask, I asked myself if I should expose fill and padding_mode parameters. In particular, segmentation masks shouldn't be binary? If so, should I check the fill value to be in [0, 1]?

In case of fill, this can be a feature request later if user would like to have background class other then 0. For existing ops with fill, I think we do not expose it for masks.

def affine_segmentation_mask(
img: torch.Tensor,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
) -> torch.Tensor:

As for padding_mode I think it should be supported. For example modes like "replicate" recreate usable image data on images and it would make sense to do that on masks.

segmentation masks shouldn't be binary?

I'd say no as it could have in general N classes

@federicopozzi33
Copy link
Contributor Author

federicopozzi33 commented Apr 24, 2022

segmentation masks shouldn't be binary?

I'd say no as it could have in general N classes

Ok, I was just thinking of one mask per class, in which 0 is the "background" class, and 1 represents class pixels.

Anyway, I think you can start looking at the PR.

Meanwhile, I'll try to figure out how to test other padding modes other than constant, which seems not so straightforward. It sounds like I have to re-implement pad operation to test pad function... which is at least strange/incorrect (that's because I'm using random segmentation masks).

@federicopozzi33 federicopozzi33 changed the title WIP: add functional pad on segmentation mask feat: add functional pad on segmentation mask Apr 24, 2022
@federicopozzi33 federicopozzi33 marked this pull request as ready for review April 24, 2022 22:24
@federicopozzi33 federicopozzi33 force-pushed the feat/5782-proto-mask-pad branch from 6f3f37e to 8b42851 Compare April 24, 2022 22:26
@datumbox
Copy link
Contributor

FYI, pad is busted and took it with it TorchVision's CI (see #5873). So any failures you see on the CI are not necessarily related to this PR.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @federicopozzi33
I left few questions. Let me think about how to test this op for all other params

Comment on lines 362 to 384
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
continue
if (
padding_mode == "edge"
and len(padding) == 2
and mask.ndim not in [2, 3]
or len(padding) == 4
and mask.ndim not in [4, 3]
or len(padding) == 1
):
continue
if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]:
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@federicopozzi33 can you comment out why we skip so much test cases ?

I printed locally which cases it skips and now quite sure about why it is so. For example

print(padding_mode, padding, mask.ndim, mask.shape)
>
constant [1] 3 torch.Size([1, 16, 16])

Copy link
Contributor Author

@federicopozzi33 federicopozzi33 Apr 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed. I was skipping too much tests, you're right.

The cases to be skipped are the following.

First case:

if padding_mode == "symmetric" and mask.ndim not in [2, 3, 4]:
    continue

Reason:
RuntimeError: Symmetric padding of N-D tensors are not supported yet in torchvision.transforms.functional_tensor._pad_symmetric

Second case:

if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [2, 3, 4]:
    continue

Reason:
NotImplementedError: Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now in torch.nn.functional.pad

So why am I also skipping 5D?

You have to keep in mind that len(padding) is always equal to 4 (since it's parsed before passing it to the torch function), and with that padding the only allowed dimensions are 3 and 4.

So why am I not skipping 2D?

That because in torchvision.transforms.functional_tensor.pad, mask is unsqueezed:

if img.ndim < 4:
    img = img.unsqueeze(dim=0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask.ndim not in [2, 3, 4]:

I think mask can't have ndim == 2 ? Minimally, it is (1, H, W). cc @pmeier

For N-D cases e.g. (N, M, K, L, 1, H, W) we could reshape it into 4D before padding (N * M * K * L, 1, H, W) and reshape back inside padding op ?

Copy link
Contributor Author

@federicopozzi33 federicopozzi33 Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like this?

(The code can be further refactored).

def pad_segmentation_mask(
    segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
) -> torch.Tensor:
  *extra_dims, c, h, w = segmentation_mask.shape 
      
  mask_4d = segmentation_mask.reshape(-1, b, h, w) if len(extra_dims) > 0 else segmentation_mask
  padded = pad_image_tensor(img=mask_4d, padding=padding, fill=0, padding_mode=padding_mode)
      
  *_, new_c, new_h, new_w = padded.shape
  return padded.reshape(*extra_dims, new_c, new_h, new_w) if len(extra_dims) > 0 else padded

Seems that .shape operation is not supported by torch.jit...

cannot statically infer the expected size of a list in this context:
  File "/Users/fpozzi/Documents/OpenSource/my_repos/vision/torchvision/prototype/transforms/functional/_geometry.py", line 402
    segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
) -> torch.Tensor:
    *extra_dims, b, h, w = segmentation_mask.shape 
                           ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
    mask_4d = segmentation_mask.reshape(-1, b, h, w) if len(extra_dims) > 0 else segmentation_mask

I tried also with .size() but I got the same expection.

Any suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.shape is supported by torch script but *extra_dims is not.
You can do something like:

mask = torch.rand(5, 4, 1, 32, 32)

def foo(mask):
    c, h, w = mask.shape[-3], mask.shape[-2], mask.shape[-1]
    m = mask.view(-1, c, h, w)
    return m.view(mask.shape)

torch.jit.script(foo)(mask).shape

Copy link
Collaborator

@pmeier pmeier Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mask can't have ndim == 2 ? Minimally, it is (1, H, W).

Masks have gotten very little love so far so there is no final decision. But I agree, boolean masks should have shape (N, H, W) where N is the number of segmentation whether they are individual objects (instance segmentation) or a crowd (semantic segmentation).

For colored masks (not sure if there is a better name for them) the shape should be (3, H, W) and probably uint8 as dtype. But we should probably have a different feature for them so we don't confuse the two. Let's stick to boolean masks for now.

Do you mean something like this?

Have a look how we are doing it for the bounding boxes:

def horizontal_flip_bounding_box(
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)

The only difference is that the "data dimensions" are not fixed. For masks this would look like

@torch.jit.script
def pad_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
    shape = segmentation_mask.shape

    num_masks, height, width = shape[-3:]
    segmentation_mask = segmentation_mask.view(-1, num_masks, height, width)
    
    # kernel
    
    return segmentation_mask.view(shape)

Copy link
Contributor Author

@federicopozzi33 federicopozzi33 Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference is that the "data dimensions" are not fixed. For masks this would look like

@torch.jit.script
def pad_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
    shape = segmentation_mask.shape

    num_masks, height, width = shape[-3:]
    segmentation_mask = segmentation_mask.view(-1, num_masks, height, width)
    
    # kernel
    
    return segmentation_mask.view(shape)

Yeah, the point is that the mask shape before padding is different from mask shape after padding.
So I can't simply do:

return segmentation_mask.view(shape)

Mask can have any num_dims > 2. So I need to collect dimensions before padding: in particular, I need the extra dimensions, num_masks, height and width.

# Let's suppose that segmentation_mask have shape (3, 4, 1, 5, 5).
extra_dims, num_masks, h, w = segmentation_mask.shape 
print(extra_dims, num_masks, h, w)
> (3, 4), 1, 5, 5

num_masks, height and width are used to reshape the mask into 4D tensor before padding.

segmentation_mask.view(-1, num_masks, height, width)

extra_dims is needed to reshape the mask back: however, I haven't figured out yet how to use the .view method without the unpacking operator.

# with unpacking operator
return padded_mask.view(*extra_dims, num_masks, height, width)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the point is that the mask shape before padding is different from mask shape after padding.

🤦 My bad, sorry. resize_image_tensor is a better reference:

@torch.jit.script
def pad_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
    num_masks, old_height, old_width = segmentation_mask.shape[-3:]
    batch_shape = segmentation_mask.shape[:-3]
    segmentation_mask = segmentation_mask.view((-1, num_masks, old_height, old_width))

    # kernel
    new_height, new_width = old_height, old_width

    return segmentation_mask.view(batch_shape + (num_masks, new_height, new_width))

Copy link
Contributor Author

@federicopozzi33 federicopozzi33 Apr 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the point is that the mask shape before padding is different from mask shape after padding.

🤦 My bad, sorry. resize_image_tensor is a better reference:

@torch.jit.script
def pad_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
    num_masks, old_height, old_width = segmentation_mask.shape[-3:]
    batch_shape = segmentation_mask.shape[:-3]
    segmentation_mask = segmentation_mask.view((-1, num_masks, old_height, old_width))

    # kernel
    new_height, new_width = old_height, old_width

    return segmentation_mask.view(batch_shape + (num_masks, new_height, new_width))

Thank you!

I was missing some things:
First one:

size = torch.ones(3, 3).shape[:]
> torch.Size([3, 3])

size = np.ones((3, 3)).shape[:]
> (3, 3)

Second one:

torch.Size([1, 2]) + (3, 4)
> torch.Size([1, 2, 3, 4])

test/test_prototype_transforms_functional.py Outdated Show resolved Hide resolved
@federicopozzi33 federicopozzi33 force-pushed the feat/5782-proto-mask-pad branch from 8b42851 to b425177 Compare April 25, 2022 17:50
@federicopozzi33 federicopozzi33 requested a review from vfdev-5 April 25, 2022 17:52
Comment on lines +1089 to +1087
out_mask = F.pad_segmentation_mask(mask, padding, "constant")

expected_mask = _compute_expected_mask()
torch.testing.assert_close(out_mask, expected_mask)
Copy link
Collaborator

@vfdev-5 vfdev-5 Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking of how to test all other padding modes. Maybe we could just check output value from out_mask for two lines: one horizontal and one vertical instead of constructing full expected mask.
While checking the lines we still need to route the checks according to the padding mode.
@federicopozzi33 What do you think ?

Copy link
Contributor Author

@federicopozzi33 federicopozzi33 Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems pretty reasonable.

Anyway, the code needed for constructing the two lines IMO still "mimics" (or re-implements) the padding operation (which is what I wanted to avoid, but it seems that there are no other options).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking of how to test all other padding modes.

I don't think we need a correctness check for them. Internally, torch.nn.functional.pad does the heavy lifting. Thus, if we rely on that giving us the correct behavior there is no need to check if the values of the padding are correct.

Copy link
Collaborator

@vfdev-5 vfdev-5 May 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier if we do not check correctness on the output, another option is to mock torch_pad and ensure that it is called with correct configuration otherwise the code is not covered. I'm talking about all other non-tested padding options.

@federicopozzi33 federicopozzi33 force-pushed the feat/5782-proto-mask-pad branch from 33ec8c1 to 5b0d597 Compare April 28, 2022 18:35
Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see anything the I would object too. I'll leave it to @vfdev-5 to approve, since he shepherded the PR so far.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@federicopozzi33 thanks for your patience and your work on this PR ! I suggest right now to merge this PR and send a follow up with improved tests (once we decided how and what to do)

@vfdev-5 vfdev-5 merged commit 104073c into pytorch:main May 2, 2022
@github-actions
Copy link

github-actions bot commented May 2, 2022

Hey @vfdev-5!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 5, 2022

@federicopozzi33 can you please update the discription with few visualizations of how padding works on images, bbox and segm masks. You can inspire from #5613 or other similar PRs. Thanks !

facebook-github-bot pushed a commit that referenced this pull request May 6, 2022
Summary:
* feat: add functional pad on segmentation mask

* test: add basic correctness test with random masks

* test: add all padding options

* fix: pr comments

* fix: tests

* refactor: reshape tensor in 4d, then pad

Reviewed By: jdsgomes, NicolasHug

Differential Revision: D36095691

fbshipit-source-id: 1e31988216fea1664c1fd48ee39598d28bac8308

Co-authored-by: Federico Pozzi <federico.pozzi@argo.vision>
@federicopozzi33
Copy link
Contributor Author

@federicopozzi33 can you please update the discription with few visualizations of how padding works on images, bbox and segm masks. You can inspire from #5613 or other similar PRs. Thanks !

Done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants