Skip to content

Commit

Permalink
[RLlib] Make the "tiny CNN" example RLModule run with APPO (by implem…
Browse files Browse the repository at this point in the history
…enting `TargetNetAPI`) (ray-project#49825)

Signed-off-by: Puyuan Yao <williamyao034@gmail.com>
  • Loading branch information
sven1977 authored and anyadontfly committed Feb 13, 2025
1 parent 83ecf23 commit 33608b8
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 34 deletions.
1 change: 0 additions & 1 deletion rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"
OLD_ACTION_DIST_KEY = "old_action_dist"
OLD_ACTION_DIST_LOGITS_KEY = "old_action_dist_logits"


class APPOConfig(IMPALAConfig):
Expand Down
8 changes: 5 additions & 3 deletions rllib/algorithms/appo/default_appo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from typing import Any, Dict, List, Tuple

from ray.rllib.algorithms.ppo.default_ppo_rl_module import DefaultPPORLModule
from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
from ray.rllib.core.rl_module.apis import TargetNetworkAPI
from ray.rllib.core.rl_module.apis import (
TARGET_NETWORK_ACTION_DIST_INPUTS,
TargetNetworkAPI,
)
from ray.rllib.utils.typing import NetworkType

from ray.rllib.utils.annotations import (
Expand Down Expand Up @@ -43,7 +45,7 @@ def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}
return {TARGET_NETWORK_ACTION_DIST_INPUTS: old_action_dist_logits}

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(DefaultPPORLModule)
Expand Down
9 changes: 6 additions & 3 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
APPOConfig,
LEARNER_RESULTS_CURR_KL_COEFF_KEY,
LEARNER_RESULTS_KL_KEY,
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.appo.appo_learner import APPOLearner
from ray.rllib.algorithms.impala.torch.impala_torch_learner import IMPALATorchLearner
Expand All @@ -25,7 +24,11 @@
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY
from ray.rllib.core.rl_module.apis import TargetNetworkAPI, ValueFunctionAPI
from ray.rllib.core.rl_module.apis import (
TARGET_NETWORK_ACTION_DIST_INPUTS,
TargetNetworkAPI,
ValueFunctionAPI,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
Expand Down Expand Up @@ -76,7 +79,7 @@ def compute_loss_for_module(
)

old_target_policy_dist = action_dist_cls_train.from_logits(
module.forward_target(batch)[OLD_ACTION_DIST_LOGITS_KEY]
module.forward_target(batch)[TARGET_NETWORK_ACTION_DIST_INPUTS]
)
old_target_policy_actions_logp = old_target_policy_dist.logp(
batch[Columns.ACTIONS]
Expand Down
6 changes: 5 additions & 1 deletion rllib/core/rl_module/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from ray.rllib.core.rl_module.apis.inference_only_api import InferenceOnlyAPI
from ray.rllib.core.rl_module.apis.q_net_api import QNetAPI
from ray.rllib.core.rl_module.apis.self_supervised_loss_api import SelfSupervisedLossAPI
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.apis.target_network_api import (
TargetNetworkAPI,
TARGET_NETWORK_ACTION_DIST_INPUTS,
)
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI


Expand All @@ -10,5 +13,6 @@
"QNetAPI",
"SelfSupervisedLossAPI",
"TargetNetworkAPI",
"TARGET_NETWORK_ACTION_DIST_INPUTS",
"ValueFunctionAPI",
]
9 changes: 6 additions & 3 deletions rllib/core/rl_module/apis/target_network_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from ray.util.annotations import PublicAPI


TARGET_NETWORK_ACTION_DIST_INPUTS = "target_network_action_dist_inputs"


@PublicAPI(stability="alpha")
class TargetNetworkAPI(abc.ABC):
"""An API to be implemented by RLModules for handling target networks.
Expand All @@ -20,9 +23,9 @@ class TargetNetworkAPI(abc.ABC):
def make_target_networks(self) -> None:
"""Creates the required target nets for this RLModule.
You should use the convenience
`ray.rllib.core.learner.utils.make_target_network()` utility and pass in
an already existing, corresponding "main" net (for which you need a target net).
Use the convenience `ray.rllib.core.learner.utils.make_target_network()` utility
when implementing this method. Pass in an already existing, corresponding "main"
net (for which you need a target net).
This function already takes care of initialization (from the "main" net).
"""

Expand Down
72 changes: 50 additions & 22 deletions rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Any, Dict, Optional

from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.rl_module.apis import (
TargetNetworkAPI,
ValueFunctionAPI,
TARGET_NETWORK_ACTION_DIST_INPUTS,
)
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.models.torch.misc import (
normc_initializer,
Expand All @@ -15,7 +20,7 @@
torch, nn = try_import_torch()


class TinyAtariCNN(TorchRLModule, ValueFunctionAPI):
class TinyAtariCNN(TorchRLModule, ValueFunctionAPI, TargetNetworkAPI):
"""A tiny CNN stack for fast-learning of Atari envs.
The architecture here is the exact same as the one used by the old API stack as
Expand All @@ -26,27 +31,29 @@ class TinyAtariCNN(TorchRLModule, ValueFunctionAPI):
Simple reshaping (no flattening or extra linear layers necessary) lead to the
action logits, which can directly be used inside a distribution or loss.
.. testcode::
import numpy as np
import gymnasium as gym
my_net = TinyAtariCNN(
observation_space=gym.spaces.Box(-1.0, 1.0, (42, 42, 4), np.float32),
action_space=gym.spaces.Discrete(4),
)
B = 10
w = 42
h = 42
c = 4
data = torch.from_numpy(
np.random.random_sample(size=(B, w, h, c)).astype(np.float32)
)
print(my_net.forward_inference({"obs": data}))
print(my_net.forward_exploration({"obs": data}))
print(my_net.forward_train({"obs": data}))
num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters())
print(f"num params = {num_all_params}")
import gymnasium as gym
my_net = TinyAtariCNN(
observation_space=gym.spaces.Box(-1.0, 1.0, (42, 42, 4), np.float32),
action_space=gym.spaces.Discrete(4),
)
B = 10
w = 42
h = 42
c = 4
data = torch.from_numpy(
np.random.random_sample(size=(B, w, h, c)).astype(np.float32)
)
print(my_net.forward_inference({"obs": data}))
print(my_net.forward_exploration({"obs": data}))
print(my_net.forward_train({"obs": data}))
num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters())
print(f"num params = {num_all_params}")
"""

@override(TorchRLModule)
Expand Down Expand Up @@ -141,6 +148,27 @@ def _forward_train(self, batch, **kwargs):
Columns.EMBEDDINGS: embeddings,
}

# We implement this RLModule as a TargetNetworkAPI RLModule, so it can be used
# by the APPO algorithm.
@override(TargetNetworkAPI)
def make_target_networks(self) -> None:
self._target_base_cnn_stack = make_target_network(self._base_cnn_stack)
self._target_logits = make_target_network(self._logits)

@override(TargetNetworkAPI)
def get_target_network_pairs(self):
return [
(self._base_cnn_stack, self._target_base_cnn_stack),
(self._logits, self._target_logits),
]

@override(TargetNetworkAPI)
def forward_target(self, batch, **kw):
obs = batch[Columns.OBS].permute(0, 3, 1, 2)
embeddings = self._target_base_cnn_stack(obs)
logits = self._target_logits(embeddings)
return {TARGET_NETWORK_ACTION_DIST_INPUTS: torch.squeeze(logits, dim=[-1, -2])}

# We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used
# by value-based methods like PPO or IMPALA.
@override(ValueFunctionAPI)
Expand Down
2 changes: 1 addition & 1 deletion rllib/tuned_examples/impala/stateless_cartpole_impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
parser.set_defaults(
enable_new_api_stack=True,
num_env_runners=3,
num_env_runners=5,
)
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values to set up `config` below.
Expand Down

0 comments on commit 33608b8

Please sign in to comment.