-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgmn_layer.py
144 lines (116 loc) · 5.32 KB
/
gmn_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from typing import Any, Callable, Dict, List, Optional, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import ModuleList, Sequential
from torch_geometric.nn.conv import PNAConv
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.resolver import activation_resolver
class TransposeModule(nn.Module):
def __init__(self, dim0, dim1):
super(TransposeModule, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
return x.transpose(self.dim0, self.dim1)
class ResidualModule(nn.Module):
def __init__(self, x):
super(ResidualModule, self).__init__()
self.x = x
def forward(self, x):
return x + self.x
class GMNConv(PNAConv):
r"""The Graph Mixer convolution operator
from the `"The Graph Mixer Networks"
<https://arxiv.org/abs/2301.12493>`_ paper
.. math::
\mathbf{x}_i^{\prime} = \Mix \left(
\mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus}
\left( \mathbf{x}_i, \mathbf{x}_j \right)
\right)
with
.. math::
\Mix = \MLP_{2}(\LayerNorm((\MLP_{1}
((\LayerNorm(\x))^{T}))^{T} + \x)) + \x
and
.. math::
\bigoplus = \underbrace{\begin{bmatrix}
1 \\
S(\mathbf{D}, \alpha=1) \\
S(\mathbf{D}, \alpha=-1)
\end{bmatrix} }_{\text{scalers}}
\otimes \underbrace{\begin{bmatrix}
\mu \\
\sigma \\
\max \\
\min
\end{bmatrix}}_{\text{aggregators}},
where :math:`\gamma_{\mathbf{\Theta}}` denotes MLPs.
.. note::
For an example of using :obj:`GMNConv`, see `examples/gmn.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/
examples/gmn.py>`_.
Args:
in_channels (int): Size of each input sample, or :obj:`-1` to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
aggregators (List[str]): Set of aggregation function identifiers,
namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"var"` and :obj:`"std"`.
scalers (List[str]): Set of scaling function identifiers, namely
:obj:`"identity"`, :obj:`"amplification"`,
:obj:`"attenuation"`, :obj:`"linear"` and
:obj:`"inverse_linear"`.
deg (torch.Tensor): Histogram of in-degrees of nodes in the training
set, used by scalers to normalize.
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). (default :obj:`None`)
towers (int, optional): Number of towers (default: :obj:`1`).
post_layers (int, optional): Number of transformation layers after
aggregation (default: :obj:`1`).
divide_input (bool, optional): Whether the input features should
be split between towers or not (default: :obj:`False`).
act (str or callable, optional): Pre- and post-layer activation
function to use. (default: :obj:`"relu"`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
train_norm (bool, optional): Whether normalization parameters
are trainable. (default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
edge indices :math:`(2, |\mathcal{E}|)`,
edge features :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
"""
def __init__(self, in_channels: int, out_channels: int,
aggregators: List[str], scalers: List[str], deg: Tensor,
edge_dim: Optional[int] = None, towers: int = 1,
post_layers: int = 1, divide_input: bool = False,
act: Union[str, Callable, None] = "relu",
act_kwargs: Optional[Dict[str, Any]] = None, **kwargs):
super().__init__(in_channels, out_channels, aggregators, scalers, deg,
edge_dim, towers, divide_input, **kwargs)
self.post_nns = ModuleList()
for _ in range(towers):
in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in
modules = [Linear(in_channels, self.F_out)]
for _ in range(post_layers - 1):
x = self.F_out
modules += [nn.LayerNorm(x)]
modules += [TransposeModule(1, -1)]
modules += [activation_resolver(act, **(act_kwargs or {}))]
modules += [Linear(self.F_out, self.F_out)]
modules += [TransposeModule(1, -1)] #
modules += [ResidualModule(x)]
modules += [nn.LayerNorm(self.F_out)]
modules += [activation_resolver(act, **(act_kwargs or {}))]
modules += [Linear(self.F_out, self.F_out)]
modules += [ResidualModule(self.F_out)]
self.post_nns.append(Sequential(*modules))
self.lin = Linear(out_channels, out_channels)
self.reset_parameters()
def message(self, x_j: Tensor) -> Tensor:
return x_j