diff --git a/google/cloud/automl_v1beta1/services/tables/tables_client.py b/google/cloud/automl_v1beta1/services/tables/tables_client.py index ed96e98b..566f509c 100644 --- a/google/cloud/automl_v1beta1/services/tables/tables_client.py +++ b/google/cloud/automl_v1beta1/services/tables/tables_client.py @@ -16,8 +16,9 @@ """A tables helper for the google.cloud.automl_v1beta1 AutoML API""" -import pkg_resources +import copy import logging +import pkg_resources import six from google.api_core.gapic_v1 import client_info @@ -107,7 +108,7 @@ def __init__( client=None, prediction_client=None, gcs_client=None, - **kwargs + **kwargs, ): """Constructor. @@ -258,7 +259,6 @@ def __dataset_from_args( dataset_name=None, project=None, region=None, - **kwargs ): if dataset is None and dataset_display_name is None and dataset_name is None: raise ValueError( @@ -275,7 +275,6 @@ def __dataset_from_args( dataset_name=dataset_name, project=project, region=region, - **kwargs ) def __model_from_args( @@ -285,7 +284,6 @@ def __model_from_args( model_name=None, project=None, region=None, - **kwargs ): if model is None and model_display_name is None and model_name is None: raise ValueError( @@ -301,7 +299,6 @@ def __model_from_args( model_name=model_name, project=project, region=region, - **kwargs ) def __dataset_name_from_args( @@ -311,7 +308,6 @@ def __dataset_name_from_args( dataset_name=None, project=None, region=None, - **kwargs ): if dataset is None and dataset_display_name is None and dataset_name is None: raise ValueError( @@ -325,15 +321,12 @@ def __dataset_name_from_args( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs ) dataset_name = dataset.name else: # we do this to force a NotFound error when needed - self.get_dataset( - dataset_name=dataset_name, project=project, region=region, **kwargs - ) + self.get_dataset(dataset_name=dataset_name, project=project, region=region) return dataset_name def __table_spec_name_from_args( @@ -344,7 +337,6 @@ def __table_spec_name_from_args( dataset_name=None, project=None, region=None, - **kwargs ): dataset_name = self.__dataset_name_from_args( dataset=dataset, @@ -352,12 +344,9 @@ def __table_spec_name_from_args( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs ) - table_specs = [ - t for t in self.list_table_specs(dataset_name=dataset_name, **kwargs) - ] + table_specs = [t for t in self.list_table_specs(dataset_name=dataset_name)] table_spec_full_id = table_specs[table_spec_index].name return table_spec_full_id @@ -369,7 +358,6 @@ def __model_name_from_args( model_name=None, project=None, region=None, - **kwargs ): if model is None and model_display_name is None and model_name is None: raise ValueError( @@ -382,14 +370,11 @@ def __model_name_from_args( model_display_name=model_display_name, project=project, region=region, - **kwargs ) model_name = model.name else: # we do this to force a NotFound error when needed - self.get_model( - model_name=model_name, project=project, region=region, **kwargs - ) + self.get_model(model_name=model_name, project=project, region=region) return model_name def __log_operation_info(self, message, op): @@ -426,7 +411,7 @@ def __column_spec_name_from_args( column_spec_display_name=None, project=None, region=None, - **kwargs + **kwargs, ): column_specs = self.list_column_specs( dataset=dataset, @@ -436,7 +421,7 @@ def __column_spec_name_from_args( table_spec_index=table_spec_index, project=project, region=region, - **kwargs + **kwargs, ) if column_spec_display_name is not None: column_specs = {s.display_name: s for s in column_specs} @@ -481,6 +466,29 @@ def __ensure_gcs_client_is_initialized(self, credentials, project): project=project, credentials=credentials ) + def __process_request_kwargs(self, request, **kwargs): + """Add request kwargs to the request and return remaining kwargs. + + Some kwargs are for the request object and others are for + the method itself (retry, metdata). + + Args: + request (proto.Message) The request object. + + Returns: + dict: kwargs to be added to the method. + """ + + method_kwargs = copy.deepcopy(kwargs) + for key, value in kwargs.items(): + try: + setattr(request, key, value) + method_kwargs.pop(key) + except AttributeError: + continue + + return method_kwargs + def list_datasets(self, project=None, region=None, **kwargs): """List all datasets in a particular project and region. @@ -523,17 +531,22 @@ def list_datasets(self, project=None, region=None, **kwargs): to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.auto_ml_client.list_datasets( - parent=self.__location_path(project=project, region=region), **kwargs + + request = google.cloud.automl_v1beta1.ListDatasetsRequest( + parent=self.__location_path(project=project, region=region), ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + + return self.auto_ml_client.list_datasets(request=request, **method_kwargs) + def get_dataset( self, project=None, region=None, dataset_name=None, dataset_display_name=None, - **kwargs + **kwargs, ): """Gets a single dataset in a particular project and region. @@ -586,7 +599,10 @@ def get_dataset( ) if dataset_name is not None: - return self.auto_ml_client.get_dataset(name=dataset_name, **kwargs) + request = google.cloud.automl_v1beta1.GetDatasetRequest(name=dataset_name,) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + + return self.auto_ml_client.get_dataset(request=request, **method_kwargs) return self.__lookup_by_display_name( "dataset", @@ -633,14 +649,16 @@ def create_dataset( to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.auto_ml_client.create_dataset( + request = google.cloud.automl_v1beta1.CreateDatasetRequest( parent=self.__location_path(project, region), dataset={ "display_name": dataset_display_name, "tables_dataset_metadata": metadata, }, - **kwargs ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + + return self.auto_ml_client.create_dataset(request=request, **method_kwargs) def delete_dataset( self, @@ -649,7 +667,7 @@ def delete_dataset( dataset_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Deletes a dataset. This does not delete any models trained on this dataset. @@ -709,13 +727,15 @@ def delete_dataset( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs + **kwargs, ) # delete is idempotent except exceptions.NotFound: return None - op = self.auto_ml_client.delete_dataset(name=dataset_name, **kwargs) + request = google.cloud.automl_v1beta1.DeleteDatasetRequest(name=dataset_name,) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + op = self.auto_ml_client.delete_dataset(request=request, **method_kwargs) self.__log_operation_info("Delete dataset", op) return op @@ -730,7 +750,7 @@ def import_data( project=None, region=None, credentials=None, - **kwargs + **kwargs, ): """Imports data into a dataset. @@ -814,7 +834,6 @@ def import_data( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs ) request = {} @@ -838,9 +857,12 @@ def import_data( "One of 'gcs_input_uris', or 'bigquery_input_uri', or 'pandas_dataframe' must be set." ) - op = self.auto_ml_client.import_data( - name=dataset_name, input_config=request, **kwargs + req = google.cloud.automl_v1beta1.ImportDataRequest( + name=dataset_name, input_config=request ) + method_kwargs = self.__process_request_kwargs(req, **kwargs) + + op = self.auto_ml_client.import_data(request=req, **method_kwargs) self.__log_operation_info("Data import", op) return op @@ -853,7 +875,7 @@ def export_data( bigquery_output_uri=None, project=None, region=None, - **kwargs + **kwargs, ): """Exports data from a dataset. @@ -923,7 +945,6 @@ def export_data( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs ) request = {} @@ -936,9 +957,12 @@ def export_data( "One of 'gcs_output_uri_prefix', or 'bigquery_output_uri' must be set." ) - op = self.auto_ml_client.export_data( + req = google.cloud.automl_v1beta1.ExportDataRequest( name=dataset_name, output_config=request, **kwargs ) + + method_kwargs = self.__process_request_kwargs(req, **kwargs) + op = self.auto_ml_client.export_data(request=req, **method_kwargs) self.__log_operation_info("Export data", op) return op @@ -980,7 +1004,10 @@ def get_table_spec(self, table_spec_name, project=None, region=None, **kwargs): to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.auto_ml_client.get_table_spec(name=table_spec_name, **kwargs) + request = google.cloud.automl_v1beta1.GetTableSpecRequest(name=table_spec_name,) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + + return self.auto_ml_client.get_table_spec(request=request, **method_kwargs) def list_table_specs( self, @@ -989,7 +1016,7 @@ def list_table_specs( dataset_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Lists table specs. @@ -1049,10 +1076,13 @@ def list_table_specs( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs ) - return self.auto_ml_client.list_table_specs(parent=dataset_name, **kwargs) + request = google.cloud.automl_v1beta1.ListTableSpecsRequest( + parent=dataset_name, + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + return self.auto_ml_client.list_table_specs(request=request, **method_kwargs) def get_column_spec(self, column_spec_name, project=None, region=None, **kwargs): """Gets a single column spec in a particular project and region. @@ -1092,7 +1122,11 @@ def get_column_spec(self, column_spec_name, project=None, region=None, **kwargs) to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.auto_ml_client.get_column_spec(name=column_spec_name, **kwargs) + request = google.cloud.automl_v1beta1.GetColumnSpecRequest( + name=column_spec_name, + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + return self.auto_ml_client.get_column_spec(request=request, **method_kwargs) def list_column_specs( self, @@ -1103,7 +1137,7 @@ def list_column_specs( table_spec_index=0, project=None, region=None, - **kwargs + **kwargs, ): """Lists column specs. @@ -1181,13 +1215,17 @@ def list_column_specs( dataset_name=dataset_name, project=project, region=region, - **kwargs ) ] table_spec_name = table_specs[table_spec_index].name - return self.auto_ml_client.list_column_specs(parent=table_spec_name, **kwargs) + request = google.cloud.automl_v1beta1.ListColumnSpecsRequest( + parent=table_spec_name, + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + + return self.auto_ml_client.list_column_specs(request=request, **method_kwargs) def update_column_spec( self, @@ -1202,7 +1240,7 @@ def update_column_spec( nullable=None, project=None, region=None, - **kwargs + **kwargs, ): """Updates a column's specs. @@ -1292,7 +1330,6 @@ def update_column_spec( column_spec_display_name=column_spec_display_name, project=project, region=region, - **kwargs ) # type code must always be set @@ -1309,7 +1346,6 @@ def update_column_spec( table_spec_index=table_spec_index, project=project, region=region, - **kwargs ) }[column_spec_name].data_type.type_code @@ -1317,11 +1353,14 @@ def update_column_spec( if nullable is not None: data_type["nullable"] = nullable - data_type["type_code"] = type_code + data_type["type_code"] = google.cloud.automl_v1beta1.TypeCode(type_code) - request = {"name": column_spec_name, "data_type": data_type} + request = google.cloud.automl_v1beta1.UpdateColumnSpecRequest( + column_spec={"name": column_spec_name, "data_type": data_type} + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) - return self.auto_ml_client.update_column_spec(column_spec=request, **kwargs) + return self.auto_ml_client.update_column_spec(request=request, **method_kwargs) def set_target_column( self, @@ -1334,7 +1373,7 @@ def set_target_column( column_spec_display_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Sets the target column for a given table. @@ -1419,7 +1458,7 @@ def set_target_column( column_spec_display_name=column_spec_display_name, project=project, region=region, - **kwargs + **kwargs, ) column_spec_id = column_spec_name.rsplit("/", 1)[-1] @@ -1429,16 +1468,19 @@ def set_target_column( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs + **kwargs, ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata( metadata, "target_column_spec_id", column_spec_id ) - request = {"name": dataset.name, "tables_dataset_metadata": metadata} + request = google.cloud.automl_v1beta1.UpdateDatasetRequest( + dataset={"name": dataset.name, "tables_dataset_metadata": metadata} + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) - return self.auto_ml_client.update_dataset(dataset=request, **kwargs) + return self.auto_ml_client.update_dataset(request=request, **method_kwargs) def set_time_column( self, @@ -1451,7 +1493,7 @@ def set_time_column( column_spec_display_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Sets the time column which designates which data will be of type timestamp and will be used for the timeseries data. @@ -1534,7 +1576,7 @@ def set_time_column( column_spec_display_name=column_spec_display_name, project=project, region=region, - **kwargs + **kwargs, ) column_spec_id = column_spec_name.rsplit("/", 1)[-1] @@ -1544,19 +1586,18 @@ def set_time_column( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs ) - table_spec_full_id = self.__table_spec_name_from_args( - dataset_name=dataset_name, **kwargs - ) - - my_table_spec = { - "name": table_spec_full_id, - "time_column_spec_id": column_spec_id, - } + table_spec_full_id = self.__table_spec_name_from_args(dataset_name=dataset_name) - return self.auto_ml_client.update_table_spec(table_spec=my_table_spec, **kwargs) + request = google.cloud.automl_v1beta1.UpdateTableSpecRequest( + table_spec={ + "name": table_spec_full_id, + "time_column_spec_id": column_spec_id, + } + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + return self.auto_ml_client.update_table_spec(request=request, **method_kwargs) def clear_time_column( self, @@ -1565,7 +1606,7 @@ def clear_time_column( dataset_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Clears the time column which designates which data will be of type timestamp and will be used for the timeseries data. @@ -1627,16 +1668,17 @@ def clear_time_column( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs ) - table_spec_full_id = self.__table_spec_name_from_args( - dataset_name=dataset_name, **kwargs - ) + table_spec_full_id = self.__table_spec_name_from_args(dataset_name=dataset_name) my_table_spec = {"name": table_spec_full_id, "time_column_spec_id": None} - return self.auto_ml_client.update_table_spec(table_spec=my_table_spec, **kwargs) + request = google.cloud.automl_v1beta1.UpdateTableSpecRequest( + table_spec=my_table_spec + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + return self.auto_ml_client.update_table_spec(request=request, **method_kwargs) def set_weight_column( self, @@ -1649,7 +1691,7 @@ def set_weight_column( column_spec_display_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Sets the weight column for a given table. @@ -1734,7 +1776,6 @@ def set_weight_column( column_spec_display_name=column_spec_display_name, project=project, region=region, - **kwargs ) column_spec_id = column_spec_name.rsplit("/", 1)[-1] @@ -1744,16 +1785,19 @@ def set_weight_column( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata( metadata, "weight_column_spec_id", column_spec_id ) - request = {"name": dataset.name, "tables_dataset_metadata": metadata} + request = google.cloud.automl_v1beta1.UpdateDatasetRequest( + dataset={"name": dataset.name, "tables_dataset_metadata": metadata} + ) + + method_kwargs = self.__process_request_kwargs(request, **kwargs) - return self.auto_ml_client.update_dataset(dataset=request, **kwargs) + return self.auto_ml_client.update_dataset(request=request, **method_kwargs) def clear_weight_column( self, @@ -1762,7 +1806,7 @@ def clear_weight_column( dataset_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Clears the weight column for a given dataset. @@ -1825,14 +1869,16 @@ def clear_weight_column( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata(metadata, "weight_column_spec_id", None) - request = {"name": dataset.name, "tables_dataset_metadata": metadata} + request = google.cloud.automl_v1beta1.UpdateDatasetRequest( + dataset={"name": dataset.name, "tables_dataset_metadata": metadata} + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) - return self.auto_ml_client.update_dataset(dataset=request, **kwargs) + return self.auto_ml_client.update_dataset(request=request, **method_kwargs) def set_test_train_column( self, @@ -1845,7 +1891,7 @@ def set_test_train_column( column_spec_display_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Sets the test/train (ml_use) column which designates which data belongs to the test and train sets. This column must be categorical. @@ -1931,7 +1977,7 @@ def set_test_train_column( column_spec_display_name=column_spec_display_name, project=project, region=region, - **kwargs + **kwargs, ) column_spec_id = column_spec_name.rsplit("/", 1)[-1] @@ -1941,16 +1987,19 @@ def set_test_train_column( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs + **kwargs, ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata( metadata, "ml_use_column_spec_id", column_spec_id ) - request = {"name": dataset.name, "tables_dataset_metadata": metadata} + request = google.cloud.automl_v1beta1.UpdateDatasetRequest( + dataset={"name": dataset.name, "tables_dataset_metadata": metadata} + ) - return self.auto_ml_client.update_dataset(dataset=request, **kwargs) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + return self.auto_ml_client.update_dataset(request=request, **method_kwargs) def clear_test_train_column( self, @@ -1959,7 +2008,7 @@ def clear_test_train_column( dataset_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Clears the test/train (ml_use) column which designates which data belongs to the test and train sets. @@ -2023,14 +2072,17 @@ def clear_test_train_column( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs + **kwargs, ) metadata = dataset.tables_dataset_metadata metadata = self.__update_metadata(metadata, "ml_use_column_spec_id", None) - request = {"name": dataset.name, "tables_dataset_metadata": metadata} + request = google.cloud.automl_v1beta1.UpdateDatasetRequest( + dataset={"name": dataset.name, "tables_dataset_metadata": metadata} + ) - return self.auto_ml_client.update_dataset(dataset=request, **kwargs) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + return self.auto_ml_client.update_dataset(request=request, **method_kwargs) def list_models(self, project=None, region=None, **kwargs): """List all models in a particular project and region. @@ -2074,10 +2126,14 @@ def list_models(self, project=None, region=None, **kwargs): to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ - return self.auto_ml_client.list_models( - parent=self.__location_path(project=project, region=region), **kwargs + + request = google.cloud.automl_v1beta1.ListModelsRequest( + parent=self.__location_path(project=project, region=region), ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + return self.auto_ml_client.list_models(request=request, **method_kwargs) + def list_model_evaluations( self, project=None, @@ -2085,7 +2141,7 @@ def list_model_evaluations( model=None, model_display_name=None, model_name=None, - **kwargs + **kwargs, ): """List all model evaluations for a given model. @@ -2153,10 +2209,15 @@ def list_model_evaluations( model_display_name=model_display_name, project=project, region=region, - **kwargs ) - return self.auto_ml_client.list_model_evaluations(parent=model_name, **kwargs) + request = google.cloud.automl_v1beta1.ListModelEvaluationsRequest( + parent=model_name, + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + return self.auto_ml_client.list_model_evaluations( + request=request, **method_kwargs + ) def create_model( self, @@ -2172,7 +2233,7 @@ def create_model( include_column_spec_names=None, exclude_column_spec_names=None, disable_early_stopping=False, - **kwargs + **kwargs, ): """Create a model. This will train your model on the given dataset. @@ -2276,7 +2337,7 @@ def create_model( dataset_display_name=dataset_display_name, project=project, region=region, - **kwargs + **kwargs, ) model_metadata["train_budget_milli_node_hours"] = train_budget_milli_node_hours @@ -2292,7 +2353,7 @@ def create_model( dataset=dataset, dataset_name=dataset_name, dataset_display_name=dataset_display_name, - **kwargs + **kwargs, ) ] @@ -2310,17 +2371,19 @@ def create_model( model_metadata["input_feature_column_specs"] = final_columns - request = { - "display_name": model_display_name, - "dataset_id": dataset_id, - "tables_model_metadata": model_metadata, - } - - op = self.auto_ml_client.create_model( + req = google.cloud.automl_v1beta1.CreateModelRequest( parent=self.__location_path(project=project, region=region), - model=request, - **kwargs + model=google.cloud.automl_v1beta1.Model( + display_name=model_display_name, + dataset_id=dataset_id, + tables_model_metadata=google.cloud.automl_v1beta1.TablesModelMetadata( + model_metadata + ), + ), ) + + method_kwargs = self.__process_request_kwargs(req, **kwargs) + op = self.auto_ml_client.create_model(request=req, **method_kwargs) self.__log_operation_info("Model creation", op) return op @@ -2331,7 +2394,7 @@ def delete_model( model_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Deletes a model. Note this will not delete any datasets associated with this model. @@ -2391,13 +2454,14 @@ def delete_model( model_display_name=model_display_name, project=project, region=region, - **kwargs ) # delete is idempotent except exceptions.NotFound: return None - op = self.auto_ml_client.delete_model(name=model_name, **kwargs) + request = google.cloud.automl_v1beta1.DeleteModelRequest(name=model_name) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + op = self.auto_ml_client.delete_model(request=request, **method_kwargs) self.__log_operation_info("Delete model", op) return op @@ -2441,8 +2505,12 @@ def get_model_evaluation( to a retryable error and retry attempts failed. ValueError: If required parameters are missing. """ + request = google.cloud.automl_v1beta1.GetModelEvaluationRequest( + name=model_evaluation_name + ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) return self.auto_ml_client.get_model_evaluation( - name=model_evaluation_name, **kwargs + request=request, **method_kwargs ) def get_model( @@ -2451,7 +2519,7 @@ def get_model( region=None, model_name=None, model_display_name=None, - **kwargs + **kwargs, ): """Gets a single model in a particular project and region. @@ -2503,10 +2571,10 @@ def get_model( ) if model_name is not None: - return self.auto_ml_client.get_model(name=model_name, **kwargs) + return self.auto_ml_client.get_model(name=model_name) return self.__lookup_by_display_name( - "model", self.list_models(project, region, **kwargs), model_display_name + "model", self.list_models(project, region), model_display_name ) # TODO(jonathanskim): allow deployment from just model ID @@ -2517,7 +2585,7 @@ def deploy_model( model_display_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Deploys a model. This allows you make online predictions using the model you've deployed. @@ -2576,10 +2644,12 @@ def deploy_model( model_display_name=model_display_name, project=project, region=region, - **kwargs ) - op = self.auto_ml_client.deploy_model(name=model_name, **kwargs) + request = google.cloud.automl_v1beta1.DeployModelRequest(name=model_name) + + method_kwargs = self.__process_request_kwargs(request, **kwargs) + op = self.auto_ml_client.deploy_model(request=request, **method_kwargs) self.__log_operation_info("Deploy model", op) return op @@ -2590,7 +2660,7 @@ def undeploy_model( model_display_name=None, project=None, region=None, - **kwargs + **kwargs, ): """Undeploys a model. @@ -2648,10 +2718,11 @@ def undeploy_model( model_display_name=model_display_name, project=project, region=region, - **kwargs ) - op = self.auto_ml_client.undeploy_model(name=model_name, **kwargs) + request = google.cloud.automl_v1beta1.UndeployModelRequest(name=model_name) + method_kwargs = self.__process_request_kwargs(request=request, **kwargs) + op = self.auto_ml_client.undeploy_model(request=request, **method_kwargs) self.__log_operation_info("Undeploy model", op) return op @@ -2665,7 +2736,7 @@ def predict( feature_importance=False, project=None, region=None, - **kwargs + **kwargs, ): """Makes a prediction on a deployed model. This will fail if the model was not deployed. @@ -2730,7 +2801,6 @@ def predict( model_display_name=model_display_name, project=project, region=region, - **kwargs ) column_specs = model.tables_model_metadata.input_feature_column_specs @@ -2766,9 +2836,11 @@ def predict( if feature_importance: params = {"feature_importance": "true"} - return self.prediction_client.predict( - name=model.name, payload=payload, params=params, **kwargs + request = google.cloud.automl_v1beta1.PredictRequest( + name=model.name, payload=payload, params=params, ) + method_kwargs = self.__process_request_kwargs(request, **kwargs) + return self.prediction_client.predict(request=request, **method_kwargs) def batch_predict( self, @@ -2785,7 +2857,7 @@ def batch_predict( credentials=None, inputs=None, params={}, - **kwargs + **kwargs, ): """Makes a batch prediction on a model. This does _not_ require the model to be deployed. @@ -2873,7 +2945,6 @@ def batch_predict( model_display_name=model_display_name, project=project, region=region, - **kwargs ) input_request = None @@ -2911,11 +2982,11 @@ def batch_predict( "One of 'gcs_output_uri_prefix'/'bigquery_output_uri' must be set" ) - op = self.prediction_client.batch_predict( - name=model_name, - input_config=input_request, - output_config=output_request, - **kwargs + req = google.cloud.automl_v1beta1.BatchPredictRequest( + name=model_name, input_config=input_request, output_config=output_request, ) + + method_kwargs = self.__process_request_kwargs(req, **kwargs) + op = self.prediction_client.batch_predict(request=req, **method_kwargs) self.__log_operation_info("Batch predict", op) return op diff --git a/noxfile.py b/noxfile.py index 62ae86d4..9c69b3b7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -74,6 +74,7 @@ def default(session): session.install("mock", "pytest", "pytest-cov") session.install("-e", ".[pandas,storage]") + session.install("proto-plus==1.8.1") # Run py.test against the unit tests. session.run( diff --git a/samples/tables/automl_tables_dataset.py b/samples/tables/automl_tables_dataset.py index 144f2ee6..76ece1c9 100644 --- a/samples/tables/automl_tables_dataset.py +++ b/samples/tables/automl_tables_dataset.py @@ -47,9 +47,7 @@ def create_dataset(project_id, compute_region, dataset_display_name): print("Dataset metadata:") print("\t{}".format(dataset.tables_dataset_metadata)) print("Dataset example count: {}".format(dataset.example_count)) - print("Dataset create time:") - print("\tseconds: {}".format(dataset.create_time.seconds)) - print("\tnanos: {}".format(dataset.create_time.nanos)) + print("Dataset create time: {}".format(dataset.create_time)) # [END automl_tables_create_dataset] @@ -105,9 +103,7 @@ def list_datasets(project_id, compute_region, filter_=None): ) ) print("Dataset example count: {}".format(dataset.example_count)) - print("Dataset create time:") - print("\tseconds: {}".format(dataset.create_time.seconds)) - print("\tnanos: {}".format(dataset.create_time.nanos)) + print("Dataset create time: {}".format(dataset.create_time)) print("\n") # [END automl_tables_list_datasets] @@ -137,9 +133,7 @@ def get_dataset(project_id, compute_region, dataset_display_name): print("Dataset metadata:") print("\t{}".format(dataset.tables_dataset_metadata)) print("Dataset example count: {}".format(dataset.example_count)) - print("Dataset create time:") - print("\tseconds: {}".format(dataset.create_time.seconds)) - print("\tnanos: {}".format(dataset.create_time.nanos)) + print("Dataset create time: {}".format(dataset.create_time)) return dataset diff --git a/samples/tables/automl_tables_model.py b/samples/tables/automl_tables_model.py index 09a8f4ca..95dc5eb8 100644 --- a/samples/tables/automl_tables_model.py +++ b/samples/tables/automl_tables_model.py @@ -190,7 +190,7 @@ def get_model(project_id, compute_region, model_display_name): def list_model_evaluations( - project_id, compute_region, model_display_name, filter_=None + project_id, compute_region, model_display_name, filter=None ): """List model evaluations.""" @@ -200,7 +200,7 @@ def list_model_evaluations( # project_id = 'PROJECT_ID_HERE' # compute_region = 'COMPUTE_REGION_HERE' # model_display_name = 'MODEL_DISPLAY_NAME_HERE' - # filter_ = 'filter expression here' + # filter = 'filter expression here' from google.cloud import automl_v1beta1 as automl @@ -208,7 +208,7 @@ def list_model_evaluations( # List all the model evaluations in the model by applying filter. response = client.list_model_evaluations( - model_display_name=model_display_name, filter_=filter_ + model_display_name=model_display_name, filter=filter ) print("List of model evaluations:") @@ -220,9 +220,7 @@ def list_model_evaluations( evaluation.evaluated_example_count ) ) - print("Model evaluation time:") - print("\tseconds: {}".format(evaluation.create_time.seconds)) - print("\tnanos: {}".format(evaluation.create_time.nanos)) + print("Model evaluation time: {}".format(evaluation.create_time)) print("\n") # [END automl_tables_list_model_evaluations] result.append(evaluation) @@ -261,7 +259,7 @@ def get_model_evaluation( def display_evaluation( - project_id, compute_region, model_display_name, filter_=None + project_id, compute_region, model_display_name, filter=None ): """Display evaluation.""" # [START automl_tables_display_evaluation] @@ -269,7 +267,7 @@ def display_evaluation( # project_id = 'PROJECT_ID_HERE' # compute_region = 'COMPUTE_REGION_HERE' # model_display_name = 'MODEL_DISPLAY_NAME_HERE' - # filter_ = 'filter expression here' + # filter = 'filter expression here' from google.cloud import automl_v1beta1 as automl @@ -277,7 +275,7 @@ def display_evaluation( # List all the model evaluations in the model by applying filter. response = client.list_model_evaluations( - model_display_name=model_display_name, filter_=filter_ + model_display_name=model_display_name, filter=filter ) # Iterate through the results. diff --git a/samples/tables/automl_tables_predict.py b/samples/tables/automl_tables_predict.py index 9787e1b9..a330213c 100644 --- a/samples/tables/automl_tables_predict.py +++ b/samples/tables/automl_tables_predict.py @@ -58,7 +58,7 @@ def predict( print("Prediction results:") for result in response.payload: print( - "Predicted class name: {}".format(result.tables.value.string_value) + "Predicted class name: {}".format(result.tables.value) ) print("Predicted class score: {}".format(result.tables.score)) diff --git a/samples/tables/endpoint_test.py b/samples/tables/endpoint_test.py index 5a20aba5..6af6b8da 100644 --- a/samples/tables/endpoint_test.py +++ b/samples/tables/endpoint_test.py @@ -23,4 +23,4 @@ def test_client_creation(capsys): automl_tables_set_endpoint.create_client_with_endpoint(PROJECT) out, _ = capsys.readouterr() - assert "GRPCIterator" in out + assert "ListDatasetsPager" in out diff --git a/tests/unit/test_tables_client_v1beta1.py b/tests/unit/test_tables_client_v1beta1.py index 3c5b55d8..1d5b168c 100644 --- a/tests/unit/test_tables_client_v1beta1.py +++ b/tests/unit/test_tables_client_v1beta1.py @@ -48,29 +48,31 @@ def tables_client( def test_list_datasets_empty(self): client = self.tables_client( - { + client_attrs={ "list_datasets.return_value": [], "location_path.return_value": LOCATION_PATH, }, - {}, + prediction_client_attrs={}, ) ds = client.list_datasets() - client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) - client.auto_ml_client.list_datasets.assert_called_with(parent=LOCATION_PATH) + + request = automl_v1beta1.ListDatasetsRequest(parent=LOCATION_PATH) + client.auto_ml_client.list_datasets.assert_called_with(request=request) assert ds == [] def test_list_datasets_not_empty(self): datasets = ["some_dataset"] client = self.tables_client( - { + client_attrs={ "list_datasets.return_value": datasets, "location_path.return_value": LOCATION_PATH, }, - {}, + prediction_client_attrs={}, ) ds = client.list_datasets() - client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) - client.auto_ml_client.list_datasets.assert_called_with(parent=LOCATION_PATH) + + request = automl_v1beta1.ListDatasetsRequest(parent=LOCATION_PATH) + client.auto_ml_client.list_datasets.assert_called_with(request=request) assert len(ds) == 1 assert ds[0] == "some_dataset" @@ -84,7 +86,9 @@ def test_get_dataset_name(self): dataset_actual = "dataset" client = self.tables_client({"get_dataset.return_value": dataset_actual}, {}) dataset = client.get_dataset(dataset_name="my_dataset") - client.auto_ml_client.get_dataset.assert_called_with(name="my_dataset") + client.auto_ml_client.get_dataset.assert_called_with( + request=automl_v1beta1.GetDatasetRequest(name="my_dataset") + ) assert dataset == dataset_actual def test_get_no_dataset(self): @@ -93,7 +97,9 @@ def test_get_no_dataset(self): ) with pytest.raises(exceptions.NotFound): client.get_dataset(dataset_name="my_dataset") - client.auto_ml_client.get_dataset.assert_called_with(name="my_dataset") + client.auto_ml_client.get_dataset.assert_called_with( + request=automl_v1beta1.GetDatasetRequest(name="my_dataset") + ) def test_get_dataset_from_empty_list(self): client = self.tables_client({"list_datasets.return_value": []}, {}) @@ -142,12 +148,14 @@ def test_create_dataset(self): }, {}, ) - metadata = {"metadata": "values"} + metadata = {"primary_table_spec_id": "1234"} dataset = client.create_dataset("name", metadata=metadata) - client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.create_dataset.assert_called_with( - parent=LOCATION_PATH, - dataset={"display_name": "name", "tables_dataset_metadata": metadata}, + request=automl_v1beta1.CreateDatasetRequest( + parent=LOCATION_PATH, + dataset={"display_name": "name", "tables_dataset_metadata": metadata}, + ) ) assert dataset.display_name == "name" @@ -156,7 +164,9 @@ def test_delete_dataset(self): dataset.configure_mock(name="name") client = self.tables_client({"delete_dataset.return_value": None}, {}) client.delete_dataset(dataset=dataset) - client.auto_ml_client.delete_dataset.assert_called_with(name="name") + client.auto_ml_client.delete_dataset.assert_called_with( + request=automl_v1beta1.DeleteDatasetRequest(name="name") + ) def test_delete_dataset_not_found(self): client = self.tables_client({"list_datasets.return_value": []}, {}) @@ -166,7 +176,9 @@ def test_delete_dataset_not_found(self): def test_delete_dataset_name(self): client = self.tables_client({"delete_dataset.return_value": None}, {}) client.delete_dataset(dataset_name="name") - client.auto_ml_client.delete_dataset.assert_called_with(name="name") + client.auto_ml_client.delete_dataset.assert_called_with( + request=automl_v1beta1.DeleteDatasetRequest(name="name") + ) def test_export_not_found(self): client = self.tables_client({"list_datasets.return_value": []}, {}) @@ -179,14 +191,20 @@ def test_export_gcs_uri(self): client = self.tables_client({"export_data.return_value": None}, {}) client.export_data(dataset_name="name", gcs_output_uri_prefix="uri") client.auto_ml_client.export_data.assert_called_with( - name="name", output_config={"gcs_destination": {"output_uri_prefix": "uri"}} + request=automl_v1beta1.ExportDataRequest( + name="name", + output_config={"gcs_destination": {"output_uri_prefix": "uri"}}, + ) ) def test_export_bq_uri(self): client = self.tables_client({"export_data.return_value": None}, {}) client.export_data(dataset_name="name", bigquery_output_uri="uri") client.auto_ml_client.export_data.assert_called_with( - name="name", output_config={"bigquery_destination": {"output_uri": "uri"}} + request=automl_v1beta1.ExportDataRequest( + name="name", + output_config={"bigquery_destination": {"output_uri": "uri"}}, + ) ) def test_import_not_found(self): @@ -213,7 +231,9 @@ def test_import_pandas_dataframe(self): client.gcs_client.ensure_bucket_exists.assert_called_with(PROJECT, REGION) client.gcs_client.upload_pandas_dataframe.assert_called_with(dataframe) client.auto_ml_client.import_data.assert_called_with( - name="name", input_config={"gcs_source": {"input_uris": ["uri"]}} + request=automl_v1beta1.ImportDataRequest( + name="name", input_config={"gcs_source": {"input_uris": ["uri"]}} + ) ) def test_import_pandas_dataframe_init_gcs(self): @@ -240,34 +260,44 @@ def test_import_pandas_dataframe_init_gcs(self): client.gcs_client.ensure_bucket_exists.assert_called_with(PROJECT, REGION) client.gcs_client.upload_pandas_dataframe.assert_called_with(dataframe) client.auto_ml_client.import_data.assert_called_with( - name="name", input_config={"gcs_source": {"input_uris": ["uri"]}} + request=automl_v1beta1.ImportDataRequest( + name="name", input_config={"gcs_source": {"input_uris": ["uri"]}} + ) ) def test_import_gcs_uri(self): client = self.tables_client({"import_data.return_value": None}, {}) client.import_data(dataset_name="name", gcs_input_uris="uri") client.auto_ml_client.import_data.assert_called_with( - name="name", input_config={"gcs_source": {"input_uris": ["uri"]}} + request=automl_v1beta1.ImportDataRequest( + name="name", input_config={"gcs_source": {"input_uris": ["uri"]}} + ) ) def test_import_gcs_uris(self): client = self.tables_client({"import_data.return_value": None}, {}) client.import_data(dataset_name="name", gcs_input_uris=["uri", "uri"]) client.auto_ml_client.import_data.assert_called_with( - name="name", input_config={"gcs_source": {"input_uris": ["uri", "uri"]}} + request=automl_v1beta1.ImportDataRequest( + name="name", input_config={"gcs_source": {"input_uris": ["uri", "uri"]}} + ) ) def test_import_bq_uri(self): client = self.tables_client({"import_data.return_value": None}, {}) client.import_data(dataset_name="name", bigquery_input_uri="uri") client.auto_ml_client.import_data.assert_called_with( - name="name", input_config={"bigquery_source": {"input_uri": "uri"}} + request=automl_v1beta1.ImportDataRequest( + name="name", input_config={"bigquery_source": {"input_uri": "uri"}} + ) ) def test_list_table_specs(self): client = self.tables_client({"list_table_specs.return_value": None}, {}) client.list_table_specs(dataset_name="name") - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) def test_list_table_specs_not_found(self): client = self.tables_client( @@ -275,17 +305,23 @@ def test_list_table_specs_not_found(self): ) with pytest.raises(exceptions.NotFound): client.list_table_specs(dataset_name="name") - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) def test_get_table_spec(self): client = self.tables_client({}, {}) client.get_table_spec("name") - client.auto_ml_client.get_table_spec.assert_called_with(name="name") + client.auto_ml_client.get_table_spec.assert_called_with( + request=automl_v1beta1.GetTableSpecRequest(name="name") + ) def test_get_column_spec(self): client = self.tables_client({}, {}) client.get_column_spec("name") - client.auto_ml_client.get_column_spec.assert_called_with(name="name") + client.auto_ml_client.get_column_spec.assert_called_with( + request=automl_v1beta1.GetColumnSpecRequest(name="name") + ) def test_list_column_specs(self): table_spec_mock = mock.Mock() @@ -299,171 +335,238 @@ def test_list_column_specs(self): {}, ) client.list_column_specs(dataset_name="name") - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) def test_update_column_spec_not_found(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code="type_code") - column_spec_mock.configure_mock( - name="column", display_name="column", data_type=data_type_mock + + column_spec = automl_v1beta1.ColumnSpec( + name="column", + display_name="column", + data_type=automl_v1beta1.DataType(type_code=automl_v1beta1.TypeCode.STRING), ) + client = self.tables_client( - { + client_attrs={ "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [column_spec_mock], + "list_column_specs.return_value": [column_spec], }, - {}, + prediction_client_attrs={}, ) with pytest.raises(exceptions.NotFound): client.update_column_spec(dataset_name="name", column_spec_name="column2") - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_column_spec.assert_not_called() def test_update_column_spec_display_name_not_found(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code="type_code") - column_spec_mock.configure_mock( - name="column", display_name="column", data_type=data_type_mock + + column_spec = automl_v1beta1.ColumnSpec( + name="column", + display_name="column", + data_type=automl_v1beta1.DataType(type_code=automl_v1beta1.TypeCode.STRING), ) client = self.tables_client( - { + client_attrs={ "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [column_spec_mock], + "list_column_specs.return_value": [column_spec], }, - {}, + prediction_client_attrs={}, ) with pytest.raises(exceptions.NotFound): client.update_column_spec( dataset_name="name", column_spec_display_name="column2" ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_column_spec.assert_not_called() def test_update_column_spec_name_no_args(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code="type_code") - column_spec_mock.configure_mock( - name="column/2", display_name="column", data_type=data_type_mock + + column_spec = automl_v1beta1.ColumnSpec( + name="column/2", + display_name="column", + data_type=automl_v1beta1.DataType( + type_code=automl_v1beta1.TypeCode.FLOAT64 + ), ) + client = self.tables_client( { "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [column_spec_mock], + "list_column_specs.return_value": [column_spec], }, {}, ) client.update_column_spec(dataset_name="name", column_spec_name="column/2") - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_column_spec.assert_called_with( - column_spec={"name": "column/2", "data_type": {"type_code": "type_code"}} + request=automl_v1beta1.UpdateColumnSpecRequest( + column_spec={ + "name": "column/2", + "data_type": {"type_code": automl_v1beta1.TypeCode.FLOAT64}, + } + ) ) def test_update_column_spec_no_args(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code="type_code") - column_spec_mock.configure_mock( - name="column", display_name="column", data_type=data_type_mock + + column_spec = automl_v1beta1.ColumnSpec( + name="column", + display_name="column", + data_type=automl_v1beta1.DataType( + type_code=automl_v1beta1.TypeCode.FLOAT64 + ), ) + client = self.tables_client( { "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [column_spec_mock], + "list_column_specs.return_value": [column_spec], }, {}, ) client.update_column_spec( dataset_name="name", column_spec_display_name="column" ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_column_spec.assert_called_with( - column_spec={"name": "column", "data_type": {"type_code": "type_code"}} + request=automl_v1beta1.UpdateColumnSpecRequest( + column_spec={ + "name": "column", + "data_type": {"type_code": automl_v1beta1.TypeCode.FLOAT64}, + } + ) ) def test_update_column_spec_nullable(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code="type_code") - column_spec_mock.configure_mock( - name="column", display_name="column", data_type=data_type_mock + + column_spec = automl_v1beta1.ColumnSpec( + name="column", + display_name="column", + data_type=automl_v1beta1.DataType( + type_code=automl_v1beta1.TypeCode.FLOAT64 + ), ) + client = self.tables_client( { "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [column_spec_mock], + "list_column_specs.return_value": [column_spec], }, {}, ) client.update_column_spec( dataset_name="name", column_spec_display_name="column", nullable=True ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_column_spec.assert_called_with( - column_spec={ - "name": "column", - "data_type": {"type_code": "type_code", "nullable": True}, - } + request=automl_v1beta1.UpdateColumnSpecRequest( + column_spec={ + "name": "column", + "data_type": { + "type_code": automl_v1beta1.TypeCode.FLOAT64, + "nullable": True, + }, + } + ) ) def test_update_column_spec_type_code(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code="type_code") - column_spec_mock.configure_mock( - name="column", display_name="column", data_type=data_type_mock + column_spec = automl_v1beta1.ColumnSpec( + name="column", + display_name="column", + data_type=automl_v1beta1.DataType( + type_code=automl_v1beta1.TypeCode.FLOAT64 + ), ) client = self.tables_client( { "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [column_spec_mock], + "list_column_specs.return_value": [column_spec], }, {}, ) client.update_column_spec( dataset_name="name", column_spec_display_name="column", - type_code="type_code2", + type_code=automl_v1beta1.TypeCode.ARRAY, + ) + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") client.auto_ml_client.update_column_spec.assert_called_with( - column_spec={"name": "column", "data_type": {"type_code": "type_code2"}} + request=automl_v1beta1.UpdateColumnSpecRequest( + column_spec={ + "name": "column", + "data_type": {"type_code": automl_v1beta1.TypeCode.ARRAY}, + } + ) ) def test_update_column_spec_type_code_nullable(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code="type_code") - column_spec_mock.configure_mock( - name="column", display_name="column", data_type=data_type_mock + column_spec = automl_v1beta1.ColumnSpec( + name="column", + display_name="column", + data_type=automl_v1beta1.DataType( + type_code=automl_v1beta1.TypeCode.FLOAT64 + ), ) client = self.tables_client( { "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [column_spec_mock], + "list_column_specs.return_value": [column_spec], }, {}, ) @@ -471,30 +574,41 @@ def test_update_column_spec_type_code_nullable(self): dataset_name="name", nullable=True, column_spec_display_name="column", - type_code="type_code2", + type_code=automl_v1beta1.TypeCode.ARRAY, + ) + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") client.auto_ml_client.update_column_spec.assert_called_with( - column_spec={ - "name": "column", - "data_type": {"type_code": "type_code2", "nullable": True}, - } + request=automl_v1beta1.UpdateColumnSpecRequest( + column_spec={ + "name": "column", + "data_type": { + "type_code": automl_v1beta1.TypeCode.ARRAY, + "nullable": True, + }, + } + ) ) def test_update_column_spec_type_code_nullable_false(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock = mock.Mock() - data_type_mock = mock.Mock(type_code="type_code") - column_spec_mock.configure_mock( - name="column", display_name="column", data_type=data_type_mock + column_spec = automl_v1beta1.ColumnSpec( + name="column", + display_name="column", + data_type=automl_v1beta1.DataType( + type_code=automl_v1beta1.TypeCode.FLOAT64 + ), ) client = self.tables_client( { "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [column_spec_mock], + "list_column_specs.return_value": [column_spec], }, {}, ) @@ -502,15 +616,24 @@ def test_update_column_spec_type_code_nullable_false(self): dataset_name="name", nullable=False, column_spec_display_name="column", - type_code="type_code2", + type_code=automl_v1beta1.TypeCode.FLOAT64, + ) + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") client.auto_ml_client.update_column_spec.assert_called_with( - column_spec={ - "name": "column", - "data_type": {"type_code": "type_code2", "nullable": False}, - } + request=automl_v1beta1.UpdateColumnSpecRequest( + column_spec={ + "name": "column", + "data_type": { + "type_code": automl_v1beta1.TypeCode.FLOAT64, + "nullable": False, + }, + } + ) ) def test_set_target_column_table_not_found(self): @@ -521,7 +644,9 @@ def test_set_target_column_table_not_found(self): client.set_target_column( dataset_name="name", column_spec_display_name="column2" ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) client.auto_ml_client.list_column_specs.assert_not_called() client.auto_ml_client.update_dataset.assert_not_called() @@ -542,8 +667,12 @@ def test_set_target_column_not_found(self): client.set_target_column( dataset_name="name", column_spec_display_name="column2" ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_dataset.assert_not_called() def test_set_target_column(self): @@ -571,17 +700,23 @@ def test_set_target_column(self): {}, ) client.set_target_column(dataset_name="name", column_spec_display_name="column") - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_dataset.assert_called_with( - dataset={ - "name": "dataset", - "tables_dataset_metadata": { - "target_column_spec_id": "1", - "weight_column_spec_id": "2", - "ml_use_column_spec_id": "3", - }, - } + request=automl_v1beta1.UpdateDatasetRequest( + dataset={ + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": "3", + }, + } + ) ) def test_set_weight_column_table_not_found(self): @@ -594,7 +729,9 @@ def test_set_weight_column_table_not_found(self): ) except exceptions.NotFound: pass - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) client.auto_ml_client.list_column_specs.assert_not_called() client.auto_ml_client.update_dataset.assert_not_called() @@ -615,8 +752,12 @@ def test_set_weight_column_not_found(self): client.set_weight_column( dataset_name="name", column_spec_display_name="column2" ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_dataset.assert_not_called() def test_set_weight_column(self): @@ -644,17 +785,23 @@ def test_set_weight_column(self): {}, ) client.set_weight_column(dataset_name="name", column_spec_display_name="column") - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_dataset.assert_called_with( - dataset={ - "name": "dataset", - "tables_dataset_metadata": { - "target_column_spec_id": "1", - "weight_column_spec_id": "2", - "ml_use_column_spec_id": "3", - }, - } + request=automl_v1beta1.UpdateDatasetRequest( + dataset={ + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": "3", + }, + } + ) ) def test_clear_weight_column(self): @@ -671,14 +818,16 @@ def test_clear_weight_column(self): client = self.tables_client({"get_dataset.return_value": dataset_mock}, {}) client.clear_weight_column(dataset_name="name") client.auto_ml_client.update_dataset.assert_called_with( - dataset={ - "name": "dataset", - "tables_dataset_metadata": { - "target_column_spec_id": "1", - "weight_column_spec_id": None, - "ml_use_column_spec_id": "3", - }, - } + request=automl_v1beta1.UpdateDatasetRequest( + dataset={ + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": None, + "ml_use_column_spec_id": "3", + }, + } + ) ) def test_set_test_train_column_table_not_found(self): @@ -689,7 +838,9 @@ def test_set_test_train_column_table_not_found(self): client.set_test_train_column( dataset_name="name", column_spec_display_name="column2" ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) client.auto_ml_client.list_column_specs.assert_not_called() client.auto_ml_client.update_dataset.assert_not_called() @@ -710,8 +861,12 @@ def test_set_test_train_column_not_found(self): client.set_test_train_column( dataset_name="name", column_spec_display_name="column2" ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_dataset.assert_not_called() def test_set_test_train_column(self): @@ -741,17 +896,23 @@ def test_set_test_train_column(self): client.set_test_train_column( dataset_name="name", column_spec_display_name="column" ) - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_dataset.assert_called_with( - dataset={ - "name": "dataset", - "tables_dataset_metadata": { - "target_column_spec_id": "1", - "weight_column_spec_id": "2", - "ml_use_column_spec_id": "3", - }, - } + request=automl_v1beta1.UpdateDatasetRequest( + dataset={ + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": "3", + }, + } + ) ) def test_clear_test_train_column(self): @@ -768,14 +929,16 @@ def test_clear_test_train_column(self): client = self.tables_client({"get_dataset.return_value": dataset_mock}, {}) client.clear_test_train_column(dataset_name="name") client.auto_ml_client.update_dataset.assert_called_with( - dataset={ - "name": "dataset", - "tables_dataset_metadata": { - "target_column_spec_id": "1", - "weight_column_spec_id": "2", - "ml_use_column_spec_id": None, - }, - } + request=automl_v1beta1.UpdateDatasetRequest( + dataset={ + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": None, + }, + } + ) ) def test_set_time_column(self): @@ -795,10 +958,16 @@ def test_set_time_column(self): {}, ) client.set_time_column(dataset_name="name", column_spec_display_name="column") - client.auto_ml_client.list_table_specs.assert_called_with(parent="name") - client.auto_ml_client.list_column_specs.assert_called_with(parent="table") + client.auto_ml_client.list_table_specs.assert_called_with( + request=automl_v1beta1.ListTableSpecsRequest(parent="name") + ) + client.auto_ml_client.list_column_specs.assert_called_with( + request=automl_v1beta1.ListColumnSpecsRequest(parent="table") + ) client.auto_ml_client.update_table_spec.assert_called_with( - table_spec={"name": "table", "time_column_spec_id": "3"} + request=automl_v1beta1.UpdateTableSpecRequest( + table_spec={"name": "table", "time_column_spec_id": "3"} + ) ) def test_clear_time_column(self): @@ -816,18 +985,24 @@ def test_clear_time_column(self): ) client.clear_time_column(dataset_name="name") client.auto_ml_client.update_table_spec.assert_called_with( - table_spec={"name": "table", "time_column_spec_id": None} + request=automl_v1beta1.UpdateTableSpecRequest( + table_spec={"name": "table", "time_column_spec_id": None} + ) ) def test_get_model_evaluation(self): client = self.tables_client({}, {}) client.get_model_evaluation(model_evaluation_name="x") - client.auto_ml_client.get_model_evaluation.assert_called_with(name="x") + client.auto_ml_client.get_model_evaluation.assert_called_with( + request=automl_v1beta1.GetModelEvaluationRequest(name="x") + ) def test_list_model_evaluations_empty(self): client = self.tables_client({"list_model_evaluations.return_value": []}, {}) ds = client.list_model_evaluations(model_name="model") - client.auto_ml_client.list_model_evaluations.assert_called_with(parent="model") + client.auto_ml_client.list_model_evaluations.assert_called_with( + request=automl_v1beta1.ListModelEvaluationsRequest(parent="model") + ) assert ds == [] def test_list_model_evaluations_not_empty(self): @@ -840,7 +1015,9 @@ def test_list_model_evaluations_not_empty(self): {}, ) ds = client.list_model_evaluations(model_name="model") - client.auto_ml_client.list_model_evaluations.assert_called_with(parent="model") + client.auto_ml_client.list_model_evaluations.assert_called_with( + request=automl_v1beta1.ListModelEvaluationsRequest(parent="model") + ) assert len(ds) == 1 assert ds[0] == "eval" @@ -853,8 +1030,10 @@ def test_list_models_empty(self): {}, ) ds = client.list_models() - client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) - client.auto_ml_client.list_models.assert_called_with(parent=LOCATION_PATH) + + client.auto_ml_client.list_models.assert_called_with( + request=automl_v1beta1.ListModelsRequest(parent=LOCATION_PATH) + ) assert ds == [] def test_list_models_not_empty(self): @@ -867,8 +1046,10 @@ def test_list_models_not_empty(self): {}, ) ds = client.list_models() - client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) - client.auto_ml_client.list_models.assert_called_with(parent=LOCATION_PATH) + + client.auto_ml_client.list_models.assert_called_with( + request=automl_v1beta1.ListModelsRequest(parent=LOCATION_PATH) + ) assert len(ds) == 1 assert ds[0] == "some_model" @@ -931,7 +1112,9 @@ def test_delete_model(self): model.configure_mock(name="name") client = self.tables_client({"delete_model.return_value": None}, {}) client.delete_model(model=model) - client.auto_ml_client.delete_model.assert_called_with(name="name") + client.auto_ml_client.delete_model.assert_called_with( + request=automl_v1beta1.DeleteModelRequest(name="name") + ) def test_delete_model_not_found(self): client = self.tables_client({"list_models.return_value": []}, {}) @@ -941,7 +1124,9 @@ def test_delete_model_not_found(self): def test_delete_model_name(self): client = self.tables_client({"delete_model.return_value": None}, {}) client.delete_model(model_name="name") - client.auto_ml_client.delete_model.assert_called_with(name="name") + client.auto_ml_client.delete_model.assert_called_with( + request=automl_v1beta1.DeleteModelRequest(name="name") + ) def test_deploy_model_no_args(self): client = self.tables_client({}, {}) @@ -952,7 +1137,9 @@ def test_deploy_model_no_args(self): def test_deploy_model(self): client = self.tables_client({}, {}) client.deploy_model(model_name="name") - client.auto_ml_client.deploy_model.assert_called_with(name="name") + client.auto_ml_client.deploy_model.assert_called_with( + request=automl_v1beta1.DeployModelRequest(name="name") + ) def test_deploy_model_not_found(self): client = self.tables_client({"list_models.return_value": []}, {}) @@ -963,7 +1150,9 @@ def test_deploy_model_not_found(self): def test_undeploy_model(self): client = self.tables_client({}, {}) client.undeploy_model(model_name="name") - client.auto_ml_client.undeploy_model.assert_called_with(name="name") + client.auto_ml_client.undeploy_model.assert_called_with( + request=automl_v1beta1.UndeployModelRequest(name="name") + ) def test_undeploy_model_not_found(self): client = self.tables_client({"list_models.return_value": []}, {}) @@ -989,32 +1178,37 @@ def test_create_model(self): "my_model", dataset_name="my_dataset", train_budget_milli_node_hours=1000 ) client.auto_ml_client.create_model.assert_called_with( - parent=LOCATION_PATH, - model={ - "display_name": "my_model", - "dataset_id": "my_dataset", - "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, - }, + request=automl_v1beta1.CreateModelRequest( + parent=LOCATION_PATH, + model={ + "display_name": "my_model", + "dataset_id": "my_dataset", + "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, + }, + ) ) def test_create_model_include_columns(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock1 = mock.Mock() - column_spec_mock1.configure_mock(name="column/1", display_name="column1") - column_spec_mock2 = mock.Mock() - column_spec_mock2.configure_mock(name="column/2", display_name="column2") + + column_spec_1 = automl_v1beta1.ColumnSpec( + name="column/1", display_name="column1" + ) + column_spec_2 = automl_v1beta1.ColumnSpec( + name="column/2", display_name="column2" + ) + client = self.tables_client( - { - "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [ - column_spec_mock1, - column_spec_mock2, + client_attrs={ + "list_table_specs.return_value": [ + automl_v1beta1.TableSpec(name="table") ], + "list_column_specs.return_value": [column_spec_1, column_spec_2], "location_path.return_value": LOCATION_PATH, }, - {}, + prediction_client_attrs={}, ) client.create_model( "my_model", @@ -1023,35 +1217,37 @@ def test_create_model_include_columns(self): train_budget_milli_node_hours=1000, ) client.auto_ml_client.create_model.assert_called_with( - parent=LOCATION_PATH, - model={ - "display_name": "my_model", - "dataset_id": "my_dataset", - "tables_model_metadata": { - "train_budget_milli_node_hours": 1000, - "input_feature_column_specs": [column_spec_mock1], - }, - }, + request=automl_v1beta1.CreateModelRequest( + parent=LOCATION_PATH, + model=automl_v1beta1.Model( + display_name="my_model", + dataset_id="my_dataset", + tables_model_metadata=automl_v1beta1.TablesModelMetadata( + train_budget_milli_node_hours=1000, + input_feature_column_specs=[column_spec_1], + ), + ), + ) ) def test_create_model_exclude_columns(self): table_spec_mock = mock.Mock() # name is reserved in use of __init__, needs to be passed here table_spec_mock.configure_mock(name="table") - column_spec_mock1 = mock.Mock() - column_spec_mock1.configure_mock(name="column/1", display_name="column1") - column_spec_mock2 = mock.Mock() - column_spec_mock2.configure_mock(name="column/2", display_name="column2") + + column_spec_1 = automl_v1beta1.ColumnSpec( + name="column/1", display_name="column1" + ) + column_spec_2 = automl_v1beta1.ColumnSpec( + name="column/2", display_name="column2" + ) client = self.tables_client( - { + client_attrs={ "list_table_specs.return_value": [table_spec_mock], - "list_column_specs.return_value": [ - column_spec_mock1, - column_spec_mock2, - ], + "list_column_specs.return_value": [column_spec_1, column_spec_2], "location_path.return_value": LOCATION_PATH, }, - {}, + prediction_client_attrs={}, ) client.create_model( "my_model", @@ -1060,15 +1256,17 @@ def test_create_model_exclude_columns(self): train_budget_milli_node_hours=1000, ) client.auto_ml_client.create_model.assert_called_with( - parent=LOCATION_PATH, - model={ - "display_name": "my_model", - "dataset_id": "my_dataset", - "tables_model_metadata": { - "train_budget_milli_node_hours": 1000, - "input_feature_column_specs": [column_spec_mock2], - }, - }, + request=automl_v1beta1.CreateModelRequest( + parent=LOCATION_PATH, + model=automl_v1beta1.Model( + display_name="my_model", + dataset_id="my_dataset", + tables_model_metadata=automl_v1beta1.TablesModelMetadata( + train_budget_milli_node_hours=1000, + input_feature_column_specs=[column_spec_2], + ), + ), + ) ) def test_create_model_invalid_hours_small(self): @@ -1125,7 +1323,9 @@ def test_predict_from_array(self): payload = data_items.ExamplePayload(row=row) client.prediction_client.predict.assert_called_with( - name="my_model", payload=payload, params=None + request=automl_v1beta1.PredictRequest( + name="my_model", payload=payload, params=None + ) ) def test_predict_from_dict(self): @@ -1149,7 +1349,9 @@ def test_predict_from_dict(self): payload = data_items.ExamplePayload(row=row) client.prediction_client.predict.assert_called_with( - name="my_model", payload=payload, params=None + request=automl_v1beta1.PredictRequest( + name="my_model", payload=payload, params=None + ) ) def test_predict_from_dict_with_feature_importance(self): @@ -1175,7 +1377,9 @@ def test_predict_from_dict_with_feature_importance(self): payload = data_items.ExamplePayload(row=row) client.prediction_client.predict.assert_called_with( - name="my_model", payload=payload, params={"feature_importance": "true"} + request=automl_v1beta1.PredictRequest( + name="my_model", payload=payload, params={"feature_importance": "true"} + ) ) def test_predict_from_dict_missing(self): @@ -1199,7 +1403,9 @@ def test_predict_from_dict_missing(self): payload = data_items.ExamplePayload(row=row) client.prediction_client.predict.assert_called_with( - name="my_model", payload=payload, params=None + request=automl_v1beta1.PredictRequest( + name="my_model", payload=payload, params=None + ) ) def test_predict_all_types(self): @@ -1282,7 +1488,9 @@ def test_predict_all_types(self): payload = data_items.ExamplePayload(row=row) client.prediction_client.predict.assert_called_with( - name="my_model", payload=payload, params=None + request=automl_v1beta1.PredictRequest( + name="my_model", payload=payload, params=None + ) ) def test_predict_from_array_missing(self): @@ -1316,9 +1524,11 @@ def test_batch_predict_pandas_dataframe(self): client.gcs_client.upload_pandas_dataframe.assert_called_with(dataframe) client.prediction_client.batch_predict.assert_called_with( - name="my_model", - input_config={"gcs_source": {"input_uris": ["gs://input"]}}, - output_config={"gcs_destination": {"output_uri_prefix": "gs://output"}}, + request=automl_v1beta1.BatchPredictRequest( + name="my_model", + input_config={"gcs_source": {"input_uris": ["gs://input"]}}, + output_config={"gcs_destination": {"output_uri_prefix": "gs://output"}}, + ) ) def test_batch_predict_pandas_dataframe_init_gcs(self): @@ -1350,9 +1560,13 @@ def test_batch_predict_pandas_dataframe_init_gcs(self): client.gcs_client.upload_pandas_dataframe.assert_called_with(dataframe) client.prediction_client.batch_predict.assert_called_with( - name="my_model", - input_config={"gcs_source": {"input_uris": ["gs://input"]}}, - output_config={"gcs_destination": {"output_uri_prefix": "gs://output"}}, + request=automl_v1beta1.BatchPredictRequest( + name="my_model", + input_config={"gcs_source": {"input_uris": ["gs://input"]}}, + output_config={ + "gcs_destination": {"output_uri_prefix": "gs://output"} + }, + ) ) def test_batch_predict_gcs(self): @@ -1363,9 +1577,11 @@ def test_batch_predict_gcs(self): gcs_output_uri_prefix="gs://output", ) client.prediction_client.batch_predict.assert_called_with( - name="my_model", - input_config={"gcs_source": {"input_uris": ["gs://input"]}}, - output_config={"gcs_destination": {"output_uri_prefix": "gs://output"}}, + request=automl_v1beta1.BatchPredictRequest( + name="my_model", + input_config={"gcs_source": {"input_uris": ["gs://input"]}}, + output_config={"gcs_destination": {"output_uri_prefix": "gs://output"}}, + ) ) def test_batch_predict_bigquery(self): @@ -1376,9 +1592,11 @@ def test_batch_predict_bigquery(self): bigquery_output_uri="bq://output", ) client.prediction_client.batch_predict.assert_called_with( - name="my_model", - input_config={"bigquery_source": {"input_uri": "bq://input"}}, - output_config={"bigquery_destination": {"output_uri": "bq://output"}}, + request=automl_v1beta1.BatchPredictRequest( + name="my_model", + input_config={"bigquery_source": {"input_uri": "bq://input"}}, + output_config={"bigquery_destination": {"output_uri": "bq://output"}}, + ) ) def test_batch_predict_mixed(self): @@ -1389,9 +1607,11 @@ def test_batch_predict_mixed(self): bigquery_output_uri="bq://output", ) client.prediction_client.batch_predict.assert_called_with( - name="my_model", - input_config={"gcs_source": {"input_uris": ["gs://input"]}}, - output_config={"bigquery_destination": {"output_uri": "bq://output"}}, + request=automl_v1beta1.BatchPredictRequest( + name="my_model", + input_config={"gcs_source": {"input_uris": ["gs://input"]}}, + output_config={"bigquery_destination": {"output_uri": "bq://output"}}, + ) ) def test_batch_predict_missing_input_gcs_uri(self):