Skip to content

Commit

Permalink
[fbsync] improve affine bounding box reference helper (#7884)
Browse files Browse the repository at this point in the history
Summary: (Note: this ignores all push blocking failures!)

Reviewed By: matteobettini

Differential Revision: D48900391

fbshipit-source-id: 1d28073ec3059ff6815e9d7cdf5c3ea054acd71f
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Sep 6, 2023
1 parent afecea2 commit 5f97a97
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 78 deletions.
178 changes: 102 additions & 76 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,45 +351,62 @@ def assert_warns_antialias_default_value():
yield


def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_size, affine_matrix):
def transform(bbox):
def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
format = bounding_boxes.format
canvas_size = new_canvas_size or bounding_boxes.canvas_size

def affine_bounding_boxes(bounding_boxes):
dtype = bounding_boxes.dtype

# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor),
input_xyxy = F.convert_bounding_box_format(
bounding_boxes.to(torch.float64, copy=True),
old_format=format,
new_format=datapoints.BoundingBoxFormat.XYXY,
inplace=True,
)
x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist()

points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
[x1, y1, 1.0],
[x2, y1, 1.0],
[x1, y2, 1.0],
[x2, y2, 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix.T)
out_bbox = torch.tensor(
transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T)

output_xyxy = torch.Tensor(
[
np.min(transformed_points[:, 0]).item(),
np.min(transformed_points[:, 1]).item(),
np.max(transformed_points[:, 0]).item(),
np.max(transformed_points[:, 1]).item(),
],
dtype=bbox_xyxy.dtype,
float(np.min(transformed_points[:, 0])),
float(np.min(transformed_points[:, 1])),
float(np.max(transformed_points[:, 0])),
float(np.max(transformed_points[:, 1])),
]
)
out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True

output = F.convert_bounding_box_format(
output_xyxy, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_boxes(out_bbox, format=format, canvas_size=canvas_size)
out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox

return torch.stack([transform(b) for b in bounding_boxes.reshape(-1, 4).unbind()]).reshape(bounding_boxes.shape)
if clamp:
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
output = F.clamp_bounding_boxes(
output,
format=format,
canvas_size=canvas_size,
).to(dtype)

return output

return datapoints.BoundingBoxes(
torch.cat([affine_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape(
bounding_boxes.shape
),
format=format,
canvas_size=canvas_size,
)


class TestResize:
Expand Down Expand Up @@ -580,16 +597,13 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non
[new_width / old_width, 0, 0],
[0, new_height / old_height, 0],
],
dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_boxes_helper(
return reference_affine_bounding_boxes_helper(
bounding_boxes,
format=bounding_boxes.format,
canvas_size=(new_height, new_width),
affine_matrix=affine_matrix,
new_canvas_size=(new_height, new_width),
)
return datapoints.wrap(expected_bboxes, like=bounding_boxes, canvas_size=(new_height, new_width))

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("size", OUTPUT_SIZES)
Expand Down Expand Up @@ -884,17 +898,9 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes):
[-1, 0, bounding_boxes.canvas_size[1]],
[0, 1, 0],
],
dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes,
format=bounding_boxes.format,
canvas_size=bounding_boxes.canvas_size,
affine_matrix=affine_matrix,
)

return datapoints.wrap(expected_bboxes, like=bounding_boxes)
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1129,26 +1135,19 @@ def _compute_affine_matrix(self, *, angle, translate, scale, shear, center):
shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
return true_matrix
return true_matrix[:2, :]

def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate, scale, shear, center):
if center is None:
center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]

affine_matrix = self._compute_affine_matrix(
angle=angle, translate=translate, scale=scale, shear=shear, center=center
)
affine_matrix = affine_matrix[:2, :]

expected_bboxes = reference_affine_bounding_boxes_helper(
return reference_affine_bounding_boxes_helper(
bounding_boxes,
format=bounding_boxes.format,
canvas_size=bounding_boxes.canvas_size,
affine_matrix=affine_matrix,
affine_matrix=self._compute_affine_matrix(
angle=angle, translate=translate, scale=scale, shear=shear, center=center
),
)

return expected_bboxes

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
Expand Down Expand Up @@ -1347,17 +1346,9 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes):
[1, 0, 0],
[0, -1, bounding_boxes.canvas_size[0]],
],
dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes,
format=bounding_boxes.format,
canvas_size=bounding_boxes.canvas_size,
affine_matrix=affine_matrix,
)

return datapoints.wrap(expected_bboxes, like=bounding_boxes)
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
Expand Down Expand Up @@ -1535,39 +1526,73 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill,
mae = (actual.float() - expected.float()).abs().mean()
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6

def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center):
# FIXME
if expand:
raise ValueError("This reference currently does not support expand=True")
def _compute_output_canvas_size(self, *, expand, canvas_size, affine_matrix):
if not expand:
return canvas_size, (0.0, 0.0)

input_height, input_width = canvas_size

input_image_frame = np.array(
[
[0.0, 0.0, 1.0],
[0.0, input_height, 1.0],
[input_width, input_height, 1.0],
[input_width, 0.0, 1.0],
],
dtype=np.float64,
)
output_image_frame = np.matmul(input_image_frame, affine_matrix.astype(input_image_frame.dtype).T)

recenter_x = float(np.min(output_image_frame[:, 0]))
recenter_y = float(np.min(output_image_frame[:, 1]))

output_width = int(np.max(output_image_frame[:, 0]) - recenter_x)
output_height = int(np.max(output_image_frame[:, 1]) - recenter_y)

return (output_height, output_width), (recenter_x, recenter_y)

def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy):
x, y = recenter_xy
if bounding_boxes.format is datapoints.BoundingBoxFormat.XYXY:
translate = [x, y, x, y]
else:
translate = [x, y, 0.0, 0.0]
return datapoints.wrap(
(bounding_boxes.to(torch.float64) - torch.tensor(translate)).to(bounding_boxes.dtype), like=bounding_boxes
)

def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center):
if center is None:
center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
cx, cy = center

a = np.cos(angle * np.pi / 180.0)
b = np.sin(angle * np.pi / 180.0)
cx = center[0]
cy = center[1]
affine_matrix = np.array(
[
[a, b, cx - cx * a - b * cy],
[-b, a, cy + cx * b - a * cy],
],
dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_boxes_helper(
new_canvas_size, recenter_xy = self._compute_output_canvas_size(
expand=expand, canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix
)

output = reference_affine_bounding_boxes_helper(
bounding_boxes,
format=bounding_boxes.format,
canvas_size=bounding_boxes.canvas_size,
affine_matrix=affine_matrix,
new_canvas_size=new_canvas_size,
clamp=False,
)

return expected_bboxes
return F.clamp_bounding_boxes(self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy)).to(
bounding_boxes
)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
# TODO: add support for expand=True in the reference
@pytest.mark.parametrize("expand", [False])
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
def test_functional_bounding_boxes_correctness(self, format, angle, expand, center):
bounding_boxes = make_bounding_boxes(format=format)
Expand All @@ -1576,10 +1601,10 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent
expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center)

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
# TODO: add support for expand=True in the reference
@pytest.mark.parametrize("expand", [False])
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_bounding_boxes_correctness(self, format, expand, center, seed):
Expand All @@ -1596,6 +1621,7 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed
expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center)

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)

@pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"])
@pytest.mark.parametrize("seed", list(range(10)))
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,8 @@ def _affine_bounding_boxes_with_expand(
tr = torch.amin(new_points, dim=0, keepdim=True)
# Translate bounding boxes
out_bboxes.sub_(tr.repeat((1, 2)))
# Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
# Estimate meta-data for image with inverted=True
affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
canvas_size = (new_height, new_width)

Expand Down

0 comments on commit 5f97a97

Please sign in to comment.