Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhao committed Nov 12, 2021
1 parent d30bcb0 commit a75e87a
Show file tree
Hide file tree
Showing 14 changed files with 511 additions and 404 deletions.
2 changes: 1 addition & 1 deletion coo_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .parted_coo_graph import Parted_COO_Graph
from .parted_coo_graph import COO_Graph, Parted_COO_Graph
93 changes: 37 additions & 56 deletions coo_graph/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
# All rights reserved.

import os
import numpy
import scipy
import scipy.sparse
import torch
import json


data_root = os.path.join(os.path.dirname(__file__), '..', 'data')
dgl_root = os.path.join(data_root, 'dgl_datasets')
pyg_root = os.path.join(data_root, 'pyg_datasets')
for path in [data_root, dgl_root, pyg_root]:
os.makedirs(path, exist_ok=True)


def save_dataset(edge_index, features, labels, train_mask, val_mask, test_mask, num_nodes, num_edges, num_classes, name):
if name.startswith('a_quarter'):
Expand All @@ -26,8 +26,8 @@ def save_dataset(edge_index, features, labels, train_mask, val_mask, test_mask,
num_nodes = max_node
num_edges = edge_index.size(1)
path = os.path.join(data_root, name+'.torch')
torch.save({"edge_index": edge_index, "features": features, "labels": labels,
"train_mask": train_mask, 'val_mask': val_mask, 'test_mask': test_mask,
torch.save({"edge_index": edge_index.int(), "features": features, "labels": labels.char(),
"train_mask": train_mask.bool(), 'val_mask': val_mask.bool(), 'test_mask': test_mask.bool(),
"num_nodes": num_nodes, 'num_edges': num_edges, 'num_classes': num_classes}, path)


Expand All @@ -38,46 +38,45 @@ def load_dataset(name):
return torch.load(path)


def prepare_dgl_dataset(source, name):
dgl_dataset: dgl.data.DGLDataset = source(raw_dir=dgl_root)
def prepare_dgl_dataset(dgl_name, tag):
import dgl
dataset_sources = {'cora': dgl.data.CoraGraphDataset, 'reddit': dgl.data.RedditDataset}
dgl_dataset: dgl.data.DGLDataset = dataset_sources[dgl_name](raw_dir=dgl_root)
g = dgl_dataset[0]
edge_index = torch.stack(g.adj_sparse('coo'))
save_dataset(edge_index, g.ndata['feat'], g.ndata['label'],
g.ndata['train_mask'], g.ndata['val_mask'], g.ndata['test_mask'],
g.num_nodes(), g.num_edges(), dgl_dataset.num_classes, name)
g.num_nodes(), g.num_edges(), dgl_dataset.num_classes, tag)


def prepare_pyg_dataset(source, name):
pyg_dataset: torch_geometric.data.Dataset = source(root=os.path.join(pyg_root, name))
def prepare_pyg_dataset(pyg_name, tag):
import torch_geometric
dataset_sources = {'reddit': torch_geometric.datasets.Reddit,
'flickr': torch_geometric.datasets.Flickr,
'yelp': torch_geometric.datasets.Yelp,
'amazon-products': torch_geometric.datasets.AmazonProducts}
pyg_dataset: torch_geometric.data.Dataset = dataset_sources[pyg_name](root=os.path.join(pyg_root, pyg_name))
data: torch_geometric.data.Data = pyg_dataset[0]
save_dataset(data.edge_index, data.x, data.y,
data.val_mask, data.val_mask, data.test_mask,
data.num_nodes, data.num_edges, pyg_dataset.num_classes, name)


def prepare_dataset(name):
import dgl
import torch_geometric
dataset_source_mapping = {'cora': dgl.data.CoraGraphDataset,
'reddit_reorder': dgl.data.RedditDataset,
'reddit': torch_geometric.datasets.Reddit,
'a_quarter_reddit': torch_geometric.datasets.Reddit,
'flickr': torch_geometric.datasets.Flickr,
'yelp': torch_geometric.datasets.Yelp}

for path in [data_root, dgl_root, pyg_root]:
os.makedirs(path, exist_ok=True)
try:
source_class = dataset_source_mapping[name]
except KeyError:
raise Exception('no source for such dataset', name)
if source_class.__module__.startswith('dgl.'):
prepare_dgl_dataset(source_class, name)
elif source_class.__module__.startswith('torch_geometric.'):
prepare_pyg_dataset(source_class, name)
else: # other libs TODO
pass

data.num_nodes, data.num_edges, pyg_dataset.num_classes, tag)


def prepare_dataset(tag):
if tag=='reddit':
return prepare_pyg_dataset('reddit', tag)
elif tag=='flickr':
return prepare_pyg_dataset('flickr', tag)
elif tag == 'yelp': # graphsaints
return prepare_pyg_dataset('yelp', tag)
elif tag=='amazon-products': # graphsaints
return prepare_pyg_dataset('amazon-products', tag)
elif tag=='cora':
return prepare_dgl_dataset('cora', tag)
elif tag=='reddit_reorder':
return prepare_dgl_dataset('reddit', tag)
elif tag=='a_quarter_reddit':
return prepare_pyg_dataset('reddit', tag)


def check_edges(edge_index, num_nodes):
Expand All @@ -100,26 +99,8 @@ def check_edges(edge_index, num_nodes):


def main():
r = load_dataset('reddit')
check_edges(r['edge_index'], r['num_nodes'])

data = numpy.load(os.path.join(dgl_root, 'reddit', 'reddit_data.npz'))
x = torch.from_numpy(data['feature']).to(torch.float)
y = torch.from_numpy(data['label']).to(torch.long)
# split = torch.from_numpy(data['node_types'])

adj = scipy.sparse.load_npz(os.path.join(dgl_root, 'reddit', 'reddit_graph.npz'))
row = torch.from_numpy(adj.row).to(torch.long)
col = torch.from_numpy(adj.col).to(torch.long)
edge_index = torch.stack([row, col], dim=0)
check_edges(edge_index, x.size(0))



return
for dataset_name in ['cora', 'reddit', 'flickr', 'yelp']:
for dataset_name in ['cora', 'reddit', 'flickr', 'yelp', 'a_quarter_reddit','amazon-products']:
prepare_dataset(dataset_name)
pass


if __name__ == '__main__':
Expand Down
97 changes: 26 additions & 71 deletions coo_graph/graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import os
import os.path
import torch
import math
import datetime

import random
import numpy as np

def preprocess(attr_dict, preprocess_for): # normalize feature and make adj matrix from edge index
begin = datetime.datetime.now()
print(preprocess_for, 'preprocess begin', begin)
attr_dict["features"] = attr_dict["features"] / attr_dict["features"].sum(1, keepdim=True).clamp(min=1)
if preprocess_for == 'GCN': # make the coo format sym lap matrix
attr_dict['adj'] = sym_normalization(attr_dict['edge_index'], attr_dict['num_nodes'])
elif preprocess_for == 'GAT':
attr_dict['adj'] = attr_dict['edge_index']
attr_dict.pop('edge_index')
print(preprocess_for, 'preprocess done', datetime.datetime.now() - begin)
return attr_dict


def add_self_loops(edge_index, num_nodes): # from pyg
Expand All @@ -16,83 +23,31 @@ def add_self_loops(edge_index, num_nodes): # from pyg

def sym_normalization(edge_index, num_nodes, faster_device='cuda:0'):
original_device = edge_index.device
# begin = datetime.datetime.now()
begin = datetime.datetime.now()
edge_index = add_self_loops(edge_index, num_nodes)
# A = torch.sparse_coo_tensor(edge_index, torch.ones(len(edge_index[0])), (num_nodes, num_nodes)).coalesce()
# return A
A = torch.sparse_coo_tensor(edge_index, torch.ones(len(edge_index[0])), (num_nodes, num_nodes), device=faster_device).coalesce()
degree_vec = torch.sparse.sum(A, 0).pow(-0.5).to_dense()
I_edge_index = torch.stack((torch.arange(num_nodes), torch.arange(num_nodes)))
D_rsqrt = torch.sparse_coo_tensor(I_edge_index, degree_vec, (num_nodes, num_nodes), device=faster_device)
DA = torch.sparse.mm(D_rsqrt, A)
del A # to save GPU mem
# print(DA)
DAD = torch.sparse.mm(DA, D_rsqrt)
del DA
# end = datetime.datetime.now()
# print('sym norm done', end - begin)
end = datetime.datetime.now()
print('sym norm done', end - begin)
return DAD.coalesce().to(original_device)


def save_cache_dict(d, path):
if os.path.exists(path):
print(f'warning: cache file {path} is overwritten.')
torch.save(d, path)


def to(v, device):
if type(v) == torch.Tensor:
return v.to(device)
if type(v) in (list,tuple) and type(v[0]) == torch.Tensor:
return [i.to(device) for i in v]
return v

def load_cache_dict(path):
if not os.path.exists(path):
raise Exception('no such file: '+path)
d = torch.load(path)
updated_d = {}
for k,v in d.items():
if type(v) == torch.Tensor and v.is_sparse:
updated_d[k] = v.coalesce()
if type(v) == list and type(v[0]) == torch.Tensor and v[0].is_sparse:
updated_d[k] = [i.coalesce() for i in v]
d.update(updated_d)
return d

def split_2D_coo_by_size(split_idx, other_idx, values, split_size): # 2D tensors only
coo_parts = []
while len(split_idx) > 0:
mask: torch.Tensor = split_idx < split_size
coo_parts.append( (split_idx[mask], other_idx[mask], values[mask], split_size) ) # padding? TODO
split_idx = split_idx[mask.logical_not()] - split_size
return coo_parts


def split_2D_coo(split_idx, other_idx, values, seps): # 2D tensors only
coo_parts = []
def sparse_2d_split(st, split_size, split_dim=0):
seps = list(range(0, st.size(split_dim), split_size)) + [st.size(split_dim)]
parts = []
split_idx = st.indices()[split_dim]
other_idx = st.indices()[1 - split_dim]
make_2d_st = lambda idx0, idx1, val, sz0, sz1: torch.sparse_coo_tensor(torch.stack([idx0, idx1]), val, (sz0, sz1)).coalesce()
for lower, upper in zip(seps[:-1], seps[1:]):
mask: torch.Tensor = (split_idx < upper) & (split_idx >= lower)
coo_parts.append( (split_idx[mask]-lower, other_idx[mask], values[mask], upper-lower) )
return coo_parts


def make_2D_coo(idx0, idx1, val, sz0, sz1): # to mat: row is 0, col is 1
return torch.sparse_coo_tensor(torch.stack([idx0, idx1]), val, (sz0, sz1)).coalesce()


def CAGNET_split(coo_adj, split_size): # A.T==A only
seps = list(range(0, coo_adj.size(0), split_size))+[coo_adj.size(0)]
row_parts = split_2D_coo(coo_adj.indices()[0], coo_adj.indices()[1], coo_adj.values(), seps) # Ai is rows part i
row_part_coo_list, row_col_part_coos_list = [], []
for part_row_idx, full_col_idx, val, row_sz in row_parts:
print(f'coo split: {val.size(0)}, {row_sz}')
row_part_coo_list.append(make_2D_coo(part_row_idx, full_col_idx, val, row_sz, coo_adj.size(0)))
# row_col_part_coos_list.append( [make_2D_coo(p_row, p_col, p_val, row_sz, col_sz) \
# for p_col, p_row, p_val, col_sz in split_2D_coo(full_col_idx, part_row_idx, val, seps)])
row_col_part_coos = []
for p_col, p_row, p_val, col_sz in split_2D_coo(full_col_idx, part_row_idx, val, seps):
print(f'\tcoo split: {p_val.size(0)}, {row_sz}, {col_sz}')
row_col_part_coos.append(make_2D_coo(p_row, p_col, p_val, row_sz, col_sz))
row_col_part_coos_list.append(row_col_part_coos)
return row_part_coo_list, row_col_part_coos_list
if split_dim == 0:
parts.append(make_2d_st(split_idx[mask]-lower, other_idx[mask], st.values()[mask], upper-lower, st.size(1)))
else:
parts.append(make_2d_st(other_idx[mask], split_idx[mask]-lower, st.values()[mask], st.size(0), upper-lower))
return parts
Loading

0 comments on commit a75e87a

Please sign in to comment.