diff --git a/coo_graph/__init__.py b/coo_graph/__init__.py index efa105e..5755534 100644 --- a/coo_graph/__init__.py +++ b/coo_graph/__init__.py @@ -1 +1 @@ -from .parted_coo_graph import Parted_COO_Graph +from .parted_coo_graph import COO_Graph, Parted_COO_Graph diff --git a/coo_graph/datasets.py b/coo_graph/datasets.py index 1877afe..deca514 100644 --- a/coo_graph/datasets.py +++ b/coo_graph/datasets.py @@ -2,15 +2,15 @@ # All rights reserved. import os -import numpy -import scipy -import scipy.sparse import torch -import json + data_root = os.path.join(os.path.dirname(__file__), '..', 'data') dgl_root = os.path.join(data_root, 'dgl_datasets') pyg_root = os.path.join(data_root, 'pyg_datasets') +for path in [data_root, dgl_root, pyg_root]: + os.makedirs(path, exist_ok=True) + def save_dataset(edge_index, features, labels, train_mask, val_mask, test_mask, num_nodes, num_edges, num_classes, name): if name.startswith('a_quarter'): @@ -26,8 +26,8 @@ def save_dataset(edge_index, features, labels, train_mask, val_mask, test_mask, num_nodes = max_node num_edges = edge_index.size(1) path = os.path.join(data_root, name+'.torch') - torch.save({"edge_index": edge_index, "features": features, "labels": labels, - "train_mask": train_mask, 'val_mask': val_mask, 'test_mask': test_mask, + torch.save({"edge_index": edge_index.int(), "features": features, "labels": labels.char(), + "train_mask": train_mask.bool(), 'val_mask': val_mask.bool(), 'test_mask': test_mask.bool(), "num_nodes": num_nodes, 'num_edges': num_edges, 'num_classes': num_classes}, path) @@ -38,46 +38,45 @@ def load_dataset(name): return torch.load(path) -def prepare_dgl_dataset(source, name): - dgl_dataset: dgl.data.DGLDataset = source(raw_dir=dgl_root) +def prepare_dgl_dataset(dgl_name, tag): + import dgl + dataset_sources = {'cora': dgl.data.CoraGraphDataset, 'reddit': dgl.data.RedditDataset} + dgl_dataset: dgl.data.DGLDataset = dataset_sources[dgl_name](raw_dir=dgl_root) g = dgl_dataset[0] edge_index = torch.stack(g.adj_sparse('coo')) save_dataset(edge_index, g.ndata['feat'], g.ndata['label'], g.ndata['train_mask'], g.ndata['val_mask'], g.ndata['test_mask'], - g.num_nodes(), g.num_edges(), dgl_dataset.num_classes, name) + g.num_nodes(), g.num_edges(), dgl_dataset.num_classes, tag) -def prepare_pyg_dataset(source, name): - pyg_dataset: torch_geometric.data.Dataset = source(root=os.path.join(pyg_root, name)) +def prepare_pyg_dataset(pyg_name, tag): + import torch_geometric + dataset_sources = {'reddit': torch_geometric.datasets.Reddit, + 'flickr': torch_geometric.datasets.Flickr, + 'yelp': torch_geometric.datasets.Yelp, + 'amazon-products': torch_geometric.datasets.AmazonProducts} + pyg_dataset: torch_geometric.data.Dataset = dataset_sources[pyg_name](root=os.path.join(pyg_root, pyg_name)) data: torch_geometric.data.Data = pyg_dataset[0] save_dataset(data.edge_index, data.x, data.y, data.val_mask, data.val_mask, data.test_mask, - data.num_nodes, data.num_edges, pyg_dataset.num_classes, name) - - -def prepare_dataset(name): - import dgl - import torch_geometric - dataset_source_mapping = {'cora': dgl.data.CoraGraphDataset, - 'reddit_reorder': dgl.data.RedditDataset, - 'reddit': torch_geometric.datasets.Reddit, - 'a_quarter_reddit': torch_geometric.datasets.Reddit, - 'flickr': torch_geometric.datasets.Flickr, - 'yelp': torch_geometric.datasets.Yelp} - - for path in [data_root, dgl_root, pyg_root]: - os.makedirs(path, exist_ok=True) - try: - source_class = dataset_source_mapping[name] - except KeyError: - raise Exception('no source for such dataset', name) - if source_class.__module__.startswith('dgl.'): - prepare_dgl_dataset(source_class, name) - elif source_class.__module__.startswith('torch_geometric.'): - prepare_pyg_dataset(source_class, name) - else: # other libs TODO - pass - + data.num_nodes, data.num_edges, pyg_dataset.num_classes, tag) + + +def prepare_dataset(tag): + if tag=='reddit': + return prepare_pyg_dataset('reddit', tag) + elif tag=='flickr': + return prepare_pyg_dataset('flickr', tag) + elif tag == 'yelp': # graphsaints + return prepare_pyg_dataset('yelp', tag) + elif tag=='amazon-products': # graphsaints + return prepare_pyg_dataset('amazon-products', tag) + elif tag=='cora': + return prepare_dgl_dataset('cora', tag) + elif tag=='reddit_reorder': + return prepare_dgl_dataset('reddit', tag) + elif tag=='a_quarter_reddit': + return prepare_pyg_dataset('reddit', tag) def check_edges(edge_index, num_nodes): @@ -100,26 +99,8 @@ def check_edges(edge_index, num_nodes): def main(): - r = load_dataset('reddit') - check_edges(r['edge_index'], r['num_nodes']) - - data = numpy.load(os.path.join(dgl_root, 'reddit', 'reddit_data.npz')) - x = torch.from_numpy(data['feature']).to(torch.float) - y = torch.from_numpy(data['label']).to(torch.long) - # split = torch.from_numpy(data['node_types']) - - adj = scipy.sparse.load_npz(os.path.join(dgl_root, 'reddit', 'reddit_graph.npz')) - row = torch.from_numpy(adj.row).to(torch.long) - col = torch.from_numpy(adj.col).to(torch.long) - edge_index = torch.stack([row, col], dim=0) - check_edges(edge_index, x.size(0)) - - - - return - for dataset_name in ['cora', 'reddit', 'flickr', 'yelp']: + for dataset_name in ['cora', 'reddit', 'flickr', 'yelp', 'a_quarter_reddit','amazon-products']: prepare_dataset(dataset_name) - pass if __name__ == '__main__': diff --git a/coo_graph/graph_utils.py b/coo_graph/graph_utils.py index 7de3f48..0c58689 100644 --- a/coo_graph/graph_utils.py +++ b/coo_graph/graph_utils.py @@ -1,11 +1,18 @@ -import os -import os.path import torch -import math import datetime -import random -import numpy as np + +def preprocess(attr_dict, preprocess_for): # normalize feature and make adj matrix from edge index + begin = datetime.datetime.now() + print(preprocess_for, 'preprocess begin', begin) + attr_dict["features"] = attr_dict["features"] / attr_dict["features"].sum(1, keepdim=True).clamp(min=1) + if preprocess_for == 'GCN': # make the coo format sym lap matrix + attr_dict['adj'] = sym_normalization(attr_dict['edge_index'], attr_dict['num_nodes']) + elif preprocess_for == 'GAT': + attr_dict['adj'] = attr_dict['edge_index'] + attr_dict.pop('edge_index') + print(preprocess_for, 'preprocess done', datetime.datetime.now() - begin) + return attr_dict def add_self_loops(edge_index, num_nodes): # from pyg @@ -16,83 +23,31 @@ def add_self_loops(edge_index, num_nodes): # from pyg def sym_normalization(edge_index, num_nodes, faster_device='cuda:0'): original_device = edge_index.device - # begin = datetime.datetime.now() + begin = datetime.datetime.now() edge_index = add_self_loops(edge_index, num_nodes) - # A = torch.sparse_coo_tensor(edge_index, torch.ones(len(edge_index[0])), (num_nodes, num_nodes)).coalesce() - # return A A = torch.sparse_coo_tensor(edge_index, torch.ones(len(edge_index[0])), (num_nodes, num_nodes), device=faster_device).coalesce() degree_vec = torch.sparse.sum(A, 0).pow(-0.5).to_dense() I_edge_index = torch.stack((torch.arange(num_nodes), torch.arange(num_nodes))) D_rsqrt = torch.sparse_coo_tensor(I_edge_index, degree_vec, (num_nodes, num_nodes), device=faster_device) DA = torch.sparse.mm(D_rsqrt, A) del A # to save GPU mem - # print(DA) DAD = torch.sparse.mm(DA, D_rsqrt) del DA - # end = datetime.datetime.now() - # print('sym norm done', end - begin) + end = datetime.datetime.now() + print('sym norm done', end - begin) return DAD.coalesce().to(original_device) -def save_cache_dict(d, path): - if os.path.exists(path): - print(f'warning: cache file {path} is overwritten.') - torch.save(d, path) - - -def to(v, device): - if type(v) == torch.Tensor: - return v.to(device) - if type(v) in (list,tuple) and type(v[0]) == torch.Tensor: - return [i.to(device) for i in v] - return v - -def load_cache_dict(path): - if not os.path.exists(path): - raise Exception('no such file: '+path) - d = torch.load(path) - updated_d = {} - for k,v in d.items(): - if type(v) == torch.Tensor and v.is_sparse: - updated_d[k] = v.coalesce() - if type(v) == list and type(v[0]) == torch.Tensor and v[0].is_sparse: - updated_d[k] = [i.coalesce() for i in v] - d.update(updated_d) - return d - -def split_2D_coo_by_size(split_idx, other_idx, values, split_size): # 2D tensors only - coo_parts = [] - while len(split_idx) > 0: - mask: torch.Tensor = split_idx < split_size - coo_parts.append( (split_idx[mask], other_idx[mask], values[mask], split_size) ) # padding? TODO - split_idx = split_idx[mask.logical_not()] - split_size - return coo_parts - - -def split_2D_coo(split_idx, other_idx, values, seps): # 2D tensors only - coo_parts = [] +def sparse_2d_split(st, split_size, split_dim=0): + seps = list(range(0, st.size(split_dim), split_size)) + [st.size(split_dim)] + parts = [] + split_idx = st.indices()[split_dim] + other_idx = st.indices()[1 - split_dim] + make_2d_st = lambda idx0, idx1, val, sz0, sz1: torch.sparse_coo_tensor(torch.stack([idx0, idx1]), val, (sz0, sz1)).coalesce() for lower, upper in zip(seps[:-1], seps[1:]): mask: torch.Tensor = (split_idx < upper) & (split_idx >= lower) - coo_parts.append( (split_idx[mask]-lower, other_idx[mask], values[mask], upper-lower) ) - return coo_parts - - -def make_2D_coo(idx0, idx1, val, sz0, sz1): # to mat: row is 0, col is 1 - return torch.sparse_coo_tensor(torch.stack([idx0, idx1]), val, (sz0, sz1)).coalesce() - - -def CAGNET_split(coo_adj, split_size): # A.T==A only - seps = list(range(0, coo_adj.size(0), split_size))+[coo_adj.size(0)] - row_parts = split_2D_coo(coo_adj.indices()[0], coo_adj.indices()[1], coo_adj.values(), seps) # Ai is rows part i - row_part_coo_list, row_col_part_coos_list = [], [] - for part_row_idx, full_col_idx, val, row_sz in row_parts: - print(f'coo split: {val.size(0)}, {row_sz}') - row_part_coo_list.append(make_2D_coo(part_row_idx, full_col_idx, val, row_sz, coo_adj.size(0))) - # row_col_part_coos_list.append( [make_2D_coo(p_row, p_col, p_val, row_sz, col_sz) \ - # for p_col, p_row, p_val, col_sz in split_2D_coo(full_col_idx, part_row_idx, val, seps)]) - row_col_part_coos = [] - for p_col, p_row, p_val, col_sz in split_2D_coo(full_col_idx, part_row_idx, val, seps): - print(f'\tcoo split: {p_val.size(0)}, {row_sz}, {col_sz}') - row_col_part_coos.append(make_2D_coo(p_row, p_col, p_val, row_sz, col_sz)) - row_col_part_coos_list.append(row_col_part_coos) - return row_part_coo_list, row_col_part_coos_list + if split_dim == 0: + parts.append(make_2d_st(split_idx[mask]-lower, other_idx[mask], st.values()[mask], upper-lower, st.size(1))) + else: + parts.append(make_2d_st(other_idx[mask], split_idx[mask]-lower, st.values()[mask], st.size(0), upper-lower)) + return parts diff --git a/coo_graph/parted_coo_graph.py b/coo_graph/parted_coo_graph.py index 550b450..43a65e2 100644 --- a/coo_graph/parted_coo_graph.py +++ b/coo_graph/parted_coo_graph.py @@ -6,95 +6,109 @@ from . import datasets -class Parted_COO_Graph(): - def full_graph_path(self, root=datasets.data_root): - return os.path.join(root, f'{self.name}_{self.preprocess_for}_full.coo_graph') - - def parted_graph_path(self, rank, total_parted, root=datasets.data_root): - dirpath = os.path.join(root, f'{self.name}_{self.preprocess_for}_{total_parted}_parts') - os.makedirs(dirpath, exist_ok=True) - return os.path.join(dirpath, f'part_{rank}_of_{total_parted}.coo_graph') - - @property - def split_size(self): - assert self.num_parts > 1 - return (self.num_nodes+self.num_parts-1)//self.num_parts - - - def sparse_resize(self, sp_t, size): - resized = torch.sparse_coo_tensor(sp_t._indices(), sp_t._values(), size, device=self.device).coalesce() - return resized - - def pad(self): - pad_size = self.split_size*self.num_parts - self.num_nodes - assert(pad_size>=0) - if pad_size==0: - return self - if self.local_features.size(0)' + return f'' + + +class GraphCache: + @staticmethod + def full_graph_path(name, preprocess_for, root=datasets.data_root): + return os.path.join(root, f'{name}_{preprocess_for}_full.coo_graph') + @staticmethod + def parted_graph_path(name, preprocess_for, rank, num_parts, root=datasets.data_root): + dirpath = os.path.join(root, f'{name}_{preprocess_for}_{num_parts}_parts') + os.makedirs(dirpath, exist_ok=True) + return os.path.join(dirpath, f'part_{rank}_of_{num_parts}.coo_graph') + @staticmethod + def save_dict(d, path): + if os.path.exists(path): + print(f'warning: cache file {path} is overwritten.') + d_to_save = {} + for k, v in d.items(): + d_to_save[k] = v.clone() if type(v)==torch.Tensor else v + torch.save(d_to_save, path) + @staticmethod + def load_dict(path): + d = torch.load(path) + updated_d = {} + for k, v in d.items(): + if type(v) == torch.Tensor and v.is_sparse: + updated_d[k] = v.coalesce() + d.update(updated_d) + return d + + +class COO_Graph(BasicGraph): + def __init__(self, name, preprocess_for='GCN', full_graph_cache_enabled=True, device='cpu'): + self.preprocess_for = preprocess_for + self.cache_path = GraphCache.full_graph_path(name, preprocess_for) + if full_graph_cache_enabled and os.path.exists(self.cache_path): + cached_attr_dict = GraphCache.load_dict(self.cache_path) else: - local_g = "Full" - return f'' + src_data = datasets.load_dataset(name) + cached_attr_dict = graph_utils.preprocess(src_data, preprocess_for) # norm feat, remove edge_index, add adj + GraphCache.save_dict(cached_attr_dict, self.cache_path) + super().__init__(cached_attr_dict, name, device) - def partition(self, num_parts): + def partition(self, num_parts, padding=True): begin = datetime.datetime.now() - print(f'{num_parts} partition begin', self.full_graph_path(), begin) - self.num_parts = num_parts - full_dict = {k:v for k,v in self.attr_dict.items() if k not in ['adj', 'edge_index', 'features']} - local_dict = {'local_'+x: torch.split(getattr(self,x), self.split_size) for x in ['train_mask', 'labels', 'features']} - # local_dict.update(dict(zip(['local_adj', 'local_adj_parts'], graph_utils.CAGNET_split(self.adj, self.split_size)))) - local_adj_list, local_adj_parts_list = graph_utils.CAGNET_split(self.adj, self.split_size) - local_dict['local_num_nodes'] = [adj.size(0) for adj in local_adj_list] - local_dict['local_num_edges'] = [adj.values().size(0) for adj in local_adj_list] - local_dict['local_adj_parts'] = local_adj_parts_list + print(self.name, num_parts, 'partition begin', begin) + attr_dict = self.attr_dict.copy() + split_size = (self.num_nodes+num_parts-1)//num_parts + pad_size = split_size*num_parts-self.num_nodes + + adj_list = graph_utils.sparse_2d_split(self.adj, split_size) + features_list = list(torch.split(self.features, split_size)) + + if padding and pad_size>0: + padding_feat = torch.zeros((pad_size, self.features.size(1)), dtype=self.features.dtype, device=self.device) + features_list[-1] = torch.cat((features_list[-1], padding_feat)) + + padding_labels_size = torch.Size(pad_size)+self.labels.size()[1:] + padding_labels = torch.zeros(padding_labels_size, dtype=self.labels.dtype, device=self.device) + attr_dict['labels'] = torch.cat((self.labels, padding_labels)) + + padding_mask = torch.zeros(pad_size, dtype=self.train_mask.dtype, device=self.device) + for key in ['train_mask', 'val_mask', 'test_mask']: + attr_dict[key] = torch.cat((attr_dict[key], padding_mask)) + + adj_list = [torch.sparse_coo_tensor(adj._indices(), adj._values(), (split_size, split_size*num_parts)) + for adj in adj_list] + for i in range(num_parts): - full_dict.update({k: v[i] for k, v in local_dict.items()}) - graph_utils.save_cache_dict(full_dict, self.parted_graph_path(i, num_parts)) - print(f'{num_parts} partition done ', datetime.datetime.now()-begin) + cache_path = GraphCache.parted_graph_path(self.name, self.preprocess_for, i, num_parts) + attr_dict.update({'adj': adj_list[i], 'features': features_list[i]}) + GraphCache.save_dict(attr_dict, cache_path) + print(Parted_COO_Graph(self.name, i, num_parts, self.preprocess_for)) + print(self.name, num_parts, 'partition done', datetime.datetime.now()-begin) - def preprocess(self, attr_dict): - begin = datetime.datetime.now() - print('preprocess begin', self.full_graph_path(), begin) - attr_dict["features"] = attr_dict["features"] / attr_dict["features"].sum(1, keepdim=True).clamp(min=1) - if self.preprocess_for=='GCN': # make the coo format sym lap matrix - attr_dict['adj'] = graph_utils.sym_normalization(attr_dict['edge_index'], attr_dict['num_nodes']) - elif self.preprocess_for=='GAT': - pass - attr_dict.pop('edge_index') - print('preprocess done', self.full_graph_path(), datetime.datetime.now()-begin) - return attr_dict +class Parted_COO_Graph(BasicGraph): + def __init__(self, name, rank, num_parts, preprocess_for='GCN', device='cpu'): + # self.full_g = COO_Graph(name, preprocess_for, True, 'cpu') + self.rank, self.num_parts = rank, num_parts + cache_path = GraphCache.parted_graph_path(name, preprocess_for, rank, num_parts) + if not os.path.exists(cache_path): + raise Exception('Not parted yet. Run COO_Graph.partition() first.', cache_path) + cached_attr_dict = GraphCache.load_dict(cache_path) + super().__init__(cached_attr_dict, name, device) + + self.local_num_nodes = self.adj.size(0) + self.local_num_edges = self.adj.values().size(0) + self.local_labels = self.labels[self.local_num_nodes*rank:self.local_num_nodes*(rank+1)].long() + self.local_train_mask = self.train_mask[self.local_num_nodes*rank:self.local_num_nodes*(rank+1)].bool() + + self.adj_parts = graph_utils.sparse_2d_split(self.adj, self.local_num_nodes, split_dim=1) + + def __repr__(self): + local_g = f'' + return super().__repr__() + local_g diff --git a/dist_gcn_train.py b/dist_train.py similarity index 91% rename from dist_gcn_train.py rename to dist_train.py index c6f9822..1f1e049 100644 --- a/dist_gcn_train.py +++ b/dist_train.py @@ -1,6 +1,7 @@ import datetime from coo_graph import Parted_COO_Graph from handcraft_gcn import GCN +from handcraft_gat import GAT import torch import torch.nn.functional as F @@ -8,10 +9,11 @@ def train(g, env, total_epoch): model = GCN(g, env) + # model = GAT(g, env) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for epoch in range(total_epoch): with env.timer.timing('epoch'): - outputs = model(g.local_features) + outputs = model(g.features) optimizer.zero_grad() if g.local_labels[g.local_train_mask].size(0) > 0: loss = F.nll_loss(outputs[g.local_train_mask], g.local_labels[g.local_train_mask]) @@ -31,7 +33,7 @@ def train(g, env, total_epoch): def main(env, args): env.logger.log('proc begin:', env) with env.timer.timing('total'): - g = Parted_COO_Graph(args.dataset, rank=env.rank, num_parts=env.world_size, device=env.device).pad() + g = Parted_COO_Graph(args.dataset, rank=env.rank, num_parts=env.world_size, device=env.device) env.logger.log('graph loaded', g) train(g, env, total_epoch=args.epoch) env.logger.log(env.timer.summary_all(), rank=0) diff --git a/dist_utils.py b/dist_utils.py deleted file mode 100644 index bd9e20e..0000000 --- a/dist_utils.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -import datetime as dt -import torch -import torch.distributed as dist -import math -import time -import pickle -import statistics -from collections import defaultdict -import tempfile - - -class DistEnv: - def __init__(self, rank, world_size, backend='nccl'): - assert(rank>=0) - assert(world_size>0) - self.rank, self.world_size = rank, world_size - self.backend = backend - self.init_device() - self.init_dist_groups() - self.logger = DistLogger(self) - self.timer = DistTimer(self) - self.store = dist.FileStore(os.path.join(tempfile.gettempdir(), 'torch-dist'), self.world_size) - DistEnv.env = self # no global... - - def __repr__(self): - return ''%(self.rank, self.world_size, self.backend) - - def init_device(self): - if torch.cuda.device_count()>1: - self.device = torch.device('cuda', self.rank) - torch.cuda.set_device(self.device) - else: - self.device = torch.device('cpu') - - def all_reduce_sum(self, tensor): - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=self.world_group) - - def broadcast(self, tensor, src): - dist.broadcast(tensor, src=src, group=self.world_group) - - def all_gather_then_cat(self, src_t): - recv_list = [torch.zeros_like(src_t) for _ in range(self.world_size)] - dist.all_gather(recv_list, src_t, group=self.world_group) - return torch.cat(recv_list, dim=0) - - def barrier_all(self): - dist.barrier(self.world_group) - - def init_dist_groups(self): - dist.init_process_group(backend=self.backend, rank=self.rank, world_size=self.world_size, init_method='env://') - self.world_group = dist.new_group(list(range(self.world_size))) - self.p2p_group_dict = {} - for src in range(self.world_size): - for dst in range(src+1, self.world_size): - self.p2p_group_dict[(src, dst)] = dist.new_group([src, dst]) - self.p2p_group_dict[(dst, src)] = self.p2p_group_dict[(src, dst)] - - -class DistUtil: - def __init__(self, env): - self.env = env - - -class DistLogger(DistUtil): - def __init__(self, env): - super().__init__(env) - self.log_root = os.path.join(os.path.dirname(__file__), f'logs_{env.world_size}') - os.makedirs(self.log_root, exist_ok=True) - self.log_fname = os.path.join(self.log_root, 'all_log_%d.txt'%self.env.rank) - - def log(self, *args, oneline=False, rank=-1): - if rank!=-1 and self.env.rank!=rank: - return - head = '%s [%1d] '%(dt.datetime.now().time(), self.env.rank) - tail = '\r' if oneline else '\n' - the_whole_line = head+' '.join(map(str, args))+tail - print(the_whole_line, end='', flush=True) # to prevent line breaking - with open(self.log_fname, 'a+') as f: - print(the_whole_line, end='', file=f, flush=True) # to prevent line breaking - - -class DistTimer(DistUtil): - def __init__(self, env): - super().__init__(env) - self.start_time_dict = {} - self.duration_dict = defaultdict(float) - self.count_dict = defaultdict(int) - - def summary(self): - s = '\ntimer summary:\n' + "\n".join("%6.2fs %5d %s" % (self.duration_dict[key], self.count_dict[key], key) for key in self.duration_dict) - return s - - def sync_duration_dicts(self): - self.env.store.set('duration_dict_%d'%self.env.rank, pickle.dumps(self.duration_dict)) - self.env.barrier_all() - self.all_durations = [pickle.loads(self.env.store.get('duration_dict_%d'%rank)) for rank in range(self.env.world_size)] - - def summary_all(self): - self.sync_duration_dicts() - avg_dict = {} - std_dict = {} - for key in self.duration_dict: - data = [d[key] for d in self.all_durations] - avg_dict[key], std_dict[key] = statistics.mean(data), statistics.stdev(data) - s = '\ntimer summary:\n' + "\n".join("%6.2fs %6.2fs %5d %s" % (avg_dict[key], std_dict[key], self.count_dict[key], key) for key in self.duration_dict) - return s - - def detail_all(self): - self.sync_duration_dicts() - avg_dict = {} - std_dict = {} - detail_dict = {} - for key in self.duration_dict: - data = [d[key] for d in self.all_durations] - avg_dict[key], std_dict[key] = statistics.mean(data), statistics.stdev(data) - detail_dict[key] = ' '.join("%6.2f"%x for x in data) - s = '\ntimer summary:\n' + "\n".join("%6.2fs %6.2fs %5d %s \ndetail: %s \n--------------" % (avg_dict[key], std_dict[key], self.count_dict[key], key, detail_dict[key]) for key in self.duration_dict) - return s - - class TimerCtx: - def __init__(self, timer, key, cuda): - self.cuda = cuda - self.timer = timer - self.key = key - - def __enter__(self): - if self.cuda: - torch.cuda.synchronize() - self.timer.start_time_dict[self.key] = time.time() - return self - - def __exit__(self, type, value, traceback): - if self.cuda: - torch.cuda.synchronize() - d=time.time() - self.timer.start_time_dict[self.key] - self.timer.duration_dict[self.key]+=d - self.timer.count_dict[self.key]+=1 - - def timing(self, key): - return DistTimer.TimerCtx(self, key, cuda=False) - - def timing_cuda(self, key): - return DistTimer.TimerCtx(self, key, cuda=True) - - def start(self, key): - self.start_time_dict[key] = time.time() - return self.start_time_dict[key] - - def stop(self, key, *other_keys): - def log(k, d=time.time() - self.start_time_dict[key]): - self.duration_dict[k]+=d - self.count_dict[k]+=1 - log(key) - for subkey in other_keys: - log(key+'-'+subkey) - return - - -if __name__ == '__main__': - pass - diff --git a/dist_utils/__init__.py b/dist_utils/__init__.py new file mode 100644 index 0000000..32795c8 --- /dev/null +++ b/dist_utils/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2021, Zhao CHEN +# All rights reserved. + +from .env import DistEnv + + +def main(): + pass + + +if __name__ == '__main__': + main() diff --git a/dist_utils/env.py b/dist_utils/env.py new file mode 100644 index 0000000..0a6223a --- /dev/null +++ b/dist_utils/env.py @@ -0,0 +1,64 @@ +import os +import torch +import torch.distributed as dist +import tempfile + +from .timer import DistTimer +from .logger import DistLogger + + +class DistEnv: + def __init__(self, rank, world_size, backend='nccl'): + assert(rank>=0) + assert(world_size>0) + self.rank, self.world_size = rank, world_size + self.backend = backend + self.init_device() + self.init_dist_groups() + self.logger = DistLogger(self) + self.timer = DistTimer(self) + self.store = dist.FileStore(os.path.join(tempfile.gettempdir(), 'torch-dist'), self.world_size) + DistEnv.env = self # no global... + + def __repr__(self): + return ''%(self.rank, self.world_size, self.backend) + + def init_device(self): + if torch.cuda.device_count()>1: + self.device = torch.device('cuda', self.rank) + torch.cuda.set_device(self.device) + else: + self.device = torch.device('cpu') + + def all_reduce_sum(self, tensor): + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=self.world_group) + + def broadcast(self, tensor, src): + dist.broadcast(tensor, src=src, group=self.world_group) + + def all_gather_then_cat(self, src_t): + recv_list = [torch.zeros_like(src_t) for _ in range(self.world_size)] + dist.all_gather(recv_list, src_t, group=self.world_group) + return torch.cat(recv_list, dim=0) + + def barrier_all(self): + dist.barrier(self.world_group) + + def init_dist_groups(self): + dist.init_process_group(backend=self.backend, rank=self.rank, world_size=self.world_size, init_method='env://') + self.world_group = dist.new_group(list(range(self.world_size))) + self.p2p_group_dict = {} + for src in range(self.world_size): + for dst in range(src+1, self.world_size): + self.p2p_group_dict[(src, dst)] = dist.new_group([src, dst]) + self.p2p_group_dict[(dst, src)] = self.p2p_group_dict[(src, dst)] + + +class DistUtil: + def __init__(self, env): + self.env = env + + +if __name__ == '__main__': + pass + diff --git a/dist_utils/logger.py b/dist_utils/logger.py new file mode 100644 index 0000000..f649d90 --- /dev/null +++ b/dist_utils/logger.py @@ -0,0 +1,25 @@ +import os +import datetime as dt + + +class DistLogger: + def __init__(self, env): + self.env = env + self.log_root = os.path.join(os.path.dirname(__file__), '..', f'logs_{env.world_size}') + os.makedirs(self.log_root, exist_ok=True) + self.log_fname = os.path.join(self.log_root, 'all_log_%d.txt'%self.env.rank) + + def log(self, *args, oneline=False, rank=-1): + if rank!=-1 and self.env.rank!=rank: + return + head = '%s [%1d] '%(dt.datetime.now().time(), self.env.rank) + tail = '\r' if oneline else '\n' + the_whole_line = head+' '.join(map(str, args))+tail + print(the_whole_line, end='', flush=True) # to prevent line breaking + with open(self.log_fname, 'a+') as f: + print(the_whole_line, end='', file=f, flush=True) # to prevent line breaking + + +if __name__ == '__main__': + pass + diff --git a/dist_utils/timer.py b/dist_utils/timer.py new file mode 100644 index 0000000..2965c42 --- /dev/null +++ b/dist_utils/timer.py @@ -0,0 +1,90 @@ +import datetime as dt +import torch +import math +import time +import pickle +import statistics +from collections import defaultdict + + +class TimerCtx: + def __init__(self, timer, key, cuda): + self.cuda = cuda + self.timer = timer + self.key = key + + def __enter__(self): + if self.cuda: + torch.cuda.synchronize() + self.timer.start_time_dict[self.key] = time.time() + return self + + def __exit__(self, type, value, traceback): + if self.cuda: + torch.cuda.synchronize() + d = time.time() - self.timer.start_time_dict[self.key] + self.timer.duration_dict[self.key] += d + self.timer.count_dict[self.key] += 1 + + +class DistTimer: + def __init__(self, env): + self.env = env + self.start_time_dict = {} + self.duration_dict = defaultdict(float) + self.count_dict = defaultdict(int) + + def summary(self): + s = '\ntimer summary:\n' + "\n".join("%6.2fs %5d %s" % (self.duration_dict[key], self.count_dict[key], key) for key in self.duration_dict) + return s + + def sync_duration_dicts(self): + self.env.store.set('duration_dict_%d'%self.env.rank, pickle.dumps(self.duration_dict)) + self.env.barrier_all() + self.all_durations = [pickle.loads(self.env.store.get('duration_dict_%d'%rank)) for rank in range(self.env.world_size)] + + def summary_all(self): + self.sync_duration_dicts() + avg_dict = {} + std_dict = {} + for key in self.duration_dict: + data = [d[key] for d in self.all_durations] + avg_dict[key], std_dict[key] = statistics.mean(data), statistics.stdev(data) + s = '\ntimer summary:\n' + "\n".join("%6.2fs %6.2fs %5d %s" % (avg_dict[key], std_dict[key], self.count_dict[key], key) for key in self.duration_dict) + return s + + def detail_all(self): + self.sync_duration_dicts() + avg_dict = {} + std_dict = {} + detail_dict = {} + for key in self.duration_dict: + data = [d[key] for d in self.all_durations] + avg_dict[key], std_dict[key] = statistics.mean(data), statistics.stdev(data) + detail_dict[key] = ' '.join("%6.2f"%x for x in data) + s = '\ntimer summary:\n' + "\n".join("%6.2fs %6.2fs %5d %s \ndetail: %s \n--------------" % (avg_dict[key], std_dict[key], self.count_dict[key], key, detail_dict[key]) for key in self.duration_dict) + return s + + def timing(self, key): + return TimerCtx(self, key, cuda=False) + + def timing_cuda(self, key): + return TimerCtx(self, key, cuda=True) + + def start(self, key): + self.start_time_dict[key] = time.time() + return self.start_time_dict[key] + + def stop(self, key, *other_keys): + def log(k, d=time.time() - self.start_time_dict[key]): + self.duration_dict[k]+=d + self.count_dict[k]+=1 + log(key) + for subkey in other_keys: + log(key+'-'+subkey) + return + + +if __name__ == '__main__': + pass + diff --git a/handcraft_gat.py b/handcraft_gat.py new file mode 100644 index 0000000..c3b4c07 --- /dev/null +++ b/handcraft_gat.py @@ -0,0 +1,128 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from dist_utils import DistEnv +import torch.distributed as dist + +try: + import spmm_cpp + spmm = lambda A,B,C: spmm_cpp.spmm_cusparse(A.indices()[0].int(), A.indices()[1].int(), A.values(), A.size(0), \ + A.size(1), B, C, 1, 1) +except ImportError: + spmm = lambda A,B,C: C.addmm_(A,B) + + +def broadcast(local_adj_parts, local_feature, tag): + env = DistEnv.env + z_loc = torch.zeros_like(local_feature) + feature_bcast = torch.zeros_like(local_feature) + + for src in range(env.world_size): + if src==env.rank: + feature_bcast = local_feature.clone() + # env.barrier_all() + with env.timer.timing_cuda('broadcast'): + dist.broadcast(feature_bcast, src=src) + + with env.timer.timing_cuda('spmm'): + spmm(local_adj_parts[src], feature_bcast, z_loc) + return z_loc + + +class DistGCNLayer(torch.autograd.Function): + @staticmethod + def forward(ctx, local_feature, weight, local_adj_parts, tag): + ctx.save_for_backward(local_feature, weight) + ctx.local_adj_parts = local_adj_parts + ctx.tag = tag + z_local = broadcast(local_adj_parts, local_feature, 'Forward'+tag) + with DistEnv.env.timer.timing_cuda('mm'): + z_local = torch.mm(z_local, weight) + return z_local + + @staticmethod + def backward(ctx, grad_output): + local_feature, weight = ctx.saved_tensors + ag = broadcast(ctx.local_adj_parts, grad_output, 'Backward'+ctx.tag) + with DistEnv.env.timer.timing_cuda('mm'): + grad_feature = torch.mm(ag, weight.t()) + grad_weight = torch.mm(local_feature.t(), ag) + with DistEnv.env.timer.timing_cuda('all_reduce'): + DistEnv.env.all_reduce_sum(grad_weight) + return grad_feature, grad_weight, None, None + + +class DistMMLayer(torch.autograd.Function): + @staticmethod + def forward(ctx, local_feature, weight, tag): + ctx.save_for_backward(local_feature, weight) + ctx.tag = tag + Hw = torch.mm(local_feature, weight) + all_Hw = DistEnv.env.all_gather_then_cat(Hw) + return all_Hw + + @staticmethod + def backward(ctx, grad_output): + local_feature, weight = ctx.saved_tensors + split_sz = local_feature.size(0) + rank = DistEnv.env.rank + grad_output = grad_output[split_sz*rank:split_sz*(rank+1),:] + grad_feature = torch.mm(grad_output, weight.t()) + grad_weight = torch.mm(local_feature.t(), grad_output) + DistEnv.env.all_reduce_sum(grad_weight) + return grad_feature, grad_weight, None + + +class GAT(nn.Module): + def __init__(self, g, env, hidden_dim=16): + super().__init__() + self.g, self.env = g, env + in_dim, out_dim = g.local_features.size(1), g.num_classes + torch.manual_seed(0) + + self.weight1 = nn.Parameter(torch.rand(in_dim, hidden_dim)).to(env.device) + self.weight2 = nn.Parameter(torch.rand(hidden_dim, out_dim)).to(env.device) + + self.attention_weight1 = nn.Parameter(torch.rand(2*hidden_dim, 1)).to(env.device) + self.attention_weight2 = nn.Parameter(torch.rand(out_dim*2, 1)).to(env.device) + + def forward(self, local_features): + local_edge_index = self.g.local_adj._indices() + self.env.logger.log('L1', self.weight1.sum(), self.attention_weight1.sum()) + + # Hw1 = torch.mm(local_features, self.weight1) + # all_Hw1 = self.env.all_gather_then_cat(Hw1) + all_Hw1 = DistMMLayer.apply(local_features, self.weight1, 'L1') + + # Hw_bcast = torch.zeros_like(Hw1) + # for src in range(self.env.world_size): + # if src == self.env.rank: + # Hw_bcast = Hw1.clone() + # dist.broadcast(Hw_bcast, src=src) + + edge_features = torch.cat((all_Hw1[local_edge_index[0, :], :], all_Hw1[local_edge_index[1, :], :]), dim=1) + + att_input = F.leaky_relu(torch.mm(edge_features, self.attention_weight1).squeeze()) + att_input = torch.sparse_coo_tensor(local_edge_index, att_input, self.g.local_adj.size()) + attention = torch.sparse.softmax(att_input, dim=1) + # print(attention.size(), Hw1.size()) + + hidden_features = torch.sparse.mm(attention, all_Hw1) + hidden_features = F.elu(hidden_features) + + + # self.env.logger.log('L2', self.weight2.sum(), self.attention_weight2.sum()) + # Hw2 = torch.mm(hidden_features, self.weight2) + # all_Hw2 = self.env.all_gather_then_cat(Hw2) + all_Hw2 = DistMMLayer.apply(hidden_features, self.weight2, 'L2') + edge_features = torch.cat((all_Hw2[local_edge_index[0, :], :], all_Hw2[local_edge_index[1, :], :]), dim=1) + + att_input = F.leaky_relu(torch.mm(edge_features, self.attention_weight2).squeeze()) + att_input = torch.sparse_coo_tensor(local_edge_index, att_input, self.g.local_adj.size()) + attention = torch.sparse.softmax(att_input, dim=1) + + outputs = torch.sparse.mm(attention, all_Hw2) + return F.log_softmax(outputs, 1) + diff --git a/handcraft_gcn.py b/handcraft_gcn.py index 5a8f894..93ec4c0 100644 --- a/handcraft_gcn.py +++ b/handcraft_gcn.py @@ -7,9 +7,8 @@ import torch.distributed as dist try: - import spmm_cpp - spmm = lambda A,B,C: spmm_cpp.spmm_cusparse(A.indices()[0].int(), A.indices()[1].int(), A.values(), A.size(0), \ - A.size(1), B, C, 1, 1) + from spmm_cpp import spmm_cusparse + spmm = lambda A,B,C: spmm_cusparse(A.indices()[0].int(), A.indices()[1].int(), A.values(), A.size(0), A.size(1), B, C, 1, 1) except ImportError: spmm = lambda A,B,C: C.addmm_(A,B) @@ -33,38 +32,38 @@ def broadcast(local_adj_parts, local_feature, tag): class DistGCNLayer(torch.autograd.Function): @staticmethod - def forward(ctx, local_feature, weight, local_adj_parts, tag): - ctx.save_for_backward(local_feature, weight) - ctx.local_adj_parts = local_adj_parts + def forward(ctx, features, weight, adj_parts, tag): + ctx.save_for_backward(features, weight) + ctx.adj_parts = adj_parts ctx.tag = tag - z_local = broadcast(local_adj_parts, local_feature, 'Forward'+tag) + z_local = broadcast(adj_parts, features, 'Forward'+tag) with DistEnv.env.timer.timing_cuda('mm'): z_local = torch.mm(z_local, weight) return z_local @staticmethod def backward(ctx, grad_output): - local_feature, weight = ctx.saved_tensors - ag = broadcast(ctx.local_adj_parts, grad_output, 'Backward'+ctx.tag) + features, weight = ctx.saved_tensors + ag = broadcast(ctx.adj_parts, grad_output, 'Backward'+ctx.tag) with DistEnv.env.timer.timing_cuda('mm'): - grad_feature = torch.mm(ag, weight.t()) - grad_weight = torch.mm(local_feature.t(), ag) + grad_features = torch.mm(ag, weight.t()) + grad_weight = torch.mm(features.t(), ag) with DistEnv.env.timer.timing_cuda('all_reduce'): DistEnv.env.all_reduce_sum(grad_weight) - return grad_feature, grad_weight, None, None + return grad_features, grad_weight, None, None class GCN(nn.Module): def __init__(self, g, env, hidden_dim=16): super().__init__() self.g, self.env = g, env - in_dim, out_dim = g.local_features.size(1), g.num_classes + in_dim, out_dim = g.features.size(1), g.num_classes torch.manual_seed(0) self.weight1 = nn.Parameter(torch.rand(in_dim, hidden_dim).to(env.device)) self.weight2 = nn.Parameter(torch.rand(hidden_dim, out_dim).to(env.device)) def forward(self, features): - hidden_features1 = F.relu(DistGCNLayer.apply(features, self.weight1, self.g.local_adj_parts, 'L1')) - outputs = DistGCNLayer.apply(hidden_features1, self.weight2, self.g.local_adj_parts, 'L2') + hidden_features1 = F.relu(DistGCNLayer.apply(features, self.weight1, self.g.adj_parts, 'L1')) + outputs = DistGCNLayer.apply(hidden_features1, self.weight2, self.g.adj_parts, 'L2') return F.log_softmax(outputs, 1) diff --git a/dist_main.py b/main.py similarity index 92% rename from dist_main.py rename to main.py index f12a342..90ad62c 100644 --- a/dist_main.py +++ b/main.py @@ -3,7 +3,7 @@ import torch import dist_utils -import dist_gcn_train +import dist_train import torch.distributed as dist @@ -25,10 +25,10 @@ def process_wrapper(rank, args, func): if __name__ == "__main__": num_GPUs = torch.cuda.device_count() parser = argparse.ArgumentParser() - parser.add_argument("--nprocs", type=int, default=num_GPUs if num_GPUs>1 else 4) + parser.add_argument("--nprocs", type=int, default=num_GPUs if num_GPUs>1 else 8) parser.add_argument("--epoch", type=int, default=20) parser.add_argument("--backend", type=str, default='nccl' if num_GPUs>1 else 'gloo') parser.add_argument("--dataset", type=str, default='reddit') args = parser.parse_args() - process_args = (args, dist_gcn_train.main) + process_args = (args, dist_train.main) torch.multiprocessing.spawn(process_wrapper, process_args, args.nprocs) diff --git a/prepare_data.py b/prepare_data.py index 715a914..0a59dcb 100644 --- a/prepare_data.py +++ b/prepare_data.py @@ -5,21 +5,20 @@ def main(): - cached = True + cached = False # r = COO_Graph('cora') - # r = Parted_COO_Graph('flickr') - # r = coo_graph.Parted_COO_Graph('cora', full_graph_cache_enabled=cached) - # r = coo_graph.Parted_COO_Graph('flickr', full_graph_cache_enabled=cached) - r = coo_graph.Parted_COO_Graph('reddit', full_graph_cache_enabled=cached) - # r = coo_graph.Parted_COO_Graph('a_quarter_reddit', full_graph_cache_enabled=True) - print(r.adj.size()) + # r = coo_graph.COO_Graph('cora', full_graph_cache_enabled=cached) + # r = coo_graph.COO_Graph('flickr', full_graph_cache_enabled=cached) + # r = coo_graph.COO_Graph('reddit', full_graph_cache_enabled=cached) + r = coo_graph.COO_Graph('a_quarter_reddit', full_graph_cache_enabled=cached) r.partition(8) r.partition(4) - - # r = COO_Graph('reddit', cached=True) - # r = COO_Graph('AmazonProducts') - # r = COO_Graph('Yelp') - print(r) + return + for name in ['reddit', 'flickr', 'cora']: + r = coo_graph.COO_Graph(name, full_graph_cache_enabled=cached) + r.partition(8) + r.partition(4) + print(r) return