Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/minedojo #148

Merged
merged 12 commits into from
Nov 9, 2023
37 changes: 7 additions & 30 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
from lightning.fabric.wrappers import _FabricModule
from sympy import Union
from torch import Tensor, nn
from torch.distributions import Normal

from sheeprl.algos.dreamer_v1.utils import compute_stochastic_state
from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor
from sheeprl.algos.dreamer_v2.agent import CNNDecoder, CNNEncoder
from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor
from sheeprl.algos.dreamer_v2.agent import MLPDecoder, MLPEncoder
from sheeprl.models.models import MLP, MultiDecoder, MultiEncoder
from sheeprl.utils.distribution import OneHotCategoricalValidateArgs
from sheeprl.utils.utils import init_weights

# In order to use the hydra.utils.get_class method, in this way the user can
Expand Down Expand Up @@ -226,7 +224,6 @@ class PlayerDV1(nn.Module):
representation_model (nn.Module): the representation model.
actor (nn.Module): the actor.
actions_dim (Sequence[int]): the dimension of each action.
expl_amout (float): the exploration amout to use during training.
num_envs (int): the number of environments.
stochastic_size (int): the size of the stochastic state.
recurrent_state_size (int): the size of the recurrent state.
Expand All @@ -240,7 +237,6 @@ def __init__(
representation_model: nn.Module,
actor: nn.Module,
actions_dim: Sequence[int],
expl_amount: float,
num_envs: int,
stochastic_size: int,
recurrent_state_size: int,
Expand All @@ -252,7 +248,6 @@ def __init__(
self.representation_model = representation_model
self.actor = actor
self.device = device
self.expl_amount = expl_amount
self.actions_dim = actions_dim
self.stochastic_size = stochastic_size
self.recurrent_state_size = recurrent_state_size
Expand All @@ -277,41 +272,23 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs])
self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs])

def get_exploration_action(
self, obs: Tensor, is_continuous: bool, mask: Optional[Dict[str, Tensor]] = None
) -> Sequence[Tensor]:
def get_exploration_action(self, obs: Tensor, mask: Optional[Dict[str, Tensor]] = None) -> Sequence[Tensor]:
"""Return the actions with a certain amount of noise for exploration.

Args:
obs (Tensor): the current observations.
is_continuous (bool): whether or not the actions are continuous.
mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed).
Defaults to None.

Returns:
The actions the agent has to perform (Sequence[Tensor]).
"""
actions = self.get_greedy_action(obs, mask=mask)
if is_continuous:
self.actions = torch.cat(actions, -1)
if self.expl_amount > 0.0:
self.actions = torch.clip(
Normal(self.actions, self.expl_amount, validate_args=self.validate_args).sample(), -1, 1
)
expl_actions = [self.actions]
else:
expl_actions = []
for act in actions:
sample = (
OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=self.validate_args)
.sample()
.to(self.device)
)
expl_actions.append(
torch.where(torch.rand(act.shape[:1], device=self.device) < self.expl_amount, sample, act)
)
self.actions = torch.cat(expl_actions, -1)
return tuple(expl_actions)
expl_actions = None
if self.actor.expl_amount > 0:
expl_actions = self.actor.add_exploration_noise(actions, mask=mask)
self.actions = torch.cat(expl_actions, dim=-1)
return expl_actions or actions

def get_greedy_action(
self, obs: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None
Expand Down Expand Up @@ -397,7 +374,7 @@ def build_models(
keys=cfg.mlp_keys.encoder,
input_dims=[obs_space[k].shape[0] for k in cfg.mlp_keys.encoder],
mlp_layers=world_model_cfg.encoder.mlp_layers,
dense_units=world_model_cfg.encoderdense_units,
dense_units=world_model_cfg.encoder.dense_units,
activation=eval(world_model_cfg.encoder.dense_act),
layer_norm=False,
)
Expand Down
21 changes: 10 additions & 11 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
world_model.rssm.representation_model.module,
actor.module,
actions_dim,
cfg.algo.player.expl_amount,
cfg.env.num_envs,
cfg.algo.world_model.stochastic_size,
cfg.algo.world_model.recurrent_model.recurrent_state_size,
Expand Down Expand Up @@ -548,12 +547,12 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0
if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint:
learning_starts += start_step
max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size)
max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size)
if cfg.checkpoint.resume_from:
player.expl_amount = polynomial_decay(
actor.expl_amount = polynomial_decay(
expl_decay_steps,
initial=cfg.algo.player.expl_amount,
final=cfg.algo.player.expl_min,
initial=cfg.algo.actor.expl_amount,
final=cfg.algo.actor.expl_min,
max_decay_steps=max_step_expl_decay,
)

Expand Down Expand Up @@ -620,7 +619,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask)
real_actions = actions = player.get_exploration_action(preprocessed_obs, mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down Expand Up @@ -712,16 +711,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)
train_step += world_size
updates_before_training = cfg.algo.train_every // policy_steps_per_update
if cfg.algo.player.expl_decay:
if cfg.algo.actor.expl_decay:
expl_decay_steps += 1
player.expl_amount = polynomial_decay(
actor.expl_amount = polynomial_decay(
expl_decay_steps,
initial=cfg.algo.player.expl_amount,
final=cfg.algo.player.expl_min,
initial=cfg.algo.actor.expl_amount,
final=cfg.algo.actor.expl_min,
max_decay_steps=max_step_expl_decay,
)
if aggregator:
aggregator.update("Params/exploration_amout", player.expl_amount)
aggregator.update("Params/exploration_amout", actor.expl_amount)

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates):
Expand Down
1 change: 0 additions & 1 deletion sheeprl/algos/dreamer_v1/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
world_model.rssm.representation_model.module,
actor.module,
actions_dim,
cfg.algo.player.expl_amount,
cfg.env.num_envs,
cfg.algo.world_model.stochastic_size,
cfg.algo.world_model.recurrent_model.recurrent_state_size,
Expand Down
125 changes: 91 additions & 34 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
from typing import Any, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -430,6 +432,8 @@ class Actor(nn.Module):
Default to 4.
layer_norm (bool): whether or not to use the layer norm.
Default to False.
expl_amout (float): the exploration amout to use during training.
Default to 0.0.
"""

def __init__(
Expand All @@ -444,6 +448,7 @@ def __init__(
activation: nn.Module = nn.ELU,
mlp_layers: int = 4,
layer_norm: bool = False,
expl_amount: float = 0.0,
) -> None:
super().__init__()
self.distribution = distribution_cfg.pop("type", "auto").lower()
Expand Down Expand Up @@ -477,6 +482,15 @@ def __init__(
self.init_std = torch.tensor(init_std)
self.min_std = min_std
self.distribution_cfg = distribution_cfg
self._expl_amount = expl_amount

@property
def expl_amount(self) -> float:
return self._expl_amount

@expl_amount.setter
def expl_amount(self, amount: float):
self._expl_amount = amount

def forward(
self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None
Expand Down Expand Up @@ -541,6 +555,27 @@ def forward(
actions.append(actions_dist[-1].mode)
return tuple(actions), tuple(actions_dist)

def add_exploration_noise(
self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None
) -> Sequence[Tensor]:
if self.is_continuous:
actions = torch.cat(actions, -1)
if self._expl_amount > 0.0:
actions = torch.clip(Normal(actions, self._expl_amount).sample(), -1, 1)
expl_actions = [actions]
else:
expl_actions = []
for act in actions:
sample = (
OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False)
.sample()
.to(act.device)
)
expl_actions.append(
torch.where(torch.rand(act.shape[:1], device=act.device) < self._expl_amount, sample, act)
)
return tuple(expl_actions)


class MinedojoActor(Actor):
def __init__(
Expand All @@ -555,18 +590,20 @@ def __init__(
activation: nn.Module = nn.ELU,
mlp_layers: int = 4,
layer_norm: bool = False,
expl_amount: float = 0.0,
) -> None:
super().__init__(
latent_state_size=latent_state_size,
actions_dim=actions_dim,
is_continuous=is_continuous,
distribution_cfg=distribution_cfg,
init_std=init_std,
min_std=min_std,
dense_units=dense_units,
activation=activation,
mlp_layers=mlp_layers,
layer_norm=layer_norm,
distribution_cfg=distribution_cfg,
expl_amount=expl_amount,
)

def forward(
Expand Down Expand Up @@ -605,12 +642,12 @@ def forward(
logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf
elif i == 2:
mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits)
mask["mask_equip/place"] = mask["mask_equip/place"].expand_as(logits)
mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits)
for t in range(functional_action.shape[0]):
for b in range(functional_action.shape[1]):
sampled_action = functional_action[t, b].item()
if sampled_action in (16, 17): # Equip/Place action
logits[t, b][torch.logical_not(mask["mask_equip/place"][t, b])] = -torch.inf
logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf
elif sampled_action == 18: # Destroy action
logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf
actions_dist.append(
Expand All @@ -626,6 +663,51 @@ def forward(
functional_action = actions[0].argmax(dim=-1) # [T, B]
return tuple(actions), tuple(actions_dist)

def add_exploration_noise(
self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None
) -> Sequence[Tensor]:
expl_actions = []
functional_action = actions[0].argmax(dim=-1)
for i, act in enumerate(actions):
logits = torch.zeros_like(act)
# Exploratory action must respect the constraints of the environment
if mask is not None:
if i == 0:
logits[torch.logical_not(mask["mask_action_type"].expand_as(logits))] = -torch.inf
elif i == 1:
mask["mask_craft_smelt"] = mask["mask_craft_smelt"].expand_as(logits)
for t in range(functional_action.shape[0]):
for b in range(functional_action.shape[1]):
sampled_action = functional_action[t, b].item()
if sampled_action == 15: # Craft action
logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf
elif i == 2:
mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits)
mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits)
for t in range(functional_action.shape[0]):
for b in range(functional_action.shape[1]):
sampled_action = functional_action[t, b].item()
if sampled_action in {16, 17}: # Equip/Place action
logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf
elif sampled_action == 18: # Destroy action
logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf
sample = (
OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False).sample().to(act.device)
)
expl_amount = self.expl_amount
# If the action[0] was changed, and now it is critical, then we force to change also the other 2 actions
# to satisfy the constraints of the environment
if (
i in {1, 2}
and actions[0].argmax() != expl_actions[0].argmax()
and expl_actions[0].argmax().item() in {15, 16, 17, 18}
):
expl_amount = 2
expl_actions.append(torch.where(torch.rand(act.shape[:1], device=self.device) < expl_amount, sample, act))
if mask is not None and i == 0:
functional_action = expl_actions[0].argmax(dim=-1)
return tuple(expl_actions)


class WorldModel(nn.Module):
"""
Expand Down Expand Up @@ -665,7 +747,6 @@ class PlayerDV2(nn.Module):
representation_model (nn.Module): the representation model.
actor (nn.Module): the actor.
actions_dim (Sequence[int]): the dimension of the actions.
expl_amount (float): the exploration amout to use during training.
num_envs (int): the number of environments.
stochastic_size (int): the size of the stochastic state.
recurrent_state_size (int): the size of the recurrent state.
Expand All @@ -682,7 +763,6 @@ def __init__(
representation_model: nn.Module,
actor: nn.Module,
actions_dim: Sequence[int],
expl_amount: float,
num_envs: int,
stochastic_size: int,
recurrent_state_size: int,
Expand All @@ -695,7 +775,6 @@ def __init__(
self.representation_model = representation_model
self.actor = actor
self.device = device
self.expl_amount = expl_amount
self.actions_dim = actions_dim
self.stochastic_size = stochastic_size
self.discrete_size = discrete_size
Expand All @@ -722,12 +801,7 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs])
self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs])

def get_exploration_action(
self,
obs: Dict[str, Tensor],
is_continuous: bool,
mask: Optional[Dict[str, Tensor]] = None,
) -> Tensor:
def get_exploration_action(self, obs: Dict[str, Tensor], mask: Optional[Dict[str, Tensor]] = None) -> Tensor:
"""
Return the actions with a certain amount of noise for exploration.

Expand All @@ -741,28 +815,11 @@ def get_exploration_action(
The actions the agent has to perform.
"""
actions = self.get_greedy_action(obs, mask=mask)
if is_continuous:
self.actions = torch.cat(actions, -1)
if self.expl_amount > 0.0:
self.actions = torch.clip(
Normal(self.actions, self.expl_amount, validate_args=self.validate_args).sample(),
-1,
1,
)
expl_actions = [self.actions]
else:
expl_actions = []
for act in actions:
sample = (
OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=self.validate_args)
.sample()
.to(self.device)
)
expl_actions.append(
torch.where(torch.rand(act.shape[:1], device=self.device) < self.expl_amount, sample, act)
)
self.actions = torch.cat(expl_actions, -1)
return tuple(expl_actions)
expl_actions = None
if self.actor.expl_amount > 0:
expl_actions = self.actor.add_exploration_noise(actions, mask=mask)
self.actions = torch.cat(expl_actions, dim=-1)
return expl_actions or actions

def get_greedy_action(
self,
Expand Down
Loading