-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
233 lines (187 loc) · 7.56 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# -*- coding: utf-8 -*-
import os
import datetime
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from parameters import get_args
from wandb_utils import wandb_init
import pcode.create_dataset as create_dataset
import pcode.create_optimizer as create_optimizer
import pcode.create_metrics as create_metrics
import pcode.create_model as create_model
import pcode.create_scheduler as create_scheduler
import pcode.utils.topology as topology
import pcode.utils.checkpoint as checkpoint
import pcode.utils.op_paths as op_paths
import pcode.utils.stat_tracker as stat_tracker
import pcode.utils.logging as logging
from pcode.utils.timer import Timer
def init_distributed_world(conf, backend):
if backend == "mpi":
dist.init_process_group("mpi")
elif backend == "nccl" or backend == "gloo":
# init the process group.
_tmp_path = os.path.join(conf.checkpoint, "tmp", conf.timestamp)
op_paths.build_dirs(_tmp_path)
dist_init_file = os.path.join(_tmp_path, "dist_init")
torch.distributed.init_process_group(
backend=backend,
init_method="file://" + os.path.abspath(dist_init_file),
timeout=datetime.timedelta(seconds=120),
world_size=conf.n_mpi_process,
rank=int(conf.local_rank),
)
else:
raise NotImplementedError
def main(conf):
try:
init_distributed_world(conf, backend=conf.backend)
conf.distributed = True and conf.n_mpi_process > 1
except AttributeError as e:
print(f"failed to init the distributed world: {e}.")
conf.distributed = False
# setup wandb
if not conf.distributed or dist.get_rank() == 0:
wandb_init(vars(conf), name=conf.optimizer + "-{}".format(conf.lr))
# init the config.
init_config(conf)
if conf.stop_criteria == "iteration":
conf.num_epochs = conf.eval_n_points
# define the timer for different operations.
# if we choose the `train_fast` mode, then we will not track the time.
conf.timer = Timer(
verbosity_level=1 if conf.track_time and not conf.train_fast else 0,
log_fn=conf.logger.log_metric,
on_cuda=conf.on_cuda,
)
# create dataset.
data_loader = create_dataset.define_dataset(conf, force_shuffle=True)
# create model
model = create_model.define_model(conf, data_loader=data_loader)
# define the optimizer.
optimizer = create_optimizer.define_optimizer(conf, model)
# define the lr scheduler.
scheduler = create_scheduler.Scheduler(conf)
# add model with data-parallel wrapper.
if conf.graph.on_cuda:
if conf.n_sub_process > 1:
model = torch.nn.DataParallel(model, device_ids=conf.graph.device)
# (optional) reload checkpoint
try:
checkpoint.maybe_resume_from_checkpoint(conf, model, optimizer, scheduler)
except RuntimeError as e:
conf.logger.log(f"Resume Error: {e}")
conf.resumed = False
# train amd evaluate model.
if "rnn_lm" in conf.arch:
from pcode.distributed_running_nlp import train_and_validate
# safety check.
assert (
conf.n_sub_process == 1
), "our current data-parallel wrapper does not support RNN."
# define the criterion and metrics.
criterion = nn.CrossEntropyLoss(reduction="mean")
criterion = criterion.cuda() if conf.graph.on_cuda else criterion
metrics = create_metrics.Metrics(
model.module if "DataParallel" == model.__class__.__name__ else model,
task="language_modeling",
)
# define the best_perf tracker, either empty or from the checkpoint.
best_tracker = stat_tracker.BestPerf(
best_perf=None if "best_perf" not in conf else conf.best_perf,
larger_is_better=False,
)
scheduler.set_best_tracker(best_tracker)
# get train_and_validate_func
train_and_validate_fn = train_and_validate
elif "regression" in conf.arch:
from pcode.distributed_running_cv import train_and_validate
# define the criterion and metrics.
criterion = nn.MSELoss(reduction="mean")
criterion = criterion.cuda() if conf.graph.on_cuda else criterion
metrics = create_metrics.Metrics(
model.module if "DataParallel" == model.__class__.__name__ else model,
task="regression",
)
# define the best_perf tracker, either empty or from the checkpoint.
best_tracker = stat_tracker.BestPerf(
best_perf=None if "best_perf" not in conf else conf.best_perf,
larger_is_better=True,
)
scheduler.set_best_tracker(best_tracker)
# get train_and_validate_func
train_and_validate_fn = train_and_validate
else:
from pcode.distributed_running_cv import train_and_validate
# define the criterion and metrics.
criterion = nn.CrossEntropyLoss(reduction="mean")
criterion = criterion.cuda() if conf.graph.on_cuda else criterion
metrics = create_metrics.Metrics(
model.module if "DataParallel" == model.__class__.__name__ else model,
task="classification",
)
# define the best_perf tracker, either empty or from the checkpoint.
best_tracker = stat_tracker.BestPerf(
best_perf=None if "best_perf" not in conf else conf.best_perf,
larger_is_better=True,
)
scheduler.set_best_tracker(best_tracker)
# get train_and_validate_func
train_and_validate_fn = train_and_validate
# save arguments to disk.
checkpoint.save_arguments(conf)
# start training.
train_and_validate_fn(
conf,
model=model,
criterion=criterion,
scheduler=scheduler,
optimizer=optimizer,
metrics=metrics,
data_loader=data_loader,
)
def init_config(conf):
# define the graph for the computation.
cur_rank = dist.get_rank() if conf.distributed else 0
conf.rank = cur_rank
conf.graph = topology.define_graph_topology(
graph_topology=conf.graph_topology,
world=conf.world,
n_mpi_process=conf.n_mpi_process, # the # of total main processes.
# the # of subprocess for each main process.
n_sub_process=conf.n_sub_process,
comm_device=conf.comm_device,
on_cuda=conf.on_cuda,
rank=cur_rank,
p=conf.er_p
)
conf.is_centralized = conf.graph_topology == "complete"
# re-configure batch_size if sub_process > 1.
if conf.n_sub_process > 1:
conf.batch_size = conf.batch_size * conf.n_sub_process
# configure cuda related.
if conf.graph.on_cuda:
assert torch.cuda.is_available()
torch.manual_seed(conf.manual_seed)
torch.cuda.manual_seed(conf.manual_seed)
torch.cuda.set_device(conf.graph.device[0])
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True if conf.train_fast else False
else:
torch.manual_seed(conf.manual_seed)
# define checkpoint for logging.
checkpoint.init_checkpoint(conf)
# configure logger.
conf.logger = logging.Logger(conf.checkpoint_dir)
# display the arguments' info.
logging.display_args(conf)
if __name__ == "__main__":
conf = get_args()
if conf.optimizer == "parallel_choco" or conf.optimizer == "docom":
mp.set_start_method("forkserver", force=True)
# mp.set_start_method("spawn", force=True)
mp.set_sharing_strategy("file_system")
main(conf)