-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptuna_search.py
114 lines (88 loc) · 4.29 KB
/
optuna_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import json
import optuna
from datetime import datetime
from argparse import ArgumentParser
from functools import partial
from dotmap import DotMap
from train_vae_model import main as train_vae_model
from train_gan_model import main as train_gan_model
parser = ArgumentParser("Arguments for searching optimal hyperparameters")
parser.add_argument("--n-trials", "--n", type=int, default=100,
help="Number of times the model will we trained with different hyperparameters")
parser.add_argument("--model", type=str, required=True, choices=["vae", "gan"],
help="Which model should be tunned")
parser.add_argument("--dataset-path", type=str, required=True,
help="Path to the directory where train, val, test sets are stored")
parser.add_argument("--gpu", type=int, default=1,
help="1 - use gpu, 0 - use cpu")
parser.add_argument("--epochs", type=int, default=100,
help="Number of training epochs for each run")
def sample_hyperparameters(trial, args):
args_dot_map = DotMap()
args_dot_map.lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
args_dot_map.optimizer = trial.suggest_categorical("optim", ["Adam", "AdamW", "SGD"])
args_dot_map.weight_decay = trial.suggest_float("decay", 0, 0.3)
args_dot_map.momentum = trial.suggest_float("momentum", 0.8, 0.99)
args_dot_map.batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
if args.model == "vae":
args_dot_map.embedding_size = trial.suggest_categorical("embed_size", [64, 128, 256, 512, 1024])
elif args.model == "gan":
args_dot_map.noise_dimension = trial.suggest_categorical("noise_dim", [50, 100, 200, 500, 1000])
else:
raise RuntimeError(f"Specified model: '{args.model}' is not supported")
return args_dot_map
def update_dot_map_with_args(args_dot_map, trial, args):
args_dot_map.dataset_path = args.dataset_path
args_dot_map.save_path = f"optuna_search/{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"
args_dot_map.gpu = args.gpu
args_dot_map.epochs = args.epochs
args_dot_map.trial = trial
args_dot_map.trial_number = trial.number
if args.model == "vae":
args_dot_map.calc_mifid = False
return args_dot_map
def serializable_dict(args_dot_map):
return {"lr": args_dot_map.lr,
"embedding_size": args_dot_map.embedding_size,
"optimizer": args_dot_map.optimizer,
"weight_decay": args_dot_map.weight_decay,
"momentum": args_dot_map.momentum,
"batch_size": args_dot_map.batch_size,
"dataset_path": args_dot_map.dataset_path,
"save_path": args_dot_map.save_path,
"gpu": args_dot_map.gpu,
"epochs": args_dot_map.epochs,
"trial_number": args_dot_map.trial_number,
}
def save_hyperparameters_logs(args_dot_map):
file_path = os.path.join(args_dot_map.save_path, "hyperparameters")
os.makedirs(args_dot_map.save_path, exist_ok=True)
with open(file_path, "w") as outfile:
json.dump(serializable_dict(args_dot_map), outfile)
def objective_fn(trial, args):
args_dot_map = sample_hyperparameters(trial, args)
args_dot_map = update_dot_map_with_args(args_dot_map, trial, args)
save_hyperparameters_logs(args_dot_map)
if args.model == "vae":
loss = train_vae_model(args_dot_map)
elif args.model == "gan":
loss = train_gan_model(args_dot_map)
else:
raise RuntimeError(f"Specified model: '{args.model}' is not supported")
return loss
def main(args):
study = optuna.create_study(direction="minimize")
objective = partial(objective_fn, args=args)
study.optimize(objective, n_trials=args.n_trials)
pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
print(f" No. All trials: {len(study.trials)}")
print(f" No. Finished trials: {len(complete_trials)}")
print(f" No. Pruned trials: {len(pruned_trials)}")
print(f"Best trial value: {study.best_trial.value}")
for key, value in study.best_trial.params.items():
print(f"{key}: {value}")
if __name__ == "__main__":
args = parser.parse_args()
main(args)