-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
66 lines (54 loc) · 2.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import utils
import os
# Disable Tensorflow's INFO and WARNING messages
# See http://stackoverflow.com/questions/35911252
if 'TF_CPP_MIN_LOG_LEVEL' not in os.environ:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmarks largely inspired by https://github.com/aizvorski/vgg-benchmarks')
parser.add_argument('--run_keras', action="store_true", help="Run keras benchmark")
parser.add_argument('--run_tensorflow', action="store_true", help="Run pure tensorflow benchmark")
parser.add_argument('--run_pytorch', action="store_true", help="Run pytorch benchmark")
parser.add_argument('--batch_size', default=16, type=int, help="Batch size")
parser.add_argument('--n_trials', default=100, type=int,
help="Number of full iterations (forward + backward + update)")
parser.add_argument('--use_XLA', action="store_true", help="Whether to use XLA compiler")
parser.add_argument('--data_format', default="NCHW", type=str, help="Image format")
parser.add_argument('--use_bn', action="store_true",help="Use batch normalization (tf benchmark)")
parser.add_argument('--use_fused', action="store_true",help="Use fused batch normalization (tf benchmark)")
args = parser.parse_args()
assert args.data_format in ["NHWC", "NCHW"]
if args.run_keras:
import benchmark_keras
utils.print_module("Running %s..." % benchmark_keras.__name__)
utils.print_dict(args.__dict__)
benchmark_keras.run_VGG16(args.batch_size,
args.n_trials,
args.use_bn,
args.data_format)
# import benchmark_keras
# utils.print_module("Running %s..." % benchmark_keras.__name__)
# utils.print_dict(args.__dict__)
# benchmark_keras.run_SimpleCNN(args.batch_size)
if args.run_tensorflow:
import benchmark_tensorflow
utils.print_module("Running %s..." % benchmark_tensorflow.__name__)
utils.print_dict(args.__dict__)
benchmark_tensorflow.run_VGG16(args.batch_size,
args.n_trials,
args.data_format,
args.use_XLA,
args.use_bn,
args.use_fused)
if args.run_pytorch:
import benchmark_pytorch
utils.print_module("Running %s..." % benchmark_pytorch.__name__)
utils.print_dict(args.__dict__)
benchmark_pytorch.run_VGG16(args.batch_size,
args.n_trials,)
# utils.print_module("Running %s..." % benchmark_pytorch.__name__)
# utils.print_dict(args.__dict__)
# benchmark_pytorch.run_SimpleCNN(args.batch_size,
# args.n_trials,)