Skip to content

Commit

Permalink
✨ [Update] BoxMatcher matching criteria (#125)
Browse files Browse the repository at this point in the history
* ✨ [Update] BoxMatcher matching criteria

Added an additional validity criterium in get_valid_matrix, which masks out anchors from targets, that are too large to predict with the given reg_max and stride values.

Implemented a new function: ensure_one_anchor, which adds a single best suited anchor for valid targets without valid anchors. It is a fallback mechanism, which enables too small or too large targets to be trained to be predicted as well, even if not perfectly.

Fixed the filter_duplicate function to use the topk_masked iou_mat for the selection, which previously sometimes matched invalid targets to anchors with duplicates.

Updated docsstrings across the BoxMatcher functions to match the changes.

* 🔨 [Update] F.one_hot calls in BoxMatcher

to a more efficient solution, without using torch.nn.functional.

torch.nn.functional.one_hot always returns a long tensor, consuming a lot of memory for tensors, which are only used as masks.
  • Loading branch information
Adamusen authored Jan 3, 2025
1 parent da4f0bf commit f080104
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 34 deletions.
2 changes: 1 addition & 1 deletion yolo/tools/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, loss_cfg: LossConfig, vec2box: Vec2Box, class_num: int = 80,
self.dfl = DFLoss(vec2box, reg_max)
self.iou = BoxLoss()

self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box, reg_max)

def separate_anchor(self, anchors):
"""
Expand Down
104 changes: 71 additions & 33 deletions yolo/utils/bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, tensor
from torchmetrics.detection import MeanAveragePrecision
Expand Down Expand Up @@ -143,28 +142,35 @@ def generate_anchors(image_size: List[int], strides: List[int]):


class BoxMatcher:
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
def __init__(self, cfg: MatcherConfig, class_num: int, vec2box, reg_max: int) -> None:
self.class_num = class_num
self.anchors = anchors
self.vec2box = vec2box
self.reg_max = reg_max
for attr_name in cfg:
setattr(self, attr_name, cfg[attr_name])

def get_valid_matrix(self, target_bbox: Tensor):
"""
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor.
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor
and is able to correctly predict it with the available reg_max value.
Args:
target_bbox [batch x targets x 4]: The bounding box of each targets.
target_bbox [batch x targets x 4]: The bounding box of each target.
Returns:
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors.
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps
with the anchors, and the anchor is able to predict the target.
"""
Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3)
anchors = self.anchors[None, None] # add a axis at first, second dimension
x_min, y_min, x_max, y_max = target_bbox[:, :, None].unbind(3)
anchors = self.vec2box.anchor_grid[None, None] # add a axis at first, second dimension
anchors_x, anchors_y = anchors.unbind(dim=3)
target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax)
target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax)
target_on_anchor = target_in_x & target_in_y
return target_on_anchor
x_min_dist, x_max_dist = anchors_x - x_min, x_max - anchors_x
y_min_dist, y_max_dist = anchors_y - y_min, y_max - anchors_y
targets_dist = torch.stack((x_min_dist, y_min_dist, x_max_dist, y_max_dist), dim=-1)
targets_dist /= self.vec2box.scaler[None, None, :, None] # (1, 1, anchors, 1)
min_reg_dist, max_reg_dist = targets_dist.amin(dim=-1), targets_dist.amax(dim=-1)
target_on_anchor = min_reg_dist >= 0
target_in_reg_max = max_reg_dist <= self.reg_max - 1.01
return target_on_anchor & target_in_reg_max

def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -194,40 +200,68 @@ def get_iou_matrix(self, predict_bbox, target_bbox) -> Tensor:
"""
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)

def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
def filter_topk(self, target_matrix: Tensor, grid_mask: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
"""
Filter the top-k suitability of targets for each anchor.
Args:
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
grid_mask [batch x targets x anchors]: The match validity for each target to anchors
topk (int, optional): Number of top scores to retain per anchor.
Returns:
topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
"""
values, indices = target_matrix.topk(topk, dim=-1)
masked_target_matrix = grid_mask * target_matrix
values, indices = masked_target_matrix.topk(topk, dim=-1)
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
topk_targets.scatter_(dim=-1, index=indices, src=values)
topk_masks = topk_targets > 0
return topk_targets, topk_masks
topk_mask = topk_targets > 0
return topk_targets, topk_mask

def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor):
def ensure_one_anchor(self, target_matrix: Tensor, topk_mask: tensor) -> Tensor:
"""
Filter the maximum suitability target index of each anchor.
Ensures each valid target gets at least one anchor matched based on the unmasked target matrix,
which enables an otherwise invalid match. This enables too small or too large targets to be
learned as well, even if they can't be predicted perfectly.
Args:
iou_mat [batch x targets x anchors]: The suitability for each targets-anchors
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
Returns:
topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
"""
values, indices = target_matrix.max(dim=-1)
best_anchor_mask = torch.zeros_like(target_matrix, dtype=torch.bool)
best_anchor_mask.scatter_(-1, index=indices[..., None], src=~best_anchor_mask)
matched_anchor_num = torch.sum(topk_mask, dim=-1)
target_without_anchor = (matched_anchor_num == 0) & (values > 0)
topk_mask = torch.where(target_without_anchor[..., None], best_anchor_mask, topk_mask)
return topk_mask

def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor):
"""
Filter the maximum suitability target index of each anchor based on IoU.
Args:
iou_mat [batch x targets x anchors]: The IoU for each targets-anchors
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
Returns:
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
valid_mask [batch x anchors]: Mask indicating the validity of each anchor
topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
"""
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
topk_mask = torch.where(duplicates, max_idx, topk_mask)
topk_mask &= grid_mask
unique_indices = topk_mask.argmax(dim=1)
return unique_indices[..., None], topk_mask.sum(1), topk_mask
masked_iou_mat = topk_mask * iou_mat
best_indices = masked_iou_mat.argmax(1)[:, None, :]
best_target_mask = torch.zeros_like(duplicates, dtype=torch.bool)
best_target_mask.scatter_(1, index=best_indices, src=~best_target_mask)
topk_mask = torch.where(duplicates, best_target_mask, topk_mask)
unique_indices = topk_mask.to(torch.uint8).argmax(dim=1)
return unique_indices[..., None], topk_mask.any(dim=1), topk_mask

def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
"""Matches each target to the most suitable anchor.
Expand Down Expand Up @@ -273,17 +307,21 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens
# get cls matrix (cls prob with each gt class and each predict class)
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)

target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
target_matrix = (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])

# choose topk
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
topk_targets, topk_mask = self.filter_topk(target_matrix, grid_mask, topk=self.topk)

# match best anchor to valid targets without valid anchors
topk_mask = self.ensure_one_anchor(target_matrix, topk_mask)

# delete one anchor pred assign to mutliple gts
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask)
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)

align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
align_cls = F.one_hot(align_cls, self.class_num)
align_cls_indices = torch.gather(target_cls, 1, unique_indices)
align_cls = torch.zeros_like(align_cls_indices, dtype=torch.bool).repeat(1, 1, self.class_num)
align_cls.scatter_(-1, index=align_cls_indices, src=~align_cls)

# normalize class ditribution
iou_mat *= topk_mask
Expand All @@ -294,7 +332,7 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask.bool()
return anchor_matched_targets, valid_mask


class Vec2Box:
Expand All @@ -305,7 +343,7 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
self.strides = anchor_cfg.strides
else:
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
self.strides = self.create_auto_anchor(model, image_size)

anchor_grid, scaler = generate_anchors(image_size, self.strides)
Expand Down Expand Up @@ -358,7 +396,7 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
self.strides = anchor_cfg.strides
else:
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
self.strides = self.create_auto_anchor(model, image_size)

self.head_num = len(anchor_cfg.anchor)
Expand Down

0 comments on commit f080104

Please sign in to comment.