-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtransforms.py
69 lines (51 loc) · 1.98 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import copy
import torch
from torch_geometric.utils.dropout import dropout_adj
from torch_geometric.transforms import Compose
class DropFeatures:
r"""Drops node features with probability p."""
def __init__(self, p=None):
assert 0. < p < 1., \
'Dropout probability has to be between 0 and 1, but got %.2f' % p
self.p = p
def __call__(self, data):
drop_mask = torch.empty(size=(data.x.size(1),),
dtype=torch.float32,
device=data.x.device).uniform_(0, 1) < self.p
data.x[:, drop_mask] = 0
return data
def __repr__(self):
return '{}(p={})'.format(self.__class__.__name__, self.p)
class DropEdges:
r"""Drops edges with probability p."""
def __init__(self, p, force_undirected=False):
assert 0. < p < 1., \
'Dropout probability has to be between 0 and 1, but got %.2f' % p
self.p = p
self.force_undirected = force_undirected
def __call__(self, data):
edge_index = data.edge_index
edge_attr = data.edge_attr if 'edge_attr' in data else None
edge_index, edge_attr = dropout_adj(edge_index, edge_attr,
p=self.p,
force_undirected=self.force_undirected)
data.edge_index = edge_index
if edge_attr is not None:
data.edge_attr = edge_attr
return data
def __repr__(self):
return '{}(p={}, force_undirected={})'.format(
self.__class__.__name__,
self.p,
self.force_undirected)
def get_graph_drop_transform(drop_edge_p, drop_feat_p):
transforms = list()
# make copy of graph
transforms.append(copy.deepcopy)
# drop edges
if drop_edge_p > 0.:
transforms.append(DropEdges(drop_edge_p))
# drop features
if drop_feat_p > 0.:
transforms.append(DropFeatures(drop_feat_p))
return Compose(transforms)