Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Make Episode.to_numpy optional at end of EnvRunner.sample() (renamed from finalize()). #49800

Merged
Merged
10 changes: 10 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.num_gpus_per_env_runner = 0
self.custom_resources_per_env_runner = {}
self.validate_env_runners_after_construction = True
self.episodes_to_numpy = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

self.max_requests_in_flight_per_env_runner = 1
self.sample_timeout_s = 60.0
self.create_env_on_local_worker = False
Expand Down Expand Up @@ -1758,6 +1759,7 @@ def env_runners(
rollout_fragment_length: Optional[Union[int, str]] = NotProvided,
batch_mode: Optional[str] = NotProvided,
explore: Optional[bool] = NotProvided,
episodes_to_numpy: Optional[bool] = NotProvided,
# @OldAPIStack settings.
exploration_config: Optional[dict] = NotProvided, # @OldAPIStack
create_env_on_local_worker: Optional[bool] = NotProvided, # @OldAPIStack
Expand Down Expand Up @@ -1910,6 +1912,10 @@ def env_runners(
explore: Default exploration behavior, iff `explore=None` is passed into
compute_action(s). Set to False for no exploration behavior (e.g.,
for evaluation).
episodes_to_numpy: Whether to numpy'ize episodes before
returning them from an EnvRunner. False by default. If True, EnvRunners
call `to_numpy()` on those episode (chunks) to be returned by
`EnvRunners.sample()`.
exploration_config: A dict specifying the Exploration object's config.
remote_worker_envs: If using num_envs_per_env_runner > 1, whether to create
those new envs in remote processes instead of in the same worker.
Expand Down Expand Up @@ -2034,6 +2040,10 @@ def env_runners(
self.batch_mode = batch_mode
if explore is not NotProvided:
self.explore = explore
if episodes_to_numpy is not NotProvided:
self.episodes_to_numpy = episodes_to_numpy

# @OldAPIStack
if exploration_config is not NotProvided:
# Override entire `exploration_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
Expand Down
2 changes: 2 additions & 0 deletions rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def __init__(self, algo_class=None):
# .env_runners()
# Set to `self.n_step`, if 'auto'.
self.rollout_fragment_length = "auto"

# .training()
self.train_batch_size_per_learner = 256
self.train_batch_size = 256 # @OldAPIstack
# Number of timesteps to collect from rollout workers before we start
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,19 @@ def __call__(
self.add_n_batch_items(
batch,
Columns.OBS,
# Add all observations, except the very last one.
# For a terminated episode, this is the terminal observation that
# has no value for training.
# For a truncated episode, algorithms either add an extra NEXT_OBS
# column to the batch (ex. DQN) or extend the episode length by one
# (using a separate connector piece and this truncated last obs),
# then bootstrap the value estimation for that extra timestep.
items_to_add=sa_episode.get_observations(slice(0, len(sa_episode))),
num_items=len(sa_episode),
single_agent_episode=sa_episode,
)
else:
assert not sa_episode.is_finalized
assert not sa_episode.is_numpy
self.add_batch_item(
batch,
Columns.OBS,
Expand Down
59 changes: 34 additions & 25 deletions rllib/connectors/common/add_states_from_episodes_to_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_initial_state(self):
output_batch = connector(
rl_module=rl_module,
batch={},
episodes=[episode.finalize()],
episodes=[episode.to_numpy()],
shared_data={},
)
check(
Expand Down Expand Up @@ -280,11 +280,8 @@ def __call__(
agents_that_stepped_only=not self._as_learner_connector,
):
if self._as_learner_connector:
assert sa_episode.is_finalized

# Multi-agent case: Extract correct single agent RLModule (to get the
# state for individually).
sa_module = rl_module
if sa_episode.module_id is not None:
sa_module = rl_module[sa_episode.module_id]
else:
Expand Down Expand Up @@ -327,38 +324,50 @@ def __call__(
else:
# Then simply use the `look_back_state`, i.e. in this case the
# initial state as `"state_in` in training.
state_outs = tree.map_structure(
lambda a: np.repeat(
a[np.newaxis, ...], len(sa_episode), axis=0
),
if sa_episode.is_numpy:
state_outs = tree.map_structure(
lambda a, _sae=sa_episode: np.repeat(
a[np.newaxis, ...], len(_sae), axis=0
),
look_back_state,
)
else:
state_outs = [look_back_state for _ in range(len(sa_episode))]
# Explanation:
# B=episode len // max_seq_len
# [::max_seq_len]: only keep every Tth state.
# [:-1]: Shift state outs by one; ignore very last
# STATE_OUT, but therefore add the lookback/init state at
# the beginning.
items_to_add = (
tree.map_structure(
lambda i, o, m=max_seq_len: np.concatenate([[i], o[:-1]])[::m],
look_back_state,
state_outs,
)
if sa_episode.is_numpy
else ([look_back_state] + state_outs[:-1])[::max_seq_len]
)
self.add_n_batch_items(
batch=batch,
column=Columns.STATE_IN,
# items_to_add.shape=(B,[state-dim])
# B=episode len // max_seq_len
items_to_add=tree.map_structure(
# Explanation:
# [::max_seq_len]: only keep every Tth state.
# [:-1]: Shift state outs by one, ignore very last
# STATE_OUT (but therefore add the lookback/init state at
# the beginning).
lambda i, o, m=max_seq_len: np.concatenate([[i], o[:-1]])[::m],
look_back_state,
state_outs,
),
items_to_add=items_to_add,
num_items=int(math.ceil(len(sa_episode) / max_seq_len)),
single_agent_episode=sa_episode,
)
if Columns.NEXT_OBS in batch:
items_to_add = (
tree.map_structure(
lambda i, m=max_seq_len: i[::m],
state_outs,
)
if sa_episode.is_numpy
else state_outs[::max_seq_len]
)
self.add_n_batch_items(
batch=batch,
column=Columns.NEXT_STATE_IN,
items_to_add=tree.map_structure(
lambda i, m=max_seq_len: i[::m],
state_outs,
),
items_to_add=items_to_add,
num_items=int(math.ceil(len(sa_episode) / max_seq_len)),
single_agent_episode=sa_episode,
)
Expand All @@ -382,7 +391,7 @@ def __call__(
single_agent_episode=sa_episode,
)
else:
assert not sa_episode.is_finalized
assert not sa_episode.is_numpy

# Multi-agent case: Extract correct single agent RLModule (to get the
# state for individually).
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/common/frame_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _map_fn(s, _sa_episode=sa_episode):
# Env-to-module pipeline. Episodes still operate on lists.
else:
for sa_episode in self.single_agent_episode_iterator(episodes):
assert not sa_episode.is_finalized
assert not sa_episode.is_numpy
# Get the list of observations to stack.
obs_stack = sa_episode.get_observations(
indices=slice(-self.num_frames, None),
Expand Down
3 changes: 0 additions & 3 deletions rllib/connectors/env_to_module/flatten_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ def __call__(
for sa_episode in self.single_agent_episode_iterator(
episodes, agents_that_stepped_only=True
):
# Episode is not finalized yet and thus still operates on lists of items.
assert not sa_episode.is_finalized

last_obs = sa_episode.get_observations(-1)

if self._multi_agent:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __call__(
episodes, agents_that_stepped_only=True
):
# Episode is not finalized yet and thus still operates on lists of items.
assert not sa_episode.is_finalized
assert not sa_episode.is_numpy

augmented_obs = {self.ORIG_OBS_KEY: sa_episode.get_observations(-1)}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __call__(
):
# Make sure episodes are NOT finalized yet (we are expecting to run in an
# env-to-module pipeline).
assert not sa_episode.is_finalized
assert not sa_episode.is_numpy
# Write new information into the episode.
sa_episode.set_observations(at_indices=-1, new_data=obs)
# Change the observation space of the sa_episode.
Expand Down
45 changes: 34 additions & 11 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,17 @@ def _sample_timesteps(
# the connector, if applicable).
self._make_on_episode_callback("on_episode_end")

# Finalize (numpy'ize) the episode.
self._episode.finalize(drop_zero_len_single_agent_episodes=True)
done_episodes_to_return.append(self._episode)
# Numpy'ize the episode.
if self.config.episodes_to_numpy:
# Any possibly compress observations.
done_episodes_to_return.append(
self._episode.to_numpy(
drop_zero_len_single_agent_episodes=True,
)
)
# Leave episode as lists of individual (obs, action, etc..) items.
else:
done_episodes_to_return.append(self._episode)

# Create a new episode instance.
self._episode = self._new_episode()
Expand Down Expand Up @@ -442,10 +450,18 @@ def _sample_timesteps(
if self._episode.env_t > 0:
self._episode.validate()
self._ongoing_episodes_for_metrics[self._episode.id_].append(self._episode)
# Return finalized (numpy'ized) Episodes.
ongoing_episodes_to_return.append(
self._episode.finalize(drop_zero_len_single_agent_episodes=True)
)

# Numpy'ize the episode.
if self.config.episodes_to_numpy:
# Any possibly compress observations.
ongoing_episodes_to_return.append(
self._episode.to_numpy(
drop_zero_len_single_agent_episodes=True,
)
)
# Leave episode as lists of individual (obs, action, etc..) items.
else:
ongoing_episodes_to_return.append(self._episode)

# Continue collecting into the cut Episode chunk.
self._episode = ongoing_episode_continuation
Expand Down Expand Up @@ -615,10 +631,17 @@ def _sample_episodes(
# the connector, if applicable).
self._make_on_episode_callback("on_episode_end", _episode)

# Finish the episode.
done_episodes_to_return.append(
_episode.finalize(drop_zero_len_single_agent_episodes=True)
)
# Numpy'ize the episode.
if self.config.episodes_to_numpy:
# Any possibly compress observations.
done_episodes_to_return.append(
_episode.to_numpy(
drop_zero_len_single_agent_episodes=True,
)
)
# Leave episode as lists of individual (obs, action, etc..) items.
else:
done_episodes_to_return.append(_episode)

# Also early-out if we reach the number of episodes within this
# for-loop.
Expand Down
Loading