-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmodels2d.py
52 lines (45 loc) · 1.59 KB
/
models2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
from sklearn import model_selection
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.conv1d1 = nn.Conv1d(120, 9, kernel_size=9)
self.conv1d2 = nn.Conv1d(9, 9, kernel_size=9)
self.conv1d3 = nn.Conv1d(9, 10, kernel_size=11)
self.fc0 = nn.Linear(940, 435)
self.fc11 = nn.Linear(435, 2)
self.fc12 = nn.Linear(435, 2)
self.fc2 = nn.Linear(2, 2)
self.gru = nn.GRU(2, 501, 3, batch_first=True)
self.fc3 = nn.Linear(501, 35)
def encode(self, x):
h = F.relu(self.conv1d1(x))
h = F.relu(self.conv1d2(h))
h = F.relu(self.conv1d3(h))
h = h.view(h.size(0), -1)
h = F.selu(self.fc0(h))
return self.fc11(h), self.fc12(h)
def reparametrize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
w = eps.mul(std).add_(mu)
return w
else:
return mu
def decode(self, z):
z = F.selu(self.fc2(z))
z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 120, 1)
out, h = self.gru(z)
out_reshape = out.contiguous().view(-1, out.size(-1))
y0 = F.softmax(self.fc3(out_reshape))
y = y0.contiguous().view(out.size(0), -1, y0.size(-1))
return y
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
return self.decode(z), mu, logvar