|
| 1 | +import torch.nn as nn |
| 2 | + |
| 3 | + |
| 4 | +class FConvBigNet(nn.Module): |
| 5 | + def __init__(self): |
| 6 | + super(FConvBigNet, self).__init__() |
| 7 | + self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 8 | + self.conv2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 9 | + self.pool1 = nn.MaxPool2d(kernel_size=(1, 2)) |
| 10 | + |
| 11 | + self.conv3 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 12 | + self.conv4 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 13 | + self.pool2 = nn.MaxPool2d(kernel_size=(1, 2)) |
| 14 | + |
| 15 | + self.conv5 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 16 | + self.conv6 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 17 | + self.pool3 = nn.MaxPool2d(kernel_size=(1, 2)) |
| 18 | + |
| 19 | + self.conv7 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 20 | + self.conv8 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 21 | + self.pool4 = nn.MaxPool2d(kernel_size=(1, 2), padding=(0, 1)) |
| 22 | + |
| 23 | + self.conv9 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 24 | + self.conv10 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 25 | + self.pool5 = nn.MaxPool2d(kernel_size=(1, 2), padding=(0, 1)) |
| 26 | + |
| 27 | + self.conv11 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 28 | + self.conv12 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 29 | + self.pool6 = nn.MaxPool2d(kernel_size=(1, 2), padding=(0, 1)) |
| 30 | + |
| 31 | + self.conv13 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 32 | + self.conv14 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular') |
| 33 | + |
| 34 | + self.conv15 = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1) |
| 35 | + |
| 36 | + self.relu = nn.ReLU() |
| 37 | + |
| 38 | + @staticmethod |
| 39 | + def init_weights(m): |
| 40 | + if type(m) == nn.Conv2d: |
| 41 | + nn.init.xavier_uniform_(m.weight) |
| 42 | + m.bias.data.fill_(0.00001) |
| 43 | + |
| 44 | + def forward(self, x): |
| 45 | + x = self.conv1(x) |
| 46 | + x = self.relu(x) |
| 47 | + x = self.conv2(x) |
| 48 | + x = self.relu(x) |
| 49 | + x = self.pool1(x) |
| 50 | + |
| 51 | + x = self.conv3(x) |
| 52 | + x = self.relu(x) |
| 53 | + x = self.conv4(x) |
| 54 | + x = self.relu(x) |
| 55 | + x = self.pool2(x) |
| 56 | + |
| 57 | + x = self.conv5(x) |
| 58 | + x = self.relu(x) |
| 59 | + x = self.conv6(x) |
| 60 | + x = self.relu(x) |
| 61 | + x = self.pool3(x) |
| 62 | + |
| 63 | + x = self.conv7(x) |
| 64 | + x = self.relu(x) |
| 65 | + x = self.conv8(x) |
| 66 | + x = self.relu(x) |
| 67 | + x = self.pool4(x) |
| 68 | + |
| 69 | + x = self.conv9(x) |
| 70 | + x = self.relu(x) |
| 71 | + x = self.conv10(x) |
| 72 | + x = self.relu(x) |
| 73 | + x = self.pool5(x) |
| 74 | + |
| 75 | + x = self.conv11(x) |
| 76 | + x = self.relu(x) |
| 77 | + x = self.conv12(x) |
| 78 | + x = self.relu(x) |
| 79 | + x = self.pool6(x) |
| 80 | + |
| 81 | + x = self.conv13(x) |
| 82 | + x = self.relu(x) |
| 83 | + x = self.conv14(x) |
| 84 | + |
| 85 | + x = self.conv15(x) |
| 86 | + |
| 87 | + return x.view(-1, 2048) |
0 commit comments