-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathboxes.py
271 lines (213 loc) · 8.8 KB
/
boxes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import torch
from torch import Tensor
from typing import Tuple
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
import torchvision
from torchvision.extension import _assert_has_ops
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).
NMS iteratively removes lower scoring boxes which have an
IoU greater than iou_threshold with another (higher scoring)
box.
If multiple boxes have the exact same score and satisfy the IoU
criterion with respect to a reference box, the selected box is
not guaranteed to be the same between CPU and GPU. This is similar
to the behavior of argsort in PyTorch when repeated values are present.
Parameters
----------
boxes : Tensor[N, 4])
boxes to perform NMS on. They
are expected to be in (x1, y1, x2, y2) format
scores : Tensor[N]
scores for each one of the boxes
iou_threshold : float
discards all overlapping
boxes with IoU > iou_threshold
Returns
-------
keep : Tensor
int64 tensor with the indices
of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
_assert_has_ops()
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
@torch.jit._script_if_tracing
def batched_nms(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
"""
Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Parameters
----------
boxes : Tensor[N, 4]
boxes where NMS will be performed. They
are expected to be in (x1, y1, x2, y2) format
scores : Tensor[N]
scores for each one of the boxes
idxs : Tensor[N]
indices of the categories for each one of the boxes.
iou_threshold : float
discards all overlapping boxes
with IoU > iou_threshold
Returns
-------
keep : Tensor
int64 tensor with the indices of
the elements that have been kept by NMS, sorted
in decreasing order of scores
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# strategy: in order to perform NMS independently per class.
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
else:
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]
keep = nms(boxes_for_nms, scores, iou_threshold)
return keep
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
"""
Remove boxes which contains at least one side smaller than min_size.
Args:
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
min_size (float): minimum size
Returns:
keep (Tensor[K]): indices of the boxes that have both sides
larger than min_size
"""
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size)
keep = torch.where(keep)[0]
return keep
def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
"""
Clip boxes so that they lie inside an image of size `size`.
Args:
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
size (Tuple[height, width]): size of the image
Returns:
clipped_boxes (Tensor[N, 4])
"""
dim = boxes.dim()
boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2]
height, width = size
if torchvision._is_tracing():
boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
else:
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
return clipped_boxes.reshape(boxes.shape)
def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
"""
Converts boxes from given in_fmt to out_fmt.
Supported in_fmt and out_fmt are:
'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h
being width and height.
Args:
boxes (Tensor[N, 4]): boxes which will be converted.
in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'].
out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']
Returns:
boxes (Tensor[N, 4]): Boxes into converted format.
"""
allowed_fmts = ("xyxy", "xywh", "cxcywh")
if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")
if in_fmt == out_fmt:
return boxes.clone()
if in_fmt != 'xyxy' and out_fmt != 'xyxy':
# convert to xyxy and change in_fmt xyxy
if in_fmt == "xywh":
boxes = _box_xywh_to_xyxy(boxes)
elif in_fmt == "cxcywh":
boxes = _box_cxcywh_to_xyxy(boxes)
in_fmt = 'xyxy'
if in_fmt == "xyxy":
if out_fmt == "xywh":
boxes = _box_xyxy_to_xywh(boxes)
elif out_fmt == "cxcywh":
boxes = _box_xyxy_to_cxcywh(boxes)
elif out_fmt == "xyxy":
if in_fmt == "xywh":
boxes = _box_xywh_to_xyxy(boxes)
elif in_fmt == "cxcywh":
boxes = _box_cxcywh_to_xyxy(boxes)
return boxes
def box_area(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2) coordinates.
Args:
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2) format
Returns:
area (Tensor[N]): area for each box
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Args:
boxes1 (Tensor[N, 4])
boxes2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
"""
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
iou = inter / (area1[:, None] + area2 - inter)
return iou
# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Return generalized intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Args:
boxes1 (Tensor[N, 4])
boxes2 (Tensor[M, 4])
Returns:
generalized_iou (Tensor[N, M]): the NxM matrix containing the pairwise generalized_IoU values
for every element in boxes1 and boxes2
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
whi = (rbi - lti).clamp(min=0) # [N,M,2]
areai = whi[:, :, 0] * whi[:, :, 1]
return iou - (areai - union) / areai