-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
63 lines (53 loc) · 1.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch.nn as nn
import torch.nn.functional as F
import torch
class SoftDiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(SoftDiceLoss, self).__init__()
def forward(self, pred, target):
num = target.size(0)
probs = torch.sigmoid(pred)
m1 = probs.view(num, -1)
m2 = target.view(num, -1)
intersection = (m1 * m2)
score = 2. * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1)
score = 1 - score.sum() / num
return score
def soft_dice(pred, target, size_average=True, batch_average=True):
loss_f = SoftDiceLoss()
return loss_f(pred, target)
def load_pretrain_model(net, weights):
net_keys = list(net.state_dict().keys())
weights_keys = list(weights.keys())
# assert(len(net_keys) <= len(weights_keys))
i = 0
j = 0
while i < len(net_keys) and j < len(weights_keys):
name_i = net_keys[i]
name_j = weights_keys[j]
if net.state_dict()[name_i].shape == weights[name_j].shape:
net.state_dict()[name_i].copy_(weights[name_j].cpu())
i += 1
j += 1
else:
i += 1
# print i, len(net_keys), j, len(weights_keys)
return net
def load_pretrain_model_fast(net, weights):
net_keys = net.state_dict().keys()
weights_keys = weights.keys()
# assert(len(net_keys) <= len(weights_keys))
i = 0
j = 0
while i < len(net_keys) and j < len(weights_keys):
name_i = net_keys[i]
name_j = weights_keys[j]
if net.state_dict()[name_i].shape == weights[name_j].shape:
net.state_dict()[name_i].copy_(weights[name_j].cpu())
i += 1
j += 1
else:
break
i += 1
# print i, len(net_keys), j, len(weights_keys)
return net