Skip to content


versioned the 0.13.0 tfx example (#1912)
Browse files Browse the repository at this point in the history
* versioned the 0.13.0 tfx example
* use the taxi_utils from 0.13.1
* update tfx oss readme
  • Loading branch information
gaoning777 authored Aug 26, 2019
1 parent 27d6742 commit 5205339
Show file tree
Hide file tree
Showing 3 changed files with 574 additions and 24 deletions.
23 changes: 8 additions & 15 deletions samples/core/tfx-oss/
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,28 @@ conda create -n tfx-kfp pip python=3.5.3
then activate the environment.

Install TensorFlow, TFX and Kubeflow Pipelines SDK
Install TFX and Kubeflow Pipelines SDK
pip install tensorflow --upgrade
pip install tfx
pip3 install tfx==0.13.0 --upgrade
pip install kfp --upgrade

Clone TFX github repo
git clone

Upload the utility code to your storage bucket. You can modify this code if needed for a different dataset.
gsutil cp tfx/tfx/examples/chicago_taxi_pipeline/ gs://my-bucket/<path>/
gsutil cp utils/ gs://my-bucket/<path>/

If gsutil does not work, try `tensorflow.gfile`:
from tensorflow import gfile
gfile.Copy('tfx/tfx/examples/chicago_taxi_pipeline/', 'gs://<my bucket>/<path>/')
gfile.Copy('utils/', 'gs://<my bucket>/<path>/')

## Configure the TFX Pipeline

Modify the pipeline configuration file at
Modify the pipeline configurations at
TFX Example.ipynb
- Set `_input_bucket` to the GCS directory where you've copied I.e. gs://<my bucket>/<path>/
Expand All @@ -54,8 +48,7 @@ Configure
- The original BigQuery dataset has 100M rows, which can take time to process. Modify the selection criteria (% of records) to run a sample test.

## Compile and run the pipeline
python tfx/tfx/examples/chicago_taxi_pipeline/
Run the notebook.

This will generate a file named chicago_taxi_pipeline_kubeflow.tar.gz
Upload this file to the Pipelines Cluster and create a run.
207 changes: 198 additions & 9 deletions samples/core/tfx-oss/TFX Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"metadata": {},
"outputs": [],
"source": [
"!pip3 install \n",
"!pip3 install kfp --upgrade\n"
"!pip3 install tfx==0.13.0 --upgrade\n",
"!pip3 install kfp --upgrade"
Expand All @@ -39,10 +39,21 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": [
"outputs": [],
"source": [
"!git clone"
"# Directory and data locations (uses Google Cloud Storage).\n",
"import os\n",
"_input_bucket = '<your gcs bucket>'\n",
"_output_bucket = '<your gcs bucket>'\n",
"_pipeline_root = os.path.join(_output_bucket, 'tfx')\n",
"# Google Cloud Platform project id to use when deploying this pipeline.\n",
"_project_id = '<your project id>'"
Expand All @@ -53,7 +64,7 @@
"source": [
"# copy the trainer code to a storage bucket as the TFX pipeline will need that code file in GCS\n",
"from tensorflow import gfile\n",
"gfile.Copy('tfx/examples/chicago_taxi_pipeline/', 'gs://<my bucket>/<path>/')"
"gfile.Copy('utils/', _input_bucket + '/')"
Expand All @@ -77,11 +88,180 @@
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load tfx/examples/chicago_taxi_pipeline/"
"\"\"\"Chicago Taxi example using TFX DSL on Kubeflow.\"\"\"\n",
"from __future__ import absolute_import\n",
"from __future__ import division\n",
"from __future__ import print_function\n",
"import os\n",
"from tfx.components.evaluator.component import Evaluator\n",
"from tfx.components.example_gen.big_query_example_gen.component import BigQueryExampleGen\n",
"from tfx.components.example_validator.component import ExampleValidator\n",
"from tfx.components.model_validator.component import ModelValidator\n",
"from tfx.components.pusher.component import Pusher\n",
"from tfx.components.schema_gen.component import SchemaGen\n",
"from tfx.components.statistics_gen.component import StatisticsGen\n",
"from tfx.components.trainer.component import Trainer\n",
"from tfx.components.transform.component import Transform\n",
"from tfx.orchestration.kubeflow.runner import KubeflowRunner\n",
"from tfx.orchestration.pipeline import PipelineDecorator\n",
"from tfx.proto import evaluator_pb2\n",
"from tfx.proto import pusher_pb2\n",
"from tfx.proto import trainer_pb2\n",
"# Python module file to inject customized logic into the TFX components. The\n",
"# Transform and Trainer both require user-defined functions to run successfully.\n",
"# Copy this from the current directory to a GCS bucket and update the location\n",
"# below.\n",
"_taxi_utils = os.path.join(_input_bucket, '')\n",
"# Path which can be listened to by the model server. Pusher will output the\n",
"# trained model here.\n",
"_serving_model_dir = os.path.join(_output_bucket, 'serving_model/taxi_bigquery')\n",
"# Region to use for Dataflow jobs and CMLE training.\n",
"# Dataflow:\n",
"# CMLE:\n",
"_gcp_region = 'us-central1'\n",
"# A dict which contains the training job parameters to be passed to Google\n",
"# Cloud ML Engine. For the full set of parameters supported by Google Cloud ML\n",
"# Engine, refer to\n",
"_cmle_training_args = {\n",
" 'pythonModule': None, # Will be populated by TFX\n",
" 'args': None, # Will be populated by TFX\n",
" 'region': _gcp_region,\n",
" 'jobDir': os.path.join(_output_bucket, 'tmp'),\n",
" 'runtimeVersion': '1.12',\n",
" 'pythonVersion': '2.7',\n",
" 'project': _project_id,\n",
"# A dict which contains the serving job parameters to be passed to Google\n",
"# Cloud ML Engine. For the full set of parameters supported by Google Cloud ML\n",
"# Engine, refer to\n",
"_cmle_serving_args = {\n",
" 'model_name': 'chicago_taxi',\n",
" 'project_id': _project_id,\n",
" 'runtime_version': '1.12',\n",
"# The rate at which to sample rows from the Chicago Taxi dataset using BigQuery.\n",
"# The full taxi dataset is > 120M record. In the interest of resource\n",
"# savings and time, we've set the default for this example to be much smaller.\n",
"# Feel free to crank it up and process the full dataset!\n",
"_query_sample_rate = 0.001 # Generate a 0.1% random sample.\n",
"# TODO(zhitaoli): Remove PipelineDecorator after 0.13.0.\n",
" pipeline_name='chicago_taxi_pipeline_kubeflow',\n",
" log_root='/var/tmp/tfx/logs',\n",
" pipeline_root=_pipeline_root,\n",
" additional_pipeline_args={\n",
" 'beam_pipeline_args': [\n",
" '--runner=DataflowRunner',\n",
" '--experiments=shuffle_mode=auto',\n",
" '--project=' + _project_id,\n",
" '--temp_location=' + os.path.join(_output_bucket, 'tmp'),\n",
" '--region=' + _gcp_region,\n",
" ],\n",
" # Optional args:\n",
" # 'tfx_image': custom docker image to use for components. This is needed\n",
" # if TFX package is not installed from an RC or released version.\n",
" })\n",
"def _create_pipeline():\n",
" \"\"\"Implements the chicago taxi pipeline with TFX.\"\"\"\n",
" query = \"\"\"\n",
" SELECT\n",
" pickup_community_area,\n",
" fare,\n",
" EXTRACT(MONTH FROM trip_start_timestamp) AS trip_start_month,\n",
" EXTRACT(HOUR FROM trip_start_timestamp) AS trip_start_hour,\n",
" EXTRACT(DAYOFWEEK FROM trip_start_timestamp) AS trip_start_day,\n",
" UNIX_SECONDS(trip_start_timestamp) AS trip_start_timestamp,\n",
" pickup_latitude,\n",
" pickup_longitude,\n",
" dropoff_latitude,\n",
" dropoff_longitude,\n",
" trip_miles,\n",
" pickup_census_tract,\n",
" dropoff_census_tract,\n",
" payment_type,\n",
" company,\n",
" trip_seconds,\n",
" dropoff_community_area,\n",
" tips\n",
" FROM `bigquery-public-data.chicago_taxi_trips.taxi_trips`\n",
" WHERE RAND() < {}\"\"\".format(_query_sample_rate)\n",
" # Brings data into the pipeline or otherwise joins/converts training data.\n",
" example_gen = BigQueryExampleGen(query=query)\n",
" # Computes statistics over data for visualization and example validation.\n",
" statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples)\n",
" # Generates schema based on statistics files.\n",
" infer_schema = SchemaGen(stats=statistics_gen.outputs.output)\n",
" # Performs anomaly detection based on statistics and data schema.\n",
" validate_stats = ExampleValidator(\n",
" stats=statistics_gen.outputs.output, schema=infer_schema.outputs.output)\n",
" # Performs transformations and feature engineering in training and serving.\n",
" transform = Transform(\n",
" input_data=example_gen.outputs.examples,\n",
" schema=infer_schema.outputs.output,\n",
" module_file=_taxi_utils)\n",
" # Uses user-provided Python function that implements a model using TF-Learn.\n",
" trainer = Trainer(\n",
" module_file=_taxi_utils,\n",
" transformed_examples=transform.outputs.transformed_examples,\n",
" schema=infer_schema.outputs.output,\n",
" transform_output=transform.outputs.transform_output,\n",
" train_args=trainer_pb2.TrainArgs(num_steps=10000),\n",
" eval_args=trainer_pb2.EvalArgs(num_steps=5000),\n",
" custom_config={'cmle_training_args': _cmle_training_args})\n",
" # Uses TFMA to compute a evaluation statistics over features of a model.\n",
" model_analyzer = Evaluator(\n",
" examples=example_gen.outputs.examples,\n",
" model_exports=trainer.outputs.output,\n",
" feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[\n",
" evaluator_pb2.SingleSlicingSpec(\n",
" column_for_slicing=['trip_start_hour'])\n",
" ]))\n",
" # Performs quality validation of a candidate model (compared to a baseline).\n",
" model_validator = ModelValidator(\n",
" examples=example_gen.outputs.examples, model=trainer.outputs.output)\n",
" # Checks whether the model passed the validation steps and pushes the model\n",
" # to a file destination if check passed.\n",
" pusher = Pusher(\n",
" model_export=trainer.outputs.output,\n",
" model_blessing=model_validator.outputs.blessing,\n",
" custom_config={'cmle_serving_args': _cmle_serving_args},\n",
" push_destination=pusher_pb2.PushDestination(\n",
" filesystem=pusher_pb2.PushDestination.Filesystem(\n",
" base_directory=_serving_model_dir)))\n",
" return [\n",
" example_gen, statistics_gen, infer_schema, validate_stats, transform,\n",
" trainer, model_analyzer, model_validator, pusher\n",
" ]\n",
"pipeline = KubeflowRunner().run(_create_pipeline())"
Expand Down Expand Up @@ -188,9 +368,18 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
"version": "3.6.7"
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"source": [],
"metadata": {
"collapsed": false
"nbformat": 4,
"nbformat_minor": 2

0 comments on commit 5205339

Please sign in to comment.