diff --git a/mmdet/models/dense_heads/free_anchor_retina_head.py b/mmdet/models/dense_heads/free_anchor_retina_head.py index fa4238974da..3acd25ecba4 100644 --- a/mmdet/models/dense_heads/free_anchor_retina_head.py +++ b/mmdet/models/dense_heads/free_anchor_retina_head.py @@ -79,7 +79,8 @@ def loss(self, featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device - anchor_list, _ = self.get_anchors(featmap_sizes, img_metas, device=device) + anchor_list, _ = self.get_anchors( + featmap_sizes, img_metas, device=device) anchors = [torch.cat(anchor) for anchor in anchor_list] # concatenate each level diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index 5777aebd290..97f12e50375 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings + import torch import torch.nn as nn import torch.nn.functional as F @@ -13,7 +15,8 @@ def cross_entropy(pred, reduction='mean', avg_factor=None, class_weight=None, - ignore_index=-100): + ignore_index=-100, + avg_non_ignore=False): """Calculate the CrossEntropy loss. Args: @@ -27,6 +30,8 @@ def cross_entropy(pred, class_weight (list[float], optional): The weight for each class. ignore_index (int | None): The label index to be ignored. If None, it will be set to default value. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. Returns: torch.Tensor: The calculated loss @@ -41,6 +46,12 @@ def cross_entropy(pred, reduction='none', ignore_index=ignore_index) + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = label.numel() - (label == ignore_index).sum().item() + # apply weights and do the reduction if weight is not None: weight = weight.float() @@ -68,7 +79,7 @@ def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index): bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels) bin_label_weights *= valid_mask - return bin_labels, bin_label_weights + return bin_labels, bin_label_weights, valid_mask def binary_cross_entropy(pred, @@ -77,7 +88,8 @@ def binary_cross_entropy(pred, reduction='mean', avg_factor=None, class_weight=None, - ignore_index=-100): + ignore_index=-100, + avg_non_ignore=False): """Calculate the binary CrossEntropy loss. Args: @@ -95,19 +107,32 @@ def binary_cross_entropy(pred, class_weight (list[float], optional): The weight for each class. ignore_index (int | None): The label index to be ignored. If None, it will be set to default value. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. Returns: torch.Tensor: The calculated loss. """ # The default value of ignore_index is the same as F.cross_entropy ignore_index = -100 if ignore_index is None else ignore_index + if pred.dim() != label.dim(): - label, weight = _expand_onehot_labels(label, weight, pred.size(-1), - ignore_index) + label, weight, valid_mask = _expand_onehot_labels( + label, weight, pred.size(-1), ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight *= valid_mask + else: + weight = valid_mask + + # average loss over non-ignored elements + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = valid_mask.sum().item() # weighted element-wise losses - if weight is not None: - weight = weight.float() + weight = weight.float() loss = F.binary_cross_entropy_with_logits( pred, label.float(), pos_weight=class_weight, reduction='none') # do the reduction for the weighted loss @@ -123,7 +148,8 @@ def mask_cross_entropy(pred, reduction='mean', avg_factor=None, class_weight=None, - ignore_index=None): + ignore_index=None, + **kwargs): """Calculate the CrossEntropy loss for masks. Args: @@ -177,7 +203,8 @@ def __init__(self, reduction='mean', class_weight=None, ignore_index=None, - loss_weight=1.0): + loss_weight=1.0, + avg_non_ignore=False): """CrossEntropyLoss. Args: @@ -192,6 +219,8 @@ def __init__(self, ignore_index (int | None): The label index to be ignored. Defaults to None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. """ super(CrossEntropyLoss, self).__init__() assert (use_sigmoid is False) or (use_mask is False) @@ -201,6 +230,14 @@ def __init__(self, self.loss_weight = loss_weight self.class_weight = class_weight self.ignore_index = ignore_index + self.avg_non_ignore = avg_non_ignore + if ((ignore_index is not None) and not self.avg_non_ignore + and self.reduction == 'mean'): + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') if self.use_sigmoid: self.cls_criterion = binary_cross_entropy @@ -209,6 +246,11 @@ def __init__(self, else: self.cls_criterion = cross_entropy + def extra_repr(self): + """Extra repr.""" + s = f'avg_non_ignore={self.avg_non_ignore}' + return s + def forward(self, cls_score, label, @@ -251,5 +293,6 @@ def forward(self, reduction=reduction, avg_factor=avg_factor, ignore_index=ignore_index, + avg_non_ignore=self.avg_non_ignore, **kwargs) return loss_cls diff --git a/mmdet/models/losses/utils.py b/mmdet/models/losses/utils.py index a7ae7e215bc..778237ebfd5 100644 --- a/mmdet/models/losses/utils.py +++ b/mmdet/models/losses/utils.py @@ -2,6 +2,7 @@ import functools import mmcv +import torch import torch.nn.functional as F @@ -48,7 +49,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): else: # if reduction is mean, then average the loss by avg_factor if reduction == 'mean': - loss = loss.sum() / avg_factor + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) # if reduction is 'none', then do nothing, otherwise raise an error elif reduction != 'none': raise ValueError('avg_factor can not be used with reduction="sum"') diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py index 380bc3263f7..280f3f6ddec 100644 --- a/tests/test_models/test_loss.py +++ b/tests/test_models/test_loss.py @@ -135,10 +135,15 @@ def test_GHMR_loss(loss_class, input_shape): @pytest.mark.parametrize('use_sigmoid', [True, False]) -def test_loss_with_ignore_index(use_sigmoid): +@pytest.mark.parametrize('reduction', ['sum', 'mean', None]) +@pytest.mark.parametrize('avg_non_ignore', [True, False]) +def test_loss_with_ignore_index(use_sigmoid, reduction, avg_non_ignore): # Test cross_entropy loss loss_class = CrossEntropyLoss( - use_sigmoid=use_sigmoid, use_mask=False, ignore_index=255) + use_sigmoid=use_sigmoid, + use_mask=False, + ignore_index=255, + avg_non_ignore=avg_non_ignore) pred = torch.rand((10, 5)) target = torch.randint(0, 5, (10, )) @@ -146,24 +151,32 @@ def test_loss_with_ignore_index(use_sigmoid): target[ignored_indices] = 255 # Test loss forward with default ignore - loss_with_ignore = loss_class(pred, target, reduction_override='sum') + loss_with_ignore = loss_class(pred, target, reduction_override=reduction) assert isinstance(loss_with_ignore, torch.Tensor) # Test loss forward with forward ignore - target[ignored_indices] = 250 + target[ignored_indices] = 255 loss_with_forward_ignore = loss_class( - pred, target, ignore_index=250, reduction_override='sum') + pred, target, ignore_index=255, reduction_override=reduction) assert isinstance(loss_with_forward_ignore, torch.Tensor) # Verify correctness - not_ignored_indices = (target != 250) - pred = pred[not_ignored_indices] - target = target[not_ignored_indices] - loss = loss_class(pred, target, reduction_override='sum') + if avg_non_ignore: + # manually remove the ignored elements + not_ignored_indices = (target != 255) + pred = pred[not_ignored_indices] + target = target[not_ignored_indices] + loss = loss_class(pred, target, reduction_override=reduction) assert torch.allclose(loss, loss_with_ignore) assert torch.allclose(loss, loss_with_forward_ignore) + # test ignore all target + pred = torch.rand((10, 5)) + target = torch.ones((10, ), dtype=torch.long) * 255 + loss = loss_class(pred, target, reduction_override=reduction) + assert loss == 0 + @pytest.mark.parametrize('naive_dice', [True, False]) def test_dice_loss(naive_dice):