Skip to content

Commit

Permalink
[AIR] Fix TensorflowTrainer docstring example (ray-project#29463)
Browse files Browse the repository at this point in the history
You can't run the code, because we don't provide an argument to train_loop_per_worker.

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: tmynn <hovhannes.tamoyan@gmail.com>
  • Loading branch information
bveeramani authored and tamohannes committed Jan 25, 2023
1 parent bcd01bb commit 2bdb2f8
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions python/ray/train/tensorflow/tensorflow_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,38 +85,36 @@ def train_loop_per_worker():
Example:
.. code-block:: python
.. testcode::
import tensorflow as tf
import ray
from ray.air import session, Checkpoint
from ray.train.tensorflow import TensorflowTrainer
from ray.air.config import ScalingConfig
input_size = 1
from ray.train.tensorflow import TensorflowTrainer
def build_model():
# toy neural network : 1-layer
return tf.keras.Sequential(
[tf.keras.layers.Dense(
1, activation="linear", input_shape=(input_size,))]
1, activation="linear", input_shape=(1,))]
)
def train_loop_for_worker(config):
def train_loop_per_worker(config):
dataset_shard = session.get_dataset_shard("train")
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model = build_model()
model.compile(
optimizer="Adam", loss="mean_squared_error", metrics=["mse"])
tf_dataset = dataset_shard.to_tf(
feature_columns="x",
label_columns="y",
batch_size=1
)
for epoch in range(config["num_epochs"]):
tf_dataset = dataset_shard.to_tf(
feature_columns="x",
label_columns="y",
batch_size=1
)
model.fit(tf_dataset)
# You can also use ray.air.integrations.keras.Callback
# for reporting and checkpointing instead of reporting manually.
Expand All @@ -127,13 +125,20 @@ def train_loop_for_worker(config):
),
)
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(scaling_config=ScalingConfig(num_workers=3),
train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(
train_loop_per_worker=train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=3),
datasets={"train": train_dataset},
train_loop_config={"num_epochs": 2})
train_loop_config={"num_epochs": 2},
)
result = trainer.fit()
.. testoutput::
:hide:
:options: +ELLIPSIS
...
Args:
train_loop_per_worker: The training function to execute.
Expand Down

0 comments on commit 2bdb2f8

Please sign in to comment.