-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
63 lines (56 loc) · 3.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import argparse
import torch
import dist_utils
import dist_train
import torch.distributed as dist
# 定义一个函数,用于包装分布式训练的设置和启动
def process_wrapper(rank, args, func):
# for single machine
os.environ['MASTER_ADDR'] = '127.0.0.1' #设置分布式训练的主节点,127.0.0.1默认是本地节点
os.environ['MASTER_PORT'] = '29501' #端口号
os.environ['NCCL_SOCKET_IFNAME'] = 'lo' #NCCL网络接口名称,'lo'通常表示本地回环接口,GPU通信将通过本地主机进行
env = dist_utils.DistEnv(rank, args.nprocs, args.backend)
env.half_enabled = True
env.csr_enabled = True
# for multi machine
# os.environ['MASTER_ADDR'] = '202.199.6.34' #设置分布式训练的主节点,127.0.0.1默认是本地节点
# os.environ['MASTER_PORT'] = '29500' #端口号
# os.environ['NCCL_SOCKET_IFNAME'] = 'eno1' #通信接口
# # 创建 DistEnv 对象,该对象封装了分布式训练的环境信息
# # rank = 0 #多机下,要手动指定rank值
# env = dist_utils.DistEnv(rank, args.nnodes, args.backend)
# env.half_enabled = True
# env.csr_enabled = True
# 调用传入的 func 函数,开始分布式训练
func(env, args)
if __name__ == "__main__":
num_GPUs = torch.cuda.device_count()
parser = argparse.ArgumentParser()
# parser.add_argument("--nprocs", type=int, default=num_GPUs if num_GPUs>1 else 8)
#single GPU
parser.add_argument("--nprocs", type=int, default=2)
parser.add_argument("--chunk", type=int, default=32)
parser.add_argument("--nnodes", type=int, default=1)
parser.add_argument("--nlayers", type=int, default=2)
parser.add_argument("--hidden", type=int, default=128)
parser.add_argument("--epoch", type=int, default=20)
# parser.add_argument("--backend", type=str, default='gloo')
parser.add_argument("--backend", type=str, default='nccl' if num_GPUs>1 else 'gloo')
# parser.add_argument("--dataset", type=str, default='ogbn-100m')
# parser.add_argument("--dataset", type=str, default='friendster')
# parser.add_argument("--dataset", type=str, default='reddit')
parser.add_argument("--dataset", type=str, default='cora')
# parser.add_argument("--model", type=str, default='DecoupleGCN')
# parser.add_argument("--model", type=str, default='GCN')
# parser.add_argument("--model", type=str, default='TensplitGCN')
# parser.add_argument("--model", type=str, default='TensplitGCNLARGE')
# parser.add_argument("--model", type=str, default='TensplitGCNSWAP')
# parser.add_argument("--model", type=str, default='TensplitGCNCPU')
# parser.add_argument("--model", type=str, default='TensplitGATLARGE')
# parser.add_argument("--model", type=str, default='GAT')
parser.add_argument("--model", type=str, default='TensplitGAT')
args = parser.parse_args()
process_args = (args, dist_train.main)
# 启动多个进程进行分布式训练
torch.multiprocessing.spawn(process_wrapper, process_args, args.nprocs)