diff --git a/models/losses.py b/models/losses.py index a0e9df5c..8b271778 100644 --- a/models/losses.py +++ b/models/losses.py @@ -165,8 +165,12 @@ def tversky_loss(true, logits, alpha, beta, eps=1e-7): return (1 - tversky_loss) -def ce_dice(true, pred, log=False, w1=1, w2=1): - pass +def ce_dice(true, pred, weights=torch.tensor([0.5, 2])): + if weights is not None: + weights = torch.tensor(weights).to(pred.device) + + return ce_loss(true, pred, weights) + \ + dice_loss(true, pred) def ce_jaccard(true, pred, weights=torch.tensor([0.5, 2])): diff --git a/models/mesh_classifier.py b/models/mesh_classifier.py index ecf38183..fd382621 100644 --- a/models/mesh_classifier.py +++ b/models/mesh_classifier.py @@ -34,8 +34,12 @@ def __init__(self, opt): self.net = networks.define_classifier(opt.input_nc, opt.ncf, opt.ninput_edges, opt.nclasses, opt, self.gpu_ids, opt.arch, opt.init_type, opt.init_gain) self.net.train(self.is_train) - from .losses import ce_jaccard - self.criterion = ce_jaccard#networks.define_loss(opt).to(self.device) + self.criterion = networks.define_loss(opt) + + if self.is_train: + self.optimizer = torch.optim.Adam(self.net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.scheduler = networks.get_scheduler(self.optimizer, opt) + print_network(self.net) if not self.is_train or opt.continue_train: self.load_network(opt.which_epoch) diff --git a/models/networks.py b/models/networks.py index 8f6fe4b1..816baa46 100644 --- a/models/networks.py +++ b/models/networks.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from models.layers.mesh_pool import MeshPool from models.layers.mesh_unpool import MeshUnpool -from .losses import ce_jaccard +from .losses import ce_jaccard, dice_loss, jaccard_loss, ce_loss, ce_dice ############################################################################### @@ -115,7 +115,18 @@ def define_loss(opt): if opt.dataset_mode == 'classification': loss = torch.nn.CrossEntropyLoss() elif opt.dataset_mode == 'segmentation': - loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2])) + device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') + weights = torch.FloatTensor(opt.loss_weights).to(device) + + losses = { + 'ce': functools.partial(ce_loss, weights=weights), + 'dice': dice_loss, + 'jaccard': jaccard_loss, + 'ce_dice': functools.partial(ce_dice, weights=weights), + 'ce_jaccard': functools.partial(ce_jaccard, weights=weights) + } + + loss = losses.get(opt.loss) return loss ############################################################################## diff --git a/options/base_options.py b/options/base_options.py index 09b9aa26..1364a6b1 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -26,6 +26,11 @@ def initialize(self): self.parser.add_argument('--num_groups', type=int, default=16, help='# of groups for groupnorm') self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') self.parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + self.parser.add_argument('--loss', type=str, default='ce_dice', + help='loss function; possible values: ce, dice, jaccard, ce_dice, ce_jaccard') + self.parser.add_argument('--loss_weights', nargs='+', default=[0.5, 2], type=float, + help='weights for loss function, used only with ce/ce_dice/ce_jaccard losses') + # general params self.parser.add_argument('--num_threads', default=3, type=int, help='# threads for loading data') self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') diff --git a/train_pl.py b/train_pl.py index 048568aa..eda93bdc 100644 --- a/train_pl.py +++ b/train_pl.py @@ -27,7 +27,7 @@ def __init__(self, opt): if opt.from_pretrained is not None: print('Loaded pretrained weights:', opt.from_pretrained) self.model.load_weights(opt.from_pretrained) - self.criterion = ce_jaccard + self.criterion = self.model.criterion if self.training: self.train_metrics = torch.nn.ModuleList([ torchmetrics.Accuracy(num_classes=opt.nclasses, average='macro'), @@ -40,46 +40,34 @@ def __init__(self, opt): torchmetrics.F1(num_classes=opt.nclasses, average='macro') ]) - def training_step(self, batch, idx): + def step(self, batch, is_train=True): self.model.set_input(batch) out = self.model.forward() true, pred = postprocess(self.model.labels, out) - loss = self.criterion(true, pred, self.opt.class_weights) + loss = self.criterion(true, pred) - pred_class = out.data.max(1)[1] - not_padding = self.model.labels != -1 - label_class = self.model.labels[not_padding] - pred_class = pred_class[not_padding] + true = true.view(-1) + pred = pred.argmax(1).view(-1) + prefix = '' if is_train else 'val_' for m in self.train_metrics: - val = m(pred_class, label_class) + val = m(pred, true) metric_name = str(m).split('(')[0] - self.log(metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True) - self.log('loss', loss, on_epoch=True) + self.log(prefix + metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True) + self.log(prefix + 'loss', loss, on_epoch=True) return loss - def validation_step(self, batch, idx): - self.model.set_input(batch) - out = self.model.forward() - true, pred = postprocess(self.model.labels, out) - loss = self.criterion(true, pred, self.opt.class_weights) + def training_step(self, batch, idx): - pred_class = out.data.max(1)[1] - not_padding = self.model.labels != -1 - label_class = self.model.labels[not_padding] - pred_class = pred_class[not_padding] + return self.step(batch, is_train=True) - for m in self.val_metrics: - val = m(pred_class, label_class) - metric_name = str(m).split('(')[0] - self.log('val_' + metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True) - self.log('val_loss', loss, on_epoch=True) - return loss + def validation_step(self, batch, idx): + return self.step(batch, is_train=False) def forward(self, image): return self.model(image) - def on_train_epoch_end(self, unused = None): + def on_train_epoch_end(self, unused=None): for m in self.train_metrics: m.reset() diff --git a/util/util.py b/util/util.py index 562c22f6..db42c733 100644 --- a/util/util.py +++ b/util/util.py @@ -66,3 +66,18 @@ def calculate_entropy(np_array): entropy -= a * np.log(a) entropy /= np.log(np_array.shape[0]) return entropy + + +def remove_padding(label_class, pred_class): + num_classes = pred_class.size()[1] + label_class, pred_class = label_class.flatten(), pred_class.flatten() + + not_padding = label_class != -1 + label_class = label_class[not_padding] + label_class = label_class.view(1, -1) + + not_padding = not_padding.repeat(num_classes) + pred_class = pred_class[not_padding] + pred_class = pred_class.view(1, num_classes, -1) + + return label_class, pred_class \ No newline at end of file