Skip to content

Commit

Permalink
[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce gen…
Browse files Browse the repository at this point in the history
…eric `_forward` to further simplify the user experience). (#47889)
  • Loading branch information
sven1977 authored Oct 5, 2024
1 parent cbde03c commit e182e19
Show file tree
Hide file tree
Showing 29 changed files with 545 additions and 585 deletions.
13 changes: 8 additions & 5 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.apis import TargetNetworkAPI, ValueFunctionAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
Expand All @@ -35,6 +35,10 @@ def compute_loss_for_module(
batch: Dict,
fwd_out: Dict[str, TensorType],
) -> TensorType:
module = self.module[module_id].unwrapped()
assert isinstance(module, TargetNetworkAPI)
assert isinstance(module, ValueFunctionAPI)

# TODO (sven): Now that we do the +1ts trick to be less vulnerable about
# bootstrap values at the end of rollouts in the new stack, we might make
# this a more flexible, configurable parameter for users, e.g.
Expand All @@ -51,10 +55,9 @@ def compute_loss_for_module(
)
size_loss_mask = torch.sum(loss_mask)

module = self.module[module_id].unwrapped()
assert isinstance(module, TargetNetworkAPI)

values = fwd_out[Columns.VF_PREDS]
values = module.compute_values(
batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
)

action_dist_cls_train = module.get_train_action_dist_cls()
target_policy_dist = action_dist_cls_train.from_logits(
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/bc/bc_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def build_pi_head(self, framework: str) -> Model:
The default behavior is to build the head from the pi_head_config.
This can be overridden to build a custom policy head as a means of configuring
the behavior of a BCRLModule implementation.
the behavior of a BC specific RLModule implementation.
Args:
framework: The framework to use. Either "torch" or "tf2".
Expand Down
82 changes: 0 additions & 82 deletions rllib/algorithms/bc/bc_rl_module.py

This file was deleted.

32 changes: 29 additions & 3 deletions rllib/algorithms/bc/torch/bc_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
from ray.rllib.algorithms.bc.bc_rl_module import BCRLModule
from typing import Any, Dict

from ray.rllib.core import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.utils.annotations import override


class BCTorchRLModule(TorchRLModule):
@override(RLModule)
def setup(self):
# __sphinx_doc_begin__
# Build models from catalog
self.encoder = self.catalog.build_encoder(framework=self.framework)
self.pi = self.catalog.build_pi_head(framework=self.framework)

@override(RLModule)
def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]:
"""Generic BC forward pass (for all phases of training/evaluation)."""
output = {}

# State encodings.
encoder_outs = self.encoder(batch)
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Actions.
action_logits = self.pi(encoder_outs[ENCODER_OUT])
output[Columns.ACTION_DIST_INPUTS] = action_logits

class BCTorchRLModule(TorchRLModule, BCRLModule):
pass
return output
14 changes: 8 additions & 6 deletions rllib/algorithms/impala/torch/impala_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def compute_loss_for_module(
batch: Dict,
fwd_out: Dict[str, TensorType],
) -> TensorType:
module = self.module[module_id].unwrapped()

# TODO (sven): Now that we do the +1ts trick to be less vulnerable about
# bootstrap values at the end of rollouts in the new stack, we might make
# this a more flexible, configurable parameter for users, e.g.
Expand All @@ -46,17 +48,17 @@ def compute_loss_for_module(

# Behavior actions logp and target actions logp.
behaviour_actions_logp = batch[Columns.ACTION_LOGP]
target_policy_dist = (
self.module[module_id]
.unwrapped()
.get_train_action_dist_cls()
.from_logits(fwd_out[Columns.ACTION_DIST_INPUTS])
target_policy_dist = module.get_train_action_dist_cls().from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)
target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS])

# Values and bootstrap values.
values = module.compute_values(
batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
)
values_time_major = make_time_major(
fwd_out[Columns.VF_PREDS],
values,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
Expand Down
6 changes: 2 additions & 4 deletions rllib/algorithms/marwil/marwil_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

@DeveloperAPI(stability="alpha")
class MARWILRLModule(RLModule, ValueFunctionAPI, abc.ABC):
@override(RLModule)
def setup(self):
# Build models from catalog
self.encoder = self.catalog.build_actor_critic_encoder(framework=self.framework)
Expand Down Expand Up @@ -47,7 +48,4 @@ def input_specs_train(self) -> SpecDict:

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [
Columns.VF_PREDS,
Columns.ACTION_DIST_INPUTS,
]
return [Columns.ACTION_DIST_INPUTS]
9 changes: 5 additions & 4 deletions rllib/algorithms/marwil/torch/marwil_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def compute_loss_for_module(
batch: Dict[str, Any],
fwd_out: Dict[str, TensorType]
) -> TensorType:
module = self.module[module_id].unwrapped()

# Possibly apply masking to some sub loss terms and to the total loss term
# at the end. Masking could be used for RNN-based model (zero padded `batch`)
Expand All @@ -46,9 +47,7 @@ def possibly_masked_mean(data_):
else:
possibly_masked_mean = torch.mean

action_dist_class_train = (
self.module[module_id].unwrapped().get_train_action_dist_cls()
)
action_dist_class_train = module.get_train_action_dist_cls()
curr_action_dist = action_dist_class_train.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)
Expand All @@ -64,7 +63,9 @@ def possibly_masked_mean(data_):
# Otherwise, compute advantages.
else:
# cumulative_rewards = batch[Columns.ADVANTAGES]
value_fn_out = fwd_out[Columns.VF_PREDS]
value_fn_out = module.compute_values(
batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
)
advantages = batch[Columns.VALUE_TARGETS] - value_fn_out
advantages_squared_mean = possibly_masked_mean(torch.pow(advantages, 2.0))

Expand Down
32 changes: 18 additions & 14 deletions rllib/algorithms/marwil/torch/marwil_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.algorithms.marwil.marwil_rl_module import MARWILRLModule
from ray.rllib.core.columns import Columns
Expand Down Expand Up @@ -42,7 +42,6 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"flag `inference_only=False` when building the module."
)
output = {}

# Shared encoder.
encoder_outs = self.encoder(batch)
if Columns.STATE_OUT in encoder_outs:
Expand All @@ -63,18 +62,23 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
# (similar to IMPALA's v-trace architecture). This would also get rid of the
# second Connector pass currently necessary.
@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
if self.is_stateful():
# The recurrent encoders expect a `(state_in, h)` key in the
# input dict while the key returned is `(state_in, critic, h)`.
batch[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
encoder_outs = self.encoder.critic_encoder(batch)[ENCODER_OUT]
# Shared encoder.
else:
encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC]
def compute_values(
self,
batch: Dict[str, Any],
embeddings: Optional[Any] = None,
) -> TensorType:
if embeddings is None:
# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
if self.is_stateful():
# The recurrent encoders expect a `(state_in, h)` key in the
# input dict while the key returned is `(state_in, critic, h)`.
batch[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
embeddings = self.encoder.critic_encoder(batch)[ENCODER_OUT]
# Shared encoder.
else:
embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC]
# Value head.
vf_out = self.vf(encoder_outs)
vf_out = self.vf(embeddings)
# Squeeze out last dimension (single node value head).
return vf_out.squeeze(-1)
5 changes: 1 addition & 4 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ def input_specs_train(self) -> SpecDict:

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [
Columns.VF_PREDS,
Columns.ACTION_DIST_INPUTS,
]
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(InferenceOnlyAPI)
Expand Down
49 changes: 4 additions & 45 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,7 @@
torch, nn = try_import_torch()


def get_expected_module_config(
env: gym.Env,
model_config_dict: dict,
observation_space: gym.spaces.Space,
) -> RLModuleConfig:
"""Get a PPOModuleConfig that we would expect from the catalog otherwise.
Args:
env: Environment for which we build the model later
model_config_dict: Model config to use for the catalog
observation_space: Observation space to use for the catalog.
Returns:
A PPOModuleConfig containing the relevant configs to build PPORLModule
"""
def get_expected_module_config(env, model_config_dict, observation_space):
config = RLModuleConfig(
observation_space=observation_space,
action_space=env.action_space,
Expand All @@ -52,22 +38,7 @@ def get_expected_module_config(


def dummy_torch_ppo_loss(module, batch, fwd_out):
"""Dummy PPO loss function for testing purposes.
Will eventually use the actual PPO loss function implemented in PPO.
Args:
batch: Batch used for training.
fwd_out: Forward output of the model.
Returns:
Loss tensor
"""
# TODO: we should replace these components later with real ppo components when
# RLOptimizer and RLModule are integrated together.
# this is not exactly a ppo loss, just something to show that the
# forward train works
adv = batch[Columns.REWARDS] - fwd_out[Columns.VF_PREDS]
adv = batch[Columns.REWARDS] - module.compute_values(batch)
action_dist_class = module.get_train_action_dist_cls()
action_probs = action_dist_class.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
Expand All @@ -80,19 +51,7 @@ def dummy_torch_ppo_loss(module, batch, fwd_out):


def dummy_tf_ppo_loss(module, batch, fwd_out):
"""Dummy PPO loss function for testing purposes.
Will eventually use the actual PPO loss function implemented in PPO.
Args:
module: PPOTfRLModule
batch: Batch used for training.
fwd_out: Forward output of the model.
Returns:
Loss tensor
"""
adv = batch[Columns.REWARDS] - fwd_out[Columns.VF_PREDS]
adv = batch[Columns.REWARDS] - module.compute_values(batch)
action_dist_class = module.get_train_action_dist_cls()
action_probs = action_dist_class.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
Expand Down Expand Up @@ -180,7 +139,7 @@ def test_rollouts(self):

def test_forward_train(self):
# TODO: Add FrozenLake-v1 to cover LSTM case.
frameworks = ["tf2", "torch"]
frameworks = ["torch", "tf2"]
env_names = ["CartPole-v1", "Pendulum-v1", "ALE/Breakout-v5"]
lstm = [False, True]
config_combinations = [frameworks, env_names, lstm]
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tf/ppo_tf_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _forward_train(self, batch: Dict):
return output

@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
def compute_values(self, batch: Dict[str, Any], embeddings=None) -> TensorType:
infos = batch.pop(Columns.INFOS, None)
batch = tree.map_structure(lambda s: tf.convert_to_tensor(s), batch)
if infos is not None:
Expand Down
Loading

0 comments on commit e182e19

Please sign in to comment.