Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] - Replace GAE in MARWILOfflinePreLearner with GeneralAdvantageEstimation connector in learner pipeline. #47532

26 changes: 17 additions & 9 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.marwil.marwil_catalog import MARWILCatalog
from ray.rllib.algorithms.marwil.marwil_offline_prelearner import (
MARWILOfflinePreLearner,
)
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
from ray.rllib.connectors.learner import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddOneTsToEpisodesAndTruncate,
AddNextObservationsFromEpisodesToTrainBatch,
GeneralAdvantageEstimation,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
Expand Down Expand Up @@ -104,9 +101,6 @@ def __init__(self, algo_class=None):

# Override some of AlgorithmConfig's default values with MARWIL-specific values.

# Define the `OfflinePreLearner` class for `MARWIL`.
self.prelearner_class = MARWILOfflinePreLearner

# You should override input_ to point to an offline dataset
# (see algorithm.py and algorithm_config.py).
# The dataset may have an arbitrary number of timesteps
Expand Down Expand Up @@ -283,13 +277,27 @@ def build_learner_connector(
device=device,
)

# Before anything, add one ts to each episode (and record this in the loss
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
# mask, so that the computations at this extra ts are not used to compute
# the loss).
pipeline.prepend(AddOneTsToEpisodesAndTruncate())

# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)

# At the end of the pipeline (when the batch is already completed), add the
# GAE connector, which performs a vf forward pass, then computes the GAE
# computations, and puts the results of this (advantages, value targets)
# directly back in the batch. This is then the batch used for
# `forward_train` and `compute_losses`.
pipeline.append(
GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_)
)

return pipeline

@override(AlgorithmConfig)
Expand Down
237 changes: 0 additions & 237 deletions rllib/algorithms/marwil/marwil_offline_prelearner.py

This file was deleted.

8 changes: 4 additions & 4 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
episodes = OfflinePreLearner._map_sample_batch_to_episode(
self._is_multi_agent,
batch,
finalize=False,
finalize=True,
schema=SCHEMA | self.config.input_read_schema,
input_compress_columns=self.config.input_compress_columns,
)["episodes"]
Expand All @@ -160,7 +160,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
self._is_multi_agent,
batch,
schema=SCHEMA | self.config.input_read_schema,
finalize=False,
finalize=True,
input_compress_columns=self.config.input_compress_columns,
observation_space=self.observation_space,
action_space=self.action_space,
Expand Down Expand Up @@ -285,7 +285,7 @@ def convert(sample, space):
else:
# Build a single-agent episode with a single row of the batch.
episode = SingleAgentEpisode(
id_=batch[schema[Columns.EPS_ID]][i],
id_=str(batch[schema[Columns.EPS_ID]][i]),
agent_id=agent_id,
# Observations might be (a) serialized and/or (b) converted
# to a JSONable (when a composite space was used). We unserialize
Expand Down Expand Up @@ -412,7 +412,7 @@ def _map_sample_batch_to_episode(
)
# Create a `SingleAgentEpisode`.
episode = SingleAgentEpisode(
id_=batch[schema[Columns.EPS_ID]][i][0],
id_=str(batch[schema[Columns.EPS_ID]][i][0]),
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
agent_id=agent_id,
observations=obs,
infos=(
Expand Down
7 changes: 7 additions & 0 deletions rllib/tuned_examples/cql/pendulum_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
)
.offline_data(
input_=[data_path.as_posix()],
# Define the number of reading blocks, these should be larger than 1
# and aligned with the data size.
input_read_method_kwargs={"override_num_blocks": max(args.num_gpus, 2)},
# Concurrency defines the number of processes that run the
# `map_batches` transformations. This should be aligned with the
# 'prefetch_batches' argument in 'iter_batches_kwargs'.
map_batches_kwargs={"concurrency": max(2, args.num_gpus * 2)},
actions_in_input_normalized=True,
dataset_num_iters_per_learner=1 if args.num_gpus == 0 else None,
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/tuned_examples/marwil/cartpole_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# as remote learners.
.offline_data(
input_=[data_path.as_posix()],
input_read_method_kwargs={"override_num_blocks": max(args.num_gpus, 1)},
input_read_method_kwargs={"override_num_blocks": max(args.num_gpus, 2)},
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
prelearner_module_synch_period=20,
dataset_num_iters_per_learner=1 if args.num_gpus == 0 else None,
)
Expand Down