Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Off-policy] Add sequence sampling to 'EpisodeReplayBuffer'. #48116

Conversation

simonsays1980
Copy link
Collaborator

@simonsays1980 simonsays1980 commented Oct 21, 2024

Why are these changes needed?

At the moment replay buffers do not allow to sample sequences (which could become helpul in case of stateful policies or bias reduction in value functions). This PR offers a solution that

  • Samples sequences in the EpisodeReplayBuffer.
  • Adds given states to the episodes.
  • Simpliifes episode sampling by using SingleAgentEpisode.slice instead of constructing episodes from basic data structures.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…construction to slicing. This is still incomplete b/c it needs the correct discounting of rewards in case a sequence is returned. In addition the steps in between the end of the sequence and the n-step needs to be rmeoved.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@sven1977 sven1977 changed the title [RLlib; Off-policy] - Add sequence sampling to 'EpisodeReplayBuffer'. [RLlib; Off-policy] Add sequence sampling to 'EpisodeReplayBuffer'. Oct 21, 2024
@sven1977 sven1977 marked this pull request as ready for review October 21, 2024 12:18
@sven1977 sven1977 self-requested a review as a code owner October 21, 2024 12:18
@simonsays1980 simonsays1980 added enhancement Request for new feature and/or capability rllib RLlib related issues rllib-offline-rl Offline RL problems labels Oct 21, 2024
@@ -518,78 +525,82 @@ def _sample_episodes(
if random_n_step:
actual_n_step = int(self.rng.integers(n_step[0], n_step[1]))

lookback = int(episode_ts != 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on our slack discussion, just some thoughts:

  • Maybe the user should provide the lookback value?
  • Even if the user requires a lookback of 10000, it shouldn't matter in filtering out episodes here, b/c if the episode is not reaching back that many timesteps, then that's also ok and the lookback will result in fill values, not actual episode values.
  • Thus, in the if-block below, we should simply do:
if episode_ts + batch_length_T + (actual_n_step - 1) > len(episode):
    continue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no lookback is provided I guess we run into this line and error out: https://github.com/ray-project/ray/blob/5669b479e13fe65308071f4de8b5b0763afa6aa7/rllib/connectors/common/add_states_from_episodes_to_batch.py#L374

We could set fill=True. The things is that the starting point is 0 but the length is 1.

episode_ts - lookback,
episode_ts + actual_n_step + batch_length_T - 1,
),
len_lookback_buffer=lookback,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This arg should already take care of the slice-adjustment, so you don't need to do this math above anymore:

sampled_episode = episode.slice(
    slice(episode_ts, episode_ts + batch_length_T + (actual_n_step - 1)),
    len_lookback_buffer=lookback,
)

@@ -473,6 +478,8 @@ def _sample_episodes(
the extra model outputs at the `"obs"` in the batch is included (the
timestep at which the action is computed).
finalize: If episodes should be finalized.
states: States of stateful `RLModule` that can be added to the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how this would work. We sample (randomly) some episode chunk(s) from the buffer, no? So how would the user know, which states (at which timesteps and for which episodes) to provide??

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that we actually have cases where the buffer does not know how a state looks like as it does not know the module. For example: in behavior cloning the expert policy contains state only in rare cases, but if we want to train a stateful model we need to provide somehow states. My goal is to make the offline API as powerful as possible to apply it to real industry cases that will necessarily come with complex modules.
A different approach would be to add a connector that does so. This is probably the "nicer" solution and more aligned with our design. What do yout hink @sven1977 ?

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have 2 high level questions, which I think we should answer first:

  • Do we even need n_step in combination with sequence sampling? It's overkill, no. If I have an RNN-based loss/model and I sample sequence chunks from the buffer, then I don't really care about n-step, b/c I already have something better: whole and longer sequences + a memory-capable model.
  • Do we need lookback? I think we do b/c the first state-out of each batch row is the state-out of the previous(!) timestep (the one in a lookback buffer of size 1). I do NOT think, however, that we need it for the (discounted) rewards. Unless, however, :) you have an LSTM that requires prev-action/reward inputs as well. As a first solution, I think users should have to provide the lookback as an argument to the sample method.

…d lookback. Furthermore, added 'get_initial_state' to 'DQNRainbowRLModule' and adapted module for stateful training.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…adeletion for sampled episodes.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…s necessary for plain DQN to learn.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…g architecture.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
}
if include_extra_model_outputs
else {}
if batch_length_T == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if batch_length_T is None:

sampled_episodes.append(sampled_episode)

# Increment counter.
B += 1
B += batch_length_T or 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

if Columns.NEXT_OBS in batch:
self.add_n_batch_items(
batch=batch,
column="new_state_in",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: create a new Column constant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a similar idea, but did not before it was running. Have created a new column now.

key=Columns.STATE_OUT
)
else:
state_outs = tree.map_structure(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we explain here why we would repeat the lookback state len(episode) times? What's the logic behind doing this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The main case where this happens is in offline learning when the expert was non-stateful. I have added some comments to make it more clear.

@simonsays1980 simonsays1980 requested a review from a team as a code owner December 28, 2024 11:32
Comment on lines 9 to 10

# from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled
from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @sven1977 ! I was already wondering where this linter message came from. I did not change it myself however and hoped for the master to fix it.

@@ -132,6 +132,13 @@ def compute_q_values(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType
{"af": self.af, "vf": self.vf} if self.uses_dueling else self.af,
)

@override(RLModule)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, nice :)

rllib/core/columns.py Outdated Show resolved Hide resolved
del episode

# Add the actually chosen n-step in this episode.
sampled_episode.extra_model_outputs["n_step"] = InfiniteLookbackBuffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add this information to the episode, even if we don't do n-step?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because in the loss we use the n_step as an exponent.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Great hustle through making this PR work, @simonsays1980 !!

Just a few cleanup nits and 2-3 remaining smaller questions.

@sven1977 sven1977 enabled auto-merge (squash) December 28, 2024 22:34
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Dec 28, 2024
Co-authored-by: Sven Mika <sven@anyscale.io>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@github-actions github-actions bot disabled auto-merge December 30, 2024 11:11
…ate_in' to 'next_state_in'.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ePreLearner' in regard to 'n_step' and step counting.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
… in the CI tests.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@sven1977 sven1977 merged commit a29d0c4 into ray-project:master Jan 6, 2025
5 checks passed
roshankathawate pushed a commit to roshankathawate/ray that referenced this pull request Jan 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Request for new feature and/or capability go add ONLY when ready to merge, run all tests rllib RLlib related issues rllib-offline-rl Offline RL problems
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants