Skip to content

Commit

Permalink
Fix AI Platform online prediction tests (#4791)
Browse files Browse the repository at this point in the history
## Description

Fixes #4776 and fixes #4777 by using a new model version (created in the Cloud Console based on the same trained ML model) and updating code accordingly. It's unknown why the old model version that the test used stopped working.
Fixes #4778 by removing the code in question, which is no longer used in documentation.

## Checklist
- [x] I have followed [Sample Guidelines from AUTHORING_GUIDE.MD](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/AUTHORING_GUIDE.md)
- [ ] README is updated to include [all relevant information](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/AUTHORING_GUIDE.md#readme-file)
- [x] **Tests** pass:   `nox -s py-3.6` (see [Test Environment Setup](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/AUTHORING_GUIDE.md#test-environment-setup))
- [x] **Lint** pass:   `nox -s lint` (see [Test Environment Setup](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/AUTHORING_GUIDE.md#test-environment-setup))
- [ ] These samples need a new **API enabled** in testing projects to pass (let us know which ones)
- [ ] These samples need a new/updated **env vars** in testing projects set to pass (let us know which ones)
- [x] Please **merge** this PR for me once it is approved.
- [ ] This sample adds a new sample directory, and I updated the [CODEOWNERS file](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/.github/CODEOWNERS) with the codeowners for this sample
  • Loading branch information
Alec Glassford authored Oct 6, 2020
1 parent f6ad120 commit e087818
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 118 deletions.
99 changes: 3 additions & 96 deletions ml_engine/online_prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@

"""Examples of using AI Platform's online prediction service."""
import argparse
import base64
import json

# [START import_libraries]
import googleapiclient.discovery
import six
# [END import_libraries]


Expand Down Expand Up @@ -61,83 +59,7 @@ def predict_json(project, model, instances, version=None):
# [END predict_json]


# [START predict_tf_records]
def predict_examples(project,
model,
example_bytes_list,
version=None):
"""Send protocol buffer data to a deployed model for prediction.
Args:
project (str): project where the AI Platform Model is deployed.
model (str): model name.
example_bytes_list ([str]): A list of bytestrings representing
serialized tf.train.Example protocol buffers. The contents of this
protocol buffer will change depending on the signature of your
deployed model.
version: str, version of the model to target.
Returns:
Mapping[str: any]: dictionary of prediction results defined by the
model.
"""
service = googleapiclient.discovery.build('ml', 'v1')
name = 'projects/{}/models/{}'.format(project, model)

if version is not None:
name += '/versions/{}'.format(version)

response = service.projects().predict(
name=name,
body={'instances': [
{'b64': base64.b64encode(example_bytes).decode('utf-8')}
for example_bytes in example_bytes_list
]}
).execute()

if 'error' in response:
raise RuntimeError(response['error'])

return response['predictions']
# [END predict_tf_records]


# [START census_to_example_bytes]
def census_to_example_bytes(json_instance):
"""Serialize a JSON example to the bytes of a tf.train.Example.
This method is specific to the signature of the Census example.
See: https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
for details.
Args:
json_instance (Mapping[str: Any]): Keys should be the names of Tensors
your deployed model expects to parse using it's tf.FeatureSpec.
Values should be datatypes convertible to Tensors, or (potentially
nested) lists of datatypes convertible to tensors.
Returns:
str: A string as a container for the serialized bytes of
tf.train.Example protocol buffer.
"""
import tensorflow as tf
feature_dict = {}
for key, data in six.iteritems(json_instance):
if isinstance(data, six.string_types):
feature_dict[key] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[data.encode('utf-8')]))
elif isinstance(data, float):
feature_dict[key] = tf.train.Feature(
float_list=tf.train.FloatList(value=[data]))
elif isinstance(data, int):
feature_dict[key] = tf.train.Feature(
int64_list=tf.train.Int64List(value=[data]))
return tf.train.Example(
features=tf.train.Features(
feature=feature_dict
)
).SerializeToString()
# [END census_to_example_bytes]


def main(project, model, version=None, force_tfrecord=False):
def main(project, model, version=None):
"""Send user input to the prediction service."""
while True:
try:
Expand All @@ -148,16 +70,8 @@ def main(project, model, version=None, force_tfrecord=False):
if not isinstance(user_input, list):
user_input = [user_input]
try:
if force_tfrecord:
example_bytes_list = [
census_to_example_bytes(e)
for e in user_input
]
result = predict_examples(
project, model, example_bytes_list, version=version)
else:
result = predict_json(
project, model, user_input, version=version)
result = predict_json(
project, model, user_input, version=version)
except RuntimeError as err:
print(str(err))
else:
Expand All @@ -183,16 +97,9 @@ def main(project, model, version=None, force_tfrecord=False):
help='Name of the version.',
type=str
)
parser.add_argument(
'--force-tfrecord',
help='Send predictions as TFRecords rather than raw JSON',
action='store_true',
default=False
)
args = parser.parse_args()
main(
args.project,
args.model,
version=args.version,
force_tfrecord=args.force_tfrecord
)
24 changes: 2 additions & 22 deletions ml_engine/online_prediction/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
import predict

MODEL = 'census'
JSON_VERSION = 'v1json'
EXAMPLES_VERSION = 'v1example'
JSON_VERSION = 'v2json'
PROJECT = 'python-docs-samples-tests'
EXPECTED_OUTPUT = {
u'confidence': 0.7760371565818787,
u'confidence': 0.7760370969772339,
u'predictions': u' <=50K'
}

Expand All @@ -37,10 +36,6 @@
JSON = json.load(f)


with open('resources/census_example_bytes.pb', 'rb') as f:
BYTESTRING = f.read()


@pytest.mark.flaky
def test_predict_json():
result = predict.predict_json(
Expand All @@ -53,18 +48,3 @@ def test_predict_json_error():
with pytest.raises(RuntimeError):
predict.predict_json(
PROJECT, MODEL, [{"foo": "bar"}], version=JSON_VERSION)


@pytest.mark.flaky
def test_census_example_to_bytes():
import tensorflow as tf
b = predict.census_to_example_bytes(JSON)
assert tf.train.Example.FromString(b) == tf.train.Example.FromString(
BYTESTRING)


@pytest.mark.flaky(max_runs=6)
def test_predict_examples():
result = predict.predict_examples(
PROJECT, MODEL, [BYTESTRING, BYTESTRING], version=EXAMPLES_VERSION)
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result

0 comments on commit e087818

Please sign in to comment.