Skip to content

Commit

Permalink
float16 training
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhao committed Mar 11, 2022
1 parent 26740e0 commit d9f9129
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 35 deletions.
5 changes: 3 additions & 2 deletions coo_graph/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ def sym_normalization(edge_index, num_nodes, faster_device='cuda:0'):
return DAD.coalesce().to(original_device)


def sparse_2d_split(st, split_size, split_dim=0):
def sparse_2d_split(st, split_size, split_dim=0, device='cpu'):
seps = list(range(0, st.size(split_dim), split_size)) + [st.size(split_dim)]
parts = []
split_idx = st.indices()[split_dim]
other_idx = st.indices()[1 - split_dim]
make_2d_st = lambda idx0, idx1, val, sz0, sz1: torch.sparse_coo_tensor(torch.stack([idx0, idx1]), val, (sz0, sz1)).coalesce()
def make_2d_st(idx0, idx1, val, sz0, sz1):
return torch.sparse_coo_tensor(torch.stack([idx0, idx1]), val, (sz0, sz1), device=device).coalesce()
for lower, upper in zip(seps[:-1], seps[1:]):
mask: torch.Tensor = (split_idx < upper) & (split_idx >= lower)
if split_dim == 0:
Expand Down
18 changes: 10 additions & 8 deletions coo_graph/parted_coo_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@


class BasicGraph:
def __init__(self, d, name, device):
def __init__(self, d, name, device, half_enabled):
dtype=torch.float16 if half_enabled else torch.float
self.name, self.device, self.attr_dict = name, device, d
self.adj = d['adj'].to(device)
self.features = d['features'].to(device)
self.adj = d['adj'].to(dtype=dtype)
self.features = d['features'].to(device, dtype=dtype)
self.labels = d['labels'].to(device).to(torch.float if d['labels'].dim()==2 else torch.long)
self.train_mask, self.val_mask, self.test_mask = (d[t].bool().to(device) for t in ("train_mask", 'val_mask', 'test_mask'))
self.num_nodes, self.num_edges, self.num_classes = d["num_nodes"], d['num_edges'], d['num_classes']
Expand Down Expand Up @@ -49,7 +50,7 @@ def load_dict(path):


class COO_Graph(BasicGraph):
def __init__(self, name, preprocess_for='GCN', full_graph_cache_enabled=True, device='cpu'):
def __init__(self, name, full_graph_cache_enabled=True, device='cpu', half_enabled=False, preprocess_for='GCN'):
self.preprocess_for = preprocess_for
self.cache_path = GraphCache.full_graph_path(name, preprocess_for)
if full_graph_cache_enabled and os.path.exists(self.cache_path):
Expand All @@ -58,7 +59,7 @@ def __init__(self, name, preprocess_for='GCN', full_graph_cache_enabled=True, de
src_data = datasets.load_dataset(name)
cached_attr_dict = graph_utils.preprocess(name, src_data, preprocess_for) # norm feat, remove edge_index, add adj
GraphCache.save_dict(cached_attr_dict, self.cache_path)
super().__init__(cached_attr_dict, name, device)
super().__init__(cached_attr_dict, name, device, half_enabled)

def partition(self, num_parts, padding=True):
begin = datetime.datetime.now()
Expand Down Expand Up @@ -94,21 +95,22 @@ def partition(self, num_parts, padding=True):


class Parted_COO_Graph(BasicGraph):
def __init__(self, name, rank, num_parts, preprocess_for='GCN', device='cpu'):
def __init__(self, name, rank, num_parts, device='cpu', half_enabled=False, preprocess_for='GCN'):
# self.full_g = COO_Graph(name, preprocess_for, True, 'cpu')
self.rank, self.num_parts = rank, num_parts
cache_path = GraphCache.parted_graph_path(name, preprocess_for, rank, num_parts)
if not os.path.exists(cache_path):
raise Exception('Not parted yet. Run COO_Graph.partition() first.', cache_path)
cached_attr_dict = GraphCache.load_dict(cache_path)
super().__init__(cached_attr_dict, name, device)
super().__init__(cached_attr_dict, name, device, half_enabled)

# adj and features are local already
self.local_num_nodes = self.adj.size(0)
self.local_num_edges = self.adj.values().size(0)
self.local_labels = self.labels[self.local_num_nodes*rank:self.local_num_nodes*(rank+1)]
self.local_train_mask = self.train_mask[self.local_num_nodes*rank:self.local_num_nodes*(rank+1)].bool()

self.adj_parts = graph_utils.sparse_2d_split(self.adj, self.local_num_nodes, split_dim=1)
self.adj_parts = graph_utils.sparse_2d_split(self.adj, self.local_num_nodes, split_dim=1, device=device)

def __repr__(self):
local_g = f'<Local: {self.rank}, |V|: {self.local_num_nodes}, |E|: {self.local_num_edges}>'
Expand Down
18 changes: 10 additions & 8 deletions dist_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import numpy as np
from sklearn.metrics import f1_score
from dist_utils import DistEnv
Expand Down Expand Up @@ -34,13 +35,14 @@ def train(g, env, total_epoch):
loss_func = nn.BCEWithLogitsLoss(reduction='mean')
for epoch in range(total_epoch):
with env.timer.timing('epoch'):
outputs = model(g.features)
optimizer.zero_grad()
if g.local_labels[g.local_train_mask].size(0) > 0:
loss = loss_func(outputs[g.local_train_mask], g.local_labels[g.local_train_mask])
else:
env.logger.log('Warning: no training nodes in this partition! Backward fake loss.')
loss = (outputs * 0).sum()
with autocast(env.half_enabled):
outputs = model(g.features)
optimizer.zero_grad()
if g.local_labels[g.local_train_mask].size(0) > 0:
loss = loss_func(outputs[g.local_train_mask], g.local_labels[g.local_train_mask])
else:
env.logger.log('Warning: no training nodes in this partition! Backward fake loss.')
loss = (outputs * 0).sum()
loss.backward()
optimizer.step()
env.logger.log("Epoch {:05d} | Loss {:.4f}".format(epoch, loss.item()), rank=0)
Expand All @@ -58,7 +60,7 @@ def train(g, env, total_epoch):
def main(env, args):
env.logger.log('proc begin:', env)
with env.timer.timing('total'):
g = Parted_COO_Graph(args.dataset, rank=env.rank, num_parts=env.world_size, device=env.device)
g = Parted_COO_Graph(args.dataset, env.rank, env.world_size, env.device, env.half_enabled)
env.logger.log('graph loaded', g)
train(g, env, total_epoch=args.epoch)
env.logger.log(env.timer.summary_all(), rank=0)
Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def process_wrapper(rank, args, func):
# os.environ['NCCL_MAX_NCHANNELS'] = '1'

env = dist_utils.DistEnv(rank, args.nprocs, args.backend)
env.half_enabled = True
func(env, args)


Expand Down
Binary file added models/.cached_gcn.py.swp
Binary file not shown.
18 changes: 12 additions & 6 deletions models/cached_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@

try:
from spmm_cpp import spmm_cusparse
spmm = lambda A,B,C: spmm_cusparse(A.indices()[0].int(), A.indices()[1].int(), A.values(), A.size(0), A.size(1), B, C, 1, 1)
except ImportError:
def spmm(A,B,C):
spmm_cusparse(A.indices()[0].int(), A.indices()[1].int(), A.values(), A.size(0), A.size(1), \
B, C, 1.0, 1.0, DistEnv.env.half_enabled)
except ImportError as e:
print('no spmm cpp:', e)
spmm = lambda A,B,C: C.addmm_(A,B)


from collections import defaultdict
g_cache = defaultdict(dict)
g_cache_enabled = {'ForwardL1': True, 'ForwardL2': True,
'BackwardL1': False, 'BackwardL2': False
}
'BackwardL1': False, 'BackwardL2': False }
g_cache_enabled = {'ForwardL1': False, 'ForwardL2': False,
'BackwardL1': False, 'BackwardL2': False }

g_bcast_counter = defaultdict(lambda: defaultdict(int))
g_epoch_counter = defaultdict(int)
Expand All @@ -35,6 +39,7 @@ def cached_broadcast(local_adj_parts, local_feature, tag):
env = DistEnv.env
z_loc = torch.zeros_like(local_feature)
feature_bcast = torch.zeros_like(local_feature)
# print('bcast feature', feature_bcast)
g_epoch_counter[tag] += 1

for src in range(env.world_size):
Expand All @@ -46,7 +51,8 @@ def cached_broadcast(local_adj_parts, local_feature, tag):
with env.timer.timing_cuda(f'broadcast {tag} {src}'):
dist.broadcast(feature_bcast, src=src)
g_bcast_counter[tag][src] += 1
g_cache[tag][src] = feature_bcast.clone()
if g_cache_enabled[tag]:
g_cache[tag][src] = feature_bcast.clone()
# env.logger.log('not cached', tag, src, 'counter', g_bcast_counter[tag][src])
else:
# env.logger.log('cached', tag, src)
Expand All @@ -72,7 +78,7 @@ def backward(ctx, grad_output):
features, weight = ctx.saved_tensors
ag = cached_broadcast(ctx.adj_parts, grad_output, 'Backward'+ctx.tag)
with DistEnv.env.timer.timing_cuda('mm'):
grad_features = torch.mm(ag, weight.t())
grad_features = torch.mm(ag.to(dtype=torch.float), weight.t())
grad_weight = torch.mm(features.t(), ag)
with DistEnv.env.timer.timing_cuda('all_reduce'):
DistEnv.env.all_reduce_sum(grad_weight)
Expand Down
8 changes: 4 additions & 4 deletions prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@


def main():
cached = False
# r = COO_Graph('cora')
cached = True
r = coo_graph.COO_Graph('reddit')
# r = coo_graph.COO_Graph('cora', full_graph_cache_enabled=cached)
# r = coo_graph.COO_Graph('flickr', full_graph_cache_enabled=cached)
# r = coo_graph.COO_Graph('reddit', full_graph_cache_enabled=cached)
# r = coo_graph.COO_Graph('ogbn-arxiv', full_graph_cache_enabled=cached)
# r.partition(8)
# r.partition(4)
# return
r.partition(4)
return
# for name in ['amazon-products', 'ogbn-products']:
for name in ['ogbn-arxiv', 'ogbn-products']:
r = coo_graph.COO_Graph(name, full_graph_cache_enabled=cached)
Expand Down
27 changes: 20 additions & 7 deletions spmm_cpp/spmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,35 @@
}

typedef const at::Tensor& T;
void spmm_cusparse(T A_row_idx, T A_col_idx, T A_values, int32_t A_row, int32_t A_col, T B, T C, float alpha, float beta) {
void spmm_cusparse(T A_row_idx, T A_col_idx, T A_values, int32_t A_row, int32_t A_col, T B, T C, float alpha, float beta, int use_half) {
cusparseSpMatDescr_t matA;
CHECK_CUSPARSE( cusparseCreateCoo(&matA, A_row, A_col, A_values.size(0), A_row_idx.data_ptr<int>(), A_col_idx.data_ptr<int>(), A_values.data_ptr<float>(),
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F) )

cusparseDnMatDescr_t matB, matC; // mat from torch is row major
if (use_half){
//printf("use half\n");
CHECK_CUSPARSE( cusparseCreateCoo(&matA, A_row, A_col, A_values.size(0), A_row_idx.data_ptr<int>(), A_col_idx.data_ptr<int>(), A_values.data_ptr<at::Half>(), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) )

//printf("use half A ready\n");
CHECK_CUSPARSE( cusparseCreateDnMat(&matB, B.size(0), B.size(1), B.size(1), B.data_ptr<at::Half>(), CUDA_R_16F, CUSPARSE_ORDER_ROW) )
CHECK_CUSPARSE( cusparseCreateDnMat(&matC, C.size(0), C.size(1), C.size(1), C.data_ptr<at::Half>(), CUDA_R_16F, CUSPARSE_ORDER_ROW) )

//printf("use half BC ready\n");
CHECK_CUSPARSE(cusparseSpMM(at::cuda::getCurrentCUDASparseHandle(), CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, CUDA_R_16F, CUSPARSE_SPMM_COO_ALG4, NULL)); //CUSPARSE_MM_ALG_DEFAULT, CUSPARSE_SPMM_CSR_ALG2 , CUSPARSE_SPMM_COO_ALG4

}else{
//printf("use float32\n");
CHECK_CUSPARSE( cusparseCreateCoo(&matA, A_row, A_col, A_values.size(0), A_row_idx.data_ptr<int>(), A_col_idx.data_ptr<int>(), A_values.data_ptr<float>(), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F) )

CHECK_CUSPARSE( cusparseCreateDnMat(&matB, B.size(0), B.size(1), B.size(1), B.data_ptr<float>(), CUDA_R_32F, CUSPARSE_ORDER_ROW) )
CHECK_CUSPARSE( cusparseCreateDnMat(&matC, C.size(0), C.size(1), C.size(1), C.data_ptr<float>(), CUDA_R_32F, CUSPARSE_ORDER_ROW) )

CHECK_CUSPARSE(cusparseSpMM(at::cuda::getCurrentCUDASparseHandle(), CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, matA, matB, &beta, matC,
CUDA_R_32F, CUSPARSE_SPMM_COO_ALG4, NULL)); //CUSPARSE_MM_ALG_DEFAULT, CUSPARSE_SPMM_CSR_ALG2 , CUSPARSE_SPMM_COO_ALG4
CHECK_CUSPARSE(cusparseSpMM(at::cuda::getCurrentCUDASparseHandle(), CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, CUDA_R_32F, CUSPARSE_SPMM_COO_ALG4, NULL)); //CUSPARSE_MM_ALG_DEFAULT, CUSPARSE_SPMM_CSR_ALG2 , CUSPARSE_SPMM_COO_ALG4
}


CHECK_CUSPARSE( cusparseDestroySpMat(matA) )
CHECK_CUSPARSE( cusparseDestroyDnMat(matB) )
CHECK_CUSPARSE( cusparseDestroyDnMat(matC) )
//printf("cusparse spmm done\n");
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down

0 comments on commit d9f9129

Please sign in to comment.