-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeometric_linear.py
114 lines (100 loc) · 4.65 KB
/
geometric_linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Optional
import copy
import math
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch_geometric.nn import inits
class Linear(torch.nn.Module):
r"""Applies a linear tranformation to the incoming data
.. math::
\mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}
similar to :class:`torch.nn.Linear`.
It supports lazy initialization and customizable weight and bias
initialization.
Args:
in_channels (int): Size of each input sample.
Will be initialized lazily in case :obj:`-1`.
out_channels (int): Size of each output sample.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
weight_initializer (str, optional): The initializer for the weight
matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"`
or :obj:`None`).
If set to :obj:`None`, will match default weight initialization of
:class:`torch.nn.Linear`. (default: :obj:`None`)
bias_initializer (str, optional): The initializer for the bias
vector (:obj:`"zeros"` or :obj:`None`).
If set to :obj:`None`, will match default bias initialization of
:class:`torch.nn.Linear`. (default: :obj:`None`)
"""
def __init__(self, in_channels: int, out_channels: int, bias: bool = True,
weight_initializer: Optional[str] = None,
bias_initializer: Optional[str] = None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.weight_initializer = weight_initializer
self.bias_initializer = bias_initializer
if in_channels > 0:
self.weight = Parameter(torch.Tensor(out_channels, in_channels))
else:
self.weight = torch.nn.parameter.UninitializedParameter()
self._hook = self.register_forward_pre_hook(
self.initialize_parameters)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def __deepcopy__(self, memo):
out = Linear(self.in_channels, self.out_channels, self.bias
is not None, self.weight_initializer,
self.bias_initializer)
if self.in_channels > 0:
out.weight = copy.deepcopy(self.weight, memo)
if self.bias is not None:
out.bias = copy.deepcopy(self.bias, memo)
return out
def reset_parameters(self):
if self.in_channels > 0:
if self.weight_initializer == 'glorot':
inits.glorot(self.weight)
elif self.weight_initializer == 'uniform':
bound = 1.0 / math.sqrt(self.weight.size(-1))
torch.nn.init.uniform_(self.weight.data, -bound, bound)
elif self.weight_initializer == 'kaiming_uniform':
inits.kaiming_uniform(self.weight, fan=self.in_channels,
a=math.sqrt(5))
elif self.weight_initializer is None:
inits.kaiming_uniform(self.weight, fan=self.in_channels,
a=math.sqrt(5))
else:
raise RuntimeError(
f"Linear layer weight initializer "
f"'{self.weight_initializer}' is not supported")
if self.in_channels > 0 and self.bias is not None:
if self.bias_initializer == 'zeros':
inits.zeros(self.bias)
elif self.bias_initializer is None:
inits.uniform(self.in_channels, self.bias)
else:
raise RuntimeError(
f"Linear layer bias initializer "
f"'{self.bias_initializer}' is not supported")
def forward(self, x: Tensor) -> Tensor:
#print("neighbourhod feature:", x.shape)
#print("self.weight:", self.weight.shape)
return F.linear(x, self.weight, self.bias)
@torch.no_grad()
def initialize_parameters(self, module, input):
if isinstance(self.weight, torch.nn.parameter.UninitializedParameter):
self.in_channels = input[0].size(-1)
self.weight.materialize((self.out_channels, self.in_channels))
self.reset_parameters()
module._hook.remove()
delattr(module, '_hook')
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, bias={self.bias is not None})')