forked from moucheng2017/SatsumaSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRun.py
76 lines (61 loc) · 5.17 KB
/
Run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# basic libs:
import argparse
from MainSemi import trainBPL
from MainSup import trainSup
# We use 0 or 1 for False or True as alternative for boolean operations in this argparse
def main():
parser = argparse.ArgumentParser(description='Training for semi supervised segmentation with bayesian pseudo labels.')
# paths to the training data
parser.add_argument('--data', type=str, default='/home/moucheng/projects_data/HipCT/COVID_ML_data/hip_covid', help='Data path')
# parser.add_argument('--data', type=str, default='/home/moucheng/projects_data/Pulmonary_data/airway', help='Data path')
parser.add_argument('--log_tag', type=str, default='hip_ct_sup_aug', help='experiment tag for the record')
# hyper parameters for training (both sup and semi sup):
parser.add_argument('--input_dim', type=int, help='dimension for the input image, e.g. 1 for CT, 3 for RGB, and more for 3D inputs', default=1)
parser.add_argument('--output_dim', type=int, help='dimension for the output, e.g. 1 for binary segmentation, 3 for 3 classes', default=1)
parser.add_argument('--iterations', type=int, help='number of iterations', default=5000)
parser.add_argument('--lr', type=float, help='learning rate', default=0.01)
parser.add_argument('--width', type=int, help='number of filters in the first conv block in encoder', default=16)
parser.add_argument('--depth', type=int, help='number of downsampling stages', default=4)
parser.add_argument('--batch', type=int, help='number of training batch size', default=2)
parser.add_argument('--temp', '--t', type=float, help='temperature scaling on output logits when applying sigmoid and softmax', default=2.0)
parser.add_argument('--l2', type=float, help='l2 normalisation', default=0.01)
parser.add_argument('--seed', type=int, help='random seed', default=1128)
parser.add_argument('--ema_saving_starting', type=int, help='number of iterations when it starts to save avg model', default=100)
parser.add_argument('--patience', type=int, help='patience for validate accurate', default=100)
parser.add_argument('--validate_no', type=int, help='no of batch for validate because full validate is too time consuming', default=1)
# hyper parameters for training (specific for semi sup)
parser.add_argument('--unlabelled', type=int, help='SSL, ratio between unlabelled and labelled data in one batch, 0 for supervised learning', default=0)
parser.add_argument('--mu', type=float, help='SSL, prior Gaussian mean', default=0.9) # mu
parser.add_argument('--alpha', type=float, help='SSL, weight on the unsupervised learning part', default=1.0)
parser.add_argument('--beta', type=float, help='SSL, weight on the KL loss part', default=0.1)
parser.add_argument('--warmup', type=float, help='SSL, ratio between the iterations of warming up and the whole training iterations', default=0.1)
parser.add_argument('--warmup_start', type=int, help='SSL, when to start warm up the weight for the unsupervised learning part', default=100)
# flags for data preprocessing and augmentation in data loader:
parser.add_argument('--gaussian', type=int, help='1 when add random gaussian noise', default=1)
parser.add_argument('--zoom', type=int, help='1 when use random zoom in augmentation', default=1)
parser.add_argument('--cutout', type=int, help='1 when randomly cutout some patches', default=1)
parser.add_argument('--contrast', type=int, help='1 when use random contrast using histogram equalization with random bins', default=1)
parser.add_argument('--full_orthogonal', type=int, help='1 when each iteration has three orthogonal planes all together', default=1)
parser.add_argument('--new_size_h', type=int, help='new size for the image height', default=448)
parser.add_argument('--new_size_w', type=int, help='new size for the image width', default=448)
# flags for if we use fine-tuning on an trained model:
parser.add_argument('--resume', type=int, help='resume training on an existing model', default=0)
parser.add_argument('--checkpoint_path', type=str, help='path to the checkpoint model')
# flags to save the
global args
args = parser.parse_args()
if args.unlabelled == 0:
trainSup(args)
else:
trainBPL(args)
if __name__ == '__main__':
main()
# disabled options:
# parser.add_argument('--lung_window', type=int, help='1 when we apply lung window on data', default=0)
# parser.add_argument('--sampling', type=int, help='weight for sampling the slices along each axis of 3d volume for training, '
# 'highest weights at the edges and lowest at the middle', default=0)
# parser.add_argument('--norm', type=int, help='1 when normalise each case individually', default=1)
# parser.add_argument('--saving_frequency', type=int, help='number of interval of iterations when it starts to save', default=50)
# parser.add_argument('--cutout', type=int, help='1 when randomly cutout some patches', default=0)
# parser.add_argument('--detach', type=int, help='SSL, 1 when we cut the gradients in consistency regularisation or 0', default=0)
# parser.add_argument('--sigma', type=float, help='SSL, prior Gaussian std', default=0.1) # sigma