Skip to content

Commit

Permalink
Merge pull request aim-uofa#93 from aim-uofa/v0.2.0
Browse files Browse the repository at this point in the history
merge fcos extra information
  • Loading branch information
tianzhi0549 authored Jun 6, 2020
2 parents e65d9eb + 7efe7e0 commit d423e3d
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 131 deletions.
19 changes: 6 additions & 13 deletions adet/modeling/blendmask/blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,13 @@ def __call__(self, bases, proposals, gt_instances):
if gt_instances is not None:
# training
# reshape attns
extras = proposals["extras"]
attns = proposals["top_feats"]
pos_inds = extras["pos_inds"]
dense_info = proposals["instances"]
attns = dense_info.top_feats
pos_inds = dense_info.pos_inds
if pos_inds.numel() == 0:
return None, {"loss_mask": sum([x.sum() * 0 for x in attns]) + bases[0].sum() * 0}

gt_inds = extras["gt_inds"]
attns = cat(
[
# Reshape: (N, C, Hi, Wi) -> (N, Hi, Wi, C) -> (N*Hi*Wi, C)
x.permute(0, 2, 3, 1).reshape(-1, self.attn_len)
for x in attns
], dim=0,)
attns = attns[pos_inds]
gt_inds = dense_info.gt_inds

rois = self.pooler(bases, [x.gt_boxes for x in gt_instances])
rois = rois[gt_inds]
Expand All @@ -68,8 +61,8 @@ def __call__(self, bases, proposals, gt_instances):
N = gt_masks.size(0)
gt_masks = gt_masks.view(N, -1)

gt_ctr = extras["gt_ctr"]
loss_denorm = extras["loss_denorm"]
gt_ctr = dense_info.gt_ctrs
loss_denorm = proposals["loss_denorm"]
mask_losses = F.binary_cross_entropy_with_logits(
pred_mask_logits, gt_masks.to(dtype=torch.float32), reduction="none")
mask_loss = ((mask_losses.mean(dim=-1) * gt_ctr).sum()
Expand Down
51 changes: 15 additions & 36 deletions adet/modeling/fcos/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from detectron2.modeling.proposal_generator.build import PROPOSAL_GENERATOR_REGISTRY

from adet.layers import DFConv2d, NaiveGroupNorm
from adet.utils.comm import compute_locations
from .fcos_outputs import FCOSOutputs


Expand Down Expand Up @@ -74,7 +75,8 @@ def forward(self, images, features, gt_instances=None, top_module=None):
features = [features[f] for f in self.in_features]
locations = self.compute_locations(features)
logits_pred, reg_pred, ctrness_pred, top_feats, bbox_towers = self.fcos_head(
features, top_module, self.yield_proposal)
features, top_module, self.yield_proposal
)

results = {}
if self.yield_proposal:
Expand All @@ -83,60 +85,37 @@ def forward(self, images, features, gt_instances=None, top_module=None):
}

if self.training:
losses, extras = self.fcos_outputs.losses(
results, losses = self.fcos_outputs.losses(
logits_pred, reg_pred, ctrness_pred,
locations, gt_instances
locations, gt_instances, top_feats
)

if top_module is not None:
results["extras"] = extras
results["top_feats"] = top_feats
if self.yield_proposal:
with torch.no_grad():
results["proposals"] = self.fcos_outputs.predict_proposals(
top_feats, logits_pred, reg_pred,
ctrness_pred, locations, images.image_sizes
logits_pred, reg_pred, ctrness_pred,
locations, images.image_sizes, top_feats
)
return results, losses
else:
losses = {}
with torch.no_grad():
proposals = self.fcos_outputs.predict_proposals(
top_feats, logits_pred, reg_pred,
ctrness_pred, locations, images.image_sizes
)
if self.yield_proposal:
results["proposals"] = proposals
else:
results = proposals
results = self.fcos_outputs.predict_proposals(
logits_pred, reg_pred, ctrness_pred,
locations, images.image_sizes, top_feats
)

return results, losses
return results, {}

def compute_locations(self, features):
locations = []
for level, feature in enumerate(features):
h, w = feature.size()[-2:]
locations_per_level = self.compute_locations_per_level(
locations_per_level = compute_locations(
h, w, self.fpn_strides[level],
feature.device
)
locations.append(locations_per_level)
return locations

def compute_locations_per_level(self, h, w, stride, device):
shifts_x = torch.arange(
0, w * stride, step=stride,
dtype=torch.float32, device=device
)
shifts_y = torch.arange(
0, h * stride, step=stride,
dtype=torch.float32, device=device
)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
return locations


class FCOSHead(nn.Module):
def __init__(self, cfg, input_shape: List[ShapeSpec]):
Expand Down Expand Up @@ -246,4 +225,4 @@ def forward(self, x, top_module=None, yield_bbox_towers=False):
bbox_reg.append(F.relu(reg))
if top_module is not None:
top_feats.append(top_module(bbox_tower))
return logits, bbox_reg, ctrness, top_feats, bbox_towers
return logits, bbox_reg, ctrness, top_feats, bbox_towers
Loading

0 comments on commit d423e3d

Please sign in to comment.