Skip to content

Commit

Permalink
[RAY AIR][DOC][TorchTrainer] Rewrote the TorchTrainer code snippet as…
Browse files Browse the repository at this point in the history
… a working example (#30492)

Signed-off-by: Jules Damji jules@anyscale.com

- Rewrote the code snippet as it was not working

- Removed python-code block directives; instead use testcode and testoutput. This will test code if it runs in the CI

- Ignore the output since we get loads of output from the three workers

- Assert that the loss converges with the training data within specified epochs

- Tested code end-to-end
  • Loading branch information
dmatrix authored Nov 28, 2022
1 parent 53a5b4d commit 19aadd4
Showing 1 changed file with 72 additions and 29 deletions.
101 changes: 72 additions & 29 deletions python/ray/train/torch/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@ class TorchTrainer(DataParallelTrainer):
"""A Trainer for data parallel PyTorch training.
This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
Actors. These actors already have the necessary torch process group already
Actors. These actors already have the necessary torch process group
configured for distributed PyTorch training.
The ``train_loop_per_worker`` function is expected to take in either 0 or 1
arguments:
.. code-block:: python
.. testcode::
def train_loop_per_worker():
...
.. code-block:: python
.. testcode::
def train_loop_per_worker(config: Dict):
from typing import Dict, Any
def train_loop_per_worker(config: Dict[str, Any]):
...
If ``train_loop_per_worker`` accepts an argument, then
Expand All @@ -43,34 +44,34 @@ def train_loop_per_worker(config: Dict):
``session.get_dataset_shard(...)`` will return the the entire Dataset.
Inside the ``train_loop_per_worker`` function, you can use any of the
:ref:`Ray AIR session methods <air-session-ref>`.
:ref:`Ray AIR session methods <air-session-ref>`. See full example code below.
.. code-block:: python
.. testcode::
def train_loop_per_worker():
# Report intermediate results for callbacks or logging and
# checkpoint data.
session.report(...)
# Returns dict of last saved checkpoint.
# Get dict of last saved checkpoint.
session.get_checkpoint()
# Returns the Ray Dataset shard for the given key.
# Session returns the Ray Dataset shard for the given key.
session.get_dataset_shard("my_dataset")
# Returns the total number of workers executing training.
# Get the total number of workers executing training.
session.get_world_size()
# Returns the rank of this worker.
# Get the rank of this worker.
session.get_world_rank()
# Returns the rank of the worker on the current node.
# Get the rank of the worker on the current node.
session.get_local_rank()
You can also use any of the Torch specific function utils,
such as :func:`ray.train.torch.get_device` and :func:`ray.train.torch.prepare_model`
.. code-block:: python
.. testcode::
def train_loop_per_worker():
# Prepares model for distribted training by wrapping in
Expand All @@ -83,7 +84,7 @@ def train_loop_per_worker():
# `session.get_dataset_shard(...).iter_torch_batches(...)`
train.torch.prepare_data_loader(...)
# Returns the current torch device.
# Get the current torch device.
train.torch.get_device()
Any returns from the ``train_loop_per_worker`` will be discarded and not
Expand All @@ -93,7 +94,8 @@ def train_loop_per_worker():
"model" kwarg in ``Checkpoint`` passed to ``session.report()``.
Example:
.. code-block:: python
.. testcode::
import torch
import torch.nn as nn
Expand All @@ -103,12 +105,17 @@ def train_loop_per_worker():
from ray.air import session, Checkpoint
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from ray.air.config import RunConfig
from ray.air.config import CheckpointConfig
# Define NN layers archicture, epochs, and number of workers
input_size = 1
layer_size = 15
layer_size = 32
output_size = 1
num_epochs = 3
num_epochs = 200
num_workers = 3
# Define your network structure
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
Expand All @@ -119,46 +126,82 @@ def __init__(self):
def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))
# Define your train worker loop
def train_loop_per_worker():
# Fetch training set from the session
dataset_shard = session.get_dataset_shard("train")
model = NeuralNetwork()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# Loss function, optimizer, prepare model for training.
# This moves the data and prepares model for distributed
# execution
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=0.01,
weight_decay=0.01)
model = train.torch.prepare_model(model)
# Iterate over epochs and batches
for epoch in range(num_epochs):
for batches in dataset_shard.iter_torch_batches(
batch_size=32, dtypes=torch.float
):
for batches in dataset_shard.iter_torch_batches(batch_size=32,
dtypes=torch.float):
# Add batch or unsqueeze as an additional dimension [32, x]
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"]
output = model(inputs)
loss = loss_fn(output, labels)
# Make output shape same as the as labels
loss = loss_fn(output.squeeze(), labels)
# Zero out grads, do backward, and update optimizer
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch: {epoch}, loss: {loss.item()}")
session.report(
{},
checkpoint=Checkpoint.from_dict(
dict(epoch=epoch, model=model.state_dict()
),
# Print what's happening with loss per 30 epochs
if epoch % 20 == 0:
print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}")
# Report and record metrics, checkpoint model at end of each
# epoch
session.report({"loss": loss.item(), "epoch": epoch},
checkpoint=Checkpoint.from_dict(
dict(epoch=epoch, model=model.state_dict()))
)
torch.manual_seed(42)
train_dataset = ray.data.from_items(
[{"x": x, "y": 2 * x + 1} for x in range(200)]
)
scaling_config = ScalingConfig(num_workers=3)
# Define scaling and run configs
# If using GPUs, use the below scaling config instead.
# scaling_config = ScalingConfig(num_workers=3, use_gpu=True)
scaling_config = ScalingConfig(num_workers=num_workers)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
scaling_config=scaling_config,
run_config=run_config,
datasets={"train": train_dataset})
result = trainer.fit()
best_checkpoint_loss = result.metrics['loss']
# Assert loss is less 0.09
assert best_checkpoint_loss <= 0.09
.. testoutput::
:hide:
:options: +ELLIPSIS
...
Args:
train_loop_per_worker: The training function to execute.
This can either take in no arguments or a ``config`` dict.
train_loop_config: Configurations to pass into
Expand Down

0 comments on commit 19aadd4

Please sign in to comment.