Skip to content

Commit

Permalink
🐛 [Fix] BoxMatcher for filter outsided bbox
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytsui000 committed Nov 21, 2024
1 parent c4cd90a commit 96da794
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions yolo/utils/bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,20 @@ def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Te
topk_masks = topk_targets > 0
return topk_targets, topk_masks

def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor):
"""
Filter the maximum suitability target index of each anchor.
Args:
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
iou_mat [batch x targets x anchors]: The suitability for each targets-anchors
Returns:
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
"""
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
topk_mask = torch.where(duplicates, max_idx, topk_mask)
topk_mask &= grid_mask
unique_indices = topk_mask.argmax(dim=1)
return unique_indices[..., None], topk_mask.sum(1), topk_mask

Expand Down Expand Up @@ -278,7 +279,7 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)

# delete one anchor pred assign to mutliple gts
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask)

align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
Expand Down

0 comments on commit 96da794

Please sign in to comment.