From 3b2e9655631a2edd28bb94c640bd6a74c0bfad55 Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Fri, 25 Mar 2022 23:01:32 +0800
Subject: [PATCH] [Fix] Fix reduction=mean in CELoss. (#7449)

* [Fix] Fix ignore in CELoss.

* add ut

* fix and add comments

* add avg_non_ignore option

* bce avg

* fix lint
---
 .../dense_heads/free_anchor_retina_head.py    |  3 +-
 mmdet/models/losses/cross_entropy_loss.py     | 61 ++++++++++++++++---
 mmdet/models/losses/utils.py                  |  6 +-
 tests/test_models/test_loss.py                | 31 +++++++---
 4 files changed, 81 insertions(+), 20 deletions(-)

diff --git a/mmdet/models/dense_heads/free_anchor_retina_head.py b/mmdet/models/dense_heads/free_anchor_retina_head.py
index fa4238974da..3acd25ecba4 100644
--- a/mmdet/models/dense_heads/free_anchor_retina_head.py
+++ b/mmdet/models/dense_heads/free_anchor_retina_head.py
@@ -79,7 +79,8 @@ def loss(self,
         featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
         assert len(featmap_sizes) == self.prior_generator.num_levels
         device = cls_scores[0].device
-        anchor_list, _ = self.get_anchors(featmap_sizes, img_metas, device=device)
+        anchor_list, _ = self.get_anchors(
+            featmap_sizes, img_metas, device=device)
         anchors = [torch.cat(anchor) for anchor in anchor_list]
 
         # concatenate each level
diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py
index 5777aebd290..97f12e50375 100644
--- a/mmdet/models/losses/cross_entropy_loss.py
+++ b/mmdet/models/losses/cross_entropy_loss.py
@@ -1,4 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -13,7 +15,8 @@ def cross_entropy(pred,
                   reduction='mean',
                   avg_factor=None,
                   class_weight=None,
-                  ignore_index=-100):
+                  ignore_index=-100,
+                  avg_non_ignore=False):
     """Calculate the CrossEntropy loss.
 
     Args:
@@ -27,6 +30,8 @@ def cross_entropy(pred,
         class_weight (list[float], optional): The weight for each class.
         ignore_index (int | None): The label index to be ignored.
             If None, it will be set to default value. Default: -100.
+        avg_non_ignore (bool): The flag decides to whether the loss is
+            only averaged over non-ignored targets. Default: False.
 
     Returns:
         torch.Tensor: The calculated loss
@@ -41,6 +46,12 @@ def cross_entropy(pred,
         reduction='none',
         ignore_index=ignore_index)
 
+    # average loss over non-ignored elements
+    # pytorch's official cross_entropy average loss over non-ignored elements
+    # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660  # noqa
+    if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
+        avg_factor = label.numel() - (label == ignore_index).sum().item()
+
     # apply weights and do the reduction
     if weight is not None:
         weight = weight.float()
@@ -68,7 +79,7 @@ def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
         bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
         bin_label_weights *= valid_mask
 
-    return bin_labels, bin_label_weights
+    return bin_labels, bin_label_weights, valid_mask
 
 
 def binary_cross_entropy(pred,
@@ -77,7 +88,8 @@ def binary_cross_entropy(pred,
                          reduction='mean',
                          avg_factor=None,
                          class_weight=None,
-                         ignore_index=-100):
+                         ignore_index=-100,
+                         avg_non_ignore=False):
     """Calculate the binary CrossEntropy loss.
 
     Args:
@@ -95,19 +107,32 @@ def binary_cross_entropy(pred,
         class_weight (list[float], optional): The weight for each class.
         ignore_index (int | None): The label index to be ignored.
             If None, it will be set to default value. Default: -100.
+        avg_non_ignore (bool): The flag decides to whether the loss is
+            only averaged over non-ignored targets. Default: False.
 
     Returns:
         torch.Tensor: The calculated loss.
     """
     # The default value of ignore_index is the same as F.cross_entropy
     ignore_index = -100 if ignore_index is None else ignore_index
+
     if pred.dim() != label.dim():
-        label, weight = _expand_onehot_labels(label, weight, pred.size(-1),
-                                              ignore_index)
+        label, weight, valid_mask = _expand_onehot_labels(
+            label, weight, pred.size(-1), ignore_index)
+    else:
+        # should mask out the ignored elements
+        valid_mask = ((label >= 0) & (label != ignore_index)).float()
+        if weight is not None:
+            weight *= valid_mask
+        else:
+            weight = valid_mask
+
+    # average loss over non-ignored elements
+    if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
+        avg_factor = valid_mask.sum().item()
 
     # weighted element-wise losses
-    if weight is not None:
-        weight = weight.float()
+    weight = weight.float()
     loss = F.binary_cross_entropy_with_logits(
         pred, label.float(), pos_weight=class_weight, reduction='none')
     # do the reduction for the weighted loss
@@ -123,7 +148,8 @@ def mask_cross_entropy(pred,
                        reduction='mean',
                        avg_factor=None,
                        class_weight=None,
-                       ignore_index=None):
+                       ignore_index=None,
+                       **kwargs):
     """Calculate the CrossEntropy loss for masks.
 
     Args:
@@ -177,7 +203,8 @@ def __init__(self,
                  reduction='mean',
                  class_weight=None,
                  ignore_index=None,
-                 loss_weight=1.0):
+                 loss_weight=1.0,
+                 avg_non_ignore=False):
         """CrossEntropyLoss.
 
         Args:
@@ -192,6 +219,8 @@ def __init__(self,
             ignore_index (int | None): The label index to be ignored.
                 Defaults to None.
             loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+            avg_non_ignore (bool): The flag decides to whether the loss is
+                only averaged over non-ignored targets. Default: False.
         """
         super(CrossEntropyLoss, self).__init__()
         assert (use_sigmoid is False) or (use_mask is False)
@@ -201,6 +230,14 @@ def __init__(self,
         self.loss_weight = loss_weight
         self.class_weight = class_weight
         self.ignore_index = ignore_index
+        self.avg_non_ignore = avg_non_ignore
+        if ((ignore_index is not None) and not self.avg_non_ignore
+                and self.reduction == 'mean'):
+            warnings.warn(
+                'Default ``avg_non_ignore`` is False, if you would like to '
+                'ignore the certain label and average loss over non-ignore '
+                'labels, which is the same with PyTorch official '
+                'cross_entropy, set ``avg_non_ignore=True``.')
 
         if self.use_sigmoid:
             self.cls_criterion = binary_cross_entropy
@@ -209,6 +246,11 @@ def __init__(self,
         else:
             self.cls_criterion = cross_entropy
 
+    def extra_repr(self):
+        """Extra repr."""
+        s = f'avg_non_ignore={self.avg_non_ignore}'
+        return s
+
     def forward(self,
                 cls_score,
                 label,
@@ -251,5 +293,6 @@ def forward(self,
             reduction=reduction,
             avg_factor=avg_factor,
             ignore_index=ignore_index,
+            avg_non_ignore=self.avg_non_ignore,
             **kwargs)
         return loss_cls
diff --git a/mmdet/models/losses/utils.py b/mmdet/models/losses/utils.py
index a7ae7e215bc..778237ebfd5 100644
--- a/mmdet/models/losses/utils.py
+++ b/mmdet/models/losses/utils.py
@@ -2,6 +2,7 @@
 import functools
 
 import mmcv
+import torch
 import torch.nn.functional as F
 
 
@@ -48,7 +49,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
     else:
         # if reduction is mean, then average the loss by avg_factor
         if reduction == 'mean':
-            loss = loss.sum() / avg_factor
+            # Avoid causing ZeroDivisionError when avg_factor is 0.0,
+            # i.e., all labels of an image belong to ignore index.
+            eps = torch.finfo(torch.float32).eps
+            loss = loss.sum() / (avg_factor + eps)
         # if reduction is 'none', then do nothing, otherwise raise an error
         elif reduction != 'none':
             raise ValueError('avg_factor can not be used with reduction="sum"')
diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py
index 380bc3263f7..280f3f6ddec 100644
--- a/tests/test_models/test_loss.py
+++ b/tests/test_models/test_loss.py
@@ -135,10 +135,15 @@ def test_GHMR_loss(loss_class, input_shape):
 
 
 @pytest.mark.parametrize('use_sigmoid', [True, False])
-def test_loss_with_ignore_index(use_sigmoid):
+@pytest.mark.parametrize('reduction', ['sum', 'mean', None])
+@pytest.mark.parametrize('avg_non_ignore', [True, False])
+def test_loss_with_ignore_index(use_sigmoid, reduction, avg_non_ignore):
     # Test cross_entropy loss
     loss_class = CrossEntropyLoss(
-        use_sigmoid=use_sigmoid, use_mask=False, ignore_index=255)
+        use_sigmoid=use_sigmoid,
+        use_mask=False,
+        ignore_index=255,
+        avg_non_ignore=avg_non_ignore)
     pred = torch.rand((10, 5))
     target = torch.randint(0, 5, (10, ))
 
@@ -146,24 +151,32 @@ def test_loss_with_ignore_index(use_sigmoid):
     target[ignored_indices] = 255
 
     # Test loss forward with default ignore
-    loss_with_ignore = loss_class(pred, target, reduction_override='sum')
+    loss_with_ignore = loss_class(pred, target, reduction_override=reduction)
     assert isinstance(loss_with_ignore, torch.Tensor)
 
     # Test loss forward with forward ignore
-    target[ignored_indices] = 250
+    target[ignored_indices] = 255
     loss_with_forward_ignore = loss_class(
-        pred, target, ignore_index=250, reduction_override='sum')
+        pred, target, ignore_index=255, reduction_override=reduction)
     assert isinstance(loss_with_forward_ignore, torch.Tensor)
 
     # Verify correctness
-    not_ignored_indices = (target != 250)
-    pred = pred[not_ignored_indices]
-    target = target[not_ignored_indices]
-    loss = loss_class(pred, target, reduction_override='sum')
+    if avg_non_ignore:
+        # manually remove the ignored elements
+        not_ignored_indices = (target != 255)
+        pred = pred[not_ignored_indices]
+        target = target[not_ignored_indices]
+    loss = loss_class(pred, target, reduction_override=reduction)
 
     assert torch.allclose(loss, loss_with_ignore)
     assert torch.allclose(loss, loss_with_forward_ignore)
 
+    # test ignore all target
+    pred = torch.rand((10, 5))
+    target = torch.ones((10, ), dtype=torch.long) * 255
+    loss = loss_class(pred, target, reduction_override=reduction)
+    assert loss == 0
+
 
 @pytest.mark.parametrize('naive_dice', [True, False])
 def test_dice_loss(naive_dice):