forked from mlcommons/GaNDLF
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgandlf_run
155 lines (134 loc) · 4.64 KB
/
gandlf_run
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function, division
import os
from GANDLF.utils import *
fix_paths(os.getcwd()) # add relevant vips path
import argparse
import sys
from pathlib import Path
from datetime import date
from GANDLF.training_manager import *
from GANDLF.inference_manager import InferenceManager
from GANDLF.parseConfig import parseConfig
from GANDLF.utils import populate_header_in_parameters
from GANDLF import version
def main():
copyrightMessage = (
"Contact: gandlf@cbica.upenn.edu\n\n"
+ "This program is NOT FDA/CE approved and NOT intended for clinical use.\nCopyright (c) "
+ str(date.today().year)
+ " University of Pennsylvania. All rights reserved."
)
parser = argparse.ArgumentParser(
prog="GANDLF",
formatter_class=argparse.RawTextHelpFormatter,
description="Image Semantic Segmentation and Regression using Deep Learning.\n\n"
+ copyrightMessage,
)
parser.add_argument(
"-config",
type=str,
help="The configuration file (contains all the information related to the training/inference session), this is read from 'output' during inference",
required=True,
)
parser.add_argument(
"-data",
type=str,
help="Data csv file that is used for training/inference; can also take a comma-separate training-validatation pre-split CSV",
required=True,
)
parser.add_argument(
"-output",
type=str,
help="Output directory to save intermediate files and model weights",
required=True,
)
parser.add_argument(
"-train",
type=int,
help="1: training and 0: inference; for 0, there needs to be a compatible model saved in '-output'",
required=True,
)
parser.add_argument(
"-device",
default="cuda",
type=str,
help="Device to perform requested session on 'cpu' or 'cuda'; for cuda, ensure CUDA_VISIBLE_DEVICES env var is set",
required=True,
)
parser.add_argument(
"-reset_prev",
default=False,
type=bool,
help="Whether the previous run in the output directory will be discarded or not",
required=False,
)
parser.add_argument(
"-v",
"--version",
action="version",
version="%(prog)s v{}".format(version) + "\n\n" + copyrightMessage,
help="Show program's version number and exit.",
)
args = parser.parse_args()
file_data_full = args.data
model_parameters = args.config
parameters = parseConfig(model_parameters)
device = args.device
parameters["output_dir"] = args.output
# fixme: for some reason, the 'bool' type is not working for train, needs to be checked
if args.train == 0:
args.train = False
else:
args.train = True
reset_prev = args.reset_prev
if "-1" in device:
device = "cpu"
if args.train: # train mode
Path(args.output).mkdir(parents=True, exist_ok=True)
# parse training CSV
if "," in file_data_full:
# training and validation pre-split
data_full = None
both_csvs = file_data_full.split(",")
data_train, headers_train = parseTrainingCSV(both_csvs[0], train=args.train)
data_validation, headers_validation = parseTrainingCSV(
both_csvs[1], train=args.train
)
if headers_train != headers_validation:
sys.exit(
"The training and validation CSVs do not have the same header information."
)
parameters = populate_header_in_parameters(parameters, headers_train)
# if we are here, it is assumed that the user wants to do training
TrainingManager_split(
dataframe_train=data_train,
dataframe_validation=data_validation,
outputDir=args.output,
parameters=parameters,
device=device,
reset_prev=reset_prev,
)
else:
data_full, headers = parseTrainingCSV(file_data_full, train=args.train)
parameters = populate_header_in_parameters(parameters, headers)
# # start computation - either training or inference
if args.train: # training mode
TrainingManager(
dataframe=data_full,
outputDir=args.output,
parameters=parameters,
device=device,
reset_prev=reset_prev,
)
else:
InferenceManager(
dataframe=data_full,
outputDir=args.output,
parameters=parameters,
device=device,
)
print("Finished.")
if __name__ == "__main__":
main()