-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnetstacks.py
57 lines (44 loc) · 1.38 KB
/
netstacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
File: netstacks.py
By: Peter Caven, peter@sparseinference.com
Description:
Blocks and stacks of blocks for neural nets.
Blocks (neural net modules), and compositions of blocks (neural stacks),
are a higher level abstraction than layers.
The blocks defined here are variations on Residual Nets,
where the input layer and the output layer have the same dimension,
while the hidden layer(s) could be thinner or wider.
"""
import torch
import torch.nn as nn
class Block(nn.Module):
"""
A ResNet module.
"""
def __init__(self, iDim, hDim):
super().__init__()
#----
self.W0 = nn.Linear(iDim, hDim)
self.W1 = nn.Linear(hDim, iDim)
#----
def LS(w):
return w.weight.numel() + w.bias.numel()
self.parameterCount = LS(self.W0) + LS(self.W1)
#----
def forward(self, x):
return (self.W1(self.W0(x).clamp(min=0)) + x).clamp(min=0.0)
class Stack(nn.Module):
"""
A stack of blocks.
"""
def __init__(self, block, stackDepth, iDim, hDim, *args, **kwargs):
super().__init__()
#----
self.stack = nn.ModuleList([block(iDim, hDim, *args, **kwargs) for _ in range(stackDepth)])
#----
self.parameterCount = sum(nn.parameterCount for nn in self.stack)
#----
def forward(self, x):
for nn in self.stack:
x = nn(x)
return x