Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/losses #2

Merged
merged 8 commits into from
Dec 2, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'master' into feature/losses
# Conflicts:
#	models/losses.py
#	models/mesh_classifier.py
  • Loading branch information
ihahanov committed Dec 2, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit bb0d6b1e0b2286f1416d7f2d5ef0fee66c64b56d
67 changes: 32 additions & 35 deletions models/losses.py
Original file line number Diff line number Diff line change
@@ -8,15 +8,13 @@

def bce_loss(true, logits, pos_weight=None):
"""Computes the weighted binary cross-entropy loss.

Args:
true: a tensor of shape [B, 1, H, W].
logits: a tensor of shape [B, 1, H, W]. Corresponds to
the raw output or logits of the model.
pos_weight: a scalar representing the weight attributed
to the positive class. This is especially useful for
an imbalanced dataset.

Returns:
bce_loss: the weighted binary cross-entropy loss.
"""
@@ -28,22 +26,20 @@ def bce_loss(true, logits, pos_weight=None):
return bce_loss


def ce_loss(logits, true, weights=None, ignore=255):
def ce_loss(true, logits, weights, ignore=255):
"""Computes the weighted multi-class cross-entropy loss.

Args:
true: a tensor of shape [1, N].
logits: a tensor of shape [1, C, N]. Corresponds to
true: a tensor of shape [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
weight: a tensor of shape [C,]. The weights attributed
to each class.
ignore: the class index to ignore.

Returns:
ce_loss: the weighted multi-class cross-entropy loss.
"""
true = true.squeeze()
logits = logits.squeeze().transpose(0,1)
true = true.squeeze(-1).squeeze(1)
logits = logits.squeeze(-1)

ce_loss = F.cross_entropy(
logits.float(),
@@ -54,25 +50,19 @@ def ce_loss(logits, true, weights=None, ignore=255):
return ce_loss


def dice_loss(logits, true, eps=1e-7):
def dice_loss(true, logits, eps=1e-7):
"""Computes the Sørensen–Dice loss.

Note that PyTorch optimizers minimize a loss. In this
case, we would like to maximize the dice loss so we
return the negated dice loss.

Args:
true: a tensor of shape [1, N].
logits: a tensor of shape [1, C, N]. Corresponds to
true: a tensor of shape [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
eps: added to the denominator for numerical stability.

Returns:
dice_loss: the Sørensen–Dice loss.
"""
true = true.unsqueeze(1).unsqueeze(-1)
logits = logits.unsqueeze(-1)

num_classes = logits.shape[1]
if num_classes == 1:
true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
@@ -95,24 +85,19 @@ def dice_loss(logits, true, eps=1e-7):
return (1 - dice_loss)


def jaccard_loss(logits, true, eps=1e-7):
def jaccard_loss(true, logits, eps=1e-7):
"""Computes the Jaccard loss, a.k.a the IoU loss.

Note that PyTorch optimizers minimize a loss. In this
case, we would like to maximize the jaccard loss so we
return the negated jaccard loss.

Args:
true: a tensor of shape [1, N].
logits: a tensor of shape [1, C, N]. Corresponds to
true: a tensor of shape [B, H, W] or [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
eps: added to the denominator for numerical stability.

Returns:
jacc_loss: the Jaccard loss.
"""
true = true.unsqueeze(1).unsqueeze(-1)
logits = logits.unsqueeze(-1)

num_classes = logits.shape[1]
if num_classes == 1:
@@ -139,23 +124,19 @@ def jaccard_loss(logits, true, eps=1e-7):

def tversky_loss(true, logits, alpha, beta, eps=1e-7):
"""Computes the Tversky loss [1].

Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
alpha: controls the penalty for false positives.
beta: controls the penalty for false negatives.
eps: added to the denominator for numerical stability.

Returns:
tversky_loss: the Tversky loss.

Notes:
alpha = beta = 0.5 => dice coeff
alpha = beta = 1 => tanimoto coeff
alpha + beta = 1 => F beta coeff

References:
[1]: https://arxiv.org/abs/1706.05721
"""
@@ -184,13 +165,29 @@ def tversky_loss(true, logits, alpha, beta, eps=1e-7):
return (1 - tversky_loss)


def ce_dice(logits, true, weights=None):
return ce_loss(logits, true, weights) + dice_loss(logits, true)
def ce_dice(true, pred, log=False, w1=1, w2=1):
pass


def ce_jaccard(true, pred, weights=torch.tensor([0.5, 2])):
if weights is not None:
weights = torch.tensor(weights).to(pred.device)

def ce_jaccard(logits, true, weights=None):
return ce_loss(logits, true, weights) + jaccard_loss(logits, true)
return ce_loss(true, pred, weights) + \
jaccard_loss(true, pred)


def focal_loss(true, pred):
pass
pass


def postprocess(true, pred):
num_classses = pred.shape[1]
true = true.view(-1)
pred = pred.view(num_classses, -1)
not_padding = true != -1
true = true[not_padding]
pred = pred[:, not_padding]
true = true.view(1, 1, -1, 1)
pred = pred.view(1, num_classses, -1, 1)
return true, pred
5 changes: 2 additions & 3 deletions models/mesh_classifier.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

from . import networks
from os.path import join
from util.util import seg_accuracy, print_network, remove_padding
from util.util import seg_accuracy, print_network


class ClassifierModel:
@@ -60,8 +60,7 @@ def forward(self):
return out

def backward(self, out):
label_class, pred_class = remove_padding(self.labels, out)
self.loss = self.criterion(pred_class, label_class)
self.loss = self.criterion(self.labels, out)
self.loss.backward()

def optimize_parameters(self):
You are viewing a condensed version of this merge commit. You can view the full changes here.