Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder (vol 23): Float16 training support and new example script. #47362

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/rllib/package_ref/learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Computing Losses
:nosignatures:
:toctree: doc/

Learner.compute_loss
Learner.compute_losses
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this API name to be even more clear. This method computes one(!) loss per RLModule (in a MultiRLModule) inside the Learner.

Got rid of the confusing TOTAL_LOSS key. We compute this now in the default implementation of compute_gradients.

Learner.compute_loss_for_module
Learner._is_module_compatible_with_learner
Learner._get_tensor_variable
Expand Down
8 changes: 8 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2969,6 +2969,14 @@ py_test(

# subdirectory: gpus/
# ....................................
py_test(
name = "examples/gpus/float16_training_and_inference",
main = "examples/gpus/float16_training_and_inference.py",
tags = ["team:rllib", "exclusive", "examples", "gpu"],
size = "medium",
srcs = ["examples/gpus/float16_training_and_inference.py"],
args = ["--enable-new-api-stack", "--as-test", "--stop-reward=150.0"]
)
Comment on lines +2972 to +2979
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Is this even used by users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Actually a customer asked for it ;)
It does get interesting for large models (see also many efforts training LLMs with super compressed precisions down to bfloat16) and multi-agent. I think if one can stabilize this, it's very useful. Another example script with mixed-precision training (and float16 inference on the EnvRunners) is in the making ...

py_test(
name = "examples/gpus/fractional_0.5_gpus_per_learner",
main = "examples/gpus/fractional_gpus_per_learner.py",
Expand Down
25 changes: 19 additions & 6 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def __init__(self, algo_class: Optional[type] = None):
self._per_module_overrides: Dict[ModuleID, "AlgorithmConfig"] = {}

# `self.experimental()`
self._torch_grad_scaler_class = None
self._tf_policy_handles_more_than_one_loss = False
self._disable_preprocessor_api = False
self._disable_action_flattening = False
Expand Down Expand Up @@ -3158,6 +3159,13 @@ def rl_module(
Returns:
This updated AlgorithmConfig object.
"""
if _enable_rl_module_api != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)",
new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)",
error=True,
)

if model_config_dict is not NotProvided:
self._model_config_dict = model_config_dict
if rl_module_spec is not NotProvided:
Expand All @@ -3173,17 +3181,12 @@ def rl_module(
algorithm_config_overrides_per_module
)

if _enable_rl_module_api != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)",
new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)",
error=False,
)
return self

def experimental(
self,
*,
_torch_grad_scaler_class: Optional[Type] = NotProvided,
_tf_policy_handles_more_than_one_loss: Optional[bool] = NotProvided,
_disable_preprocessor_api: Optional[bool] = NotProvided,
_disable_action_flattening: Optional[bool] = NotProvided,
Expand All @@ -3194,6 +3197,14 @@ def experimental(
"""Sets the config's experimental settings.

Args:
_torch_grad_scaler_class: Class to use for torch loss scaling (and gradient
unscaling). The class must implement the following methods to be
compatible with a `TorchLearner`. These methods/APIs match exactly the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: Remove "the" at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

those of torch's own `torch.amp.GradScaler`:
`scale([loss])` to scale the loss.
`get_scale()` to get the current scale value.
`step([optimizer])` to unscale the grads and step the given optimizer.
`update()` to update the scaler after an optimizer step.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description is not clear enough imo. For what is it used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enhanced and added a link to the torch docs

_tf_policy_handles_more_than_one_loss: Experimental flag.
If True, TFPolicy will handle more than one loss/optimizer.
Set this to True, if you would like to return more than
Expand Down Expand Up @@ -3235,6 +3246,8 @@ def experimental(
self._disable_initialize_loss_from_dummy_batch = (
_disable_initialize_loss_from_dummy_batch
)
if _torch_grad_scaler_class is not NotProvided:
self._torch_grad_scaler_class = _torch_grad_scaler_class

return self

Expand Down
10 changes: 5 additions & 5 deletions rllib/algorithms/ppo/torch/ppo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ def compute_loss_for_module(
# for which we add an additional (artificial) timestep to each episode to
# simplify the actual computation.
if Columns.LOSS_MASK in batch:
num_valid = torch.sum(batch[Columns.LOSS_MASK])
mask = batch[Columns.LOSS_MASK]
num_valid = torch.sum(mask)

def possibly_masked_mean(data_):
return torch.sum(data_[batch[Columns.LOSS_MASK]]) / num_valid
return torch.sum(data_[mask]) / num_valid
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we potentially also mask with this mask values in the observations that are not available in a certain step of the environment, e.g. different number of entities at different steps of a game?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should work. But you would have to manually change that in the Learner connector (where this column is being produced and added to the batch).


else:
possibly_masked_mean = torch.mean
Expand Down Expand Up @@ -98,9 +99,8 @@ def possibly_masked_mean(data_):
mean_vf_unclipped_loss = possibly_masked_mean(vf_loss)
# Ignore the value function.
else:
value_fn_out = torch.tensor(0.0).to(surrogate_loss.device)
mean_vf_unclipped_loss = torch.tensor(0.0).to(surrogate_loss.device)
vf_loss_clipped = mean_vf_loss = torch.tensor(0.0).to(surrogate_loss.device)
z = torch.tensor(0.0, device=surrogate_loss.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify

value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Niceee!


total_loss = possibly_masked_mean(
-surrogate_loss
Expand Down
75 changes: 44 additions & 31 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ class Learner(Checkpointable):

class MyLearner(TorchLearner):

def compute_loss(self, fwd_out, batch):
# compute the loss based on batch and output of the forward pass
def compute_losses(self, fwd_out, batch):
# Compute the loss based on batch and output of the forward pass
# to access the learner hyper-parameters use `self._hps`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove _hps. Hyperparameters are not used anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Will fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return {ALL_MODULES: loss}
"""
Expand Down Expand Up @@ -597,9 +597,22 @@ def get_optimizer(
The optimizer object, configured under the given `module_id` and
`optimizer_name`.
"""
# `optimizer_name` could possibly be the full optimizer name (including the
# module_id under which it is registered).
if optimizer_name in self._named_optimizers:
return self._named_optimizers[optimizer_name]

# Normally, `optimizer_name` is just the optimizer's name, not including the
# `module_id`.
full_registration_name = module_id + "_" + optimizer_name
assert full_registration_name in self._named_optimizers
return self._named_optimizers[full_registration_name]
if full_registration_name in self._named_optimizers:
return self._named_optimizers[full_registration_name]

# No optimizer found.
raise KeyError(
f"Optimizer not found! module_id={module_id} "
f"optimizer_name={optimizer_name}"
)

def get_optimizers_for_module(
self, module_id: ModuleID = ALL_MODULES
Expand Down Expand Up @@ -830,33 +843,30 @@ def should_module_be_updated(self, module_id, multi_agent_batch=None):
return should_module_be_updated_fn(module_id, multi_agent_batch)

@OverrideToImplementCustomLogic
def compute_loss(
def compute_losses(
self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any]
) -> Dict[str, Any]:
"""Computes the loss for the module being optimized.

This method must be overridden by multiagent-specific algorithm learners to
specify the specific loss computation logic. If the algorithm is single agent
`compute_loss_for_module()` should be overridden instead.
`fwd_out` is the output of the `forward_train()` method of the underlying
MultiRLModule. `batch` is the data that was used to compute `fwd_out`.
The returned dictionary must contain a key called
ALL_MODULES, which will be used to compute gradients. It is recommended
to not compute any forward passes within this method, and to use the
`forward_train()` outputs of the RLModule(s) to compute the required tensors for
loss calculations.
"""Computes the loss(es) for the module being optimized.

This method must be overridden by MultiRLModule-specific Learners in order to
define the specific loss computation logic. If the algorithm is single-agent
`compute_loss_for_module()` should be overridden instead. If the algorithm uses
independent multi-agent learning (default behavior for multi-agent setups), also
`compute_loss_for_module()` should be overridden, but it will be called for each
individual RLModule inside the MultiRLModule.
Comment on lines +851 to +856
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's either refer to the docs about this or to some examples. Just from the text this does not become very clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can link to the custom loss example ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

It is recommended to not compute any forward passes within this method, and to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important hint! Awesome!

use the `forward_train()` outputs of the RLModule(s) to compute the required
tensors for loss calculations.

Args:
fwd_out: Output from a call to the `forward_train()` method of self.module
during training (`self.update()`).
fwd_out: Output from a call to the `forward_train()` method of the
underlying MultiRLModule (`self.module`) during training
(`self.update()`).
batch: The training batch that was used to compute `fwd_out`.

Returns:
A dictionary mapping module IDs to individual loss terms. The dictionary
must contain one protected key ALL_MODULES which will be used for computing
gradients through.
A dictionary mapping module IDs to individual loss terms.
"""
loss_total = None
loss_per_module = {}
for module_id in fwd_out:
module_batch = batch[module_id]
Expand All @@ -870,13 +880,6 @@ def compute_loss(
)
loss_per_module[module_id] = loss

if loss_total is None:
loss_total = loss
else:
loss_total += loss

loss_per_module[ALL_MODULES] = loss_total

return loss_per_module

@OverrideToImplementCustomLogic
Expand All @@ -893,7 +896,7 @@ def compute_loss_for_module(

Think of this as computing loss for a single agent. For multi-agent use-cases
that require more complicated computation for loss, consider overriding the
`compute_loss` method instead.
`compute_losses` method instead.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can compute_losses call internally compute_loss_for_module or is the latter called nevertheless and losses would be computed two times?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default implementation of compute_losses calls n times compute_loss_for_module, where n is the number of RLModules within the Learner's MultiRLModule. So nothing is computed twice.

If you have a complex multi-agent case, you should override compute_losses, in which case the n calls to compute_loss_for_module will NOT be made.


Args:
module_id: The id of the module.
Expand Down Expand Up @@ -1621,3 +1624,13 @@ def get_optimizer_state(self, *args, **kwargs):
@Deprecated(new="Learner._set_optimizer_state()", error=True)
def set_optimizer_state(self, *args, **kwargs):
pass

@Deprecated(new="Learner.compute_losses(...)", error=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emulate the old behavior, in case users use this in their tests or other code.

def compute_loss(self, *args, **kwargs):
losses_per_module = self.compute_losses(*args, **kwargs)
# To continue supporting the old `compute_loss` behavior (instead of
# the new `compute_losses`, add the ALL_MODULES key here holding the sum
# of all individual loss terms.
if ALL_MODULES not in losses_per_module:
losses_per_module[ALL_MODULES] = sum(losses_per_module.values())
return losses_per_module
8 changes: 4 additions & 4 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
OverrideToImplementCustomLogic,
)
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import ALL_MODULES
from ray.rllib.utils.typing import (
ModuleID,
Optimizer,
Expand Down Expand Up @@ -75,7 +74,7 @@ def configure_optimizers_for_module(

# For this default implementation, the learning rate is handled by the
# attached lr Scheduler (controlled by self.config.lr, which can be a
# fixed value of a schedule setting).
# fixed value or a schedule setting).
optimizer = tf.keras.optimizers.Adam()
params = self.get_parameters(module)

Expand All @@ -99,7 +98,8 @@ def compute_gradients(
gradient_tape: "tf.GradientTape",
**kwargs,
) -> ParamDict:
grads = gradient_tape.gradient(loss_per_module[ALL_MODULES], self._params)
total_loss = sum(loss_per_module.values())
grads = gradient_tape.gradient(total_loss, self._params)
return grads

@override(Learner)
Expand Down Expand Up @@ -300,7 +300,7 @@ def _untraced_update(
def helper(_batch):
with tf.GradientTape(persistent=True) as tape:
fwd_out = self._module.forward_train(_batch)
loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=_batch)
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=_batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even want this here and motivate users to write TF algorithms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! But .... DreamerV3 :|

gradients = self.compute_gradients(loss_per_module, gradient_tape=tape)
del tape
postprocessed_gradients = self.postprocess_gradients(gradients)
Expand Down
51 changes: 42 additions & 9 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import logging
from typing import (
Any,
Expand Down Expand Up @@ -97,6 +98,14 @@ def __init__(self, **kwargs):
torch_dynamo_mode=self.config.torch_compile_learner_dynamo_mode,
)

# Loss scalers for mixed precision training. Map optimizer names to
# associated torch GradScaler objects.
self._grad_scalers = None
if self.config._torch_grad_scaler_class:
self._grad_scalers = defaultdict(
lambda: self.config._torch_grad_scaler_class()
)

@OverrideToImplementCustomLogic
@override(Learner)
def configure_optimizers_for_module(
Expand All @@ -108,7 +117,7 @@ def configure_optimizers_for_module(

# For this default implementation, the learning rate is handled by the
# attached lr Scheduler (controlled by self.config.lr, which can be a
# fixed value of a schedule setting).
# fixed value or a schedule setting).
params = self.get_parameters(module)
optimizer = torch.optim.Adam(params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to pass in kwargs to Adam?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used to think the same way.

BUT that would make configuring custom optimizers again very non-pythonic and very yaml'ish. The user would then have to provide a class/type and some kwargs, but then has no chance to customize anything else within the optimizer setup process. Think of a user having to configure two optimizers, or three. Where do you make these options available in the config, then? What if the user needs different optimizers per module?


Expand All @@ -130,7 +139,7 @@ def _uncompiled_update(
self.metrics.activate_tensor_mode()

fwd_out = self.module.forward_train(batch)
loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=batch)
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch)

gradients = self.compute_gradients(loss_per_module)
postprocessed_gradients = self.postprocess_gradients(gradients)
Expand All @@ -149,24 +158,48 @@ def compute_gradients(
for optim in self._optimizer_parameters:
# `set_to_none=True` is a faster way to zero out the gradients.
optim.zero_grad(set_to_none=True)
loss_per_module[ALL_MODULES].backward()

if self._grad_scalers is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!!!

total_loss = sum(
self._grad_scalers[mid].scale(loss)
for mid, loss in loss_per_module.items()
)
else:
total_loss = sum(loss_per_module.values())

total_loss.backward()
grads = {pid: p.grad for pid, p in self._params.items()}

return grads

@override(Learner)
def apply_gradients(self, gradients_dict: ParamDict) -> None:
# Make sure the parameters do not carry gradients on their own.
for optim in self._optimizer_parameters:
optim.zero_grad(set_to_none=True)

# Set the gradient of the parameters.
for pid, grad in gradients_dict.items():
self._params[pid].grad = grad

# For each optimizer call its step function.
for optim in self._optimizer_parameters:
optim.step()
for module_id, optimizer_names in self._module_optimizers.items():
for optimizer_name in optimizer_names:
optim = self.get_optimizer(module_id, optimizer_name)
# Step through the scaler (unscales gradients, if applicable).
if self._grad_scalers is not None:
scaler = self._grad_scalers[module_id]
scaler.step(optim)
self.metrics.log_value(
(module_id, "_torch_grad_scaler_current_scale"),
scaler.get_scale(),
window=1, # snapshot in time, no EMA/mean.
)
# Update the scaler.
scaler.update()
# `step` the optimizer (default), but only if all gradients are finite.
elif all(
param.grad is None or torch.isfinite(param.grad).all()
for group in optim.param_groups
for param in group["params"]
):
optim.step()

@override(Learner)
def _get_optimizer_state(self) -> StateDict:
Expand Down
Loading
Loading