-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlosses.py
72 lines (52 loc) · 2.24 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as Func
class ContrastiveLoss(nn.Module):
"""Contrastive learing loss."""
def __init__(self, margin: float | None = 1.0):
"""
Initialize the class.
Args:
margin (float): margin for contrastive loss.
"""
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
# Calculate the positive pair distance
pos_dist = Func.pairwise_distance(anchor, positive)
# Calculate the negative pair distance
neg_dist = Func.pairwise_distance(anchor, negative)
# Calculate the contrastive loss
loss = torch.mean((pos_dist ** 2) +
torch.clamp(self.margin - neg_dist, min=0.0) ** 2)
return loss
contrastive_loss = ContrastiveLoss(margin=1.0)
def dice_loss(pred: torch.Tensor, target: torch.Tensor, smooth: float | None = 1.0):
'''
Calculate dice loss.
Args:
pred (Tensor): output of model, (batch_size, num_class, width, height).
target (Tensor): true prediction for eac pixel, (batch_size, width, height).
smooth (float): smooth parameter for dice loss. Default to `1.0`.
Return:
dice_loss (float): dice loss.
'''
_, num_classes, _, _ = pred.shape
target = target - 1 # convert values from 1-3 to 0-2
assert target.min() >= 0 and target.max(
) < num_classes, 'target contains invalid class indices'
# Convert logits to probabilities
pred = torch.nn.functional.softmax(pred, dim=1)
# convert targets, trimap, to one shot
# (batch_size, width, height) -> (batch_size, 3, width, height)
targets_one_hot = torch.nn.functional.one_hot(
target, num_classes).permute(0, 3, 1, 2)
targets_one_hot = targets_one_hot.type_as(pred)
# calculate dice_score for each (batch,class)
# sum over width and height
intersection = torch.sum(pred * targets_one_hot, dim=(2, 3))
union = torch.sum(pred, dim=(2, 3)) + \
torch.sum(targets_one_hot, dim=(2, 3))
dice_score = (2. * intersection + smooth) / (union + smooth)
dice_loss = 1. - dice_score.mean() # aveerage over batch and num_classes
return dice_loss