forked from meteo-team/oemer
-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathtrain.py
64 lines (55 loc) · 2.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys
import time
import os
import tensorflow as tf
from oemer import train
from oemer import classifier
def write_text_to_file(text, path):
with open(path, "w") as f:
f.write(text)
if len(sys.argv) != 2:
print("Usage: python train.py <model_name>")
sys.exit(1)
def get_model_base_name(model_name: str) -> str:
timestamp = str(round(time.time()))
return f"{model_name}_{timestamp}"
model_type = sys.argv[1]
def prepare_classifier_data():
if not os.path.exists("train_data"):
classifier.collect_data(2000)
if model_type == "segnet":
model = train.train_model("ds2_dense", data_model=model_type, steps=1500, epochs=15)
filename = get_model_base_name(model_type)
os.makedirs(filename)
write_text_to_file(model.to_json(), os.path.join(filename, "arch.json"))
model.save_weights(os.path.join(filename, "weights.h5"))
elif model_type == "unet":
model = train.train_model("CvcMuscima-Distortions", data_model=model_type, steps=1500, epochs=15)
filename = get_model_base_name(model_type)
os.makedirs(filename)
write_text_to_file(model.to_json(), os.path.join(filename, "arch.json"))
model.save_weights(os.path.join(filename, "weights.h5"))
elif model_type == "unet_from_checkpoint" or model_type == "segnet_from_checkpoint":
model = tf.keras.models.load_model("seg_unet", custom_objects={"WarmUpLearningRate": train.WarmUpLearningRate})
filename = get_model_base_name(model_type.split("_")[0])
os.makedirs(filename)
write_text_to_file(model.to_json(), os.path.join(filename, "arch.json"))
model.save_weights(os.path.join(filename, "weights.h5"))
elif model_type == "rests_above8":
prepare_classifier_data()
classifier.train_rests_above8(get_model_base_name(model_type))
elif model_type == "rests":
prepare_classifier_data()
classifier.train_rests(get_model_base_name(model_type))
elif model_type == "all_rests":
prepare_classifier_data()
classifier.train_all_rests(get_model_base_name(model_type))
elif model_type == "sfn":
prepare_classifier_data()
classifier.train_sfn(get_model_base_name(model_type))
elif model_type == "clef":
prepare_classifier_data()
classifier.train_clefs(get_model_base_name(model_type))
else:
print("Unknown model: " + model_type)
sys.exit(1)