Skip to content

Commit

Permalink
revert some changes on affine bounding box helper
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Aug 28, 2023
1 parent 48f7859 commit 99bbbb2
Showing 1 changed file with 81 additions and 42 deletions.
123 changes: 81 additions & 42 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,17 @@ def assert_warns_antialias_default_value():
yield


def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, expand=False):
def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
format = bounding_boxes.format
canvas_size = new_canvas_size or bounding_boxes.canvas_size

def affine_bounding_boxes(bounding_boxes):
format = bounding_boxes.format
height, width = bounding_boxes.canvas_size
dtype = bounding_boxes.dtype

# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
input_xyxy = F.convert_bounding_box_format(
bounding_boxes.to(torch.float64, copy=True),
old_format=format,
new_format=datapoints.BoundingBoxFormat.XYXY,
inplace=True,
)
Expand All @@ -359,49 +361,39 @@ def affine_bounding_boxes(bounding_boxes):
[x2, y1, 1.0],
[x1, y2, 1.0],
[x2, y2, 1.0],
# image frame
[0.0, 0.0, 1.0],
[0.0, height, 1.0],
[width, height, 1.0],
[width, 0.0, 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T)

output_xyxy = [
float(np.min(transformed_points[:4, 0])),
float(np.min(transformed_points[:4, 1])),
float(np.max(transformed_points[:4, 0])),
float(np.max(transformed_points[:4, 1])),
]
if expand:
x_translate = float(np.min(transformed_points[4:, 0]))
y_translate = float(np.min(transformed_points[4:, 1]))

output_xyxy[0] -= x_translate
output_xyxy[1] -= y_translate
output_xyxy[2] -= x_translate
output_xyxy[3] -= y_translate

width = int(np.max(transformed_points[4:, 0]) - x_translate)
height = int(np.max(transformed_points[4:, 1]) - y_translate)
output_xyxy = datapoints.BoundingBoxes(
output_xyxy, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(height, width)
output_xyxy = torch.Tensor(
[
float(np.min(transformed_points[:, 0])),
float(np.min(transformed_points[:, 1])),
float(np.max(transformed_points[:, 0])),
float(np.max(transformed_points[:, 1])),
]
)

# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
return F.clamp_bounding_boxes(F.convert_bounding_box_format(output_xyxy, new_format=format)).to(dtype)

outputs = [
affine_bounding_boxes(datapoints.wrap(b, like=bounding_boxes)) for b in bounding_boxes.reshape(-1, 4).unbind()
]

canvas_sizes = [o.canvas_size for o in outputs]
canvas_size = canvas_sizes[0]
assert all(s == canvas_size for s in canvas_sizes[1:])
output = F.convert_bounding_box_format(
output_xyxy, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format
)

return datapoints.wrap(
torch.cat(outputs, dim=0).reshape(bounding_boxes.shape), like=bounding_boxes, canvas_size=canvas_size
if clamp:
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
output = F.clamp_bounding_boxes(
output,
format=format,
canvas_size=canvas_size,
).to(dtype)

return output

return datapoints.BoundingBoxes(
torch.cat([affine_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape(
bounding_boxes.shape
),
format=format,
canvas_size=canvas_size,
)


Expand Down Expand Up @@ -593,7 +585,7 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non
return reference_affine_bounding_boxes_helper(
bounding_boxes,
affine_matrix=affine_matrix,
expand=True,
new_canvas_size=(new_height, new_width),
)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
Expand Down Expand Up @@ -1517,8 +1509,42 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill,
mae = (actual.float() - expected.float()).abs().mean()
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6

def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center):
def _compute_output_canvas_size(self, *, expand, canvas_size, affine_matrix):
if not expand:
return canvas_size, (0.0, 0.0)

input_height, input_width = canvas_size

input_image_frame = np.array(
[
[0.0, 0.0, 1.0],
[0.0, input_height, 1.0],
[input_width, input_height, 1.0],
[input_width, 0.0, 1.0],
],
dtype=np.float64,
)
output_image_frame = np.matmul(input_image_frame, affine_matrix.astype(input_image_frame.dtype).T)

recenter_x = float(np.min(output_image_frame[:, 0]))
recenter_y = float(np.min(output_image_frame[:, 1]))

output_width = int(np.max(output_image_frame[:, 0]) - recenter_x)
output_height = int(np.max(output_image_frame[:, 1]) - recenter_y)

return (output_height, output_width), (recenter_x, recenter_y)

def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy):
x, y = recenter_xy
if bounding_boxes.format is datapoints.BoundingBoxFormat.XYXY:
translate = [x, y, x, y]
else:
translate = [x, y, 0.0, 0.0]
return datapoints.wrap(
(bounding_boxes.to(torch.float64) - torch.tensor(translate)).to(bounding_boxes.dtype), like=bounding_boxes
)

def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center):
if center is None:
center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
cx, cy = center
Expand All @@ -1532,7 +1558,20 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
],
)

return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix, expand=expand)
new_canvas_size, recenter_xy = self._compute_output_canvas_size(
expand=expand, canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix
)

output = reference_affine_bounding_boxes_helper(
bounding_boxes,
affine_matrix=affine_matrix,
new_canvas_size=new_canvas_size,
clamp=False,
)

return F.clamp_bounding_boxes(self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy)).to(
bounding_boxes
)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
Expand Down

0 comments on commit 99bbbb2

Please sign in to comment.