-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdatasets.py
154 lines (117 loc) · 5.38 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os.path as osp
import torch
from torch_geometric.datasets import Planetoid, Coauthor, Amazon
import torch_geometric.transforms as T
from deeprobust.graph.data import Dataset, Dpr2Pyg
import argparse
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def uniqueId(dataset_name, target_node, perturbation):
return dataset_name.lower() +"_"+str(target_node)+"_"+str(perturbation)
def prepare_data(args, lcc=False):
transform = T.ToSparseTensor()
if args.dataset == "Cora" or args.dataset == "CiteSeer" or args.dataset == "PubMed":
if lcc:
dpr_data = Dataset(root='/tmp/', name=(args.dataset).lower())
dataset = Dpr2Pyg(dpr_data, transform=transform)
else:
dataset = get_planetoid_dataset(args.dataset, args.normalize_features, transform)
permute_masks = random_planetoid_splits if args.random_splits else None
elif args.dataset == "cs" or args.dataset == "physics":
dataset = get_coauthor_dataset(args.dataset, args.normalize_features, transform)
permute_masks = random_coauthor_amazon_splits
elif args.dataset == "computers" or args.dataset == "photo":
dataset = get_amazon_dataset(args.dataset, args.normalize_features, transform)
permute_masks = random_coauthor_amazon_splits
print("Data:", dataset[0])
return dataset, permute_masks
def get_planetoid_dataset(name, normalize_features=False, transform=None):
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
dataset = Planetoid(path, name)
if transform is not None and normalize_features:
dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
elif normalize_features:
dataset.transform = T.NormalizeFeatures()
elif transform is not None:
dataset.transform = transform
return dataset
def get_coauthor_dataset(name, normalize_features=False, transform=None):
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
dataset = Coauthor(path, name)
if transform is not None and normalize_features:
dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
elif normalize_features:
dataset.transform = T.NormalizeFeatures()
elif transform is not None:
dataset.transform = transform
return dataset
def get_amazon_dataset(name, normalize_features=False, transform=None):
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
dataset = Amazon(path, name)
if transform is not None and normalize_features:
dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
elif normalize_features:
dataset.transform = T.NormalizeFeatures()
elif transform is not None:
dataset.transform = transform
return dataset
def index_to_mask(index, size):
mask = torch.zeros(size, dtype=torch.bool, device=index.device)
mask[index] = 1
return mask
def random_planetoid_splits(data, num_classes, lcc_mask=None):
# Set new random planetoid splits:
# * 20 * num_classes labels for training
# * 500 labels for validation
# * 1000 labels for testing
indices = []
if lcc_mask is not None:
for i in range(num_classes):
index = (data.y[lcc_mask] == i).nonzero().view(-1)
index = index[torch.randperm(index.size(0))]
indices.append(index)
else:
for i in range(num_classes):
index = (data.y == i).nonzero().view(-1)
index = index[torch.randperm(index.size(0))]
indices.append(index)
train_index = torch.cat([i[:20] for i in indices], dim=0)
rest_index = torch.cat([i[20:] for i in indices], dim=0)
rest_index = rest_index[torch.randperm(rest_index.size(0))]
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
data.val_mask = index_to_mask(rest_index[:500], size=data.num_nodes)
data.test_mask = index_to_mask(rest_index[500:1500], size=data.num_nodes)
return data
def random_coauthor_amazon_splits(data, num_classes, lcc_mask=None, seed=None):
# Set random coauthor/co-purchase splits:
# * 20 * num_classes labels for training
# * 30 * num_classes labels for validation
# rest labels for testing
g = None
if seed is not None:
g = torch.Generator()
g.manual_seed(seed)
indices = []
if lcc_mask is not None:
for i in range(num_classes):
index = (data.y[lcc_mask] == i).nonzero().view(-1)
index = index[torch.randperm(index.size(0), generator=g)]
indices.append(index)
else:
for i in range(num_classes):
index = (data.y == i).nonzero().view(-1)
index = index[torch.randperm(index.size(0), generator=g)]
indices.append(index)
train_index = torch.cat([i[:20] for i in indices], dim=0)
val_index = torch.cat([i[20:50] for i in indices], dim=0)
rest_index = torch.cat([i[50:] for i in indices], dim=0)
rest_index = rest_index[torch.randperm(rest_index.size(0))]
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
data.val_mask = index_to_mask(val_index, size=data.num_nodes)
data.test_mask = index_to_mask(rest_index, size=data.num_nodes)
return data