-
Notifications
You must be signed in to change notification settings - Fork 229
/
Copy pathtest_giou_loss.py
61 lines (46 loc) · 2.02 KB
/
test_giou_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# pyre-strict
import unittest
import numpy as np
import torch
from fvcore.nn import giou_loss
class TestGIoULoss(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
np.random.seed(42)
def test_giou_loss(self) -> None:
# Identical boxes should have loss of 0
box = torch.tensor([-1, -1, 1, 1], dtype=torch.float32)
loss = giou_loss(box, box)
self.assertTrue(np.allclose(loss, [0.0]))
# quarter size box inside other box = IoU of 0.25
box2 = torch.tensor([0, 0, 1, 1], dtype=torch.float32)
loss = giou_loss(box, box2)
self.assertTrue(np.allclose(loss, [0.75]))
# Two side by side boxes, area=union
# IoU=0 and GIoU=0 (loss 1.0)
box3 = torch.tensor([0, 1, 1, 2], dtype=torch.float32)
loss = giou_loss(box2, box3)
self.assertTrue(np.allclose(loss, [1.0]))
# Two diagonally adjacent boxes, area=2*union
# IoU=0 and GIoU=-0.5 (loss 1.5)
box4 = torch.tensor([1, 1, 2, 2], dtype=torch.float32)
loss = giou_loss(box2, box4)
self.assertTrue(np.allclose(loss, [1.5]))
# Test batched loss and reductions
box1s = torch.stack([box2, box2], dim=0)
box2s = torch.stack([box3, box4], dim=0)
loss = giou_loss(box1s, box2s, reduction="sum")
self.assertTrue(np.allclose(loss, [2.5]))
loss = giou_loss(box1s, box2s, reduction="mean")
self.assertTrue(np.allclose(loss, [1.25]))
def test_empty_inputs(self) -> None:
box1 = torch.randn([0, 4], dtype=torch.float32).requires_grad_()
box2 = torch.randn([0, 4], dtype=torch.float32).requires_grad_()
loss = giou_loss(box1, box2, reduction="mean")
loss.backward()
self.assertEqual(loss.detach().numpy(), 0.0)
self.assertIsNotNone(box1.grad)
self.assertIsNotNone(box2.grad)
loss = giou_loss(box1, box2, reduction="none")
self.assertEqual(loss.numel(), 0)