Skip to content

Commit

Permalink
[RLlib; Offline RL] - Replace GAE in MARWILOfflinePreLearner with `…
Browse files Browse the repository at this point in the history
…GeneralAdvantageEstimation` connector in learner pipeline. (#47532)
  • Loading branch information
simonsays1980 authored Sep 9, 2024
1 parent 7648e76 commit 5e2d73d
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 268 deletions.
26 changes: 17 additions & 9 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.marwil.marwil_catalog import MARWILCatalog
from ray.rllib.algorithms.marwil.marwil_offline_prelearner import (
MARWILOfflinePreLearner,
)
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
from ray.rllib.connectors.learner import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddOneTsToEpisodesAndTruncate,
AddNextObservationsFromEpisodesToTrainBatch,
GeneralAdvantageEstimation,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
Expand Down Expand Up @@ -104,9 +101,6 @@ def __init__(self, algo_class=None):

# Override some of AlgorithmConfig's default values with MARWIL-specific values.

# Define the `OfflinePreLearner` class for `MARWIL`.
self.prelearner_class = MARWILOfflinePreLearner

# You should override input_ to point to an offline dataset
# (see algorithm.py and algorithm_config.py).
# The dataset may have an arbitrary number of timesteps
Expand Down Expand Up @@ -283,13 +277,27 @@ def build_learner_connector(
device=device,
)

# Before anything, add one ts to each episode (and record this in the loss
# mask, so that the computations at this extra ts are not used to compute
# the loss).
pipeline.prepend(AddOneTsToEpisodesAndTruncate())

# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)

# At the end of the pipeline (when the batch is already completed), add the
# GAE connector, which performs a vf forward pass, then computes the GAE
# computations, and puts the results of this (advantages, value targets)
# directly back in the batch. This is then the batch used for
# `forward_train` and `compute_losses`.
pipeline.append(
GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_)
)

return pipeline

@override(AlgorithmConfig)
Expand Down
237 changes: 0 additions & 237 deletions rllib/algorithms/marwil/marwil_offline_prelearner.py

This file was deleted.

41 changes: 24 additions & 17 deletions rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

import ray
import ray.rllib.algorithms.marwil as marwil
from ray.rllib.algorithms.marwil.marwil_offline_prelearner import (
MARWILOfflinePreLearner,
)
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
Expand Down Expand Up @@ -169,51 +168,59 @@ def test_marwil_loss_function(self):
batch = algo.offline_data.data.take_batch(2000)

# Create the prelearner and compute advantages and values.
offline_prelearner = MARWILOfflinePreLearner(
config, algo.learner_group._learner
)
offline_prelearner = OfflinePreLearner(config, algo.learner_group._learner)
# Note, for `ray.data`'s pipeline everything has to be a dictionary
# therefore the batch is embedded into another dictionary.
batch = offline_prelearner(batch)["batch"][0]
if Columns.LOSS_MASK in batch[DEFAULT_MODULE_ID]:
loss_mask = (
batch[DEFAULT_MODULE_ID][Columns.LOSS_MASK].detach().cpu().numpy()
)
num_valid = np.sum(loss_mask)

def possibly_masked_mean(data_):
return np.sum(data_[loss_mask]) / num_valid

else:
possibly_masked_mean = np.mean

# Calculate our own expected values (to then compare against the
# agent's loss output).
MODULE_ID = "default_policy"
fwd_out = (
algo.learner_group._learner.module[MODULE_ID]
algo.learner_group._learner.module[DEFAULT_MODULE_ID]
.unwrapped()
.forward_train({k: v for k, v in batch[MODULE_ID].items()})
.forward_train({k: v for k, v in batch[DEFAULT_MODULE_ID].items()})
)
advantages = batch[MODULE_ID][Columns.ADVANTAGES].detach().cpu().numpy()
advantages_squared = np.mean(np.square(advantages))
advantages = batch[DEFAULT_MODULE_ID][Columns.ADVANTAGES].detach().cpu().numpy()
advantages_squared = possibly_masked_mean(np.square(advantages))
c_2 = 100.0 + 1e-8 * (advantages_squared - 100.0)
c = np.sqrt(c_2)
exp_advantages = np.exp(config.beta * (advantages / c))
action_dist_cls = (
algo.learner_group._learner.module[MODULE_ID]
algo.learner_group._learner.module[DEFAULT_MODULE_ID]
.unwrapped()
.get_train_action_dist_cls()
)
# Note we need the actual model's logits not the ones from the data set
# stored in `batch[Columns.ACTION_DIST_INPUTS]`.
action_dist = action_dist_cls.from_logits(fwd_out[Columns.ACTION_DIST_INPUTS])
logp = action_dist.logp(batch[MODULE_ID][Columns.ACTIONS])
logp = action_dist.logp(batch[DEFAULT_MODULE_ID][Columns.ACTIONS])
logp = logp.detach().cpu().numpy()

# Calculate all expected loss components.
expected_vf_loss = 0.5 * advantages_squared
expected_pol_loss = -1.0 * np.mean(exp_advantages * logp)
expected_pol_loss = -1.0 * possibly_masked_mean(exp_advantages * logp)
expected_loss = expected_pol_loss + config.vf_coeff * expected_vf_loss

# Calculate the algorithm's loss (to check against our own
# calculation above).
total_loss = algo.learner_group._learner.compute_loss_for_module(
module_id=MODULE_ID,
batch={k: v for k, v in batch[MODULE_ID].items()},
module_id=DEFAULT_MODULE_ID,
batch={k: v for k, v in batch[DEFAULT_MODULE_ID].items()},
fwd_out=fwd_out,
config=config,
)
learner_results = algo.learner_group._learner.metrics.peek(MODULE_ID)
learner_results = algo.learner_group._learner.metrics.peek(DEFAULT_MODULE_ID)

# Check all components.
check(
Expand Down
Loading

0 comments on commit 5e2d73d

Please sign in to comment.