Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix reduction=mean in CELoss. #7449

Merged
merged 6 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/free_anchor_retina_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def loss(self,
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, _ = self.get_anchors(featmap_sizes, img_metas, device=device)
anchor_list, _ = self.get_anchors(
featmap_sizes, img_metas, device=device)
anchors = [torch.cat(anchor) for anchor in anchor_list]

# concatenate each level
Expand Down
61 changes: 52 additions & 9 deletions mmdet/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -13,7 +15,8 @@ def cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100):
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the CrossEntropy loss.

Args:
Expand All @@ -27,6 +30,8 @@ def cross_entropy(pred,
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.

Returns:
torch.Tensor: The calculated loss
Expand All @@ -41,6 +46,12 @@ def cross_entropy(pred,
reduction='none',
ignore_index=ignore_index)

# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()

# apply weights and do the reduction
if weight is not None:
weight = weight.float()
Expand Down Expand Up @@ -68,7 +79,7 @@ def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
bin_label_weights *= valid_mask

return bin_labels, bin_label_weights
return bin_labels, bin_label_weights, valid_mask


def binary_cross_entropy(pred,
Expand All @@ -77,7 +88,8 @@ def binary_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100):
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the binary CrossEntropy loss.

Args:
Expand All @@ -95,19 +107,32 @@ def binary_cross_entropy(pred,
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.

Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index

if pred.dim() != label.dim():
label, weight = _expand_onehot_labels(label, weight, pred.size(-1),
ignore_index)
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.size(-1), ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight *= valid_mask
else:
weight = valid_mask

# average loss over non-ignored elements
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = valid_mask.sum().item()

# weighted element-wise losses
if weight is not None:
weight = weight.float()
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
Expand All @@ -123,7 +148,8 @@ def mask_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None):
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.

Args:
Expand Down Expand Up @@ -177,7 +203,8 @@ def __init__(self,
reduction='mean',
class_weight=None,
ignore_index=None,
loss_weight=1.0):
loss_weight=1.0,
avg_non_ignore=False):
"""CrossEntropyLoss.

Args:
Expand All @@ -192,6 +219,8 @@ def __init__(self,
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
Expand All @@ -201,6 +230,14 @@ def __init__(self,
self.loss_weight = loss_weight
self.class_weight = class_weight
self.ignore_index = ignore_index
self.avg_non_ignore = avg_non_ignore
if ((ignore_index is not None) and not self.avg_non_ignore
and self.reduction == 'mean'):
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')

if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
Expand All @@ -209,6 +246,11 @@ def __init__(self,
else:
self.cls_criterion = cross_entropy

def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s

def forward(self,
cls_score,
label,
Expand Down Expand Up @@ -251,5 +293,6 @@ def forward(self,
reduction=reduction,
avg_factor=avg_factor,
ignore_index=ignore_index,
avg_non_ignore=self.avg_non_ignore,
**kwargs)
return loss_cls
6 changes: 5 additions & 1 deletion mmdet/models/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools

import mmcv
import torch
import torch.nn.functional as F


Expand Down Expand Up @@ -48,7 +49,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
Expand Down
31 changes: 22 additions & 9 deletions tests/test_models/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,35 +135,48 @@ def test_GHMR_loss(loss_class, input_shape):


@pytest.mark.parametrize('use_sigmoid', [True, False])
def test_loss_with_ignore_index(use_sigmoid):
@pytest.mark.parametrize('reduction', ['sum', 'mean', None])
@pytest.mark.parametrize('avg_non_ignore', [True, False])
def test_loss_with_ignore_index(use_sigmoid, reduction, avg_non_ignore):
# Test cross_entropy loss
loss_class = CrossEntropyLoss(
use_sigmoid=use_sigmoid, use_mask=False, ignore_index=255)
use_sigmoid=use_sigmoid,
use_mask=False,
ignore_index=255,
avg_non_ignore=avg_non_ignore)
pred = torch.rand((10, 5))
target = torch.randint(0, 5, (10, ))

ignored_indices = torch.randint(0, 10, (2, ), dtype=torch.long)
target[ignored_indices] = 255

# Test loss forward with default ignore
loss_with_ignore = loss_class(pred, target, reduction_override='sum')
loss_with_ignore = loss_class(pred, target, reduction_override=reduction)
assert isinstance(loss_with_ignore, torch.Tensor)

# Test loss forward with forward ignore
target[ignored_indices] = 250
target[ignored_indices] = 255
loss_with_forward_ignore = loss_class(
pred, target, ignore_index=250, reduction_override='sum')
pred, target, ignore_index=255, reduction_override=reduction)
assert isinstance(loss_with_forward_ignore, torch.Tensor)

# Verify correctness
not_ignored_indices = (target != 250)
pred = pred[not_ignored_indices]
target = target[not_ignored_indices]
loss = loss_class(pred, target, reduction_override='sum')
if avg_non_ignore:
# manually remove the ignored elements
not_ignored_indices = (target != 255)
pred = pred[not_ignored_indices]
target = target[not_ignored_indices]
loss = loss_class(pred, target, reduction_override=reduction)

assert torch.allclose(loss, loss_with_ignore)
assert torch.allclose(loss, loss_with_forward_ignore)

# test ignore all target
pred = torch.rand((10, 5))
target = torch.ones((10, ), dtype=torch.long) * 255
loss = loss_class(pred, target, reduction_override=reduction)
assert loss == 0


@pytest.mark.parametrize('naive_dice', [True, False])
def test_dice_loss(naive_dice):
Expand Down