From 5205339b99b22e221a74cfd3a5a441a42a1da5ee Mon Sep 17 00:00:00 2001 From: Ning Date: Mon, 26 Aug 2019 11:45:46 -0700 Subject: [PATCH] versioned the 0.13.0 tfx example (#1912) * versioned the 0.13.0 tfx example * use the taxi_utils from 0.13.1 * update tfx oss readme --- samples/core/tfx-oss/README.md | 23 +- samples/core/tfx-oss/TFX Example.ipynb | 207 ++++++++++++- samples/core/tfx-oss/utils/taxi_utils.py | 368 +++++++++++++++++++++++ 3 files changed, 574 insertions(+), 24 deletions(-) create mode 100644 samples/core/tfx-oss/utils/taxi_utils.py diff --git a/samples/core/tfx-oss/README.md b/samples/core/tfx-oss/README.md index de8716c0dca..053e9222bd0 100644 --- a/samples/core/tfx-oss/README.md +++ b/samples/core/tfx-oss/README.md @@ -18,34 +18,28 @@ conda create -n tfx-kfp pip python=3.5.3 then activate the environment. -Install TensorFlow, TFX and Kubeflow Pipelines SDK +Install TFX and Kubeflow Pipelines SDK ``` -pip install tensorflow --upgrade -pip install tfx +pip3 install tfx==0.13.0 --upgrade pip install kfp --upgrade ``` -Clone TFX github repo -``` -git clone https://github.com/tensorflow/tfx -``` - Upload the utility code to your storage bucket. You can modify this code if needed for a different dataset. ``` -gsutil cp tfx/tfx/examples/chicago_taxi_pipeline/taxi_utils.py gs://my-bucket// +gsutil cp utils/taxi_utils.py gs://my-bucket// ``` If gsutil does not work, try `tensorflow.gfile`: ``` from tensorflow import gfile -gfile.Copy('tfx/tfx/examples/chicago_taxi_pipeline/taxi_utils.py', 'gs:////taxi_utils.py') +gfile.Copy('utils/taxi_utils.py', 'gs:////taxi_utils.py') ``` ## Configure the TFX Pipeline -Modify the pipeline configuration file at +Modify the pipeline configurations at ``` -tfx/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_kubeflow.py +TFX Example.ipynb ``` Configure - Set `_input_bucket` to the GCS directory where you've copied taxi_utils.py. I.e. gs://// @@ -54,8 +48,7 @@ Configure - The original BigQuery dataset has 100M rows, which can take time to process. Modify the selection criteria (% of records) to run a sample test. ## Compile and run the pipeline -``` -python tfx/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_kubeflow.py -``` +Run the notebook. + This will generate a file named chicago_taxi_pipeline_kubeflow.tar.gz Upload this file to the Pipelines Cluster and create a run. diff --git a/samples/core/tfx-oss/TFX Example.ipynb b/samples/core/tfx-oss/TFX Example.ipynb index 180341e167c..f98450ade3a 100644 --- a/samples/core/tfx-oss/TFX Example.ipynb +++ b/samples/core/tfx-oss/TFX Example.ipynb @@ -17,8 +17,8 @@ "metadata": {}, "outputs": [], "source": [ - "!pip3 install https://storage.googleapis.com/ml-pipeline/tfx/tfx-0.12.0rc0-py2.py3-none-any.whl \n", - "!pip3 install kfp --upgrade\n" + "!pip3 install tfx==0.13.0 --upgrade\n", + "!pip3 install kfp --upgrade" ] }, { @@ -39,10 +39,21 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "parameters" + ] + }, "outputs": [], "source": [ - "!git clone https://github.com/tensorflow/tfx" + "# Directory and data locations (uses Google Cloud Storage).\n", + "import os\n", + "_input_bucket = ''\n", + "_output_bucket = ''\n", + "_pipeline_root = os.path.join(_output_bucket, 'tfx')\n", + "\n", + "# Google Cloud Platform project id to use when deploying this pipeline.\n", + "_project_id = ''" ] }, { @@ -53,7 +64,7 @@ "source": [ "# copy the trainer code to a storage bucket as the TFX pipeline will need that code file in GCS\n", "from tensorflow import gfile\n", - "gfile.Copy('tfx/examples/chicago_taxi_pipeline/taxi_utils.py', 'gs:////taxi_utils.py')" + "gfile.Copy('utils/taxi_utils.py', _input_bucket + '/taxi_utils.py')" ] }, { @@ -77,11 +88,180 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "%load tfx/examples/chicago_taxi_pipeline/taxi_pipeline_kubeflow.py" + "\"\"\"Chicago Taxi example using TFX DSL on Kubeflow.\"\"\"\n", + "\n", + "from __future__ import absolute_import\n", + "from __future__ import division\n", + "from __future__ import print_function\n", + "\n", + "import os\n", + "from tfx.components.evaluator.component import Evaluator\n", + "from tfx.components.example_gen.big_query_example_gen.component import BigQueryExampleGen\n", + "from tfx.components.example_validator.component import ExampleValidator\n", + "from tfx.components.model_validator.component import ModelValidator\n", + "from tfx.components.pusher.component import Pusher\n", + "from tfx.components.schema_gen.component import SchemaGen\n", + "from tfx.components.statistics_gen.component import StatisticsGen\n", + "from tfx.components.trainer.component import Trainer\n", + "from tfx.components.transform.component import Transform\n", + "from tfx.orchestration.kubeflow.runner import KubeflowRunner\n", + "from tfx.orchestration.pipeline import PipelineDecorator\n", + "from tfx.proto import evaluator_pb2\n", + "from tfx.proto import pusher_pb2\n", + "from tfx.proto import trainer_pb2\n", + "\n", + "# Python module file to inject customized logic into the TFX components. The\n", + "# Transform and Trainer both require user-defined functions to run successfully.\n", + "# Copy this from the current directory to a GCS bucket and update the location\n", + "# below.\n", + "_taxi_utils = os.path.join(_input_bucket, 'taxi_utils.py')\n", + "\n", + "# Path which can be listened to by the model server. Pusher will output the\n", + "# trained model here.\n", + "_serving_model_dir = os.path.join(_output_bucket, 'serving_model/taxi_bigquery')\n", + "\n", + "# Region to use for Dataflow jobs and CMLE training.\n", + "# Dataflow: https://cloud.google.com/dataflow/docs/concepts/regional-endpoints\n", + "# CMLE: https://cloud.google.com/ml-engine/docs/tensorflow/regions\n", + "_gcp_region = 'us-central1'\n", + "\n", + "# A dict which contains the training job parameters to be passed to Google\n", + "# Cloud ML Engine. For the full set of parameters supported by Google Cloud ML\n", + "# Engine, refer to\n", + "# https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job\n", + "_cmle_training_args = {\n", + " 'pythonModule': None, # Will be populated by TFX\n", + " 'args': None, # Will be populated by TFX\n", + " 'region': _gcp_region,\n", + " 'jobDir': os.path.join(_output_bucket, 'tmp'),\n", + " 'runtimeVersion': '1.12',\n", + " 'pythonVersion': '2.7',\n", + " 'project': _project_id,\n", + "}\n", + "\n", + "# A dict which contains the serving job parameters to be passed to Google\n", + "# Cloud ML Engine. For the full set of parameters supported by Google Cloud ML\n", + "# Engine, refer to\n", + "# https://cloud.google.com/ml-engine/reference/rest/v1/projects.models\n", + "_cmle_serving_args = {\n", + " 'model_name': 'chicago_taxi',\n", + " 'project_id': _project_id,\n", + " 'runtime_version': '1.12',\n", + "}\n", + "\n", + "# The rate at which to sample rows from the Chicago Taxi dataset using BigQuery.\n", + "# The full taxi dataset is > 120M record. In the interest of resource\n", + "# savings and time, we've set the default for this example to be much smaller.\n", + "# Feel free to crank it up and process the full dataset!\n", + "_query_sample_rate = 0.001 # Generate a 0.1% random sample.\n", + "\n", + "\n", + "# TODO(zhitaoli): Remove PipelineDecorator after 0.13.0.\n", + "@PipelineDecorator(\n", + " pipeline_name='chicago_taxi_pipeline_kubeflow',\n", + " log_root='/var/tmp/tfx/logs',\n", + " pipeline_root=_pipeline_root,\n", + " additional_pipeline_args={\n", + " 'beam_pipeline_args': [\n", + " '--runner=DataflowRunner',\n", + " '--experiments=shuffle_mode=auto',\n", + " '--project=' + _project_id,\n", + " '--temp_location=' + os.path.join(_output_bucket, 'tmp'),\n", + " '--region=' + _gcp_region,\n", + " ],\n", + " # Optional args:\n", + " # 'tfx_image': custom docker image to use for components. This is needed\n", + " # if TFX package is not installed from an RC or released version.\n", + " })\n", + "def _create_pipeline():\n", + " \"\"\"Implements the chicago taxi pipeline with TFX.\"\"\"\n", + "\n", + " query = \"\"\"\n", + " SELECT\n", + " pickup_community_area,\n", + " fare,\n", + " EXTRACT(MONTH FROM trip_start_timestamp) AS trip_start_month,\n", + " EXTRACT(HOUR FROM trip_start_timestamp) AS trip_start_hour,\n", + " EXTRACT(DAYOFWEEK FROM trip_start_timestamp) AS trip_start_day,\n", + " UNIX_SECONDS(trip_start_timestamp) AS trip_start_timestamp,\n", + " pickup_latitude,\n", + " pickup_longitude,\n", + " dropoff_latitude,\n", + " dropoff_longitude,\n", + " trip_miles,\n", + " pickup_census_tract,\n", + " dropoff_census_tract,\n", + " payment_type,\n", + " company,\n", + " trip_seconds,\n", + " dropoff_community_area,\n", + " tips\n", + " FROM `bigquery-public-data.chicago_taxi_trips.taxi_trips`\n", + " WHERE RAND() < {}\"\"\".format(_query_sample_rate)\n", + "\n", + " # Brings data into the pipeline or otherwise joins/converts training data.\n", + " example_gen = BigQueryExampleGen(query=query)\n", + "\n", + " # Computes statistics over data for visualization and example validation.\n", + " statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples)\n", + "\n", + " # Generates schema based on statistics files.\n", + " infer_schema = SchemaGen(stats=statistics_gen.outputs.output)\n", + "\n", + " # Performs anomaly detection based on statistics and data schema.\n", + " validate_stats = ExampleValidator(\n", + " stats=statistics_gen.outputs.output, schema=infer_schema.outputs.output)\n", + "\n", + " # Performs transformations and feature engineering in training and serving.\n", + " transform = Transform(\n", + " input_data=example_gen.outputs.examples,\n", + " schema=infer_schema.outputs.output,\n", + " module_file=_taxi_utils)\n", + "\n", + " # Uses user-provided Python function that implements a model using TF-Learn.\n", + " trainer = Trainer(\n", + " module_file=_taxi_utils,\n", + " transformed_examples=transform.outputs.transformed_examples,\n", + " schema=infer_schema.outputs.output,\n", + " transform_output=transform.outputs.transform_output,\n", + " train_args=trainer_pb2.TrainArgs(num_steps=10000),\n", + " eval_args=trainer_pb2.EvalArgs(num_steps=5000),\n", + " custom_config={'cmle_training_args': _cmle_training_args})\n", + "\n", + " # Uses TFMA to compute a evaluation statistics over features of a model.\n", + " model_analyzer = Evaluator(\n", + " examples=example_gen.outputs.examples,\n", + " model_exports=trainer.outputs.output,\n", + " feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[\n", + " evaluator_pb2.SingleSlicingSpec(\n", + " column_for_slicing=['trip_start_hour'])\n", + " ]))\n", + "\n", + " # Performs quality validation of a candidate model (compared to a baseline).\n", + " model_validator = ModelValidator(\n", + " examples=example_gen.outputs.examples, model=trainer.outputs.output)\n", + "\n", + " # Checks whether the model passed the validation steps and pushes the model\n", + " # to a file destination if check passed.\n", + " pusher = Pusher(\n", + " model_export=trainer.outputs.output,\n", + " model_blessing=model_validator.outputs.blessing,\n", + " custom_config={'cmle_serving_args': _cmle_serving_args},\n", + " push_destination=pusher_pb2.PushDestination(\n", + " filesystem=pusher_pb2.PushDestination.Filesystem(\n", + " base_directory=_serving_model_dir)))\n", + "\n", + " return [\n", + " example_gen, statistics_gen, infer_schema, validate_stats, transform,\n", + " trainer, model_analyzer, model_validator, pusher\n", + " ]\n", + "\n", + "\n", + "pipeline = KubeflowRunner().run(_create_pipeline())" ] }, { @@ -188,9 +368,18 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.4" + "version": "3.6.7" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "source": [], + "metadata": { + "collapsed": false + } + } } }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/samples/core/tfx-oss/utils/taxi_utils.py b/samples/core/tfx-oss/utils/taxi_utils.py new file mode 100644 index 00000000000..6dca172e6ca --- /dev/null +++ b/samples/core/tfx-oss/utils/taxi_utils.py @@ -0,0 +1,368 @@ +# Copyright 2019 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python source file include taxi pipeline functions and necesasry utils. + +For a TFX pipeline to successfully run, a preprocessing_fn and a +_build_estimator function needs to be provided. This file contains both. + +This file is equivalent to examples/chicago_taxi/trainer/model.py and +examples/chicago_taxi/preprocess.py. +""" + +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf +import tensorflow_model_analysis as tfma +import tensorflow_transform as tft +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.saved import saved_transform_io +from tensorflow_transform.tf_metadata import metadata_io +from tensorflow_transform.tf_metadata import schema_utils + +# Categorical features are assumed to each have a maximum value in the dataset. +_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12] + +_CATEGORICAL_FEATURE_KEYS = [ + 'trip_start_hour', 'trip_start_day', 'trip_start_month', + 'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area', + 'dropoff_community_area' +] + +_DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds'] + +# Number of buckets used by tf.transform for encoding each feature. +_FEATURE_BUCKET_COUNT = 10 + +_BUCKET_FEATURE_KEYS = [ + 'pickup_latitude', 'pickup_longitude', 'dropoff_latitude', + 'dropoff_longitude' +] + +# Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform +_VOCAB_SIZE = 1000 + +# Count of out-of-vocab buckets in which unrecognized VOCAB_FEATURES are hashed. +_OOV_SIZE = 10 + +_VOCAB_FEATURE_KEYS = [ + 'payment_type', + 'company', +] + +# Keys +_LABEL_KEY = 'tips' +_FARE_KEY = 'fare' + + +def _transformed_name(key): + return key + '_xf' + + +def _transformed_names(keys): + return [_transformed_name(key) for key in keys] + + +# Tf.Transform considers these features as "raw" +def _get_raw_feature_spec(schema): + return schema_utils.schema_as_feature_spec(schema).feature_spec + + +def _gzip_reader_fn(): + """Small utility returning a record reader that can read gzip'ed files.""" + return tf.TFRecordReader( + options=tf.python_io.TFRecordOptions( + compression_type=tf.python_io.TFRecordCompressionType.GZIP)) + + +def _fill_in_missing(x): + """Replace missing values in a SparseTensor. + + Fills in missing values of `x` with '' or 0, and converts to a dense tensor. + + Args: + x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 + in the second dimension. + + Returns: + A rank 1 tensor where missing values of `x` have been filled in. + """ + default_value = '' if x.dtype == tf.string else 0 + return tf.squeeze( + tf.sparse.to_dense( + tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), + default_value), + axis=1) + + +def preprocessing_fn(inputs): + """tf.transform's callback function for preprocessing inputs. + + Args: + inputs: map from feature keys to raw not-yet-transformed features. + + Returns: + Map from string feature key to transformed feature operations. + """ + outputs = {} + for key in _DENSE_FLOAT_FEATURE_KEYS: + # Preserve this feature as a dense float, setting nan's to the mean. + outputs[_transformed_name(key)] = tft.scale_to_z_score( + _fill_in_missing(inputs[key])) + + for key in _VOCAB_FEATURE_KEYS: + # Build a vocabulary for this feature. + outputs[_transformed_name(key)] = tft.compute_and_apply_vocabulary( + _fill_in_missing(inputs[key]), + top_k=_VOCAB_SIZE, + num_oov_buckets=_OOV_SIZE) + + for key in _BUCKET_FEATURE_KEYS: + outputs[_transformed_name(key)] = tft.bucketize( + _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT) + + for key in _CATEGORICAL_FEATURE_KEYS: + outputs[_transformed_name(key)] = _fill_in_missing(inputs[key]) + + # Was this passenger a big tipper? + taxi_fare = _fill_in_missing(inputs[_FARE_KEY]) + tips = _fill_in_missing(inputs[_LABEL_KEY]) + outputs[_transformed_name(_LABEL_KEY)] = tf.where( + tf.is_nan(taxi_fare), + tf.cast(tf.zeros_like(taxi_fare), tf.int64), + # Test if the tip was > 20% of the fare. + tf.cast( + tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64)) + + return outputs + + +def _build_estimator(config, hidden_units=None, warm_start_from=None): + """Build an estimator for predicting the tipping behavior of taxi riders. + + Args: + config: tf.contrib.learn.RunConfig defining the runtime environment for the + estimator (including model_dir). + hidden_units: [int], the layer sizes of the DNN (input layer first) + warm_start_from: Optional directory to warm start from. + + Returns: + A dict of the following: + - estimator: The estimator that will be used for training and eval. + - train_spec: Spec for training. + - eval_spec: Spec for eval. + - eval_input_receiver_fn: Input function for eval. + """ + real_valued_columns = [ + tf.feature_column.numeric_column(key, shape=()) + for key in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) + ] + categorical_columns = [ + tf.feature_column.categorical_column_with_identity( + key, num_buckets=_VOCAB_SIZE + _OOV_SIZE, default_value=0) + for key in _transformed_names(_VOCAB_FEATURE_KEYS) + ] + categorical_columns += [ + tf.feature_column.categorical_column_with_identity( + key, num_buckets=_FEATURE_BUCKET_COUNT, default_value=0) + for key in _transformed_names(_BUCKET_FEATURE_KEYS) + ] + categorical_columns += [ + tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension + key, + num_buckets=num_buckets, + default_value=0) for key, num_buckets in zip( + _transformed_names(_CATEGORICAL_FEATURE_KEYS), + _MAX_CATEGORICAL_FEATURE_VALUES) + ] + return tf.estimator.DNNLinearCombinedClassifier( + config=config, + linear_feature_columns=categorical_columns, + dnn_feature_columns=real_valued_columns, + dnn_hidden_units=hidden_units or [100, 70, 50, 25], + warm_start_from=warm_start_from) + + +def _example_serving_receiver_fn(transform_output, schema): + """Build the serving in inputs. + + Args: + transform_output: directory in which the tf-transform model was written + during the preprocessing step. + schema: the schema of the input data. + + Returns: + Tensorflow graph which parses examples, applying tf-transform to them. + """ + raw_feature_spec = _get_raw_feature_spec(schema) + raw_feature_spec.pop(_LABEL_KEY) + + raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( + raw_feature_spec, default_batch_size=None) + serving_input_receiver = raw_input_fn() + + _, transformed_features = ( + saved_transform_io.partially_apply_saved_transform( + os.path.join(transform_output, transform_fn_io.TRANSFORM_FN_DIR), + serving_input_receiver.features)) + + return tf.estimator.export.ServingInputReceiver( + transformed_features, serving_input_receiver.receiver_tensors) + + +def _eval_input_receiver_fn(transform_output, schema): + """Build everything needed for the tf-model-analysis to run the model. + + Args: + transform_output: directory in which the tf-transform model was written + during the preprocessing step. + schema: the schema of the input data. + + Returns: + EvalInputReceiver function, which contains: + - Tensorflow graph which parses raw untransformed features, applies the + tf-transform preprocessing operators. + - Set of raw, untransformed features. + - Label against which predictions will be compared. + """ + # Notice that the inputs are raw features, not transformed features here. + raw_feature_spec = _get_raw_feature_spec(schema) + + serialized_tf_example = tf.placeholder( + dtype=tf.string, shape=[None], name='input_example_tensor') + + # Add a parse_example operator to the tensorflow graph, which will parse + # raw, untransformed, tf examples. + features = tf.parse_example(serialized_tf_example, raw_feature_spec) + + # Now that we have our raw examples, process them through the tf-transform + # function computed during the preprocessing step. + _, transformed_features = ( + saved_transform_io.partially_apply_saved_transform( + os.path.join(transform_output, transform_fn_io.TRANSFORM_FN_DIR), + features)) + + # The key name MUST be 'examples'. + receiver_tensors = {'examples': serialized_tf_example} + + # NOTE: Model is driven by transformed features (since training works on the + # materialized output of TFT, but slicing will happen on raw features. + features.update(transformed_features) + + return tfma.export.EvalInputReceiver( + features=features, + receiver_tensors=receiver_tensors, + labels=transformed_features[_transformed_name(_LABEL_KEY)]) + + +def _input_fn(filenames, transform_output, batch_size=200): + """Generates features and labels for training or evaluation. + + Args: + filenames: [str] list of CSV files to read data from. + transform_output: directory in which the tf-transform model was written + during the preprocessing step. + batch_size: int First dimension size of the Tensors returned by input_fn + + Returns: + A (features, indices) tuple where features is a dictionary of + Tensors, and indices is a single Tensor of label indices. + """ + metadata_dir = os.path.join(transform_output, + transform_fn_io.TRANSFORMED_METADATA_DIR) + transformed_metadata = metadata_io.read_metadata(metadata_dir) + transformed_feature_spec = transformed_metadata.schema.as_feature_spec() + + transformed_features = tf.contrib.learn.io.read_batch_features( + filenames, batch_size, transformed_feature_spec, reader=_gzip_reader_fn) + + # We pop the label because we do not want to use it as a feature while we're + # training. + return transformed_features, transformed_features.pop( + _transformed_name(_LABEL_KEY)) + + +# TFX will call this function +def trainer_fn(hparams, schema): + """Build the estimator using the high level API. + + Args: + hparams: Holds hyperparameters used to train the model as name/value pairs. + schema: Holds the schema of the training examples. + + Returns: + A dict of the following: + - estimator: The estimator that will be used for training and eval. + - train_spec: Spec for training. + - eval_spec: Spec for eval. + - eval_input_receiver_fn: Input function for eval. + """ + # Number of nodes in the first layer of the DNN + first_dnn_layer_size = 100 + num_dnn_layers = 4 + dnn_decay_factor = 0.7 + + train_batch_size = 40 + eval_batch_size = 40 + + train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda + hparams.train_files, + hparams.transform_output, + batch_size=train_batch_size) + + eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda + hparams.eval_files, + hparams.transform_output, + batch_size=eval_batch_size) + + train_spec = tf.estimator.TrainSpec( # pylint: disable=g-long-lambda + train_input_fn, + max_steps=hparams.train_steps) + + serving_receiver_fn = lambda: _example_serving_receiver_fn( # pylint: disable=g-long-lambda + hparams.transform_output, schema) + + exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn) + eval_spec = tf.estimator.EvalSpec( + eval_input_fn, + steps=hparams.eval_steps, + exporters=[exporter], + name='chicago-taxi-eval') + + run_config = tf.estimator.RunConfig( + save_checkpoints_steps=999, keep_checkpoint_max=1) + + run_config = run_config.replace(model_dir=hparams.serving_model_dir) + + estimator = _build_estimator( + # Construct layers sizes with exponetial decay + hidden_units=[ + max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) + for i in range(num_dnn_layers) + ], + config=run_config, + warm_start_from=hparams.warm_start_from) + + # Create an input receiver for TFMA processing + receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda + hparams.transform_output, schema) + + return { + 'estimator': estimator, + 'train_spec': train_spec, + 'eval_spec': eval_spec, + 'eval_input_receiver_fn': receiver_fn + }