Skip to content

Commit

Permalink
csr support
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhao committed Mar 11, 2022
1 parent 58cea66 commit 8ba96a7
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 31 deletions.
4 changes: 2 additions & 2 deletions coo_graph/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def sym_normalization(edge_index, num_nodes, faster_device='cuda:0'):
return DAD.coalesce().to(original_device)


def sparse_2d_split(st, split_size, split_dim=0, device='cpu'):
def sparse_2d_split(st, split_size, split_dim=0):
seps = list(range(0, st.size(split_dim), split_size)) + [st.size(split_dim)]
parts = []
split_idx = st.indices()[split_dim]
other_idx = st.indices()[1 - split_dim]
def make_2d_st(idx0, idx1, val, sz0, sz1):
return torch.sparse_coo_tensor(torch.stack([idx0, idx1]), val, (sz0, sz1), device=device).coalesce()
return torch.sparse_coo_tensor(torch.stack([idx0, idx1]), val, (sz0, sz1)).coalesce()
for lower, upper in zip(seps[:-1], seps[1:]):
mask: torch.Tensor = (split_idx < upper) & (split_idx >= lower)
if split_dim == 0:
Expand Down
35 changes: 25 additions & 10 deletions coo_graph/parted_coo_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@


class BasicGraph:
def __init__(self, d, name, device, half_enabled):
dtype=torch.float16 if half_enabled else torch.float
def __init__(self, d, name, device):
self.name, self.device, self.attr_dict = name, device, d
self.adj = d['adj'].to(dtype=dtype)
self.features = d['features'].to(device, dtype=dtype)
self.adj = d['adj']
self.features = d['features']
self.labels = d['labels'].to(device).to(torch.float if d['labels'].dim()==2 else torch.long)
self.train_mask, self.val_mask, self.test_mask = (d[t].bool().to(device) for t in ("train_mask", 'val_mask', 'test_mask'))
self.num_nodes, self.num_edges, self.num_classes = d["num_nodes"], d['num_edges'], d['num_classes']
Expand Down Expand Up @@ -50,7 +49,7 @@ def load_dict(path):


class COO_Graph(BasicGraph):
def __init__(self, name, full_graph_cache_enabled=True, device='cpu', half_enabled=False, preprocess_for='GCN'):
def __init__(self, name, full_graph_cache_enabled=True, device='cpu', preprocess_for='GCN'):
self.preprocess_for = preprocess_for
self.cache_path = GraphCache.full_graph_path(name, preprocess_for)
if full_graph_cache_enabled and os.path.exists(self.cache_path):
Expand All @@ -59,7 +58,7 @@ def __init__(self, name, full_graph_cache_enabled=True, device='cpu', half_enabl
src_data = datasets.load_dataset(name)
cached_attr_dict = graph_utils.preprocess(name, src_data, preprocess_for) # norm feat, remove edge_index, add adj
GraphCache.save_dict(cached_attr_dict, self.cache_path)
super().__init__(cached_attr_dict, name, device, half_enabled)
super().__init__(cached_attr_dict, name, device)

def partition(self, num_parts, padding=True):
begin = datetime.datetime.now()
Expand Down Expand Up @@ -94,23 +93,39 @@ def partition(self, num_parts, padding=True):
print(self.name, num_parts, 'partition done', datetime.datetime.now()-begin)


def coo_to_csr(coo, device, dtype):
print('coo', coo.size())
csr = coo.to_sparse_csr()
print('csr', csr.size())
small_csr = torch.sparse_csr_tensor(csr.crow_indices().to(dtype=torch.int32),
csr.col_indices().to(dtype=torch.int32), csr.values().to(dtype=dtype), size=csr.size(), dtype=dtype, device=device)
print('small csr', small_csr.size())
return small_csr

class Parted_COO_Graph(BasicGraph):
def __init__(self, name, rank, num_parts, device='cpu', half_enabled=False, preprocess_for='GCN'):
def __init__(self, name, rank, num_parts, device='cpu', half_enabled=False, csr_enabled=False, preprocess_for='GCN'):
# self.full_g = COO_Graph(name, preprocess_for, True, 'cpu')
self.rank, self.num_parts = rank, num_parts
cache_path = GraphCache.parted_graph_path(name, preprocess_for, rank, num_parts)
if not os.path.exists(cache_path):
raise Exception('Not parted yet. Run COO_Graph.partition() first.', cache_path)
cached_attr_dict = GraphCache.load_dict(cache_path)
super().__init__(cached_attr_dict, name, device, half_enabled)
super().__init__(cached_attr_dict, name, device)

# adj and features are local already
self.local_num_nodes = self.adj.size(0)
self.local_num_edges = self.adj.values().size(0)
self.local_labels = self.labels[self.local_num_nodes*rank:self.local_num_nodes*(rank+1)]
self.local_train_mask = self.train_mask[self.local_num_nodes*rank:self.local_num_nodes*(rank+1)].bool()

self.adj_parts = graph_utils.sparse_2d_split(self.adj, self.local_num_nodes, split_dim=1, device=device)
# adj and features are local already
dtype=torch.float16 if half_enabled else torch.float
self.features = self.features.to(device, dtype=dtype)

adj_parts = graph_utils.sparse_2d_split(self.adj, self.local_num_nodes, split_dim=1)
if csr_enabled:
self.adj_parts = [coo_to_csr(adj, device, dtype) for adj in adj_parts]
else:
self.adj_parts = [adj.to(device=device, dtype=dtype) for adj in adj_parts]

def __repr__(self):
local_g = f'<Local: {self.rank}, |V|: {self.local_num_nodes}, |E|: {self.local_num_edges}>'
Expand Down
8 changes: 7 additions & 1 deletion dist_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,16 @@ def train(g, env, total_epoch):


def main(env, args):
env.csr_enabled = False
env.csr_enabled = True

env.half_enabled = True
env.half_enabled = False
env.logger.log('proc begin:', env)
with env.timer.timing('total'):
g = Parted_COO_Graph(args.dataset, env.rank, env.world_size, env.device, env.half_enabled)
g = Parted_COO_Graph(args.dataset, env.rank, env.world_size, env.device, env.half_enabled, env.csr_enabled)
env.logger.log('graph loaded', g)
env.logger.log('graph loaded', torch.cuda.memory_summary())
train(g, env, total_epoch=args.epoch)
env.logger.log(env.timer.summary_all(), rank=0)

1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def process_wrapper(rank, args, func):

env = dist_utils.DistEnv(rank, args.nprocs, args.backend)
env.half_enabled = True
env.csr_enabled = True
func(env, args)


Expand Down
8 changes: 6 additions & 2 deletions models/cached_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
import torch.distributed as dist

try:
from spmm_cpp import spmm_cusparse
from spmm_cpp import spmm_cusparse_coo, spmm_cusparse_csr
def spmm(A,B,C):
spmm_cusparse(A.indices()[0].int(), A.indices()[1].int(), A.values(), A.size(0), A.size(1), \
if DistEnv.env.csr_enabled:
spmm_cusparse_csr(A.crow_indices().int(), A.col_indices().int(), A.values(), A.size(0), A.size(1), \
B, C, 1.0, 1.0, DistEnv.env.half_enabled)
else:
spmm_cusparse_coo(A.indices()[0].int(), A.indices()[1].int(), A.values(), A.size(0), A.size(1), \
B, C, 1.0, 1.0, DistEnv.env.half_enabled)
except ImportError as e:
print('no spmm cpp:', e)
Expand Down
84 changes: 68 additions & 16 deletions spmm_cpp/spmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@
#include <pybind11/pybind11.h>
#include <torch/extension.h>


#define CHECK_CUDA(func) \
{ \
cudaError_t status = (func); \
if (status != cudaSuccess) { \
printf("CUDA API failed at line %d with error: %s (%d)\n", \
__LINE__, cudaGetErrorString(status), status); \
} \
}



#define CHECK_CUSPARSE(func) \
{ \
cusparseStatus_t status = (func); \
Expand All @@ -14,37 +26,77 @@
}

typedef const at::Tensor& T;
void spmm_cusparse(T A_row_idx, T A_col_idx, T A_values, int32_t A_row, int32_t A_col, T B, T C, float alpha, float beta, int use_half) {

void spmm_cusparse_coo(T A_row_idx, T A_col_idx, T A_values, int32_t A_row, int32_t A_col, T B, T C, float alpha, float beta, int use_half) {
cusparseSpMatDescr_t matA;
cusparseDnMatDescr_t matB, matC; // mat from torch is row major
if (use_half){
//printf("use half\n");
CHECK_CUSPARSE( cusparseCreateCoo(&matA, A_row, A_col, A_values.size(0), A_row_idx.data_ptr<int>(), A_col_idx.data_ptr<int>(), A_values.data_ptr<at::Half>(), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) )

//printf("use half A ready\n");
CHECK_CUSPARSE( cusparseCreateDnMat(&matB, B.size(0), B.size(1), B.size(1), B.data_ptr<at::Half>(), CUDA_R_16F, CUSPARSE_ORDER_ROW) )
CHECK_CUSPARSE( cusparseCreateDnMat(&matC, C.size(0), C.size(1), C.size(1), C.data_ptr<at::Half>(), CUDA_R_16F, CUSPARSE_ORDER_ROW) )

//printf("use half BC ready\n");
CHECK_CUSPARSE(cusparseSpMM(at::cuda::getCurrentCUDASparseHandle(), CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, CUDA_R_16F, CUSPARSE_SPMM_COO_ALG4, NULL)); //CUSPARSE_MM_ALG_DEFAULT, CUSPARSE_SPMM_CSR_ALG2 , CUSPARSE_SPMM_COO_ALG4
CHECK_CUSPARSE( cusparseCreateCoo(&matA, A_row, A_col, A_values.size(0), A_row_idx.data_ptr<int>(), A_col_idx.data_ptr<int>(), A_values.data_ptr<at::Half>(), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) )

CHECK_CUSPARSE( cusparseCreateDnMat(&matB, B.size(0), B.size(1), B.size(1), B.data_ptr<at::Half>(), CUDA_R_16F, CUSPARSE_ORDER_ROW) )
CHECK_CUSPARSE( cusparseCreateDnMat(&matC, C.size(0), C.size(1), C.size(1), C.data_ptr<at::Half>(), CUDA_R_16F, CUSPARSE_ORDER_ROW) )
}else{
//printf("use float32\n");
CHECK_CUSPARSE( cusparseCreateCoo(&matA, A_row, A_col, A_values.size(0), A_row_idx.data_ptr<int>(), A_col_idx.data_ptr<int>(), A_values.data_ptr<float>(), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F) )
CHECK_CUSPARSE( cusparseCreateCoo(&matA, A_row, A_col, A_values.size(0), A_row_idx.data_ptr<int>(), A_col_idx.data_ptr<int>(), A_values.data_ptr<float>(), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F) )

CHECK_CUSPARSE( cusparseCreateDnMat(&matB, B.size(0), B.size(1), B.size(1), B.data_ptr<float>(), CUDA_R_32F, CUSPARSE_ORDER_ROW) )
CHECK_CUSPARSE( cusparseCreateDnMat(&matC, C.size(0), C.size(1), C.size(1), C.data_ptr<float>(), CUDA_R_32F, CUSPARSE_ORDER_ROW) )
CHECK_CUSPARSE( cusparseCreateDnMat(&matB, B.size(0), B.size(1), B.size(1), B.data_ptr<float>(), CUDA_R_32F, CUSPARSE_ORDER_ROW) )
CHECK_CUSPARSE( cusparseCreateDnMat(&matC, C.size(0), C.size(1), C.size(1), C.data_ptr<float>(), CUDA_R_32F, CUSPARSE_ORDER_ROW) )
}

CHECK_CUSPARSE(cusparseSpMM(at::cuda::getCurrentCUDASparseHandle(), CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, CUDA_R_32F, CUSPARSE_SPMM_COO_ALG4, NULL)); //CUSPARSE_MM_ALG_DEFAULT, CUSPARSE_SPMM_CSR_ALG2 , CUSPARSE_SPMM_COO_ALG4


CHECK_CUSPARSE( cusparseDestroySpMat(matA) )
CHECK_CUSPARSE( cusparseDestroyDnMat(matB) )
CHECK_CUSPARSE( cusparseDestroyDnMat(matC) )
}


void spmm_cusparse_csr(T A_row_offsets, T A_col_idx, T A_values, int32_t A_row, int32_t A_col, T B, T C, float alpha, float beta, int use_half) {
cusparseSpMatDescr_t matA;
cusparseDnMatDescr_t matB, matC; // mat from torch is row major

void* dBuffer = NULL;
size_t bufferSize = 0;

cusparseHandle_t handle = NULL;

CHECK_CUSPARSE( cusparseCreate(&handle) )

if (use_half){
CHECK_CUSPARSE( cusparseCreateCsr(&matA, A_row, A_col, A_values.size(0), A_row_offsets.data_ptr<int>(), A_col_idx.data_ptr<int>(), A_values.data_ptr<at::Half>(), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) )

CHECK_CUSPARSE( cusparseCreateDnMat(&matB, B.size(0), B.size(1), B.size(1), B.data_ptr<at::Half>(), CUDA_R_16F, CUSPARSE_ORDER_ROW) )
CHECK_CUSPARSE( cusparseCreateDnMat(&matC, C.size(0), C.size(1), C.size(1), C.data_ptr<at::Half>(), CUDA_R_16F, CUSPARSE_ORDER_ROW) )

}else{
CHECK_CUSPARSE( cusparseCreateCsr(&matA, A_row, A_col, A_values.size(0), A_row_offsets.data_ptr<int>(), A_col_idx.data_ptr<int>(), A_values.data_ptr<float>(), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F) )

CHECK_CUSPARSE( cusparseCreateDnMat(&matB, B.size(0), B.size(1), B.size(1), B.data_ptr<float>(), CUDA_R_32F, CUSPARSE_ORDER_ROW) )
CHECK_CUSPARSE( cusparseCreateDnMat(&matC, C.size(0), C.size(1), C.size(1), C.data_ptr<float>(), CUDA_R_32F, CUSPARSE_ORDER_ROW) )
}

CHECK_CUSPARSE( cusparseSpMM_bufferSize(
handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, matA, matB, &beta, matC, CUDA_R_32F,
CUSPARSE_SPMM_CSR_ALG2, &bufferSize) )
CHECK_CUDA( cudaMalloc(&dBuffer, bufferSize) )

CHECK_CUSPARSE(cusparseSpMM(at::cuda::getCurrentCUDASparseHandle(), CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, CUDA_R_32F, CUSPARSE_SPMM_CSR_ALG2 , dBuffer)); //CUSPARSE_MM_ALG_DEFAULT, CUSPARSE_SPMM_CSR_ALG2 , CUSPARSE_SPMM_COO_ALG4

CHECK_CUDA( cudaFree(dBuffer) )

CHECK_CUSPARSE( cusparseDestroySpMat(matA) )
CHECK_CUSPARSE( cusparseDestroyDnMat(matB) )
CHECK_CUSPARSE( cusparseDestroyDnMat(matC) )
//printf("cusparse spmm done\n");

CHECK_CUSPARSE( cusparseCreate(&handle) )
}



PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm_cusparse", &spmm_cusparse, "SpMM wrapper for cusparse");
m.def("spmm_cusparse_coo", &spmm_cusparse_coo, "SpMM wrapper for cusparse coo");
m.def("spmm_cusparse_csr", &spmm_cusparse_csr, "SpMM wrapper for cusparse csr");
}

0 comments on commit 8ba96a7

Please sign in to comment.