diff --git a/gan_server/server.py b/gan_server/server.py index 00d7e56..89fbc3f 100644 --- a/gan_server/server.py +++ b/gan_server/server.py @@ -1,9 +1,8 @@ import sys import json from flask import Flask, request, jsonify -from flask_cors import CORS, cross_origin import time -import tensorflow as tf +import tensorflow.compat.v1 as tf import tensorflow_hub as hub import numpy as np from scipy.stats import truncnorm @@ -17,8 +16,10 @@ truncation = 0.5 tf.reset_default_graph() +tf.disable_eager_execution() print('Loading BigGAN module from:', module_path) module = hub.Module(module_path) +print('BigGAN module loaded') inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) for k, v in module.get_input_info_dict().items()} output = module(inputs) @@ -34,6 +35,7 @@ sess = tf.Session() sess.run(initializer) +print('ready!') def truncated_z_sample(batch_size): values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=random_state) @@ -120,7 +122,6 @@ def encode_img(arr): return img_str app = Flask(__name__, static_url_path='') #, static_folder='public', ) -CORS(app) @app.route('/') def index(): @@ -186,4 +187,4 @@ def mix_images(): if __name__ == '__main__': port = int(sys.argv[1]) if len(sys.argv) > 1 else 5000 print('port=', port) - app.run(host='0.0.0.0', debug=True, port=port) + app.run(host='0.0.0.0', debug=True, port=port) \ No newline at end of file