diff --git a/rllib/algorithms/dqn/default_dqn_rl_module.py b/rllib/algorithms/dqn/default_dqn_rl_module.py index 5f56aae7104fd..056051f50ca7e 100644 --- a/rllib/algorithms/dqn/default_dqn_rl_module.py +++ b/rllib/algorithms/dqn/default_dqn_rl_module.py @@ -132,6 +132,13 @@ def compute_q_values(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType {"af": self.af, "vf": self.vf} if self.uses_dueling else self.af, ) + @override(RLModule) + def get_initial_state(self) -> dict: + if hasattr(self.encoder, "get_initial_state"): + return self.encoder.get_initial_state() + else: + return {} + @override(RLModule) def input_specs_train(self) -> SpecType: return [ diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 6a6c47f46038c..0575b07925fed 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -18,12 +18,6 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy -from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( - AddObservationsFromEpisodesToBatch, -) -from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa - AddNextObservationsFromEpisodesToTrainBatch, -) from ray.rllib.core.learner import Learner from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.execution.rollout_ops import ( @@ -544,28 +538,6 @@ def get_default_learner_class(self) -> Union[Type["Learner"], str]: "Use `config.framework('torch')` instead." ) - @override(AlgorithmConfig) - def build_learner_connector( - self, - input_observation_space, - input_action_space, - device=None, - ): - pipeline = super().build_learner_connector( - input_observation_space=input_observation_space, - input_action_space=input_action_space, - device=device, - ) - - # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right - # after the corresponding "add-OBS-..." default piece). - pipeline.insert_after( - AddObservationsFromEpisodesToBatch, - AddNextObservationsFromEpisodesToTrainBatch(), - ) - - return pipeline - def calculate_rr_weights(config: AlgorithmConfig) -> List[float]: """Calculate the round robin weights for the rollout and train steps""" @@ -674,9 +646,15 @@ def _training_step_new_api_stack(self): for _ in range(sample_and_train_weight): # Sample a list of episodes used for learning from the replay buffer. with self.metrics.log_time((TIMERS, REPLAY_BUFFER_SAMPLE_TIMER)): + episodes = self.local_replay_buffer.sample( num_items=self.config.total_train_batch_size, n_step=self.config.n_step, + # In case an `EpisodeReplayBuffer` is used we need to provide + # the sequence length. + batch_length_T=self.env_runner.module.is_stateful() + * self.config.model_config.get("max_seq_len", 0), + lookback=int(self.env_runner.module.is_stateful()), gamma=self.config.gamma, beta=self.config.replay_buffer_config.get("beta"), sample_episodes=True, diff --git a/rllib/algorithms/dqn/dqn_learner.py b/rllib/algorithms/dqn/dqn_learner.py index 4a4c17271dfbc..b55385eaf939d 100644 --- a/rllib/algorithms/dqn/dqn_learner.py +++ b/rllib/algorithms/dqn/dqn_learner.py @@ -1,5 +1,11 @@ from typing import Any, Dict, Optional +from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( + AddObservationsFromEpisodesToBatch, +) +from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa + AddNextObservationsFromEpisodesToTrainBatch, +) from ray.rllib.core.learner.learner import Learner from ray.rllib.core.learner.utils import update_target_network from ray.rllib.core.rl_module.apis import QNetAPI, TargetNetworkAPI @@ -48,6 +54,13 @@ def build(self) -> None: ) ) + # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right + # after the corresponding "add-OBS-..." default piece). + self._learner_connector.insert_after( + AddObservationsFromEpisodesToBatch, + AddNextObservationsFromEpisodesToTrainBatch(), + ) + @override(Learner) def add_module( self, diff --git a/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py b/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py index a98c44322cfdf..f583c504800c7 100644 --- a/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py +++ b/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py @@ -1,3 +1,4 @@ +import tree from typing import Dict, Union from ray.rllib.algorithms.dqn.default_dqn_rl_module import ( @@ -46,13 +47,20 @@ def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorTy # outputs directly the `argmax` of the logits. exploit_actions = action_dist.to_deterministic().sample() + output = {Columns.ACTIONS: exploit_actions} + if Columns.STATE_OUT in qf_outs: + output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT] + # In inference, we only need the exploitation actions. - return {Columns.ACTIONS: exploit_actions} + return output @override(RLModule) def _forward_exploration( self, batch: Dict[str, TensorType], t: int ) -> Dict[str, TensorType]: + # Define the return dictionary. + output = {} + # Q-network forward pass. qf_outs = self.compute_q_values(batch) @@ -73,7 +81,13 @@ def _forward_exploration( B = qf_outs[QF_PREDS].shape[0] random_actions = torch.squeeze( torch.multinomial( - (torch.nan_to_num(qf_outs[QF_PREDS], neginf=0.0) != 0.0).float(), + ( + torch.nan_to_num( + qf_outs[QF_PREDS].reshape(-1, qf_outs[QF_PREDS].size(-1)), + neginf=0.0, + ) + != 0.0 + ).float(), num_samples=1, ), dim=1, @@ -85,7 +99,14 @@ def _forward_exploration( exploit_actions, ) - return {Columns.ACTIONS: actions} + # Add the actions to the return dictionary. + output[Columns.ACTIONS] = actions + + # If this is a stateful module, add output states. + if Columns.STATE_OUT in qf_outs: + output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT] + + return output @override(RLModule) def _forward_train( @@ -107,11 +128,33 @@ def _forward_train( [batch[Columns.OBS], batch[Columns.NEXT_OBS]], dim=0 ), } + # If this is a stateful module add the input states. + if Columns.STATE_IN in batch: + # Add both, the input state for the actual observation and + # the one for the next observation. + batch_base.update( + { + Columns.STATE_IN: tree.map_structure( + lambda t1, t2: torch.cat([t1, t2], dim=0), + batch[Columns.STATE_IN], + batch[Columns.NEXT_STATE_IN], + ) + } + ) # Otherwise we can just use the current observations. else: batch_base = {Columns.OBS: batch[Columns.OBS]} + # If this is a stateful module add the input state. + if Columns.STATE_IN in batch: + batch_base.update({Columns.STATE_IN: batch[Columns.STATE_IN]}) + batch_target = {Columns.OBS: batch[Columns.NEXT_OBS]} + # If we have a stateful encoder, add the states for the target forward + # pass. + if Columns.NEXT_STATE_IN in batch: + batch_target.update({Columns.STATE_IN: batch[Columns.NEXT_STATE_IN]}) + # Q-network forward passes. qf_outs = self.compute_q_values(batch_base) if self.uses_double_q: @@ -135,6 +178,14 @@ def _forward_train( # Probabilities of the target Q-value distribution of the next state. output[QF_TARGET_NEXT_PROBS] = qf_target_next_outs[QF_PROBS] + # Add the states to the output, if the module is stateful. + if Columns.STATE_OUT in qf_outs: + output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT] + # For correctness, also add the output states from the target forward pass. + # Note, we do not backpropagate through this state. + if Columns.STATE_OUT in qf_target_next_outs: + output[Columns.NEXT_STATE_OUT] = qf_target_next_outs[Columns.STATE_OUT] + return output @override(QNetAPI) @@ -154,7 +205,7 @@ def compute_advantage_distribution( # Reshape the action values. # NOTE: Handcrafted action shape. logits_per_action_per_atom = torch.reshape( - batch, shape=(-1, self.action_space.n, self.num_atoms) + batch, shape=(*batch.shape[:-1], self.action_space.n, self.num_atoms) ) # Calculate the probability for each action value atom. Note, # the sum along action value atoms of a single action value @@ -216,10 +267,12 @@ def _qf_forward_helper( # Center the advantage stream distribution. centered_af_logits = af_dist_output["logits"] - af_dist_output[ "logits" - ].mean(dim=1, keepdim=True) + ].mean(dim=-1, keepdim=True) # Calculate the Q-value distribution by adding advantage and # value stream. - qf_logits = centered_af_logits + vf_outs.unsqueeze(dim=-1) + qf_logits = centered_af_logits + vf_outs.view( + -1, *((1,) * (centered_af_logits.dim() - 1)) + ) # Calculate probabilites for the Q-value distribution along # the support given by the atoms. qf_probs = nn.functional.softmax(qf_logits, dim=-1) @@ -236,8 +289,8 @@ def _qf_forward_helper( # https://discuss.pytorch.org/t/gradient-computation-issue-due-to- # inplace-operation-unsure-how-to-debug-for-custom-model/170133 # Has to be a mean for each batch element. - af_outs_mean = torch.unsqueeze( - torch.nan_to_num(qf_outs, neginf=torch.nan).nanmean(dim=1), dim=1 + af_outs_mean = torch.nan_to_num(qf_outs, neginf=torch.nan).nanmean( + dim=-1, keepdim=True ) qf_outs = qf_outs - af_outs_mean # Add advantage and value stream. Note, we broadcast here. @@ -266,4 +319,9 @@ def _qf_forward_helper( # In this case we have a Q-head of dimension (1, action_space.n). output[QF_PREDS] = qf_outs + # If we have a stateful encoder add the output states to the return + # dictionary. + if Columns.STATE_OUT in encoder_outs: + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + return output diff --git a/rllib/algorithms/dqn/torch/dqn_torch_learner.py b/rllib/algorithms/dqn/torch/dqn_torch_learner.py index 0bc1db4a6fc82..c66b5b3a2a4af 100644 --- a/rllib/algorithms/dqn/torch/dqn_torch_learner.py +++ b/rllib/algorithms/dqn/torch/dqn_torch_learner.py @@ -44,6 +44,29 @@ def compute_loss_for_module( fwd_out: Dict[str, TensorType] ) -> TensorType: + # Possibly apply masking to some sub loss terms and to the total loss term + # at the end. Masking could be used for RNN-based model (zero padded `batch`) + # and for PPO's batched value function (and bootstrap value) computations, + # for which we add an (artificial) timestep to each episode to + # simplify the actual computation. + if Columns.LOSS_MASK in batch: + mask = batch[Columns.LOSS_MASK] + num_valid = torch.sum(mask) + + def possibly_masked_mean(data_): + return torch.sum(data_[mask]) / num_valid + + def possibly_masked_min(data_): + return torch.max(data_[mask]) + + def possibly_masked_max(data_): + return torch.max(data_[mask]) + + else: + possibly_masked_mean = torch.mean + possibly_masked_min = torch.min + possibly_masked_max = torch.max + q_curr = fwd_out[QF_PREDS] q_target_next = fwd_out[QF_TARGET_NEXT_PREDS] @@ -53,11 +76,13 @@ def compute_loss_for_module( q_selected = torch.nan_to_num( torch.gather( q_curr, - dim=1, - index=batch[Columns.ACTIONS].view(-1, 1).expand(-1, 1).long(), + dim=-1, + index=batch[Columns.ACTIONS] + .view(*batch[Columns.ACTIONS].shape, 1) + .long(), ), neginf=0.0, - ).squeeze() + ).squeeze(dim=-1) # Use double Q learning. if config.double_q: @@ -65,22 +90,22 @@ def compute_loss_for_module( # over the online Q-function. # Mark the best online Q-value of the next state. q_next_best_idx = ( - torch.argmax(fwd_out[QF_NEXT_PREDS], dim=1).unsqueeze(dim=-1).long() + torch.argmax(fwd_out[QF_NEXT_PREDS], dim=-1).unsqueeze(dim=-1).long() ) # Get the Q-value of the target network at maximum of the online network # (bootstrap action). q_next_best = torch.nan_to_num( - torch.gather(q_target_next, dim=1, index=q_next_best_idx), + torch.gather(q_target_next, dim=-1, index=q_next_best_idx), neginf=0.0, ).squeeze() else: # Mark the maximum Q-value(s). q_next_best_idx = ( - torch.argmax(q_target_next, dim=1).unsqueeze(dim=-1).long() + torch.argmax(q_target_next, dim=-1).unsqueeze(dim=-1).long() ) # Get the maximum Q-value(s). q_next_best = torch.nan_to_num( - torch.gather(q_target_next, dim=1, index=q_next_best_idx), + torch.gather(q_target_next, dim=-1, index=q_next_best_idx), neginf=0.0, ).squeeze() @@ -179,7 +204,7 @@ def compute_loss_for_module( # Compute the TD error. td_error = torch.abs(q_selected - q_selected_target) # Compute the weighted loss (importance sampling weights). - total_loss = torch.mean( + total_loss = possibly_masked_mean( batch["weights"] * loss_fn(reduction="none")(q_selected, q_selected_target) ) @@ -198,10 +223,10 @@ def compute_loss_for_module( self.metrics.log_dict( { QF_LOSS_KEY: total_loss, - QF_MEAN_KEY: torch.mean(q_selected), - QF_MAX_KEY: torch.max(q_selected), - QF_MIN_KEY: torch.min(q_selected), - TD_ERROR_MEAN_KEY: torch.mean(td_error), + QF_MEAN_KEY: possibly_masked_mean(q_selected), + QF_MAX_KEY: possibly_masked_max(q_selected), + QF_MIN_KEY: possibly_masked_min(q_selected), + TD_ERROR_MEAN_KEY: possibly_masked_mean(td_error), }, key=module_id, window=1, # <- single items (should not be mean/ema-reduced over time). diff --git a/rllib/connectors/common/add_observations_from_episodes_to_batch.py b/rllib/connectors/common/add_observations_from_episodes_to_batch.py index 2bc02e5068881..5e007d2e6734c 100644 --- a/rllib/connectors/common/add_observations_from_episodes_to_batch.py +++ b/rllib/connectors/common/add_observations_from_episodes_to_batch.py @@ -136,15 +136,22 @@ def __call__( # If "obs" already in data, early out. if Columns.OBS in batch: return batch - - for sa_episode in self.single_agent_episode_iterator( - episodes, - # If Learner connector, get all episodes (for train batch). - # If EnvToModule, get only those ongoing episodes that just had their - # agent step (b/c those are the ones we need to compute actions for next). - agents_that_stepped_only=not self._as_learner_connector, + for i, sa_episode in enumerate( + self.single_agent_episode_iterator( + episodes, + # If Learner connector, get all episodes (for train batch). + # If EnvToModule, get only those ongoing episodes that just had their + # agent step (b/c those are the ones we need to compute actions for + # next). + agents_that_stepped_only=not self._as_learner_connector, + ) ): if self._as_learner_connector: + # TODO (sven): Resolve this hack by adding a new connector piece that + # performs this very task. + if "_" not in sa_episode.id_: + sa_episode.id_ += "_" + str(i) + self.add_n_batch_items( batch, Columns.OBS, @@ -160,4 +167,5 @@ def __call__( item_to_add=sa_episode.get_observations(-1), single_agent_episode=sa_episode, ) + return batch diff --git a/rllib/connectors/common/add_states_from_episodes_to_batch.py b/rllib/connectors/common/add_states_from_episodes_to_batch.py index 3a5cf0a12c5ec..058b06395ef68 100644 --- a/rllib/connectors/common/add_states_from_episodes_to_batch.py +++ b/rllib/connectors/common/add_states_from_episodes_to_batch.py @@ -305,6 +305,7 @@ def __call__( # state. convert_to_numpy(sa_module.get_initial_state()) if sa_episode.t_started == 0 + or (Columns.STATE_OUT not in sa_episode.extra_model_outputs) # Episode starts somewhere in the middle (is a cut # continuation chunk) -> Use previous chunk's last # STATE_OUT as initial state. @@ -314,8 +315,24 @@ def __call__( neg_index_as_lookback=True, ) ) - # state_outs.shape=(T,[state-dim]) T=episode len - state_outs = sa_episode.get_extra_model_outputs(key=Columns.STATE_OUT) + # If we have `"state_out"`s (e.g. from rollouts) use them for the + # `"state_in"`s. + if Columns.STATE_OUT in sa_episode.extra_model_outputs: + # state_outs.shape=(T,[state-dim]) T=episode len + state_outs = sa_episode.get_extra_model_outputs( + key=Columns.STATE_OUT + ) + # Otherwise, we have no `"state_out"` (e.g. because we are sampling + # from offline data and the expert policy was not stateful). + else: + # Then simply use the `look_back_state`, i.e. in this case the + # initial state as `"state_in` in training. + state_outs = tree.map_structure( + lambda a: np.repeat( + a[np.newaxis, ...], len(sa_episode), axis=0 + ), + look_back_state, + ) self.add_n_batch_items( batch=batch, column=Columns.STATE_IN, @@ -334,6 +351,17 @@ def __call__( num_items=int(math.ceil(len(sa_episode) / max_seq_len)), single_agent_episode=sa_episode, ) + if Columns.NEXT_OBS in batch: + self.add_n_batch_items( + batch=batch, + column=Columns.NEXT_STATE_IN, + items_to_add=tree.map_structure( + lambda i, m=max_seq_len: i[::m], + state_outs, + ), + num_items=int(math.ceil(len(sa_episode) / max_seq_len)), + single_agent_episode=sa_episode, + ) # Also, create the loss mask (b/c of our now possibly zero-padded data) # as well as the seq_lens array and add these to `data` as well. @@ -365,13 +393,17 @@ def __call__( if not sa_module.is_stateful(): continue - # Episode just started -> Get initial state from our RLModule. - if sa_episode.t_started == 0 and len(sa_episode) == 0: + # Episode just started or has no `"state_out"` (e.g. in offline + # sampling) -> Get initial state from our RLModule. + if (sa_episode.t_started == 0 and len(sa_episode) == 0) or ( + Columns.STATE_OUT not in sa_episode.extra_model_outputs + ): state = sa_module.get_initial_state() # Episode is already ongoing -> Use most recent STATE_OUT. else: state = sa_episode.get_extra_model_outputs( - key=Columns.STATE_OUT, indices=-1 + key=Columns.STATE_OUT, + indices=-1, ) self.add_batch_item( batch, diff --git a/rllib/core/columns.py b/rllib/core/columns.py index b7ce7d2a67015..98cb8646913e2 100644 --- a/rllib/core/columns.py +++ b/rllib/core/columns.py @@ -43,7 +43,9 @@ class Columns: # Common extra RLModule output keys. STATE_IN = "state_in" + NEXT_STATE_IN = "next_state_in" STATE_OUT = "state_out" + NEXT_STATE_OUT = "next_state_out" EMBEDDINGS = "embeddings" ACTION_DIST_INPUTS = "action_dist_inputs" ACTION_PROB = "action_prob" diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index c33b12bee392f..3235f7ae6cfb7 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -176,8 +176,13 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] episodes = self._validate_episodes(episodes) # Add the episodes to the buffer. self.episode_buffer.add(episodes) + # TODO (simon): Refactor into a single code block for both cases. episodes = self.episode_buffer.sample( num_items=self.config.train_batch_size_per_learner, + batch_length_T=self.config.model_config.get("max_seq_len", 0) + if self._module.is_stateful() + else None, + n_step=self.config.get("n_step", 1) or 1, # TODO (simon): This can be removed as soon as DreamerV3 has been # cleaned up, i.e. can use episode samples for training. sample_episodes=True, @@ -199,6 +204,10 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] # Sample steps from the buffer. episodes = self.episode_buffer.sample( num_items=self.config.train_batch_size_per_learner, + batch_length_T=self.config.model_config.get("max_seq_len", 0) + if self._module.is_stateful() + else None, + n_step=self.config.get("n_step", 1) or 1, # TODO (simon): This can be removed as soon as DreamerV3 has been # cleaned up, i.e. can use episode samples for training. sample_episodes=True, diff --git a/rllib/tuned_examples/dqn/stateless_cartpole_dqn.py b/rllib/tuned_examples/dqn/stateless_cartpole_dqn.py new file mode 100644 index 0000000000000..e8610dea67a19 --- /dev/null +++ b/rllib/tuned_examples/dqn/stateless_cartpole_dqn.py @@ -0,0 +1,57 @@ +from ray.rllib.algorithms.dqn import DQNConfig +from ray.rllib.connectors.env_to_module import MeanStdFilter +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole +from ray.rllib.utils.test_utils import add_rllib_example_script_args + +parser = add_rllib_example_script_args( + default_timesteps=2000000, + default_reward=350.0, +) +parser.set_defaults( + enable_new_api_stack=True, + num_env_runners=3, +) +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values to set up `config` below. +args = parser.parse_args() + +config = ( + DQNConfig() + .environment(StatelessCartPole) + .env_runners( + env_to_module_connector=lambda env: MeanStdFilter(), + ) + .training( + lr=0.0003, + train_batch_size_per_learner=32, + replay_buffer_config={ + "type": "EpisodeReplayBuffer", + "capacity": 50000, + }, + n_step=1, + double_q=True, + dueling=True, + num_atoms=1, + epsilon=[(0, 1.0), (10000, 0.02)], + ) + .rl_module( + # Settings identical to old stack. + model_config=DefaultModelConfig( + fcnet_hiddens=[256], + fcnet_activation="tanh", + fcnet_bias_initializer="zeros_", + head_fcnet_bias_initializer="zeros_", + head_fcnet_hiddens=[256], + head_fcnet_activation="tanh", + lstm_kernel_initializer="xavier_uniform_", + use_lstm=True, + max_seq_len=20, + ), + ) +) + +if __name__ == "__main__": + from ray.rllib.utils.test_utils import run_rllib_example_script_experiment + + run_rllib_example_script_experiment(config, args) diff --git a/rllib/utils/replay_buffers/episode_replay_buffer.py b/rllib/utils/replay_buffers/episode_replay_buffer.py index 4a3309b301d9a..c02c066ccbe59 100644 --- a/rllib/utils/replay_buffers/episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/episode_replay_buffer.py @@ -6,6 +6,7 @@ import scipy from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer from ray.rllib.utils.annotations import override from ray.rllib.utils.replay_buffers.base import ReplayBufferInterface from ray.rllib.utils.typing import SampleBatchType @@ -219,6 +220,8 @@ def sample( include_extra_model_outputs: bool = False, sample_episodes: Optional[bool] = False, finalize: bool = False, + # TODO (simon): Check, if we need here 1 as default. + lookback: Optional[int] = 0, **kwargs, ) -> Union[SampleBatchType, SingleAgentEpisode]: """Samples from a buffer in a randomized way. @@ -264,6 +267,7 @@ def sample( the extra model outputs at the `"obs"` in the batch is included (the timestep at which the action is computed). finalize: If episodes should be finalized. + lookback: A desired lookback. Any non-negative integer is valid. Returns: Either a batch with transitions in each row or (if `return_episodes=True`) @@ -282,6 +286,7 @@ def sample( include_infos=include_infos, include_extra_model_outputs=include_extra_model_outputs, finalize=finalize, + lookback=lookback, ) else: return self._sample_batch( @@ -428,6 +433,7 @@ def _sample_episodes( include_infos: bool = False, include_extra_model_outputs: bool = False, finalize: bool = False, + lookback: Optional[int] = 1, **kwargs, ) -> List[SingleAgentEpisode]: """Samples episodes from a buffer in a randomized way. @@ -448,8 +454,8 @@ def _sample_episodes( buffer. batch_size_B: The number of rows (transitions) to return in the batch - batch_length_T: THe sequence length to sample. At this point in time - only sequences of length 1 are possible. + batch_length_T: The sequence length to sample. Can be either `None` + (the default) or any positive integer. n_step: The n-step to apply. For the default the batch contains in `"new_obs"` the observation and in `"obs"` the observation `n` time steps before. The reward will be the sum of rewards @@ -473,6 +479,7 @@ def _sample_episodes( the extra model outputs at the `"obs"` in the batch is included (the timestep at which the action is computed). finalize: If episodes should be finalized. + lookback: A desired lookback. Any non-negative integer is valid. Returns: A list of 1-step long episodes containing all basic episode data and if @@ -487,12 +494,22 @@ def _sample_episodes( # Use our default values if no sizes/lengths provided. batch_size_B = batch_size_B or self.batch_size_B - # TODO (simon): Implement trajectory sampling for RNNs. - batch_length_T = batch_length_T or self.batch_length_T - # Sample the n-step if necessary. - actual_n_step = n_step or 1 - random_n_step = isinstance(n_step, tuple) + assert n_step is not None, ( + "When sampling episodes, `n_step` must be " + "provided, but `n_step` is `None`." + ) + # If no sequence should be sampled, we sample n-steps. + if not batch_length_T: + # Sample the `n_step`` itself, if necessary. + actual_n_step = n_step + random_n_step = isinstance(n_step, tuple) + # Otherwise we use an n-step of 1. + else: + assert ( + not isinstance(n_step, tuple) and n_step == 1 + ), "When sampling sequences n-step must be 1." + actual_n_step = n_step # Keep track of the indices that were sampled last for updating the # weights later (see `ray.rllib.utils.replay_buffer.utils. @@ -515,81 +532,96 @@ def _sample_episodes( episode = self.episodes[episode_idx] # If we use random n-step sampling, draw the n-step for this item. - if random_n_step: + if not batch_length_T and random_n_step: actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) # Skip, if we are too far to the end and `episode_ts` + n_step would go # beyond the episode's end. - if episode_ts + actual_n_step > len(episode): - continue - - # Note, this will be the reward after executing action - # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the discounted - # sum of all discounted rewards that were collected over the last n steps. - raw_rewards = episode.get_rewards( - slice(episode_ts, episode_ts + actual_n_step) - ) - rewards = scipy.signal.lfilter([1], [1, -gamma], raw_rewards[::-1], axis=0)[ - -1 - ] - - # Generate the episode to be returned. - sampled_episode = SingleAgentEpisode( - # Ensure that each episode contains a tuple of the form: - # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) - # Two observations (t and t+n). - observations=[ - episode.get_observations(episode_ts), - episode.get_observations(episode_ts + actual_n_step), - ], - observation_space=episode.observation_space, - infos=( - [ - episode.get_infos(episode_ts), - episode.get_infos(episode_ts + actual_n_step), - ] - if include_infos - else None - ), - actions=[episode.get_actions(episode_ts)], - action_space=episode.action_space, - rewards=[rewards], - # If the sampled time step is the episode's last time step check, if - # the episode is terminated or truncated. - terminated=( - False - if episode_ts + actual_n_step < len(episode) - else episode.is_terminated - ), - truncated=( - False - if episode_ts + actual_n_step < len(episode) - else episode.is_truncated - ), - extra_model_outputs={ - # TODO (simon): Check, if we have to correct here for sequences - # later. - "n_step": [actual_n_step], - **( - { - k: [episode.get_extra_model_outputs(k, episode_ts)] - for k in episode.extra_model_outputs.keys() - } - if include_extra_model_outputs - else {} + if episode_ts + (batch_length_T or 0) + (actual_n_step - 1) > len(episode): + actual_length = len(episode) + else: + actual_length = episode_ts + (batch_length_T or 0) + (actual_n_step - 1) + + # If no sequence should be sampled, we sample here the n-step. + if not batch_length_T: + sampled_episode = episode.slice( + slice( + episode_ts, + episode_ts + actual_n_step, + ) + ) + # Note, this will be the reward after executing action + # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the discounted + # sum of all discounted rewards that were collected over the last n + # steps. + raw_rewards = sampled_episode.get_rewards() + + rewards = scipy.signal.lfilter( + [1], [1, -gamma], raw_rewards[::-1], axis=0 + )[-1] + + sampled_episode = SingleAgentEpisode( + id_=sampled_episode.id_, + agent_id=sampled_episode.agent_id, + module_id=sampled_episode.module_id, + observation_space=sampled_episode.observation_space, + action_space=sampled_episode.action_space, + observations=[ + sampled_episode.get_observations(0), + sampled_episode.get_observations(-1), + ], + actions=[sampled_episode.get_actions(0)], + rewards=[rewards], + infos=[ + sampled_episode.get_infos(0), + sampled_episode.get_infos(-1), + ], + terminated=sampled_episode.is_terminated, + truncated=sampled_episode.is_truncated, + extra_model_outputs={ + **( + { + k: [episode.get_extra_model_outputs(k, 0)] + for k in episode.extra_model_outputs.keys() + } + if include_extra_model_outputs + else {} + ), + }, + t_started=episode_ts, + len_lookback_buffer=0, + ) + # Otherwise we simply slice the episode. + else: + sampled_episode = episode.slice( + slice( + episode_ts, + actual_length, ), - }, - # TODO (sven): Support lookback buffers. - len_lookback_buffer=0, - t_started=episode_ts, + len_lookback_buffer=lookback, + ) + + # Remove reference to sampled episode. + del episode + + # Add the actually chosen n-step in this episode. + sampled_episode.extra_model_outputs["n_step"] = InfiniteLookbackBuffer( + np.full((len(sampled_episode) + lookback,), actual_n_step), + lookback=lookback, ) - if finalize: - sampled_episode.finalize() + # Some loss functions need `weights` - which are only relevant when + # prioritizing. + sampled_episode.extra_model_outputs["weights"] = InfiniteLookbackBuffer( + np.ones((len(sampled_episode) + lookback,)), lookback=lookback + ) + + # Append the sampled episode. sampled_episodes.append(sampled_episode) # Increment counter. - B += 1 + B += (actual_length - episode_ts + 1) or 1 + # Update the metric. self.sampled_timesteps += batch_size_B return sampled_episodes