Skip to content

Commit

Permalink
updated deploy logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhanush123 committed Mar 9, 2020
1 parent 23c9af8 commit beec15e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ WORKDIR $APP_HOME
ENV GCS_BUCKET_NAME shakti0
ENV GOOGLE_APPLICATION_CREDENTIALS gcs_creds.json
ENV PROJECT_ID shakti123
ENV GCP_EMAIL personalprojects0@gmail.com


# must run deploy command from folder containing server code
Expand Down
36 changes: 23 additions & 13 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os

from flask import Flask, jsonify, request
from joblib import load
from dotenv import load_dotenv

import os
import sklearn
import numpy as np
import glob

from process import preprocess, postprocess

app = Flask(__name__)
model = None

Expand All @@ -19,21 +21,29 @@ def load_resources():
# that way don't need to use global b/c not thread-safe & using gunicorn
global model
if not model:
model = load(glob.glob("*joblib")[0])
load_model()


def transform_data(input_data):
# will convert from 1D to required 2D
return np.array(input_data.tolist()[:784]).reshape(1, -1)
def load_model():
model_type = os.getenv("MODEL_TYPE", "sklearn")
if model_type == "sklearn":
model_file = glob.glob(
"*joblib")[0] if glob.glob("*joblib") else glob.glob("*pkl")[0]
model = load(model_file)
elif model_type == "pytorch":
# TODO: add onnx runtime
temp = None
else:
raise Exception


@app.route("/", methods=["GET", "POST"])
def predict():
img_nparray = np.fromstring(request.files["image"].read(), np.uint8)
transformed_data = transform_data(img_nparray)
prediction = model.predict(transformed_data).tolist()[0]
return jsonify({"prediction": prediction})
@app.route(os.getenv("REST_ENDPOINT", "/"))
def predict(input_data, methods=['POST']):
transformed_input_data = preprocess(input_data)
prediction = model.predict(transformed_input_data)
transformed_prediction = preprocess(prediction)
return jsonify({"prediction": transformed_prediction})


if __name__ == "__main__":
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8080)))
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))
15 changes: 15 additions & 0 deletions process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from flask import Flask, jsonify, request
import numpy as np

from __main__ import app


def preprocess(input_data):
# will convert from 1D to required 2D
img_nparray = np.fromstring(request.files["image"].read(), np.uint8)
reshaped_data = np.array(img_nparray.tolist()[:784]).reshape(1, -1)
return reshaped_data


def postprocess(input_data):
return input_data.tolist()[0]

0 comments on commit beec15e

Please sign in to comment.