From 95625ce71ebf4e0d17f201fe64851aeaa7293e90 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Thu, 17 Mar 2022 21:57:37 +0800 Subject: [PATCH 1/6] [Fix] Fix ignore in CELoss. --- mmdet/models/losses/cross_entropy_loss.py | 17 +++++++++++++++-- tests/test_models/test_loss.py | 13 +++++++------ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index 5777aebd290..0b0bf1dda38 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -41,6 +41,9 @@ def cross_entropy(pred, reduction='none', ignore_index=ignore_index) + if reduction == 'mean' and avg_factor is None: + avg_factor = max(1, (label != ignore_index).sum()) + # apply weights and do the reduction if weight is not None: weight = weight.float() @@ -101,13 +104,23 @@ def binary_cross_entropy(pred, """ # The default value of ignore_index is the same as F.cross_entropy ignore_index = -100 if ignore_index is None else ignore_index + if reduction == 'mean' and avg_factor is None: + avg_factor = max(1, + label.numel() - (label == ignore_index).sum().item()) + if pred.dim() != label.dim(): label, weight = _expand_onehot_labels(label, weight, pred.size(-1), ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight *= valid_mask + else: + weight = valid_mask # weighted element-wise losses - if weight is not None: - weight = weight.float() + weight = weight.float() loss = F.binary_cross_entropy_with_logits( pred, label.float(), pos_weight=class_weight, reduction='none') # do the reduction for the weighted loss diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py index 380bc3263f7..f566ccce8b3 100644 --- a/tests/test_models/test_loss.py +++ b/tests/test_models/test_loss.py @@ -135,7 +135,8 @@ def test_GHMR_loss(loss_class, input_shape): @pytest.mark.parametrize('use_sigmoid', [True, False]) -def test_loss_with_ignore_index(use_sigmoid): +@pytest.mark.parametrize('reduction', ['sum', 'mean', None]) +def test_loss_with_ignore_index(use_sigmoid, reduction): # Test cross_entropy loss loss_class = CrossEntropyLoss( use_sigmoid=use_sigmoid, use_mask=False, ignore_index=255) @@ -146,20 +147,20 @@ def test_loss_with_ignore_index(use_sigmoid): target[ignored_indices] = 255 # Test loss forward with default ignore - loss_with_ignore = loss_class(pred, target, reduction_override='sum') + loss_with_ignore = loss_class(pred, target, reduction_override=reduction) assert isinstance(loss_with_ignore, torch.Tensor) # Test loss forward with forward ignore - target[ignored_indices] = 250 + target[ignored_indices] = 255 loss_with_forward_ignore = loss_class( - pred, target, ignore_index=250, reduction_override='sum') + pred, target, ignore_index=255, reduction_override=reduction) assert isinstance(loss_with_forward_ignore, torch.Tensor) # Verify correctness - not_ignored_indices = (target != 250) + not_ignored_indices = (target != 255) pred = pred[not_ignored_indices] target = target[not_ignored_indices] - loss = loss_class(pred, target, reduction_override='sum') + loss = loss_class(pred, target, reduction_override=reduction) assert torch.allclose(loss, loss_with_ignore) assert torch.allclose(loss, loss_with_forward_ignore) From cc8074337124f8bd97261291bc56c30a432fdbcd Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 18 Mar 2022 09:32:33 +0800 Subject: [PATCH 2/6] add ut --- tests/test_models/test_loss.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py index f566ccce8b3..e17d22f94b1 100644 --- a/tests/test_models/test_loss.py +++ b/tests/test_models/test_loss.py @@ -165,6 +165,12 @@ def test_loss_with_ignore_index(use_sigmoid, reduction): assert torch.allclose(loss, loss_with_ignore) assert torch.allclose(loss, loss_with_forward_ignore) + # test ignore all target + pred = torch.rand((10, 5)) + target = torch.ones((10, ), dtype=torch.long) * 255 + loss = loss_class(pred, target, reduction_override=reduction) + assert loss == 0 + @pytest.mark.parametrize('naive_dice', [True, False]) def test_dice_loss(naive_dice): From ad9534eed1b2c04b341d63dec67763722fa895f8 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 18 Mar 2022 14:19:15 +0800 Subject: [PATCH 3/6] fix and add comments --- mmdet/models/losses/cross_entropy_loss.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index 0b0bf1dda38..f71391cf52d 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -41,6 +41,7 @@ def cross_entropy(pred, reduction='none', ignore_index=ignore_index) + # loss is averaged over non-ignored targets if reduction == 'mean' and avg_factor is None: avg_factor = max(1, (label != ignore_index).sum()) @@ -104,9 +105,10 @@ def binary_cross_entropy(pred, """ # The default value of ignore_index is the same as F.cross_entropy ignore_index = -100 if ignore_index is None else ignore_index + + # loss is averaged over non-ignored targets if reduction == 'mean' and avg_factor is None: - avg_factor = max(1, - label.numel() - (label == ignore_index).sum().item()) + avg_factor = max(1, ((label >= 0) & (label != ignore_index)).sum()) if pred.dim() != label.dim(): label, weight = _expand_onehot_labels(label, weight, pred.size(-1), From 5c1bef5f2113a074e8e29a9f0a4e48a7f64c9f2f Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Thu, 24 Mar 2022 19:56:54 +0800 Subject: [PATCH 4/6] add avg_non_ignore option --- mmdet/models/losses/cross_entropy_loss.py | 49 ++++++++++++++++++----- mmdet/models/losses/utils.py | 6 ++- tests/test_models/test_loss.py | 16 +++++--- 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index f71391cf52d..2bcb597f5aa 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings + import torch import torch.nn as nn import torch.nn.functional as F @@ -13,7 +15,8 @@ def cross_entropy(pred, reduction='mean', avg_factor=None, class_weight=None, - ignore_index=-100): + ignore_index=-100, + avg_non_ignore=False): """Calculate the CrossEntropy loss. Args: @@ -27,6 +30,8 @@ def cross_entropy(pred, class_weight (list[float], optional): The weight for each class. ignore_index (int | None): The label index to be ignored. If None, it will be set to default value. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. Returns: torch.Tensor: The calculated loss @@ -41,9 +46,11 @@ def cross_entropy(pred, reduction='none', ignore_index=ignore_index) - # loss is averaged over non-ignored targets - if reduction == 'mean' and avg_factor is None: - avg_factor = max(1, (label != ignore_index).sum()) + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = label.numel() - (label == ignore_index).sum().item() # apply weights and do the reduction if weight is not None: @@ -81,7 +88,8 @@ def binary_cross_entropy(pred, reduction='mean', avg_factor=None, class_weight=None, - ignore_index=-100): + ignore_index=-100, + avg_non_ignore=False): """Calculate the binary CrossEntropy loss. Args: @@ -99,6 +107,8 @@ def binary_cross_entropy(pred, class_weight (list[float], optional): The weight for each class. ignore_index (int | None): The label index to be ignored. If None, it will be set to default value. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. Returns: torch.Tensor: The calculated loss. @@ -106,9 +116,10 @@ def binary_cross_entropy(pred, # The default value of ignore_index is the same as F.cross_entropy ignore_index = -100 if ignore_index is None else ignore_index - # loss is averaged over non-ignored targets - if reduction == 'mean' and avg_factor is None: - avg_factor = max(1, ((label >= 0) & (label != ignore_index)).sum()) + # average loss over non-ignored elements + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = label.numel() - ((label >= 0) & + (label == ignore_index)).sum().item() if pred.dim() != label.dim(): label, weight = _expand_onehot_labels(label, weight, pred.size(-1), @@ -138,7 +149,8 @@ def mask_cross_entropy(pred, reduction='mean', avg_factor=None, class_weight=None, - ignore_index=None): + ignore_index=None, + **kwargs): """Calculate the CrossEntropy loss for masks. Args: @@ -192,7 +204,8 @@ def __init__(self, reduction='mean', class_weight=None, ignore_index=None, - loss_weight=1.0): + loss_weight=1.0, + avg_non_ignore=False): """CrossEntropyLoss. Args: @@ -207,6 +220,8 @@ def __init__(self, ignore_index (int | None): The label index to be ignored. Defaults to None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. """ super(CrossEntropyLoss, self).__init__() assert (use_sigmoid is False) or (use_mask is False) @@ -216,6 +231,14 @@ def __init__(self, self.loss_weight = loss_weight self.class_weight = class_weight self.ignore_index = ignore_index + self.avg_non_ignore = avg_non_ignore + if ((ignore_index is not None) and not self.avg_non_ignore + and self.reduction == 'mean'): + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') if self.use_sigmoid: self.cls_criterion = binary_cross_entropy @@ -224,6 +247,11 @@ def __init__(self, else: self.cls_criterion = cross_entropy + def extra_repr(self): + """Extra repr.""" + s = f'avg_non_ignore={self.avg_non_ignore}' + return s + def forward(self, cls_score, label, @@ -266,5 +294,6 @@ def forward(self, reduction=reduction, avg_factor=avg_factor, ignore_index=ignore_index, + avg_non_ignore=self.avg_non_ignore, **kwargs) return loss_cls diff --git a/mmdet/models/losses/utils.py b/mmdet/models/losses/utils.py index a7ae7e215bc..778237ebfd5 100644 --- a/mmdet/models/losses/utils.py +++ b/mmdet/models/losses/utils.py @@ -2,6 +2,7 @@ import functools import mmcv +import torch import torch.nn.functional as F @@ -48,7 +49,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): else: # if reduction is mean, then average the loss by avg_factor if reduction == 'mean': - loss = loss.sum() / avg_factor + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) # if reduction is 'none', then do nothing, otherwise raise an error elif reduction != 'none': raise ValueError('avg_factor can not be used with reduction="sum"') diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py index e17d22f94b1..280f3f6ddec 100644 --- a/tests/test_models/test_loss.py +++ b/tests/test_models/test_loss.py @@ -136,10 +136,14 @@ def test_GHMR_loss(loss_class, input_shape): @pytest.mark.parametrize('use_sigmoid', [True, False]) @pytest.mark.parametrize('reduction', ['sum', 'mean', None]) -def test_loss_with_ignore_index(use_sigmoid, reduction): +@pytest.mark.parametrize('avg_non_ignore', [True, False]) +def test_loss_with_ignore_index(use_sigmoid, reduction, avg_non_ignore): # Test cross_entropy loss loss_class = CrossEntropyLoss( - use_sigmoid=use_sigmoid, use_mask=False, ignore_index=255) + use_sigmoid=use_sigmoid, + use_mask=False, + ignore_index=255, + avg_non_ignore=avg_non_ignore) pred = torch.rand((10, 5)) target = torch.randint(0, 5, (10, )) @@ -157,9 +161,11 @@ def test_loss_with_ignore_index(use_sigmoid, reduction): assert isinstance(loss_with_forward_ignore, torch.Tensor) # Verify correctness - not_ignored_indices = (target != 255) - pred = pred[not_ignored_indices] - target = target[not_ignored_indices] + if avg_non_ignore: + # manually remove the ignored elements + not_ignored_indices = (target != 255) + pred = pred[not_ignored_indices] + target = target[not_ignored_indices] loss = loss_class(pred, target, reduction_override=reduction) assert torch.allclose(loss, loss_with_ignore) From 2847ff30599dccd9ccb9b4c9c2c0051eea5c8b64 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 25 Mar 2022 09:53:20 +0800 Subject: [PATCH 5/6] bce avg --- mmdet/models/losses/cross_entropy_loss.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index 2bcb597f5aa..97f12e50375 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -79,7 +79,7 @@ def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index): bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels) bin_label_weights *= valid_mask - return bin_labels, bin_label_weights + return bin_labels, bin_label_weights, valid_mask def binary_cross_entropy(pred, @@ -116,14 +116,9 @@ def binary_cross_entropy(pred, # The default value of ignore_index is the same as F.cross_entropy ignore_index = -100 if ignore_index is None else ignore_index - # average loss over non-ignored elements - if (avg_factor is None) and avg_non_ignore and reduction == 'mean': - avg_factor = label.numel() - ((label >= 0) & - (label == ignore_index)).sum().item() - if pred.dim() != label.dim(): - label, weight = _expand_onehot_labels(label, weight, pred.size(-1), - ignore_index) + label, weight, valid_mask = _expand_onehot_labels( + label, weight, pred.size(-1), ignore_index) else: # should mask out the ignored elements valid_mask = ((label >= 0) & (label != ignore_index)).float() @@ -132,6 +127,10 @@ def binary_cross_entropy(pred, else: weight = valid_mask + # average loss over non-ignored elements + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = valid_mask.sum().item() + # weighted element-wise losses weight = weight.float() loss = F.binary_cross_entropy_with_logits( From 4c2e19d831e8cf00f77df89bb01012cf0283075b Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 25 Mar 2022 22:56:18 +0800 Subject: [PATCH 6/6] fix lint --- mmdet/models/dense_heads/free_anchor_retina_head.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmdet/models/dense_heads/free_anchor_retina_head.py b/mmdet/models/dense_heads/free_anchor_retina_head.py index fa4238974da..3acd25ecba4 100644 --- a/mmdet/models/dense_heads/free_anchor_retina_head.py +++ b/mmdet/models/dense_heads/free_anchor_retina_head.py @@ -79,7 +79,8 @@ def loss(self, featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device - anchor_list, _ = self.get_anchors(featmap_sizes, img_metas, device=device) + anchor_list, _ = self.get_anchors( + featmap_sizes, img_metas, device=device) anchors = [torch.cat(anchor) for anchor in anchor_list] # concatenate each level