-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distance IoU #5786
Distance IoU #5786
Changes from 23 commits
135763c
ec599d2
41703e6
51616ed
ee37c8d
7631ab7
b744d6d
8ceffcc
bc65b83
a4e58b7
4ba5cdc
0ead2c3
a2702f8
27894ef
d4bd825
497a7c1
d8b7f35
a054032
d7baa67
4213ee4
1a2d6ab
2856947
3a9d3d7
13fa495
1b2f1e6
ab44428
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1258,6 +1258,97 @@ def test_giou_jit(self) -> None: | |
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||
|
||
|
||
class TestDistanceBoxIoU(BoxTestBase): | ||
def _target_fn(self): | ||
return (True, ops.distance_box_iou) | ||
|
||
def _generate_int_input(): | ||
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] | ||
|
||
def _generate_int_expected(): | ||
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] | ||
|
||
def _generate_float_input(): | ||
return [ | ||
[285.3538, 185.5758, 1193.5110, 851.4551], | ||
[285.1472, 188.7374, 1192.4984, 851.0669], | ||
[279.2440, 197.9812, 1189.4746, 849.2019], | ||
] | ||
|
||
def _generate_float_expected(): | ||
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] | ||
|
||
@pytest.mark.parametrize( | ||
"test_input, dtypes, tolerance, expected", | ||
[ | ||
pytest.param( | ||
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() | ||
), | ||
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), | ||
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), | ||
], | ||
) | ||
def test_distance_iou(self, test_input, dtypes, tolerance, expected): | ||
self._run_test(test_input, dtypes, tolerance, expected) | ||
|
||
def test_distance_iou_jit(self): | ||
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||
|
||
|
||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
def test_distance_iou_loss(dtype, device): | ||
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) | ||
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) | ||
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) | ||
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) | ||
|
||
box1s = torch.stack( | ||
[box2, box2], | ||
dim=0, | ||
) | ||
box2s = torch.stack( | ||
[box3, box4], | ||
dim=0, | ||
) | ||
|
||
def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"): | ||
output = ops.distance_box_iou_loss(box1, box2, reduction=reduction) | ||
# TODO: When passing the dtype, the torch.half fails as usual. | ||
expected_output = torch.tensor(expected_output, device=device) | ||
tol = 1e-5 if dtype != torch.half else 1e-3 | ||
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) | ||
|
||
assert_distance_iou_loss(box1, box1, 0.0) | ||
|
||
assert_distance_iou_loss(box1, box2, 0.8125) | ||
|
||
assert_distance_iou_loss(box1, box3, 1.1923) | ||
|
||
assert_distance_iou_loss(box1, box4, 1.2500) | ||
|
||
assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean") | ||
assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum") | ||
|
||
|
||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
def test_empty_distance_iou_inputs(dtype, device) -> None: | ||
box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() | ||
box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() | ||
|
||
loss = ops.distance_box_iou_loss(box1, box2, reduction="mean") | ||
loss.backward() | ||
|
||
tol = 1e-3 if dtype is torch.half else 1e-5 | ||
torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) | ||
assert box1.grad is not None, "box1.grad should not be None after backward is called" | ||
assert box2.grad is not None, "box2.grad should not be None after backward is called" | ||
|
||
loss = ops.distance_box_iou_loss(box1, box2, reduction="none") | ||
assert loss.numel() == 0, "diou_loss for two empty box should be empty" | ||
|
||
|
||
class TestCompleteBoxIou(BoxTestBase): | ||
def _target_fn(self) -> Tuple[bool, Callable]: | ||
return (True, ops.complete_box_iou) | ||
|
@@ -1676,6 +1767,7 @@ def test_ciou_loss(self, dtype, device): | |
def assert_ciou_loss(box1, box2, expected_output, reduction="none"): | ||
|
||
output = ops.complete_box_iou_loss(box1, box2, reduction=reduction) | ||
# TODO: When passing the dtype, the torch.half test doesn't pass... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it still valid? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we provide a bit more info on what doesn't pass here and what's exactly the issue? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @oke-aditya I think so. @datumbox I read the cIoU code since it was passing the torch.half tests and I found out that the dtype wasn't passed so the test wasn't correct for torch.half. For now, I have removed the dtype to have the same code as cIoU but I think we should investigate this further (or maybe we can't do anything since we use the _upcast function? 🤔). Let me know if this clear enough. I can provide more details. |
||
expected_output = torch.tensor(expected_output, device=device) | ||
tol = 1e-5 if dtype != torch.half else 1e-3 | ||
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
|
||
from ..utils import _log_api_usage_once | ||
from .boxes import _upcast | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Commenting above the function might be better? As comments in the function call will be a small execution of commented code everytime we call code? (Is there a subtle performance difference? Not sure but always had this doubt) |
||
def distance_box_iou_loss( | ||
boxes1: torch.Tensor, | ||
boxes2: torch.Tensor, | ||
reduction: str = "none", | ||
eps: float = 1e-7, | ||
) -> torch.Tensor: | ||
""" | ||
Gradient-friendly IoU loss with an additional penalty that is non-zero when the | ||
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping | ||
boxes, the distance IoU is the same as the IoU loss. | ||
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. | ||
|
||
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with | ||
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the | ||
same dimensions. | ||
|
||
Args: | ||
boxes1 (Tensor[N, 4]): first set of boxes | ||
boxes2 (Tensor[N, 4]): second set of boxes | ||
reduction (string, optional): Specifies the reduction to apply to the output: | ||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be | ||
applied to the output. ``'mean'``: The output will be averaged. | ||
``'sum'``: The output will be summed. Default: ``'none'`` | ||
eps (float, optional): small number to prevent division by zero. Default: 1e-7 | ||
|
||
Returns: | ||
Tensor: Loss tensor with the reduction option applied. | ||
|
||
Reference: | ||
Zhaohui Zheng et. al: Distance Intersection over Union Loss: | ||
https://arxiv.org/abs/1911.08287 | ||
""" | ||
|
||
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant this comment |
||
|
||
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||
_log_api_usage_once(distance_box_iou_loss) | ||
|
||
boxes1 = _upcast(boxes1) | ||
boxes2 = _upcast(boxes2) | ||
|
||
x1, y1, x2, y2 = boxes1.unbind(dim=-1) | ||
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) | ||
|
||
# Intersection keypoints | ||
xkis1 = torch.max(x1, x1g) | ||
ykis1 = torch.max(y1, y1g) | ||
xkis2 = torch.min(x2, x2g) | ||
ykis2 = torch.min(y2, y2g) | ||
|
||
intsct = torch.zeros_like(x1) | ||
mask = (ykis2 > ykis1) & (xkis2 > xkis1) | ||
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) | ||
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps | ||
iou = intsct / union | ||
|
||
# smallest enclosing box | ||
xc1 = torch.min(x1, x1g) | ||
yc1 = torch.min(y1, y1g) | ||
xc2 = torch.max(x2, x2g) | ||
yc2 = torch.max(y2, y2g) | ||
# The diagonal distance of the smallest enclosing box squared | ||
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps | ||
|
||
# centers of boxes | ||
x_p = (x2 + x1) / 2 | ||
y_p = (y2 + y1) / 2 | ||
x_g = (x1g + x2g) / 2 | ||
y_g = (y1g + y2g) / 2 | ||
# The distance between boxes' centers squared. | ||
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) | ||
|
||
# The distance IoU is the IoU penalized by a normalized | ||
# distance between boxes' centers squared. | ||
loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared) | ||
if reduction == "mean": | ||
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() | ||
elif reduction == "sum": | ||
loss = loss.sum() | ||
return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has nothing to do with this PR, but I want make this visible. Although this works, there are two issues with this structure introduced in #5380 cc @datumbox:
self
as first parameter. Given that we don't need it here, we should use the@staticmethod
decorator.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any suggestions to improve this @pmeier? I guess we should keep this for another PR? Other thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something else: by close inspection, I have noticed that in the
BoxTestBase
class, we never test the box operation on two different batches of boxes (i.e. T[N, 4] vs T[M, 4] where N <> M). Should we change this base class or better to add new tests to test this case?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 @yassineAlouini , Even I was thinking about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this is not blocking for this PR. I would appreciate if either @yassineAlouini or @abhi-glitchhg could send a follow-up PR fixing this. Whoever wants to pick this up, please comment here to avoid duplicate work and resolve this comment afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @pmeier, sorry :(
I have uni exams starting next week, so I will not be able to devote much time to this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries @abhi-glitchhg. Good luck with that. @yassineAlouini would you be able to pick this up? Otherwise, I'll open an issue so we don't forget and we can do this later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pmeier Yes, I have time this afternoon so will take this on the next PR. 👌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And good luck @abhi-glitchhg in your exams. 👍