Skip to content

Commit

Permalink
[RLlib] Make Episode.to_numpy optional at end of `EnvRunner.sample(…
Browse files Browse the repository at this point in the history
…)` (renamed from `finalize()`). (#49800)
  • Loading branch information
sven1977 authored Jan 14, 2025
1 parent 8607755 commit 74bc097
Show file tree
Hide file tree
Showing 37 changed files with 483 additions and 278 deletions.
14 changes: 7 additions & 7 deletions doc/source/rllib/doc_code/sa_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@

# __rllib-sa-episode-03-begin__

# Episodes start in the non-finalized state (in which data is stored
# under the hood in python lists).
assert episode.is_finalized is False
# Episodes start in the non-numpy'ized state (in which data is stored
# under the hood in lists).
assert episode.is_numpy is False

# Call `finalize()` to convert all stored data from lists of individual (possibly
# Call `to_numpy()` to convert all stored data from lists of individual (possibly
# complex) items to numpy arrays. Note that RLlib normally performs this method call,
# so users don't need to call `finalize()` themselves.
episode.finalize()
assert episode.is_finalized is True
# so users don't need to call `to_numpy()` themselves.
episode.to_numpy()
assert episode.is_numpy is True

# __rllib-sa-episode-03-end__

Expand Down
4 changes: 2 additions & 2 deletions doc/source/rllib/package_ref/env/multi_agent_episode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Getting basic information
~MultiAgentEpisode.get_return
~MultiAgentEpisode.get_duration_s
~MultiAgentEpisode.is_done
~MultiAgentEpisode.is_finalized
~MultiAgentEpisode.is_numpy
~MultiAgentEpisode.env_steps
~MultiAgentEpisode.agent_steps

Expand Down Expand Up @@ -81,4 +81,4 @@ Creating and handling episode chunks
~MultiAgentEpisode.cut
~MultiAgentEpisode.slice
~MultiAgentEpisode.concat_episode
~MultiAgentEpisode.finalize
~MultiAgentEpisode.to_numpy
4 changes: 2 additions & 2 deletions doc/source/rllib/package_ref/env/single_agent_episode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Getting basic information
~SingleAgentEpisode.get_return
~SingleAgentEpisode.get_duration_s
~SingleAgentEpisode.is_done
~SingleAgentEpisode.is_finalized
~SingleAgentEpisode.is_numpy
~SingleAgentEpisode.env_steps

Getting environment data
Expand Down Expand Up @@ -68,4 +68,4 @@ Creating and handling episode chunks
~SingleAgentEpisode.cut
~SingleAgentEpisode.slice
~SingleAgentEpisode.concat_episode
~SingleAgentEpisode.finalize
~SingleAgentEpisode.to_numpy
20 changes: 10 additions & 10 deletions doc/source/rllib/rllib-offline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,7 @@ The following example demonstrates how to use a custom :py:class:`~ray.rllib.off
is_multi_agent: bool,
batch: Dict[str, Union[list, np.ndarray]],
schema: Dict[str, str] = SCHEMA,
finalize: bool = False,
to_numpy: bool = False,
input_compress_columns: Optional[List[str]] = None,
observation_space: gym.Space = None,
action_space: gym.Space = None,
Expand Down Expand Up @@ -1404,9 +1404,9 @@ The following example demonstrates how to use a custom :py:class:`~ray.rllib.off
t_started=0,
)

# If episodes should be finalized. Some connectors need this.
if finalize:
episode.finalize()
# If episodes should be numpy'ized. Some connectors need this.
if to_numpy:
episode.to_numpy()

# Append the episode to the list of episodes.
episodes.append(episode)
Expand All @@ -1430,7 +1430,7 @@ The following example demonstrates how to use a custom :py:class:`~ray.rllib.off
episodes = TextOfflinePreLearner._map_to_episodes(
is_multi_agent=False,
batch=batch,
finalize=True,
to_numpy=True,
schema=None,
input_compress_columns=False,
action_space=None,
Expand Down Expand Up @@ -1527,7 +1527,7 @@ The preceding example illustrates the flexibility of RLlib's Offline RL API for
episodes = TextOfflinePreLearner._map_to_episodes(
is_multi_agent=False,
batch=batch,
finalize=True,
to_numpy=True,
schema=None,
input_compress_columns=False,
action_space=self.spaces[0],
Expand Down Expand Up @@ -1564,7 +1564,7 @@ The preceding example illustrates the flexibility of RLlib's Offline RL API for
is_multi_agent: bool,
batch: Dict[str, Union[list, np.ndarray]],
schema: Dict[str, str] = SCHEMA,
finalize: bool = False,
to_numpy: bool = False,
input_compress_columns: Optional[List[str]] = None,
observation_space: gym.Space = None,
action_space: gym.Space = None,
Expand Down Expand Up @@ -1615,9 +1615,9 @@ The preceding example illustrates the flexibility of RLlib's Offline RL API for
t_started=0,
)

# If episodes should be finalized. Some connectors need this.
if finalize:
episode.finalize()
# If episodes should be numpy'ized. Some connectors need this.
if to_numpy:
episode.to_numpy()

# Append the episode to the list of episodes.
episodes.append(episode)
Expand Down
32 changes: 16 additions & 16 deletions doc/source/rllib/single-agent-episode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ and extract information from this episode using its different "getter" methods:
**SingleAgentEpisode getter APIs**: "getter" methods exist for all of the Episode's fields, which are `observations`,
`actions`, `rewards`, `infos`, and `extra_model_outputs`. For simplicity, only the getters for observations, actions, and rewards
are shown here. Their behavior is intuitive, returning a single item when provided with a single index and returning a list of items
(in the non-finalized case; see further below) when provided with a list of indices or a slice (range) of indices.
(in the non-numpy'ized case; see further below) when provided with a list of indices or a slice of indices.


Note that for `extra_model_outputs`, the getter is slightly more complicated as there exist sub-keys in this data (for example:
Expand All @@ -107,24 +107,24 @@ The following code snippet summarizes the various capabilities of the different
:end-before: rllib-sa-episode-02-end


Finalized and Non-Finalized Episodes
------------------------------------
Numpy'ized and non-numpy'ized Episodes
--------------------------------------

The data in a :py:class:`~ray.rllib.env.single_agent_episode.SingleAgentEpisode` can exist in two states:
non-finalized and finalized. A non-finalized episode stores its data items in plain python lists
and appends new timestep data to these. In a finalized episode,
these lists have been converted into (possibly complex) structures that have NumPy arrays at their leafs.
Note that a "finalized" episode doesn't necessarily have to be terminated or truncated yet
in the sense that the underlying RL environment declared the episode to be over (or has reached some
maximum number of timesteps).
non-numpy'ized and numpy'ized. A non-numpy'ized episode stores its data items in plain python lists
and appends new timestep data to these. In a numpy'ized episode,
these lists have been converted into possibly complex structures that have NumPy arrays at their leafs.
Note that a numpy'ized episode doesn't necessarily have to be terminated or truncated yet
in the sense that the underlying RL environment declared the episode to be over or has reached some
maximum number of timesteps.

.. figure:: images/episodes/sa_episode_non_finalized_vs_finalized.svg
:width: 900
:align: left


:py:class:`~ray.rllib.env.single_agent_episode.SingleAgentEpisode` objects start in the non-finalized
state (data stored in python lists), making it very fast to append data from an ongoing episode:
:py:class:`~ray.rllib.env.single_agent_episode.SingleAgentEpisode` objects start in the non-numpy'ized
state, in which data is stored in python lists, making it very fast to append data from an ongoing episode:


.. literalinclude:: doc_code/sa_episode.py
Expand All @@ -133,22 +133,22 @@ state (data stored in python lists), making it very fast to append data from an
:end-before: rllib-sa-episode-03-end


To illustrate the differences between the data stored in a non-finalized episode vs. the same data stored in
a finalized one, take a look at this complex observation example here, showing the exact same observation data in two
episodes (one non-finalized the other finalized):
To illustrate the differences between the data stored in a non-numpy'ized episode vs. the same data stored in
a numpy'ized one, take a look at this complex observation example here, showing the exact same observation data in two
episodes (one non-numpy'ized the other numpy'ized):

.. figure:: images/episodes/sa_episode_non_finalized.svg
:width: 800
:align: left

**Complex observations in a non-finalized episode**: Each individual observation is a (complex) dict matching the
**Complex observations in a non-numpy'ized episode**: Each individual observation is a (complex) dict matching the
gymnasium environment's observation space. There are three such observation items stored in the episode so far.

.. figure:: images/episodes/sa_episode_finalized.svg
:width: 600
:align: left

**Complex observations in a finalized episode**: The entire observation record is a single (complex) dict matching the
**Complex observations in a numpy'ized episode**: The entire observation record is a single complex dict matching the
gymnasium environment's observation space. At the leafs of the structure are `NDArrays` holding the individual values of the leaf.
Note that these `NDArrays` have an extra batch dim (axis=0), whose length matches the length of the episode stored (here 3).

Expand Down
73 changes: 72 additions & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import ray.cloudpickle as pickle
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.registry import ALGORITHMS_CLASS_TO_NAME as ALL_ALGORITHMS
from ray.rllib.algorithms.utils import AggregatorActor
from ray.rllib.callbacks.utils import make_callback
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.core import (
Expand Down Expand Up @@ -80,6 +81,7 @@
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils import deep_update, FilterManager, force_list
from ray.rllib.utils.actor_manager import FaultTolerantActorManager, RemoteCallResults
from ray.rllib.utils.annotations import (
DeveloperAPI,
ExperimentalAPI,
Expand All @@ -106,6 +108,7 @@
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.metrics import (
AGGREGATOR_ACTOR_RESULTS,
ALL_MODULES,
ENV_RUNNER_RESULTS,
ENV_RUNNER_SAMPLING_TIMER,
Expand Down Expand Up @@ -878,6 +881,53 @@ def setup(self, config: AlgorithmConfig) -> None:
# need it for reading recorded experiences.
self.offline_data.spaces = spaces

# Create an Aggregator actor set, if necessary.
self._aggregator_actor_manager = None
if self.config.enable_rl_module_and_learner and (
self.config.num_aggregator_actors_per_learner > 0
):
# Get the devices of each learner.
learner_locations = self.learner_group.foreach_learner(
func=lambda _learner: (_learner.node, _learner.device),
)
rl_module_spec = self.config.get_multi_rl_module_spec(
spaces=self.env_runner_group.get_spaces(),
inference_only=False,
)
agg_cls = ray.remote(
num_cpus=1,
num_gpus=0.01 if self.config.num_gpus_per_learner > 0 else 0,
max_restarts=-1,
)(AggregatorActor)
self._aggregator_actor_manager = FaultTolerantActorManager(
[
agg_cls.remote(self.config, rl_module_spec)
for _ in range(
(self.config.num_learners or 1)
* self.config.num_aggregator_actors_per_learner
)
],
max_remote_requests_in_flight_per_actor=(
self.config.max_requests_in_flight_per_aggregator_actor
),
)
aggregator_locations = self._aggregator_actor_manager.foreach_actor(
func=lambda actor: (actor._node, actor._device)
)
self._aggregator_actor_to_learner = {}
for agg_idx, aggregator_location in enumerate(aggregator_locations):
for learner_idx, learner_location in enumerate(learner_locations):
if learner_location.get() == aggregator_location.get():
self._aggregator_actor_to_learner[agg_idx] = learner_idx
break
if agg_idx not in self._aggregator_actor_to_learner:
raise RuntimeError(
"No Learner worker found that matches aggregation worker "
f"#{agg_idx}'s node ({aggregator_location[0]}) and device "
f"({aggregator_location[1]})! The Learner workers' locations "
f"are {learner_locations}."
)

# Run `on_algorithm_init` callback after initialization is done.
make_callback(
"on_algorithm_init",
Expand Down Expand Up @@ -931,7 +981,6 @@ def step(self) -> ResultDict:
and (self.iteration + 1) % self.config.evaluation_interval == 0
)
# Results dict for training (and if appolicable: evaluation).
train_results: ResultDict = {}
eval_results: ResultDict = {}

# Parallel eval + training: Kick off evaluation-loop and parallel train() call.
Expand Down Expand Up @@ -3325,6 +3374,28 @@ def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
clear_on_reduce=True,
)

if self.config.num_aggregator_actors_per_learner:
remote_aggregator_metrics: RemoteCallResults = (
self._aggregator_actor_manager.fetch_ready_async_reqs(
timeout_seconds=0.0,
return_obj_refs=False,
tags="metrics",
)
)
self._aggregator_actor_manager.foreach_actor_async(
func=lambda actor: actor.get_metrics(),
tag="metrics",
)

FaultTolerantActorManager.handle_remote_call_result_errors(
remote_aggregator_metrics,
ignore_ray_errors=False,
)
self.metrics.merge_and_log_n_dicts(
[res.get() for res in remote_aggregator_metrics.result_or_errors],
key=AGGREGATOR_ACTOR_RESULTS,
)

# Only here (at the end of the iteration), reduce the results into a single
# result dict.
return self.metrics.reduce(), train_iter_ctx
Expand Down
10 changes: 10 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.num_gpus_per_env_runner = 0
self.custom_resources_per_env_runner = {}
self.validate_env_runners_after_construction = True
self.episodes_to_numpy = True
self.max_requests_in_flight_per_env_runner = 1
self.sample_timeout_s = 60.0
self.create_env_on_local_worker = False
Expand Down Expand Up @@ -1754,6 +1755,7 @@ def env_runners(
rollout_fragment_length: Optional[Union[int, str]] = NotProvided,
batch_mode: Optional[str] = NotProvided,
explore: Optional[bool] = NotProvided,
episodes_to_numpy: Optional[bool] = NotProvided,
# @OldAPIStack settings.
exploration_config: Optional[dict] = NotProvided, # @OldAPIStack
create_env_on_local_worker: Optional[bool] = NotProvided, # @OldAPIStack
Expand Down Expand Up @@ -1906,6 +1908,10 @@ def env_runners(
explore: Default exploration behavior, iff `explore=None` is passed into
compute_action(s). Set to False for no exploration behavior (e.g.,
for evaluation).
episodes_to_numpy: Whether to numpy'ize episodes before
returning them from an EnvRunner. False by default. If True, EnvRunners
call `to_numpy()` on those episode (chunks) to be returned by
`EnvRunners.sample()`.
exploration_config: A dict specifying the Exploration object's config.
remote_worker_envs: If using num_envs_per_env_runner > 1, whether to create
those new envs in remote processes instead of in the same worker.
Expand Down Expand Up @@ -2030,6 +2036,10 @@ def env_runners(
self.batch_mode = batch_mode
if explore is not NotProvided:
self.explore = explore
if episodes_to_numpy is not NotProvided:
self.episodes_to_numpy = episodes_to_numpy

# @OldAPIStack
if exploration_config is not NotProvided:
# Override entire `exploration_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
Expand Down
8 changes: 4 additions & 4 deletions rllib/algorithms/dreamerv3/utils/env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,8 @@ def _sample(

eps += 1

# Then finalize (numpy'ize) the episode.
done_episodes_to_return.append(episodes[env_index].finalize())
# Then numpy'ize the episode.
done_episodes_to_return.append(episodes[env_index].to_numpy())

# Also early-out if we reach the number of episodes within this
# for-loop.
Expand Down Expand Up @@ -447,8 +447,8 @@ def _sample(
continue
episode.validate()
self._ongoing_episodes_for_metrics[episode.id_].append(episode)
# Return finalized (numpy'ized) Episodes.
ongoing_episodes_to_return.append(episode.finalize())
# Return numpy'ized Episodes.
ongoing_episodes_to_return.append(episode.to_numpy())

# Continue collecting into the cut Episode chunks.
self._episodes = ongoing_episodes_continuations
Expand Down
Loading

0 comments on commit 74bc097

Please sign in to comment.