-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
71 lines (51 loc) · 2.72 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse, os
import torch
from model import UNet, ResUNet, UNet_LRes, ResUNet_LRes
from utils import DataLoaderVal
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image
parser = argparse.ArgumentParser(description="PyTorch InfantSeg")
parser.add_argument("--gpuID", type=int, default=0, help="how to normalize the data")
parser.add_argument("--img_size", type=int, default=256, help="size of image")
parser.add_argument("--dataset", action="store_true", help="name of dataset", default='MRI-PET')
parser.add_argument("--numOfChannel_allSource", type=int, default=3, help="# of channels for a 2D patch for all the concatenated modalities (Default, 5)")
parser.add_argument("--modelName", default="/home/niedong/Data4LowDosePET/pytorch_UNet/model/resunet2d_dp_pet_BatchAug_sNorm_lres_bn_lr5e3_lrdec_base1_lossL1_0p005_0628_200000.pt", type=str, help="modelname")
parser.add_argument("--inputKey", default="MRI", type=str, help="input modality")
parser.add_argument("--targetKey", default="PET", type=str, help="target modality")
global opt
opt = parser.parse_args()
def main():
os.makedirs('result', exist_ok=True)
if opt.whichNet==1:
netG = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
elif opt.whichNet==2:
netG = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
elif opt.whichNet==3:
netG = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
elif opt.whichNet==4:
netG = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
#netG.apply(weights_init)
netG.cuda()
checkpoint = torch.load(os.path.join('checkpoints', opt.modelName))
netG.load_state_dict(checkpoint['model'])
test_dataset = DataLoaderVal(os.path.join('../dataset', opt.dataset, 'test'), opt.inputKey, opt.targetKey, {'w': opt.img_size, 'h': opt.img_size})
testloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=16, drop_last=False, pin_memory=True)
for i, data in enumerate(tqdm(testloader)):
inputs = data[0]
f_name = data[2][0]
source = inputs
#source = inputs
mid_slice = opt.numOfChannel_singleSource//2
residual_source = inputs[:, mid_slice, ...]
source = source.cuda()
residual_source = residual_source.cuda()
labels = labels.cuda()
if opt.whichNet == 3 or opt.whichNet == 4:
outputG = netG(source, residual_source) # 5x64x64->1*64x64
else:
outputG = netG(source) # 5x64x64->1*64x64
save_image(outputG, os.path.join('result', f_name))
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpuID)
main()