-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathloss.py
59 lines (46 loc) · 1.9 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_msssim
import numpy as np
ssim_loss = pytorch_msssim.msssim
class Sobelxy(nn.Module):
def __init__(self):
super(Sobelxy, self).__init__()
kernelx = [[-1, 0, 1],
[-2,0 , 2],
[-1, 0, 1]]
kernely = [[1, 2, 1],
[0,0 , 0],
[-1, -2, -1]]
kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
def forward(self,x):
sobelx=F.conv2d(x, self.weightx, padding=1)
sobely=F.conv2d(x, self.weighty, padding=1)
return torch.abs(sobelx)+torch.abs(sobely)
class Fusionloss(nn.Module):
def __init__(self):
super(Fusionloss, self).__init__()
self.sobelconv=Sobelxy()
def forward(self,image_vis,image_ir,labels,generate_img,i):
image_y=image_vis[:,:1,:,:]
x_in_max=torch.max(image_y,image_ir)
wb0 = 0.5
wb1 = 0.5
ssim_loss_temp1 = ssim_loss(generate_img, image_y, normalize=True)
ssim_loss_temp2 = ssim_loss(generate_img, image_ir, normalize=True)
ssim_value = wb0 * (1 - ssim_loss_temp1) + wb1 * (1 - ssim_loss_temp2)
loss_in = F.mse_loss(x_in_max, generate_img)
y_grad=self.sobelconv(image_y)
ir_grad=self.sobelconv(image_ir)
generate_img_grad=self.sobelconv(generate_img)
x_grad_joint=torch.max(y_grad,ir_grad)
loss_grad=F.l1_loss(x_grad_joint,generate_img_grad)
loss_total=(10*ssim_value)+(10*loss_in)+(1*loss_grad)
return loss_total, loss_in, ssim_value, loss_grad
#CT-MRI loss_in:10 loss_ssim:10,loss_grad:1