-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsei_body.py
109 lines (89 loc) · 3.54 KB
/
sei_body.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import numpy as np
import torch.nn as nn
from sei import *
class SeiBody(nn.Module):
def __init__(self, sequence_length=4096, model_name=''):
"""
Parameters
----------
sequence_length : int
"""
super(SeiBody, self).__init__()
self.lconv1 = nn.Sequential(
nn.Conv1d(4, 480, kernel_size=9, padding=4),
nn.Conv1d(480, 480, kernel_size=9, padding=4))
self.conv1 = nn.Sequential(
nn.Conv1d(480, 480, kernel_size=9, padding=4),
nn.ReLU(inplace=True),
nn.Conv1d(480, 480, kernel_size=9, padding=4),
nn.ReLU(inplace=True))
self.lconv2 = nn.Sequential(
nn.MaxPool1d(kernel_size=4, stride=4),
nn.Dropout(p=0.2),
nn.Conv1d(480, 640, kernel_size=9, padding=4),
nn.Conv1d(640, 640, kernel_size=9, padding=4))
self.conv2 = nn.Sequential(
nn.Dropout(p=0.2),
nn.Conv1d(640, 640, kernel_size=9,padding=4),
nn.ReLU(inplace=True),
nn.Conv1d(640, 640, kernel_size=9,padding=4),
nn.ReLU(inplace=True))
self.lconv3 = nn.Sequential(
nn.MaxPool1d(kernel_size=4, stride=4),
nn.Dropout(p=0.2),
nn.Conv1d(640, 960, kernel_size=9, padding=4),
nn.Conv1d(960, 960, kernel_size=9, padding=4))
self.conv3 = nn.Sequential(
nn.Dropout(p=0.2),
nn.Conv1d(960, 960, kernel_size=9,padding=4),
nn.ReLU(inplace=True),
nn.Conv1d(960, 960, kernel_size=9,padding=4),
nn.ReLU(inplace=True))
self.dconv1 = nn.Sequential(
nn.Dropout(p=0.10),
nn.Conv1d(960, 960, kernel_size=5, dilation=2, padding=4),
nn.ReLU(inplace=True))
self.dconv2 = nn.Sequential(
nn.Dropout(p=0.10),
nn.Conv1d(960, 960, kernel_size=5, dilation=4, padding=8),
nn.ReLU(inplace=True))
self.dconv3 = nn.Sequential(
nn.Dropout(p=0.10),
nn.Conv1d(960, 960, kernel_size=5, dilation=8, padding=16),
nn.ReLU(inplace=True))
self.dconv4 = nn.Sequential(
nn.Dropout(p=0.10),
nn.Conv1d(960, 960, kernel_size=5, dilation=16, padding=32),
nn.ReLU(inplace=True))
self.dconv5 = nn.Sequential(
nn.Dropout(p=0.10),
nn.Conv1d(960, 960, kernel_size=5, dilation=25, padding=50),
nn.ReLU(inplace=True))
self._spline_df = int(128/8)
self.spline_tr = nn.Sequential(
nn.Dropout(p=0.5),
BSplineTransformation(self._spline_df, scaled=False))
self.seq_input = 960 * self._spline_df
def forward(self, x):
"""Forward propagation of a batch.
"""
lout1 = self.lconv1(x)
out1 = self.conv1(lout1)
lout2 = self.lconv2(out1 + lout1)
out2 = self.conv2(lout2)
lout3 = self.lconv3(out2 + lout2)
out3 = self.conv3(lout3)
dconv_out1 = self.dconv1(out3 + lout3)
cat_out1 = out3 + dconv_out1
dconv_out2 = self.dconv2(cat_out1)
cat_out2 = cat_out1 + dconv_out2
dconv_out3 = self.dconv3(cat_out2)
cat_out3 = cat_out2 + dconv_out3
dconv_out4 = self.dconv4(cat_out3)
cat_out4 = cat_out3 + dconv_out4
dconv_out5 = self.dconv5(cat_out4)
out = cat_out4 + dconv_out5
spline_out = self.spline_tr(out)
reshape_out = spline_out.view(spline_out.size(0), 960 * self._spline_df)
return reshape_out