Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder (vol 23): Float16 training support and new example script. #47362

Merged

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Aug 27, 2024

Cleanup examples folder (vol 23): Float16 training support and new example script.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) August 28, 2024 12:55
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Aug 28, 2024
@@ -91,7 +91,7 @@ Computing Losses
:nosignatures:
:toctree: doc/

Learner.compute_loss
Learner.compute_losses
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this API name to be even more clear. This method computes one(!) loss per RLModule (in a MultiRLModule) inside the Learner.

Got rid of the confusing TOTAL_LOSS key. We compute this now in the default implementation of compute_gradients.

value_fn_out = torch.tensor(0.0).to(surrogate_loss.device)
mean_vf_unclipped_loss = torch.tensor(0.0).to(surrogate_loss.device)
vf_loss_clipped = mean_vf_loss = torch.tensor(0.0).to(surrogate_loss.device)
z = torch.tensor(0.0, device=surrogate_loss.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify

@@ -1621,3 +1624,13 @@ def get_optimizer_state(self, *args, **kwargs):
@Deprecated(new="Learner._set_optimizer_state()", error=True)
def set_optimizer_state(self, *args, **kwargs):
pass

@Deprecated(new="Learner.compute_losses(...)", error=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emulate the old behavior, in case users use this in their tests or other code.

Copy link
Collaborator

@simonsays1980 simonsays1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Awesome PR. It would be nice to add some more documentation of why certain components are needed and how the new loss setup works now.

Comment on lines +2972 to +2979
py_test(
name = "examples/gpus/float16_training_and_inference",
main = "examples/gpus/float16_training_and_inference.py",
tags = ["team:rllib", "exclusive", "examples", "gpu"],
size = "medium",
srcs = ["examples/gpus/float16_training_and_inference.py"],
args = ["--enable-new-api-stack", "--as-test", "--stop-reward=150.0"]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Is this even used by users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Actually a customer asked for it ;)
It does get interesting for large models (see also many efforts training LLMs with super compressed precisions down to bfloat16) and multi-agent. I think if one can stabilize this, it's very useful. Another example script with mixed-precision training (and float16 inference on the EnvRunners) is in the making ...

@@ -3194,6 +3197,14 @@ def experimental(
"""Sets the config's experimental settings.

Args:
_torch_grad_scaler_class: Class to use for torch loss scaling (and gradient
unscaling). The class must implement the following methods to be
compatible with a `TorchLearner`. These methods/APIs match exactly the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: Remove "the" at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

`scale([loss])` to scale the loss.
`get_scale()` to get the current scale value.
`step([optimizer])` to unscale the grads and step the given optimizer.
`update()` to update the scaler after an optimizer step.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description is not clear enough imo. For what is it used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enhanced and added a link to the torch docs


def possibly_masked_mean(data_):
return torch.sum(data_[batch[Columns.LOSS_MASK]]) / num_valid
return torch.sum(data_[mask]) / num_valid
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we potentially also mask with this mask values in the observations that are not available in a certain step of the environment, e.g. different number of entities at different steps of a game?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should work. But you would have to manually change that in the Learner connector (where this column is being produced and added to the batch).

mean_vf_unclipped_loss = torch.tensor(0.0).to(surrogate_loss.device)
vf_loss_clipped = mean_vf_loss = torch.tensor(0.0).to(surrogate_loss.device)
z = torch.tensor(0.0, device=surrogate_loss.device)
value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Niceee!

@@ -893,7 +896,7 @@ def compute_loss_for_module(

Think of this as computing loss for a single agent. For multi-agent use-cases
that require more complicated computation for loss, consider overriding the
`compute_loss` method instead.
`compute_losses` method instead.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can compute_losses call internally compute_loss_for_module or is the latter called nevertheless and losses would be computed two times?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default implementation of compute_losses calls n times compute_loss_for_module, where n is the number of RLModules within the Learner's MultiRLModule. So nothing is computed twice.

If you have a complex multi-agent case, you should override compute_losses, in which case the n calls to compute_loss_for_module will NOT be made.

@@ -300,7 +300,7 @@ def _untraced_update(
def helper(_batch):
with tf.GradientTape(persistent=True) as tape:
fwd_out = self._module.forward_train(_batch)
loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=_batch)
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=_batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even want this here and motivate users to write TF algorithms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! But .... DreamerV3 :|

@@ -108,7 +117,7 @@ def configure_optimizers_for_module(

# For this default implementation, the learning rate is handled by the
# attached lr Scheduler (controlled by self.config.lr, which can be a
# fixed value of a schedule setting).
# fixed value or a schedule setting).
params = self.get_parameters(module)
optimizer = torch.optim.Adam(params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to pass in kwargs to Adam?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used to think the same way.

BUT that would make configuring custom optimizers again very non-pythonic and very yaml'ish. The user would then have to provide a class/type and some kwargs, but then has no chance to customize anything else within the optimizer setup process. Think of a user having to configure two optimizers, or three. Where do you make these options available in the config, then? What if the user needs different optimizers per module?

@@ -149,24 +158,48 @@ def compute_gradients(
for optim in self._optimizer_parameters:
# `set_to_none=True` is a faster way to zero out the gradients.
optim.zero_grad(set_to_none=True)
loss_per_module[ALL_MODULES].backward()

if self._grad_scalers is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!!!

"""Example of using float16 precision for training and inference.

This example:
- shows how to write a custom callback for RLlib to convert all RLModules
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely awesome!!! This is such a great example!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a ton!! Yes, and it actually works :D (only with adding all these weird tricks to stabilize learning). But I agree, it's super nice to see how users can customize not just one component at the same time, but several and then achieve a very high degree of customizability w/o us having to add feature after feature to the lib.

Signed-off-by: sven1977 <svenmika1977@gmail.com>
@github-actions github-actions bot disabled auto-merge August 28, 2024 14:31
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) August 28, 2024 16:10
@github-actions github-actions bot disabled auto-merge August 28, 2024 16:10
@sven1977 sven1977 enabled auto-merge (squash) August 28, 2024 16:15
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@github-actions github-actions bot disabled auto-merge August 28, 2024 17:48
@sven1977 sven1977 merged commit a92f3b4 into ray-project:master Aug 29, 2024
4 of 5 checks passed
@sven1977 sven1977 deleted the cleanup_examples_folder_23_float16 branch August 29, 2024 05:40
@can-anyscale
Copy link
Collaborator

Broke #45088 and is blocking the release, i'm reverting to unblock

ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 12, 2024
…d new example script. (ray-project#47362)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…d new example script. (ray-project#47362)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…d new example script. (ray-project#47362)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…d new example script. (ray-project#47362)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…d new example script. (ray-project#47362)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants