Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove spatial_size #7734

Merged
merged 6 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions gallery/plot_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
# corresponding image alongside the actual values:

bounding_box = datapoints.BoundingBoxes(
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
)
print(bounding_box)

Expand Down Expand Up @@ -108,7 +108,7 @@ def __getitem__(self, item):
target["boxes"] = datapoints.BoundingBoxes(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
canvas_size=F.get_size(img),
)
target["labels"] = labels
target["masks"] = datapoints.Mask(masks)
Expand All @@ -129,7 +129,7 @@ def __call__(self, img, target):
target["boxes"] = datapoints.BoundingBoxes(
target["boxes"],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
canvas_size=F.get_size(img),
)
target["masks"] = datapoints.Mask(target["masks"])
return img, target
Expand Down
2 changes: 1 addition & 1 deletion gallery/plot_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load_data():
masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1))

bounding_boxes = datapoints.BoundingBoxes(
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
)

return path, image, bounding_boxes, masks, labels
Expand Down
86 changes: 26 additions & 60 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def load(self, device="cpu"):
)


def _parse_spatial_size(size, *, name="size"):
def _parse_canvas_size(size, *, name="size"):
if size == "random":
raise ValueError("This should never happen")
elif isinstance(size, int) and size > 0:
Expand Down Expand Up @@ -467,12 +467,13 @@ def load(self, device):

@dataclasses.dataclass
class ImageLoader(TensorLoader):
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format
canvas_size: Tuple[int, int] = dataclasses.field(init=False)

def __post_init__(self):
self.spatial_size = self.shape[-2:]
self.canvas_size = self.canvas_size = self.shape[-2:]
self.num_channels = self.shape[-3]

def load(self, device):
Expand Down Expand Up @@ -538,7 +539,7 @@ def make_image_loader(
):
if not constant_alpha:
raise ValueError("This should never happen")
size = _parse_spatial_size(size)
size = _parse_canvas_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device, memory_format):
Expand Down Expand Up @@ -578,7 +579,7 @@ def make_image_loaders(
def make_image_loader_for_interpolation(
size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
):
size = _parse_spatial_size(size)
size = _parse_canvas_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device, memory_format):
Expand Down Expand Up @@ -623,43 +624,20 @@ def make_image_loaders_for_interpolation(
class BoundingBoxesLoader(TensorLoader):
format: datapoints.BoundingBoxFormat
spatial_size: Tuple[int, int]
canvas_size: Tuple[int, int] = dataclasses.field(init=False)

def __post_init__(self):
self.canvas_size = self.spatial_size


def make_bounding_box(
size=None,
canvas_size=DEFAULT_SIZE,
*,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=None,
batch_dims=(),
dtype=None,
device="cpu",
):
"""
size: Size of the actual bounding box, i.e.
- (box[3] - box[1], box[2] - box[0]) for XYXY
- (H, W) for XYWH and CXCYWH
spatial_size: Size of the reference object, e.g. an image. Corresponds to the .spatial_size attribute on
returned datapoints.BoundingBoxes

To generate a valid joint sample, you need to set spatial_size here to the same value as size on the other maker
functions, e.g.

.. code::

image = make_image=(size=size)
bounding_boxes = make_bounding_box(spatial_size=size)
assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)

For convenience, if both size and spatial_size are omitted, spatial_size defaults to the same value as size for all
other maker functions, e.g.

.. code::

image = make_image=()
bounding_boxes = make_bounding_box()
assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
"""

def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
Expand All @@ -668,28 +646,16 @@ def sample_position(values, max_value):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]

if spatial_size is None:
if size is None:
spatial_size = DEFAULT_SIZE
else:
height, width = size
height_margin, width_margin = torch.randint(10, (2,)).tolist()
spatial_size = (height + height_margin, width + width_margin)

dtype = dtype or torch.float32

if any(dim == 0 for dim in batch_dims):
return datapoints.BoundingBoxes(
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
)

if size is None:
h, w = [torch.randint(1, s, batch_dims) for s in spatial_size]
else:
h, w = [torch.full(batch_dims, s, dtype=torch.int) for s in size]

y = sample_position(h, spatial_size[0])
x = sample_position(w, spatial_size[1])
h, w = [torch.randint(1, c, batch_dims) for c in canvas_size]
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])

if format is datapoints.BoundingBoxFormat.XYWH:
parts = (x, y, w, h)
Expand All @@ -706,37 +672,37 @@ def sample_position(values, max_value):
raise ValueError(f"Format {format} is not supported")

return datapoints.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
)


def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]

spatial_size = _parse_spatial_size(spatial_size, name="spatial_size")
canvas_size = _parse_canvas_size(canvas_size, name="canvas_size")

def fn(shape, dtype, device):
*batch_dims, num_coordinates = shape
if num_coordinates != 4:
raise pytest.UsageError()

return make_bounding_box(
format=format, spatial_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device
)

return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size)


def make_bounding_box_loaders(
*,
extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat),
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64),
):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
yield make_bounding_box_loader(**params, canvas_size=canvas_size)


make_bounding_boxes = from_loaders(make_bounding_box_loaders)
Expand All @@ -761,7 +727,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp

def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_spatial_size(size)
size = _parse_canvas_size(size)

def fn(shape, dtype, device):
*batch_dims, num_objects, height, width = shape
Expand Down Expand Up @@ -802,15 +768,15 @@ def make_segmentation_mask_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
spatial_size = _parse_spatial_size(size)
canvas_size = _parse_canvas_size(size)

def fn(shape, dtype, device):
*batch_dims, height, width = shape
return make_segmentation_mask(
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
)

return MaskLoader(fn, shape=(*extra_dims, *spatial_size), dtype=dtype)
return MaskLoader(fn, shape=(*extra_dims, *canvas_size), dtype=dtype)


def make_segmentation_mask_loaders(
Expand Down Expand Up @@ -860,7 +826,7 @@ def make_video_loader(
extra_dims=(),
dtype=torch.uint8,
):
size = _parse_spatial_size(size)
size = _parse_canvas_size(size)

def fn(shape, dtype, device, memory_format):
*batch_dims, num_frames, _, height, width = shape
Expand Down
4 changes: 2 additions & 2 deletions test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_mask_instance(data):
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
)
def test_bbox_instance(data, format):
bboxes = datapoints.BoundingBoxes(data, format=format, spatial_size=(32, 32))
bboxes = datapoints.BoundingBoxes(data, format=format, canvas_size=(32, 32))
assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str):
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_wrap_like():
[
datapoints.Image(torch.rand(3, 16, 16)),
datapoints.Video(torch.rand(2, 3, 16, 16)),
datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)),
datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(10, 10)),
datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)),
],
)
Expand Down
34 changes: 17 additions & 17 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test__copy_paste(self, label_type):
labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = {
"boxes": BoundingBoxes(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32)
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", canvas_size=(32, 32)
),
"masks": Mask(masks),
"labels": label_type(labels),
Expand All @@ -179,7 +179,7 @@ def test__copy_paste(self, label_type):
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = {
"boxes": BoundingBoxes(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32)
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", canvas_size=(32, 32)
),
"masks": Mask(paste_masks),
"labels": label_type(paste_labels),
Expand Down Expand Up @@ -210,13 +210,13 @@ class TestFixedSizeCrop:
def test__get_params(self, mocker):
crop_size = (7, 7)
batch_shape = (10,)
spatial_size = (11, 5)
canvas_size = (11, 5)

transform = transforms.FixedSizeCrop(size=crop_size)

flat_inputs = [
make_image(size=spatial_size, color_space="RGB"),
make_bounding_box(format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=batch_shape),
make_image(size=canvas_size, color_space="RGB"),
make_bounding_box(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=batch_shape),
]
params = transform._get_params(flat_inputs)

Expand Down Expand Up @@ -295,7 +295,7 @@ def test__transform(self, mocker, needs):

def test__transform_culling(self, mocker):
batch_size = 10
spatial_size = (10, 10)
canvas_size = (10, 10)

is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch(
Expand All @@ -304,17 +304,17 @@ def test__transform_culling(self, mocker):
needs_crop=True,
top=0,
left=0,
height=spatial_size[0],
width=spatial_size[1],
height=canvas_size[0],
width=canvas_size[1],
is_valid=is_valid,
needs_pad=False,
),
)

bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
)
masks = make_detection_mask(size=spatial_size, batch_dims=(batch_size,))
masks = make_detection_mask(size=canvas_size, batch_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,))

transform = transforms.FixedSizeCrop((-1, -1))
Expand All @@ -334,23 +334,23 @@ def test__transform_culling(self, mocker):

def test__transform_bounding_boxes_clamping(self, mocker):
batch_size = 3
spatial_size = (10, 10)
canvas_size = (10, 10)

mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=spatial_size[0],
width=spatial_size[1],
height=canvas_size[0],
width=canvas_size[1],
is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False,
),
)

bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
)
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes")

Expand Down Expand Up @@ -496,7 +496,7 @@ def make_datapoints():

pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
Expand All @@ -505,7 +505,7 @@ def make_datapoints():

tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
Expand All @@ -514,7 +514,7 @@ def make_datapoints():

datapoint_image = make_image(size=size, color_space="RGB")
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
Expand Down
Loading