-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
160 lines (131 loc) · 8.47 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os, sys
import time
from glob import glob
import argparse
import torch
import wandb
from utils.tools import draw_eval_score
wandb.login()
from utils import *
import warnings
warnings.filterwarnings('ignore')
torch.multiprocessing.set_sharing_strategy('file_system')
def config():
"""
This function is for parsing commandline arguments.
"""
parser = argparse.ArgumentParser()
# mode parameters
parser.add_argument("--mode", type=str, default="offline", help="choose the mode for wandb, can be 'disabled', 'offline', 'online'")
parser.add_argument("--_4d", action="store_true", help="toggle to train on 4D image data")
parser.add_argument("--_mr", action="store_true", help="toggle to ONLY use MR data for training")
parser.add_argument("--save_on", type=str, default="cap", help="the dataset for validation, can be 'cap' or 'sct'")
parser.add_argument("--target", type=str, default=None, help="the target dataset for test, particularly for a different process on acdc data")
parser.add_argument("--template_mesh_dir", type=str,
default="./template/template_mesh-myo.obj",
help="the path to your initial meshes")
# training parameters
parser.add_argument("--max_epochs", type=int, default=20, help="the maximum number of epochs for training")
parser.add_argument("--pretrain_epochs", type=int, default=10, help="the number of epochs to train the segmentation UNet")
parser.add_argument("--train_epochs", type=int, default=12, help="the number of epochs to train the distance field prediction ResNet")
parser.add_argument("--reduce_count_down", type=int, default=-1, help="the count down for reduce the mesh face numbers.")
parser.add_argument("--val_interval", type=int, default=1, help="the interval of validation")
parser.add_argument("--lr", type=float, default=1e-3, help="the learning rate for training")
parser.add_argument("--batch_size", type=int, default=1, help="the batch size for training")
parser.add_argument("--cache_rate", type=float, default=1.0, help="the cache rate for training, see MONAI document for more details")
parser.add_argument("--crop_window_size", type=int, nargs='+', default=[128, 128, 128], help="the size of the crop window for training")
parser.add_argument("--pixdim", type=float, nargs='+', default=[4, 4, 4], help="the pixel dimension of downsampled images")
parser.add_argument("--lambda_0", type=float, default=1.06, help="the loss coefficients for Chamfer verts distance term")
parser.add_argument("--lambda_1", type=float, default=1.05, help="the loss coefficients for point to mesh distance term")
parser.add_argument("--interation", type=int, default=5, help="the interations for the distance field warping")
# data parameters
parser.add_argument("--ct_ratio", type=float, default=0.8, help="the portion of CT data for training")
parser.add_argument("--ct_json_dir", type=str,
default="./dataset/dataset_task20_f0.json",
help="the path to the json file with named list of CT train/valid/test sets")
parser.add_argument("--mr_json_dir", type=str,
# default="./dataset/dataset_task11_f0.json", # less data less burden
default="./dataset/dataset_task10_f0.json", # use only for 4d
help="the path to the json file with named list of MR train/valid/test sets")
parser.add_argument("--ct_data_dir", type=str,
default="/mnt/data/Experiment/Data/MorphiNet-MR_CT/Dataset020_SCOTHEART",
help="the path to your processed images, must be in nifti format")
parser.add_argument("--mr_data_dir", type=str,
# default="/mnt/data/Experiment/Data/MorphiNet-MR_CT/Dataset011_CAP_SAX",
default="/mnt/data/Experiment/Data/MorphiNet-MR_CT/Dataset010_CAP_SAX_NRRD",
help="the path to your processed images")
parser.add_argument("--ckpt_dir", type=str,
default="/mnt/data/Experiment/MorphiNet/Checkpoint",
help="the path to your checkpoint directory, for holding trained models and wandb logs")
parser.add_argument("--out_dir", type=str,
default="/mnt/data/Experiment/MorphiNet/Result",
help="the path to your output directory, for saving outputs")
# path to the pretrained modules
parser.add_argument("--use_ckpt", type=str,
# default=None,
default="/mnt/data/Experiment/MorphiNet/Checkpoint/dynamic/sct--myo--f0--2024-07-30-1649/",
help="the path to the pretrained models")
# structure parameters for df-predict module
parser.add_argument("--num_classes", type=int, default=4, help="the number of segmentation classes including the background")
parser.add_argument("--kernel_size", type=int, default=(3, 3, 3, 3, 3), nargs='+', help="the kernel size of the convolutional layer in the encoder")
parser.add_argument("--strides", type=int, default=(1, 2, 2, 2, 2), nargs='+', help="the stride of the convolutional layer in the encoder")
parser.add_argument("--filters", type=int, default=(8, 16, 32, 64, 128), nargs='+', help="the number of output channels in each layer of the encoder")
parser.add_argument("--layers", type=int, default=(1, 2, 2, 4), nargs='+', help="the number of layers in each residual block of the decoder")
parser.add_argument("--block_inplanes", type=int, default=(8, 16, 32, 64), nargs='+', help="the number of intermedium channels in each residual block")
# structure parameters for subdiv module
parser.add_argument("--subdiv_levels", type=int, default=2, help="the number of subdivision levels for the mesh (should be an integer larger than 0, where 0 means no subdivision)")
parser.add_argument("--hidden_features_gsn", type=int, default=16, help="the number of hidden features for the graph subdivide network")
# run_id for wandb, will create automatically if not specified for training
parser.add_argument("--run_id", type=str, default=None, help="the run name for wandb and local machine")
# the best epoch for testing
parser.add_argument("--best_epoch", type=int, default=None, help="the best epoch for testing")
args = parser.parse_args()
return args
def test(super_params):
wandb.init(config=super_params, mode="offline", project="MorphiNet-test", name=super_params.run_id.replace("sct", super_params.target))
pipeline = TrainPipeline(
super_params=super_params,
seed=42, num_workers=19,
is_training=False,
target="acdc" if super_params.target == "acdc" else None
)
pipeline._data_warper(rotation=False)
pipeline.test(super_params.save_on)
def ablation(super_params):
wandb.init(mode="disabled")
pipeline = TrainPipeline(
super_params=super_params,
seed=42, num_workers=19,
is_training=False
)
pipeline._data_warper(rotation=False)
pipeline.ablation_study(super_params.save_on)
if __name__ == '__main__':
super_params = config()
# super_params._mr = True
if super_params._mr:
from run_mr import *
else:
from run import *
# checkpoint info
super_params._4d = True
super_params.save_on = "cap"
ckpt = "cap--myo--f0--2024-08-20-2312"
super_params.best_epoch = "best"
super_params.iteration = 10
# data info
super_params.target = "4d"
super_params.ct_json_dir = f"/home/yd21/Documents/MorphiNet/dataset/dataset_task20_f0.json"
super_params.ct_data_dir = f"/mnt/data/Experiment/Data/MorphiNet-MR_CT/Dataset020_SCOTHEART"
super_params.mr_json_dir = f"/home/yd21/Documents/MorphiNet/dataset/dataset_task11_f0.json"
super_params.mr_data_dir = f"/mnt/data/Experiment/Data/MorphiNet-MR_CT/Dataset011_CAP_SAX"
# output info
super_params.out_dir = f"/mnt/data/Experiment/TMI_2024/{super_params.target}/MorphiNet/myo/f0/"
# super_params.out_dir = f"/mnt/data/Experiment/TMI_2024/{super_params.target}/MorphiNet/myo/80/"
# model info
super_params.run_id = ckpt
super_params.ckpt_dir = f"/mnt/data/Experiment/MorphiNet/Checkpoint/dynamic/{ckpt}/trained_weights"
super_params.template_mesh_dir = f"/home/yd21/Documents/MorphiNet/template/template_mesh-myo.obj"
test(super_params)
# super_params.out_dir = f"/mnt/data/Experiment/TMI_2024/{super_params.target}/MorphiNet/myo/f0/"
# ablation(super_params)