Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder (vol 24): Mixed-precision training (and float16 inference) through new example script. #47116

Merged
merged 24 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/rllib/package_ref/learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Computing Losses
:nosignatures:
:toctree: doc/

Learner.compute_loss
Learner.compute_losses
Learner.compute_loss_for_module
Learner._is_module_compatible_with_learner
Learner._get_tensor_variable
Expand Down
23 changes: 17 additions & 6 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ def __init__(self, algo_class: Optional[type] = None):
self._per_module_overrides: Dict[ModuleID, "AlgorithmConfig"] = {}

# `self.experimental()`
self._enable_torch_mixed_precision_training = False
self._tf_policy_handles_more_than_one_loss = False
self._disable_preprocessor_api = False
self._disable_action_flattening = False
Expand Down Expand Up @@ -3143,6 +3144,13 @@ def rl_module(
Returns:
This updated AlgorithmConfig object.
"""
if _enable_rl_module_api != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)",
new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)",
error=True,
)

if model_config_dict is not NotProvided:
self._model_config_dict = model_config_dict
if rl_module_spec is not NotProvided:
Expand All @@ -3158,17 +3166,12 @@ def rl_module(
algorithm_config_overrides_per_module
)

if _enable_rl_module_api != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)",
new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)",
error=False,
)
return self

def experimental(
self,
*,
_enable_torch_mixed_precision_training: Optional[bool] = NotProvided,
_tf_policy_handles_more_than_one_loss: Optional[bool] = NotProvided,
_disable_preprocessor_api: Optional[bool] = NotProvided,
_disable_action_flattening: Optional[bool] = NotProvided,
Expand All @@ -3179,6 +3182,10 @@ def experimental(
"""Sets the config's experimental settings.

Args:
_enable_torch_mixed_precision_training: Whether to switch on automatic
mixed-precision training for torch RLModules. Note that this setting
only works on the new API stack, by doing
`config.api_stack(enable_rl_module_and_learner=True)`.
_tf_policy_handles_more_than_one_loss: Experimental flag.
If True, TFPolicy will handle more than one loss/optimizer.
Set this to True, if you would like to return more than
Expand Down Expand Up @@ -3208,6 +3215,10 @@ def experimental(
)
self.api_stack(enable_rl_module_and_learner=_enable_new_api_stack)

if _enable_torch_mixed_precision_training is not NotProvided:
self._enable_torch_mixed_precision_training = (
_enable_torch_mixed_precision_training
)
if _tf_policy_handles_more_than_one_loss is not NotProvided:
self._tf_policy_handles_more_than_one_loss = (
_tf_policy_handles_more_than_one_loss
Expand Down
75 changes: 44 additions & 31 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ class Learner(Checkpointable):

class MyLearner(TorchLearner):

def compute_loss(self, fwd_out, batch):
# compute the loss based on batch and output of the forward pass
def compute_losses(self, fwd_out, batch):
# Compute the loss based on batch and output of the forward pass
# to access the learner hyper-parameters use `self._hps`
return {ALL_MODULES: loss}
"""
Expand Down Expand Up @@ -595,9 +595,22 @@ def get_optimizer(
The optimizer object, configured under the given `module_id` and
`optimizer_name`.
"""
# `optimizer_name` could possibly be the full optimizer name (including the
# module_id under which it is registered).
if optimizer_name in self._named_optimizers:
return self._named_optimizers[optimizer_name]

# Normally, `optimizer_name` is just the optimizer's name, not including the
# `module_id`.
full_registration_name = module_id + "_" + optimizer_name
assert full_registration_name in self._named_optimizers
return self._named_optimizers[full_registration_name]
if full_registration_name in self._named_optimizers:
return self._named_optimizers[full_registration_name]

# No optimizer found.
raise KeyError(
f"Optimizer not found! module_id={module_id} "
f"optimizer_name={optimizer_name}"
)

def get_optimizers_for_module(
self, module_id: ModuleID = ALL_MODULES
Expand Down Expand Up @@ -828,33 +841,30 @@ def should_module_be_updated(self, module_id, multi_agent_batch=None):
return should_module_be_updated_fn(module_id, multi_agent_batch)

@OverrideToImplementCustomLogic
def compute_loss(
def compute_losses(
self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any]
) -> Dict[str, Any]:
"""Computes the loss for the module being optimized.

This method must be overridden by multiagent-specific algorithm learners to
specify the specific loss computation logic. If the algorithm is single agent
`compute_loss_for_module()` should be overridden instead.
`fwd_out` is the output of the `forward_train()` method of the underlying
MultiRLModule. `batch` is the data that was used to compute `fwd_out`.
The returned dictionary must contain a key called
ALL_MODULES, which will be used to compute gradients. It is recommended
to not compute any forward passes within this method, and to use the
`forward_train()` outputs of the RLModule(s) to compute the required tensors for
loss calculations.
"""Computes the loss(es) for the module being optimized.

This method must be overridden by MultiRLModule-specific Learners in order to
define the specific loss computation logic. If the algorithm is single-agent
`compute_loss_for_module()` should be overridden instead. If the algorithm uses
independent multi-agent learning (default behavior for multi-agent setups), also
`compute_loss_for_module()` should be overridden, but it will be called for each
individual RLModule inside the MultiRLModule.
It is recommended to not compute any forward passes within this method, and to
use the `forward_train()` outputs of the RLModule(s) to compute the required
tensors for loss calculations.

Args:
fwd_out: Output from a call to the `forward_train()` method of self.module
during training (`self.update()`).
fwd_out: Output from a call to the `forward_train()` method of the
underlying MultiRLModule (`self.module`) during training
(`self.update()`).
batch: The training batch that was used to compute `fwd_out`.

Returns:
A dictionary mapping module IDs to individual loss terms. The dictionary
must contain one protected key ALL_MODULES which will be used for computing
gradients through.
A dictionary mapping module IDs to individual loss terms.
"""
loss_total = None
loss_per_module = {}
for module_id in fwd_out:
module_batch = batch[module_id]
Expand All @@ -868,13 +878,6 @@ def compute_loss(
)
loss_per_module[module_id] = loss

if loss_total is None:
loss_total = loss
else:
loss_total += loss

loss_per_module[ALL_MODULES] = loss_total

return loss_per_module

@OverrideToImplementCustomLogic
Expand All @@ -891,7 +894,7 @@ def compute_loss_for_module(

Think of this as computing loss for a single agent. For multi-agent use-cases
that require more complicated computation for loss, consider overriding the
`compute_loss` method instead.
`compute_losses` method instead.

Args:
module_id: The id of the module.
Expand Down Expand Up @@ -1665,3 +1668,13 @@ def get_optimizer_state(self, *args, **kwargs):
@Deprecated(new="Learner._set_optimizer_state()", error=True)
def set_optimizer_state(self, *args, **kwargs):
pass

@Deprecated(new="Learner.compute_losses(...)", error=False)
def compute_loss(self, *args, **kwargs):
losses_per_module = self.compute_losses(*args, **kwargs)
# To continue supporting the old `compute_loss` behavior (instead of
# the new `compute_losses`, add the ALL_MODULES key here holding the sum
# of all individual loss terms.
if ALL_MODULES not in losses_per_module:
losses_per_module[ALL_MODULES] = sum(losses_per_module.values())
return losses_per_module
6 changes: 3 additions & 3 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
OverrideToImplementCustomLogic,
)
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import ALL_MODULES
from ray.rllib.utils.typing import (
ModuleID,
Optimizer,
Expand Down Expand Up @@ -99,7 +98,8 @@ def compute_gradients(
gradient_tape: "tf.GradientTape",
**kwargs,
) -> ParamDict:
grads = gradient_tape.gradient(loss_per_module[ALL_MODULES], self._params)
total_loss = sum(loss_per_module.values())
grads = gradient_tape.gradient(total_loss, self._params)
return grads

@override(Learner)
Expand Down Expand Up @@ -300,7 +300,7 @@ def _untraced_update(
def helper(_batch):
with tf.GradientTape(persistent=True) as tape:
fwd_out = self._module.forward_train(_batch)
loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=_batch)
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=_batch)
gradients = self.compute_gradients(loss_per_module, gradient_tape=tape)
del tape
postprocessed_gradients = self.postprocess_gradients(gradients)
Expand Down
42 changes: 35 additions & 7 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import logging
from typing import (
Any,
Expand Down Expand Up @@ -97,6 +98,10 @@ def __init__(self, **kwargs):
torch_dynamo_mode=self.config.torch_compile_learner_dynamo_mode,
)

# Loss scalers for mixed precision training. Map optimizer names to
# associated torch GradScaler objects.
self._amp_grad_scalers = defaultdict(lambda: torch.amp.GradScaler(self._device))

@OverrideToImplementCustomLogic
@override(Learner)
def configure_optimizers_for_module(
Expand Down Expand Up @@ -129,8 +134,13 @@ def _uncompiled_update(
# Activate tensor-mode on our MetricsLogger.
self.metrics.activate_tensor_mode()

fwd_out = self.module.forward_train(batch)
loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=batch)
if self.config._enable_torch_mixed_precision_training:
with torch.cuda.amp.autocast():
fwd_out = self.module.forward_train(batch)
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch)
else:
fwd_out = self.module.forward_train(batch)
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch)

gradients = self.compute_gradients(loss_per_module)
postprocessed_gradients = self.postprocess_gradients(gradients)
Expand All @@ -149,24 +159,42 @@ def compute_gradients(
for optim in self._optimizer_parameters:
# `set_to_none=True` is a faster way to zero out the gradients.
optim.zero_grad(set_to_none=True)
loss_per_module[ALL_MODULES].backward()

if self.config._enable_torch_mixed_precision_training:
total_loss = sum(
self._amp_grad_scalers[key].scale(loss)
for key, loss in loss_per_module.items()
)
else:
total_loss = sum(loss_per_module.values())

total_loss.backward()
grads = {pid: p.grad for pid, p in self._params.items()}

return grads

@override(Learner)
def apply_gradients(self, gradients_dict: ParamDict) -> None:
# Make sure the parameters do not carry gradients on their own.
for optim in self._optimizer_parameters:
optim.zero_grad(set_to_none=True)
# for optim in self._optimizer_parameters:
# optim.zero_grad(set_to_none=True)

# Set the gradient of the parameters.
for pid, grad in gradients_dict.items():
self._params[pid].grad = grad

# For each optimizer call its step function.
for optim in self._optimizer_parameters:
optim.step()
for module_id, optimizer_names in self._module_optimizers.items():
for optimizer_name in optimizer_names:
optim = self.get_optimizer(module_id, optimizer_name)
if self.config._enable_torch_mixed_precision_training:
self._amp_grad_scalers[module_id].step(optim)
else:
optim.step()

if self.config._enable_torch_mixed_precision_training:
for scaler in self._amp_grad_scalers.values():
scaler.update()

@override(Learner)
def _get_optimizer_state(self) -> StateDict:
Expand Down
Loading