-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Off-policy] Add sequence sampling to 'EpisodeReplayBuffer'. #48116
[RLlib; Off-policy] Add sequence sampling to 'EpisodeReplayBuffer'. #48116
Conversation
…construction to slicing. This is still incomplete b/c it needs the correct discounting of rewards in case a sequence is returned. In addition the steps in between the end of the sequence and the n-step needs to be rmeoved. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@@ -518,78 +525,82 @@ def _sample_episodes( | |||
if random_n_step: | |||
actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) | |||
|
|||
lookback = int(episode_ts != 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on our slack discussion, just some thoughts:
- Maybe the user should provide the lookback value?
- Even if the user requires a lookback of 10000, it shouldn't matter in filtering out episodes here, b/c if the episode is not reaching back that many timesteps, then that's also ok and the lookback will result in fill values, not actual episode values.
- Thus, in the if-block below, we should simply do:
if episode_ts + batch_length_T + (actual_n_step - 1) > len(episode):
continue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If no lookback is provided I guess we run into this line and error out: https://github.com/ray-project/ray/blob/5669b479e13fe65308071f4de8b5b0763afa6aa7/rllib/connectors/common/add_states_from_episodes_to_batch.py#L374
We could set fill=True
. The things is that the starting point is 0
but the length is 1
.
episode_ts - lookback, | ||
episode_ts + actual_n_step + batch_length_T - 1, | ||
), | ||
len_lookback_buffer=lookback, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This arg should already take care of the slice-adjustment, so you don't need to do this math above anymore:
sampled_episode = episode.slice(
slice(episode_ts, episode_ts + batch_length_T + (actual_n_step - 1)),
len_lookback_buffer=lookback,
)
@@ -473,6 +478,8 @@ def _sample_episodes( | |||
the extra model outputs at the `"obs"` in the batch is included (the | |||
timestep at which the action is computed). | |||
finalize: If episodes should be finalized. | |||
states: States of stateful `RLModule` that can be added to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how this would work. We sample (randomly) some episode chunk(s) from the buffer, no? So how would the user know, which states (at which timesteps and for which episodes) to provide??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is that we actually have cases where the buffer does not know how a state
looks like as it does not know the module. For example: in behavior cloning the expert policy contains state only in rare cases, but if we want to train a stateful model we need to provide somehow states. My goal is to make the offline API as powerful as possible to apply it to real industry cases that will necessarily come with complex modules.
A different approach would be to add a connector that does so. This is probably the "nicer" solution and more aligned with our design. What do yout hink @sven1977 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have 2 high level questions, which I think we should answer first:
- Do we even need n_step in combination with sequence sampling? It's overkill, no. If I have an RNN-based loss/model and I sample sequence chunks from the buffer, then I don't really care about n-step, b/c I already have something better: whole and longer sequences + a memory-capable model.
- Do we need lookback? I think we do b/c the first state-out of each batch row is the state-out of the previous(!) timestep (the one in a lookback buffer of size 1). I do NOT think, however, that we need it for the (discounted) rewards. Unless, however, :) you have an LSTM that requires prev-action/reward inputs as well. As a first solution, I think users should have to provide the
lookback
as an argument to thesample
method.
…d lookback. Furthermore, added 'get_initial_state' to 'DQNRainbowRLModule' and adapted module for stateful training. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…adeletion for sampled episodes. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…s necessary for plain DQN to learn. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…g architecture. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
} | ||
if include_extra_model_outputs | ||
else {} | ||
if batch_length_T == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if batch_length_T is None:
sampled_episodes.append(sampled_episode) | ||
|
||
# Increment counter. | ||
B += 1 | ||
B += batch_length_T or 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
if Columns.NEXT_OBS in batch: | ||
self.add_n_batch_items( | ||
batch=batch, | ||
column="new_state_in", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: create a new Column constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a similar idea, but did not before it was running. Have created a new column now.
key=Columns.STATE_OUT | ||
) | ||
else: | ||
state_outs = tree.map_structure( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we explain here why we would repeat the lookback state len(episode)
times? What's the logic behind doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. The main case where this happens is in offline learning when the expert was non-stateful. I have added some comments to make it more clear.
python/ray/data/exceptions.py
Outdated
|
||
# from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled | |
from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix @sven1977 ! I was already wondering where this linter message came from. I did not change it myself however and hoped for the master to fix it.
@@ -132,6 +132,13 @@ def compute_q_values(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType | |||
{"af": self.af, "vf": self.vf} if self.uses_dueling else self.af, | |||
) | |||
|
|||
@override(RLModule) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha, nice :)
del episode | ||
|
||
# Add the actually chosen n-step in this episode. | ||
sampled_episode.extra_model_outputs["n_step"] = InfiniteLookbackBuffer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add this information to the episode, even if we don't do n-step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, because in the loss we use the n_step
as an exponent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Great hustle through making this PR work, @simonsays1980 !!
Just a few cleanup nits and 2-3 remaining smaller questions.
Co-authored-by: Sven Mika <sven@anyscale.io> Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ate_in' to 'next_state_in'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ode-replay-buffer
…ode-replay-buffer
…ePreLearner' in regard to 'n_step' and step counting. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ode-replay-buffer
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ode-replay-buffer
… in the CI tests. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ay-project#48116) Signed-off-by: Roshan Kathawate <roshankathawate@gmail.com>
Why are these changes needed?
At the moment replay buffers do not allow to sample sequences (which could become helpul in case of stateful policies or bias reduction in value functions). This PR offers a solution that
EpisodeReplayBuffer
.SingleAgentEpisode.slice
instead of constructing episodes from basic data structures.Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.