-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Cleanup examples folder (vol 23): Float16 training support and new example script. #47362
Changes from 8 commits
41600c3
9feb1ef
33fefd8
42da46b
819c241
cc5e5c1
6596ee5
e676226
af14020
58f4e26
07acd9f
73b359c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2969,6 +2969,14 @@ py_test( | |
|
||
# subdirectory: gpus/ | ||
# .................................... | ||
py_test( | ||
name = "examples/gpus/float16_training_and_inference", | ||
main = "examples/gpus/float16_training_and_inference.py", | ||
tags = ["team:rllib", "exclusive", "examples", "gpu"], | ||
size = "medium", | ||
srcs = ["examples/gpus/float16_training_and_inference.py"], | ||
args = ["--enable-new-api-stack", "--as-test", "--stop-reward=150.0"] | ||
) | ||
Comment on lines
+2972
to
+2979
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting. Is this even used by users? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Actually a customer asked for it ;) |
||
py_test( | ||
name = "examples/gpus/fractional_0.5_gpus_per_learner", | ||
main = "examples/gpus/fractional_gpus_per_learner.py", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -540,6 +540,7 @@ def __init__(self, algo_class: Optional[type] = None): | |
self._per_module_overrides: Dict[ModuleID, "AlgorithmConfig"] = {} | ||
|
||
# `self.experimental()` | ||
self._torch_grad_scaler_class = None | ||
self._tf_policy_handles_more_than_one_loss = False | ||
self._disable_preprocessor_api = False | ||
self._disable_action_flattening = False | ||
|
@@ -3158,6 +3159,13 @@ def rl_module( | |
Returns: | ||
This updated AlgorithmConfig object. | ||
""" | ||
if _enable_rl_module_api != DEPRECATED_VALUE: | ||
deprecation_warning( | ||
old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)", | ||
new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)", | ||
error=True, | ||
) | ||
|
||
if model_config_dict is not NotProvided: | ||
self._model_config_dict = model_config_dict | ||
if rl_module_spec is not NotProvided: | ||
|
@@ -3173,17 +3181,12 @@ def rl_module( | |
algorithm_config_overrides_per_module | ||
) | ||
|
||
if _enable_rl_module_api != DEPRECATED_VALUE: | ||
deprecation_warning( | ||
old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)", | ||
new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)", | ||
error=False, | ||
) | ||
return self | ||
|
||
def experimental( | ||
self, | ||
*, | ||
_torch_grad_scaler_class: Optional[Type] = NotProvided, | ||
_tf_policy_handles_more_than_one_loss: Optional[bool] = NotProvided, | ||
_disable_preprocessor_api: Optional[bool] = NotProvided, | ||
_disable_action_flattening: Optional[bool] = NotProvided, | ||
|
@@ -3194,6 +3197,14 @@ def experimental( | |
"""Sets the config's experimental settings. | ||
|
||
Args: | ||
_torch_grad_scaler_class: Class to use for torch loss scaling (and gradient | ||
unscaling). The class must implement the following methods to be | ||
compatible with a `TorchLearner`. These methods/APIs match exactly the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small nit: Remove "the" at the end. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
those of torch's own `torch.amp.GradScaler`: | ||
`scale([loss])` to scale the loss. | ||
`get_scale()` to get the current scale value. | ||
`step([optimizer])` to unscale the grads and step the given optimizer. | ||
`update()` to update the scaler after an optimizer step. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The description is not clear enough imo. For what is it used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. enhanced and added a link to the torch docs |
||
_tf_policy_handles_more_than_one_loss: Experimental flag. | ||
If True, TFPolicy will handle more than one loss/optimizer. | ||
Set this to True, if you would like to return more than | ||
|
@@ -3235,6 +3246,8 @@ def experimental( | |
self._disable_initialize_loss_from_dummy_batch = ( | ||
_disable_initialize_loss_from_dummy_batch | ||
) | ||
if _torch_grad_scaler_class is not NotProvided: | ||
self._torch_grad_scaler_class = _torch_grad_scaler_class | ||
|
||
return self | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,10 +47,11 @@ def compute_loss_for_module( | |
# for which we add an additional (artificial) timestep to each episode to | ||
# simplify the actual computation. | ||
if Columns.LOSS_MASK in batch: | ||
num_valid = torch.sum(batch[Columns.LOSS_MASK]) | ||
mask = batch[Columns.LOSS_MASK] | ||
num_valid = torch.sum(mask) | ||
|
||
def possibly_masked_mean(data_): | ||
return torch.sum(data_[batch[Columns.LOSS_MASK]]) / num_valid | ||
return torch.sum(data_[mask]) / num_valid | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we potentially also mask with this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That should work. But you would have to manually change that in the Learner connector (where this column is being produced and added to the batch). |
||
|
||
else: | ||
possibly_masked_mean = torch.mean | ||
|
@@ -98,9 +99,8 @@ def possibly_masked_mean(data_): | |
mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) | ||
# Ignore the value function. | ||
else: | ||
value_fn_out = torch.tensor(0.0).to(surrogate_loss.device) | ||
mean_vf_unclipped_loss = torch.tensor(0.0).to(surrogate_loss.device) | ||
vf_loss_clipped = mean_vf_loss = torch.tensor(0.0).to(surrogate_loss.device) | ||
z = torch.tensor(0.0, device=surrogate_loss.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. simplify |
||
value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Niceee! |
||
|
||
total_loss = possibly_masked_mean( | ||
-surrogate_loss | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -214,8 +214,8 @@ class Learner(Checkpointable): | |
|
||
class MyLearner(TorchLearner): | ||
|
||
def compute_loss(self, fwd_out, batch): | ||
# compute the loss based on batch and output of the forward pass | ||
def compute_losses(self, fwd_out, batch): | ||
# Compute the loss based on batch and output of the forward pass | ||
# to access the learner hyper-parameters use `self._hps` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great catch! Will fix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
return {ALL_MODULES: loss} | ||
""" | ||
|
@@ -597,9 +597,22 @@ def get_optimizer( | |
The optimizer object, configured under the given `module_id` and | ||
`optimizer_name`. | ||
""" | ||
# `optimizer_name` could possibly be the full optimizer name (including the | ||
# module_id under which it is registered). | ||
if optimizer_name in self._named_optimizers: | ||
return self._named_optimizers[optimizer_name] | ||
|
||
# Normally, `optimizer_name` is just the optimizer's name, not including the | ||
# `module_id`. | ||
full_registration_name = module_id + "_" + optimizer_name | ||
assert full_registration_name in self._named_optimizers | ||
return self._named_optimizers[full_registration_name] | ||
if full_registration_name in self._named_optimizers: | ||
return self._named_optimizers[full_registration_name] | ||
|
||
# No optimizer found. | ||
raise KeyError( | ||
f"Optimizer not found! module_id={module_id} " | ||
f"optimizer_name={optimizer_name}" | ||
) | ||
|
||
def get_optimizers_for_module( | ||
self, module_id: ModuleID = ALL_MODULES | ||
|
@@ -830,33 +843,30 @@ def should_module_be_updated(self, module_id, multi_agent_batch=None): | |
return should_module_be_updated_fn(module_id, multi_agent_batch) | ||
|
||
@OverrideToImplementCustomLogic | ||
def compute_loss( | ||
def compute_losses( | ||
self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any] | ||
) -> Dict[str, Any]: | ||
"""Computes the loss for the module being optimized. | ||
|
||
This method must be overridden by multiagent-specific algorithm learners to | ||
specify the specific loss computation logic. If the algorithm is single agent | ||
`compute_loss_for_module()` should be overridden instead. | ||
`fwd_out` is the output of the `forward_train()` method of the underlying | ||
MultiRLModule. `batch` is the data that was used to compute `fwd_out`. | ||
The returned dictionary must contain a key called | ||
ALL_MODULES, which will be used to compute gradients. It is recommended | ||
to not compute any forward passes within this method, and to use the | ||
`forward_train()` outputs of the RLModule(s) to compute the required tensors for | ||
loss calculations. | ||
"""Computes the loss(es) for the module being optimized. | ||
|
||
This method must be overridden by MultiRLModule-specific Learners in order to | ||
define the specific loss computation logic. If the algorithm is single-agent | ||
`compute_loss_for_module()` should be overridden instead. If the algorithm uses | ||
independent multi-agent learning (default behavior for multi-agent setups), also | ||
`compute_loss_for_module()` should be overridden, but it will be called for each | ||
individual RLModule inside the MultiRLModule. | ||
Comment on lines
+851
to
+856
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's either refer to the docs about this or to some examples. Just from the text this does not become very clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can link to the custom loss example ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
It is recommended to not compute any forward passes within this method, and to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very important hint! Awesome! |
||
use the `forward_train()` outputs of the RLModule(s) to compute the required | ||
tensors for loss calculations. | ||
|
||
Args: | ||
fwd_out: Output from a call to the `forward_train()` method of self.module | ||
during training (`self.update()`). | ||
fwd_out: Output from a call to the `forward_train()` method of the | ||
underlying MultiRLModule (`self.module`) during training | ||
(`self.update()`). | ||
batch: The training batch that was used to compute `fwd_out`. | ||
|
||
Returns: | ||
A dictionary mapping module IDs to individual loss terms. The dictionary | ||
must contain one protected key ALL_MODULES which will be used for computing | ||
gradients through. | ||
A dictionary mapping module IDs to individual loss terms. | ||
""" | ||
loss_total = None | ||
loss_per_module = {} | ||
for module_id in fwd_out: | ||
module_batch = batch[module_id] | ||
|
@@ -870,13 +880,6 @@ def compute_loss( | |
) | ||
loss_per_module[module_id] = loss | ||
|
||
if loss_total is None: | ||
loss_total = loss | ||
else: | ||
loss_total += loss | ||
|
||
loss_per_module[ALL_MODULES] = loss_total | ||
|
||
return loss_per_module | ||
|
||
@OverrideToImplementCustomLogic | ||
|
@@ -893,7 +896,7 @@ def compute_loss_for_module( | |
|
||
Think of this as computing loss for a single agent. For multi-agent use-cases | ||
that require more complicated computation for loss, consider overriding the | ||
`compute_loss` method instead. | ||
`compute_losses` method instead. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default implementation of If you have a complex multi-agent case, you should override |
||
|
||
Args: | ||
module_id: The id of the module. | ||
|
@@ -1621,3 +1624,13 @@ def get_optimizer_state(self, *args, **kwargs): | |
@Deprecated(new="Learner._set_optimizer_state()", error=True) | ||
def set_optimizer_state(self, *args, **kwargs): | ||
pass | ||
|
||
@Deprecated(new="Learner.compute_losses(...)", error=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Emulate the old behavior, in case users use this in their tests or other code. |
||
def compute_loss(self, *args, **kwargs): | ||
losses_per_module = self.compute_losses(*args, **kwargs) | ||
# To continue supporting the old `compute_loss` behavior (instead of | ||
# the new `compute_losses`, add the ALL_MODULES key here holding the sum | ||
# of all individual loss terms. | ||
if ALL_MODULES not in losses_per_module: | ||
losses_per_module[ALL_MODULES] = sum(losses_per_module.values()) | ||
return losses_per_module |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,6 @@ | |
OverrideToImplementCustomLogic, | ||
) | ||
from ray.rllib.utils.framework import try_import_tf | ||
from ray.rllib.utils.metrics import ALL_MODULES | ||
from ray.rllib.utils.typing import ( | ||
ModuleID, | ||
Optimizer, | ||
|
@@ -75,7 +74,7 @@ def configure_optimizers_for_module( | |
|
||
# For this default implementation, the learning rate is handled by the | ||
# attached lr Scheduler (controlled by self.config.lr, which can be a | ||
# fixed value of a schedule setting). | ||
# fixed value or a schedule setting). | ||
optimizer = tf.keras.optimizers.Adam() | ||
params = self.get_parameters(module) | ||
|
||
|
@@ -99,7 +98,8 @@ def compute_gradients( | |
gradient_tape: "tf.GradientTape", | ||
**kwargs, | ||
) -> ParamDict: | ||
grads = gradient_tape.gradient(loss_per_module[ALL_MODULES], self._params) | ||
total_loss = sum(loss_per_module.values()) | ||
grads = gradient_tape.gradient(total_loss, self._params) | ||
return grads | ||
|
||
@override(Learner) | ||
|
@@ -300,7 +300,7 @@ def _untraced_update( | |
def helper(_batch): | ||
with tf.GradientTape(persistent=True) as tape: | ||
fwd_out = self._module.forward_train(_batch) | ||
loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=_batch) | ||
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=_batch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we even want this here and motivate users to write TF algorithms? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! But .... DreamerV3 :| |
||
gradients = self.compute_gradients(loss_per_module, gradient_tape=tape) | ||
del tape | ||
postprocessed_gradients = self.postprocess_gradients(gradients) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from collections import defaultdict | ||
import logging | ||
from typing import ( | ||
Any, | ||
|
@@ -97,6 +98,14 @@ def __init__(self, **kwargs): | |
torch_dynamo_mode=self.config.torch_compile_learner_dynamo_mode, | ||
) | ||
|
||
# Loss scalers for mixed precision training. Map optimizer names to | ||
# associated torch GradScaler objects. | ||
self._grad_scalers = None | ||
if self.config._torch_grad_scaler_class: | ||
self._grad_scalers = defaultdict( | ||
lambda: self.config._torch_grad_scaler_class() | ||
) | ||
|
||
@OverrideToImplementCustomLogic | ||
@override(Learner) | ||
def configure_optimizers_for_module( | ||
|
@@ -108,7 +117,7 @@ def configure_optimizers_for_module( | |
|
||
# For this default implementation, the learning rate is handled by the | ||
# attached lr Scheduler (controlled by self.config.lr, which can be a | ||
# fixed value of a schedule setting). | ||
# fixed value or a schedule setting). | ||
params = self.get_parameters(module) | ||
optimizer = torch.optim.Adam(params) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we want to pass in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Used to think the same way. BUT that would make configuring custom optimizers again very non-pythonic and very yaml'ish. The user would then have to provide a class/type and some kwargs, but then has no chance to customize anything else within the optimizer setup process. Think of a user having to configure two optimizers, or three. Where do you make these options available in the config, then? What if the user needs different optimizers per module? |
||
|
||
|
@@ -130,7 +139,7 @@ def _uncompiled_update( | |
self.metrics.activate_tensor_mode() | ||
|
||
fwd_out = self.module.forward_train(batch) | ||
loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=batch) | ||
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch) | ||
|
||
gradients = self.compute_gradients(loss_per_module) | ||
postprocessed_gradients = self.postprocess_gradients(gradients) | ||
|
@@ -149,24 +158,48 @@ def compute_gradients( | |
for optim in self._optimizer_parameters: | ||
# `set_to_none=True` is a faster way to zero out the gradients. | ||
optim.zero_grad(set_to_none=True) | ||
loss_per_module[ALL_MODULES].backward() | ||
|
||
if self._grad_scalers is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome!!! |
||
total_loss = sum( | ||
self._grad_scalers[mid].scale(loss) | ||
for mid, loss in loss_per_module.items() | ||
) | ||
else: | ||
total_loss = sum(loss_per_module.values()) | ||
|
||
total_loss.backward() | ||
grads = {pid: p.grad for pid, p in self._params.items()} | ||
|
||
return grads | ||
|
||
@override(Learner) | ||
def apply_gradients(self, gradients_dict: ParamDict) -> None: | ||
# Make sure the parameters do not carry gradients on their own. | ||
for optim in self._optimizer_parameters: | ||
optim.zero_grad(set_to_none=True) | ||
|
||
# Set the gradient of the parameters. | ||
for pid, grad in gradients_dict.items(): | ||
self._params[pid].grad = grad | ||
|
||
# For each optimizer call its step function. | ||
for optim in self._optimizer_parameters: | ||
optim.step() | ||
for module_id, optimizer_names in self._module_optimizers.items(): | ||
for optimizer_name in optimizer_names: | ||
optim = self.get_optimizer(module_id, optimizer_name) | ||
# Step through the scaler (unscales gradients, if applicable). | ||
if self._grad_scalers is not None: | ||
scaler = self._grad_scalers[module_id] | ||
scaler.step(optim) | ||
self.metrics.log_value( | ||
(module_id, "_torch_grad_scaler_current_scale"), | ||
scaler.get_scale(), | ||
window=1, # snapshot in time, no EMA/mean. | ||
) | ||
# Update the scaler. | ||
scaler.update() | ||
# `step` the optimizer (default), but only if all gradients are finite. | ||
elif all( | ||
param.grad is None or torch.isfinite(param.grad).all() | ||
for group in optim.param_groups | ||
for param in group["params"] | ||
): | ||
optim.step() | ||
|
||
@override(Learner) | ||
def _get_optimizer_state(self) -> StateDict: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed this API name to be even more clear. This method computes one(!) loss per RLModule (in a MultiRLModule) inside the Learner.
Got rid of the confusing
TOTAL_LOSS
key. We compute this now in the default implementation ofcompute_gradients
.