Skip to content

Commit

Permalink
[RLlib; offline RL] Add sequence sampling to 'EpisodeReplayBuffer'. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Jan 6, 2025
1 parent f4c4c81 commit a29d0c4
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 132 deletions.
7 changes: 7 additions & 0 deletions rllib/algorithms/dqn/default_dqn_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def compute_q_values(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType
{"af": self.af, "vf": self.vf} if self.uses_dueling else self.af,
)

@override(RLModule)
def get_initial_state(self) -> dict:
if hasattr(self.encoder, "get_initial_state"):
return self.encoder.get_initial_state()
else:
return {}

@override(RLModule)
def input_specs_train(self) -> SpecType:
return [
Expand Down
34 changes: 6 additions & 28 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner import Learner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.execution.rollout_ops import (
Expand Down Expand Up @@ -544,28 +538,6 @@ def get_default_learner_class(self) -> Union[Type["Learner"], str]:
"Use `config.framework('torch')` instead."
)

@override(AlgorithmConfig)
def build_learner_connector(
self,
input_observation_space,
input_action_space,
device=None,
):
pipeline = super().build_learner_connector(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
device=device,
)

# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)

return pipeline


def calculate_rr_weights(config: AlgorithmConfig) -> List[float]:
"""Calculate the round robin weights for the rollout and train steps"""
Expand Down Expand Up @@ -674,9 +646,15 @@ def _training_step_new_api_stack(self):
for _ in range(sample_and_train_weight):
# Sample a list of episodes used for learning from the replay buffer.
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_SAMPLE_TIMER)):

episodes = self.local_replay_buffer.sample(
num_items=self.config.total_train_batch_size,
n_step=self.config.n_step,
# In case an `EpisodeReplayBuffer` is used we need to provide
# the sequence length.
batch_length_T=self.env_runner.module.is_stateful()
* self.config.model_config.get("max_seq_len", 0),
lookback=int(self.env_runner.module.is_stateful()),
gamma=self.config.gamma,
beta=self.config.replay_buffer_config.get("beta"),
sample_episodes=True,
Expand Down
13 changes: 13 additions & 0 deletions rllib/algorithms/dqn/dqn_learner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from typing import Any, Dict, Optional

from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.learner.utils import update_target_network
from ray.rllib.core.rl_module.apis import QNetAPI, TargetNetworkAPI
Expand Down Expand Up @@ -48,6 +54,13 @@ def build(self) -> None:
)
)

# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
self._learner_connector.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)

@override(Learner)
def add_module(
self,
Expand Down
74 changes: 66 additions & 8 deletions rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tree
from typing import Dict, Union

from ray.rllib.algorithms.dqn.default_dqn_rl_module import (
Expand Down Expand Up @@ -46,13 +47,20 @@ def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorTy
# outputs directly the `argmax` of the logits.
exploit_actions = action_dist.to_deterministic().sample()

output = {Columns.ACTIONS: exploit_actions}
if Columns.STATE_OUT in qf_outs:
output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]

# In inference, we only need the exploitation actions.
return {Columns.ACTIONS: exploit_actions}
return output

@override(RLModule)
def _forward_exploration(
self, batch: Dict[str, TensorType], t: int
) -> Dict[str, TensorType]:
# Define the return dictionary.
output = {}

# Q-network forward pass.
qf_outs = self.compute_q_values(batch)

Expand All @@ -73,7 +81,13 @@ def _forward_exploration(
B = qf_outs[QF_PREDS].shape[0]
random_actions = torch.squeeze(
torch.multinomial(
(torch.nan_to_num(qf_outs[QF_PREDS], neginf=0.0) != 0.0).float(),
(
torch.nan_to_num(
qf_outs[QF_PREDS].reshape(-1, qf_outs[QF_PREDS].size(-1)),
neginf=0.0,
)
!= 0.0
).float(),
num_samples=1,
),
dim=1,
Expand All @@ -85,7 +99,14 @@ def _forward_exploration(
exploit_actions,
)

return {Columns.ACTIONS: actions}
# Add the actions to the return dictionary.
output[Columns.ACTIONS] = actions

# If this is a stateful module, add output states.
if Columns.STATE_OUT in qf_outs:
output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]

return output

@override(RLModule)
def _forward_train(
Expand All @@ -107,11 +128,33 @@ def _forward_train(
[batch[Columns.OBS], batch[Columns.NEXT_OBS]], dim=0
),
}
# If this is a stateful module add the input states.
if Columns.STATE_IN in batch:
# Add both, the input state for the actual observation and
# the one for the next observation.
batch_base.update(
{
Columns.STATE_IN: tree.map_structure(
lambda t1, t2: torch.cat([t1, t2], dim=0),
batch[Columns.STATE_IN],
batch[Columns.NEXT_STATE_IN],
)
}
)
# Otherwise we can just use the current observations.
else:
batch_base = {Columns.OBS: batch[Columns.OBS]}
# If this is a stateful module add the input state.
if Columns.STATE_IN in batch:
batch_base.update({Columns.STATE_IN: batch[Columns.STATE_IN]})

batch_target = {Columns.OBS: batch[Columns.NEXT_OBS]}

# If we have a stateful encoder, add the states for the target forward
# pass.
if Columns.NEXT_STATE_IN in batch:
batch_target.update({Columns.STATE_IN: batch[Columns.NEXT_STATE_IN]})

# Q-network forward passes.
qf_outs = self.compute_q_values(batch_base)
if self.uses_double_q:
Expand All @@ -135,6 +178,14 @@ def _forward_train(
# Probabilities of the target Q-value distribution of the next state.
output[QF_TARGET_NEXT_PROBS] = qf_target_next_outs[QF_PROBS]

# Add the states to the output, if the module is stateful.
if Columns.STATE_OUT in qf_outs:
output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]
# For correctness, also add the output states from the target forward pass.
# Note, we do not backpropagate through this state.
if Columns.STATE_OUT in qf_target_next_outs:
output[Columns.NEXT_STATE_OUT] = qf_target_next_outs[Columns.STATE_OUT]

return output

@override(QNetAPI)
Expand All @@ -154,7 +205,7 @@ def compute_advantage_distribution(
# Reshape the action values.
# NOTE: Handcrafted action shape.
logits_per_action_per_atom = torch.reshape(
batch, shape=(-1, self.action_space.n, self.num_atoms)
batch, shape=(*batch.shape[:-1], self.action_space.n, self.num_atoms)
)
# Calculate the probability for each action value atom. Note,
# the sum along action value atoms of a single action value
Expand Down Expand Up @@ -216,10 +267,12 @@ def _qf_forward_helper(
# Center the advantage stream distribution.
centered_af_logits = af_dist_output["logits"] - af_dist_output[
"logits"
].mean(dim=1, keepdim=True)
].mean(dim=-1, keepdim=True)
# Calculate the Q-value distribution by adding advantage and
# value stream.
qf_logits = centered_af_logits + vf_outs.unsqueeze(dim=-1)
qf_logits = centered_af_logits + vf_outs.view(
-1, *((1,) * (centered_af_logits.dim() - 1))
)
# Calculate probabilites for the Q-value distribution along
# the support given by the atoms.
qf_probs = nn.functional.softmax(qf_logits, dim=-1)
Expand All @@ -236,8 +289,8 @@ def _qf_forward_helper(
# https://discuss.pytorch.org/t/gradient-computation-issue-due-to-
# inplace-operation-unsure-how-to-debug-for-custom-model/170133
# Has to be a mean for each batch element.
af_outs_mean = torch.unsqueeze(
torch.nan_to_num(qf_outs, neginf=torch.nan).nanmean(dim=1), dim=1
af_outs_mean = torch.nan_to_num(qf_outs, neginf=torch.nan).nanmean(
dim=-1, keepdim=True
)
qf_outs = qf_outs - af_outs_mean
# Add advantage and value stream. Note, we broadcast here.
Expand Down Expand Up @@ -266,4 +319,9 @@ def _qf_forward_helper(
# In this case we have a Q-head of dimension (1, action_space.n).
output[QF_PREDS] = qf_outs

# If we have a stateful encoder add the output states to the return
# dictionary.
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

return output
49 changes: 37 additions & 12 deletions rllib/algorithms/dqn/torch/dqn_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,29 @@ def compute_loss_for_module(
fwd_out: Dict[str, TensorType]
) -> TensorType:

# Possibly apply masking to some sub loss terms and to the total loss term
# at the end. Masking could be used for RNN-based model (zero padded `batch`)
# and for PPO's batched value function (and bootstrap value) computations,
# for which we add an (artificial) timestep to each episode to
# simplify the actual computation.
if Columns.LOSS_MASK in batch:
mask = batch[Columns.LOSS_MASK]
num_valid = torch.sum(mask)

def possibly_masked_mean(data_):
return torch.sum(data_[mask]) / num_valid

def possibly_masked_min(data_):
return torch.max(data_[mask])

def possibly_masked_max(data_):
return torch.max(data_[mask])

else:
possibly_masked_mean = torch.mean
possibly_masked_min = torch.min
possibly_masked_max = torch.max

q_curr = fwd_out[QF_PREDS]
q_target_next = fwd_out[QF_TARGET_NEXT_PREDS]

Expand All @@ -53,34 +76,36 @@ def compute_loss_for_module(
q_selected = torch.nan_to_num(
torch.gather(
q_curr,
dim=1,
index=batch[Columns.ACTIONS].view(-1, 1).expand(-1, 1).long(),
dim=-1,
index=batch[Columns.ACTIONS]
.view(*batch[Columns.ACTIONS].shape, 1)
.long(),
),
neginf=0.0,
).squeeze()
).squeeze(dim=-1)

# Use double Q learning.
if config.double_q:
# Then we evaluate the target Q-function at the best action (greedy action)
# over the online Q-function.
# Mark the best online Q-value of the next state.
q_next_best_idx = (
torch.argmax(fwd_out[QF_NEXT_PREDS], dim=1).unsqueeze(dim=-1).long()
torch.argmax(fwd_out[QF_NEXT_PREDS], dim=-1).unsqueeze(dim=-1).long()
)
# Get the Q-value of the target network at maximum of the online network
# (bootstrap action).
q_next_best = torch.nan_to_num(
torch.gather(q_target_next, dim=1, index=q_next_best_idx),
torch.gather(q_target_next, dim=-1, index=q_next_best_idx),
neginf=0.0,
).squeeze()
else:
# Mark the maximum Q-value(s).
q_next_best_idx = (
torch.argmax(q_target_next, dim=1).unsqueeze(dim=-1).long()
torch.argmax(q_target_next, dim=-1).unsqueeze(dim=-1).long()
)
# Get the maximum Q-value(s).
q_next_best = torch.nan_to_num(
torch.gather(q_target_next, dim=1, index=q_next_best_idx),
torch.gather(q_target_next, dim=-1, index=q_next_best_idx),
neginf=0.0,
).squeeze()

Expand Down Expand Up @@ -179,7 +204,7 @@ def compute_loss_for_module(
# Compute the TD error.
td_error = torch.abs(q_selected - q_selected_target)
# Compute the weighted loss (importance sampling weights).
total_loss = torch.mean(
total_loss = possibly_masked_mean(
batch["weights"]
* loss_fn(reduction="none")(q_selected, q_selected_target)
)
Expand All @@ -198,10 +223,10 @@ def compute_loss_for_module(
self.metrics.log_dict(
{
QF_LOSS_KEY: total_loss,
QF_MEAN_KEY: torch.mean(q_selected),
QF_MAX_KEY: torch.max(q_selected),
QF_MIN_KEY: torch.min(q_selected),
TD_ERROR_MEAN_KEY: torch.mean(td_error),
QF_MEAN_KEY: possibly_masked_mean(q_selected),
QF_MAX_KEY: possibly_masked_max(q_selected),
QF_MIN_KEY: possibly_masked_min(q_selected),
TD_ERROR_MEAN_KEY: possibly_masked_mean(td_error),
},
key=module_id,
window=1, # <- single items (should not be mean/ema-reduced over time).
Expand Down
22 changes: 15 additions & 7 deletions rllib/connectors/common/add_observations_from_episodes_to_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,22 @@ def __call__(
# If "obs" already in data, early out.
if Columns.OBS in batch:
return batch

for sa_episode in self.single_agent_episode_iterator(
episodes,
# If Learner connector, get all episodes (for train batch).
# If EnvToModule, get only those ongoing episodes that just had their
# agent step (b/c those are the ones we need to compute actions for next).
agents_that_stepped_only=not self._as_learner_connector,
for i, sa_episode in enumerate(
self.single_agent_episode_iterator(
episodes,
# If Learner connector, get all episodes (for train batch).
# If EnvToModule, get only those ongoing episodes that just had their
# agent step (b/c those are the ones we need to compute actions for
# next).
agents_that_stepped_only=not self._as_learner_connector,
)
):
if self._as_learner_connector:
# TODO (sven): Resolve this hack by adding a new connector piece that
# performs this very task.
if "_" not in sa_episode.id_:
sa_episode.id_ += "_" + str(i)

self.add_n_batch_items(
batch,
Columns.OBS,
Expand All @@ -160,4 +167,5 @@ def __call__(
item_to_add=sa_episode.get_observations(-1),
single_agent_episode=sa_episode,
)

return batch
Loading

0 comments on commit a29d0c4

Please sign in to comment.