From cf18e47da2f7ab585fdfb8f3b39d04619a2ff485 Mon Sep 17 00:00:00 2001 From: mrybakova Date: Thu, 25 Nov 2021 17:16:38 +0200 Subject: [PATCH 1/7] losses --- models/losses.py | 187 ++++++++++++++++++++++++++++++++++++++ models/mesh_classifier.py | 2 +- models/networks.py | 5 +- 3 files changed, 191 insertions(+), 3 deletions(-) create mode 100644 models/losses.py diff --git a/models/losses.py b/models/losses.py new file mode 100644 index 00000000..614e408b --- /dev/null +++ b/models/losses.py @@ -0,0 +1,187 @@ +"""Common image segmentation losses. +""" + +import torch + +from torch.nn import functional as F + + +def bce_loss(true, logits, pos_weight=None): + """Computes the weighted binary cross-entropy loss. + + Args: + true: a tensor of shape [B, 1, H, W]. + logits: a tensor of shape [B, 1, H, W]. Corresponds to + the raw output or logits of the model. + pos_weight: a scalar representing the weight attributed + to the positive class. This is especially useful for + an imbalanced dataset. + + Returns: + bce_loss: the weighted binary cross-entropy loss. + """ + bce_loss = F.binary_cross_entropy_with_logits( + logits.float(), + true.float(), + pos_weight=pos_weight, + ) + return bce_loss + + +def ce_loss(true, logits, weights, ignore=255): + """Computes the weighted multi-class cross-entropy loss. + + Args: + true: a tensor of shape [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + weight: a tensor of shape [C,]. The weights attributed + to each class. + ignore: the class index to ignore. + + Returns: + ce_loss: the weighted multi-class cross-entropy loss. + """ + ce_loss = F.cross_entropy( + logits.float(), + true.long(), + ignore_index=ignore, + weight=weights, + ) + return ce_loss + + +def dice_loss(true, logits, eps=1e-7): + """Computes the Sørensen–Dice loss. + + Note that PyTorch optimizers minimize a loss. In this + case, we would like to maximize the dice loss so we + return the negated dice loss. + + Args: + true: a tensor of shape [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + eps: added to the denominator for numerical stability. + + Returns: + dice_loss: the Sørensen–Dice loss. + """ + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(logits, dim=1) + true_1_hot = true_1_hot.type(logits.type()) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + cardinality = torch.sum(probas + true_1_hot, dims) + dice_loss = (2. * intersection / (cardinality + eps)).mean() + return (1 - dice_loss) + + +def jaccard_loss(true, logits, eps=1e-7): + """Computes the Jaccard loss, a.k.a the IoU loss. + + Note that PyTorch optimizers minimize a loss. In this + case, we would like to maximize the jaccard loss so we + return the negated jaccard loss. + + Args: + true: a tensor of shape [B, H, W] or [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + eps: added to the denominator for numerical stability. + + Returns: + jacc_loss: the Jaccard loss. + """ + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(logits, dim=1) + true_1_hot = true_1_hot.type(logits.type()) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + cardinality = torch.sum(probas + true_1_hot, dims) + union = cardinality - intersection + jacc_loss = (intersection / (union + eps)).mean() + return (1 - jacc_loss) + + +def tversky_loss(true, logits, alpha, beta, eps=1e-7): + """Computes the Tversky loss [1]. + + Args: + true: a tensor of shape [B, H, W] or [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + alpha: controls the penalty for false positives. + beta: controls the penalty for false negatives. + eps: added to the denominator for numerical stability. + + Returns: + tversky_loss: the Tversky loss. + + Notes: + alpha = beta = 0.5 => dice coeff + alpha = beta = 1 => tanimoto coeff + alpha + beta = 1 => F beta coeff + + References: + [1]: https://arxiv.org/abs/1706.05721 + """ + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(logits, dim=1) + true_1_hot = true_1_hot.type(logits.type()) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + fps = torch.sum(probas * (1 - true_1_hot), dims) + fns = torch.sum((1 - probas) * true_1_hot, dims) + num = intersection + denom = intersection + (alpha * fps) + (beta * fns) + tversky_loss = (num / (denom + eps)).mean() + return (1 - tversky_loss) + + +def ce_dice(true, pred, log=False, w1=1, w2=1): + pass + + +def ce_jaccard(true, pred, log=False, w1=1, w2=1): + pass + + +def focal_loss(true, pred): + pass \ No newline at end of file diff --git a/models/mesh_classifier.py b/models/mesh_classifier.py index d0596a62..2b9bcc26 100644 --- a/models/mesh_classifier.py +++ b/models/mesh_classifier.py @@ -34,7 +34,7 @@ def __init__(self, opt): self.net = networks.define_classifier(opt.input_nc, opt.ncf, opt.ninput_edges, opt.nclasses, opt, self.gpu_ids, opt.arch, opt.init_type, opt.init_gain) self.net.train(self.is_train) - self.criterion = networks.define_loss(opt).to(self.device) + self.criterion = networks.define_loss(opt) if self.is_train: self.optimizer = torch.optim.Adam(self.net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) diff --git a/models/networks.py b/models/networks.py index 8f6fe4b1..29011afd 100644 --- a/models/networks.py +++ b/models/networks.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from models.layers.mesh_pool import MeshPool from models.layers.mesh_unpool import MeshUnpool -from .losses import ce_jaccard +from .losses import ce_jaccard, dice_loss ############################################################################### @@ -115,7 +115,8 @@ def define_loss(opt): if opt.dataset_mode == 'classification': loss = torch.nn.CrossEntropyLoss() elif opt.dataset_mode == 'segmentation': - loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) + # loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) + loss = lambda out, labels: dice_loss(labels.unsqueeze(1).unsqueeze(-1), out.unsqueeze(-1)) return loss ############################################################################## From e37d1bf26f29bdb1606553f77ec7c4bf4bb8ca5f Mon Sep 17 00:00:00 2001 From: mrybakova Date: Fri, 26 Nov 2021 18:44:48 +0200 Subject: [PATCH 2/7] remove padding --- models/mesh_classifier.py | 5 +++-- util/util.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/models/mesh_classifier.py b/models/mesh_classifier.py index 2b9bcc26..7007f08c 100644 --- a/models/mesh_classifier.py +++ b/models/mesh_classifier.py @@ -3,7 +3,7 @@ from . import networks from os.path import join -from util.util import seg_accuracy, print_network +from util.util import seg_accuracy, print_network, remove_padding class ClassifierModel: @@ -60,7 +60,8 @@ def forward(self): return out def backward(self, out): - self.loss = self.criterion(out, self.labels) + label_class, pred_class = remove_padding(self.labels, out) + self.loss = self.criterion(pred_class, label_class) self.loss.backward() def optimize_parameters(self): diff --git a/util/util.py b/util/util.py index 562c22f6..0767a21d 100644 --- a/util/util.py +++ b/util/util.py @@ -66,3 +66,15 @@ def calculate_entropy(np_array): entropy -= a * np.log(a) entropy /= np.log(np_array.shape[0]) return entropy + + +def remove_padding(label_class, pred_class): + not_padding = label_class != -1 + label_class = label_class[not_padding] + label_class = label_class.unsqueeze(0) + + not_padding = not_padding.repeat(2, 1) + not_padding = not_padding.unsqueeze(0) + pred_class = pred_class[not_padding] + pred_class = pred_class.reshape([1, 2, int(pred_class.size()[0] / 2)]) + return label_class, pred_class \ No newline at end of file From 2682e07b174a8bbb3714fa1fafca4a0a61510a1b Mon Sep 17 00:00:00 2001 From: mrybakova Date: Mon, 29 Nov 2021 13:08:53 +0200 Subject: [PATCH 3/7] combined loss --- models/networks.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/models/networks.py b/models/networks.py index 29011afd..9c8e9fea 100644 --- a/models/networks.py +++ b/models/networks.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from models.layers.mesh_pool import MeshPool from models.layers.mesh_unpool import MeshUnpool -from .losses import ce_jaccard, dice_loss +from .losses import ce_jaccard, dice_loss, jaccard_loss, ce_loss, bce_loss ############################################################################### @@ -115,8 +115,10 @@ def define_loss(opt): if opt.dataset_mode == 'classification': loss = torch.nn.CrossEntropyLoss() elif opt.dataset_mode == 'segmentation': - # loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) - loss = lambda out, labels: dice_loss(labels.unsqueeze(1).unsqueeze(-1), out.unsqueeze(-1)) + loss_ce = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) + loss_dice = lambda out, labels: dice_loss(labels.unsqueeze(1).unsqueeze(-1), out.unsqueeze(-1)) + loss = lambda out, labels: loss_ce(out, labels) + loss_dice(out, labels) + # loss = lambda out, labels: ce_loss(labels.squeeze(), out.squeeze().transpose(0,1), weights=torch.FloatTensor([0.5, 2])) return loss ############################################################################## From 363a662ae99eaf90dfb274c6e3dd6dc93eeab4ee Mon Sep 17 00:00:00 2001 From: mrybakova Date: Mon, 29 Nov 2021 14:53:12 +0200 Subject: [PATCH 4/7] combined loss fixed --- models/networks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/models/networks.py b/models/networks.py index 9c8e9fea..b0a7d97d 100644 --- a/models/networks.py +++ b/models/networks.py @@ -115,10 +115,11 @@ def define_loss(opt): if opt.dataset_mode == 'classification': loss = torch.nn.CrossEntropyLoss() elif opt.dataset_mode == 'segmentation': - loss_ce = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) + # loss_ce = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) loss_dice = lambda out, labels: dice_loss(labels.unsqueeze(1).unsqueeze(-1), out.unsqueeze(-1)) + device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') + loss_ce = lambda out, labels: ce_loss(labels.squeeze(), out.squeeze().transpose(0,1), weights=torch.FloatTensor([0.5, 2]).to(device)) loss = lambda out, labels: loss_ce(out, labels) + loss_dice(out, labels) - # loss = lambda out, labels: ce_loss(labels.squeeze(), out.squeeze().transpose(0,1), weights=torch.FloatTensor([0.5, 2])) return loss ############################################################################## From 0e3292589736b63cb3b9bf8fc17e3af704ed2cae Mon Sep 17 00:00:00 2001 From: mrybakova Date: Wed, 1 Dec 2021 14:13:26 +0200 Subject: [PATCH 5/7] losses update --- models/losses.py | 35 ++++++++++++++++++++++------------- models/networks.py | 17 +++++++++++++---- util/util.py | 11 +++++++---- 3 files changed, 42 insertions(+), 21 deletions(-) diff --git a/models/losses.py b/models/losses.py index 614e408b..ef217278 100644 --- a/models/losses.py +++ b/models/losses.py @@ -28,12 +28,12 @@ def bce_loss(true, logits, pos_weight=None): return bce_loss -def ce_loss(true, logits, weights, ignore=255): +def ce_loss(logits, true, weights=None, ignore=255): """Computes the weighted multi-class cross-entropy loss. Args: - true: a tensor of shape [B, 1, H, W]. - logits: a tensor of shape [B, C, H, W]. Corresponds to + true: a tensor of shape [1, N]. + logits: a tensor of shape [1, C, N]. Corresponds to the raw output or logits of the model. weight: a tensor of shape [C,]. The weights attributed to each class. @@ -42,6 +42,9 @@ def ce_loss(true, logits, weights, ignore=255): Returns: ce_loss: the weighted multi-class cross-entropy loss. """ + true = true.squeeze() + logits = logits.squeeze().transpose(0,1) + ce_loss = F.cross_entropy( logits.float(), true.long(), @@ -51,7 +54,7 @@ def ce_loss(true, logits, weights, ignore=255): return ce_loss -def dice_loss(true, logits, eps=1e-7): +def dice_loss(logits, true, eps=1e-7): """Computes the Sørensen–Dice loss. Note that PyTorch optimizers minimize a loss. In this @@ -59,14 +62,17 @@ def dice_loss(true, logits, eps=1e-7): return the negated dice loss. Args: - true: a tensor of shape [B, 1, H, W]. - logits: a tensor of shape [B, C, H, W]. Corresponds to + true: a tensor of shape [1, N]. + logits: a tensor of shape [1, C, N]. Corresponds to the raw output or logits of the model. eps: added to the denominator for numerical stability. Returns: dice_loss: the Sørensen–Dice loss. """ + true = true.unsqueeze(1).unsqueeze(-1) + logits = logits.unsqueeze(-1) + num_classes = logits.shape[1] if num_classes == 1: true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] @@ -89,7 +95,7 @@ def dice_loss(true, logits, eps=1e-7): return (1 - dice_loss) -def jaccard_loss(true, logits, eps=1e-7): +def jaccard_loss(logits, true, eps=1e-7): """Computes the Jaccard loss, a.k.a the IoU loss. Note that PyTorch optimizers minimize a loss. In this @@ -97,14 +103,17 @@ def jaccard_loss(true, logits, eps=1e-7): return the negated jaccard loss. Args: - true: a tensor of shape [B, H, W] or [B, 1, H, W]. - logits: a tensor of shape [B, C, H, W]. Corresponds to + true: a tensor of shape [1, N]. + logits: a tensor of shape [1, C, N]. Corresponds to the raw output or logits of the model. eps: added to the denominator for numerical stability. Returns: jacc_loss: the Jaccard loss. """ + true = true.unsqueeze(1).unsqueeze(-1) + logits = logits.unsqueeze(-1) + num_classes = logits.shape[1] if num_classes == 1: true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] @@ -175,12 +184,12 @@ def tversky_loss(true, logits, alpha, beta, eps=1e-7): return (1 - tversky_loss) -def ce_dice(true, pred, log=False, w1=1, w2=1): - pass +def ce_dice(logits, true, weights=None): + return ce_loss(logits, true, weights) + dice_loss(logits, true) -def ce_jaccard(true, pred, log=False, w1=1, w2=1): - pass +def ce_jaccard(logits, true, weights=None): + return ce_loss(logits, true, weights) + jaccard_loss(logits, true) def focal_loss(true, pred): diff --git a/models/networks.py b/models/networks.py index b0a7d97d..f5779240 100644 --- a/models/networks.py +++ b/models/networks.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from models.layers.mesh_pool import MeshPool from models.layers.mesh_unpool import MeshUnpool -from .losses import ce_jaccard, dice_loss, jaccard_loss, ce_loss, bce_loss +from .losses import ce_jaccard, dice_loss, jaccard_loss, ce_loss, ce_dice ############################################################################### @@ -116,10 +116,19 @@ def define_loss(opt): loss = torch.nn.CrossEntropyLoss() elif opt.dataset_mode == 'segmentation': # loss_ce = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) - loss_dice = lambda out, labels: dice_loss(labels.unsqueeze(1).unsqueeze(-1), out.unsqueeze(-1)) + + # loss_dice = dice_loss + # loss_jaccard = jaccard_loss + device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') - loss_ce = lambda out, labels: ce_loss(labels.squeeze(), out.squeeze().transpose(0,1), weights=torch.FloatTensor([0.5, 2]).to(device)) - loss = lambda out, labels: loss_ce(out, labels) + loss_dice(out, labels) + weights = torch.FloatTensor([0.5, 2]).to(device) + + # loss_ce = functools.partial(ce_loss, weights=weights) + loss_ce_dice = functools.partial(ce_dice, weights=weights) + # loss_ce_jaccard = functools.partial(ce_jaccard, weights=weights) + + loss = loss_ce_dice + # ToDo: loss option return loss ############################################################################## diff --git a/util/util.py b/util/util.py index 0767a21d..db42c733 100644 --- a/util/util.py +++ b/util/util.py @@ -69,12 +69,15 @@ def calculate_entropy(np_array): def remove_padding(label_class, pred_class): + num_classes = pred_class.size()[1] + label_class, pred_class = label_class.flatten(), pred_class.flatten() + not_padding = label_class != -1 label_class = label_class[not_padding] - label_class = label_class.unsqueeze(0) + label_class = label_class.view(1, -1) - not_padding = not_padding.repeat(2, 1) - not_padding = not_padding.unsqueeze(0) + not_padding = not_padding.repeat(num_classes) pred_class = pred_class[not_padding] - pred_class = pred_class.reshape([1, 2, int(pred_class.size()[0] / 2)]) + pred_class = pred_class.view(1, num_classes, -1) + return label_class, pred_class \ No newline at end of file From ba41de320af8b2dc0d46624aa957b8fcdf0f2510 Mon Sep 17 00:00:00 2001 From: mrybakova Date: Wed, 1 Dec 2021 18:59:17 +0200 Subject: [PATCH 6/7] loss option --- models/networks.py | 20 ++++++++++---------- options/base_options.py | 5 +++++ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/models/networks.py b/models/networks.py index f5779240..7430ba36 100644 --- a/models/networks.py +++ b/models/networks.py @@ -115,20 +115,20 @@ def define_loss(opt): if opt.dataset_mode == 'classification': loss = torch.nn.CrossEntropyLoss() elif opt.dataset_mode == 'segmentation': - # loss_ce = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) - - # loss_dice = dice_loss - # loss_jaccard = jaccard_loss + # loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') - weights = torch.FloatTensor([0.5, 2]).to(device) + weights = torch.FloatTensor(opt.loss_weights).to(device) - # loss_ce = functools.partial(ce_loss, weights=weights) - loss_ce_dice = functools.partial(ce_dice, weights=weights) - # loss_ce_jaccard = functools.partial(ce_jaccard, weights=weights) + losses = { + 'ce': functools.partial(ce_loss, weights=weights), + 'dice': dice_loss, + 'jaccard': jaccard_loss, + 'ce_dice': functools.partial(ce_dice, weights=weights), + 'ce_jaccard': functools.partial(ce_jaccard, weights=weights) + } - loss = loss_ce_dice - # ToDo: loss option + loss = losses.get(opt.loss) return loss ############################################################################## diff --git a/options/base_options.py b/options/base_options.py index 09b9aa26..1364a6b1 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -26,6 +26,11 @@ def initialize(self): self.parser.add_argument('--num_groups', type=int, default=16, help='# of groups for groupnorm') self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') self.parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + self.parser.add_argument('--loss', type=str, default='ce_dice', + help='loss function; possible values: ce, dice, jaccard, ce_dice, ce_jaccard') + self.parser.add_argument('--loss_weights', nargs='+', default=[0.5, 2], type=float, + help='weights for loss function, used only with ce/ce_dice/ce_jaccard losses') + # general params self.parser.add_argument('--num_threads', default=3, type=int, help='# threads for loading data') self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') From 3d2513db433bd0a143c0f04f96a97aaa4b59317e Mon Sep 17 00:00:00 2001 From: ihahanov Date: Thu, 2 Dec 2021 12:41:20 +0200 Subject: [PATCH 7/7] revert dice loss --- models/losses.py | 8 ++++++-- models/networks.py | 2 -- train_pl.py | 40 ++++++++++++++-------------------------- 3 files changed, 20 insertions(+), 30 deletions(-) diff --git a/models/losses.py b/models/losses.py index a0e9df5c..8b271778 100644 --- a/models/losses.py +++ b/models/losses.py @@ -165,8 +165,12 @@ def tversky_loss(true, logits, alpha, beta, eps=1e-7): return (1 - tversky_loss) -def ce_dice(true, pred, log=False, w1=1, w2=1): - pass +def ce_dice(true, pred, weights=torch.tensor([0.5, 2])): + if weights is not None: + weights = torch.tensor(weights).to(pred.device) + + return ce_loss(true, pred, weights) + \ + dice_loss(true, pred) def ce_jaccard(true, pred, weights=torch.tensor([0.5, 2])): diff --git a/models/networks.py b/models/networks.py index 7430ba36..816baa46 100644 --- a/models/networks.py +++ b/models/networks.py @@ -115,8 +115,6 @@ def define_loss(opt): if opt.dataset_mode == 'classification': loss = torch.nn.CrossEntropyLoss() elif opt.dataset_mode == 'segmentation': - # loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) - device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') weights = torch.FloatTensor(opt.loss_weights).to(device) diff --git a/train_pl.py b/train_pl.py index 048568aa..eda93bdc 100644 --- a/train_pl.py +++ b/train_pl.py @@ -27,7 +27,7 @@ def __init__(self, opt): if opt.from_pretrained is not None: print('Loaded pretrained weights:', opt.from_pretrained) self.model.load_weights(opt.from_pretrained) - self.criterion = ce_jaccard + self.criterion = self.model.criterion if self.training: self.train_metrics = torch.nn.ModuleList([ torchmetrics.Accuracy(num_classes=opt.nclasses, average='macro'), @@ -40,46 +40,34 @@ def __init__(self, opt): torchmetrics.F1(num_classes=opt.nclasses, average='macro') ]) - def training_step(self, batch, idx): + def step(self, batch, is_train=True): self.model.set_input(batch) out = self.model.forward() true, pred = postprocess(self.model.labels, out) - loss = self.criterion(true, pred, self.opt.class_weights) + loss = self.criterion(true, pred) - pred_class = out.data.max(1)[1] - not_padding = self.model.labels != -1 - label_class = self.model.labels[not_padding] - pred_class = pred_class[not_padding] + true = true.view(-1) + pred = pred.argmax(1).view(-1) + prefix = '' if is_train else 'val_' for m in self.train_metrics: - val = m(pred_class, label_class) + val = m(pred, true) metric_name = str(m).split('(')[0] - self.log(metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True) - self.log('loss', loss, on_epoch=True) + self.log(prefix + metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True) + self.log(prefix + 'loss', loss, on_epoch=True) return loss - def validation_step(self, batch, idx): - self.model.set_input(batch) - out = self.model.forward() - true, pred = postprocess(self.model.labels, out) - loss = self.criterion(true, pred, self.opt.class_weights) + def training_step(self, batch, idx): - pred_class = out.data.max(1)[1] - not_padding = self.model.labels != -1 - label_class = self.model.labels[not_padding] - pred_class = pred_class[not_padding] + return self.step(batch, is_train=True) - for m in self.val_metrics: - val = m(pred_class, label_class) - metric_name = str(m).split('(')[0] - self.log('val_' + metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True) - self.log('val_loss', loss, on_epoch=True) - return loss + def validation_step(self, batch, idx): + return self.step(batch, is_train=False) def forward(self, image): return self.model(image) - def on_train_epoch_end(self, unused = None): + def on_train_epoch_end(self, unused=None): for m in self.train_metrics: m.reset()