Skip to content

Commit

Permalink
[AIR] Add default prefetch to tf dataset in AIR examples (ray-project…
Browse files Browse the repository at this point in the history
…#28306)

We should recommend better practice in our tf examples.

Signed-off-by: Jiao Dong jiaodong@anyscale.com
Signed-off-by: ilee300a <ilee300@anyscale.com>
  • Loading branch information
jiaodong authored and ilee300a committed Sep 12, 2022
1 parent aff6344 commit 0492a6b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
7 changes: 2 additions & 5 deletions doc/source/ray-air/examples/tfx_tabular_train_to_serve.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,7 @@
" # This will make sure that the training workers will get their own\n",
" # share of batch to work on.\n",
" # See `ray.train.tensorflow.prepare_dataset_shard` for more information.\n",
" tf_dataset = to_tf_dataset(\n",
" dataset=dataset_shard,\n",
" batch_size=BATCH_SIZE,\n",
" )\n",
" tf_dataset = to_tf_dataset(dataset=dataset_shard, batch_size=BATCH_SIZE)\n",
"\n",
" model.fit(tf_dataset, verbose=0)\n",
" # This saves checkpoint in a way that can be used by Ray Serve coherently.\n",
Expand Down Expand Up @@ -969,7 +966,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.7.10"
},
"vscode": {
"interpreter": {
Expand Down
12 changes: 7 additions & 5 deletions python/ray/train/tensorflow/train_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@

@PublicAPI(stability="beta")
def prepare_dataset_shard(tf_dataset_shard: tf.data.Dataset):
"""A utility function that disables Tensorflow autosharding.
"""A utility function that overrides default config for Tensorflow Dataset.
This should be used on a TensorFlow ``Dataset`` created by calling
``iter_tf_batches()`` on a ``ray.data.Dataset`` returned by
``ray.train.get_dataset_shard()`` since the dataset has already been sharded across
the workers.
``ray.train.get_dataset_shard()`` since the dataset has already been sharded
across the workers.
Args:
tf_dataset_shard (tf.data.Dataset): A TensorFlow Dataset.
Returns:
A TensorFlow Dataset with autosharding turned off.
A TensorFlow Dataset with:
- autosharding turned off
- prefetching turned on with autotune enabled
"""
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF
)
return tf_dataset_shard.with_options(options)
return tf_dataset_shard.with_options(options).prefetch(tf.data.AUTOTUNE)

0 comments on commit 0492a6b

Please sign in to comment.