-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
98 lines (85 loc) · 4.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import sys
import os
os.environ["OMP_NUM_THREADS"] = "1"
import pyDRESCALk.config as config
config.init(0)
import argparse
from pyDRESCALk.utils import *
from pyDRESCALk.pyDRESCALk import *
from pyDRESCALk.pyDRESCAL import *
from pyDRESCALk.dist_comm import *
import pandas as pd
from mpi4py import MPI
from scipy import sparse
from scipy.sparse import csr_matrix
def parser_pyRescal(parser):
parser.add_argument('--p_r', type=int, required=True, help='Now of row processors')
parser.add_argument('--p_c', type=int, required=True, help='Now of column processors')
parser.add_argument('--k', type=int, required=False, help='feature count')
parser.add_argument('--gpu', type=str2bool, default=False, help='Switch to turn on GPUs')
parser.add_argument('--fpath', type=str, default='../Data/tmp/', help='data path to read(eg: tmp/)')
parser.add_argument('--ftype', type=str, default='npy', help='data type : mat/folder/h5')
parser.add_argument('--fname', type=str, default='A_', help='File name')
parser.add_argument('--init', type=str, default='rand', help='NMF initializations: rand/nnsvd')
parser.add_argument('--itr', type=int, default=10, help='NMF iterations, default:1000')
parser.add_argument('--norm', type=str, default='fro', help='Reconstruction Norm for NMF to optimize:KL/FRO')
parser.add_argument('--method', type=str, default='mu', help='NMF update method:MU/BCD/HALS')
parser.add_argument('--verbose', type=str2bool, default=True)
parser.add_argument('--results_path', type=str, default='Results/', help='Path for saving results')
parser.add_argument('--precision', type=str, default='float32', help='Precision of the data(float32/float64/float16).')
parser.add_argument('--key', type=str, default='X',
help='Key for the data if stored in dictionary for mat/npy file')
return parser
def parser_pyRescalk(parser):
parser.add_argument('--perturbation', type=int, default=10, help='perturbation for NMFk')
parser.add_argument('--noise_var', type=float, default=0.03, help='Noise variance for NMFk')
parser.add_argument('--start_k', type=int, default=1, help='Start index of K for NMFk')
parser.add_argument('--end_k', type=int, default=10, help='End index of K for NMFk')
parser.add_argument('--step_k', type=int, default=1, help='Start index of K for NMFk')
parser.add_argument('--sampling', type=str, default='uniform', help='Sampling noise for NMFk i.e uniform/poisson')
return parser
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Arguments for pyRescal/pyRescalk'
'To run the code for pyRescal: mpirun -n 4 python main.py --p_r=2 --p_c=2 --k=4 -fpath=../Data/') # ArgumentParser(description='Arguments for pyRescal/pyRescalk')
parser.add_argument('--process', type=str, default='pyRescal', help='pyRescal/pyRescalk')
parser = parser_pyRescal(parser)
parser = parser_pyRescalk(parser)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
'''Comm initialization block'''
#if args.p_r != args.p_c:
main_comm = MPI.COMM_WORLD
rank = main_comm.rank
comm = MPI_comm(main_comm, args.p_r, args.p_c)
args.comm1 = comm.comm
args.comm = comm
args.col_comm = comm.cart_1d_column()
args.row_comm = comm.cart_1d_row()
if args.gpu:
import cupy as xp
else:
import numpy as xp
args.np = xp
if args.gpu:
print("Using GPU:"+str(rank))
gpu_ct = xp.cuda.runtime.getDeviceCount()
gpu_id = rank % gpu_ct
xp.cuda.device.Device(gpu_id).use()
'''Data read block'''
if rank == 0: print('Reading data now')
X_ij = data_read(args).read()
X_ij = reorder_tensor(X_ij,args.precision)
if rank == 0: print('Reading data complete')
'''pyDRESCAL/pyDRESCALk block'''
if args.process == 'pyDRESCALk':
if main_comm.rank == 0: print('Starting pyDRESCALk...')
pyDRESCALk(X_ij, factors=None, params=args).fit()
if main_comm.rank == 0: print('pyDRESCALk done.')
elif args.process == 'pyDRESCAL':
if main_comm.rank == 0: print('Starting pyDRESCAL...')
pyDRESCAL(X_ij, factors=None, params=args).fit()
if main_comm.rank == 0: print('pyDRESCAL done.')