Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Off-policy] Add sequence sampling to 'EpisodeReplayBuffer'. #48116

Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3868c9d
Added sequence sampling to 'EpisodeReplayBuffer' and changed episode …
simonsays1980 Oct 21, 2024
c26bf76
Removed states from 'EpisodeReplayBuffer._sample_episodes' and remove…
simonsays1980 Oct 22, 2024
8489719
Fixed multiple shape errors in DQN when training with LSTM.
simonsays1980 Oct 23, 2024
36bf8ac
Fixed some shape errors in 'foward_train' of DQN. Furthermore, added …
simonsays1980 Oct 27, 2024
602aebb
Removed number of atoms from tuned example as we use always 1.
simonsays1980 Oct 27, 2024
0b8ec9c
Modified no-sequence sampling to replicate the master version. This i…
simonsays1980 Oct 27, 2024
45befc6
Added state-ins for the next timestep to compute DQN TD-loss.
simonsays1980 Oct 27, 2024
4ef4255
Added next obs states to batch of the q-network when using the duelin…
simonsays1980 Oct 29, 2024
570baf1
Added @sven1977's review.
simonsays1980 Oct 30, 2024
bc8dc4b
Replaced 'new_state_in' by 'Column.NEXT_STATE_IN'.
simonsays1980 Oct 30, 2024
b4be41c
Added partial sequence sampling to add more terminal nodes to training.
simonsays1980 Nov 1, 2024
28d805b
Added stateless Cartpole example to tuned examples for DQN.
simonsays1980 Nov 11, 2024
61a43dc
Merge branch 'master' of https://github.com/ray-project/ray into offp…
sven1977 Nov 13, 2024
4391d93
Merge branch 'offpolicy-enable-sequence-sampling-in-episode-replay-bu…
sven1977 Nov 13, 2024
f2affc5
wip
sven1977 Nov 13, 2024
6dec441
wip
sven1977 Nov 13, 2024
052b729
Merge branch 'master' into offpolicy-enable-sequence-sampling-in-epis…
simonsays1980 Nov 25, 2024
1b5a81a
Merge branch 'master' into offpolicy-enable-sequence-sampling-in-epis…
simonsays1980 Nov 28, 2024
2cf7b6b
Merged master and resolved conflicts.
simonsays1980 Dec 21, 2024
cc61432
Rewrote DefaultDQNTorchRLModule to enable stateful modules. Fixed som…
simonsays1980 Dec 28, 2024
ba46118
Merge branch 'master' into offpolicy-enable-sequence-sampling-in-epis…
simonsays1980 Dec 28, 2024
081b645
Update python/ray/data/_internal/planner/plan_udf_map_op.py
simonsays1980 Dec 30, 2024
59b56c7
Removed commented code in DefaultDQNTorchRLModule and changed 'new_st…
simonsays1980 Dec 30, 2024
10978bb
Added assertion to avoid sequence sampling and multi-n-step.
simonsays1980 Dec 30, 2024
d55dd5e
Merge branch 'master' into offpolicy-enable-sequence-sampling-in-epis…
simonsays1980 Dec 30, 2024
99aac92
WIP
simonsays1980 Jan 3, 2025
42448d9
Merge branch 'master' into offpolicy-enable-sequence-sampling-in-epis…
simonsays1980 Jan 3, 2025
913f7e4
Fixed a couple of small nits in the 'EpisodeReplayBuffer' and 'Offlin…
simonsays1980 Jan 3, 2025
43870ec
Merge branch 'master' into offpolicy-enable-sequence-sampling-in-epis…
simonsays1980 Jan 3, 2025
9854a7d
Fixed a linting error.
simonsays1980 Jan 3, 2025
4ee0cc3
Merge branch 'master' into offpolicy-enable-sequence-sampling-in-epis…
simonsays1980 Jan 3, 2025
a9f53cf
Redid small change in 'InfiniteLookbackBuffer' b/c it caused an error…
simonsays1980 Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
)
from ray.data.context import DataContext
from ray.data.exceptions import UserCodeException
from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled

# from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved


class _MapActorContext:
Expand Down
3 changes: 2 additions & 1 deletion python/ray/data/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from ray.exceptions import UserCodeException
from ray.util import log_once
from ray.util.annotations import DeveloperAPI
from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled

# from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled
from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @sven1977 ! I was already wondering where this linter message came from. I did not change it myself however and hoped for the master to fix it.


logger = logging.getLogger(__name__)

Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/dqn/default_dqn_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def compute_q_values(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType
{"af": self.af, "vf": self.vf} if self.uses_dueling else self.af,
)

@override(RLModule)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, nice :)

def get_initial_state(self) -> dict:
if hasattr(self.encoder, "get_initial_state"):
return self.encoder.get_initial_state()
else:
return {}

@override(RLModule)
def input_specs_train(self) -> SpecType:
return [
Expand Down
32 changes: 4 additions & 28 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner import Learner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.execution.rollout_ops import (
Expand Down Expand Up @@ -544,28 +538,6 @@ def get_default_learner_class(self) -> Union[Type["Learner"], str]:
"Use `config.framework('torch')` instead."
)

@override(AlgorithmConfig)
def build_learner_connector(
self,
input_observation_space,
input_action_space,
device=None,
):
pipeline = super().build_learner_connector(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
device=device,
)

# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)

return pipeline


def calculate_rr_weights(config: AlgorithmConfig) -> List[float]:
"""Calculate the round robin weights for the rollout and train steps"""
Expand Down Expand Up @@ -674,9 +646,13 @@ def _training_step_new_api_stack(self):
for _ in range(sample_and_train_weight):
# Sample a list of episodes used for learning from the replay buffer.
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_SAMPLE_TIMER)):

episodes = self.local_replay_buffer.sample(
num_items=self.config.total_train_batch_size,
n_step=self.config.n_step,
batch_length_T=self.env_runner.module.is_stateful()
* self.config.model_config.get("max_seq_len", 0),
lookback=int(self.env_runner.module.is_stateful()),
gamma=self.config.gamma,
beta=self.config.replay_buffer_config.get("beta"),
sample_episodes=True,
Expand Down
13 changes: 13 additions & 0 deletions rllib/algorithms/dqn/dqn_learner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from typing import Any, Dict, Optional

from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.learner.utils import update_target_network
from ray.rllib.core.rl_module.apis import QNetAPI, TargetNetworkAPI
Expand Down Expand Up @@ -48,6 +54,13 @@ def build(self) -> None:
)
)

# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
self._learner_connector.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)

@override(Learner)
def add_module(
self,
Expand Down
78 changes: 70 additions & 8 deletions rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tree
from typing import Dict, Union

from ray.rllib.algorithms.dqn.default_dqn_rl_module import (
Expand Down Expand Up @@ -46,13 +47,20 @@ def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorTy
# outputs directly the `argmax` of the logits.
exploit_actions = action_dist.to_deterministic().sample()

output = {Columns.ACTIONS: exploit_actions}
if Columns.STATE_OUT in qf_outs:
output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]

# In inference, we only need the exploitation actions.
return {Columns.ACTIONS: exploit_actions}
return output

@override(RLModule)
def _forward_exploration(
self, batch: Dict[str, TensorType], t: int
) -> Dict[str, TensorType]:
# Define the return dictionary.
output = {}

# Q-network forward pass.
qf_outs = self.compute_q_values(batch)

Expand All @@ -73,7 +81,13 @@ def _forward_exploration(
B = qf_outs[QF_PREDS].shape[0]
random_actions = torch.squeeze(
torch.multinomial(
(torch.nan_to_num(qf_outs[QF_PREDS], neginf=0.0) != 0.0).float(),
(
torch.nan_to_num(
qf_outs[QF_PREDS].reshape(-1, qf_outs[QF_PREDS].size(-1)),
neginf=0.0,
)
!= 0.0
).float(),
num_samples=1,
),
dim=1,
Expand All @@ -85,7 +99,14 @@ def _forward_exploration(
exploit_actions,
)

return {Columns.ACTIONS: actions}
# Add the actions to the return dictionary.
output[Columns.ACTIONS] = actions

# If this is a stateful module, add output states.
if Columns.STATE_OUT in qf_outs:
output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]

return output

@override(RLModule)
def _forward_train(
Expand All @@ -107,11 +128,33 @@ def _forward_train(
[batch[Columns.OBS], batch[Columns.NEXT_OBS]], dim=0
),
}
# If this is a stateful module add the input states.
if Columns.STATE_IN in batch:
# Add both, the input state for the actual observation and
# the one for the next observation.
batch_base.update(
{
Columns.STATE_IN: tree.map_structure(
lambda t1, t2: torch.cat([t1, t2], dim=0),
batch[Columns.STATE_IN],
batch[Columns.NEXT_STATE_IN],
)
}
)
# Otherwise we can just use the current observations.
else:
batch_base = {Columns.OBS: batch[Columns.OBS]}
# If this is a stateful module add the input state.
if Columns.STATE_IN in batch:
batch_base.update({Columns.STATE_IN: batch[Columns.STATE_IN]})

batch_target = {Columns.OBS: batch[Columns.NEXT_OBS]}

# If we have a stateful encoder, add the states for the target forward
# pass.
if Columns.NEXT_STATE_IN in batch:
batch_target.update({Columns.STATE_IN: batch[Columns.NEXT_STATE_IN]})

# Q-network forward passes.
qf_outs = self.compute_q_values(batch_base)
if self.uses_double_q:
Expand All @@ -135,6 +178,14 @@ def _forward_train(
# Probabilities of the target Q-value distribution of the next state.
output[QF_TARGET_NEXT_PROBS] = qf_target_next_outs[QF_PROBS]

# Add the states to the output, if the module is stateful.
if Columns.STATE_OUT in qf_outs:
output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]
# For correctness, also add the output states from the target forward pass.
# Note, we do not backpropagate through this state.
if Columns.STATE_OUT in qf_target_next_outs:
output[Columns.NEXT_STATE_OUT] = qf_target_next_outs[Columns.STATE_OUT]

return output

@override(QNetAPI)
Expand All @@ -154,8 +205,11 @@ def compute_advantage_distribution(
# Reshape the action values.
# NOTE: Handcrafted action shape.
logits_per_action_per_atom = torch.reshape(
batch, shape=(-1, self.action_space.n, self.num_atoms)
batch, shape=(*batch.shape[:-1], self.action_space.n, self.num_atoms)
)
# logits_per_action_per_atom = torch.reshape(
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
# batch, shape=(-1, self.action_space.n, self.num_atoms)
# )
# Calculate the probability for each action value atom. Note,
# the sum along action value atoms of a single action value
# must sum to one.
Expand Down Expand Up @@ -216,10 +270,13 @@ def _qf_forward_helper(
# Center the advantage stream distribution.
centered_af_logits = af_dist_output["logits"] - af_dist_output[
"logits"
].mean(dim=1, keepdim=True)
].mean(dim=-1, keepdim=True)
# Calculate the Q-value distribution by adding advantage and
# value stream.
qf_logits = centered_af_logits + vf_outs.unsqueeze(dim=-1)
qf_logits = centered_af_logits + vf_outs.view(
-1, *((1,) * (centered_af_logits.dim() - 1))
)
# qf_logits = centered_af_logits + vf_outs.unsqueeze(dim=-1)
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
# Calculate probabilites for the Q-value distribution along
# the support given by the atoms.
qf_probs = nn.functional.softmax(qf_logits, dim=-1)
Expand All @@ -236,8 +293,8 @@ def _qf_forward_helper(
# https://discuss.pytorch.org/t/gradient-computation-issue-due-to-
# inplace-operation-unsure-how-to-debug-for-custom-model/170133
# Has to be a mean for each batch element.
af_outs_mean = torch.unsqueeze(
torch.nan_to_num(qf_outs, neginf=torch.nan).nanmean(dim=1), dim=1
af_outs_mean = torch.nan_to_num(qf_outs, neginf=torch.nan).nanmean(
dim=-1, keepdim=True
)
qf_outs = qf_outs - af_outs_mean
# Add advantage and value stream. Note, we broadcast here.
Expand Down Expand Up @@ -266,4 +323,9 @@ def _qf_forward_helper(
# In this case we have a Q-head of dimension (1, action_space.n).
output[QF_PREDS] = qf_outs

# If we have a stateful encoder add the output states to the return
# dictionary.
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

return output
49 changes: 37 additions & 12 deletions rllib/algorithms/dqn/torch/dqn_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,29 @@ def compute_loss_for_module(
fwd_out: Dict[str, TensorType]
) -> TensorType:

# Possibly apply masking to some sub loss terms and to the total loss term
# at the end. Masking could be used for RNN-based model (zero padded `batch`)
# and for PPO's batched value function (and bootstrap value) computations,
# for which we add an (artificial) timestep to each episode to
# simplify the actual computation.
if Columns.LOSS_MASK in batch:
mask = batch[Columns.LOSS_MASK]
num_valid = torch.sum(mask)

def possibly_masked_mean(data_):
return torch.sum(data_[mask]) / num_valid

def possibly_masked_min(data_):
return torch.max(data_[mask])

def possibly_masked_max(data_):
return torch.max(data_[mask])

else:
possibly_masked_mean = torch.mean
possibly_masked_min = torch.min
possibly_masked_max = torch.max

q_curr = fwd_out[QF_PREDS]
q_target_next = fwd_out[QF_TARGET_NEXT_PREDS]

Expand All @@ -53,34 +76,36 @@ def compute_loss_for_module(
q_selected = torch.nan_to_num(
torch.gather(
q_curr,
dim=1,
index=batch[Columns.ACTIONS].view(-1, 1).expand(-1, 1).long(),
dim=-1,
index=batch[Columns.ACTIONS]
.view(*batch[Columns.ACTIONS].shape, 1)
.long(),
),
neginf=0.0,
).squeeze()
).squeeze(dim=-1)

# Use double Q learning.
if config.double_q:
# Then we evaluate the target Q-function at the best action (greedy action)
# over the online Q-function.
# Mark the best online Q-value of the next state.
q_next_best_idx = (
torch.argmax(fwd_out[QF_NEXT_PREDS], dim=1).unsqueeze(dim=-1).long()
torch.argmax(fwd_out[QF_NEXT_PREDS], dim=-1).unsqueeze(dim=-1).long()
)
# Get the Q-value of the target network at maximum of the online network
# (bootstrap action).
q_next_best = torch.nan_to_num(
torch.gather(q_target_next, dim=1, index=q_next_best_idx),
torch.gather(q_target_next, dim=-1, index=q_next_best_idx),
neginf=0.0,
).squeeze()
else:
# Mark the maximum Q-value(s).
q_next_best_idx = (
torch.argmax(q_target_next, dim=1).unsqueeze(dim=-1).long()
torch.argmax(q_target_next, dim=-1).unsqueeze(dim=-1).long()
)
# Get the maximum Q-value(s).
q_next_best = torch.nan_to_num(
torch.gather(q_target_next, dim=1, index=q_next_best_idx),
torch.gather(q_target_next, dim=-1, index=q_next_best_idx),
neginf=0.0,
).squeeze()

Expand Down Expand Up @@ -179,7 +204,7 @@ def compute_loss_for_module(
# Compute the TD error.
td_error = torch.abs(q_selected - q_selected_target)
# Compute the weighted loss (importance sampling weights).
total_loss = torch.mean(
total_loss = possibly_masked_mean(
batch["weights"]
* loss_fn(reduction="none")(q_selected, q_selected_target)
)
Expand All @@ -198,10 +223,10 @@ def compute_loss_for_module(
self.metrics.log_dict(
{
QF_LOSS_KEY: total_loss,
QF_MEAN_KEY: torch.mean(q_selected),
QF_MAX_KEY: torch.max(q_selected),
QF_MIN_KEY: torch.min(q_selected),
TD_ERROR_MEAN_KEY: torch.mean(td_error),
QF_MEAN_KEY: possibly_masked_mean(q_selected),
QF_MAX_KEY: possibly_masked_max(q_selected),
QF_MIN_KEY: possibly_masked_min(q_selected),
TD_ERROR_MEAN_KEY: possibly_masked_mean(td_error),
},
key=module_id,
window=1, # <- single items (should not be mean/ema-reduced over time).
Expand Down
Loading
Loading