Skip to content

Commit

Permalink
Batch & online predict docstrings (googleapis#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
lwander authored Jul 2, 2019
1 parent 8e4f241 commit a66c2c7
Showing 1 changed file with 114 additions and 8 deletions.
122 changes: 114 additions & 8 deletions automl/google/cloud/automl_v1beta1/helper/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ def get_model(self, project=None, region=None,
#TODO(jonathanskim): allow deployment from just model ID
def deploy_model(self, model=None, model_name=None,
model_display_name=None, project=None, region=None):
"""Deploys a model. This allows you make online predictions using the
"""Deploys a model. This allows you make online predictions using the
model you've deployed.
Example:
Expand Down Expand Up @@ -1190,7 +1190,7 @@ def deploy_model(self, model=None, model_name=None,

def undeploy_model(self, model=None, model_name=None,
model_display_name=None, project=None, region=None):
"""Undeploys a model.
"""Undeploys a model.
Example:
>>> from google.cloud import automl_v1beta1
Expand Down Expand Up @@ -1248,8 +1248,60 @@ def undeploy_model(self, model=None, model_name=None,
return self.client.undeploy_model(model_name)

## TODO(lwander): support pandas DataFrame as input type
def make_prediction(self, model=None, model_name=None,
model_display_name=None, project=None, region=None, inputs=None):
def predict(self, inputs, model=None, model_name=None,
model_display_name=None, project=None, region=None):
"""Makes a prediction on a deployed model. This will fail if the model
was not deployed.
Example:
>>> from google.cloud import automl_v1beta1
>>>
>>> client = automl_v1beta1.tables.ClientHelper(
... prediction_client=automl_v1beta1.PredictionServiceClient(),
... project='my-project', region='us-central1')
...
>>> client.predict(inputs={'Age': 30, 'Income': 12, 'Category': 'A'}
... model_display_name='my_model')
...
>>> client.predict([30, 12, 'A'], model_display_name='my_model')
>>>
Args:
project (Optional[string]):
If you have initialized the client with a value for `project`
it will be used if this parameter is not supplied. Keep in
mind, the service account this client was initialized with must
have access to this project.
region (Optional[string]):
If you have initialized the client with a value for `region` it
will be used if this parameter is not supplied.
inputs (Union[List[string], Dict[string, string]]):
Either the sorted list of column values to predict with, or a
key-value map of column display name to value to predict with.
model_display_name (Optional[string]):
The human-readable name given to the model you want to predict
with. This must be supplied if `model` or `model_name` are not
supplied.
model_name (Optional[string]):
The AutoML-assigned name given to the model you want to predict
with. This must be supplied if `model_display_name` or `model`
are not supplied.
model (Optional[model]):
The `model` instance you want to predict with . This must be
supplied if `model_display_name` or `model_name` are not
supplied.
Returns:
A :class:`~google.cloud.automl_v1beta1.types.PredictResponse`
instance.
Raises:
google.api_core.exceptions.GoogleAPICallError: If the request
failed for any reason.
google.api_core.exceptions.RetryError: If the request failed due
to a retryable error and retry attempts failed.
ValueError: If required parameters are missing.
"""
if model is None:
model = self.get_model(
model_name=model_name,
Expand Down Expand Up @@ -1281,10 +1333,64 @@ def make_prediction(self, model=None, model_name=None,

return self.prediction_client.predict(model.name, request)

## TODO(lwander): why multiple input uris? how are they handled?
def make_batch_prediction(self, model=None, model_name=None,
model_display_name=None, project=None, gcs_input_uris=None,
gcs_output_uri_prefix=None, region=None, inputs=None):
def batch_predict(self, gcs_input_uris, gcs_output_uri_prefix,
model=None, model_name=None, model_display_name=None, project=None,
region=None, inputs=None):
"""Makes a batch prediction on a model. This does _not_ require the
model to be deployed.
Example:
>>> from google.cloud import automl_v1beta1
>>>
>>> client = automl_v1beta1.tables.ClientHelper(
... prediction_client=automl_v1beta1.PredictionServiceClient(),
... project='my-project', region='us-central1')
...
>>> client.batch_predict(
... gcs_input_uris='gs://inputs/input.csv',
... gcs_output_uri_prefix='gs://outputs/',
... model_display_name='my_model'
... ).result()
...
Args:
project (Optional[string]):
If you have initialized the client with a value for `project`
it will be used if this parameter is not supplied. Keep in
mind, the service account this client was initialized with must
have access to this project.
region (Optional[string]):
If you have initialized the client with a value for `region` it
will be used if this parameter is not supplied.
gcs_input_uris (Union[List[string], string])
Either a list of or a single GCS URI containing the data you
want to predict off of.
gcs_output_uri_prefix (string)
The folder in GCS you want to write output to.
model_display_name (Optional[string]):
The human-readable name given to the model you want to predict
with. This must be supplied if `model` or `model_name` are not
supplied.
model_name (Optional[string]):
The AutoML-assigned name given to the model you want to predict
with. This must be supplied if `model_display_name` or `model`
are not supplied.
model (Optional[model]):
The `model` instance you want to predict with . This must be
supplied if `model_display_name` or `model_name` are not
supplied.
Returns:
A :class:`~google.cloud.automl_v1beta1.types._OperationFuture`
instance.
Raises:
google.api_core.exceptions.GoogleAPICallError: If the request
failed for any reason.
google.api_core.exceptions.RetryError: If the request failed due
to a retryable error and retry attempts failed.
ValueError: If required parameters are missing.
"""
if gcs_input_uris is None or gcs_output_uri_prefix is None:
raise ValueError('Both \'gcs_input_uris\' and '
'\'gcs_output_uri_prefix\' must be set.')
Expand Down

0 comments on commit a66c2c7

Please sign in to comment.