Skip to content

Commit

Permalink
Fix/player precision plugin (#244)
Browse files Browse the repository at this point in the history
* Fix PPO player with precision

* Fix PPO agent to run with the correct precision plugin

* Add detach_actor + fix creation of testing agent

* FIx SAC to use the correct precision plugin

* Add get_single_devie_fabric method

* Fix DrOQ agent to handle fp16

* Fix SACAE agent to handle fp16

* Fix PPO recurrent to handle fp16

* Fix Join context

* Fix PPO evaluate

* Fix Dreamer-V1 player to handle fp16

* Dreamer-V1 inference mode

* Fix Dreamer-V2 to handle fp16

* Fix Dreamer-V3 to handle fp16

* Extract module inside the player

* Fix P2E-DV1 to handle fp16

* Fix p2e-dv2 to handle fp16

* Fix p2e-dv3 to handle fp16

* Fix A2C to handle fp16

* Add `greedy` flag to actor forward method

* Fix PPORecurrent arg postion

* Let P2E algos handle fp16

* Wrap target critics with a single-device fabric

* import annotations from future

* Fix P2E actor task during test

* Wrap actor with single-device fabric
  • Loading branch information
belerico authored Mar 27, 2024
1 parent 939d30d commit df3734a
Show file tree
Hide file tree
Showing 45 changed files with 1,360 additions and 1,183 deletions.
94 changes: 47 additions & 47 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Given that the environment has been created with the `make_env` method, the agent
# forward method must accept as input a dictionary like {"obs1_name": obs1, "obs2_name": obs2, ...}.
# The agent should be able to process both image and vector-like observations.
agent = build_agent(
agent, player = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down Expand Up @@ -224,68 +224,68 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data[k] = next_obs[k][np.newaxis]

for update in range(1, num_updates + 1):
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size
with torch.inference_mode():
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
# This calls the `forward` method of the PyTorch module, escaping from Fabric
# because we don't want this to be a synchronization point
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
actions, _, values = agent.module(torch_obs)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape))

dones = np.logical_or(done, truncated)
dones = dones.reshape(cfg.env.num_envs, -1)
rewards = rewards.reshape(cfg.env.num_envs, -1)

# Update the step data
step_data["dones"] = dones[np.newaxis]
step_data["values"] = values.cpu().numpy()[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["rewards"] = rewards[np.newaxis]
if cfg.buffer.memmap:
step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape))

# Append data to buffer
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Update the observation and dones
next_obs = {}
for k in obs_keys:
_obs = obs[k]
step_data[k] = _obs[np.newaxis]
next_obs[k] = _obs

if cfg.metric.log_level > 0 and "final_info" in info:
for i, agent_ep_info in enumerate(info["final_info"]):
if agent_ep_info is not None:
ep_rew = agent_ep_info["episode"]["r"]
ep_len = agent_ep_info["episode"]["l"]
if aggregator and "Rewards/rew_avg" in aggregator:
aggregator.update("Rewards/rew_avg", ep_rew)
if aggregator and "Game/ep_len_avg" in aggregator:
aggregator.update("Game/ep_len_avg", ep_len)
fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")
# Single environment step
obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape))

dones = np.logical_or(done, truncated)
dones = dones.reshape(cfg.env.num_envs, -1)
rewards = rewards.reshape(cfg.env.num_envs, -1)

# Update the step data
step_data["dones"] = dones[np.newaxis]
step_data["values"] = values.cpu().numpy()[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["rewards"] = rewards[np.newaxis]
if cfg.buffer.memmap:
step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape))

# Append data to buffer
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Update the observation and dones
next_obs = {}
for k in obs_keys:
_obs = obs[k]
step_data[k] = _obs[np.newaxis]
next_obs[k] = _obs

if cfg.metric.log_level > 0 and "final_info" in info:
for i, agent_ep_info in enumerate(info["final_info"]):
if agent_ep_info is not None:
ep_rew = agent_ep_info["episode"]["r"]
ep_len = agent_ep_info["episode"]["l"]
if aggregator and "Rewards/rew_avg" in aggregator:
aggregator.update("Rewards/rew_avg", ep_rew)
if aggregator and "Game/ep_len_avg" in aggregator:
aggregator.update("Game/ep_len_avg", ep_len)
fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")

# Transform the data into PyTorch Tensors
local_data = rb.to_tensor(dtype=None, device=device, from_numpy=cfg.buffer.from_numpy)

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.no_grad():
with torch.inference_mode():
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
next_values = agent.module.get_value(torch_obs)
_, _, next_values = player(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
local_data["values"],
Expand Down Expand Up @@ -351,7 +351,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero:
test(agent.module, fabric, cfg, log_dir)
test(player, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.ppo.utils import log_models
Expand Down
51 changes: 23 additions & 28 deletions sheeprl/algos/a2c/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
from typing import Any, Dict, List, Optional, Sequence, Tuple

import gymnasium
Expand All @@ -10,8 +11,10 @@
from torch import Tensor
from torch.distributions import Distribution, Independent, Normal

from sheeprl.algos.ppo.agent import PPOActor
from sheeprl.models.models import MLP
from sheeprl.utils.distribution import OneHotCategoricalValidateArgs
from sheeprl.utils.fabric import get_single_device_fabric


class MLPEncoder(nn.Module):
Expand Down Expand Up @@ -94,7 +97,7 @@ def __init__(
)

# Actor
self.actor_backbone = MLP(
actor_backbone = MLP(
input_dims=features_dim,
output_dim=None,
hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers,
Expand All @@ -109,19 +112,17 @@ def __init__(
)
if is_continuous:
# Output is a tuple of two elements: mean and log_std, one for every action
self.actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)])
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)])
else:
# Output is a tuple of one element: logits, one for every action
self.actor_heads = nn.ModuleList(
[nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]
)
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim])
self.actor = PPOActor(actor_backbone, actor_heads, is_continuous=is_continuous)

def forward(
self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None
self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None, greedy: bool = False
) -> Tuple[Sequence[Tensor], Tensor, Tensor]:
feat = self.feature_extractor(obs)
out: Tensor = self.actor_backbone(feat)
pre_dist: List[Tensor] = [head(out) for head in self.actor_heads]
pre_dist: List[Tensor] = self.actor(feat)
values = self.critic(feat)
if self.is_continuous:
mean, log_std = torch.chunk(pre_dist[0], chunks=2, dim=-1)
Expand All @@ -132,7 +133,7 @@ def forward(
validate_args=self.distribution_cfg.validate_args,
)
if actions is None:
actions = normal.sample()
actions = normal.mode if greedy else normal.sample()
else:
# always composed by a tuple of one element containing all the
# continuous actions
Expand All @@ -151,30 +152,14 @@ def forward(
OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args)
)
if should_append:
actions.append(actions_dist[-1].sample())
actions.append(actions_dist[-1].mode if greedy else actions_dist[-1].sample())
actions_logprobs.append(actions_dist[-1].log_prob(actions[i]))
return tuple(actions), torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True), values

def get_value(self, obs: Dict[str, Tensor]) -> Tensor:
feat = self.feature_extractor(obs)
return self.critic(feat)

def get_greedy_actions(self, obs: Dict[str, Tensor]) -> Sequence[Tensor]:
feat = self.feature_extractor(obs)
out = self.actor_backbone(feat)
pre_dist: List[Tensor] = [head(out) for head in self.actor_heads]
if self.is_continuous:
# Just take the mean of the distribution
return [torch.chunk(pre_dist[0], 2, -1)[0]]
else:
# Take the mode of the distribution
return tuple(
[
OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args).mode
for logits in pre_dist
]
)


def build_agent(
fabric: Fabric,
Expand All @@ -183,7 +168,7 @@ def build_agent(
cfg: Dict[str, Any],
obs_space: gymnasium.spaces.Dict,
agent_state: Optional[Dict[str, Tensor]] = None,
) -> _FabricModule:
) -> Tuple[_FabricModule, _FabricModule]:
agent = A2CAgent(
actions_dim=actions_dim,
obs_space=obs_space,
Expand All @@ -196,6 +181,16 @@ def build_agent(
)
if agent_state:
agent.load_state_dict(agent_state)
player = copy.deepcopy(agent)

# Setup training agent
agent = fabric.setup_module(agent)

return agent
# Setup player agent
fabric_player = get_single_device_fabric(fabric)
player = fabric_player.setup_module(player)

# Tie weights between the agent and the player
for agent_p, player_p in zip(agent.parameters(), player.parameters()):
player_p.data = agent_p.data
return agent, player
2 changes: 1 addition & 1 deletion sheeprl/algos/a2c/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ def evaluate_a2c(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n])
)
# Create the actor and critic models
agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"])
agent, _ = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"])
test(agent, fabric, cfg, log_dir)
10 changes: 7 additions & 3 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from typing import Any, Dict

import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule

from sheeprl.algos.a2c.agent import A2CAgent
from sheeprl.utils.env import make_env
Expand All @@ -10,7 +13,7 @@


@torch.no_grad()
def test(agent: A2CAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
def test(agent: A2CAgent | _FabricModule, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
Expand All @@ -25,10 +28,11 @@ def test(agent: A2CAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):

while not done:
# Act greedly through the environment
actions, _, _ = agent(obs, greedy=True)
if agent.is_continuous:
actions = torch.cat(agent.get_greedy_actions(obs), dim=-1)
actions = torch.cat(actions, dim=-1)
else:
actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1)
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

# Single environment step
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
Expand Down
42 changes: 26 additions & 16 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor
from sheeprl.algos.dreamer_v2.agent import MLPDecoder, MLPEncoder
from sheeprl.models.models import MLP, MultiDecoder, MultiEncoder
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.utils import init_weights

# In order to use the hydra.utils.get_class method, in this way the user can
Expand Down Expand Up @@ -221,45 +222,54 @@ class PlayerDV1(nn.Module):
"""The model of the DreamerV1 player.
Args:
encoder (nn.Module): the encoder.
recurrent_model (nn.Module): the recurrent model.
representation_model (nn.Module): the representation model.
actor (nn.Module): the actor.
fabric (Fabric): the fabric object.
encoder (nn.Module| _FabricModule): the encoder.
recurrent_model (nn.Module| _FabricModule): the recurrent model.
representation_model (nn.Module| _FabricModule): the representation model.
actor (nn.Module| _FabricModule): the actor.
actions_dim (Sequence[int]): the dimension of each action.
num_envs (int): the number of environments.
stochastic_size (int): the size of the stochastic state.
recurrent_state_size (int): the size of the recurrent state.
device (torch.device): the device to work on.
actor_type (str, optional): which actor the player is using ('task' or 'exploration').
Default to None.
"""

def __init__(
self,
encoder: nn.Module,
recurrent_model: nn.Module,
representation_model: nn.Module,
actor: nn.Module,
fabric: Fabric,
encoder: nn.Module | _FabricModule,
recurrent_model: nn.Module | _FabricModule,
representation_model: nn.Module | _FabricModule,
actor: nn.Module | _FabricModule,
actions_dim: Sequence[int],
num_envs: int,
stochastic_size: int,
recurrent_state_size: int,
device: torch.device,
actor_type: str | None = None,
) -> None:
super().__init__()
self.encoder = encoder
self.recurrent_model = recurrent_model
self.representation_model = representation_model
self.actor = actor
self.device = device
single_device_fabric = get_single_device_fabric(fabric)
self.encoder = single_device_fabric.setup_module(
getattr(encoder, "module", encoder),
)
self.recurrent_model = single_device_fabric.setup_module(
getattr(recurrent_model, "module", recurrent_model),
)
self.representation_model = single_device_fabric.setup_module(
getattr(representation_model, "module", representation_model)
)
self.actor = single_device_fabric.setup_module(
getattr(actor, "module", actor),
)
self.device = single_device_fabric.device
self.actions_dim = actions_dim
self.stochastic_size = stochastic_size
self.recurrent_state_size = recurrent_state_size
self.num_envs = num_envs
self.validate_args = self.actor.distribution_cfg.validate_args
self.init_states()
self.actor_type = actor_type
self.init_states()

def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
"""Initialize the states and the actions for the ended environments.
Expand Down
Loading

0 comments on commit df3734a

Please sign in to comment.