Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distance IoU #5786

Merged
merged 26 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
135763c
[FEAT] Add distance IoU and distance IoU loss + some tests (WIP for t…
Apr 7, 2022
ec599d2
[FIX] Remove URL from docstring + remove assert since it causes a big…
Apr 7, 2022
41703e6
[FIX] eps isn't None.
Apr 7, 2022
51616ed
[TEST] Update existing box dIoU test + add dIoU loss tests (inspired …
Apr 13, 2022
ee37c8d
Merge branch 'main' into dIoU
Apr 13, 2022
7631ab7
[ENH] Some pre-commit fixes + remove print + mypy.
Apr 13, 2022
b744d6d
Merge branch 'dIoU' of github.com:yassineAlouini/vision-1 into dIoU
Apr 13, 2022
8ceffcc
[ENH] Pass the device in the assertion for the dIoU loss test.
Apr 13, 2022
bc65b83
Merge branch 'main' into dIoU
Apr 14, 2022
a4e58b7
[FIX] Remove type hints from the dIoU box test.
Apr 14, 2022
4ba5cdc
Merge branch 'dIoU' of github.com:yassineAlouini/vision-1 into dIoU
Apr 14, 2022
0ead2c3
[ENH] Refactor box and loss for dIoU functions + fix half tests.
Apr 21, 2022
a2702f8
[FIX] Precommits fix.
Apr 21, 2022
27894ef
Merge branch 'main' into dIoU
Apr 21, 2022
d4bd825
Merge branch 'main' of github.com:yassineAlouini/vision-1 into dIoU
Apr 26, 2022
497a7c1
[ENH] Some improvement for the distance IoU tests thanks to code review.
Apr 26, 2022
d8b7f35
Merge branch 'dIoU' of github.com:yassineAlouini/vision-1 into dIoU
Apr 26, 2022
a054032
[ENH] Upcast in distance boxes computation to avoid overflow.
Apr 26, 2022
d7baa67
[ENH] Revert the refactor of distance IoU loss back since it introduc…
Apr 26, 2022
4213ee4
Precommit fix.
Apr 26, 2022
1a2d6ab
Merge main and fix conflicts + make code iso with cIoU.
May 2, 2022
2856947
[FIX] Few changes introduced by merge conflict.
May 2, 2022
3a9d3d7
Add code reference
datumbox May 9, 2022
13fa495
Merge branch 'main' into dIoU
datumbox May 9, 2022
1b2f1e6
Fix test
datumbox May 9, 2022
ab44428
Merge branch 'main' into dIoU
datumbox May 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Operators
drop_block3d
generalized_box_iou
generalized_box_iou_loss
distance_box_iou
distance_box_iou_loss
masks_to_boxes
nms
ps_roi_align
Expand Down
72 changes: 72 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,78 @@ def test_giou_jit(self) -> None:
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


class TestDistanceBoxIoU(BoxTestBase):
def _target_fn(self):
return (True, ops.distance_box_iou)

def _generate_int_input():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has nothing to do with this PR, but I want make this visible. Although this works, there are two issues with this structure introduced in #5380 cc @datumbox:

  1. We shouldn't have regular methods without self as first parameter. Given that we don't need it here, we should use the @staticmethod decorator.
  2. It makes little sense to separate the input and expected values into two methods given that they are always used together. This just makes it harder to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggestions to improve this @pmeier? I guess we should keep this for another PR? Other thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something else: by close inspection, I have noticed that in the BoxTestBase class, we never test the box operation on two different batches of boxes (i.e. T[N, 4] vs T[M, 4] where N <> M). Should we change this base class or better to add new tests to test this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 @yassineAlouini , Even I was thinking about this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should keep this for another PR?

👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is not blocking for this PR. I would appreciate if either @yassineAlouini or @abhi-glitchhg could send a follow-up PR fixing this. Whoever wants to pick this up, please comment here to avoid duplicate work and resolve this comment afterwards.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @pmeier, sorry :(
I have uni exams starting next week, so I will not be able to devote much time to this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries @abhi-glitchhg. Good luck with that. @yassineAlouini would you be able to pick this up? Otherwise, I'll open an issue so we don't forget and we can do this later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier Yes, I have time this afternoon so will take this on the next PR. 👌

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And good luck @abhi-glitchhg in your exams. 👍

return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]

def _generate_int_expected():
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]

def _generate_float_input():
return [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]

def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()),
],
)
def test_distance_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(test_input, dtypes, tolerance, expected)

def test_distance_iou_jit(self):
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_distance_iou_loss(dtype, device):
box1 = torch.tensor([[-1, -1, 1, 1]], dtype=dtype, device=device)
box2 = torch.tensor([[0, 0, 1, 1]], dtype=dtype, device=device)
box3 = torch.tensor([[0, 1, 1, 2]], dtype=dtype, device=device)
box4 = torch.tensor([[1, 1, 2, 2]], dtype=dtype, device=device)

box1s = torch.stack(
[box2, box2],
dim=0,
)
box2s = torch.stack(
[box3, box4],
dim=0,
)

def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"):
output = ops.distance_box_iou_loss(box1, box2, reduction=reduction)
expected_output = torch.tensor(expected_output, dtype=dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will automatically take the dtype, by passing dtype as function arg. It will also get parameterized,

Suggested change
expected_output = torch.tensor(expected_output, dtype=dtype, device=device)
expected_output = torch.tensor(expected_output, device=device)

Notice that

https://github.com/pytorch/vision/pull/5792/files#diff-d183f2afc51d6a59bc70094e8f476d2468c45e415500f6eb60abad955e065156R1587

this works!

Copy link
Contributor Author

@yassineAlouini yassineAlouini Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly the same. I made the nested function into a staticmethod so it doesn't have access to the same scope.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but why is it a static method in the first place? IIUC, we are only using inside test_distance_iou_loss, correct? If yes, we can simply inline it, which also removes the need to pass the device and dtype.

tol = 1e-5 if dtype != torch.half else 1e-3
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol)

assert_distance_iou_loss(box1, box1, [0.0])

assert_distance_iou_loss(box1, box2, [0.8125])

assert_distance_iou_loss(box1, box3, [1.1923])

assert_distance_iou_loss(box1, box4, [1.2500])

assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean")
assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum")


class TestMasksToBoxes:
def test_masks_box(self):
def masks_box_check(masks, expected, tolerance=1e-4):
Expand Down
4 changes: 4 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
box_area,
box_iou,
generalized_box_iou,
distance_box_iou,
yassineAlouini marked this conversation as resolved.
Show resolved Hide resolved
masks_to_boxes,
)
from .boxes import box_convert
from .deform_conv import deform_conv2d, DeformConv2d
from .diou_loss import distance_box_iou_loss
from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
Expand Down Expand Up @@ -38,6 +40,7 @@
"box_area",
"box_iou",
"generalized_box_iou",
"distance_box_iou",
"roi_align",
"RoIAlign",
"roi_pool",
Expand All @@ -56,6 +59,7 @@
"Conv3dNormActivation",
"SqueezeExcitation",
"generalized_box_iou_loss",
"distance_box_iou_loss",
"drop_block2d",
"DropBlock2d",
"drop_block3d",
Expand Down
44 changes: 44 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,50 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
return iou - (areai - union) / areai


def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
"""
Return distance intersection-over-union (Jaccard index) between two sets of boxes.

Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.

Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
eps (float, optional): small number to prevent division by zero. Default: 1e-7

Returns:
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(distance_box_iou)

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)

inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union

lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps

# centers of boxes
x_p = boxes1[:, None, :2].sum() / 2
yassineAlouini marked this conversation as resolved.
Show resolved Hide resolved
y_p = boxes1[:, None, 2:].sum() / 2
x_g = boxes2[:, :2].sum() / 2
y_g = boxes2[:, 2:].sum() / 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @yassineAlouini , I think there is a problem with this implementation. The calculation of centre of boxes is not correct acc to me. We should be adding up only x1 x2 and y1 y2, ref .
But in current implementation, we are adding x1,y1 and x2, y2. (BBox shape is in form [x1,y1,x2,y2]) .

This can also be checked by calculating distance_box_iou_loss and distance_box_iou on a sample tensors.

import torch
from torchvision.ops import distance_box_iou, distance_box_iou_loss

box1 = torch.tensor([[-1, -1, 1, 1]], )
box2 = torch.tensor([[0, 0, 1, 1]],)

1-distance_box_iou(box1, box2)[0] == distance_box_iou_loss(box1, box2)

Last statement returns False. Ideally it should return True.

I suggest you to do following changes.

Suggested change
# centers of boxes
x_p = boxes1[:, None, :2].sum() / 2
y_p = boxes1[:, None, 2:].sum() / 2
x_g = boxes2[:, :2].sum() / 2
y_g = boxes2[:, 2:].sum() / 2
# centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2

@datumbox,
Please correct me if I'm wrong.
Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I haven't yet reviewed the correctness of the implementation as we still discuss the structure/API.

I think that's probably a typo and @yassineAlouini intended to write something like:

x_p = boxes1[:, 0::2].sum() / 2
y_p = boxes1[:, 1::2].sum() / 2
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, I think I went too quickly over this and thought that the bounding box was in the x1x2y1y2 format. Thanks for pointing this out and your suggestions. 👍

# The distance between boxes' centers squared.
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)

# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
return iou - (centers_distance_squared / diagonal_distance_squared)


def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks.
Expand Down
85 changes: 85 additions & 0 deletions torchvision/ops/diou_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch

from ..utils import _log_api_usage_once
from .boxes import _upcast


Copy link
Contributor

@oke-aditya oke-aditya May 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting above the function might be better? As comments in the function call will be a small execution of commented code everytime we call code? (Is there a subtle performance difference? Not sure but always had this doubt)

def distance_box_iou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping
boxes, the distance IoU is the same as the IoU loss.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.

Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
same dimensions.

Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[N, 4]): second set of boxes
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
applied to the output. ``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``
eps (float, optional): small number to prevent division by zero. Default: 1e-7

Returns:
Tensor[]: Loss tensor with the reduction option applied.

Reference:
Zhaohui Zheng et. al: Distance Intersection over Union Loss:
https://arxiv.org/abs/1911.08287
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(distance_box_iou_loss)

# TODO: Removing the _upcast call makes the torch.half tests in test_ops pass
# but we might get overflow problems... How to fix without casting at the end?
# boxes1 = _upcast(boxes1)
# boxes2 = _upcast(boxes2)
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)

# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)

intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union

# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
# The diagonal distance of the smallest enclosing box squared
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps

# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
# The distance between boxes' centers squared.
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)

# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
diou = iou - (centers_distance_squared / diagonal_distance_squared)
loss = 1 - diou
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss
2 changes: 1 addition & 1 deletion torchvision/ops/giou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generalized_box_iou_loss(
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
applied to the output. ``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``
eps (float, optional): small number to prevent division by zero. Default: 1e-7
eps (float): small number to prevent division by zero. Default: 1e-7
yassineAlouini marked this conversation as resolved.
Show resolved Hide resolved

Reference:
Hamid Rezatofighi et. al: Generalized Intersection over Union:
Expand Down