Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] - Replace GAE in MARWILOfflinePreLearner with GeneralAdvantageEstimation connector in learner pipeline. #47532

26 changes: 17 additions & 9 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.marwil.marwil_catalog import MARWILCatalog
from ray.rllib.algorithms.marwil.marwil_offline_prelearner import (
MARWILOfflinePreLearner,
)
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
from ray.rllib.connectors.learner import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddOneTsToEpisodesAndTruncate,
AddNextObservationsFromEpisodesToTrainBatch,
GeneralAdvantageEstimation,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
Expand Down Expand Up @@ -104,9 +101,6 @@ def __init__(self, algo_class=None):

# Override some of AlgorithmConfig's default values with MARWIL-specific values.

# Define the `OfflinePreLearner` class for `MARWIL`.
self.prelearner_class = MARWILOfflinePreLearner

# You should override input_ to point to an offline dataset
# (see algorithm.py and algorithm_config.py).
# The dataset may have an arbitrary number of timesteps
Expand Down Expand Up @@ -283,13 +277,27 @@ def build_learner_connector(
device=device,
)

# Before anything, add one ts to each episode (and record this in the loss
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
# mask, so that the computations at this extra ts are not used to compute
# the loss).
pipeline.prepend(AddOneTsToEpisodesAndTruncate())

# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)

# At the end of the pipeline (when the batch is already completed), add the
# GAE connector, which performs a vf forward pass, then computes the GAE
# computations, and puts the results of this (advantages, value targets)
# directly back in the batch. This is then the batch used for
# `forward_train` and `compute_losses`.
pipeline.append(
GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_)
)

return pipeline

@override(AlgorithmConfig)
Expand Down
237 changes: 0 additions & 237 deletions rllib/algorithms/marwil/marwil_offline_prelearner.py

This file was deleted.

41 changes: 24 additions & 17 deletions rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

import ray
import ray.rllib.algorithms.marwil as marwil
from ray.rllib.algorithms.marwil.marwil_offline_prelearner import (
MARWILOfflinePreLearner,
)
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
Expand Down Expand Up @@ -169,51 +168,59 @@ def test_marwil_loss_function(self):
batch = algo.offline_data.data.take_batch(2000)

# Create the prelearner and compute advantages and values.
offline_prelearner = MARWILOfflinePreLearner(
config, algo.learner_group._learner
)
offline_prelearner = OfflinePreLearner(config, algo.learner_group._learner)
# Note, for `ray.data`'s pipeline everything has to be a dictionary
# therefore the batch is embedded into another dictionary.
batch = offline_prelearner(batch)["batch"][0]
if Columns.LOSS_MASK in batch[DEFAULT_MODULE_ID]:
loss_mask = (
batch[DEFAULT_MODULE_ID][Columns.LOSS_MASK].detach().cpu().numpy()
)
num_valid = np.sum(loss_mask)

def possibly_masked_mean(data_):
return np.sum(data_[loss_mask]) / num_valid

else:
possibly_masked_mean = np.mean

# Calculate our own expected values (to then compare against the
# agent's loss output).
MODULE_ID = "default_policy"
fwd_out = (
algo.learner_group._learner.module[MODULE_ID]
algo.learner_group._learner.module[DEFAULT_MODULE_ID]
.unwrapped()
.forward_train({k: v for k, v in batch[MODULE_ID].items()})
.forward_train({k: v for k, v in batch[DEFAULT_MODULE_ID].items()})
)
advantages = batch[MODULE_ID][Columns.ADVANTAGES].detach().cpu().numpy()
advantages_squared = np.mean(np.square(advantages))
advantages = batch[DEFAULT_MODULE_ID][Columns.ADVANTAGES].detach().cpu().numpy()
advantages_squared = possibly_masked_mean(np.square(advantages))
c_2 = 100.0 + 1e-8 * (advantages_squared - 100.0)
c = np.sqrt(c_2)
exp_advantages = np.exp(config.beta * (advantages / c))
action_dist_cls = (
algo.learner_group._learner.module[MODULE_ID]
algo.learner_group._learner.module[DEFAULT_MODULE_ID]
.unwrapped()
.get_train_action_dist_cls()
)
# Note we need the actual model's logits not the ones from the data set
# stored in `batch[Columns.ACTION_DIST_INPUTS]`.
action_dist = action_dist_cls.from_logits(fwd_out[Columns.ACTION_DIST_INPUTS])
logp = action_dist.logp(batch[MODULE_ID][Columns.ACTIONS])
logp = action_dist.logp(batch[DEFAULT_MODULE_ID][Columns.ACTIONS])
logp = logp.detach().cpu().numpy()

# Calculate all expected loss components.
expected_vf_loss = 0.5 * advantages_squared
expected_pol_loss = -1.0 * np.mean(exp_advantages * logp)
expected_pol_loss = -1.0 * possibly_masked_mean(exp_advantages * logp)
expected_loss = expected_pol_loss + config.vf_coeff * expected_vf_loss

# Calculate the algorithm's loss (to check against our own
# calculation above).
total_loss = algo.learner_group._learner.compute_loss_for_module(
module_id=MODULE_ID,
batch={k: v for k, v in batch[MODULE_ID].items()},
module_id=DEFAULT_MODULE_ID,
batch={k: v for k, v in batch[DEFAULT_MODULE_ID].items()},
fwd_out=fwd_out,
config=config,
)
learner_results = algo.learner_group._learner.metrics.peek(MODULE_ID)
learner_results = algo.learner_group._learner.metrics.peek(DEFAULT_MODULE_ID)

# Check all components.
check(
Expand Down
Loading
Loading