-
Notifications
You must be signed in to change notification settings - Fork 28
/
DEMO_GEDI_regression_crossval_ensemble.sh
executable file
·71 lines (57 loc) · 2.47 KB
/
DEMO_GEDI_regression_crossval_ensemble.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/bin/bash
# job index (set this to your system job variable e.g. for parallel job arrays)
# used to set model_idx and test_fold_idx below.
index=0 # index=0 --> model_idx=0, test_fold_idx=0
inputs_path=demo_data/GEDI_BDL_demo/GEDI_BDL_demo_subset_neon.npy
target_key='als_rh098'
min_gt=0
max_gt=100
input_key='rxwaveform'
sample_length=1420
noise_mean_key='noise_mean_corrected'
model_name='SimpleResNet_8blocks'
loss_key='gaussian_nll'
n_models=10
n_folds=10
normalize_targets=true
batch_size=16 # the batch size was reduced to obtain a stable optimization with the small demo dataset
nb_epoch=200
base_lr=0.0001
# data augmentation
shift_left=0.2
shift_right=0.2
# quality flags to filter different expected noise levels
setting_idx=3 # 0: power-night, 1: power-night + power-day, 2: power-night + power-day + coverage-night, 3: all
# filtering for complete crossover data including waveform matching information, otherwise all data is used
use_quality_flag=true
pearson_thresh=0.95
# select the model index for the model ensemble
model_idx=$(( $index % ${n_models} ))
# select the test fold index
test_fold_idx=$(( $index / ${n_models} ))
out_dir=output_demo/testfold_${test_fold_idx}/model_${model_idx}
echo job index: $index
echo model_idx: $model_idx
echo test_fold_idx: ${test_fold_idx}
echo output directory: ${out_dir}
# train and test
python3 torch_code/train.py --out_dir=${out_dir} \
--n_folds=${n_folds} \
--test_fold_idx=${test_fold_idx} \
--min_gt=${min_gt} \
--max_gt=${max_gt} \
--batch_size=${batch_size} \
--nb_epoch=${nb_epoch} \
--base_learning_rate=${base_lr} \
--loss_key=${loss_key} \
--sample_length=${sample_length} \
--inputs_path=${inputs_path} \
--input_key=${input_key} \
--target_key=${target_key} \
--shift_left=${shift_left} \
--shift_right=${shift_right} \
--model_name=${model_name}\
--setting_idx=${setting_idx} \
--normalize_targets=${normalize_targets} \
--pearson_thresh=${pearson_thresh} \
--noise_mean_key=${noise_mean_key}