-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distance IoU #5786
Distance IoU #5786
Changes from 18 commits
135763c
ec599d2
41703e6
51616ed
ee37c8d
7631ab7
b744d6d
8ceffcc
bc65b83
a4e58b7
4ba5cdc
0ead2c3
a2702f8
27894ef
d4bd825
497a7c1
d8b7f35
a054032
d7baa67
4213ee4
1a2d6ab
2856947
3a9d3d7
13fa495
1b2f1e6
ab44428
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1258,6 +1258,78 @@ def test_giou_jit(self) -> None: | |||||
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||||||
|
||||||
|
||||||
class TestDistanceBoxIoU(BoxTestBase): | ||||||
def _target_fn(self): | ||||||
return (True, ops.distance_box_iou) | ||||||
|
||||||
def _generate_int_input(): | ||||||
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] | ||||||
|
||||||
def _generate_int_expected(): | ||||||
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] | ||||||
|
||||||
def _generate_float_input(): | ||||||
return [ | ||||||
[285.3538, 185.5758, 1193.5110, 851.4551], | ||||||
[285.1472, 188.7374, 1192.4984, 851.0669], | ||||||
[279.2440, 197.9812, 1189.4746, 849.2019], | ||||||
] | ||||||
|
||||||
def _generate_float_expected(): | ||||||
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] | ||||||
|
||||||
@pytest.mark.parametrize( | ||||||
"test_input, dtypes, tolerance, expected", | ||||||
[ | ||||||
pytest.param( | ||||||
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() | ||||||
), | ||||||
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), | ||||||
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), | ||||||
], | ||||||
) | ||||||
def test_distance_iou(self, test_input, dtypes, tolerance, expected): | ||||||
self._run_test(test_input, dtypes, tolerance, expected) | ||||||
|
||||||
def test_distance_iou_jit(self): | ||||||
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||||||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||||||
def test_distance_iou_loss(dtype, device): | ||||||
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) | ||||||
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) | ||||||
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) | ||||||
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) | ||||||
|
||||||
box1s = torch.stack( | ||||||
[box2, box2], | ||||||
dim=0, | ||||||
) | ||||||
box2s = torch.stack( | ||||||
[box3, box4], | ||||||
dim=0, | ||||||
) | ||||||
|
||||||
def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"): | ||||||
output = ops.distance_box_iou_loss(box1, box2, reduction=reduction) | ||||||
expected_output = torch.tensor(expected_output, dtype=dtype, device=device) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this will automatically take the dtype, by passing dtype as function arg. It will also get parameterized,
Suggested change
Notice that this works! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not exactly the same. I made the nested function into a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but why is it a static method in the first place? IIUC, we are only using inside |
||||||
tol = 1e-5 if dtype != torch.half else 1e-3 | ||||||
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) | ||||||
|
||||||
assert_distance_iou_loss(box1, box1, 0.0) | ||||||
|
||||||
assert_distance_iou_loss(box1, box2, 0.8750) | ||||||
|
||||||
assert_distance_iou_loss(box1, box3, 1.1923) | ||||||
|
||||||
assert_distance_iou_loss(box1, box4, 1.2778) | ||||||
|
||||||
assert_distance_iou_loss(box1s, box2s, 1.9000, reduction="mean") | ||||||
assert_distance_iou_loss(box1s, box2s, 3.8000, reduction="sum") | ||||||
|
||||||
|
||||||
class TestMasksToBoxes: | ||||||
def test_masks_box(self): | ||||||
def masks_box_check(masks, expected, tolerance=1e-4): | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
|
||
from .boxes import distance_box_iou | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Commenting above the function might be better? As comments in the function call will be a small execution of commented code everytime we call code? (Is there a subtle performance difference? Not sure but always had this doubt) |
||
def distance_box_iou_loss( | ||
boxes1: torch.Tensor, | ||
boxes2: torch.Tensor, | ||
reduction: str = "none", | ||
eps: float = 1e-7, | ||
) -> torch.Tensor: | ||
""" | ||
Gradient-friendly IoU loss with an additional penalty that is non-zero when the | ||
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping | ||
boxes, the distance IoU is the same as the IoU loss. | ||
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. | ||
|
||
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with | ||
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the | ||
same dimensions. | ||
|
||
Args: | ||
boxes1 (Tensor[N, 4] or Tensor[4]): first set of boxes | ||
boxes2 (Tensor[N, 4] or Tensor[4]): second set of boxes | ||
reduction (string, optional): Specifies the reduction to apply to the output: | ||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be | ||
applied to the output. ``'mean'``: The output will be averaged. | ||
``'sum'``: The output will be summed. Default: ``'none'`` | ||
eps (float, optional): small number to prevent division by zero. Default: 1e-7 | ||
|
||
Returns: | ||
Tensor[]: Loss tensor with the reduction option applied. | ||
|
||
Reference: | ||
Zhaohui Zheng et. al: Distance Intersection over Union Loss: | ||
https://arxiv.org/abs/1911.08287 | ||
""" | ||
if boxes1.dim() == 1 and boxes2.dim() == 1: | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
batch_boxes1 = boxes1.unsqueeze(0) | ||
batch_boxes2 = boxes2.unsqueeze(0) | ||
diou = distance_box_iou(batch_boxes1, batch_boxes2, eps)[0, 0] | ||
else: | ||
diou = distance_box_iou(boxes1, boxes2, eps)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yassineAlouini I'm not sure this approach is equivalent to what we had earlier. Please correct me if I'm wrong but I understand that What I had in mind is try to refactor the code at cc @abhi-glitchhg because you follow a similar approach on the other PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW if you both prefer to revert to your earlier versions of the code (which didn't reuse ops.boxes) and tackle this on separate future PRs, I'm happy to go down that route. The PRs for cIoU and dIoU has been dragging for a while and I appreciate that this can become frustrating at one point. Let me know your preference so that we make this a more enjoyable experience for both of you. Thanks a bunch for your work so far. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @datumbox , I agree with your concerns. This is not computationally efficient. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, sounds good. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's a very good point and I might have introduced a bug by going quickly on my refactor, so thanks for pointing this out. I can revert to previous code or keep working on this here (in this PR), both work for me. 👌 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the flexibility! Let's revert and use the previously vetted code on the loss. We can investigate refactoring all losses to share code with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds great! I have also found a fix for the There are many options:
What are your thoughts @datumbox? 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yassineAlouini From what I see you've reverted the code that reused estimations from Concerning the casting question, note that in
So I think the safe thing to do here is to follow the same approach as in gIoU and upcast. I think we should move the method cc @fmassa for visibility in case I stated something incorrect here. |
||
loss = 1 - diou | ||
if reduction == "mean": | ||
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() | ||
elif reduction == "sum": | ||
loss = loss.sum() | ||
# Cast the loss to the same dtype as the input boxes | ||
loss = loss.to(boxes1.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? Shouldn't all operations preserve the dtype anyway? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed this shouldn't be needed. I believe the casting here happens to return the loss on the same dtype as the original boxes (due to upcastings) but I don't think we should do this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that it shouldn't be necessary but I had to do it to preserve the "half" dtype. Not sure what casting happens so that it is turned into float. I will give this another look and let you know if I need help. |
||
return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has nothing to do with this PR, but I want make this visible. Although this works, there are two issues with this structure introduced in #5380 cc @datumbox:
self
as first parameter. Given that we don't need it here, we should use the@staticmethod
decorator.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any suggestions to improve this @pmeier? I guess we should keep this for another PR? Other thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something else: by close inspection, I have noticed that in the
BoxTestBase
class, we never test the box operation on two different batches of boxes (i.e. T[N, 4] vs T[M, 4] where N <> M). Should we change this base class or better to add new tests to test this case?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 @yassineAlouini , Even I was thinking about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this is not blocking for this PR. I would appreciate if either @yassineAlouini or @abhi-glitchhg could send a follow-up PR fixing this. Whoever wants to pick this up, please comment here to avoid duplicate work and resolve this comment afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @pmeier, sorry :(
I have uni exams starting next week, so I will not be able to devote much time to this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries @abhi-glitchhg. Good luck with that. @yassineAlouini would you be able to pick this up? Otherwise, I'll open an issue so we don't forget and we can do this later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pmeier Yes, I have time this afternoon so will take this on the next PR. 👌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And good luck @abhi-glitchhg in your exams. 👍