Skip to content

Commit

Permalink
[RLlib; Offline RL] Validate episodes before adding them to the buffe…
Browse files Browse the repository at this point in the history
…r. (ray-project#48083)

Signed-off-by: JP-sDEV <jon.pablo80@gmail.com>
  • Loading branch information
simonsays1980 authored and JP-sDEV committed Nov 14, 2024
1 parent 702e2eb commit 0e04eb1
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 59 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1698,7 +1698,7 @@ py_test(
py_test(
name = "test_offline_prelearner",
tags = ["team:rllib", "offline"],
size = "small",
size = "medium",
srcs = ["offline/tests/test_offline_prelearner.py"],
# Include the offline data files.
data = [
Expand Down
35 changes: 35 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_read_schema = {}
self.input_read_episodes = False
self.input_read_sample_batches = False
self.input_read_batch_size = None
self.input_filesystem = None
self.input_filesystem_kwargs = {}
self.input_compress_columns = [Columns.OBS, Columns.NEXT_OBS]
Expand Down Expand Up @@ -2556,6 +2557,7 @@ def offline_data(
input_read_schema: Optional[Dict[str, str]] = NotProvided,
input_read_episodes: Optional[bool] = NotProvided,
input_read_sample_batches: Optional[bool] = NotProvided,
input_read_batch_size: Optional[int] = NotProvided,
input_filesystem: Optional[str] = NotProvided,
input_filesystem_kwargs: Optional[Dict] = NotProvided,
input_compress_columns: Optional[List[str]] = NotProvided,
Expand Down Expand Up @@ -2638,6 +2640,15 @@ def offline_data(
RLlib's `EpisodeType` (i.e. `SingleAgentEpisode` or
`MultiAgentEpisode`). The default is False. `input_read_episodes`
and `input_read_sample_batches` cannot be True at the same time.
input_read_batch_size: Batch size to pull from the data set. This could
differ from the `train_batch_size_per_learner`, if a dataset holds
`EpisodeType` (i.e. `SingleAgentEpisode` or `MultiAgentEpisode`) or
`BatchType` (i.e. `SampleBatch` or `MultiAgentBatch`) or any other
data type that contains multiple timesteps in a single row of the
dataset. In such cases a single batch of size
`train_batch_size_per_learner` will potentially pull a multiple of
`train_batch_size_per_learner` timesteps from the offline dataset. The
default is `None` in which the `train_batch_size_per_learner` is pulled.
input_filesystem: A cloud filesystem to handle access to cloud storage when
reading experiences. Should be either "gcs" for Google Cloud Storage,
"s3" for AWS S3 buckets, or "abs" for Azure Blob Storage.
Expand Down Expand Up @@ -2771,6 +2782,8 @@ def offline_data(
self.input_read_episodes = input_read_episodes
if input_read_sample_batches is not NotProvided:
self.input_read_sample_batches = input_read_sample_batches
if input_read_batch_size is not NotProvided:
self.input_read_batch_size = input_read_batch_size
if input_filesystem is not NotProvided:
self.input_filesystem = input_filesystem
if input_filesystem_kwargs is not NotProvided:
Expand Down Expand Up @@ -4662,6 +4675,28 @@ def _validate_offline_settings(self):
"`Single-/MultiAgentEpisode`s."
)

if self.input_read_batch_size and not (
self.input_read_episodes or self.input_read_sample_batches
):
raise ValueError(
"Setting `input_read_batch_size` is only allowed in case of a "
"dataset that holds either `EpisodeType` or `BatchType` data (i.e. "
"rows that contains multiple timesteps), but neither "
"`input_read_episodes` nor `input_read_sample_batches` is set to "
"`True`."
)

if (
self.output
and self.output_write_episodes
and self.batch_mode != "complete_episodes"
):
raise ValueError(
"When recording episodes only complete episodes should be "
"recorded (i.e. `batch_mode=='complete_episodes'`). Otherwise "
"recorded episodes cannot be read in for training."
)

@staticmethod
def _serialize_dict(config):
# Serialize classes to classpaths:
Expand Down
5 changes: 4 additions & 1 deletion rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, config: AlgorithmConfig):
self.data_read_method_kwargs = (
self.default_read_method_kwargs | self.config.input_read_method_kwargs
)
# In case `EpisodeType` or `BatchType` batches are read the size
# could differ from the final `train_batch_size_per_learner`.
self.data_read_batch_size = self.config.input_read_batch_size

# If data should be materialized.
self.materialize_data = config.materialize_data
Expand Down Expand Up @@ -153,7 +156,7 @@ def sample(
self.data = self.data.map_batches(
self.prelearner_class,
fn_constructor_kwargs=fn_constructor_kwargs,
batch_size=num_samples,
batch_size=self.data_read_batch_size or num_samples,
**self.map_batches_kwargs,
)
# Set the flag to `True`.
Expand Down
69 changes: 65 additions & 4 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import numpy as np
import random
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Union, Set, Tuple, TYPE_CHECKING

import ray
from ray.actor import ActorHandle
Expand All @@ -17,6 +17,7 @@
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.compression import unpack_if_needed
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.utils.spaces.space_utils import from_jsonable_if_needed
from ray.rllib.utils.typing import EpisodeType, ModuleID

Expand Down Expand Up @@ -165,7 +166,16 @@ def __init__(

@OverrideToImplementCustomLogic
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]:
"""Prepares plain data batches for training with `Learner`s.
Args:
batch: A dictionary of numpy arrays containing either column data
with `self.config.input_read_schema`, `EpisodeType` data, or
`BatchType` data.
Returns:
A `MultiAgentBatch` that can be passed to `Learner.update` methods.
"""
# If we directly read in episodes we just convert to list.
if self.input_read_episodes:
# Import `msgpack` for decoding.
Expand All @@ -179,6 +189,9 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
)
for state in batch["item"]
]
# Ensure that all episodes are done and no duplicates are in the batch.
episodes = self._validate_episodes(episodes)
# Add the episodes to the buffer.
self.episode_buffer.add(episodes)
episodes = self.episode_buffer.sample(
num_items=self.config.train_batch_size_per_learner,
Expand All @@ -196,7 +209,11 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
schema=SCHEMA | self.config.input_read_schema,
input_compress_columns=self.config.input_compress_columns,
)["episodes"]
# Ensure that all episodes are done and no duplicates are in the batch.
episodes = self._validate_episodes(episodes)
# Add the episodes to the buffer.
self.episode_buffer.add(episodes)
# Sample steps from the buffer.
episodes = self.episode_buffer.sample(
num_items=self.config.train_batch_size_per_learner,
# TODO (simon): This can be removed as soon as DreamerV3 has been
Expand Down Expand Up @@ -274,7 +291,8 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
return {"batch": [batch]}

@property
def default_prelearner_buffer_class(self):
def default_prelearner_buffer_class(self) -> ReplayBuffer:
"""Sets the default replay buffer."""
from ray.rllib.utils.replay_buffers.episode_replay_buffer import (
EpisodeReplayBuffer,
)
Expand All @@ -283,13 +301,56 @@ def default_prelearner_buffer_class(self):
return EpisodeReplayBuffer

@property
def default_prelearner_buffer_kwargs(self):
def default_prelearner_buffer_kwargs(self) -> Dict[str, Any]:
"""Sets the default arguments for the replay buffer.
Note, the `capacity` might vary with the size of the episodes or
sample batches in the offline dataset.
"""
return {
"capacity": self.config.train_batch_size_per_learner * 10,
"batch_size_B": self.config.train_batch_size_per_learner,
}

def _should_module_be_updated(self, module_id, multi_agent_batch=None):
def _validate_episodes(
self, episodes: List[SingleAgentEpisode]
) -> Set[SingleAgentEpisode]:
"""Validate episodes sampled from the dataset.
Note, our episode buffers cannot handle either duplicates nor
non-ordered fragmentations, i.e. fragments from episodes that do
not arrive in timestep order.
Args:
episodes: A list of `SingleAgentEpisode` instances sampled
from a dataset.
Returns:
A set of `SingleAgentEpisode` instances.
Raises:
ValueError: If not all episodes are `done`.
"""
# Ensure that episodes are all done.
if not all(eps.is_done for eps in episodes):
raise ValueError(
"When sampling from episodes (`input_read_episodes=True`) all "
"recorded episodes must be done (i.e. either `terminated=True`) "
"or `truncated=True`)."
)
# Ensure that episodes do not contain duplicates. Note, this can happen
# if the dataset is small and pulled batches contain multiple episodes.
unique_episode_ids = set()
episodes = {
eps
for eps in episodes
if eps.id_ not in unique_episode_ids
and not unique_episode_ids.add(eps.id_)
and eps.id_ not in self.episode_buffer.episode_id_to_index.keys()
}
return episodes

def _should_module_be_updated(self, module_id, multi_agent_batch=None) -> bool:
"""Checks which modules in a MultiRLModule should be updated."""
if not self._policies_to_train:
# In case of no update information, the module is updated.
Expand Down
Loading

0 comments on commit 0e04eb1

Please sign in to comment.