Skip to content

Commit

Permalink
[Fix] Fix reduction=mean in CELoss. (#7449)
Browse files Browse the repository at this point in the history
* [Fix] Fix ignore in CELoss.

* add ut

* fix and add comments

* add avg_non_ignore option

* bce avg

* fix lint
  • Loading branch information
RangiLyu authored Mar 25, 2022
1 parent 3f0f2a0 commit 3b2e965
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 20 deletions.
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/free_anchor_retina_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def loss(self,
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, _ = self.get_anchors(featmap_sizes, img_metas, device=device)
anchor_list, _ = self.get_anchors(
featmap_sizes, img_metas, device=device)
anchors = [torch.cat(anchor) for anchor in anchor_list]

# concatenate each level
Expand Down
61 changes: 52 additions & 9 deletions mmdet/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -13,7 +15,8 @@ def cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100):
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the CrossEntropy loss.
Args:
Expand All @@ -27,6 +30,8 @@ def cross_entropy(pred,
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss
Expand All @@ -41,6 +46,12 @@ def cross_entropy(pred,
reduction='none',
ignore_index=ignore_index)

# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()

# apply weights and do the reduction
if weight is not None:
weight = weight.float()
Expand Down Expand Up @@ -68,7 +79,7 @@ def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
bin_label_weights *= valid_mask

return bin_labels, bin_label_weights
return bin_labels, bin_label_weights, valid_mask


def binary_cross_entropy(pred,
Expand All @@ -77,7 +88,8 @@ def binary_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100):
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the binary CrossEntropy loss.
Args:
Expand All @@ -95,19 +107,32 @@ def binary_cross_entropy(pred,
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index

if pred.dim() != label.dim():
label, weight = _expand_onehot_labels(label, weight, pred.size(-1),
ignore_index)
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.size(-1), ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight *= valid_mask
else:
weight = valid_mask

# average loss over non-ignored elements
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = valid_mask.sum().item()

# weighted element-wise losses
if weight is not None:
weight = weight.float()
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
Expand All @@ -123,7 +148,8 @@ def mask_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None):
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
Expand Down Expand Up @@ -177,7 +203,8 @@ def __init__(self,
reduction='mean',
class_weight=None,
ignore_index=None,
loss_weight=1.0):
loss_weight=1.0,
avg_non_ignore=False):
"""CrossEntropyLoss.
Args:
Expand All @@ -192,6 +219,8 @@ def __init__(self,
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
Expand All @@ -201,6 +230,14 @@ def __init__(self,
self.loss_weight = loss_weight
self.class_weight = class_weight
self.ignore_index = ignore_index
self.avg_non_ignore = avg_non_ignore
if ((ignore_index is not None) and not self.avg_non_ignore
and self.reduction == 'mean'):
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')

if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
Expand All @@ -209,6 +246,11 @@ def __init__(self,
else:
self.cls_criterion = cross_entropy

def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s

def forward(self,
cls_score,
label,
Expand Down Expand Up @@ -251,5 +293,6 @@ def forward(self,
reduction=reduction,
avg_factor=avg_factor,
ignore_index=ignore_index,
avg_non_ignore=self.avg_non_ignore,
**kwargs)
return loss_cls
6 changes: 5 additions & 1 deletion mmdet/models/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools

import mmcv
import torch
import torch.nn.functional as F


Expand Down Expand Up @@ -48,7 +49,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
Expand Down
31 changes: 22 additions & 9 deletions tests/test_models/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,35 +135,48 @@ def test_GHMR_loss(loss_class, input_shape):


@pytest.mark.parametrize('use_sigmoid', [True, False])
def test_loss_with_ignore_index(use_sigmoid):
@pytest.mark.parametrize('reduction', ['sum', 'mean', None])
@pytest.mark.parametrize('avg_non_ignore', [True, False])
def test_loss_with_ignore_index(use_sigmoid, reduction, avg_non_ignore):
# Test cross_entropy loss
loss_class = CrossEntropyLoss(
use_sigmoid=use_sigmoid, use_mask=False, ignore_index=255)
use_sigmoid=use_sigmoid,
use_mask=False,
ignore_index=255,
avg_non_ignore=avg_non_ignore)
pred = torch.rand((10, 5))
target = torch.randint(0, 5, (10, ))

ignored_indices = torch.randint(0, 10, (2, ), dtype=torch.long)
target[ignored_indices] = 255

# Test loss forward with default ignore
loss_with_ignore = loss_class(pred, target, reduction_override='sum')
loss_with_ignore = loss_class(pred, target, reduction_override=reduction)
assert isinstance(loss_with_ignore, torch.Tensor)

# Test loss forward with forward ignore
target[ignored_indices] = 250
target[ignored_indices] = 255
loss_with_forward_ignore = loss_class(
pred, target, ignore_index=250, reduction_override='sum')
pred, target, ignore_index=255, reduction_override=reduction)
assert isinstance(loss_with_forward_ignore, torch.Tensor)

# Verify correctness
not_ignored_indices = (target != 250)
pred = pred[not_ignored_indices]
target = target[not_ignored_indices]
loss = loss_class(pred, target, reduction_override='sum')
if avg_non_ignore:
# manually remove the ignored elements
not_ignored_indices = (target != 255)
pred = pred[not_ignored_indices]
target = target[not_ignored_indices]
loss = loss_class(pred, target, reduction_override=reduction)

assert torch.allclose(loss, loss_with_ignore)
assert torch.allclose(loss, loss_with_forward_ignore)

# test ignore all target
pred = torch.rand((10, 5))
target = torch.ones((10, ), dtype=torch.long) * 255
loss = loss_class(pred, target, reduction_override=reduction)
assert loss == 0


@pytest.mark.parametrize('naive_dice', [True, False])
def test_dice_loss(naive_dice):
Expand Down

0 comments on commit 3b2e965

Please sign in to comment.