Skip to content

Commit 1ca94f9

Browse files
committed
Training code for FCNs on ARIM.
1 parent 526f276 commit 1ca94f9

13 files changed

+534
-4
lines changed

README.md

+1-4
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,4 @@ BibTeX:
6060
r.catalin196@yahoo.ro, raducu.ionescu@gmail.com
6161

6262
### Last Update:
63-
August 5, 2020
64-
65-
66-
63+
June 1, 2021

training/config.json

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"exp_name": "test",
3+
"exp_path": "",
4+
"train_data_path": "./arim_train.npy",
5+
"test_data_path": "./arim_test.npy",
6+
7+
"use_cuda": false,
8+
"no_fft_points": 2048,
9+
"no_points": 1024,
10+
"fs": 40e6,
11+
"nperseg": 102,
12+
"noverlap": 96,
13+
"window_type": "hamming",
14+
15+
"normalize_labels": true,
16+
"stft": true,
17+
18+
"lr": 0.00001,
19+
"weight_decay": 0.000001,
20+
"loss_function": "mse",
21+
22+
"train_eval_split_ratio": 0.8,
23+
"train_epochs": 100,
24+
"batch_size": 2,
25+
"eval_net_epoch": 1,
26+
"save_net_epochs": 2,
27+
28+
"print_loss": 10
29+
}

training/data/data_manager.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from training.data.r_dataset import RadarDataset
2+
import numpy as np
3+
import torch
4+
5+
6+
class DataManager:
7+
def __init__(self, config):
8+
self.config = config
9+
10+
def get_dataloader(self, path):
11+
dataset = RadarDataset(path, self.config)
12+
13+
dataloader = torch.utils.data.DataLoader(
14+
dataset,
15+
batch_size=self.config['batch_size'],
16+
shuffle=True,
17+
pin_memory=self.config['use_cuda'],
18+
drop_last=True,
19+
)
20+
return dataloader
21+
22+
def get_train_eval_dataloaders(self, path):
23+
np.random.seed(707)
24+
25+
dataset = RadarDataset(path, self.config)
26+
dataset_size = len(dataset)
27+
28+
## SPLIT DATASET
29+
train_split = self.config['train_eval_split_ratio']
30+
train_size = int(train_split * dataset_size)
31+
validation_size = dataset_size - train_size
32+
33+
########### CURRENTLY DOING THIS, WHICH WORKS ###########
34+
35+
indices = list(range(dataset_size))
36+
np.random.shuffle(indices)
37+
train_indices = indices[:train_size]
38+
temp = int(train_size + validation_size)
39+
val_indices = indices[train_size:temp]
40+
41+
## DATA LOARDER ##
42+
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
43+
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices)
44+
45+
train_loader = torch.utils.data.DataLoader(dataset=dataset,
46+
batch_size=self.config['batch_size'],
47+
sampler=train_sampler,
48+
pin_memory=self.config['use_cuda'],
49+
drop_last=True)
50+
51+
validation_loader = torch.utils.data.DataLoader(dataset=dataset,
52+
batch_size=self.config['batch_size'],
53+
sampler=valid_sampler,
54+
pin_memory=self.config['use_cuda'],
55+
drop_last=True)
56+
return train_loader, validation_loader

training/data/r_dataset.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch.utils.data
2+
from scipy import signal
3+
import numpy as np
4+
from training.utils.stft_local import stft
5+
6+
7+
class RadarDataset(torch.utils.data.Dataset):
8+
def __init__(self, path, config):
9+
self.config = config
10+
allData = np.load(path, allow_pickle=True)
11+
12+
# This should be modified in accordance with the last modifications on ARIM database
13+
self.sb_raw = allData[()]['sb']
14+
self.sb0_fft = np.fft.fft(allData[()]['sb0'], config['no_fft_points']) / config['no_points']
15+
self.labels = allData[()]['amplitudes']
16+
17+
# normalize labels
18+
self.sb0_fft = np.abs(self.sb0_fft)
19+
if self.config['normalize_labels'] is True:
20+
self.sb0_fft = self.sb0_fft * (2 / 2.5) - 1
21+
22+
def __getitem__(self, index):
23+
if self.config['stft'] is True:
24+
spectrogram = stft(self.sb_raw[index], 2048, signal.get_window('hamming', 102), 1)
25+
else:
26+
spectrogram = signal.spectrogram(self.sb_raw[index], nfft=self.config['no_fft_points'], fs=self.config['fs'],
27+
nperseg=self.config['nperseg'], noverlap=self.config['noverlap'],
28+
window=self.config['window_type'], return_onesided=False, mode='complex')[2]
29+
30+
return [np.expand_dims(np.abs(spectrogram), 0), self.sb0_fft[index], np.abs(self.labels[index])]
31+
32+
def __len__(self):
33+
return len(self.sb0_fft)
34+
35+
def __repr__(self):
36+
return self.__class__.__name__

training/main.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import json
2+
import torch
3+
import torch.optim as optim
4+
5+
from training.networks.FConvNet import FConvNet
6+
from training.networks.FConvBigNet import FConvBigNet
7+
from training.trainer import Trainer
8+
from training.data.data_manager import DataManager
9+
10+
11+
def main():
12+
config = json.load(open('./config.json'))
13+
data_manager = DataManager(config)
14+
15+
model = FConvNet()
16+
if config['use_cuda'] is True:
17+
model = model.cuda()
18+
model.apply(FConvNet.init_weights)
19+
20+
criterion = torch.nn.MSELoss()
21+
optimizer = optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
22+
train_loader, validation_loader = data_manager.get_train_eval_dataloaders(config['train_data_path'])
23+
24+
trainer = Trainer(model, train_loader, validation_loader, criterion, optimizer, config)
25+
trainer.train()
26+
27+
28+
if __name__ == "__main__":
29+
main()

training/networks/FConvBigNet.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch.nn as nn
2+
3+
4+
class FConvBigNet(nn.Module):
5+
def __init__(self):
6+
super(FConvBigNet, self).__init__()
7+
self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
8+
self.conv2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
9+
self.pool1 = nn.MaxPool2d(kernel_size=(1, 2))
10+
11+
self.conv3 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
12+
self.conv4 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
13+
self.pool2 = nn.MaxPool2d(kernel_size=(1, 2))
14+
15+
self.conv5 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
16+
self.conv6 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
17+
self.pool3 = nn.MaxPool2d(kernel_size=(1, 2))
18+
19+
self.conv7 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
20+
self.conv8 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
21+
self.pool4 = nn.MaxPool2d(kernel_size=(1, 2), padding=(0, 1))
22+
23+
self.conv9 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
24+
self.conv10 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
25+
self.pool5 = nn.MaxPool2d(kernel_size=(1, 2), padding=(0, 1))
26+
27+
self.conv11 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
28+
self.conv12 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
29+
self.pool6 = nn.MaxPool2d(kernel_size=(1, 2), padding=(0, 1))
30+
31+
self.conv13 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
32+
self.conv14 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
33+
34+
self.conv15 = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1)
35+
36+
self.relu = nn.ReLU()
37+
38+
@staticmethod
39+
def init_weights(m):
40+
if type(m) == nn.Conv2d:
41+
nn.init.xavier_uniform_(m.weight)
42+
m.bias.data.fill_(0.00001)
43+
44+
def forward(self, x):
45+
x = self.conv1(x)
46+
x = self.relu(x)
47+
x = self.conv2(x)
48+
x = self.relu(x)
49+
x = self.pool1(x)
50+
51+
x = self.conv3(x)
52+
x = self.relu(x)
53+
x = self.conv4(x)
54+
x = self.relu(x)
55+
x = self.pool2(x)
56+
57+
x = self.conv5(x)
58+
x = self.relu(x)
59+
x = self.conv6(x)
60+
x = self.relu(x)
61+
x = self.pool3(x)
62+
63+
x = self.conv7(x)
64+
x = self.relu(x)
65+
x = self.conv8(x)
66+
x = self.relu(x)
67+
x = self.pool4(x)
68+
69+
x = self.conv9(x)
70+
x = self.relu(x)
71+
x = self.conv10(x)
72+
x = self.relu(x)
73+
x = self.pool5(x)
74+
75+
x = self.conv11(x)
76+
x = self.relu(x)
77+
x = self.conv12(x)
78+
x = self.relu(x)
79+
x = self.pool6(x)
80+
81+
x = self.conv13(x)
82+
x = self.relu(x)
83+
x = self.conv14(x)
84+
85+
x = self.conv15(x)
86+
87+
return x.view(-1, 2048)

training/networks/FConvNet.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch.nn as nn
2+
3+
4+
class FConvNet(nn.Module):
5+
def __init__(self):
6+
super(FConvNet, self).__init__()
7+
8+
self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
9+
self.conv2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
10+
self.conv3 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
11+
12+
self.pool1 = nn.MaxPool2d(kernel_size=(1, 2))
13+
14+
self.conv4 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
15+
self.conv5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
16+
self.conv6 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
17+
18+
self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), padding=(0, 1))
19+
20+
self.conv7 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
21+
self.conv8 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
22+
self.conv9 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
23+
24+
self.pool3 = nn.MaxPool2d(kernel_size=(1, 2))
25+
26+
self.conv10 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
27+
self.conv11 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=(4, 0), padding_mode='circular')
28+
29+
self.conv12 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1)
30+
31+
self.relu = nn.ReLU()
32+
33+
@staticmethod
34+
def init_weights(m):
35+
if type(m) == nn.Conv2d:
36+
nn.init.xavier_uniform_(m.weight)
37+
m.bias.data.fill_(0.00001)
38+
39+
def forward(self, x):
40+
x = self.conv1(x)
41+
x = self.relu(x)
42+
x = self.conv2(x)
43+
x = self.relu(x)
44+
x = self.conv3(x)
45+
x = self.relu(x)
46+
47+
x = self.pool1(x)
48+
49+
x = self.conv4(x)
50+
x = self.relu(x)
51+
x = self.conv5(x)
52+
x = self.relu(x)
53+
x = self.conv6(x)
54+
x = self.relu(x)
55+
56+
x = self.pool2(x)
57+
58+
x = self.conv7(x)
59+
x = self.relu(x)
60+
x = self.conv8(x)
61+
x = self.relu(x)
62+
x = self.conv9(x)
63+
x = self.relu(x)
64+
65+
x = self.pool3(x)
66+
67+
x = self.conv10(x)
68+
x = self.relu(x)
69+
x = self.conv11(x)
70+
x = self.relu(x)
71+
x = self.conv12(x)
72+
73+
return x.view(-1, 2048)

training/requirements.txt

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
h5py==2.10.0
2+
hdf5storage==0.1.15
3+
jsonpatch==1.24
4+
jsonpointer==2.0
5+
matplotlib==3.1.1
6+
numpy==1.17.3
7+
Pillow==6.2.1
8+
pyparsing==2.4.5
9+
scipy==1.3.2
10+
torch==1.3.1
11+
torchfile==0.1.0
12+
torchvision==0.4.2
13+
visdom==0.1.8.9
14+
websocket-client==0.56.0

0 commit comments

Comments
 (0)