Skip to content

Commit

Permalink
fix: make TablesClient.predict permissive to input data types (#13)
Browse files Browse the repository at this point in the history
* fix: make TablesClient.predict permissive to input data types

The current implementation checks input instance's data type according
to column spec's data type. E.g., if the column spec is float, it
requires the input to be float or int, but not string. However, this
is not the same as tables API contract:

   float column data type could be string or number values.

The current code raises exception with error messages like

    TypeError: '0' has type str, but expected one of: int, long, float
    when passed in a string value for numeric columns, which should be
    allowed.

This PR changes the logic so that Python SDK side will be permissive
for the input data type - basically all JSON compatible data types are
allow. And rely on backend for the validation.

* Fix according to comment.

* Fix lint.

* Address comment: use elif instead of if

Co-authored-by: Helin Wang <helin@google.com>
  • Loading branch information
helinwang and Helin Wang authored Mar 23, 2020
1 parent 33445b2 commit 54d27a5
Showing 1 changed file with 60 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

import pkg_resources
import logging
import six

from google.api_core.gapic_v1 import client_info
from google.api_core import exceptions
from google.cloud.automl_v1beta1 import gapic
from google.cloud.automl_v1beta1.proto import data_types_pb2, data_items_pb2
from google.cloud.automl_v1beta1.proto import data_items_pb2
from google.cloud.automl_v1beta1.tables import gcs_client
from google.protobuf import struct_pb2

Expand All @@ -31,6 +32,61 @@
_LOGGER = logging.getLogger(__name__)


def to_proto_value(value):
"""translates a Python value to a google.protobuf.Value.
Args:
value: The Python value to be translated.
Returns:
Tuple of the translated google.protobuf.Value and error if any.
"""
# possible Python types (this is a Python3 module):
# https://simplejson.readthedocs.io/en/latest/#encoders-and-decoders
# JSON Python 2 Python 3
# object dict dict
# array list list
# string unicode str
# number (int) int, long int
# number (real) float float
# true True True
# false False False
# null None None
if value is None:
# translate null to an empty value.
return struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), None
elif isinstance(value, bool):
# This check needs to happen before isinstance(value, int),
# isinstance(value, int) returns True when value is bool.
return struct_pb2.Value(bool_value=value), None
elif isinstance(value, six.integer_types) or isinstance(value, float):
return struct_pb2.Value(number_value=value), None
elif isinstance(value, six.string_types) or isinstance(value, six.text_type):
return struct_pb2.Value(string_value=value), None
elif isinstance(value, dict):
struct_value = struct_pb2.Struct()
for key, v in value.items():
field_value, err = to_proto_value(v)
if err is not None:
return None, err

struct_value.fields[key].CopyFrom(field_value)
return struct_pb2.Value(struct_value=struct_value), None
elif isinstance(value, list):
list_value = []
for v in value:
proto_value, err = to_proto_value(v)
if err is not None:
return None, err
list_value.append(proto_value)
return (
struct_pb2.Value(list_value=struct_pb2.ListValue(values=list_value)),
None,
)
else:
return None, "unsupport data type: {}".format(type(value))


class TablesClient(object):
"""
AutoML Tables API helper.
Expand Down Expand Up @@ -404,42 +460,6 @@ def __column_spec_name_from_args(

return column_spec_name

def __data_type_to_proto_value(self, data_type, value):
type_code = data_type.type_code
if value is None:
return struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)
elif type_code == data_types_pb2.FLOAT64:
return struct_pb2.Value(number_value=value)
elif (
type_code == data_types_pb2.TIMESTAMP
or type_code == data_types_pb2.STRING
or type_code == data_types_pb2.CATEGORY
):
return struct_pb2.Value(string_value=value)
elif type_code == data_types_pb2.ARRAY:
if isinstance(value, struct_pb2.ListValue):
# in case the user passed in a ListValue.
return struct_pb2.Value(list_value=value)
array = []
for item in value:
array.append(
self.__data_type_to_proto_value(data_type.list_element_type, item)
)
return struct_pb2.Value(list_value=struct_pb2.ListValue(values=array))
elif type_code == data_types_pb2.STRUCT:
if isinstance(value, struct_pb2.Struct):
# in case the user passed in a Struct.
return struct_pb2.Value(struct_value=value)
struct_value = struct_pb2.Struct()
for k, v in value.items():
field_value = self.__data_type_to_proto_value(
data_type.struct_type.fields[k], v
)
struct_value.fields[k].CopyFrom(field_value)
return struct_pb2.Value(struct_value=struct_value)
else:
raise ValueError("Unknown type_code: {}".format(type_code))

def __ensure_gcs_client_is_initialized(self, credentials, project):
"""Checks if GCS client is initialized. Initializes it if not.
Expand Down Expand Up @@ -2714,7 +2734,9 @@ def predict(

values = []
for i, c in zip(inputs, column_specs):
value_type = self.__data_type_to_proto_value(c.data_type, i)
value_type, err = to_proto_value(i)
if err is not None:
raise ValueError(err)
values.append(value_type)

row = data_items_pb2.Row(values=values)
Expand Down

0 comments on commit 54d27a5

Please sign in to comment.