diff --git a/doc/source/rllib/package_ref/learner.rst b/doc/source/rllib/package_ref/learner.rst index dd91de94b60e4..aa9f3ded0edb0 100644 --- a/doc/source/rllib/package_ref/learner.rst +++ b/doc/source/rllib/package_ref/learner.rst @@ -91,7 +91,7 @@ Computing Losses :nosignatures: :toctree: doc/ - Learner.compute_loss + Learner.compute_losses Learner.compute_loss_for_module Learner._is_module_compatible_with_learner Learner._get_tensor_variable diff --git a/rllib/BUILD b/rllib/BUILD index e681e4c252a02..9daeaeed483d3 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2969,6 +2969,14 @@ py_test( # subdirectory: gpus/ # .................................... +py_test( + name = "examples/gpus/float16_training_and_inference", + main = "examples/gpus/float16_training_and_inference.py", + tags = ["team:rllib", "exclusive", "examples", "gpu"], + size = "medium", + srcs = ["examples/gpus/float16_training_and_inference.py"], + args = ["--enable-new-api-stack", "--as-test", "--stop-reward=150.0"] +) py_test( name = "examples/gpus/fractional_0.5_gpus_per_learner", main = "examples/gpus/fractional_gpus_per_learner.py", diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 41cc4874e25c1..ac1b1f296eed1 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -540,6 +540,7 @@ def __init__(self, algo_class: Optional[type] = None): self._per_module_overrides: Dict[ModuleID, "AlgorithmConfig"] = {} # `self.experimental()` + self._torch_grad_scaler_class = None self._tf_policy_handles_more_than_one_loss = False self._disable_preprocessor_api = False self._disable_action_flattening = False @@ -3158,6 +3159,13 @@ def rl_module( Returns: This updated AlgorithmConfig object. """ + if _enable_rl_module_api != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)", + new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)", + error=True, + ) + if model_config_dict is not NotProvided: self._model_config_dict = model_config_dict if rl_module_spec is not NotProvided: @@ -3173,17 +3181,12 @@ def rl_module( algorithm_config_overrides_per_module ) - if _enable_rl_module_api != DEPRECATED_VALUE: - deprecation_warning( - old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)", - new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)", - error=False, - ) return self def experimental( self, *, + _torch_grad_scaler_class: Optional[Type] = NotProvided, _tf_policy_handles_more_than_one_loss: Optional[bool] = NotProvided, _disable_preprocessor_api: Optional[bool] = NotProvided, _disable_action_flattening: Optional[bool] = NotProvided, @@ -3194,6 +3197,14 @@ def experimental( """Sets the config's experimental settings. Args: + _torch_grad_scaler_class: Class to use for torch loss scaling (and gradient + unscaling). The class must implement the following methods to be + compatible with a `TorchLearner`. These methods/APIs match exactly the + those of torch's own `torch.amp.GradScaler`: + `scale([loss])` to scale the loss. + `get_scale()` to get the current scale value. + `step([optimizer])` to unscale the grads and step the given optimizer. + `update()` to update the scaler after an optimizer step. _tf_policy_handles_more_than_one_loss: Experimental flag. If True, TFPolicy will handle more than one loss/optimizer. Set this to True, if you would like to return more than @@ -3235,6 +3246,8 @@ def experimental( self._disable_initialize_loss_from_dummy_batch = ( _disable_initialize_loss_from_dummy_batch ) + if _torch_grad_scaler_class is not NotProvided: + self._torch_grad_scaler_class = _torch_grad_scaler_class return self diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index 8442410c51dbe..88e7b5737de7b 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -47,10 +47,11 @@ def compute_loss_for_module( # for which we add an additional (artificial) timestep to each episode to # simplify the actual computation. if Columns.LOSS_MASK in batch: - num_valid = torch.sum(batch[Columns.LOSS_MASK]) + mask = batch[Columns.LOSS_MASK] + num_valid = torch.sum(mask) def possibly_masked_mean(data_): - return torch.sum(data_[batch[Columns.LOSS_MASK]]) / num_valid + return torch.sum(data_[mask]) / num_valid else: possibly_masked_mean = torch.mean @@ -98,9 +99,8 @@ def possibly_masked_mean(data_): mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) # Ignore the value function. else: - value_fn_out = torch.tensor(0.0).to(surrogate_loss.device) - mean_vf_unclipped_loss = torch.tensor(0.0).to(surrogate_loss.device) - vf_loss_clipped = mean_vf_loss = torch.tensor(0.0).to(surrogate_loss.device) + z = torch.tensor(0.0, device=surrogate_loss.device) + value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z total_loss = possibly_masked_mean( -surrogate_loss diff --git a/rllib/benchmarks/torch_compile/run_inference_bm.py b/rllib/benchmarks/torch_compile/run_inference_bm.py index d5da1f57a18ac..bc5d83cf40870 100644 --- a/rllib/benchmarks/torch_compile/run_inference_bm.py +++ b/rllib/benchmarks/torch_compile/run_inference_bm.py @@ -15,7 +15,7 @@ from ray.rllib.benchmarks.torch_compile.utils import get_ppo_batch_for_env, timed from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.rl_module.torch.torch_rl_module import TorchCompileConfig -from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind +from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack from ray.rllib.models.catalog import MODEL_DEFAULTS from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -91,8 +91,8 @@ def main(pargs): with open(output / "args.json", "w") as f: json.dump(config, f) - # create the environment - env = wrap_deepmind(gym.make("GymV26Environment-v0", env_id="ALE/Breakout-v5")) + # Create the environment. + env = wrap_atari_for_new_api_stack(gym.make("ALE/Breakout-v5")) # setup RLModule model_cfg = MODEL_DEFAULTS.copy() diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index cb19783a5bfa1..db959691dcb57 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -214,8 +214,8 @@ class Learner(Checkpointable): class MyLearner(TorchLearner): - def compute_loss(self, fwd_out, batch): - # compute the loss based on batch and output of the forward pass + def compute_losses(self, fwd_out, batch): + # Compute the loss based on batch and output of the forward pass # to access the learner hyper-parameters use `self._hps` return {ALL_MODULES: loss} """ @@ -597,9 +597,22 @@ def get_optimizer( The optimizer object, configured under the given `module_id` and `optimizer_name`. """ + # `optimizer_name` could possibly be the full optimizer name (including the + # module_id under which it is registered). + if optimizer_name in self._named_optimizers: + return self._named_optimizers[optimizer_name] + + # Normally, `optimizer_name` is just the optimizer's name, not including the + # `module_id`. full_registration_name = module_id + "_" + optimizer_name - assert full_registration_name in self._named_optimizers - return self._named_optimizers[full_registration_name] + if full_registration_name in self._named_optimizers: + return self._named_optimizers[full_registration_name] + + # No optimizer found. + raise KeyError( + f"Optimizer not found! module_id={module_id} " + f"optimizer_name={optimizer_name}" + ) def get_optimizers_for_module( self, module_id: ModuleID = ALL_MODULES @@ -830,33 +843,30 @@ def should_module_be_updated(self, module_id, multi_agent_batch=None): return should_module_be_updated_fn(module_id, multi_agent_batch) @OverrideToImplementCustomLogic - def compute_loss( + def compute_losses( self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any] ) -> Dict[str, Any]: - """Computes the loss for the module being optimized. - - This method must be overridden by multiagent-specific algorithm learners to - specify the specific loss computation logic. If the algorithm is single agent - `compute_loss_for_module()` should be overridden instead. - `fwd_out` is the output of the `forward_train()` method of the underlying - MultiRLModule. `batch` is the data that was used to compute `fwd_out`. - The returned dictionary must contain a key called - ALL_MODULES, which will be used to compute gradients. It is recommended - to not compute any forward passes within this method, and to use the - `forward_train()` outputs of the RLModule(s) to compute the required tensors for - loss calculations. + """Computes the loss(es) for the module being optimized. + + This method must be overridden by MultiRLModule-specific Learners in order to + define the specific loss computation logic. If the algorithm is single-agent + `compute_loss_for_module()` should be overridden instead. If the algorithm uses + independent multi-agent learning (default behavior for multi-agent setups), also + `compute_loss_for_module()` should be overridden, but it will be called for each + individual RLModule inside the MultiRLModule. + It is recommended to not compute any forward passes within this method, and to + use the `forward_train()` outputs of the RLModule(s) to compute the required + tensors for loss calculations. Args: - fwd_out: Output from a call to the `forward_train()` method of self.module - during training (`self.update()`). + fwd_out: Output from a call to the `forward_train()` method of the + underlying MultiRLModule (`self.module`) during training + (`self.update()`). batch: The training batch that was used to compute `fwd_out`. Returns: - A dictionary mapping module IDs to individual loss terms. The dictionary - must contain one protected key ALL_MODULES which will be used for computing - gradients through. + A dictionary mapping module IDs to individual loss terms. """ - loss_total = None loss_per_module = {} for module_id in fwd_out: module_batch = batch[module_id] @@ -870,13 +880,6 @@ def compute_loss( ) loss_per_module[module_id] = loss - if loss_total is None: - loss_total = loss - else: - loss_total += loss - - loss_per_module[ALL_MODULES] = loss_total - return loss_per_module @OverrideToImplementCustomLogic @@ -893,7 +896,7 @@ def compute_loss_for_module( Think of this as computing loss for a single agent. For multi-agent use-cases that require more complicated computation for loss, consider overriding the - `compute_loss` method instead. + `compute_losses` method instead. Args: module_id: The id of the module. @@ -1621,3 +1624,13 @@ def get_optimizer_state(self, *args, **kwargs): @Deprecated(new="Learner._set_optimizer_state()", error=True) def set_optimizer_state(self, *args, **kwargs): pass + + @Deprecated(new="Learner.compute_losses(...)", error=False) + def compute_loss(self, *args, **kwargs): + losses_per_module = self.compute_losses(*args, **kwargs) + # To continue supporting the old `compute_loss` behavior (instead of + # the new `compute_losses`, add the ALL_MODULES key here holding the sum + # of all individual loss terms. + if ALL_MODULES not in losses_per_module: + losses_per_module[ALL_MODULES] = sum(losses_per_module.values()) + return losses_per_module diff --git a/rllib/core/learner/tests/test_learner.py b/rllib/core/learner/tests/test_learner.py index 815b9b54a2d4a..de8e700629eb1 100644 --- a/rllib/core/learner/tests/test_learner.py +++ b/rllib/core/learner/tests/test_learner.py @@ -40,7 +40,7 @@ def test_end_to_end_update(self): batch = reader.next() results = learner.update_from_batch(batch=batch.as_multi_agent()) - loss = results[ALL_MODULES][Learner.TOTAL_LOSS_KEY] + loss = results[DEFAULT_MODULE_ID][Learner.TOTAL_LOSS_KEY].peek() min_loss = min(loss, min_loss) print(f"[iter = {iter_i}] Loss: {loss:.3f}, Min Loss: {min_loss:.3f}") self.assertLess(min_loss, 0.58) diff --git a/rllib/core/learner/tf/tf_learner.py b/rllib/core/learner/tf/tf_learner.py index 07caf7419f2cb..2745ad817f823 100644 --- a/rllib/core/learner/tf/tf_learner.py +++ b/rllib/core/learner/tf/tf_learner.py @@ -25,7 +25,6 @@ OverrideToImplementCustomLogic, ) from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics import ALL_MODULES from ray.rllib.utils.typing import ( ModuleID, Optimizer, @@ -75,7 +74,7 @@ def configure_optimizers_for_module( # For this default implementation, the learning rate is handled by the # attached lr Scheduler (controlled by self.config.lr, which can be a - # fixed value of a schedule setting). + # fixed value or a schedule setting). optimizer = tf.keras.optimizers.Adam() params = self.get_parameters(module) @@ -99,7 +98,8 @@ def compute_gradients( gradient_tape: "tf.GradientTape", **kwargs, ) -> ParamDict: - grads = gradient_tape.gradient(loss_per_module[ALL_MODULES], self._params) + total_loss = sum(loss_per_module.values()) + grads = gradient_tape.gradient(total_loss, self._params) return grads @override(Learner) @@ -300,7 +300,7 @@ def _untraced_update( def helper(_batch): with tf.GradientTape(persistent=True) as tape: fwd_out = self._module.forward_train(_batch) - loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=_batch) + loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=_batch) gradients = self.compute_gradients(loss_per_module, gradient_tape=tape) del tape postprocessed_gradients = self.postprocess_gradients(gradients) diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index a560f8a7b1f43..25f54c5a4c706 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -1,3 +1,4 @@ +from collections import defaultdict import logging from typing import ( Any, @@ -97,6 +98,14 @@ def __init__(self, **kwargs): torch_dynamo_mode=self.config.torch_compile_learner_dynamo_mode, ) + # Loss scalers for mixed precision training. Map optimizer names to + # associated torch GradScaler objects. + self._grad_scalers = None + if self.config._torch_grad_scaler_class: + self._grad_scalers = defaultdict( + lambda: self.config._torch_grad_scaler_class() + ) + @OverrideToImplementCustomLogic @override(Learner) def configure_optimizers_for_module( @@ -108,7 +117,7 @@ def configure_optimizers_for_module( # For this default implementation, the learning rate is handled by the # attached lr Scheduler (controlled by self.config.lr, which can be a - # fixed value of a schedule setting). + # fixed value or a schedule setting). params = self.get_parameters(module) optimizer = torch.optim.Adam(params) @@ -130,7 +139,7 @@ def _uncompiled_update( self.metrics.activate_tensor_mode() fwd_out = self.module.forward_train(batch) - loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=batch) + loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch) gradients = self.compute_gradients(loss_per_module) postprocessed_gradients = self.postprocess_gradients(gradients) @@ -149,24 +158,48 @@ def compute_gradients( for optim in self._optimizer_parameters: # `set_to_none=True` is a faster way to zero out the gradients. optim.zero_grad(set_to_none=True) - loss_per_module[ALL_MODULES].backward() + + if self._grad_scalers is not None: + total_loss = sum( + self._grad_scalers[mid].scale(loss) + for mid, loss in loss_per_module.items() + ) + else: + total_loss = sum(loss_per_module.values()) + + total_loss.backward() grads = {pid: p.grad for pid, p in self._params.items()} return grads @override(Learner) def apply_gradients(self, gradients_dict: ParamDict) -> None: - # Make sure the parameters do not carry gradients on their own. - for optim in self._optimizer_parameters: - optim.zero_grad(set_to_none=True) - # Set the gradient of the parameters. for pid, grad in gradients_dict.items(): self._params[pid].grad = grad # For each optimizer call its step function. - for optim in self._optimizer_parameters: - optim.step() + for module_id, optimizer_names in self._module_optimizers.items(): + for optimizer_name in optimizer_names: + optim = self.get_optimizer(module_id, optimizer_name) + # Step through the scaler (unscales gradients, if applicable). + if self._grad_scalers is not None: + scaler = self._grad_scalers[module_id] + scaler.step(optim) + self.metrics.log_value( + (module_id, "_torch_grad_scaler_current_scale"), + scaler.get_scale(), + window=1, # snapshot in time, no EMA/mean. + ) + # Update the scaler. + scaler.update() + # `step` the optimizer (default), but only if all gradients are finite. + elif all( + param.grad is None or torch.isfinite(param.grad).all() + for group in optim.param_groups + for param in group["params"] + ): + optim.step() @override(Learner) def _get_optimizer_state(self) -> StateDict: diff --git a/rllib/core/models/torch/primitives.py b/rllib/core/models/torch/primitives.py index 0be0937302795..9c4e557435109 100644 --- a/rllib/core/models/torch/primitives.py +++ b/rllib/core/models/torch/primitives.py @@ -174,10 +174,8 @@ def __init__( self.mlp = nn.Sequential(*layers) - self.expected_input_dtype = torch.float32 - def forward(self, x): - return self.mlp(x.type(self.expected_input_dtype)) + return self.mlp(x) class TorchCNN(nn.Module): @@ -307,13 +305,11 @@ def __init__( # Create the CNN. self.cnn = nn.Sequential(*layers) - self.expected_input_dtype = torch.float32 - def forward(self, inputs): # Permute b/c data comes in as channels_last ([B, dim, dim, channels]) -> # Convert to `channels_first` for torch: inputs = inputs.permute(0, 3, 1, 2) - out = self.cnn(inputs.type(self.expected_input_dtype)) + out = self.cnn(inputs) # Permute back to `channels_last`. return out.permute(0, 2, 3, 1) @@ -459,12 +455,10 @@ def __init__( # Create the final CNNTranspose network. self.cnn_transpose = nn.Sequential(*layers) - self.expected_input_dtype = torch.float32 - def forward(self, inputs): # Permute b/c data comes in as [B, dim, dim, channels]: out = inputs.permute(0, 3, 1, 2) - out = self.cnn_transpose(out.type(self.expected_input_dtype)) + out = self.cnn_transpose(out) return out.permute(0, 2, 3, 1) diff --git a/rllib/examples/gpus/float16_training_and_inference.py b/rllib/examples/gpus/float16_training_and_inference.py new file mode 100644 index 0000000000000..75c3b86604790 --- /dev/null +++ b/rllib/examples/gpus/float16_training_and_inference.py @@ -0,0 +1,242 @@ +"""Example of using float16 precision for training and inference. + +This example: + - shows how to write a custom callback for RLlib to convert all RLModules + (on the EnvRunners and Learners) to float16 precision. + - shows how to write a custom env-to-module ConnectorV2 piece to convert all + observations and rewards in the collected trajectories to float16 (numpy) arrays. + - shows how to write a custom grad scaler for torch that is necessary to stabilize + learning with float16 weight matrices and gradients. This custom scaler behaves + exactly like the torch built-in `torch.amp.GradScaler` but also works for float16 + gradients (which the torch built-in one doesn't). + - demonstrates how to plug in all the above custom components into an + `AlgorithmConfig` instance and start training (and inference) with float16 + precision. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +You can visualize experiment results in ~/ray_results using TensorBoard. + + +Results to expect +----------------- +You should see something similar to the following on your terminal, when running this +script with the above recommended options: + ++-----------------------------+------------+-----------------+--------+ +| Trial name | status | loc | iter | +| | | | | +|-----------------------------+------------+-----------------+--------+ +| PPO_CartPole-v1_437ee_00000 | TERMINATED | 127.0.0.1:81045 | 6 | ++-----------------------------+------------+-----------------+--------+ ++------------------+------------------------+------------------------+ +| total time (s) | episode_return_mean | num_episodes_lifetime | +| | | | +|------------------+------------------------+------------------------+ +| 71.3123 | 153.79 | 358 | ++------------------+------------------------+------------------------+ +""" +from typing import Optional + +import gymnasium as gym +import numpy as np +import torch + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import get_trainable_cls + +parser = add_rllib_example_script_args( + default_iters=50, default_reward=150.0, default_timesteps=100000 +) +parser.set_defaults( + enable_new_api_stack=True, +) + + +class Float16InitCallback(DefaultCallbacks): + """Callback making sure that all RLModules in the algo are `half()`'ed.""" + + def on_algorithm_init( + self, + *, + algorithm: Algorithm, + metrics_logger: Optional[MetricsLogger] = None, + **kwargs, + ) -> None: + # Switch all Learner RLModules to float16. + algorithm.learner_group.foreach_learner( + lambda learner: learner.module.foreach_module(lambda mid, mod: mod.half()) + ) + # Switch all EnvRunner RLModules (assuming single RLModules) to float16. + algorithm.env_runner_group.foreach_worker( + lambda env_runner: env_runner.module.half() + ) + if algorithm.eval_env_runner_group: + algorithm.eval_env_runner_group.foreach_worker( + lambda env_runner: env_runner.module.half() + ) + + +class Float16Connector(ConnectorV2): + """ConnectorV2 piece preprocessing observations and rewards to be float16. + + Note that users can also write a gymnasium.Wrapper for observations and rewards + to achieve the same thing. + """ + + def recompute_output_observation_space( + self, + input_observation_space, + input_action_space, + ): + return gym.spaces.Box( + input_observation_space.low.astype(np.float16), + input_observation_space.high.astype(np.float16), + input_observation_space.shape, + np.float16, + ) + + def __call__(self, *, rl_module, batch, episodes, **kwargs): + for sa_episode in self.single_agent_episode_iterator(episodes): + obs = sa_episode.get_observations(-1) + float16_obs = obs.astype(np.float16) + sa_episode.set_observations(new_data=float16_obs, at_indices=-1) + if len(sa_episode) > 0: + rew = sa_episode.get_rewards(-1).astype(np.float16) + sa_episode.set_rewards(new_data=rew, at_indices=-1) + return batch + + +class Float16GradScaler: + """Custom grad scaler for `TorchLearner`. + + This class is utilizing the experimental support for the `TorchLearner`'s support + for loss/gradient scaling (analogous to how a `torch.amp.GradScaler` would work). + + TorchLearner performs the following steps using this class (`scaler`): + - loss_per_module = TorchLearner.compute_losses() + - for L in loss_per_module: L = scaler.scale(L) + - grads = TorchLearner.compute_gradients() # L.backward() on scaled loss + - TorchLearner.apply_gradients(grads): + for optim in optimizers: + scaler.step(optim) # <- grads should get unscaled + scaler.update() # <- update scaling factor + """ + + def __init__( + self, + init_scale=1000.0, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + ): + self._scale = init_scale + self.growth_factor = growth_factor + self.backoff_factor = backoff_factor + self.growth_interval = growth_interval + self._found_inf_or_nan = False + self.steps_since_growth = 0 + + def scale(self, loss): + # Scale the loss by `self._scale`. + return loss * self._scale + + def get_scale(self): + return self._scale + + def step(self, optimizer): + # Unscale the gradients for all model parameters and apply. + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is not None: + param.grad.data.div_(self._scale) + if torch.isinf(param.grad).any() or torch.isnan(param.grad).any(): + self._found_inf_or_nan = True + break + if self._found_inf_or_nan: + break + # Only step if no inf/NaN grad found. + if not self._found_inf_or_nan: + optimizer.step() + + def update(self): + # If gradients are found to be inf/NaN, reduce the scale. + if self._found_inf_or_nan: + self._scale *= self.backoff_factor + self.steps_since_growth = 0 + # Increase the scale after a set number of steps without inf/NaN. + else: + self.steps_since_growth += 1 + if self.steps_since_growth >= self.growth_interval: + self._scale *= self.growth_factor + self.steps_since_growth = 0 + # Reset inf/NaN flag. + self._found_inf_or_nan = False + + +class Float16TorchLearner(PPOTorchLearner): + @override(TorchLearner) + def configure_optimizers_for_module(self, module_id, config): + module = self._module[module_id] + + params = self.get_parameters(module) + # Create an Adam optimizer with a different eps for better float16 stability. + optimizer = torch.optim.Adam(params, eps=1e-4) + + # Register the created optimizer (under the default optimizer name). + self.register_optimizer( + module_id=module_id, + optimizer=optimizer, + params=params, + lr_or_lr_schedule=config.lr, + ) + + +if __name__ == "__main__": + args = parser.parse_args() + + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment("CartPole-v1") + # Plug in our custom loss scaler class. + .experimental(_torch_grad_scaler_class=Float16GradScaler) + .env_runners(env_to_module_connector=lambda env: Float16Connector()) + .callbacks(Float16InitCallback) + .training( + learner_class=Float16TorchLearner, + # Switch off grad clipping entirely b/c we use our custom grad scaler with + # built-in inf/nan detection (see `step` method of `Float16GradScaler`). + grad_clip=None, + # Typical CartPole-v1 hyperparams known to work well: + gamma=0.99, + lr=0.0003, + num_sgd_iter=6, + vf_loss_coeff=0.01, + use_kl_loss=True, + ) + ) + + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/learners/classes/curiosity_torch_learner_utils.py b/rllib/examples/learners/classes/curiosity_torch_learner_utils.py index fd15b18e72066..d34819e549f20 100644 --- a/rllib/examples/learners/classes/curiosity_torch_learner_utils.py +++ b/rllib/examples/learners/classes/curiosity_torch_learner_utils.py @@ -71,7 +71,7 @@ def build(self): AddNextObservationsFromEpisodesToTrainBatch(), ) - def compute_loss( + def compute_losses( self, *, fwd_out: Dict[str, Any], diff --git a/rllib/utils/torch_utils.py b/rllib/utils/torch_utils.py index 900c2bf06e5eb..b9667ec050580 100644 --- a/rllib/utils/torch_utils.py +++ b/rllib/utils/torch_utils.py @@ -274,8 +274,8 @@ def mapping(item): else: tensor = torch.from_numpy(np.asarray(item)) - # Floatify all float64 tensors. - if tensor.is_floating_point(): + # Floatify all float64 tensors (but leave float16 as-is). + if tensor.is_floating_point() and str(tensor.dtype) != "torch.float16": tensor = tensor.float() # Pin the tensor's memory (for faster transfer to GPU later).