From f0801042a87077d8820b2ae9d8d961a18bd140ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Kun=C3=A1k?= <38215643+Adamusen@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:29:44 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20[Update]=20BoxMatcher=20matching=20?= =?UTF-8?q?criteria=20(#125)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ [Update] BoxMatcher matching criteria Added an additional validity criterium in get_valid_matrix, which masks out anchors from targets, that are too large to predict with the given reg_max and stride values. Implemented a new function: ensure_one_anchor, which adds a single best suited anchor for valid targets without valid anchors. It is a fallback mechanism, which enables too small or too large targets to be trained to be predicted as well, even if not perfectly. Fixed the filter_duplicate function to use the topk_masked iou_mat for the selection, which previously sometimes matched invalid targets to anchors with duplicates. Updated docsstrings across the BoxMatcher functions to match the changes. * 🔨 [Update] F.one_hot calls in BoxMatcher to a more efficient solution, without using torch.nn.functional. torch.nn.functional.one_hot always returns a long tensor, consuming a lot of memory for tensors, which are only used as masks. --- yolo/tools/loss_functions.py | 2 +- yolo/utils/bounding_box_utils.py | 104 +++++++++++++++++++++---------- 2 files changed, 72 insertions(+), 34 deletions(-) diff --git a/yolo/tools/loss_functions.py b/yolo/tools/loss_functions.py index 54fd7cf..79fe1cf 100644 --- a/yolo/tools/loss_functions.py +++ b/yolo/tools/loss_functions.py @@ -75,7 +75,7 @@ def __init__(self, loss_cfg: LossConfig, vec2box: Vec2Box, class_num: int = 80, self.dfl = DFLoss(vec2box, reg_max) self.iou = BoxLoss() - self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid) + self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box, reg_max) def separate_anchor(self, anchors): """ diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index 1d25bcb..63896e3 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -2,7 +2,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch -import torch.nn.functional as F from einops import rearrange from torch import Tensor, tensor from torchmetrics.detection import MeanAveragePrecision @@ -143,28 +142,35 @@ def generate_anchors(image_size: List[int], strides: List[int]): class BoxMatcher: - def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None: + def __init__(self, cfg: MatcherConfig, class_num: int, vec2box, reg_max: int) -> None: self.class_num = class_num - self.anchors = anchors + self.vec2box = vec2box + self.reg_max = reg_max for attr_name in cfg: setattr(self, attr_name, cfg[attr_name]) def get_valid_matrix(self, target_bbox: Tensor): """ - Get a boolean mask that indicates whether each target bounding box overlaps with each anchor. + Get a boolean mask that indicates whether each target bounding box overlaps with each anchor + and is able to correctly predict it with the available reg_max value. Args: - target_bbox [batch x targets x 4]: The bounding box of each targets. + target_bbox [batch x targets x 4]: The bounding box of each target. Returns: - [batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors. + [batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps + with the anchors, and the anchor is able to predict the target. """ - Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3) - anchors = self.anchors[None, None] # add a axis at first, second dimension + x_min, y_min, x_max, y_max = target_bbox[:, :, None].unbind(3) + anchors = self.vec2box.anchor_grid[None, None] # add a axis at first, second dimension anchors_x, anchors_y = anchors.unbind(dim=3) - target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax) - target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax) - target_on_anchor = target_in_x & target_in_y - return target_on_anchor + x_min_dist, x_max_dist = anchors_x - x_min, x_max - anchors_x + y_min_dist, y_max_dist = anchors_y - y_min, y_max - anchors_y + targets_dist = torch.stack((x_min_dist, y_min_dist, x_max_dist, y_max_dist), dim=-1) + targets_dist /= self.vec2box.scaler[None, None, :, None] # (1, 1, anchors, 1) + min_reg_dist, max_reg_dist = targets_dist.amin(dim=-1), targets_dist.amax(dim=-1) + target_on_anchor = min_reg_dist >= 0 + target_in_reg_max = max_reg_dist <= self.reg_max - 1.01 + return target_on_anchor & target_in_reg_max def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor: """ @@ -194,40 +200,68 @@ def get_iou_matrix(self, predict_bbox, target_bbox) -> Tensor: """ return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1) - def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]: + def filter_topk(self, target_matrix: Tensor, grid_mask: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]: """ Filter the top-k suitability of targets for each anchor. Args: target_matrix [batch x targets x anchors]: The suitability for each targets-anchors + grid_mask [batch x targets x anchors]: The match validity for each target to anchors topk (int, optional): Number of top scores to retain per anchor. Returns: topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor - topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions. + topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions. """ - values, indices = target_matrix.topk(topk, dim=-1) + masked_target_matrix = grid_mask * target_matrix + values, indices = masked_target_matrix.topk(topk, dim=-1) topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device) topk_targets.scatter_(dim=-1, index=indices, src=values) - topk_masks = topk_targets > 0 - return topk_targets, topk_masks + topk_mask = topk_targets > 0 + return topk_targets, topk_mask - def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor): + def ensure_one_anchor(self, target_matrix: Tensor, topk_mask: tensor) -> Tensor: """ - Filter the maximum suitability target index of each anchor. + Ensures each valid target gets at least one anchor matched based on the unmasked target matrix, + which enables an otherwise invalid match. This enables too small or too large targets to be + learned as well, even if they can't be predicted perfectly. Args: - iou_mat [batch x targets x anchors]: The suitability for each targets-anchors + target_matrix [batch x targets x anchors]: The suitability for each targets-anchors + topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions. + + Returns: + topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions. + """ + values, indices = target_matrix.max(dim=-1) + best_anchor_mask = torch.zeros_like(target_matrix, dtype=torch.bool) + best_anchor_mask.scatter_(-1, index=indices[..., None], src=~best_anchor_mask) + matched_anchor_num = torch.sum(topk_mask, dim=-1) + target_without_anchor = (matched_anchor_num == 0) & (values > 0) + topk_mask = torch.where(target_without_anchor[..., None], best_anchor_mask, topk_mask) + return topk_mask + + def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor): + """ + Filter the maximum suitability target index of each anchor based on IoU. + + Args: + iou_mat [batch x targets x anchors]: The IoU for each targets-anchors + topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions. Returns: unique_indices [batch x anchors x 1]: The index of the best targets for each anchors + valid_mask [batch x anchors]: Mask indicating the validity of each anchor + topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions. """ duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1]) - max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1) - topk_mask = torch.where(duplicates, max_idx, topk_mask) - topk_mask &= grid_mask - unique_indices = topk_mask.argmax(dim=1) - return unique_indices[..., None], topk_mask.sum(1), topk_mask + masked_iou_mat = topk_mask * iou_mat + best_indices = masked_iou_mat.argmax(1)[:, None, :] + best_target_mask = torch.zeros_like(duplicates, dtype=torch.bool) + best_target_mask.scatter_(1, index=best_indices, src=~best_target_mask) + topk_mask = torch.where(duplicates, best_target_mask, topk_mask) + unique_indices = topk_mask.to(torch.uint8).argmax(dim=1) + return unique_indices[..., None], topk_mask.any(dim=1), topk_mask def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]: """Matches each target to the most suitable anchor. @@ -273,17 +307,21 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens # get cls matrix (cls prob with each gt class and each predict class) cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls) - target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"]) + target_matrix = (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"]) # choose topk - topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk) + topk_targets, topk_mask = self.filter_topk(target_matrix, grid_mask, topk=self.topk) + + # match best anchor to valid targets without valid anchors + topk_mask = self.ensure_one_anchor(target_matrix, topk_mask) # delete one anchor pred assign to mutliple gts - unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask) + unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask) align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4)) - align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1) - align_cls = F.one_hot(align_cls, self.class_num) + align_cls_indices = torch.gather(target_cls, 1, unique_indices) + align_cls = torch.zeros_like(align_cls_indices, dtype=torch.bool).repeat(1, 1, self.class_num) + align_cls.scatter_(-1, index=align_cls_indices, src=~align_cls) # normalize class ditribution iou_mat *= topk_mask @@ -294,7 +332,7 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices) align_cls = align_cls * normalize_term * valid_mask[:, :, None] anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1) - return anchor_matched_targets, valid_mask.bool() + return anchor_matched_targets, valid_mask class Vec2Box: @@ -305,7 +343,7 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device): logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}") self.strides = anchor_cfg.strides else: - logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size") + logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size") self.strides = self.create_auto_anchor(model, image_size) anchor_grid, scaler = generate_anchors(image_size, self.strides) @@ -358,7 +396,7 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device): logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}") self.strides = anchor_cfg.strides else: - logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size") + logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size") self.strides = self.create_auto_anchor(model, image_size) self.head_num = len(anchor_cfg.anchor)