Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proto] Added functional perspective_bounding_box/segmentation_mask ops #5888

Merged
merged 6 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 192 additions & 4 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from torch.nn.functional import one_hot
from torchvision.prototype import features
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.functional_tensor import _max_value as get_max_value


make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")


Expand Down Expand Up @@ -380,6 +382,37 @@ def pad_segmentation_mask():
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)


@register_kernel_info_from_sample_inputs_fn
def perspective_bounding_box():
for bounding_box, perspective_coeffs in itertools.product(
make_bounding_boxes(),
[
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
],
):
yield SampleInput(
bounding_box,
format=bounding_box.format,
perspective_coeffs=perspective_coeffs,
)


@register_kernel_info_from_sample_inputs_fn
def perspective_segmentation_mask():
for mask, perspective_coeffs in itertools.product(
make_segmentation_masks(extra_dims=((), (4,))),
[
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
],
):
yield SampleInput(
mask,
perspective_coeffs=perspective_coeffs,
)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -985,7 +1018,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
],
)
def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size):
def _compute_expected(bbox, top_, left_, height_, width_, size_):
def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
# bbox should be xyxy
bbox[0] = (bbox[0] - left_) * size_[1] / width_
bbox[1] = (bbox[1] - top_) * size_[0] / height_
Expand All @@ -1001,7 +1034,7 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_):
]
expected_bboxes = []
for in_box in in_boxes:
expected_bboxes.append(_compute_expected(list(in_box), top, left, height, width, size))
expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
expected_bboxes = torch.tensor(expected_bboxes, device=device)

in_boxes = features.BoundingBox(
Expand All @@ -1027,7 +1060,7 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_):
],
)
def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size):
def _compute_expected(mask, top_, left_, height_, width_, size_):
def _compute_expected_mask(mask, top_, left_, height_, width_, size_):
output = mask.clone()
output = output[:, top_ : top_ + height_, left_ : left_ + width_]
output = torch.nn.functional.interpolate(output[None, :].float(), size=size_, mode="nearest")
Expand All @@ -1038,7 +1071,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
in_mask[0, 10:20, 10:20] = 1
in_mask[0, 5:15, 12:23] = 2

expected_mask = _compute_expected(in_mask, top, left, height, width, size)
expected_mask = _compute_expected_mask(in_mask, top, left, height, width, size)
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
torch.testing.assert_close(output_mask, expected_mask)

Expand Down Expand Up @@ -1085,3 +1118,158 @@ def parse_padding():

expected_mask = _compute_expected_mask()
torch.testing.assert_close(out_mask, expected_mask)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"startpoints, endpoints",
[
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
],
)
def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
def _compute_expected_bbox(bbox, pcoeffs_):
m1 = np.array(
[
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
[pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
]
)
m2 = np.array(
[
[pcoeffs_[6], pcoeffs_[7], 1.0],
[pcoeffs_[6], pcoeffs_[7], 1.0],
]
)

bbox_xyxy = convert_bounding_box_format(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
]
)
numer = np.matmul(points, m1.T)
denom = np.matmul(points, m2.T)
transformed_points = numer / denom
out_bbox = [
np.min(transformed_points[:, 0]),
np.min(transformed_points[:, 1]),
np.max(transformed_points[:, 0]),
np.max(transformed_points[:, 1]),
]
out_bbox = features.BoundingBox(
out_bbox,
format=features.BoundingBoxFormat.XYXY,
image_size=bbox.image_size,
dtype=torch.float32,
device=bbox.device,
)
return convert_bounding_box_format(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
)

image_size = (32, 38)

pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

for bboxes in make_bounding_boxes(
image_sizes=[
image_size,
],
extra_dims=((4,),),
):
bboxes = bboxes.to(device)
bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size

output_bboxes = F.perspective_bounding_box(
bboxes,
bboxes_format,
perspective_coeffs=pcoeffs,
)

if bboxes.ndim < 2:
bboxes = [bboxes]

expected_bboxes = []
for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=1e-5, atol=1e-5)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"startpoints, endpoints",
[
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
],
)
def test_correctness_perspective_segmentation_mask(device, startpoints, endpoints):
def _compute_expected_mask(mask, pcoeffs_):
assert mask.ndim == 3 and mask.shape[0] == 1
m1 = np.array(
[
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
[pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
]
)
m2 = np.array(
[
[pcoeffs_[6], pcoeffs_[7], 1.0],
[pcoeffs_[6], pcoeffs_[7], 1.0],
]
)

expected_mask = torch.zeros_like(mask.cpu())
for out_y in range(expected_mask.shape[1]):
for out_x in range(expected_mask.shape[2]):
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])

numer = np.matmul(output_pt, m1.T)
denom = np.matmul(output_pt, m2.T)
input_pt = np.floor(numer / denom).astype(np.int32)

in_x, in_y = input_pt[:2]
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
return expected_mask.to(mask.device)

pcoeffs = _get_perspective_coeffs(startpoints, endpoints)

for mask in make_segmentation_masks(extra_dims=((), (4,))):
mask = mask.to(device)

output_mask = F.perspective_segmentation_mask(
mask,
perspective_coeffs=pcoeffs,
)

if mask.ndim < 4:
masks = [mask]
else:
masks = [m for m in mask]

expected_masks = []
for mask in masks:
expected_mask = _compute_expected_mask(mask, pcoeffs)
expected_masks.append(expected_mask)
if len(expected_masks) > 1:
expected_masks = torch.stack(expected_masks)
else:
expected_masks = expected_masks[0]
torch.testing.assert_close(output_mask, expected_masks)
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@
crop_image_tensor,
crop_image_pil,
crop_segmentation_mask,
perspective_bounding_box,
perspective_image_tensor,
perspective_image_pil,
perspective_segmentation_mask,
vertical_flip_image_tensor,
vertical_flip_image_pil,
vertical_flip_bounding_box,
Expand Down
89 changes: 89 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,95 @@ def perspective_image_pil(
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)


def perspective_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
perspective_coeffs: List[float],
) -> torch.Tensor:

if len(perspective_coeffs) != 8:
raise ValueError("Argument perspective_coeffs should have 8 float values")

original_shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)

dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device

# perspective_coeffs are computed as endpoint -> start point
# We have to invert perspective_coeffs for bboxes:
# (x, y) - end point and (x_out, y_out) - start point
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
# and we would like to get:
# x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
# y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
# and compute inv_coeffs in terms of coeffs

denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
if denom == 0:
raise RuntimeError(
f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
f"Denominator is zero, denom={denom}"
)

inv_coeffs = [
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
(perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
(-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
(perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
(-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
(perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
(-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
(-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
]

theta1 = torch.tensor(
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
dtype=dtype,
device=device,
)

theta2 = torch.tensor(
[[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
)

# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using perspective matrices
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

numer_points = torch.matmul(points, theta1.T)
denom_points = torch.matmul(points, theta2.T)
transformed_points = numer_points / denom_points

# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)

# out_bboxes should be of shape [N boxes, 4]

return convert_bounding_box_format(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)


def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST)


def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)]
Expand Down