diff --git a/adet/modeling/blendmask/blender.py b/adet/modeling/blendmask/blender.py index 46a3f6c6a..009d8c8e7 100644 --- a/adet/modeling/blendmask/blender.py +++ b/adet/modeling/blendmask/blender.py @@ -35,20 +35,13 @@ def __call__(self, bases, proposals, gt_instances): if gt_instances is not None: # training # reshape attns - extras = proposals["extras"] - attns = proposals["top_feats"] - pos_inds = extras["pos_inds"] + dense_info = proposals["instances"] + attns = dense_info.top_feats + pos_inds = dense_info.pos_inds if pos_inds.numel() == 0: return None, {"loss_mask": sum([x.sum() * 0 for x in attns]) + bases[0].sum() * 0} - gt_inds = extras["gt_inds"] - attns = cat( - [ - # Reshape: (N, C, Hi, Wi) -> (N, Hi, Wi, C) -> (N*Hi*Wi, C) - x.permute(0, 2, 3, 1).reshape(-1, self.attn_len) - for x in attns - ], dim=0,) - attns = attns[pos_inds] + gt_inds = dense_info.gt_inds rois = self.pooler(bases, [x.gt_boxes for x in gt_instances]) rois = rois[gt_inds] @@ -68,8 +61,8 @@ def __call__(self, bases, proposals, gt_instances): N = gt_masks.size(0) gt_masks = gt_masks.view(N, -1) - gt_ctr = extras["gt_ctr"] - loss_denorm = extras["loss_denorm"] + gt_ctr = dense_info.gt_ctrs + loss_denorm = proposals["loss_denorm"] mask_losses = F.binary_cross_entropy_with_logits( pred_mask_logits, gt_masks.to(dtype=torch.float32), reduction="none") mask_loss = ((mask_losses.mean(dim=-1) * gt_ctr).sum() diff --git a/adet/modeling/fcos/fcos.py b/adet/modeling/fcos/fcos.py index f0f33b04b..73bc3d257 100644 --- a/adet/modeling/fcos/fcos.py +++ b/adet/modeling/fcos/fcos.py @@ -8,6 +8,7 @@ from detectron2.modeling.proposal_generator.build import PROPOSAL_GENERATOR_REGISTRY from adet.layers import DFConv2d, NaiveGroupNorm +from adet.utils.comm import compute_locations from .fcos_outputs import FCOSOutputs @@ -74,7 +75,8 @@ def forward(self, images, features, gt_instances=None, top_module=None): features = [features[f] for f in self.in_features] locations = self.compute_locations(features) logits_pred, reg_pred, ctrness_pred, top_feats, bbox_towers = self.fcos_head( - features, top_module, self.yield_proposal) + features, top_module, self.yield_proposal + ) results = {} if self.yield_proposal: @@ -83,60 +85,37 @@ def forward(self, images, features, gt_instances=None, top_module=None): } if self.training: - losses, extras = self.fcos_outputs.losses( + results, losses = self.fcos_outputs.losses( logits_pred, reg_pred, ctrness_pred, - locations, gt_instances + locations, gt_instances, top_feats ) - if top_module is not None: - results["extras"] = extras - results["top_feats"] = top_feats if self.yield_proposal: with torch.no_grad(): results["proposals"] = self.fcos_outputs.predict_proposals( - top_feats, logits_pred, reg_pred, - ctrness_pred, locations, images.image_sizes + logits_pred, reg_pred, ctrness_pred, + locations, images.image_sizes, top_feats ) + return results, losses else: - losses = {} - with torch.no_grad(): - proposals = self.fcos_outputs.predict_proposals( - top_feats, logits_pred, reg_pred, - ctrness_pred, locations, images.image_sizes - ) - if self.yield_proposal: - results["proposals"] = proposals - else: - results = proposals + results = self.fcos_outputs.predict_proposals( + logits_pred, reg_pred, ctrness_pred, + locations, images.image_sizes, top_feats + ) - return results, losses + return results, {} def compute_locations(self, features): locations = [] for level, feature in enumerate(features): h, w = feature.size()[-2:] - locations_per_level = self.compute_locations_per_level( + locations_per_level = compute_locations( h, w, self.fpn_strides[level], feature.device ) locations.append(locations_per_level) return locations - def compute_locations_per_level(self, h, w, stride, device): - shifts_x = torch.arange( - 0, w * stride, step=stride, - dtype=torch.float32, device=device - ) - shifts_y = torch.arange( - 0, h * stride, step=stride, - dtype=torch.float32, device=device - ) - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) - shift_x = shift_x.reshape(-1) - shift_y = shift_y.reshape(-1) - locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 - return locations - class FCOSHead(nn.Module): def __init__(self, cfg, input_shape: List[ShapeSpec]): @@ -246,4 +225,4 @@ def forward(self, x, top_module=None, yield_bbox_towers=False): bbox_reg.append(F.relu(reg)) if top_module is not None: top_feats.append(top_module(bbox_tower)) - return logits, bbox_reg, ctrness, top_feats, bbox_towers + return logits, bbox_reg, ctrness, top_feats, bbox_towers \ No newline at end of file diff --git a/adet/modeling/fcos/fcos_outputs.py b/adet/modeling/fcos/fcos_outputs.py index d46637a82..ae05d096f 100644 --- a/adet/modeling/fcos/fcos_outputs.py +++ b/adet/modeling/fcos/fcos_outputs.py @@ -115,11 +115,21 @@ def _get_ground_truth(self, locations, gt_instances): locations, gt_instances, loc_to_size_range, num_loc_list ) + training_targets["locations"] = [locations.clone() for _ in range(len(gt_instances))] + training_targets["im_inds"] = [ + locations.new_ones(locations.size(0), dtype=torch.long) * i for i in range(len(gt_instances)) + ] + # transpose im first training_targets to level first ones training_targets = { k: self._transpose(v, num_loc_list) for k, v in training_targets.items() } + training_targets["fpn_levels"] = [ + loc.new_ones(len(loc), dtype=torch.long) * level + for level, loc in enumerate(training_targets["locations"]) + ] + # we normalize reg_targets by FPN's strides here reg_targets = training_targets["reg_targets"] for l in range(len(reg_targets)): @@ -127,13 +137,28 @@ def _get_ground_truth(self, locations, gt_instances): return training_targets - def get_sample_region(self, gt, strides, num_loc_list, loc_xs, loc_ys, radius=1): - num_gts = gt.shape[0] + def get_sample_region(self, boxes, strides, num_loc_list, loc_xs, loc_ys, bitmasks=None, radius=1): + if bitmasks is not None: + _, h, w = bitmasks.size() + + ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device) + xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device) + + m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6) + m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1) + m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1) + center_x = m10 / m00 + center_y = m01 / m00 + else: + center_x = boxes[..., [0, 2]].sum() * 0.5 + center_y = boxes[..., [1, 3]].sum() * 0.5 + + num_gts = boxes.shape[0] K = len(loc_xs) - gt = gt[None].expand(K, num_gts, 4) - center_x = (gt[..., 0] + gt[..., 2]) / 2 - center_y = (gt[..., 1] + gt[..., 3]) / 2 - center_gt = gt.new_zeros(gt.shape) + boxes = boxes[None].expand(K, num_gts, 4) + center_x = center_x[None].expand(K, num_gts) + center_y = center_y[None].expand(K, num_gts) + center_gt = boxes.new_zeros(boxes.shape) # no gt if center_x.numel() == 0 or center_x[..., 0].sum() == 0: return loc_xs.new_zeros(loc_xs.shape, dtype=torch.uint8) @@ -146,10 +171,10 @@ def get_sample_region(self, gt, strides, num_loc_list, loc_xs, loc_ys, radius=1) xmax = center_x[beg:end] + stride ymax = center_y[beg:end] + stride # limit sample region in gt - center_gt[beg:end, :, 0] = torch.where(xmin > gt[beg:end, :, 0], xmin, gt[beg:end, :, 0]) - center_gt[beg:end, :, 1] = torch.where(ymin > gt[beg:end, :, 1], ymin, gt[beg:end, :, 1]) - center_gt[beg:end, :, 2] = torch.where(xmax > gt[beg:end, :, 2], gt[beg:end, :, 2], xmax) - center_gt[beg:end, :, 3] = torch.where(ymax > gt[beg:end, :, 3], gt[beg:end, :, 3], ymax) + center_gt[beg:end, :, 0] = torch.where(xmin > boxes[beg:end, :, 0], xmin, boxes[beg:end, :, 0]) + center_gt[beg:end, :, 1] = torch.where(ymin > boxes[beg:end, :, 1], ymin, boxes[beg:end, :, 1]) + center_gt[beg:end, :, 2] = torch.where(xmax > boxes[beg:end, :, 2], boxes[beg:end, :, 2], xmax) + center_gt[beg:end, :, 3] = torch.where(ymax > boxes[beg:end, :, 3], boxes[beg:end, :, 3], ymax) beg = end left = loc_xs[:, None] - center_gt[..., 0] right = center_gt[..., 2] - loc_xs[:, None] @@ -187,9 +212,13 @@ def compute_targets_for_locations(self, locations, targets, size_ranges, num_loc reg_targets_per_im = torch.stack([l, t, r, b], dim=2) if self.center_sample: + if targets_per_im.has("gt_bitmasks_full"): + bitmasks = targets_per_im.gt_bitmasks_full + else: + bitmasks = None is_in_boxes = self.get_sample_region( - bboxes, self.strides, num_loc_list, - xs, ys, radius=self.radius + bboxes, self.strides, num_loc_list, xs, ys, + bitmasks=bitmasks, radius=self.radius ) else: is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0 @@ -225,7 +254,7 @@ def compute_targets_for_locations(self, locations, targets, size_ranges, num_loc "target_inds": target_inds } - def losses(self, logits_pred, reg_pred, ctrness_pred, locations, gt_instances): + def losses(self, logits_pred, reg_pred, ctrness_pred, locations, gt_instances, top_feats=None): """ Return the losses from a set of FCOS predictions and their associated ground-truth. @@ -235,62 +264,59 @@ def losses(self, logits_pred, reg_pred, ctrness_pred, locations, gt_instances): training_targets = self._get_ground_truth(locations, gt_instances) - labels = training_targets["labels"] - reg_targets = training_targets["reg_targets"] - gt_inds = training_targets["target_inds"] - # Collect all logits and regression predictions over feature maps # and images to arrive at the same shape as the labels and targets # The final ordering is L, N, H, W from slowest to fastest axis. - logits_pred = cat( - [ - # Reshape: (N, C, Hi, Wi) -> (N, Hi, Wi, C) -> (N*Hi*Wi, C) - x.permute(0, 2, 3, 1).reshape(-1, self.num_classes) - for x in logits_pred - ], dim=0,) - reg_pred = cat( - [ - # Reshape: (N, B, Hi, Wi) -> (N, Hi, Wi, B) -> (N*Hi*Wi, B) - x.permute(0, 2, 3, 1).reshape(-1, 4) - for x in reg_pred - ], dim=0,) - ctrness_pred = cat( - [ - # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,) - x.permute(0, 2, 3, 1).reshape(-1) for x in ctrness_pred - ], dim=0,) - labels = cat( - [ - # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,) - x.reshape(-1) for x in labels - ], dim=0,) + instances = Instances((0, 0)) + instances.labels = cat([ + # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,) + x.reshape(-1) for x in training_targets["labels"] + ], dim=0) + instances.gt_inds = cat([ + # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,) + x.reshape(-1) for x in training_targets["target_inds"] + ], dim=0) + instances.im_inds = cat([ + x.reshape(-1) for x in training_targets["im_inds"] + ], dim=0) + instances.reg_targets = cat([ + # Reshape: (N, Hi, Wi, 4) -> (N*Hi*Wi, 4) + x.reshape(-1, 4) for x in training_targets["reg_targets"] + ], dim=0,) + instances.locations = cat([ + x.reshape(-1, 2) for x in training_targets["locations"] + ], dim=0) + instances.fpn_levels = cat([ + x.reshape(-1) for x in training_targets["fpn_levels"] + ], dim=0) + + instances.logits_pred = cat([ + # Reshape: (N, C, Hi, Wi) -> (N, Hi, Wi, C) -> (N*Hi*Wi, C) + x.permute(0, 2, 3, 1).reshape(-1, self.num_classes) for x in logits_pred + ], dim=0,) + instances.reg_pred = cat([ + # Reshape: (N, B, Hi, Wi) -> (N, Hi, Wi, B) -> (N*Hi*Wi, B) + x.permute(0, 2, 3, 1).reshape(-1, 4) for x in reg_pred + ], dim=0,) + instances.ctrness_pred = cat([ + # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,) + x.permute(0, 2, 3, 1).reshape(-1) for x in ctrness_pred + ], dim=0,) - gt_inds = cat( - [ - # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,) - x.reshape(-1) for x in gt_inds + if len(top_feats) > 0: + instances.top_feats = cat([ + # Reshape: (N, -1, Hi, Wi) -> (N*Hi*Wi, -1) + x.permute(0, 2, 3, 1).reshape(-1, x.size(1)) for x in top_feats ], dim=0,) - reg_targets = cat( - [ - # Reshape: (N, Hi, Wi, 4) -> (N*Hi*Wi, 4) - x.reshape(-1, 4) for x in reg_targets - ], dim=0,) + return self.fcos_losses(instances) - return self.fcos_losses( - labels, reg_targets, logits_pred, - reg_pred, ctrness_pred, gt_inds - ) - - def fcos_losses( - self, labels, reg_targets, logits_pred, - reg_pred, ctrness_pred, gt_inds, - ): - num_classes = logits_pred.size(1) + def fcos_losses(self, instances): + num_classes = instances.logits_pred.size(1) assert num_classes == self.num_classes - labels = labels.flatten() + labels = instances.labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() @@ -299,41 +325,40 @@ def fcos_losses( num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot - class_target = torch.zeros_like(logits_pred) + class_target = torch.zeros_like(instances.logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( - logits_pred, + instances.logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg - reg_pred = reg_pred[pos_inds] - reg_targets = reg_targets[pos_inds] - ctrness_pred = ctrness_pred[pos_inds] - gt_inds = gt_inds[pos_inds] + instances = instances[pos_inds] + instances.pos_inds = pos_inds - ctrness_targets = compute_ctrness_targets(reg_targets) + ctrness_targets = compute_ctrness_targets(instances.reg_targets) ctrness_targets_sum = ctrness_targets.sum() loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) + instances.gt_ctrs = ctrness_targets if pos_inds.numel() > 0: reg_loss = self.loc_loss_func( - reg_pred, - reg_targets, + instances.reg_pred, + instances.reg_targets, ctrness_targets ) / loss_denorm ctrness_loss = F.binary_cross_entropy_with_logits( - ctrness_pred, + instances.ctrness_pred, ctrness_targets, reduction="sum" ) / num_pos_avg else: - reg_loss = reg_pred.sum() * 0 - ctrness_loss = ctrness_pred.sum() * 0 + reg_loss = instances.reg_pred.sum() * 0 + ctrness_loss = instances.ctrness_pred.sum() * 0 losses = { "loss_fcos_cls": class_loss, @@ -341,16 +366,14 @@ def fcos_losses( "loss_fcos_ctr": ctrness_loss } extras = { - "pos_inds": pos_inds, - "gt_inds": gt_inds, - "gt_ctr": ctrness_targets, + "instances": instances, "loss_denorm": loss_denorm } - return losses, extras + return extras, losses def predict_proposals( - self, top_feats, logits_pred, reg_pred, - ctrness_pred, locations, image_sizes + self, logits_pred, reg_pred, ctrness_pred, + locations, image_sizes, top_feats=None ): if self.training: self.pre_nms_thresh = self.pre_nms_thresh_train @@ -389,6 +412,11 @@ def predict_proposals( ) ) + for per_im_sampled_boxes in sampled_boxes[-1]: + per_im_sampled_boxes.fpn_levels = l.new_ones( + len(per_im_sampled_boxes), dtype=torch.long + ) * i + boxlists = list(zip(*sampled_boxes)) boxlists = [Instances.cat(boxlist) for boxlist in boxlists] boxlists = self.select_over_all_levels(boxlists) @@ -396,9 +424,8 @@ def predict_proposals( return boxlists def forward_for_single_feature_map( - self, locations, logits_pred, - reg_pred, ctrness_pred, - image_sizes, top_feat=None + self, locations, logits_pred, reg_pred, + ctrness_pred, image_sizes, top_feat=None ): N, C, H, W = logits_pred.shape @@ -489,4 +516,4 @@ def select_over_all_levels(self, boxlists): keep = torch.nonzero(keep).squeeze(1) result = result[keep] results.append(result) - return results + return results \ No newline at end of file diff --git a/adet/utils/comm.py b/adet/utils/comm.py index 3632d6bf4..802bf76da 100644 --- a/adet/utils/comm.py +++ b/adet/utils/comm.py @@ -1,4 +1,7 @@ +import torch +import torch.nn.functional as F import torch.distributed as dist + from detectron2.utils.comm import get_world_size @@ -9,3 +12,44 @@ def reduce_sum(tensor): tensor = tensor.clone() dist.all_reduce(tensor, op=dist.ReduceOp.SUM) return tensor + + +def aligned_bilinear(tensor, factor): + assert tensor.dim() == 4 + assert factor >= 1 + assert int(factor) == factor + + if factor == 1: + return tensor + + h, w = tensor.size()[2:] + tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate") + oh = factor * h + 1 + ow = factor * w + 1 + tensor = F.interpolate( + tensor, size=(oh, ow), + mode='bilinear', + align_corners=True + ) + tensor = F.pad( + tensor, pad=(factor // 2, 0, factor // 2, 0), + mode="replicate" + ) + + return tensor[:, :, :oh - 1, :ow - 1] + + +def compute_locations(h, w, stride, device): + shifts_x = torch.arange( + 0, w * stride, step=stride, + dtype=torch.float32, device=device + ) + shifts_y = torch.arange( + 0, h * stride, step=stride, + dtype=torch.float32, device=device + ) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 + return locations \ No newline at end of file diff --git a/setup.py b/setup.py index 98293d5f9..93e66422d 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def get_extensions(): setup( name="AdelaiDet", - version="0.1.1", + version="0.2.0", author="Adelaide Intelligent Machines", url="https://github.com/stanstarks/AdelaiDet", description="AdelaiDet is AIM's research "