diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index 664b7eba..ff776684 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -8,7 +8,6 @@ from lightning.fabric.wrappers import _FabricModule from sympy import Union from torch import Tensor, nn -from torch.distributions import Normal from sheeprl.algos.dreamer_v1.utils import compute_stochastic_state from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor @@ -16,7 +15,6 @@ from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor from sheeprl.algos.dreamer_v2.agent import MLPDecoder, MLPEncoder from sheeprl.models.models import MLP, MultiDecoder, MultiEncoder -from sheeprl.utils.distribution import OneHotCategoricalValidateArgs from sheeprl.utils.utils import init_weights # In order to use the hydra.utils.get_class method, in this way the user can @@ -226,7 +224,6 @@ class PlayerDV1(nn.Module): representation_model (nn.Module): the representation model. actor (nn.Module): the actor. actions_dim (Sequence[int]): the dimension of each action. - expl_amout (float): the exploration amout to use during training. num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. recurrent_state_size (int): the size of the recurrent state. @@ -240,7 +237,6 @@ def __init__( representation_model: nn.Module, actor: nn.Module, actions_dim: Sequence[int], - expl_amount: float, num_envs: int, stochastic_size: int, recurrent_state_size: int, @@ -252,7 +248,6 @@ def __init__( self.representation_model = representation_model self.actor = actor self.device = device - self.expl_amount = expl_amount self.actions_dim = actions_dim self.stochastic_size = stochastic_size self.recurrent_state_size = recurrent_state_size @@ -277,14 +272,11 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs]) self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) - def get_exploration_action( - self, obs: Tensor, is_continuous: bool, mask: Optional[Dict[str, Tensor]] = None - ) -> Sequence[Tensor]: + def get_exploration_action(self, obs: Tensor, mask: Optional[Dict[str, Tensor]] = None) -> Sequence[Tensor]: """Return the actions with a certain amount of noise for exploration. Args: obs (Tensor): the current observations. - is_continuous (bool): whether or not the actions are continuous. mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed). Defaults to None. @@ -292,26 +284,11 @@ def get_exploration_action( The actions the agent has to perform (Sequence[Tensor]). """ actions = self.get_greedy_action(obs, mask=mask) - if is_continuous: - self.actions = torch.cat(actions, -1) - if self.expl_amount > 0.0: - self.actions = torch.clip( - Normal(self.actions, self.expl_amount, validate_args=self.validate_args).sample(), -1, 1 - ) - expl_actions = [self.actions] - else: - expl_actions = [] - for act in actions: - sample = ( - OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=self.validate_args) - .sample() - .to(self.device) - ) - expl_actions.append( - torch.where(torch.rand(act.shape[:1], device=self.device) < self.expl_amount, sample, act) - ) - self.actions = torch.cat(expl_actions, -1) - return tuple(expl_actions) + expl_actions = None + if self.actor.expl_amount > 0: + expl_actions = self.actor.add_exploration_noise(actions, mask=mask) + self.actions = torch.cat(expl_actions, dim=-1) + return expl_actions or actions def get_greedy_action( self, obs: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None @@ -397,7 +374,7 @@ def build_models( keys=cfg.mlp_keys.encoder, input_dims=[obs_space[k].shape[0] for k in cfg.mlp_keys.encoder], mlp_layers=world_model_cfg.encoder.mlp_layers, - dense_units=world_model_cfg.encoderdense_units, + dense_units=world_model_cfg.encoder.dense_units, activation=eval(world_model_cfg.encoder.dense_act), layer_norm=False, ) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 2cb9abf1..7c52b4bf 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -491,7 +491,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_model.rssm.representation_model.module, actor.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, @@ -548,12 +547,12 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) + max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: - player.expl_amount = polynomial_decay( + actor.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) @@ -620,7 +619,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -712,16 +711,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.player.expl_decay: + if cfg.algo.actor.expl_decay: expl_decay_steps += 1 - player.expl_amount = polynomial_decay( + actor.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) if aggregator: - aggregator.update("Params/exploration_amout", player.expl_amount) + aggregator.update("Params/exploration_amout", actor.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index 18c75004..6e0829ad 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -62,7 +62,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): world_model.rssm.representation_model.module, actor.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index 3ffdebb6..d513ce85 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -430,6 +432,8 @@ class Actor(nn.Module): Default to 4. layer_norm (bool): whether or not to use the layer norm. Default to False. + expl_amout (float): the exploration amout to use during training. + Default to 0.0. """ def __init__( @@ -444,6 +448,7 @@ def __init__( activation: nn.Module = nn.ELU, mlp_layers: int = 4, layer_norm: bool = False, + expl_amount: float = 0.0, ) -> None: super().__init__() self.distribution = distribution_cfg.pop("type", "auto").lower() @@ -477,6 +482,15 @@ def __init__( self.init_std = torch.tensor(init_std) self.min_std = min_std self.distribution_cfg = distribution_cfg + self._expl_amount = expl_amount + + @property + def expl_amount(self) -> float: + return self._expl_amount + + @expl_amount.setter + def expl_amount(self, amount: float): + self._expl_amount = amount def forward( self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None @@ -541,6 +555,27 @@ def forward( actions.append(actions_dist[-1].mode) return tuple(actions), tuple(actions_dist) + def add_exploration_noise( + self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None + ) -> Sequence[Tensor]: + if self.is_continuous: + actions = torch.cat(actions, -1) + if self._expl_amount > 0.0: + actions = torch.clip(Normal(actions, self._expl_amount).sample(), -1, 1) + expl_actions = [actions] + else: + expl_actions = [] + for act in actions: + sample = ( + OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False) + .sample() + .to(act.device) + ) + expl_actions.append( + torch.where(torch.rand(act.shape[:1], device=act.device) < self._expl_amount, sample, act) + ) + return tuple(expl_actions) + class MinedojoActor(Actor): def __init__( @@ -555,18 +590,20 @@ def __init__( activation: nn.Module = nn.ELU, mlp_layers: int = 4, layer_norm: bool = False, + expl_amount: float = 0.0, ) -> None: super().__init__( latent_state_size=latent_state_size, actions_dim=actions_dim, is_continuous=is_continuous, + distribution_cfg=distribution_cfg, init_std=init_std, min_std=min_std, dense_units=dense_units, activation=activation, mlp_layers=mlp_layers, layer_norm=layer_norm, - distribution_cfg=distribution_cfg, + expl_amount=expl_amount, ) def forward( @@ -605,12 +642,12 @@ def forward( logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf elif i == 2: mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits) - mask["mask_equip/place"] = mask["mask_equip/place"].expand_as(logits) + mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits) for t in range(functional_action.shape[0]): for b in range(functional_action.shape[1]): sampled_action = functional_action[t, b].item() if sampled_action in (16, 17): # Equip/Place action - logits[t, b][torch.logical_not(mask["mask_equip/place"][t, b])] = -torch.inf + logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf elif sampled_action == 18: # Destroy action logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf actions_dist.append( @@ -626,6 +663,51 @@ def forward( functional_action = actions[0].argmax(dim=-1) # [T, B] return tuple(actions), tuple(actions_dist) + def add_exploration_noise( + self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None + ) -> Sequence[Tensor]: + expl_actions = [] + functional_action = actions[0].argmax(dim=-1) + for i, act in enumerate(actions): + logits = torch.zeros_like(act) + # Exploratory action must respect the constraints of the environment + if mask is not None: + if i == 0: + logits[torch.logical_not(mask["mask_action_type"].expand_as(logits))] = -torch.inf + elif i == 1: + mask["mask_craft_smelt"] = mask["mask_craft_smelt"].expand_as(logits) + for t in range(functional_action.shape[0]): + for b in range(functional_action.shape[1]): + sampled_action = functional_action[t, b].item() + if sampled_action == 15: # Craft action + logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf + elif i == 2: + mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits) + mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits) + for t in range(functional_action.shape[0]): + for b in range(functional_action.shape[1]): + sampled_action = functional_action[t, b].item() + if sampled_action in {16, 17}: # Equip/Place action + logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf + elif sampled_action == 18: # Destroy action + logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf + sample = ( + OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False).sample().to(act.device) + ) + expl_amount = self.expl_amount + # If the action[0] was changed, and now it is critical, then we force to change also the other 2 actions + # to satisfy the constraints of the environment + if ( + i in {1, 2} + and actions[0].argmax() != expl_actions[0].argmax() + and expl_actions[0].argmax().item() in {15, 16, 17, 18} + ): + expl_amount = 2 + expl_actions.append(torch.where(torch.rand(act.shape[:1], device=self.device) < expl_amount, sample, act)) + if mask is not None and i == 0: + functional_action = expl_actions[0].argmax(dim=-1) + return tuple(expl_actions) + class WorldModel(nn.Module): """ @@ -665,7 +747,6 @@ class PlayerDV2(nn.Module): representation_model (nn.Module): the representation model. actor (nn.Module): the actor. actions_dim (Sequence[int]): the dimension of the actions. - expl_amount (float): the exploration amout to use during training. num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. recurrent_state_size (int): the size of the recurrent state. @@ -682,7 +763,6 @@ def __init__( representation_model: nn.Module, actor: nn.Module, actions_dim: Sequence[int], - expl_amount: float, num_envs: int, stochastic_size: int, recurrent_state_size: int, @@ -695,7 +775,6 @@ def __init__( self.representation_model = representation_model self.actor = actor self.device = device - self.expl_amount = expl_amount self.actions_dim = actions_dim self.stochastic_size = stochastic_size self.discrete_size = discrete_size @@ -722,12 +801,7 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs]) self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) - def get_exploration_action( - self, - obs: Dict[str, Tensor], - is_continuous: bool, - mask: Optional[Dict[str, Tensor]] = None, - ) -> Tensor: + def get_exploration_action(self, obs: Dict[str, Tensor], mask: Optional[Dict[str, Tensor]] = None) -> Tensor: """ Return the actions with a certain amount of noise for exploration. @@ -741,28 +815,11 @@ def get_exploration_action( The actions the agent has to perform. """ actions = self.get_greedy_action(obs, mask=mask) - if is_continuous: - self.actions = torch.cat(actions, -1) - if self.expl_amount > 0.0: - self.actions = torch.clip( - Normal(self.actions, self.expl_amount, validate_args=self.validate_args).sample(), - -1, - 1, - ) - expl_actions = [self.actions] - else: - expl_actions = [] - for act in actions: - sample = ( - OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=self.validate_args) - .sample() - .to(self.device) - ) - expl_actions.append( - torch.where(torch.rand(act.shape[:1], device=self.device) < self.expl_amount, sample, act) - ) - self.actions = torch.cat(expl_actions, -1) - return tuple(expl_actions) + expl_actions = None + if self.actor.expl_amount > 0: + expl_actions = self.actor.add_exploration_noise(actions, mask=mask) + self.actions = torch.cat(expl_actions, dim=-1) + return expl_actions or actions def get_greedy_action( self, diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 583f6d43..20d4d669 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -516,7 +516,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_model.rssm.representation_model.module, actor.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, @@ -586,12 +585,12 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) + max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: - player.expl_amount = polynomial_decay( + actor.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) @@ -666,7 +665,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -790,16 +789,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): per_rank_gradient_steps += 1 train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.player.expl_decay: + if cfg.algo.actor.expl_decay: expl_decay_steps += 1 - player.expl_amount = polynomial_decay( + actor.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amout", player.expl_amount) + aggregator.update("Params/exploration_amout", actor.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index 3640b1c9..f32da474 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -62,7 +62,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): world_model.rssm.representation_model.module, actor.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index cc064a13..702b0b5e 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -465,7 +467,6 @@ class PlayerDV3(nn.Module): representation_model (_FabricModule): the representation model. actor (_FabricModule): the actor. actions_dim (Sequence[int]): the dimension of the actions. - expl_amout (float): the exploration amout to use during training. num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. recurrent_state_size (int): the size of the recurrent state. @@ -482,7 +483,6 @@ def __init__( rssm: RSSM, actor: _FabricModule, actions_dim: Sequence[int], - expl_amount: float, num_envs: int, stochastic_size: int, recurrent_state_size: int, @@ -501,7 +501,6 @@ def __init__( ) self.actor = actor self.device = device - self.expl_amount = expl_amount self.actions_dim = actions_dim self.stochastic_size = stochastic_size self.discrete_size = discrete_size @@ -533,47 +532,30 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs], sample_state=False )[1].reshape(1, len(reset_envs), -1) - def get_exploration_action( - self, - obs: Dict[str, Tensor], - is_continuous: bool, - mask: Optional[Dict[str, np.ndarray]] = None, - ) -> Tensor: + def get_exploration_action(self, obs: Dict[str, Tensor], mask: Optional[Dict[str, Tensor]] = None) -> Tensor: """ Return the actions with a certain amount of noise for exploration. Args: obs (Dict[str, Tensor]): the current observations. - is_continuous (bool): whether or not the actions are continuous. + mask (Dict[str, Tensor], optional): the mask of the actions. + Default to None. Returns: The actions the agent has to perform. """ actions = self.get_greedy_action(obs, mask=mask) - if is_continuous: - self.actions = torch.cat(actions, -1) - if self.expl_amount > 0.0: - self.actions = torch.clip(Normal(self.actions, self.expl_amount).sample(), -1, 1) - expl_actions = [self.actions] - else: - expl_actions = [] - for act in actions: - sample = ( - OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False) - .sample() - .to(self.device) - ) - expl_actions.append( - torch.where(torch.rand(act.shape[:1], device=self.device) < self.expl_amount, sample, act) - ) - self.actions = torch.cat(expl_actions, -1) - return tuple(expl_actions) + expl_actions = None + if self.actor.expl_amount > 0: + expl_actions = self.actor.add_exploration_noise(actions, mask=mask) + self.actions = torch.cat(expl_actions, dim=-1) + return expl_actions or actions def get_greedy_action( self, obs: Dict[str, Tensor], is_training: bool = True, - mask: Optional[Dict[str, np.ndarray]] = None, + mask: Optional[Dict[str, Tensor]] = None, ) -> Sequence[Tensor]: """ Return the greedy actions. @@ -626,6 +608,8 @@ class Actor(nn.Module): then `p = (1 - self.unimix) * p + self.unimix * unif`, where `unif = `1 / self.discrete`. Defaults to 0.01. + expl_amout (float): the exploration amout to use during training. + Default to 0.0. """ def __init__( @@ -641,6 +625,7 @@ def __init__( mlp_layers: int = 5, layer_norm: bool = True, unimix: float = 0.01, + expl_amount: float = 0.0, ) -> None: super().__init__() self.distribution_cfg = distribution_cfg @@ -678,9 +663,18 @@ def __init__( self.init_std = torch.tensor(init_std) self.min_std = min_std self._unimix = unimix + self._expl_amount = expl_amount + + @property + def expl_amount(self) -> float: + return self._expl_amount + + @expl_amount.setter + def expl_amount(self, amount: float): + self._expl_amount = amount def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, np.ndarray]] = None + self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -746,6 +740,27 @@ def _uniform_mix(self, logits: Tensor) -> Tensor: logits = probs_to_logits(probs) return logits + def add_exploration_noise( + self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None + ) -> Sequence[Tensor]: + if self.is_continuous: + actions = torch.cat(actions, -1) + if self._expl_amount > 0.0: + actions = torch.clip(Normal(actions, self._expl_amount).sample(), -1, 1) + expl_actions = [actions] + else: + expl_actions = [] + for act in actions: + sample = ( + OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False) + .sample() + .to(act.device) + ) + expl_actions.append( + torch.where(torch.rand(act.shape[:1], device=act.device) < self._expl_amount, sample, act) + ) + return tuple(expl_actions) + class MinedojoActor(Actor): def __init__( @@ -757,9 +772,11 @@ def __init__( init_std: float = 0, min_std: float = 0.1, dense_units: int = 1024, - dense_act: nn.Module = nn.SiLU, + activation: nn.Module = nn.SiLU, mlp_layers: int = 5, layer_norm: bool = True, + unimix: float = 0.01, + expl_amount: float = 0.0, ) -> None: super().__init__( latent_state_size=latent_state_size, @@ -769,13 +786,15 @@ def __init__( init_std=init_std, min_std=min_std, dense_units=dense_units, - dense_act=dense_act, + activation=activation, mlp_layers=mlp_layers, layer_norm=layer_norm, + unimix=unimix, + expl_amount=expl_amount, ) def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, np.ndarray]] = None + self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -789,7 +808,7 @@ def forward( The distribution of the actions """ out: Tensor = self.model(state) - actions_logits: List[Tensor] = [head(out) for head in self.mlp_heads] + actions_logits: List[Tensor] = [self._uniform_mix(head(out)) for head in self.mlp_heads] actions_dist: List[Distribution] = [] actions: List[Tensor] = [] functional_action = None @@ -806,12 +825,12 @@ def forward( logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf elif i == 2: mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits) - mask["mask_equip/place"] = mask["mask_equip/place"].expand_as(logits) + mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits) for t in range(functional_action.shape[0]): for b in range(functional_action.shape[1]): sampled_action = functional_action[t, b].item() if sampled_action in (16, 17): # Equip/Place action - logits[t, b][torch.logical_not(mask["mask_equip/place"][t, b])] = -torch.inf + logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf elif sampled_action == 18: # Destroy action logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf actions_dist.append( @@ -827,6 +846,51 @@ def forward( functional_action = actions[0].argmax(dim=-1) # [T, B] return tuple(actions), tuple(actions_dist) + def add_exploration_noise( + self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None + ) -> Sequence[Tensor]: + expl_actions = [] + functional_action = actions[0].argmax(dim=-1) + for i, act in enumerate(actions): + logits = torch.zeros_like(act) + # Exploratory action must respect the constraints of the environment + if mask is not None: + if i == 0: + logits[torch.logical_not(mask["mask_action_type"].expand_as(logits))] = -torch.inf + elif i == 1: + mask["mask_craft_smelt"] = mask["mask_craft_smelt"].expand_as(logits) + for t in range(functional_action.shape[0]): + for b in range(functional_action.shape[1]): + sampled_action = functional_action[t, b].item() + if sampled_action == 15: # Craft action + logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf + elif i == 2: + mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits) + mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits) + for t in range(functional_action.shape[0]): + for b in range(functional_action.shape[1]): + sampled_action = functional_action[t, b].item() + if sampled_action in {16, 17}: # Equip/Place action + logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf + elif sampled_action == 18: # Destroy action + logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf + sample = ( + OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False).sample().to(act.device) + ) + expl_amount = self.expl_amount + # If the action[0] was changed, and now it is critical, then we force to change also the other 2 actions + # to satisfy the constraints of the environment + if ( + i in {1, 2} + and actions[0].argmax() != expl_actions[0].argmax() + and expl_actions[0].argmax().item() in {15, 16, 17, 18} + ): + expl_amount = 2 + expl_actions.append(torch.where(torch.rand(act.shape[:1], device=self.device) < expl_amount, sample, act)) + if mask is not None and i == 0: + functional_action = expl_actions[0].argmax(dim=-1) + return tuple(expl_actions) + def build_models( fabric: Fabric, diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index defcfe55..247c4885 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -449,7 +449,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_model.rssm, actor.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, @@ -516,12 +515,12 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: - player.expl_amount = polynomial_decay( + actor.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) @@ -589,7 +588,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() @@ -707,16 +706,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): per_rank_gradient_steps += 1 train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.player.expl_decay: + if cfg.algo.actor.expl_decay: expl_decay_steps += 1 - player.expl_amount = polynomial_decay( + actor.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amout", player.expl_amount) + aggregator.update("Params/exploration_amout", actor.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index eeef7915..e9f298d9 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -61,7 +61,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): world_model.rssm, actor.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index 0dd7abb2..8321cb6c 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -63,7 +63,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): world_model.rssm.representation_model.module, actor_task.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index 99b776d1..6fb1305b 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -529,7 +529,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_model.rssm.representation_model.module, actor_exploration.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, @@ -611,12 +610,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): exploration_updates = min(num_updates, exploration_updates) if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) + max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: - player.expl_amount = polynomial_decay( + actor_task.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, + max_decay_steps=max_step_expl_decay, + ) + actor_exploration.expl_amount = polynomial_decay( + expl_decay_steps, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) @@ -691,7 +696,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -792,16 +797,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.player.expl_decay: + if cfg.algo.actor.expl_decay: expl_decay_steps += 1 - player.expl_amount = polynomial_decay( + actor_task.expl_amount = polynomial_decay( + expl_decay_steps, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, + max_decay_steps=max_step_expl_decay, + ) + actor_exploration.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amout", player.expl_amount) + aggregator.update("Params/exploration_amout_task", actor_task.expl_amount) + aggregator.update("Params/exploration_amout_exploration", actor_exploration.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index 8a6b1f99..ae29731a 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -63,7 +63,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): world_model.rssm.representation_model.module, actor_task.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index b5e374e3..a395559d 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -668,7 +668,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_model.rssm.representation_model.module, actor_exploration.module, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, @@ -760,12 +759,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) + max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: - player.expl_amount = polynomial_decay( + actor_task.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, + max_decay_steps=max_step_expl_decay, + ) + actor_exploration.expl_amount = polynomial_decay( + expl_decay_steps, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) @@ -852,7 +857,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -988,16 +993,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.player.expl_decay: + if cfg.algo.actor.expl_decay: expl_decay_steps += 1 - player.expl_amount = polynomial_decay( + actor_task.expl_amount = polynomial_decay( + expl_decay_steps, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, + max_decay_steps=max_step_expl_decay, + ) + actor_exploration.expl_amount = polynomial_decay( expl_decay_steps, - initial=cfg.algo.player.expl_amount, - final=cfg.algo.player.expl_min, + initial=cfg.algo.actor.expl_amount, + final=cfg.algo.actor.expl_min, max_decay_steps=max_step_expl_decay, ) if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amout", player.expl_amount) + aggregator.update("Params/exploration_amout_task", actor_task.expl_amount) + aggregator.update("Params/exploration_amout_exploration", actor_exploration.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): diff --git a/sheeprl/configs/algo/dreamer_v1.yaml b/sheeprl/configs/algo/dreamer_v1.yaml index f76fcc36..06cd789f 100644 --- a/sheeprl/configs/algo/dreamer_v1.yaml +++ b/sheeprl/configs/algo/dreamer_v1.yaml @@ -91,6 +91,10 @@ actor: mlp_layers: ${algo.mlp_layers} dense_units: ${algo.dense_units} clip_gradients: 100.0 + expl_amount: 0.3 + expl_min: 0.0 + expl_decay: False + max_step_expl_decay: 200000 # Actor optimizer optimizer: @@ -110,10 +114,3 @@ critic: lr: 8e-5 eps: 1e-5 weight_decay: 0 - -# Player agent (it interacts with the environment) -player: - expl_min: 0.0 - expl_amount: 0.3 - expl_decay: False - max_step_expl_decay: 200000 diff --git a/sheeprl/configs/algo/dreamer_v2.yaml b/sheeprl/configs/algo/dreamer_v2.yaml index 86c31012..dd8f2a9c 100644 --- a/sheeprl/configs/algo/dreamer_v2.yaml +++ b/sheeprl/configs/algo/dreamer_v2.yaml @@ -104,6 +104,10 @@ actor: layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} clip_gradients: 100.0 + expl_amount: 0.0 + expl_min: 0.0 + expl_decay: False + max_step_expl_decay: 0 # Actor optimizer optimizer: @@ -128,8 +132,4 @@ critic: # Player agent (it interacts with the environment) player: - expl_min: 0.0 - expl_amount: 0.0 - expl_decay: False - max_step_expl_decay: 0 discrete_size: ${algo.world_model.discrete_size} diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index b13876f2..0a37436e 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -106,6 +106,10 @@ actor: layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} clip_gradients: 100.0 + expl_amount: 0.0 + expl_min: 0.0 + expl_decay: False + max_step_expl_decay: 0 # Disttributed percentile model (used to scale the values) moments: @@ -140,8 +144,4 @@ critic: # Player agent (it interacts with the environment) player: - expl_min: 0.0 - expl_amount: 0.0 - expl_decay: False - max_step_expl_decay: 0 discrete_size: ${algo.world_model.discrete_size} diff --git a/sheeprl/configs/env/minedojo.yaml b/sheeprl/configs/env/minedojo.yaml index fb79138b..590cc15b 100644 --- a/sheeprl/configs/env/minedojo.yaml +++ b/sheeprl/configs/env/minedojo.yaml @@ -5,6 +5,7 @@ defaults: # Override from `minecraft` config id: open-ended action_repeat: 1 +capture_video: True # Wrapper to be instantiated wrapper: diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index bb30835d..8d6e610f 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -11,7 +11,9 @@ import gymnasium as gym import minedojo import numpy as np +from gymnasium.core import RenderFrame from minedojo.sim import ALL_CRAFT_SMELT_ITEMS, ALL_ITEMS +from minedojo.sim.wrappers.ar_nn import ARNNWrapper N_ALL_ITEMS = len(ALL_ITEMS) ACTION_MAP = { @@ -77,7 +79,7 @@ def __init__( f"The initial position must respect the pitch limits {self._pitch_limits}, given {self._pos['pitch']}" ) - env = minedojo.make( + env: ARNNWrapper = minedojo.make( task_id=id, image_size=(height, width), world_seed=seed, @@ -103,7 +105,7 @@ def __init__( "equipment": gym.spaces.Box(0.0, 1.0, (N_ALL_ITEMS,), np.int32), "life_stats": gym.spaces.Box(0.0, np.array([20.0, 20.0, 300.0]), (3,), np.float32), "mask_action_type": gym.spaces.Box(0, 1, (len(ACTION_MAP),), bool), - "mask_equip/place": gym.spaces.Box(0, 1, (N_ALL_ITEMS,), bool), + "mask_equip_place": gym.spaces.Box(0, 1, (N_ALL_ITEMS,), bool), "mask_destroy": gym.spaces.Box(0, 1, (N_ALL_ITEMS,), bool), "mask_craft_smelt": gym.spaces.Box(0, 1, (len(ALL_CRAFT_SMELT_ITEMS),), bool), } @@ -133,7 +135,10 @@ def _convert_inventory(self, inventory: Dict[str, Any]) -> np.ndarray: else: self._inventory[item].append(i) # count the items in the inventory - converted_inventory[ITEM_NAME_TO_ID[item]] += quantity + if item == "air": + converted_inventory[ITEM_NAME_TO_ID[item]] += 1 + else: + converted_inventory[ITEM_NAME_TO_ID[item]] += quantity self._inventory_max = np.maximum(converted_inventory, self._inventory_max) return converted_inventory @@ -170,7 +175,7 @@ def _convert_masks(self, masks: Dict[str, Any]) -> Dict[str, np.ndarray]: masks["action_type"][7] *= np.any(destroy_mask).item() return { "mask_action_type": np.concatenate((np.array([True] * 12), masks["action_type"][1:])), - "mask_equip/place": equip_mask, + "mask_equip_place": equip_mask, "mask_destroy": destroy_mask, "mask_craft_smelt": masks["craft_smelt"], } @@ -284,3 +289,13 @@ def reset( "location_stats": copy.deepcopy(self._pos), "biomeid": float(obs["location_stats"]["biome_id"].item()), } + + def render(self) -> RenderFrame | list[RenderFrame] | None: + if self.render_mode == "human": + super().render() + elif self.render_mode == "rgb_array": + if self.env.unwrapped._prev_obs is None: + return None + else: + return self.env.unwrapped._prev_obs["rgb"] + return None