Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/losses #2

Merged
merged 8 commits into from
Dec 2, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
revert dice loss
  • Loading branch information
ihahanov committed Dec 2, 2021
commit 3d2513db433bd0a143c0f04f96a97aaa4b59317e
8 changes: 6 additions & 2 deletions models/losses.py
Original file line number Diff line number Diff line change
@@ -165,8 +165,12 @@ def tversky_loss(true, logits, alpha, beta, eps=1e-7):
return (1 - tversky_loss)


def ce_dice(true, pred, log=False, w1=1, w2=1):
pass
def ce_dice(true, pred, weights=torch.tensor([0.5, 2])):
if weights is not None:
weights = torch.tensor(weights).to(pred.device)

return ce_loss(true, pred, weights) + \
dice_loss(true, pred)


def ce_jaccard(true, pred, weights=torch.tensor([0.5, 2])):
2 changes: 0 additions & 2 deletions models/networks.py
Original file line number Diff line number Diff line change
@@ -115,8 +115,6 @@ def define_loss(opt):
if opt.dataset_mode == 'classification':
loss = torch.nn.CrossEntropyLoss()
elif opt.dataset_mode == 'segmentation':
# loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=torch.tensor([0.5, 2]))

device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
weights = torch.FloatTensor(opt.loss_weights).to(device)

40 changes: 14 additions & 26 deletions train_pl.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ def __init__(self, opt):
if opt.from_pretrained is not None:
print('Loaded pretrained weights:', opt.from_pretrained)
self.model.load_weights(opt.from_pretrained)
self.criterion = ce_jaccard
self.criterion = self.model.criterion
if self.training:
self.train_metrics = torch.nn.ModuleList([
torchmetrics.Accuracy(num_classes=opt.nclasses, average='macro'),
@@ -40,46 +40,34 @@ def __init__(self, opt):
torchmetrics.F1(num_classes=opt.nclasses, average='macro')
])

def training_step(self, batch, idx):
def step(self, batch, is_train=True):
self.model.set_input(batch)
out = self.model.forward()
true, pred = postprocess(self.model.labels, out)
loss = self.criterion(true, pred, self.opt.class_weights)
loss = self.criterion(true, pred)

pred_class = out.data.max(1)[1]
not_padding = self.model.labels != -1
label_class = self.model.labels[not_padding]
pred_class = pred_class[not_padding]
true = true.view(-1)
pred = pred.argmax(1).view(-1)

prefix = '' if is_train else 'val_'
for m in self.train_metrics:
val = m(pred_class, label_class)
val = m(pred, true)
metric_name = str(m).split('(')[0]
self.log(metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True)
self.log('loss', loss, on_epoch=True)
self.log(prefix + metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True)
self.log(prefix + 'loss', loss, on_epoch=True)
return loss

def validation_step(self, batch, idx):
self.model.set_input(batch)
out = self.model.forward()
true, pred = postprocess(self.model.labels, out)
loss = self.criterion(true, pred, self.opt.class_weights)
def training_step(self, batch, idx):

pred_class = out.data.max(1)[1]
not_padding = self.model.labels != -1
label_class = self.model.labels[not_padding]
pred_class = pred_class[not_padding]
return self.step(batch, is_train=True)

for m in self.val_metrics:
val = m(pred_class, label_class)
metric_name = str(m).split('(')[0]
self.log('val_' + metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True)
self.log('val_loss', loss, on_epoch=True)
return loss
def validation_step(self, batch, idx):
return self.step(batch, is_train=False)

def forward(self, image):
return self.model(image)

def on_train_epoch_end(self, unused = None):
def on_train_epoch_end(self, unused=None):
for m in self.train_metrics:
m.reset()