-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvalidate_rnn.py
93 lines (74 loc) · 3.14 KB
/
validate_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Validate our RNN. Basically just runs a validation generator on
about the same number of videos as we have in our test set.
"""
from keras.callbacks import TensorBoard, ModelCheckpoint, CSVLogger
from models import ResearchModels
from keras.models import load_model
from data import DataSet
import argparse
def validate(data_type, model, seq_length=40, saved_model=None,
class_limit=None, image_shape=None):
test_data_num = 1084
batch_size = 32
# Get the data and process it.
if image_shape is None:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit
)
else:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit,
image_shape=image_shape
)
test_generator = data.frame_generator(batch_size, 'test', data_type)
# Get the model.
#rm = ResearchModels(len(data.classes), model, seq_length, saved_model)
model = load_model(saved_model)
# Evaluate!
#results = rm.model.evaluate_generator(
# generator=val_generator,
# val_samples=3200)
results = model.evaluate_generator(generator=test_generator, steps=test_data_num // batch_size)
print(results)
print(model.metrics)
print(model.metrics_names)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model",
help="Select a model to train (conv_3d, c3d, lrcn, lstm, mlp)")
args = parser.parse_args()
# Fetch model selection
model = args.model
if model in ['conv_3d', 'c3d', 'lrcn']:
data_type = 'images'
image_shape = (80, 80, 3)
if model == 'conv_3d':
saved_model = '/data/d14122793/UCF101_Video_Classi/data/checkpoints/conv_3d-images.019-1.926.hdf5'
validate(data_type, model, saved_model=saved_model,
image_shape=image_shape, class_limit=30)
elif model == 'c3d':
saved_model = '/data/d14122793/UCF101_Video_Classi/data/checkpoints/c3d-images.012-2.149.hdf5'
validate(data_type, model, saved_model=saved_model,
image_shape=image_shape, class_limit=30)
else:
saved_model = '/data/d14122793/UCF101_Video_Classi/data/checkpoints/lrcn-images.030-2.581.hdf5'
validate(data_type, model, saved_model=saved_model,
image_shape=image_shape, class_limit=30)
elif model in ['lstm', 'mlp']:
data_type = 'features'
image_shape = None
if model == 'lstm':
saved_model = '/data/d14122793/UCF101_Video_Classi/data/checkpoints/lstm-features.017-0.525.hdf5'
validate(data_type, model, saved_model=saved_model,
image_shape=image_shape, class_limit=30)
else:
saved_model = '/data/d14122793/UCF101_Video_Classi/data/checkpoints/mlp-features.006-0.513.hdf5'
validate(data_type, model, saved_model=saved_model,
image_shape=image_shape, class_limit=30)
else:
raise ValueError("Invalid model. Please choose one of them: conv_3d, c3d, lrcn, lstm, mlp.")
if __name__ == '__main__':
main()