-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathQR_deformable_detr_head.py
337 lines (298 loc) · 14 KB
/
QR_deformable_detr_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Linear, bias_init_with_prob, constant_init
from mmcv.runner import force_fp32
from mmdet.core import multi_apply
from mmdet.models.utils.transformer import inverse_sigmoid
from ..builder import HEADS
from .detr_head import DETRHead
@HEADS.register_module()
class QRDeformableDETRHead(DETRHead):
"""Head of DeformDETR: Deformable DETR: Deformable Transformers for End-to-
End Object Detection.
Code is modified from the `official github repo
<https://github.com/fundamentalvision/Deformable-DETR>`_.
More details can be found in the `paper
<https://arxiv.org/abs/2010.04159>`_ .
Args:
with_box_refine (bool): Whether to refine the reference points
in the decoder. Defaults to False.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
transformer (obj:`ConfigDict`): ConfigDict is used for building
the Encoder and Decoder.
"""
def __init__(self,
*args,
with_box_refine=False,
as_two_stage=False,
transformer=None,
**kwargs):
self.with_box_refine = with_box_refine
self.as_two_stage = as_two_stage
if self.as_two_stage:
transformer['as_two_stage'] = self.as_two_stage
super(QRDeformableDETRHead, self).__init__(
*args, transformer=transformer, **kwargs)
def _init_layers(self):
"""Initialize classification branch and regression branch of head."""
fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = []
for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.ReLU())
reg_branch.append(Linear(self.embed_dims, 4))
reg_branch = nn.Sequential(*reg_branch)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
# last reg_branch is used to generate proposal from
# encode feature map when as_two_stage is True.
num_pred = (self.transformer.decoder.num_layers + 1) if \
self.as_two_stage else self.transformer.decoder.num_layers
if self.with_box_refine:
self.cls_branches = _get_clones(fc_cls, num_pred)
self.reg_branches = _get_clones(reg_branch, num_pred)
else:
self.cls_branches = nn.ModuleList(
[fc_cls for _ in range(num_pred)])
self.reg_branches = nn.ModuleList(
[reg_branch for _ in range(num_pred)])
if not self.as_two_stage:
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims * 2)
def init_weights(self):
"""Initialize weights of the DeformDETR head."""
self.transformer.init_weights()
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
for m in self.cls_branches:
nn.init.constant_(m.bias, bias_init)
for m in self.reg_branches:
constant_init(m[-1], 0, bias=0)
nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0)
if self.as_two_stage:
for m in self.reg_branches:
nn.init.constant_(m[-1].bias.data[2:], 0.0)
def forward(self, mlvl_feats, img_metas):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 4D-tensor with shape
(N, C, H, W).
img_metas (list[dict]): List of image information.
Returns:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, h). \
Shape [nb_dec, bs, num_query, 4].
enc_outputs_class (Tensor): The score of each point on encode \
feature map, has shape (N, h*w, num_class). Only when \
as_two_stage is True it would be returned, otherwise \
`None` would be returned.
enc_outputs_coord (Tensor): The proposal generate from the \
encode feature map, has shape (N, h*w, 4). Only when \
as_two_stage is True it would be returned, otherwise \
`None` would be returned.
"""
batch_size = mlvl_feats[0].size(0)
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
img_masks = mlvl_feats[0].new_ones(
(batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w, _ = img_metas[img_id]['img_shape']
img_masks[img_id, :img_h, :img_w] = 0
mlvl_masks = []
mlvl_positional_encodings = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(img_masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
mlvl_positional_encodings.append(
self.positional_encoding(mlvl_masks[-1]))
query_embeds = None
if not self.as_two_stage:
query_embeds = self.query_embedding.weight
hs, init_reference, inter_references, \
enc_outputs_class, enc_outputs_coord = self.transformer(
mlvl_feats,
mlvl_masks,
query_embeds,
mlvl_positional_encodings,
reg_branches=self.reg_branches if self.with_box_refine or self.as_two_stage else None, # noqa:E501
cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
)
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
for qid in range(hs.shape[0]):
if not feat.requires_grad: # test mode
lvl = qid
elif qid < 1: # training mode sqr query-stage alignment
lvl = 0
elif qid >= 1 and qid < 3:
lvl = 1
elif qid >= 3 and qid < 6:
lvl = 2
elif qid >= 6 and qid < 11:
lvl = 3
elif qid >= 11 and qid < 19:
lvl = 4
elif qid >= 19 and qid < 32:
lvl = 5
else:
lvl = 6
if lvl == 0:
reference = init_reference
else:
reference = inter_references[qid - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.cls_branches[lvl](hs[qid])
tmp = self.reg_branches[lvl](hs[qid])
if reference.shape[-1] == 4:
tmp += reference
else:
assert reference.shape[-1] == 2
tmp[..., :2] += reference
outputs_coord = tmp.sigmoid()
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_classes = torch.stack(outputs_classes)
outputs_coords = torch.stack(outputs_coords)
if self.as_two_stage:
return outputs_classes, outputs_coords, \
enc_outputs_class, \
enc_outputs_coord.sigmoid()
else:
return outputs_classes, outputs_coords, \
None, None
@force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
def loss(self,
all_cls_scores,
all_bbox_preds,
enc_cls_scores,
enc_bbox_preds,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore=None):
""""Loss function.
Args:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
num_dec_layers = len(all_cls_scores)
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_bboxes_ignore_list = [
gt_bboxes_ignore for _ in range(num_dec_layers)
]
img_metas_list = [img_metas for _ in range(num_dec_layers)]
losses_cls, losses_bbox, losses_iou = multi_apply(
self.loss_single, all_cls_scores, all_bbox_preds,
all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
all_gt_bboxes_ignore_list)
loss_dict = dict()
# loss of proposal generated from encode feature map.
if enc_cls_scores is not None:
binary_labels_list = [
torch.zeros_like(gt_labels_list[i])
for i in range(len(img_metas))
]
enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
self.loss_single(enc_cls_scores, enc_bbox_preds,
gt_bboxes_list, binary_labels_list,
img_metas, gt_bboxes_ignore)
loss_dict['enc_loss_cls'] = enc_loss_cls
loss_dict['enc_loss_bbox'] = enc_losses_bbox
loss_dict['enc_loss_iou'] = enc_losses_iou
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_bbox'] = losses_bbox[-1]
loss_dict['loss_iou'] = losses_iou[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
losses_bbox[:-1],
losses_iou[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
num_dec_layer += 1
return loss_dict
@force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
def get_bboxes(self,
all_cls_scores,
all_bbox_preds,
enc_cls_scores,
enc_bbox_preds,
img_metas,
rescale=False):
"""Transform network outputs for a batch into bbox predictions.
Args:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
img_metas (list[dict]): Meta information of each image.
rescale (bool, optional): If True, return boxes in original
image space. Default False.
Returns:
list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
The first item is an (n, 5) tensor, where the first 4 columns \
are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
5-th column is a score between 0 and 1. The second item is a \
(n,) tensor where each item is the predicted class label of \
the corresponding box.
"""
cls_scores = all_cls_scores[-1]
bbox_preds = all_bbox_preds[-1]
result_list = []
for img_id in range(len(img_metas)):
cls_score = cls_scores[img_id]
bbox_pred = bbox_preds[img_id]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(cls_score, bbox_pred,
img_shape, scale_factor,
rescale)
result_list.append(proposals)
return result_list