From 3711d056d1bb026d1291af1444049809fc5ec653 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 8 Feb 2024 09:07:13 +0100 Subject: [PATCH 01/51] Decoupled RSSM for DV3 agent --- sheeprl/algos/dreamer_v3/agent.py | 250 +++++++++++++++++++------ sheeprl/algos/dreamer_v3/dreamer_v3.py | 40 ++-- sheeprl/configs/algo/dreamer_v3.yaml | 1 + 3 files changed, 220 insertions(+), 71 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 2c3e1d8a..a734e1e8 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -13,6 +13,7 @@ from torch import Tensor, device, nn from torch.distributions import Distribution, Independent, Normal, TanhTransform, TransformedDistribution from torch.distributions.utils import probs_to_logits +from torch.nn.modules import Module from sheeprl.algos.dreamer_v2.agent import WorldModel from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state @@ -68,9 +69,11 @@ def __init__( layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, activation=activation, norm_layer=[LayerNormChannelLast for _ in range(stages)] if layer_norm else None, - norm_args=[{"normalized_shape": (2**i) * channels_multiplier, "eps": 1e-3} for i in range(stages)] - if layer_norm - else None, + norm_args=( + [{"normalized_shape": (2**i) * channels_multiplier, "eps": 1e-3} for i in range(stages)] + if layer_norm + else None + ), ), nn.Flatten(-3, -1), ) @@ -123,9 +126,9 @@ def __init__( activation=activation, layer_args={"bias": not layer_norm}, norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, - norm_args=[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] - if layer_norm - else None, + norm_args=( + [{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] if layer_norm else None + ), ) self.output_dim = dense_units self.symlog_inputs = symlog_inputs @@ -193,13 +196,15 @@ def __init__( + [{"kernel_size": 4, "stride": 2, "padding": 1}], activation=[activation for _ in range(stages - 1)] + [None], norm_layer=[LayerNormChannelLast for _ in range(stages - 1)] + [None] if layer_norm else None, - norm_args=[ - {"normalized_shape": (2 ** (stages - i - 2)) * channels_multiplier, "eps": 1e-3} - for i in range(stages - 1) - ] - + [None] - if layer_norm - else None, + norm_args=( + [ + {"normalized_shape": (2 ** (stages - i - 2)) * channels_multiplier, "eps": 1e-3} + for i in range(stages - 1) + ] + + [None] + if layer_norm + else None + ), ), ) @@ -248,9 +253,9 @@ def __init__( activation=activation, layer_args={"bias": not layer_norm}, norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, - norm_args=[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] - if layer_norm - else None, + norm_args=( + [{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] if layer_norm else None + ), ) self.heads = nn.ModuleList([nn.Linear(dense_units, mlp_dim) for mlp_dim in self.output_dims]) @@ -457,6 +462,93 @@ def imagination(self, prior: Tensor, recurrent_state: Tensor, actions: Tensor) - return imagined_prior, recurrent_state +class DecoupledRSSM(RSSM): + """RSSM model for the model-base Dreamer agent. + + Args: + recurrent_model (nn.Module): the recurrent model of the RSSM model described in + [https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551). + representation_model (nn.Module): the representation model composed by a + multi-layer perceptron to compute the stochastic part of the latent state. + For more information see [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + transition_model (nn.Module): the transition model described in + [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + The model is composed by a multi-layer perceptron to predict the stochastic part of the latent state. + distribution_cfg (Dict[str, Any]): the configs of the distributions. + discrete (int, optional): the size of the Categorical variables. + Defaults to 32. + unimix: (float, optional): the percentage of uniform distribution to inject into the categorical + distribution over states, i.e. given some logits `l` and probabilities `p = softmax(l)`, + then `p = (1 - self.unimix) * p + self.unimix * unif`, where `unif = `1 / self.discrete`. + Defaults to 0.01. + """ + + def __init__( + self, + recurrent_model: Module, + representation_model: Module, + transition_model: Module, + distribution_cfg: Dict[str, Any], + discrete: int = 32, + unimix: float = 0.01, + ) -> None: + super().__init__(recurrent_model, representation_model, transition_model, distribution_cfg, discrete, unimix) + + def dynamic( + self, posterior: Tensor, recurrent_state: Tensor, action: Tensor, is_first: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Perform one step of the dynamic learning: + Recurrent model: compute the recurrent state from the previous latent space, the action taken by the agent, + i.e., it computes the deterministic state (or ht). + Transition model: predict the prior from the recurrent output. + Representation model: compute the posterior from the recurrent state and from + the embedded observations provided by the environment. + For more information see [https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551) + and [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + + Args: + posterior (Tensor): the stochastic state computed by the representation model (posterior). It is expected + to be of dimension `[stoch_size, self.discrete]`, which by default is `[32, 32]`. + recurrent_state (Tensor): a tuple representing the recurrent state of the recurrent model. + action (Tensor): the action taken by the agent. + embedded_obs (Tensor): the embedded observations provided by the environment. + is_first (Tensor): if this is the first step in the episode. + + Returns: + The recurrent state (Tensor): the recurrent state of the recurrent model. + The posterior stochastic state (Tensor): computed by the representation model + The prior stochastic state (Tensor): computed by the transition model + The logits of the posterior state (Tensor): computed by the transition model from the recurrent state. + The logits of the prior state (Tensor): computed by the transition model from the recurrent state. + from the recurrent state and the embbedded observation. + """ + action = (1 - is_first) * action + recurrent_state = (1 - is_first) * recurrent_state + is_first * torch.tanh(torch.zeros_like(recurrent_state)) + posterior = posterior.view(*posterior.shape[:-2], -1) + # posterior = (1 - is_first) * posterior + is_first * self._transition(recurrent_state, sample_state=False)[ + # 1 + # ].view_as(posterior) + recurrent_state = self.recurrent_model(torch.cat((posterior, action), -1), recurrent_state) + prior_logits, prior = self._transition(recurrent_state) + return recurrent_state, prior, prior_logits + + def _representation(self, obs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + """ + Args: + obs (Tensor): the real observations provided by the environment. + + Returns: + logits (Tensor): the logits of the distribution of the posterior state. + posterior (Tensor): the sampled posterior stochastic state. + """ + logits: Tensor = self.representation_model(obs) + logits = self._uniform_mix(logits) + return logits, compute_stochastic_state( + logits, discrete=self.discrete, validate_args=self.distribution_cfg.validate_args + ) + + class PlayerDV3(nn.Module): """ The model of the Dreamer_v3 player. @@ -482,7 +574,7 @@ class PlayerDV3(nn.Module): def __init__( self, encoder: _FabricModule, - rssm: RSSM, + rssm: RSSM | DecoupledRSSM, actor: _FabricModule, actions_dim: Sequence[int], num_envs: int, @@ -491,10 +583,15 @@ def __init__( device: device = "cpu", discrete_size: int = 32, actor_type: str | None = None, + decoupled_rssm: bool = False, ) -> None: super().__init__() self.encoder = encoder - self.rssm = RSSM( + if decoupled_rssm: + rssm_cls = DecoupledRSSM + else: + rssm_cls = RSSM + self.rssm = rssm_cls( recurrent_model=rssm.recurrent_model.module, representation_model=rssm.representation_model.module, transition_model=rssm.transition_model.module, @@ -511,6 +608,7 @@ def __init__( self.num_envs = num_envs self.validate_args = self.actor.distribution_cfg.validate_args self.actor_type = actor_type + self.decoupled_rssm = decoupled_rssm @torch.no_grad() def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: @@ -576,7 +674,10 @@ def get_greedy_action( self.recurrent_state = self.rssm.recurrent_model( torch.cat((self.stochastic_state, self.actions), -1), self.recurrent_state ) - _, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs) + if self.decoupled_rssm: + _, self.stochastic_state = self.rssm._representation(obs) + else: + _, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs) self.stochastic_state = self.stochastic_state.view( *self.stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size ) @@ -654,9 +755,9 @@ def __init__( flatten_dim=None, layer_args={"bias": not layer_norm}, norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, - norm_args=[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] - if layer_norm - else None, + norm_args=( + [{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] if layer_norm else None + ), ) if is_continuous: self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, np.sum(actions_dim) * 2)]) @@ -972,18 +1073,28 @@ def build_agent( **world_model_cfg.recurrent_model, input_size=int(sum(actions_dim) + stochastic_size), ) - representation_model = MLP( - input_dims=recurrent_state_size + encoder.cnn_output_dim + encoder.mlp_output_dim, - output_dim=stochastic_size, - hidden_sizes=[world_model_cfg.representation_model.hidden_size], - activation=eval(world_model_cfg.representation_model.dense_act), - layer_args={"bias": not world_model_cfg.representation_model.layer_norm}, - flatten_dim=None, - norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None, - norm_args=[{"normalized_shape": world_model_cfg.representation_model.hidden_size}] - if world_model_cfg.representation_model.layer_norm - else None, - ) + if cfg.algo.decoupled_rssm: + representation_model = nn.Sequential( + copy.deepcopy(encoder), + nn.LayerNorm(encoder.output_dim, eps=1e-3), + eval(world_model_cfg.encoder.cnn_act)(), + nn.Linear(encoder.output_dim, stochastic_size, bias=False), + ) + else: + representation_model = MLP( + input_dims=recurrent_state_size + encoder.output_dim, + output_dim=stochastic_size, + hidden_sizes=[world_model_cfg.representation_model.hidden_size], + activation=eval(world_model_cfg.representation_model.dense_act), + layer_args={"bias": not world_model_cfg.representation_model.layer_norm}, + flatten_dim=None, + norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None, + norm_args=( + [{"normalized_shape": world_model_cfg.representation_model.hidden_size}] + if world_model_cfg.representation_model.layer_norm + else None + ), + ) transition_model = MLP( input_dims=recurrent_state_size, output_dim=stochastic_size, @@ -992,11 +1103,17 @@ def build_agent( layer_args={"bias": not world_model_cfg.transition_model.layer_norm}, flatten_dim=None, norm_layer=[nn.LayerNorm] if world_model_cfg.transition_model.layer_norm else None, - norm_args=[{"normalized_shape": world_model_cfg.transition_model.hidden_size}] - if world_model_cfg.transition_model.layer_norm - else None, + norm_args=( + [{"normalized_shape": world_model_cfg.transition_model.hidden_size}] + if world_model_cfg.transition_model.layer_norm + else None + ), ) - rssm = RSSM( + if cfg.algo.decoupled_rssm: + rssm_cls = DecoupledRSSM + else: + rssm_cls = RSSM + rssm = rssm_cls( recurrent_model=recurrent_model.apply(init_weights), representation_model=representation_model.apply(init_weights), transition_model=transition_model.apply(init_weights), @@ -1040,15 +1157,19 @@ def build_agent( activation=eval(world_model_cfg.reward_model.dense_act), layer_args={"bias": not world_model_cfg.reward_model.layer_norm}, flatten_dim=None, - norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)] - if world_model_cfg.reward_model.layer_norm - else None, - norm_args=[ - {"normalized_shape": world_model_cfg.reward_model.dense_units} - for _ in range(world_model_cfg.reward_model.mlp_layers) - ] - if world_model_cfg.reward_model.layer_norm - else None, + norm_layer=( + [nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)] + if world_model_cfg.reward_model.layer_norm + else None + ), + norm_args=( + [ + {"normalized_shape": world_model_cfg.reward_model.dense_units} + for _ in range(world_model_cfg.reward_model.mlp_layers) + ] + if world_model_cfg.reward_model.layer_norm + else None + ), ) continue_model = MLP( input_dims=latent_state_size, @@ -1057,15 +1178,19 @@ def build_agent( activation=eval(world_model_cfg.discount_model.dense_act), layer_args={"bias": not world_model_cfg.discount_model.layer_norm}, flatten_dim=None, - norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)] - if world_model_cfg.discount_model.layer_norm - else None, - norm_args=[ - {"normalized_shape": world_model_cfg.discount_model.dense_units} - for _ in range(world_model_cfg.discount_model.mlp_layers) - ] - if world_model_cfg.discount_model.layer_norm - else None, + norm_layer=( + [nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)] + if world_model_cfg.discount_model.layer_norm + else None + ), + norm_args=( + [ + {"normalized_shape": world_model_cfg.discount_model.dense_units} + for _ in range(world_model_cfg.discount_model.mlp_layers) + ] + if world_model_cfg.discount_model.layer_norm + else None + ), ) world_model = WorldModel( encoder.apply(init_weights), @@ -1096,9 +1221,11 @@ def build_agent( layer_args={"bias": not critic_cfg.layer_norm}, flatten_dim=None, norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, - norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] - if critic_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] + if critic_cfg.layer_norm + else None + ), ) actor.apply(init_weights) critic.apply(init_weights) @@ -1107,7 +1234,10 @@ def build_agent( actor.mlp_heads.apply(uniform_init_weights(1.0)) critic.model[-1].apply(uniform_init_weights(0.0)) rssm.transition_model.model[-1].apply(uniform_init_weights(1.0)) - rssm.representation_model.model[-1].apply(uniform_init_weights(1.0)) + if cfg.algo.decoupled_rssm: + rssm.representation_model[-1].apply(uniform_init_weights(1.0)) + else: + rssm.representation_model.model[-1].apply(uniform_init_weights(1.0)) world_model.reward_model.model[-1].apply(uniform_init_weights(0.0)) world_model.continue_model.model[-1].apply(uniform_init_weights(1.0)) if mlp_decoder is not None: diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 964bb98d..690d76b9 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -1,6 +1,7 @@ """Dreamer-V3 implementation from [https://arxiv.org/abs/2301.04104](https://arxiv.org/abs/2301.04104) Adapted from the original implementation from https://github.com/danijar/dreamerv3 """ + from __future__ import annotations import copy @@ -107,23 +108,39 @@ def train( # Dynamic Learning stoch_state_size = stochastic_size * discrete_size recurrent_state = torch.zeros(1, batch_size, recurrent_state_size, device=device) - posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device) recurrent_states = torch.empty(sequence_length, batch_size, recurrent_state_size, device=device) priors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) - posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device) - posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) # Embed observations from the environment embedded_obs = world_model.encoder(batch_obs) - for i in range(0, sequence_length): - recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic( - posterior, recurrent_state, batch_actions[i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1] - ) - recurrent_states[i] = recurrent_state - priors_logits[i] = prior_logits - posteriors[i] = posterior - posteriors_logits[i] = posterior_logits + if cfg.algo.decoupled_rssm: + posteriors_logits, posteriors = world_model.rssm._representation(batch_obs) + for i in range(0, sequence_length): + recurrent_state, posterior_logits, prior_logits = world_model.rssm.dynamic( + posteriors[i : i + 1], + recurrent_state, + batch_actions[i : i + 1], + data["is_first"][i : i + 1], + ) + recurrent_states[i] = recurrent_state + priors_logits[i] = prior_logits + else: + posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device) + posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device) + posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) + for i in range(0, sequence_length): + recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic( + posterior, + recurrent_state, + batch_actions[i : i + 1], + embedded_obs[i : i + 1], + data["is_first"][i : i + 1], + ) + recurrent_states[i] = recurrent_state + priors_logits[i] = prior_logits + posteriors[i] = posterior + posteriors_logits[i] = posterior_logits latent_states = torch.cat((posteriors.view(*posteriors.shape[:-2], -1), recurrent_states), -1) # Compute predictions for the observations @@ -456,6 +473,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.algo.world_model.recurrent_model.recurrent_state_size, fabric.device, discrete_size=cfg.algo.world_model.discrete_size, + decoupled_rssm=cfg.algo.decoupled_rssm, ) # Optimizers diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index 2fc0bace..941160c5 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -31,6 +31,7 @@ dense_act: torch.nn.SiLU cnn_act: torch.nn.SiLU unimix: 0.01 hafner_initialization: True +decoupled_rssm: False # World model world_model: From e80e9d50d1c80a50793b5ec5196f639c9c832f10 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 8 Feb 2024 09:07:57 +0100 Subject: [PATCH 02/51] Initialize posterior with prior if is_first is True --- sheeprl/algos/dreamer_v3/agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index a734e1e8..3779fd86 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -526,9 +526,9 @@ def dynamic( action = (1 - is_first) * action recurrent_state = (1 - is_first) * recurrent_state + is_first * torch.tanh(torch.zeros_like(recurrent_state)) posterior = posterior.view(*posterior.shape[:-2], -1) - # posterior = (1 - is_first) * posterior + is_first * self._transition(recurrent_state, sample_state=False)[ - # 1 - # ].view_as(posterior) + posterior = (1 - is_first) * posterior + is_first * self._transition(recurrent_state, sample_state=False)[ + 1 + ].view_as(posterior) recurrent_state = self.recurrent_model(torch.cat((posterior, action), -1), recurrent_state) prior_logits, prior = self._transition(recurrent_state) return recurrent_state, prior, prior_logits From f47b8f97f6b886498e92ad50ff8197d7fb11dd3f Mon Sep 17 00:00:00 2001 From: belerico Date: Mon, 12 Feb 2024 09:19:48 +0100 Subject: [PATCH 03/51] Fix PlayerDV3 creation in evaluation --- sheeprl/algos/dreamer_v3/evaluate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index 17b32ddf..7fa239fc 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -63,6 +63,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): cfg.algo.world_model.recurrent_model.recurrent_state_size, fabric.device, discrete_size=cfg.algo.world_model.discrete_size, + decoupled_rssm=cfg.algo.decoupled_rssm, ) test(player, fabric, cfg, log_dir, sample_actions=True) From 2ec4fbb5760d87cf3a7699b7a0c7280ccbf39bd0 Mon Sep 17 00:00:00 2001 From: belerico Date: Mon, 26 Feb 2024 21:05:36 +0100 Subject: [PATCH 04/51] Fix representation_model --- sheeprl/algos/dreamer_v3/agent.py | 52 +++++++++++--------------- sheeprl/algos/dreamer_v3/dreamer_v3.py | 2 +- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 2c3c1f2e..fe92cd9a 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -533,16 +533,16 @@ def dynamic( prior_logits, prior = self._transition(recurrent_state) return recurrent_state, prior, prior_logits - def _representation(self, obs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + def _representation(self, embedded_obs: Tensor) -> Tuple[Tensor, Tensor]: """ Args: - obs (Tensor): the real observations provided by the environment. + embedded_obs (Tensor): the embedded real observations provided by the environment. Returns: logits (Tensor): the logits of the distribution of the posterior state. posterior (Tensor): the sampled posterior stochastic state. """ - logits: Tensor = self.representation_model(obs) + logits: Tensor = self.representation_model(embedded_obs) logits = self._uniform_mix(logits) return logits, compute_stochastic_state( logits, discrete=self.discrete, validate_args=self.distribution_cfg.validate_args @@ -675,7 +675,7 @@ def get_greedy_action( torch.cat((self.stochastic_state, self.actions), -1), self.recurrent_state ) if self.decoupled_rssm: - _, self.stochastic_state = self.rssm._representation(obs) + _, self.stochastic_state = self.rssm._representation(embedded_obs) else: _, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs) self.stochastic_state = self.stochastic_state.view( @@ -1073,28 +1073,23 @@ def build_agent( **world_model_cfg.recurrent_model, input_size=int(sum(actions_dim) + stochastic_size), ) - if cfg.algo.decoupled_rssm: - representation_model = nn.Sequential( - copy.deepcopy(encoder), - nn.LayerNorm(encoder.output_dim, eps=1e-3), - eval(world_model_cfg.encoder.cnn_act)(), - nn.Linear(encoder.output_dim, stochastic_size, bias=False), - ) - else: - representation_model = MLP( - input_dims=recurrent_state_size + encoder.output_dim, - output_dim=stochastic_size, - hidden_sizes=[world_model_cfg.representation_model.hidden_size], - activation=eval(world_model_cfg.representation_model.dense_act), - layer_args={"bias": not world_model_cfg.representation_model.layer_norm}, - flatten_dim=None, - norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None, - norm_args=( - [{"normalized_shape": world_model_cfg.representation_model.hidden_size}] - if world_model_cfg.representation_model.layer_norm - else None - ), - ) + represention_model_input_size = encoder.output_dim + if not cfg.algo.decoupled_rssm: + represention_model_input_size += recurrent_state_size + representation_model = MLP( + input_dims=represention_model_input_size, + output_dim=stochastic_size, + hidden_sizes=[world_model_cfg.representation_model.hidden_size], + activation=eval(world_model_cfg.representation_model.dense_act), + layer_args={"bias": not world_model_cfg.representation_model.layer_norm}, + flatten_dim=None, + norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None, + norm_args=( + [{"normalized_shape": world_model_cfg.representation_model.hidden_size}] + if world_model_cfg.representation_model.layer_norm + else None + ), + ) transition_model = MLP( input_dims=recurrent_state_size, output_dim=stochastic_size, @@ -1234,10 +1229,7 @@ def build_agent( actor.mlp_heads.apply(uniform_init_weights(1.0)) critic.model[-1].apply(uniform_init_weights(0.0)) rssm.transition_model.model[-1].apply(uniform_init_weights(1.0)) - if cfg.algo.decoupled_rssm: - rssm.representation_model[-1].apply(uniform_init_weights(1.0)) - else: - rssm.representation_model.model[-1].apply(uniform_init_weights(1.0)) + rssm.representation_model.model[-1].apply(uniform_init_weights(1.0)) world_model.reward_model.model[-1].apply(uniform_init_weights(0.0)) world_model.continue_model.model[-1].apply(uniform_init_weights(1.0)) if mlp_decoder is not None: diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 4d646e96..170a0b9f 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -115,7 +115,7 @@ def train( embedded_obs = world_model.encoder(batch_obs) if cfg.algo.decoupled_rssm: - posteriors_logits, posteriors = world_model.rssm._representation(batch_obs) + posteriors_logits, posteriors = world_model.rssm._representation(embedded_obs) for i in range(0, sequence_length): recurrent_state, posterior_logits, prior_logits = world_model.rssm.dynamic( posteriors[i : i + 1], From 3a5380b470f12ce6947a7b1bc89a3e83c93d7b2a Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 27 Feb 2024 09:03:51 +0000 Subject: [PATCH 05/51] Fix compute first prior state with a zero posterior --- sheeprl/algos/dreamer_v3/dreamer_v3.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 170a0b9f..94e56f75 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -117,8 +117,12 @@ def train( if cfg.algo.decoupled_rssm: posteriors_logits, posteriors = world_model.rssm._representation(embedded_obs) for i in range(0, sequence_length): + if i == 0: + posterior = torch.zeros_like(posteriors[:1]) + else: + posterior = posteriors[i - 1 : i] recurrent_state, posterior_logits, prior_logits = world_model.rssm.dynamic( - posteriors[i : i + 1], + posterior, recurrent_state, batch_actions[i : i + 1], data["is_first"][i : i + 1], From 42d9433dd150d3467cd45bf136e9b8407ab78e6b Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 29 Feb 2024 10:12:55 +0100 Subject: [PATCH 06/51] DV3 replay ratio conversion --- sheeprl/algos/dreamer_v3/dreamer_v3.py | 22 ++++++++----- sheeprl/configs/algo/dreamer_v3.yaml | 4 +-- sheeprl/configs/exp/dreamer_v3.yaml | 1 + .../configs/exp/dreamer_v3_100k_boxing.yaml | 1 - .../exp/dreamer_v3_100k_ms_pacman.yaml | 3 +- sheeprl/configs/exp/dreamer_v3_L_doapp.yaml | 2 +- ..._v3_L_doapp_128px_gray_combo_discrete.yaml | 2 +- .../configs/exp/dreamer_v3_L_navigate.yaml | 2 +- .../configs/exp/dreamer_v3_XL_crafter.yaml | 2 +- .../configs/exp/dreamer_v3_benchmarks.yaml | 2 +- .../exp/dreamer_v3_dmc_walker_walk.yaml | 2 +- .../exp/dreamer_v3_super_mario_bros.yaml | 2 +- sheeprl/utils/utils.py | 33 ++++++++++++++++++- 13 files changed, 55 insertions(+), 23 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 94e56f75..2664662d 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -39,7 +39,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, polynomial_decay, save_configs # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -543,7 +543,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) @@ -558,6 +557,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -688,16 +692,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: + repeats = ratio(policy_step / world_size) + if update >= learning_starts and repeats > 0: local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ), + n_samples=repeats, dtype=None, device=fabric.device, from_numpy=cfg.buffer.from_numpy, @@ -727,7 +728,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) per_rank_gradient_steps += 1 train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.actor.expl_decay: expl_decay_steps += 1 actor.expl_amount = polynomial_decay( @@ -747,6 +747,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log("Params/replay_ratio", per_rank_gradient_steps * world_size / policy_step, policy_step) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -784,6 +787,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "critic_optimizer": critic_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, "moments": moments.state_dict(), + "ratio": ratio.state_dict(), "update": update * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index c033f1d7..37e038b3 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -13,10 +13,8 @@ lmbda: 0.95 horizon: 15 # Training recipe -train_every: 16 +replay_ratio: 1 learning_starts: 65536 -per_rank_pretrain_steps: 1 -per_rank_gradient_steps: 1 per_rank_sequence_length: ??? # Encoder and decoder keys diff --git a/sheeprl/configs/exp/dreamer_v3.yaml b/sheeprl/configs/exp/dreamer_v3.yaml index 5907104d..1c0f1419 100644 --- a/sheeprl/configs/exp/dreamer_v3.yaml +++ b/sheeprl/configs/exp/dreamer_v3.yaml @@ -8,6 +8,7 @@ defaults: # Algorithm algo: + replay_ratio: 1 total_steps: 5000000 per_rank_batch_size: 16 per_rank_sequence_length: 64 diff --git a/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml b/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml index 7479c7e3..0a8f9eda 100644 --- a/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml +++ b/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml @@ -30,6 +30,5 @@ buffer: # Algorithm algo: - train_every: 1 total_steps: 100000 learning_starts: 1024 \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml b/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml index ca440728..8c85d19e 100644 --- a/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml +++ b/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml @@ -26,6 +26,5 @@ buffer: # Algorithm algo: - learning_starts: 1024 total_steps: 100000 - train_every: 1 + learning_starts: 1024 \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml index 0539a325..cb1ac4c1 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml @@ -30,7 +30,7 @@ buffer: # Algorithm algo: learning_starts: 65536 - train_every: 8 + replay_ratio: 0.125 cnn_keys: encoder: - frame diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml index 5c3636b4..a2c6b78f 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml @@ -38,7 +38,7 @@ algo: total_steps: 10000000 per_rank_batch_size: 8 learning_starts: 65536 - train_every: 8 + replay_ratio: 0.125 cnn_keys: encoder: - frame diff --git a/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml b/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml index 5f4e2f8c..23a762c3 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml @@ -28,7 +28,7 @@ buffer: # Algorithm algo: - train_every: 16 + replay_ratio: 0.015625 learning_starts: 65536 cnn_keys: encoder: diff --git a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml index 8f0a136f..35f8fb2f 100644 --- a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml +++ b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml @@ -26,7 +26,7 @@ buffer: # Algorithm algo: train_every: 2 - learning_starts: 1024 + replay_ratio: 0.5 cnn_keys: encoder: - rgb diff --git a/sheeprl/configs/exp/dreamer_v3_benchmarks.yaml b/sheeprl/configs/exp/dreamer_v3_benchmarks.yaml index b0b83b17..e10dfd96 100644 --- a/sheeprl/configs/exp/dreamer_v3_benchmarks.yaml +++ b/sheeprl/configs/exp/dreamer_v3_benchmarks.yaml @@ -26,7 +26,7 @@ buffer: # Algorithm algo: learning_starts: 1024 - train_every: 16 + replay_ratio: 1 dense_units: 8 mlp_layers: 1 world_model: diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml index f4c0db04..3e90dfe9 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml @@ -37,7 +37,7 @@ algo: mlp_keys: encoder: [] learning_starts: 1024 - train_every: 2 + replay_ratio: 0.5 # Metric metric: diff --git a/sheeprl/configs/exp/dreamer_v3_super_mario_bros.yaml b/sheeprl/configs/exp/dreamer_v3_super_mario_bros.yaml index 1c5d8546..2c219281 100644 --- a/sheeprl/configs/exp/dreamer_v3_super_mario_bros.yaml +++ b/sheeprl/configs/exp/dreamer_v3_super_mario_bros.yaml @@ -28,7 +28,7 @@ algo: mlp_keys: encoder: [] learning_starts: 16384 - train_every: 4 + replay_ratio: 0.25 # Metric metric: diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index d85eb16a..fcdee5dd 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -2,7 +2,7 @@ import copy import os -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union import numpy as np import rich.syntax @@ -195,3 +195,34 @@ def unwrap_fabric(model: _FabricModule | nn.Module) -> nn.Module: def save_configs(cfg: dotdict, log_dir: str): OmegaConf.save(cfg.as_dict(), os.path.join(log_dir, "config.yaml"), resolve=True) + + +class Ratio: + """Directly taken from Hafner et al. (2023) implementation: + https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/embodied/core/when.py#L26 + """ + + def __init__(self, ratio: float): + if ratio < 0: + raise ValueError(f"Ratio must be non-negative, got {ratio}") + self._ratio: float = ratio + self._prev: float | None = None + + def __call__(self, step: float): + step = int(step) + if self._ratio == 0: + return 0 + if self._prev is None: + self._prev = step + return 1 + repeats = round((step - self._prev) * self._ratio) + self._prev += repeats / self._ratio + return repeats + + def state_dict(self) -> Dict[str, Any]: + return {"_ratio": self._ratio, "_prev": self._prev} + + def load_state_dict(self, state_dict: Mapping[str, Any]): + self._ratio = state_dict["_ratio"] + self._prev = state_dict["_prev"] + return self From b06433b4e40456d8c56b30b2aef70f67b457ee00 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 29 Feb 2024 10:59:57 +0100 Subject: [PATCH 07/51] Removed expl parameters dependent on old per_Rank_gradient_steps --- sheeprl/algos/dreamer_v3/dreamer_v3.py | 21 +-------------------- sheeprl/configs/algo/dreamer_v3.yaml | 4 ---- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 2664662d..6296499b 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -39,7 +39,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import Ratio, polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -526,7 +526,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -545,15 +544,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step @@ -728,16 +720,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) per_rank_gradient_steps += 1 train_step += world_size - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount", actor.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -785,7 +767,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, "moments": moments.state_dict(), "ratio": ratio.state_dict(), "update": update * fabric.world_size, diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index 37e038b3..de77c3c3 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -114,10 +114,6 @@ actor: layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} clip_gradients: 100.0 - expl_amount: 0.0 - expl_min: 0.0 - expl_decay: False - max_step_expl_decay: 0 # Disttributed percentile model (used to scale the values) moments: From 704b0ceeb21137e1a0173b12e3b698b642a2d961 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Mon, 18 Mar 2024 09:36:03 +0100 Subject: [PATCH 08/51] feat: update repeats computation --- sheeprl/algos/dreamer_v3/dreamer_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 228b1001..d4cfb21f 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -683,7 +683,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - repeats = ratio(policy_step / world_size) + repeats = ratio(policy_step) if update >= learning_starts and repeats > 0: local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, From e1290ee65db01c73bcd140f306db70d5c56138c0 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Thu, 28 Mar 2024 11:17:57 +0100 Subject: [PATCH 09/51] feat: update learning starts in config --- hf.txt | 444 ++++++++ our.txt | 1004 +++++++++++++++++++ sheeprl/configs/algo/dreamer_v3.yaml | 2 +- sheeprl/configs/exp/p2e_dv3_finetuning.yaml | 2 +- 4 files changed, 1450 insertions(+), 2 deletions(-) create mode 100644 hf.txt create mode 100644 our.txt diff --git a/hf.txt b/hf.txt new file mode 100644 index 00000000..24b01d70 --- /dev/null +++ b/hf.txt @@ -0,0 +1,444 @@ +Episode has 500 steps and return 49.6. +Episode has 500 steps and return 28.1. +Episode has 500 steps and return 50.0. +Episode has 500 steps and return 47.3. +Episode has 500 steps and return 56.3. +Episode has 500 steps and return 73.6. +Episode has 500 steps and return 105.3. +Episode has 500 steps and return 53.3. +Episode has 500 steps and return 166.2. +Episode has 500 steps and return 100.8. +Episode has 500 steps and return 80.8. +Episode has 500 steps and return 144.9. +Episode has 500 steps and return 94.1. +Episode has 500 steps and return 146.6. +Episode has 500 steps and return 181.6. +Episode has 500 steps and return 146.2. +Episode has 500 steps and return 161.2. +Episode has 500 steps and return 178.6. +Episode has 500 steps and return 168.6. +Episode has 500 steps and return 159.5. +Episode has 500 steps and return 173.2. +Episode has 500 steps and return 173.4. +Episode has 500 steps and return 227.2. +Episode has 500 steps and return 266.1. +Episode has 500 steps and return 205.1. +Episode has 500 steps and return 164.7. +Episode has 500 steps and return 183.3. +Episode has 500 steps and return 208.4. +Episode has 500 steps and return 232.4. +Episode has 500 steps and return 213.2. +Episode has 500 steps and return 224.9. +Episode has 500 steps and return 236.5. +Episode has 500 steps and return 200.1. +Episode has 500 steps and return 198.5. +Episode has 500 steps and return 255.3. +Episode has 500 steps and return 227.0. +Episode has 500 steps and return 404.7. +Episode has 500 steps and return 388.6. +Episode has 500 steps and return 211.1. +Episode has 500 steps and return 229.2. +Episode has 500 steps and return 294.0. +Episode has 500 steps and return 171.7. +Episode has 500 steps and return 230.0. +Episode has 500 steps and return 265.8. +Episode has 500 steps and return 292.8. +Episode has 500 steps and return 228.9. +Episode has 500 steps and return 488.3. +Episode has 500 steps and return 92.0. +Episode has 500 steps and return 218.0. +Episode has 500 steps and return 170.2. +Episode has 500 steps and return 254.1. +Episode has 500 steps and return 75.2. +Episode has 500 steps and return 438.4. +Episode has 500 steps and return 239.5. +Episode has 500 steps and return 74.4. +Episode has 500 steps and return 268.6. +Episode has 500 steps and return 215.2. +Episode has 500 steps and return 191.5. +Episode has 500 steps and return 177.5. +Episode has 500 steps and return 331.3. +Episode has 500 steps and return 98.2. +Episode has 500 steps and return 35.0. +Episode has 500 steps and return 563.7. +Episode has 500 steps and return 52.1. +Episode has 500 steps and return 133.8. +Episode has 500 steps and return 96.4. +Episode has 500 steps and return 117.0. +Episode has 500 steps and return 120.2. +Episode has 500 steps and return 313.9. +Episode has 500 steps and return 375.6. +Episode has 500 steps and return 65.4. +Episode has 500 steps and return 374.2. +Episode has 500 steps and return 442.6. +Episode has 500 steps and return 286.9. +Episode has 500 steps and return 399.1. +Episode has 500 steps and return 434.4. +Episode has 500 steps and return 537.1. +Episode has 500 steps and return 548.6. +Episode has 500 steps and return 293.3. +Episode has 500 steps and return 555.3. +Episode has 500 steps and return 421.8. +Episode has 500 steps and return 170.0. +Episode has 500 steps and return 460.9. +Episode has 500 steps and return 368.1. +Episode has 500 steps and return 507.9. +Episode has 500 steps and return 404.2. +Episode has 500 steps and return 557.2. +Episode has 500 steps and return 472.3. +Episode has 500 steps and return 480.3. +Episode has 500 steps and return 472.7. +Episode has 500 steps and return 442.1. +Episode has 500 steps and return 304.9. +Episode has 500 steps and return 550.3. +Episode has 500 steps and return 458.1. +Episode has 500 steps and return 403.1. +Episode has 500 steps and return 422.5. +Episode has 500 steps and return 437.9. +Episode has 500 steps and return 319.1. +Episode has 500 steps and return 505.8. +Episode has 500 steps and return 582.0. +Episode has 500 steps and return 480.5. +Episode has 500 steps and return 466.2. +Episode has 500 steps and return 415.4. +Episode has 500 steps and return 570.2. +Episode has 500 steps and return 441.7. +Episode has 500 steps and return 611.2. +Episode has 500 steps and return 520.0. +Episode has 500 steps and return 527.2. +Episode has 500 steps and return 407.4. +Episode has 500 steps and return 232.7. +Episode has 500 steps and return 547.3. +Episode has 500 steps and return 360.3. +Episode has 500 steps and return 450.2. +Episode has 500 steps and return 670.0. +Episode has 500 steps and return 552.5. +Episode has 500 steps and return 528.5. +Episode has 500 steps and return 541.7. +Episode has 500 steps and return 611.0. +Episode has 500 steps and return 466.6. +Episode has 500 steps and return 608.8. +Episode has 500 steps and return 451.6. +Episode has 500 steps and return 524.8. +Episode has 500 steps and return 666.8. +Episode has 500 steps and return 419.7. +Episode has 500 steps and return 436.0. +Episode has 500 steps and return 478.2. +Episode has 500 steps and return 596.9. +Episode has 500 steps and return 587.7. +Episode has 500 steps and return 677.1. +Episode has 500 steps and return 416.8. +Episode has 500 steps and return 531.3. +Episode has 500 steps and return 609.7. +Episode has 500 steps and return 538.5. +Episode has 500 steps and return 619.3. +Episode has 500 steps and return 510.6. +Episode has 500 steps and return 453.9. +Episode has 500 steps and return 540.1. +Episode has 500 steps and return 601.7. +Episode has 500 steps and return 523.1. +Episode has 500 steps and return 626.5. +Episode has 500 steps and return 568.7. +Episode has 500 steps and return 606.8. +Episode has 500 steps and return 575.8. +Episode has 500 steps and return 648.1. +Episode has 500 steps and return 249.3. +Episode has 500 steps and return 431.9. +Episode has 500 steps and return 414.0. +Episode has 500 steps and return 562.2. +Episode has 500 steps and return 754.8. +Episode has 500 steps and return 732.9. +Episode has 500 steps and return 616.7. +Episode has 500 steps and return 637.3. +Episode has 500 steps and return 455.7. +Episode has 500 steps and return 736.2. +Episode has 500 steps and return 718.0. +Episode has 500 steps and return 620.7. +Episode has 500 steps and return 683.3. +Episode has 500 steps and return 512.4. +Episode has 500 steps and return 654.2. +Episode has 500 steps and return 555.8. +Episode has 500 steps and return 708.6. +Episode has 500 steps and return 711.1. +Episode has 500 steps and return 741.6. +Episode has 500 steps and return 639.3. +Episode has 500 steps and return 678.0. +Episode has 500 steps and return 634.7. +Episode has 500 steps and return 586.6. +Episode has 500 steps and return 582.5. +Episode has 500 steps and return 788.8. +Episode has 500 steps and return 722.7. +Episode has 500 steps and return 759.8. +Episode has 500 steps and return 595.7. +Episode has 500 steps and return 812.6. +Episode has 500 steps and return 793.8. +Episode has 500 steps and return 688.6. +Episode has 500 steps and return 736.5. +Episode has 500 steps and return 748.2. +Episode has 500 steps and return 799.9. +Episode has 500 steps and return 798.6. +Episode has 500 steps and return 328.5. +Episode has 500 steps and return 793.1. +Episode has 500 steps and return 686.7. +Episode has 500 steps and return 697.5. +Episode has 500 steps and return 573.3. +Episode has 500 steps and return 840.9. +Episode has 500 steps and return 844.6. +Episode has 500 steps and return 862.8. +Episode has 500 steps and return 853.2. +Episode has 500 steps and return 760.1. +Episode has 500 steps and return 792.1. +Episode has 500 steps and return 871.9. +Episode has 500 steps and return 727.3. +Episode has 500 steps and return 806.6. +Episode has 500 steps and return 915.1. +Episode has 500 steps and return 753.4. +Episode has 500 steps and return 865.3. +Episode has 500 steps and return 905.4. +Episode has 500 steps and return 769.4. +Episode has 500 steps and return 898.3. +Episode has 500 steps and return 872.0. +Episode has 500 steps and return 826.6. +Episode has 500 steps and return 819.1. +Episode has 500 steps and return 775.4. +Episode has 500 steps and return 749.3. +Episode has 500 steps and return 697.5. +Episode has 500 steps and return 923.3. +Episode has 500 steps and return 800.2. +Episode has 500 steps and return 809.2. +Episode has 500 steps and return 637.9. +Episode has 500 steps and return 965.4. +Episode has 500 steps and return 878.0. +Episode has 500 steps and return 462.7. +Episode has 500 steps and return 870.2. +Episode has 500 steps and return 890.6. +Episode has 500 steps and return 888.6. +Episode has 500 steps and return 708.7. +Episode has 500 steps and return 839.5. +Episode has 500 steps and return 861.0. +Episode has 500 steps and return 884.5. +Episode has 500 steps and return 907.9. +Episode has 500 steps and return 896.9. +Episode has 500 steps and return 821.8. +Episode has 500 steps and return 917.7. +Episode has 500 steps and return 858.4. +Episode has 500 steps and return 851.0. +Episode has 500 steps and return 847.9. +Episode has 500 steps and return 752.4. +Episode has 500 steps and return 929.9. +Episode has 500 steps and return 789.4. +Episode has 500 steps and return 854.0. +Episode has 500 steps and return 913.5. +Episode has 500 steps and return 806.3. +Episode has 500 steps and return 808.1. +Episode has 500 steps and return 951.2. +Episode has 500 steps and return 944.4. +Episode has 500 steps and return 891.4. +Episode has 500 steps and return 937.1. +Episode has 500 steps and return 791.8. +Episode has 500 steps and return 944.7. +Episode has 500 steps and return 804.1. +Episode has 500 steps and return 944.1. +Episode has 500 steps and return 893.4. +Episode has 500 steps and return 879.0. +Episode has 500 steps and return 856.9. +Episode has 500 steps and return 911.5. +Episode has 500 steps and return 869.6. +Episode has 500 steps and return 709.7. +Episode has 500 steps and return 911.6. +Episode has 500 steps and return 916.0. +Episode has 500 steps and return 906.5. +Episode has 500 steps and return 893.6. +Episode has 500 steps and return 918.6. +Episode has 500 steps and return 877.6. +Episode has 500 steps and return 905.8. +Episode has 500 steps and return 931.9. +Episode has 500 steps and return 914.2. +Episode has 500 steps and return 902.2. +Episode has 500 steps and return 939.2. +Episode has 500 steps and return 852.3. +Episode has 500 steps and return 877.9. +Episode has 500 steps and return 935.4. +Episode has 500 steps and return 881.7. +Episode has 500 steps and return 861.4. +Episode has 500 steps and return 891.9. +Episode has 500 steps and return 903.1. +Episode has 500 steps and return 931.0. +Episode has 500 steps and return 897.3. +Episode has 500 steps and return 980.8. +Episode has 500 steps and return 936.5. +Episode has 500 steps and return 944.7. +Episode has 500 steps and return 780.6. +Episode has 500 steps and return 869.5. +Episode has 500 steps and return 846.1. +Episode has 500 steps and return 963.2. +Episode has 500 steps and return 904.3. +Episode has 500 steps and return 951.7. +Episode has 500 steps and return 913.4. +Episode has 500 steps and return 945.5. +Episode has 500 steps and return 963.7. +Episode has 500 steps and return 700.7. +Episode has 500 steps and return 924.4. +Episode has 500 steps and return 576.8. +Episode has 500 steps and return 959.5. +Episode has 500 steps and return 886.6. +Episode has 500 steps and return 887.9. +Episode has 500 steps and return 956.5. +Episode has 500 steps and return 959.7. +Episode has 500 steps and return 971.7. +Episode has 500 steps and return 944.7. +Episode has 500 steps and return 972.0. +Episode has 500 steps and return 941.1. +Episode has 500 steps and return 940.6. +Episode has 500 steps and return 879.0. +Episode has 500 steps and return 938.4. +Episode has 500 steps and return 964.2. +Episode has 500 steps and return 959.9. +Episode has 500 steps and return 948.9. +Episode has 500 steps and return 847.3. +Episode has 500 steps and return 967.5. +Episode has 500 steps and return 939.8. +Episode has 500 steps and return 839.4. +Episode has 500 steps and return 890.1. +Episode has 500 steps and return 629.0. +Episode has 500 steps and return 808.7. +Episode has 500 steps and return 957.7. +Episode has 500 steps and return 951.3. +Episode has 500 steps and return 875.7. +Episode has 500 steps and return 915.9. +Episode has 500 steps and return 915.8. +Episode has 500 steps and return 899.5. +Episode has 500 steps and return 918.3. +Episode has 500 steps and return 884.6. +Episode has 500 steps and return 955.4. +Episode has 500 steps and return 923.3. +Episode has 500 steps and return 956.8. +Episode has 500 steps and return 941.6. +Episode has 500 steps and return 940.0. +Episode has 500 steps and return 897.3. +Episode has 500 steps and return 957.6. +Episode has 500 steps and return 880.2. +Episode has 500 steps and return 880.7. +Episode has 500 steps and return 947.3. +Episode has 500 steps and return 781.5. +Episode has 500 steps and return 977.5. +Episode has 500 steps and return 944.2. +Episode has 500 steps and return 933.7. +Episode has 500 steps and return 866.3. +Episode has 500 steps and return 986.8. +Episode has 500 steps and return 930.0. +Episode has 500 steps and return 944.7. +Episode has 500 steps and return 814.5. +Episode has 500 steps and return 927.2. +Episode has 500 steps and return 968.0. +Episode has 500 steps and return 862.1. +Episode has 500 steps and return 941.6. +Episode has 500 steps and return 944.9. +Episode has 500 steps and return 917.0. +Episode has 500 steps and return 954.5. +Episode has 500 steps and return 942.9. +Episode has 500 steps and return 957.7. +Episode has 500 steps and return 924.8. +Episode has 500 steps and return 933.4. +Episode has 500 steps and return 983.7. +Episode has 500 steps and return 963.1. +Episode has 500 steps and return 888.6. +Episode has 500 steps and return 950.3. +Episode has 500 steps and return 946.6. +Episode has 500 steps and return 913.1. +Episode has 500 steps and return 940.6. +Episode has 500 steps and return 946.6. +Episode has 500 steps and return 947.4. +Episode has 500 steps and return 896.9. +Episode has 500 steps and return 936.2. +Episode has 500 steps and return 954.8. +Episode has 500 steps and return 908.0. +Episode has 500 steps and return 899.1. +Episode has 500 steps and return 859.9. +Episode has 500 steps and return 905.3. +Episode has 500 steps and return 918.8. +Episode has 500 steps and return 876.7. +Episode has 500 steps and return 948.1. +Episode has 500 steps and return 950.8. +Episode has 500 steps and return 829.3. +Episode has 500 steps and return 985.3. +Episode has 500 steps and return 923.1. +Episode has 500 steps and return 958.6. +Episode has 500 steps and return 985.6. +Episode has 500 steps and return 910.7. +Episode has 500 steps and return 911.0. +Episode has 500 steps and return 711.1. +Episode has 500 steps and return 957.2. +Episode has 500 steps and return 828.6. +Episode has 500 steps and return 960.4. +Episode has 500 steps and return 949.7. +Episode has 500 steps and return 973.5. +Episode has 500 steps and return 888.0. +Episode has 500 steps and return 955.3. +Episode has 500 steps and return 962.0. +Episode has 500 steps and return 900.8. +Episode has 500 steps and return 980.7. +Episode has 500 steps and return 900.0. +Episode has 500 steps and return 919.3. +Episode has 500 steps and return 888.7. +Episode has 500 steps and return 933.6. +Episode has 500 steps and return 957.7. +Episode has 500 steps and return 915.5. +Episode has 500 steps and return 858.3. +Episode has 500 steps and return 948.3. +Episode has 500 steps and return 985.8. +Episode has 500 steps and return 970.6. +Episode has 500 steps and return 887.0. +Episode has 500 steps and return 971.1. +Episode has 500 steps and return 885.5. +Episode has 500 steps and return 935.9. +Episode has 500 steps and return 949.7. +Episode has 500 steps and return 940.5. +Episode has 500 steps and return 943.3. +Episode has 500 steps and return 879.6. +Episode has 500 steps and return 959.8. +Episode has 500 steps and return 972.0. +Episode has 500 steps and return 947.0. +Episode has 500 steps and return 868.3. +Episode has 500 steps and return 961.0. +Episode has 500 steps and return 970.9. +Episode has 500 steps and return 842.1. +Episode has 500 steps and return 982.6. +Episode has 500 steps and return 939.8. +Episode has 500 steps and return 964.5. +Episode has 500 steps and return 920.9. +Episode has 500 steps and return 917.2. +Episode has 500 steps and return 959.7. +Episode has 500 steps and return 933.0. +Episode has 500 steps and return 922.8. +Episode has 500 steps and return 919.7. +Episode has 500 steps and return 921.1. +Episode has 500 steps and return 945.8. +Episode has 500 steps and return 901.1. +Episode has 500 steps and return 838.8. +Episode has 500 steps and return 901.7. +Episode has 500 steps and return 950.5. +Episode has 500 steps and return 946.3. +Episode has 500 steps and return 862.8. +Episode has 500 steps and return 819.3. +Episode has 500 steps and return 929.2. +Episode has 500 steps and return 960.2. +Episode has 500 steps and return 915.0. +Episode has 500 steps and return 907.7. +Episode has 500 steps and return 884.8. +Episode has 500 steps and return 953.6. +Episode has 500 steps and return 939.3. +Episode has 500 steps and return 885.7. +Episode has 500 steps and return 906.4. +Episode has 500 steps and return 892.6. +Episode has 500 steps and return 882.8. +Episode has 500 steps and return 944.3. +Episode has 500 steps and return 948.4. +Episode has 500 steps and return 935.3. +Episode has 500 steps and return 946.1. +Episode has 500 steps and return 974.2. +Episode has 500 steps and return 948.4. +Episode has 500 steps and return 979.5. +Episode has 500 steps and return 906.0. +Episode has 500 steps and return 924.5. +Episode has 500 steps and return 930.4. \ No newline at end of file diff --git a/our.txt b/our.txt new file mode 100644 index 00000000..26bd4cc6 --- /dev/null +++ b/our.txt @@ -0,0 +1,1004 @@ +Rank-0: policy_step=2000, reward_env_0=34.73862075805664 +Rank-0: policy_step=2000, reward_env_1=42.344600677490234 +Rank-0: policy_step=2000, reward_env_2=47.81229019165039 +Rank-0: policy_step=2000, reward_env_3=41.856040954589844 +Rank-0: policy_step=4000, reward_env_0=41.78125 +Rank-0: policy_step=4000, reward_env_1=38.49329376220703 +Rank-0: policy_step=4000, reward_env_2=49.31068801879883 +Rank-0: policy_step=4000, reward_env_3=45.358585357666016 +Rank-0: policy_step=6000, reward_env_0=81.88333129882812 +Rank-0: policy_step=6000, reward_env_1=89.1480712890625 +Rank-0: policy_step=6000, reward_env_2=48.11588668823242 +Rank-0: policy_step=6000, reward_env_3=98.4811019897461 +Rank-0: policy_step=8000, reward_env_0=68.72354888916016 +Rank-0: policy_step=8000, reward_env_1=66.5965576171875 +Rank-0: policy_step=8000, reward_env_2=82.55899810791016 +Rank-0: policy_step=8000, reward_env_3=132.6808624267578 +Rank-0: policy_step=10000, reward_env_0=118.16141510009766 +Rank-0: policy_step=10000, reward_env_1=74.5601806640625 +Rank-0: policy_step=10000, reward_env_2=142.609130859375 +Rank-0: policy_step=10000, reward_env_3=139.7652130126953 +Rank-0: policy_step=12000, reward_env_0=188.80462646484375 +Rank-0: policy_step=12000, reward_env_1=184.95005798339844 +Rank-0: policy_step=12000, reward_env_2=86.33712768554688 +Rank-0: policy_step=12000, reward_env_3=155.44322204589844 +Rank-0: policy_step=14000, reward_env_0=76.76958465576172 +Rank-0: policy_step=14000, reward_env_1=175.488525390625 +Rank-0: policy_step=14000, reward_env_2=62.65166091918945 +Rank-0: policy_step=14000, reward_env_3=142.7516632080078 +Rank-0: policy_step=16000, reward_env_0=185.68272399902344 +Rank-0: policy_step=16000, reward_env_1=216.90252685546875 +Rank-0: policy_step=16000, reward_env_2=276.9674987792969 +Rank-0: policy_step=16000, reward_env_3=202.0769805908203 +Rank-0: policy_step=18000, reward_env_0=188.1887664794922 +Rank-0: policy_step=18000, reward_env_1=263.4117431640625 +Rank-0: policy_step=18000, reward_env_2=289.2801208496094 +Rank-0: policy_step=18000, reward_env_3=184.7601318359375 +Rank-0: policy_step=20000, reward_env_0=311.6178283691406 +Rank-0: policy_step=20000, reward_env_1=346.6574401855469 +Rank-0: policy_step=20000, reward_env_2=228.565185546875 +Rank-0: policy_step=20000, reward_env_3=277.8114013671875 +Rank-0: policy_step=22000, reward_env_0=263.11260986328125 +Rank-0: policy_step=22000, reward_env_1=298.3888854980469 +Rank-0: policy_step=22000, reward_env_2=324.6815490722656 +Rank-0: policy_step=22000, reward_env_3=382.9852294921875 +Rank-0: policy_step=24000, reward_env_0=375.1632995605469 +Rank-0: policy_step=24000, reward_env_1=369.400390625 +Rank-0: policy_step=24000, reward_env_2=381.8180847167969 +Rank-0: policy_step=24000, reward_env_3=398.2370910644531 +Rank-0: policy_step=26000, reward_env_0=404.8024597167969 +Rank-0: policy_step=26000, reward_env_1=354.73333740234375 +Rank-0: policy_step=26000, reward_env_2=390.251220703125 +Rank-0: policy_step=26000, reward_env_3=383.8092041015625 +Rank-0: policy_step=28000, reward_env_0=414.41278076171875 +Rank-0: policy_step=28000, reward_env_1=291.2098388671875 +Rank-0: policy_step=28000, reward_env_2=386.4712829589844 +Rank-0: policy_step=28000, reward_env_3=391.07366943359375 +Rank-0: policy_step=30000, reward_env_0=375.96124267578125 +Rank-0: policy_step=30000, reward_env_1=430.10546875 +Rank-0: policy_step=30000, reward_env_2=378.39630126953125 +Rank-0: policy_step=30000, reward_env_3=396.7026062011719 +Rank-0: policy_step=32000, reward_env_0=373.9325866699219 +Rank-0: policy_step=32000, reward_env_1=485.3587951660156 +Rank-0: policy_step=32000, reward_env_2=377.4389953613281 +Rank-0: policy_step=32000, reward_env_3=347.692626953125 +Rank-0: policy_step=34000, reward_env_0=416.3648681640625 +Rank-0: policy_step=34000, reward_env_1=438.3783264160156 +Rank-0: policy_step=34000, reward_env_2=444.15673828125 +Rank-0: policy_step=34000, reward_env_3=445.5474853515625 +Rank-0: policy_step=36000, reward_env_0=453.7706298828125 +Rank-0: policy_step=36000, reward_env_1=424.9276123046875 +Rank-0: policy_step=36000, reward_env_2=506.4404602050781 +Rank-0: policy_step=36000, reward_env_3=447.3998718261719 +Rank-0: policy_step=38000, reward_env_0=400.382080078125 +Rank-0: policy_step=38000, reward_env_1=419.72625732421875 +Rank-0: policy_step=38000, reward_env_2=332.4638977050781 +Rank-0: policy_step=38000, reward_env_3=211.37547302246094 +Rank-0: policy_step=40000, reward_env_0=353.8775634765625 +Rank-0: policy_step=40000, reward_env_1=389.2950134277344 +Rank-0: policy_step=40000, reward_env_2=343.6236267089844 +Rank-0: policy_step=40000, reward_env_3=371.2104187011719 +Rank-0: policy_step=42000, reward_env_0=484.0337829589844 +Rank-0: policy_step=42000, reward_env_1=401.7615661621094 +Rank-0: policy_step=42000, reward_env_2=409.39385986328125 +Rank-0: policy_step=42000, reward_env_3=358.9210205078125 +Rank-0: policy_step=44000, reward_env_0=387.3330078125 +Rank-0: policy_step=44000, reward_env_1=412.122802734375 +Rank-0: policy_step=44000, reward_env_2=500.48443603515625 +Rank-0: policy_step=44000, reward_env_3=447.0583190917969 +Rank-0: policy_step=46000, reward_env_0=408.05352783203125 +Rank-0: policy_step=46000, reward_env_1=398.016845703125 +Rank-0: policy_step=46000, reward_env_2=332.7139587402344 +Rank-0: policy_step=46000, reward_env_3=301.12091064453125 +Rank-0: policy_step=48000, reward_env_0=414.05938720703125 +Rank-0: policy_step=48000, reward_env_1=378.2053527832031 +Rank-0: policy_step=48000, reward_env_2=314.1753234863281 +Rank-0: policy_step=48000, reward_env_3=433.94488525390625 +Rank-0: policy_step=50000, reward_env_0=444.0424499511719 +Rank-0: policy_step=50000, reward_env_1=325.40447998046875 +Rank-0: policy_step=50000, reward_env_2=515.9829711914062 +Rank-0: policy_step=50000, reward_env_3=330.1351623535156 +Rank-0: policy_step=52000, reward_env_0=418.6585693359375 +Rank-0: policy_step=52000, reward_env_1=287.8473205566406 +Rank-0: policy_step=52000, reward_env_2=332.3724670410156 +Rank-0: policy_step=52000, reward_env_3=308.03717041015625 +Rank-0: policy_step=54000, reward_env_0=471.9745178222656 +Rank-0: policy_step=54000, reward_env_1=470.923583984375 +Rank-0: policy_step=54000, reward_env_2=516.7538452148438 +Rank-0: policy_step=54000, reward_env_3=457.2450256347656 +Rank-0: policy_step=56000, reward_env_0=415.8127136230469 +Rank-0: policy_step=56000, reward_env_1=486.84405517578125 +Rank-0: policy_step=56000, reward_env_2=386.1386413574219 +Rank-0: policy_step=56000, reward_env_3=463.2752990722656 +Rank-0: policy_step=58000, reward_env_0=574.1663208007812 +Rank-0: policy_step=58000, reward_env_1=505.2137756347656 +Rank-0: policy_step=58000, reward_env_2=540.8296508789062 +Rank-0: policy_step=58000, reward_env_3=486.4355773925781 +Rank-0: policy_step=60000, reward_env_0=570.690673828125 +Rank-0: policy_step=60000, reward_env_1=511.0129699707031 +Rank-0: policy_step=60000, reward_env_2=415.1099853515625 +Rank-0: policy_step=60000, reward_env_3=468.572021484375 +Rank-0: policy_step=62000, reward_env_0=425.178466796875 +Rank-0: policy_step=62000, reward_env_1=387.4505615234375 +Rank-0: policy_step=62000, reward_env_2=413.6191101074219 +Rank-0: policy_step=62000, reward_env_3=400.85174560546875 +Rank-0: policy_step=64000, reward_env_0=568.7259521484375 +Rank-0: policy_step=64000, reward_env_1=533.4554443359375 +Rank-0: policy_step=64000, reward_env_2=600.3287353515625 +Rank-0: policy_step=64000, reward_env_3=535.531982421875 +Rank-0: policy_step=66000, reward_env_0=422.3890380859375 +Rank-0: policy_step=66000, reward_env_1=516.184814453125 +Rank-0: policy_step=66000, reward_env_2=470.21258544921875 +Rank-0: policy_step=66000, reward_env_3=445.8867492675781 +Rank-0: policy_step=68000, reward_env_0=468.1947937011719 +Rank-0: policy_step=68000, reward_env_1=545.9535522460938 +Rank-0: policy_step=68000, reward_env_2=526.6798706054688 +Rank-0: policy_step=68000, reward_env_3=442.2272644042969 +Rank-0: policy_step=70000, reward_env_0=505.8017578125 +Rank-0: policy_step=70000, reward_env_1=578.27392578125 +Rank-0: policy_step=70000, reward_env_2=588.2696533203125 +Rank-0: policy_step=70000, reward_env_3=546.4624633789062 +Rank-0: policy_step=72000, reward_env_0=530.622802734375 +Rank-0: policy_step=72000, reward_env_1=466.9184875488281 +Rank-0: policy_step=72000, reward_env_2=519.3150024414062 +Rank-0: policy_step=72000, reward_env_3=494.05035400390625 +Rank-0: policy_step=74000, reward_env_0=528.94287109375 +Rank-0: policy_step=74000, reward_env_1=570.19091796875 +Rank-0: policy_step=74000, reward_env_2=460.4098815917969 +Rank-0: policy_step=74000, reward_env_3=570.50927734375 +Rank-0: policy_step=76000, reward_env_0=556.430908203125 +Rank-0: policy_step=76000, reward_env_1=482.764892578125 +Rank-0: policy_step=76000, reward_env_2=594.02490234375 +Rank-0: policy_step=76000, reward_env_3=573.9700927734375 +Rank-0: policy_step=78000, reward_env_0=443.87994384765625 +Rank-0: policy_step=78000, reward_env_1=563.2550659179688 +Rank-0: policy_step=78000, reward_env_2=521.17919921875 +Rank-0: policy_step=78000, reward_env_3=352.4790954589844 +Rank-0: policy_step=80000, reward_env_0=536.4426879882812 +Rank-0: policy_step=80000, reward_env_1=409.7697448730469 +Rank-0: policy_step=80000, reward_env_2=517.1969604492188 +Rank-0: policy_step=80000, reward_env_3=519.5016479492188 +Rank-0: policy_step=82000, reward_env_0=473.44415283203125 +Rank-0: policy_step=82000, reward_env_1=554.3283081054688 +Rank-0: policy_step=82000, reward_env_2=471.43060302734375 +Rank-0: policy_step=82000, reward_env_3=486.87945556640625 +Rank-0: policy_step=84000, reward_env_0=444.3627014160156 +Rank-0: policy_step=84000, reward_env_1=623.5541381835938 +Rank-0: policy_step=84000, reward_env_2=561.1341552734375 +Rank-0: policy_step=84000, reward_env_3=632.6451416015625 +Rank-0: policy_step=86000, reward_env_0=452.0357360839844 +Rank-0: policy_step=86000, reward_env_1=457.6752624511719 +Rank-0: policy_step=86000, reward_env_2=462.83270263671875 +Rank-0: policy_step=86000, reward_env_3=633.3515625 +Rank-0: policy_step=88000, reward_env_0=496.31475830078125 +Rank-0: policy_step=88000, reward_env_1=524.0308227539062 +Rank-0: policy_step=88000, reward_env_2=446.4565124511719 +Rank-0: policy_step=88000, reward_env_3=528.0741577148438 +Rank-0: policy_step=90000, reward_env_0=526.8228759765625 +Rank-0: policy_step=90000, reward_env_1=692.5054931640625 +Rank-0: policy_step=90000, reward_env_2=558.9354248046875 +Rank-0: policy_step=90000, reward_env_3=668.9599609375 +Rank-0: policy_step=92000, reward_env_0=680.6822509765625 +Rank-0: policy_step=92000, reward_env_1=600.8048095703125 +Rank-0: policy_step=92000, reward_env_2=509.6063537597656 +Rank-0: policy_step=92000, reward_env_3=573.4466552734375 +Rank-0: policy_step=94000, reward_env_0=712.0780639648438 +Rank-0: policy_step=94000, reward_env_1=632.1633911132812 +Rank-0: policy_step=94000, reward_env_2=664.1851196289062 +Rank-0: policy_step=94000, reward_env_3=767.8641967773438 +Rank-0: policy_step=96000, reward_env_0=716.1005249023438 +Rank-0: policy_step=96000, reward_env_1=689.6419677734375 +Rank-0: policy_step=96000, reward_env_2=694.5114135742188 +Rank-0: policy_step=96000, reward_env_3=623.5415649414062 +Rank-0: policy_step=98000, reward_env_0=717.7392578125 +Rank-0: policy_step=98000, reward_env_1=693.0969848632812 +Rank-0: policy_step=98000, reward_env_2=720.88671875 +Rank-0: policy_step=98000, reward_env_3=564.4533081054688 +Rank-0: policy_step=100000, reward_env_0=652.6953735351562 +Rank-0: policy_step=100000, reward_env_1=606.155517578125 +Rank-0: policy_step=100000, reward_env_2=650.3914184570312 +Rank-0: policy_step=100000, reward_env_3=644.356689453125 +Rank-0: policy_step=102000, reward_env_0=696.5504760742188 +Rank-0: policy_step=102000, reward_env_1=830.15966796875 +Rank-0: policy_step=102000, reward_env_2=702.1847534179688 +Rank-0: policy_step=102000, reward_env_3=695.056396484375 +Rank-0: policy_step=104000, reward_env_0=705.6522827148438 +Rank-0: policy_step=104000, reward_env_1=721.8042602539062 +Rank-0: policy_step=104000, reward_env_2=661.4934692382812 +Rank-0: policy_step=104000, reward_env_3=630.6600341796875 +Rank-0: policy_step=106000, reward_env_0=733.650634765625 +Rank-0: policy_step=106000, reward_env_1=684.9617919921875 +Rank-0: policy_step=106000, reward_env_2=773.5457763671875 +Rank-0: policy_step=106000, reward_env_3=767.1033325195312 +Rank-0: policy_step=108000, reward_env_0=762.7892456054688 +Rank-0: policy_step=108000, reward_env_1=659.6124267578125 +Rank-0: policy_step=108000, reward_env_2=719.6046142578125 +Rank-0: policy_step=108000, reward_env_3=829.75390625 +Rank-0: policy_step=110000, reward_env_0=775.33740234375 +Rank-0: policy_step=110000, reward_env_1=748.4049682617188 +Rank-0: policy_step=110000, reward_env_2=775.7978515625 +Rank-0: policy_step=110000, reward_env_3=667.773681640625 +Rank-0: policy_step=112000, reward_env_0=795.8703002929688 +Rank-0: policy_step=112000, reward_env_1=807.1406860351562 +Rank-0: policy_step=112000, reward_env_2=891.5454711914062 +Rank-0: policy_step=112000, reward_env_3=716.4409790039062 +Rank-0: policy_step=114000, reward_env_0=800.0789184570312 +Rank-0: policy_step=114000, reward_env_1=748.317138671875 +Rank-0: policy_step=114000, reward_env_2=712.0599975585938 +Rank-0: policy_step=114000, reward_env_3=809.2642211914062 +Rank-0: policy_step=116000, reward_env_0=836.1480102539062 +Rank-0: policy_step=116000, reward_env_1=788.550048828125 +Rank-0: policy_step=116000, reward_env_2=710.2114868164062 +Rank-0: policy_step=116000, reward_env_3=678.5193481445312 +Rank-0: policy_step=118000, reward_env_0=733.9635009765625 +Rank-0: policy_step=118000, reward_env_1=750.5971069335938 +Rank-0: policy_step=118000, reward_env_2=811.8917846679688 +Rank-0: policy_step=118000, reward_env_3=832.9111938476562 +Rank-0: policy_step=120000, reward_env_0=736.3533325195312 +Rank-0: policy_step=120000, reward_env_1=894.9639892578125 +Rank-0: policy_step=120000, reward_env_2=894.9337768554688 +Rank-0: policy_step=120000, reward_env_3=847.1104125976562 +Rank-0: policy_step=122000, reward_env_0=920.4165649414062 +Rank-0: policy_step=122000, reward_env_1=823.3157958984375 +Rank-0: policy_step=122000, reward_env_2=905.22021484375 +Rank-0: policy_step=122000, reward_env_3=850.3617553710938 +Rank-0: policy_step=124000, reward_env_0=912.7060546875 +Rank-0: policy_step=124000, reward_env_1=935.6702880859375 +Rank-0: policy_step=124000, reward_env_2=855.1871337890625 +Rank-0: policy_step=124000, reward_env_3=867.9970703125 +Rank-0: policy_step=126000, reward_env_0=769.3657836914062 +Rank-0: policy_step=126000, reward_env_1=851.39404296875 +Rank-0: policy_step=126000, reward_env_2=675.6405029296875 +Rank-0: policy_step=126000, reward_env_3=833.4070434570312 +Rank-0: policy_step=128000, reward_env_0=894.7110595703125 +Rank-0: policy_step=128000, reward_env_1=907.6494750976562 +Rank-0: policy_step=128000, reward_env_2=886.9708862304688 +Rank-0: policy_step=128000, reward_env_3=913.380432128906 +Rank-0: policy_step=130000, reward_env_0=769.1718139648438 +Rank-0: policy_step=130000, reward_env_1=697.7454223632812 +Rank-0: policy_step=130000, reward_env_2=855.5421752929688 +Rank-0: policy_step=130000, reward_env_3=822.5703735351562 +Rank-0: policy_step=132000, reward_env_0=891.750732421875 +Rank-0: policy_step=132000, reward_env_1=858.8231811523438 +Rank-0: policy_step=132000, reward_env_2=878.7779541015625 +Rank-0: policy_step=132000, reward_env_3=791.3135375976562 +Rank-0: policy_step=134000, reward_env_0=594.6578979492188 +Rank-0: policy_step=134000, reward_env_1=566.0285034179688 +Rank-0: policy_step=134000, reward_env_2=708.566162109375 +Rank-0: policy_step=134000, reward_env_3=651.6737060546875 +Rank-0: policy_step=136000, reward_env_0=577.1491088867188 +Rank-0: policy_step=136000, reward_env_1=684.2374877929688 +Rank-0: policy_step=136000, reward_env_2=644.9037475585938 +Rank-0: policy_step=136000, reward_env_3=661.53271484375 +Rank-0: policy_step=138000, reward_env_0=681.3390502929688 +Rank-0: policy_step=138000, reward_env_1=240.81495666503906 +Rank-0: policy_step=138000, reward_env_2=682.992919921875 +Rank-0: policy_step=138000, reward_env_3=645.79443359375 +Rank-0: policy_step=140000, reward_env_0=722.959228515625 +Rank-0: policy_step=140000, reward_env_1=740.4248046875 +Rank-0: policy_step=140000, reward_env_2=664.8697509765625 +Rank-0: policy_step=140000, reward_env_3=747.2042236328125 +Rank-0: policy_step=142000, reward_env_0=726.0316162109375 +Rank-0: policy_step=142000, reward_env_1=729.0147705078125 +Rank-0: policy_step=142000, reward_env_2=667.4451293945312 +Rank-0: policy_step=142000, reward_env_3=748.6612548828125 +Rank-0: policy_step=144000, reward_env_0=814.0946044921875 +Rank-0: policy_step=144000, reward_env_1=846.5692138671875 +Rank-0: policy_step=144000, reward_env_2=729.8314208984375 +Rank-0: policy_step=144000, reward_env_3=748.5468139648438 +Rank-0: policy_step=146000, reward_env_0=750.8712768554688 +Rank-0: policy_step=146000, reward_env_1=792.1831665039062 +Rank-0: policy_step=146000, reward_env_2=805.902587890625 +Rank-0: policy_step=146000, reward_env_3=712.8002319335938 +Rank-0: policy_step=148000, reward_env_0=848.4915161132812 +Rank-0: policy_step=148000, reward_env_1=909.7400512695312 +Rank-0: policy_step=148000, reward_env_2=832.5953369140625 +Rank-0: policy_step=148000, reward_env_3=868.4920043945312 +Rank-0: policy_step=150000, reward_env_0=584.99951171875 +Rank-0: policy_step=150000, reward_env_1=634.89111328125 +Rank-0: policy_step=150000, reward_env_2=636.4849243164062 +Rank-0: policy_step=150000, reward_env_3=657.2733764648438 +Rank-0: policy_step=152000, reward_env_0=710.2503662109375 +Rank-0: policy_step=152000, reward_env_1=636.9563598632812 +Rank-0: policy_step=152000, reward_env_2=643.7001342773438 +Rank-0: policy_step=152000, reward_env_3=684.23681640625 +Rank-0: policy_step=154000, reward_env_0=805.3668823242188 +Rank-0: policy_step=154000, reward_env_1=861.1378784179688 +Rank-0: policy_step=154000, reward_env_2=850.2848510742188 +Rank-0: policy_step=154000, reward_env_3=815.2654418945312 +Rank-0: policy_step=156000, reward_env_0=920.6705322265625 +Rank-0: policy_step=156000, reward_env_1=872.3659057617188 +Rank-0: policy_step=156000, reward_env_2=816.8571166992188 +Rank-0: policy_step=156000, reward_env_3=937.50390625 +Rank-0: policy_step=158000, reward_env_0=799.9392700195312 +Rank-0: policy_step=158000, reward_env_1=905.5791625976562 +Rank-0: policy_step=158000, reward_env_2=857.4993896484375 +Rank-0: policy_step=158000, reward_env_3=879.02197265625 +Rank-0: policy_step=160000, reward_env_0=849.5126342773438 +Rank-0: policy_step=160000, reward_env_1=818.5578002929688 +Rank-0: policy_step=160000, reward_env_2=888.0670166015625 +Rank-0: policy_step=160000, reward_env_3=814.2349853515625 +Rank-0: policy_step=162000, reward_env_0=691.1488037109375 +Rank-0: policy_step=162000, reward_env_1=760.5980834960938 +Rank-0: policy_step=162000, reward_env_2=852.7131958007812 +Rank-0: policy_step=162000, reward_env_3=768.8295288085938 +Rank-0: policy_step=164000, reward_env_0=906.2494506835938 +Rank-0: policy_step=164000, reward_env_1=802.9567260742188 +Rank-0: policy_step=164000, reward_env_2=809.2301025390625 +Rank-0: policy_step=164000, reward_env_3=823.1631469726562 +Rank-0: policy_step=166000, reward_env_0=825.5352172851562 +Rank-0: policy_step=166000, reward_env_1=852.8405151367188 +Rank-0: policy_step=166000, reward_env_2=769.0669555664062 +Rank-0: policy_step=166000, reward_env_3=895.240966796875 +Rank-0: policy_step=168000, reward_env_0=764.7465209960938 +Rank-0: policy_step=168000, reward_env_1=727.8375244140625 +Rank-0: policy_step=168000, reward_env_2=673.0181274414062 +Rank-0: policy_step=168000, reward_env_3=816.3668823242188 +Rank-0: policy_step=170000, reward_env_0=234.01513671875 +Rank-0: policy_step=170000, reward_env_1=882.3270263671875 +Rank-0: policy_step=170000, reward_env_2=862.6891479492188 +Rank-0: policy_step=170000, reward_env_3=888.3853759765625 +Rank-0: policy_step=172000, reward_env_0=778.1363525390625 +Rank-0: policy_step=172000, reward_env_1=758.3740234375 +Rank-0: policy_step=172000, reward_env_2=784.8538818359375 +Rank-0: policy_step=172000, reward_env_3=775.268310546875 +Rank-0: policy_step=174000, reward_env_0=832.8033447265625 +Rank-0: policy_step=174000, reward_env_1=810.05224609375 +Rank-0: policy_step=174000, reward_env_2=754.3297119140625 +Rank-0: policy_step=174000, reward_env_3=496.98004150390625 +Rank-0: policy_step=176000, reward_env_0=803.17041015625 +Rank-0: policy_step=176000, reward_env_1=839.056884765625 +Rank-0: policy_step=176000, reward_env_2=817.6718139648438 +Rank-0: policy_step=176000, reward_env_3=865.02099609375 +Rank-0: policy_step=178000, reward_env_0=685.6907348632812 +Rank-0: policy_step=178000, reward_env_1=717.9905395507812 +Rank-0: policy_step=178000, reward_env_2=684.7826538085938 +Rank-0: policy_step=178000, reward_env_3=757.7161865234375 +Rank-0: policy_step=180000, reward_env_0=863.4733276367188 +Rank-0: policy_step=180000, reward_env_1=836.3515625 +Rank-0: policy_step=180000, reward_env_2=843.3726806640625 +Rank-0: policy_step=180000, reward_env_3=844.8733520507812 +Rank-0: policy_step=180000, reward_env_0=863.4733276367188 +Rank-0: policy_step=180000, reward_env_1=836.3515625 +Rank-0: policy_step=180000, reward_env_2=843.3726806640625 +Rank-0: policy_step=180000, reward_env_3=844.8733520507812 +Rank-0: policy_step=182000, reward_env_0=845.8255004882812 +Rank-0: policy_step=182000, reward_env_1=883.538818359375 +Rank-0: policy_step=182000, reward_env_2=791.5325317382812 +Rank-0: policy_step=182000, reward_env_3=862.1351318359375 +Rank-0: policy_step=184000, reward_env_0=779.1425170898438 +Rank-0: policy_step=184000, reward_env_1=762.2304077148438 +Rank-0: policy_step=184000, reward_env_2=742.30419921875 +Rank-0: policy_step=184000, reward_env_3=831.0992431640625 +Rank-0: policy_step=186000, reward_env_0=582.53076171875 +Rank-0: policy_step=186000, reward_env_1=822.92919921875 +Rank-0: policy_step=186000, reward_env_2=784.7510986328125 +Rank-0: policy_step=186000, reward_env_3=749.7599487304688 +Rank-0: policy_step=188000, reward_env_0=907.643310546875 +Rank-0: policy_step=188000, reward_env_1=888.6090698242188 +Rank-0: policy_step=188000, reward_env_2=829.2177734375 +Rank-0: policy_step=188000, reward_env_3=905.5299072265625 +Rank-0: policy_step=190000, reward_env_0=890.6513671875 +Rank-0: policy_step=190000, reward_env_1=883.7294921875 +Rank-0: policy_step=190000, reward_env_2=919.9202880859375 +Rank-0: policy_step=190000, reward_env_3=856.9088745117188 +Rank-0: policy_step=192000, reward_env_0=910.6689453125 +Rank-0: policy_step=192000, reward_env_1=887.120361328125 +Rank-0: policy_step=192000, reward_env_2=862.862060546875 +Rank-0: policy_step=192000, reward_env_3=883.4767456054688 +Rank-0: policy_step=194000, reward_env_0=871.2962036132812 +Rank-0: policy_step=194000, reward_env_1=841.5816040039062 +Rank-0: policy_step=194000, reward_env_2=828.805908203125 +Rank-0: policy_step=194000, reward_env_3=871.89697265625 +Rank-0: policy_step=196000, reward_env_0=854.8218994140625 +Rank-0: policy_step=196000, reward_env_1=897.5107421875 +Rank-0: policy_step=196000, reward_env_2=918.3775024414062 +Rank-0: policy_step=196000, reward_env_3=868.5860595703125 +Rank-0: policy_step=198000, reward_env_0=857.48779296875 +Rank-0: policy_step=198000, reward_env_1=878.7049560546875 +Rank-0: policy_step=198000, reward_env_2=831.6140747070312 +Rank-0: policy_step=198000, reward_env_3=828.8794555664062 +Rank-0: policy_step=200000, reward_env_0=700.7992553710938 +Rank-0: policy_step=200000, reward_env_1=755.5785522460938 +Rank-0: policy_step=200000, reward_env_2=797.6727294921875 +Rank-0: policy_step=200000, reward_env_3=698.5155029296875 +Rank-0: policy_step=202000, reward_env_0=846.9471435546875 +Rank-0: policy_step=202000, reward_env_1=857.8955078125 +Rank-0: policy_step=202000, reward_env_2=919.7608032226562 +Rank-0: policy_step=202000, reward_env_3=778.7256469726562 +Rank-0: policy_step=204000, reward_env_0=814.895263671875 +Rank-0: policy_step=204000, reward_env_1=771.8240966796875 +Rank-0: policy_step=204000, reward_env_2=838.2137451171875 +Rank-0: policy_step=204000, reward_env_3=880.1572265625 +Rank-0: policy_step=206000, reward_env_0=695.6881103515625 +Rank-0: policy_step=206000, reward_env_1=700.8348999023438 +Rank-0: policy_step=206000, reward_env_2=778.7178955078125 +Rank-0: policy_step=206000, reward_env_3=707.680908203125 +Rank-0: policy_step=208000, reward_env_0=876.1331176757812 +Rank-0: policy_step=208000, reward_env_1=811.8592529296875 +Rank-0: policy_step=208000, reward_env_2=623.7986450195312 +Rank-0: policy_step=208000, reward_env_3=762.5757446289062 +Rank-0: policy_step=210000, reward_env_0=791.9050903320312 +Rank-0: policy_step=210000, reward_env_1=884.3527221679688 +Rank-0: policy_step=210000, reward_env_2=846.5733642578125 +Rank-0: policy_step=210000, reward_env_3=893.745361328125 +Rank-0: policy_step=212000, reward_env_0=869.674072265625 +Rank-0: policy_step=212000, reward_env_1=773.3558349609375 +Rank-0: policy_step=212000, reward_env_2=869.585693359375 +Rank-0: policy_step=212000, reward_env_3=857.3773803710938 +Rank-0: policy_step=214000, reward_env_0=923.8992919921875 +Rank-0: policy_step=214000, reward_env_1=925.374267578125 +Rank-0: policy_step=214000, reward_env_2=918.5689086914062 +Rank-0: policy_step=214000, reward_env_3=889.7066650390625 +Rank-0: policy_step=216000, reward_env_0=886.3828735351562 +Rank-0: policy_step=216000, reward_env_1=813.0368041992188 +Rank-0: policy_step=216000, reward_env_2=878.5695190429688 +Rank-0: policy_step=216000, reward_env_3=791.1290893554688 +Rank-0: policy_step=218000, reward_env_0=865.8330078125 +Rank-0: policy_step=218000, reward_env_1=905.0578002929688 +Rank-0: policy_step=218000, reward_env_2=889.3573608398438 +Rank-0: policy_step=218000, reward_env_3=829.8511962890625 +Rank-0: policy_step=220000, reward_env_0=885.14306640625 +Rank-0: policy_step=220000, reward_env_1=779.61572265625 +Rank-0: policy_step=220000, reward_env_2=845.6345825195312 +Rank-0: policy_step=220000, reward_env_3=798.9155883789062 +Rank-0: policy_step=222000, reward_env_0=849.7171020507812 +Rank-0: policy_step=222000, reward_env_1=868.869384765625 +Rank-0: policy_step=222000, reward_env_2=834.7500610351562 +Rank-0: policy_step=222000, reward_env_3=842.9816284179688 +Rank-0: policy_step=224000, reward_env_0=802.7107543945312 +Rank-0: policy_step=224000, reward_env_1=901.338623046875 +Rank-0: policy_step=224000, reward_env_2=811.5595092773438 +Rank-0: policy_step=224000, reward_env_3=888.8012084960938 +Rank-0: policy_step=226000, reward_env_0=933.8987426757812 +Rank-0: policy_step=226000, reward_env_1=849.692138671875 +Rank-0: policy_step=226000, reward_env_2=922.0040893554688 +Rank-0: policy_step=226000, reward_env_3=876.5579833984375 +Rank-0: policy_step=228000, reward_env_0=886.2613525390625 +Rank-0: policy_step=228000, reward_env_1=845.2538452148438 +Rank-0: policy_step=228000, reward_env_2=867.0548095703125 +Rank-0: policy_step=228000, reward_env_3=819.923828125 +Rank-0: policy_step=230000, reward_env_0=854.1001586914062 +Rank-0: policy_step=230000, reward_env_1=919.8405151367188 +Rank-0: policy_step=230000, reward_env_2=824.2915649414062 +Rank-0: policy_step=230000, reward_env_3=878.7169799804688 +Rank-0: policy_step=232000, reward_env_0=926.7627563476562 +Rank-0: policy_step=232000, reward_env_1=872.3309326171875 +Rank-0: policy_step=232000, reward_env_2=867.2348022460938 +Rank-0: policy_step=232000, reward_env_3=784.5524291992188 +Rank-0: policy_step=234000, reward_env_0=915.9845581054688 +Rank-0: policy_step=234000, reward_env_1=827.5287475585938 +Rank-0: policy_step=234000, reward_env_2=849.2322387695312 +Rank-0: policy_step=234000, reward_env_3=780.6039428710938 +Rank-0: policy_step=236000, reward_env_0=827.017822265625 +Rank-0: policy_step=236000, reward_env_1=904.0676879882812 +Rank-0: policy_step=236000, reward_env_2=903.7184448242188 +Rank-0: policy_step=236000, reward_env_3=909.9741821289062 +Rank-0: policy_step=238000, reward_env_0=851.303955078125 +Rank-0: policy_step=238000, reward_env_1=883.248046875 +Rank-0: policy_step=238000, reward_env_2=919.256591796875 +Rank-0: policy_step=238000, reward_env_3=838.8497314453125 +Rank-0: policy_step=240000, reward_env_0=897.4510498046875 +Rank-0: policy_step=240000, reward_env_1=836.6401977539062 +Rank-0: policy_step=240000, reward_env_2=879.9327392578125 +Rank-0: policy_step=240000, reward_env_3=927.852294921875 +Rank-0: policy_step=242000, reward_env_0=911.8667602539062 +Rank-0: policy_step=242000, reward_env_1=896.4654541015625 +Rank-0: policy_step=242000, reward_env_2=918.3038940429688 +Rank-0: policy_step=242000, reward_env_3=864.33544921875 +Rank-0: policy_step=244000, reward_env_0=845.3732299804688 +Rank-0: policy_step=244000, reward_env_1=909.016845703125 +Rank-0: policy_step=244000, reward_env_2=930.9536743164062 +Rank-0: policy_step=244000, reward_env_3=861.5154418945312 +Rank-0: policy_step=246000, reward_env_0=856.290771484375 +Rank-0: policy_step=246000, reward_env_1=917.77978515625 +Rank-0: policy_step=246000, reward_env_2=915.5521240234375 +Rank-0: policy_step=246000, reward_env_3=917.7088012695312 +Rank-0: policy_step=248000, reward_env_0=885.3877563476562 +Rank-0: policy_step=248000, reward_env_1=916.2377319335938 +Rank-0: policy_step=248000, reward_env_2=923.3497314453125 +Rank-0: policy_step=248000, reward_env_3=939.8336181640625 +Rank-0: policy_step=250000, reward_env_0=818.3768920898438 +Rank-0: policy_step=250000, reward_env_1=902.4088745117188 +Rank-0: policy_step=250000, reward_env_2=867.3248291015625 +Rank-0: policy_step=250000, reward_env_3=815.4102172851562 +Rank-0: policy_step=252000, reward_env_0=745.5842895507812 +Rank-0: policy_step=252000, reward_env_1=884.4244384765625 +Rank-0: policy_step=252000, reward_env_2=916.9805908203125 +Rank-0: policy_step=252000, reward_env_3=835.7476806640625 +Rank-0: policy_step=254000, reward_env_0=868.4482421875 +Rank-0: policy_step=254000, reward_env_1=820.7265625 +Rank-0: policy_step=254000, reward_env_2=815.6775512695312 +Rank-0: policy_step=254000, reward_env_3=879.8387451171875 +Rank-0: policy_step=256000, reward_env_0=883.0449829101562 +Rank-0: policy_step=256000, reward_env_1=902.7261962890625 +Rank-0: policy_step=256000, reward_env_2=878.898193359375 +Rank-0: policy_step=256000, reward_env_3=879.6201782226562 +Rank-0: policy_step=258000, reward_env_0=917.4191284179688 +Rank-0: policy_step=258000, reward_env_1=850.9492797851562 +Rank-0: policy_step=258000, reward_env_2=894.2022094726562 +Rank-0: policy_step=258000, reward_env_3=906.3563842773438 +Rank-0: policy_step=260000, reward_env_0=835.6770629882812 +Rank-0: policy_step=260000, reward_env_1=880.777587890625 +Rank-0: policy_step=260000, reward_env_2=871.7696533203125 +Rank-0: policy_step=260000, reward_env_3=763.43896484375 +Rank-0: policy_step=262000, reward_env_0=846.7047119140625 +Rank-0: policy_step=262000, reward_env_1=910.076171875 +Rank-0: policy_step=262000, reward_env_2=874.0935668945312 +Rank-0: policy_step=262000, reward_env_3=895.3096923828125 +Rank-0: policy_step=264000, reward_env_0=920.7540283203125 +Rank-0: policy_step=264000, reward_env_1=913.9781494140625 +Rank-0: policy_step=264000, reward_env_2=897.6683349609375 +Rank-0: policy_step=264000, reward_env_3=870.248779296875 +Rank-0: policy_step=266000, reward_env_0=930.111083984375 +Rank-0: policy_step=266000, reward_env_1=875.6416625976562 +Rank-0: policy_step=266000, reward_env_2=876.0927124023438 +Rank-0: policy_step=266000, reward_env_3=926.799560546875 +Rank-0: policy_step=268000, reward_env_0=799.584716796875 +Rank-0: policy_step=268000, reward_env_1=907.9052124023438 +Rank-0: policy_step=268000, reward_env_2=935.1431274414062 +Rank-0: policy_step=268000, reward_env_3=866.684814453125 +Rank-0: policy_step=270000, reward_env_0=880.8111572265625 +Rank-0: policy_step=270000, reward_env_1=880.47314453125 +Rank-0: policy_step=270000, reward_env_2=894.7467041015625 +Rank-0: policy_step=270000, reward_env_3=916.9739990234375 +Rank-0: policy_step=272000, reward_env_0=912.5625610351562 +Rank-0: policy_step=272000, reward_env_1=902.5364990234375 +Rank-0: policy_step=272000, reward_env_2=899.695068359375 +Rank-0: policy_step=272000, reward_env_3=881.3681030273438 +Rank-0: policy_step=274000, reward_env_0=893.7555541992188 +Rank-0: policy_step=274000, reward_env_1=917.3219604492188 +Rank-0: policy_step=274000, reward_env_2=952.9459228515625 +Rank-0: policy_step=274000, reward_env_3=941.3435668945312 +Rank-0: policy_step=276000, reward_env_0=920.78515625 +Rank-0: policy_step=276000, reward_env_1=876.2333374023438 +Rank-0: policy_step=276000, reward_env_2=897.0881958007812 +Rank-0: policy_step=276000, reward_env_3=878.8807373046875 +Rank-0: policy_step=278000, reward_env_0=880.7891235351562 +Rank-0: policy_step=278000, reward_env_1=887.3251953125 +Rank-0: policy_step=278000, reward_env_2=904.9336547851562 +Rank-0: policy_step=278000, reward_env_3=922.2870483398438 +Rank-0: policy_step=280000, reward_env_0=938.641357421875 +Rank-0: policy_step=280000, reward_env_1=895.8674926757812 +Rank-0: policy_step=280000, reward_env_2=925.9614868164062 +Rank-0: policy_step=280000, reward_env_3=899.3460083007812 +Rank-0: policy_step=282000, reward_env_0=874.5557250976562 +Rank-0: policy_step=282000, reward_env_1=829.7514038085938 +Rank-0: policy_step=282000, reward_env_2=817.3855590820312 +Rank-0: policy_step=282000, reward_env_3=829.8850708007812 +Rank-0: policy_step=284000, reward_env_0=861.7803344726562 +Rank-0: policy_step=284000, reward_env_1=864.5391845703125 +Rank-0: policy_step=284000, reward_env_2=883.0468139648438 +Rank-0: policy_step=284000, reward_env_3=847.9132080078125 +Rank-0: policy_step=286000, reward_env_0=880.6534423828125 +Rank-0: policy_step=286000, reward_env_1=918.8771362304688 +Rank-0: policy_step=286000, reward_env_2=945.2252197265625 +Rank-0: policy_step=286000, reward_env_3=941.3966064453125 +Rank-0: policy_step=288000, reward_env_0=903.0276489257812 +Rank-0: policy_step=288000, reward_env_1=896.660888671875 +Rank-0: policy_step=288000, reward_env_2=959.4469604492188 +Rank-0: policy_step=288000, reward_env_3=937.8251342773438 +Rank-0: policy_step=290000, reward_env_0=907.5549926757812 +Rank-0: policy_step=290000, reward_env_1=966.6063842773438 +Rank-0: policy_step=290000, reward_env_2=945.430908203125 +Rank-0: policy_step=290000, reward_env_3=907.5317993164062 +Rank-0: policy_step=292000, reward_env_0=934.0708618164062 +Rank-0: policy_step=292000, reward_env_1=908.2861328125 +Rank-0: policy_step=292000, reward_env_2=911.2447509765625 +Rank-0: policy_step=292000, reward_env_3=899.0462646484375 +Rank-0: policy_step=294000, reward_env_0=928.2252197265625 +Rank-0: policy_step=294000, reward_env_1=869.8588256835938 +Rank-0: policy_step=294000, reward_env_2=938.7529907226562 +Rank-0: policy_step=294000, reward_env_3=904.830078125 +Rank-0: policy_step=296000, reward_env_0=894.9407958984375 +Rank-0: policy_step=296000, reward_env_1=914.6753540039062 +Rank-0: policy_step=296000, reward_env_2=887.0993041992188 +Rank-0: policy_step=296000, reward_env_3=921.4598388671875 +Rank-0: policy_step=298000, reward_env_0=849.0245361328125 +Rank-0: policy_step=298000, reward_env_1=896.1520385742188 +Rank-0: policy_step=298000, reward_env_2=934.1355590820312 +Rank-0: policy_step=298000, reward_env_3=919.2213134765625 +Rank-0: policy_step=300000, reward_env_0=890.2568969726562 +Rank-0: policy_step=300000, reward_env_1=901.0765380859375 +Rank-0: policy_step=300000, reward_env_2=928.361328125 +Rank-0: policy_step=300000, reward_env_3=887.704345703125 +Rank-0: policy_step=302000, reward_env_0=870.42138671875 +Rank-0: policy_step=302000, reward_env_1=889.7203369140625 +Rank-0: policy_step=302000, reward_env_2=866.378173828125 +Rank-0: policy_step=302000, reward_env_3=899.8527221679688 +Rank-0: policy_step=304000, reward_env_0=873.331298828125 +Rank-0: policy_step=304000, reward_env_1=904.1979370117188 +Rank-0: policy_step=304000, reward_env_2=908.2112426757812 +Rank-0: policy_step=304000, reward_env_3=881.7921752929688 +Rank-0: policy_step=306000, reward_env_0=933.1976928710938 +Rank-0: policy_step=306000, reward_env_1=921.9896850585938 +Rank-0: policy_step=306000, reward_env_2=910.9968872070312 +Rank-0: policy_step=306000, reward_env_3=853.7877807617188 +Rank-0: policy_step=308000, reward_env_0=934.4517211914062 +Rank-0: policy_step=308000, reward_env_1=931.2173461914062 +Rank-0: policy_step=308000, reward_env_2=905.6231079101562 +Rank-0: policy_step=308000, reward_env_3=900.7759399414062 +Rank-0: policy_step=310000, reward_env_0=902.0452270507812 +Rank-0: policy_step=310000, reward_env_1=910.8877563476562 +Rank-0: policy_step=310000, reward_env_2=931.95068359375 +Rank-0: policy_step=310000, reward_env_3=911.6986083984375 +Rank-0: policy_step=312000, reward_env_0=884.3322143554688 +Rank-0: policy_step=312000, reward_env_1=901.164794921875 +Rank-0: policy_step=312000, reward_env_2=878.9514770507812 +Rank-0: policy_step=312000, reward_env_3=889.063232421875 +Rank-0: policy_step=314000, reward_env_0=921.4210205078125 +Rank-0: policy_step=314000, reward_env_1=921.128662109375 +Rank-0: policy_step=314000, reward_env_2=878.6793823242188 +Rank-0: policy_step=314000, reward_env_3=885.2361450195312 +Rank-0: policy_step=316000, reward_env_0=880.9342041015625 +Rank-0: policy_step=316000, reward_env_1=917.7034912109375 +Rank-0: policy_step=316000, reward_env_2=904.115966796875 +Rank-0: policy_step=316000, reward_env_3=905.0476684570312 +Rank-0: policy_step=318000, reward_env_0=848.7282104492188 +Rank-0: policy_step=318000, reward_env_1=804.3541259765625 +Rank-0: policy_step=318000, reward_env_2=878.5125732421875 +Rank-0: policy_step=318000, reward_env_3=829.8920288085938 +Rank-0: policy_step=320000, reward_env_0=734.6503295898438 +Rank-0: policy_step=320000, reward_env_1=835.25244140625 +Rank-0: policy_step=320000, reward_env_2=885.0934448242188 +Rank-0: policy_step=320000, reward_env_3=855.514892578125 +Rank-0: policy_step=322000, reward_env_0=776.9710083007812 +Rank-0: policy_step=322000, reward_env_1=844.5307006835938 +Rank-0: policy_step=322000, reward_env_2=802.7974853515625 +Rank-0: policy_step=322000, reward_env_3=837.8748779296875 +Rank-0: policy_step=324000, reward_env_0=872.334228515625 +Rank-0: policy_step=324000, reward_env_1=909.104248046875 +Rank-0: policy_step=324000, reward_env_2=848.8099975585938 +Rank-0: policy_step=324000, reward_env_3=796.0451049804688 +Rank-0: policy_step=326000, reward_env_0=919.63818359375 +Rank-0: policy_step=326000, reward_env_1=927.7220458984375 +Rank-0: policy_step=326000, reward_env_2=920.3660278320312 +Rank-0: policy_step=326000, reward_env_3=912.32421875 +Rank-0: policy_step=328000, reward_env_0=909.9505004882812 +Rank-0: policy_step=328000, reward_env_1=902.4457397460938 +Rank-0: policy_step=328000, reward_env_2=916.6644287109375 +Rank-0: policy_step=328000, reward_env_3=916.298583984375 +Rank-0: policy_step=330000, reward_env_0=873.3424682617188 +Rank-0: policy_step=330000, reward_env_1=869.6365966796875 +Rank-0: policy_step=330000, reward_env_2=884.9525756835938 +Rank-0: policy_step=330000, reward_env_3=851.2592163085938 +Rank-0: policy_step=332000, reward_env_0=919.3666381835938 +Rank-0: policy_step=332000, reward_env_1=891.314697265625 +Rank-0: policy_step=332000, reward_env_2=900.9784545898438 +Rank-0: policy_step=332000, reward_env_3=826.5481567382812 +Rank-0: policy_step=334000, reward_env_0=932.8573608398438 +Rank-0: policy_step=334000, reward_env_1=900.768798828125 +Rank-0: policy_step=334000, reward_env_2=851.2880249023438 +Rank-0: policy_step=334000, reward_env_3=937.3782958984375 +Rank-0: policy_step=336000, reward_env_0=906.57763671875 +Rank-0: policy_step=336000, reward_env_1=907.7730102539062 +Rank-0: policy_step=336000, reward_env_2=785.6316528320312 +Rank-0: policy_step=336000, reward_env_3=901.448486328125 +Rank-0: policy_step=338000, reward_env_0=914.640625 +Rank-0: policy_step=338000, reward_env_1=896.1820678710938 +Rank-0: policy_step=338000, reward_env_2=916.7769775390625 +Rank-0: policy_step=338000, reward_env_3=913.3696899414062 +Rank-0: policy_step=340000, reward_env_0=896.8836669921875 +Rank-0: policy_step=340000, reward_env_1=837.9439086914062 +Rank-0: policy_step=340000, reward_env_2=876.250732421875 +Rank-0: policy_step=340000, reward_env_3=876.9715576171875 +Rank-0: policy_step=342000, reward_env_0=930.7742919921875 +Rank-0: policy_step=342000, reward_env_1=963.5834350585938 +Rank-0: policy_step=342000, reward_env_2=910.2763671875 +Rank-0: policy_step=342000, reward_env_3=839.4712524414062 +Rank-0: policy_step=344000, reward_env_0=956.925048828125 +Rank-0: policy_step=344000, reward_env_1=928.4138793945312 +Rank-0: policy_step=344000, reward_env_2=960.3240356445312 +Rank-0: policy_step=344000, reward_env_3=907.0892333984375 +Rank-0: policy_step=346000, reward_env_0=894.4429321289062 +Rank-0: policy_step=346000, reward_env_1=886.9314575195312 +Rank-0: policy_step=346000, reward_env_2=930.2914428710938 +Rank-0: policy_step=346000, reward_env_3=942.2235717773438 +Rank-0: policy_step=348000, reward_env_0=866.168701171875 +Rank-0: policy_step=348000, reward_env_1=721.375732421875 +Rank-0: policy_step=348000, reward_env_2=884.9852905273438 +Rank-0: policy_step=348000, reward_env_3=785.744873046875 +Rank-0: policy_step=350000, reward_env_0=813.8825073242188 +Rank-0: policy_step=350000, reward_env_1=908.8901977539062 +Rank-0: policy_step=350000, reward_env_2=914.7741088867188 +Rank-0: policy_step=350000, reward_env_3=881.0470581054688 +Rank-0: policy_step=352000, reward_env_0=942.6838989257812 +Rank-0: policy_step=352000, reward_env_1=939.2142333984375 +Rank-0: policy_step=352000, reward_env_2=835.6941528320312 +Rank-0: policy_step=352000, reward_env_3=925.718994140625 +Rank-0: policy_step=354000, reward_env_0=884.5985107421875 +Rank-0: policy_step=354000, reward_env_1=927.3810424804688 +Rank-0: policy_step=354000, reward_env_2=930.7720336914062 +Rank-0: policy_step=354000, reward_env_3=930.6976318359375 +Rank-0: policy_step=356000, reward_env_0=935.0960083007812 +Rank-0: policy_step=356000, reward_env_1=891.68212890625 +Rank-0: policy_step=356000, reward_env_2=922.30908203125 +Rank-0: policy_step=356000, reward_env_3=924.6773681640625 +Rank-0: policy_step=358000, reward_env_0=861.39990234375 +Rank-0: policy_step=358000, reward_env_1=887.5126953125 +Rank-0: policy_step=358000, reward_env_2=912.4690551757812 +Rank-0: policy_step=358000, reward_env_3=872.5657958984375 +Rank-0: policy_step=360000, reward_env_0=871.79443359375 +Rank-0: policy_step=360000, reward_env_1=766.851806640625 +Rank-0: policy_step=360000, reward_env_2=784.69580078125 +Rank-0: policy_step=360000, reward_env_3=890.5595092773438 +Rank-0: policy_step=362000, reward_env_0=641.2073364257812 +Rank-0: policy_step=362000, reward_env_1=688.7249145507812 +Rank-0: policy_step=362000, reward_env_2=722.231201171875 +Rank-0: policy_step=362000, reward_env_3=769.2327270507812 +Rank-0: policy_step=364000, reward_env_0=835.4423217773438 +Rank-0: policy_step=364000, reward_env_1=776.77587890625 +Rank-0: policy_step=364000, reward_env_2=883.8353271484375 +Rank-0: policy_step=364000, reward_env_3=720.4803466796875 +Rank-0: policy_step=366000, reward_env_0=627.900146484375 +Rank-0: policy_step=366000, reward_env_1=708.9801025390625 +Rank-0: policy_step=366000, reward_env_2=696.0964965820312 +Rank-0: policy_step=366000, reward_env_3=700.3309936523438 +Rank-0: policy_step=368000, reward_env_0=745.8839721679688 +Rank-0: policy_step=368000, reward_env_1=733.6041870117188 +Rank-0: policy_step=368000, reward_env_2=765.8311157226562 +Rank-0: policy_step=368000, reward_env_3=646.5365600585938 +Rank-0: policy_step=370000, reward_env_0=894.57958984375 +Rank-0: policy_step=370000, reward_env_1=929.0679931640625 +Rank-0: policy_step=370000, reward_env_2=924.1478271484375 +Rank-0: policy_step=370000, reward_env_3=944.9603881835938 +Rank-0: policy_step=372000, reward_env_0=851.5604248046875 +Rank-0: policy_step=372000, reward_env_1=867.4108276367188 +Rank-0: policy_step=372000, reward_env_2=861.8362426757812 +Rank-0: policy_step=372000, reward_env_3=884.82763671875 +Rank-0: policy_step=374000, reward_env_0=895.9190673828125 +Rank-0: policy_step=374000, reward_env_1=933.0680541992188 +Rank-0: policy_step=374000, reward_env_2=878.9688720703125 +Rank-0: policy_step=374000, reward_env_3=912.167236328125 +Rank-0: policy_step=376000, reward_env_0=897.4933471679688 +Rank-0: policy_step=376000, reward_env_1=853.8679809570312 +Rank-0: policy_step=376000, reward_env_2=900.552734375 +Rank-0: policy_step=376000, reward_env_3=875.390380859375 +Rank-0: policy_step=378000, reward_env_0=911.134765625 +Rank-0: policy_step=378000, reward_env_1=758.4716186523438 +Rank-0: policy_step=378000, reward_env_2=849.4696044921875 +Rank-0: policy_step=378000, reward_env_3=935.5361938476562 +Rank-0: policy_step=380000, reward_env_0=909.4636840820312 +Rank-0: policy_step=380000, reward_env_1=898.4797973632812 +Rank-0: policy_step=380000, reward_env_2=890.9351196289062 +Rank-0: policy_step=380000, reward_env_3=824.7240600585938 +Rank-0: policy_step=382000, reward_env_0=889.1553344726562 +Rank-0: policy_step=382000, reward_env_1=959.9402465820312 +Rank-0: policy_step=382000, reward_env_2=929.06396484375 +Rank-0: policy_step=382000, reward_env_3=920.4011840820312 +Rank-0: policy_step=384000, reward_env_0=937.1408081054688 +Rank-0: policy_step=384000, reward_env_1=945.7388305664062 +Rank-0: policy_step=384000, reward_env_2=889.3087768554688 +Rank-0: policy_step=384000, reward_env_3=887.58642578125 +Rank-0: policy_step=386000, reward_env_0=954.7630615234375 +Rank-0: policy_step=386000, reward_env_1=929.4154052734375 +Rank-0: policy_step=386000, reward_env_2=948.8687133789062 +Rank-0: policy_step=386000, reward_env_3=851.774169921875 +Rank-0: policy_step=388000, reward_env_0=850.1627807617188 +Rank-0: policy_step=388000, reward_env_1=841.5476684570312 +Rank-0: policy_step=388000, reward_env_2=912.009521484375 +Rank-0: policy_step=388000, reward_env_3=875.6387939453125 +Rank-0: policy_step=390000, reward_env_0=906.4107666015625 +Rank-0: policy_step=390000, reward_env_1=857.533935546875 +Rank-0: policy_step=390000, reward_env_2=948.1854858398438 +Rank-0: policy_step=390000, reward_env_3=804.4624633789062 +Rank-0: policy_step=392000, reward_env_0=902.027099609375 +Rank-0: policy_step=392000, reward_env_1=891.5255126953125 +Rank-0: policy_step=392000, reward_env_2=817.2357788085938 +Rank-0: policy_step=392000, reward_env_3=877.6072998046875 +Rank-0: policy_step=394000, reward_env_0=683.1485595703125 +Rank-0: policy_step=394000, reward_env_1=834.3070678710938 +Rank-0: policy_step=394000, reward_env_2=924.4867553710938 +Rank-0: policy_step=394000, reward_env_3=413.3022766113281 +Rank-0: policy_step=396000, reward_env_0=635.8063354492188 +Rank-0: policy_step=396000, reward_env_1=955.8380126953125 +Rank-0: policy_step=396000, reward_env_2=926.856201171875 +Rank-0: policy_step=396000, reward_env_3=900.9290161132812 +Rank-0: policy_step=398000, reward_env_0=911.2310180664062 +Rank-0: policy_step=398000, reward_env_1=877.0919189453125 +Rank-0: policy_step=398000, reward_env_2=951.266845703125 +Rank-0: policy_step=398000, reward_env_3=931.3839721679688 +Rank-0: policy_step=400000, reward_env_0=884.9244995117188 +Rank-0: policy_step=400000, reward_env_1=816.6129150390625 +Rank-0: policy_step=400000, reward_env_2=927.2639770507812 +Rank-0: policy_step=400000, reward_env_3=887.2872314453125 +Rank-0: policy_step=402000, reward_env_0=854.2955932617188 +Rank-0: policy_step=402000, reward_env_1=871.5416870117188 +Rank-0: policy_step=402000, reward_env_2=847.7739868164062 +Rank-0: policy_step=402000, reward_env_3=802.7327880859375 +Rank-0: policy_step=404000, reward_env_0=932.5904541015625 +Rank-0: policy_step=404000, reward_env_1=856.1954956054688 +Rank-0: policy_step=404000, reward_env_2=775.369873046875 +Rank-0: policy_step=404000, reward_env_3=723.0234375 +Rank-0: policy_step=406000, reward_env_0=847.7246704101562 +Rank-0: policy_step=406000, reward_env_1=838.1256713867188 +Rank-0: policy_step=406000, reward_env_2=823.4154663085938 +Rank-0: policy_step=406000, reward_env_3=855.191650390625 +Rank-0: policy_step=408000, reward_env_0=877.60400390625 +Rank-0: policy_step=408000, reward_env_1=862.7110595703125 +Rank-0: policy_step=408000, reward_env_2=876.7628173828125 +Rank-0: policy_step=408000, reward_env_3=817.1781616210938 +Rank-0: policy_step=410000, reward_env_0=880.756103515625 +Rank-0: policy_step=410000, reward_env_1=647.1429443359375 +Rank-0: policy_step=410000, reward_env_2=850.9156494140625 +Rank-0: policy_step=410000, reward_env_3=942.177978515625 +Rank-0: policy_step=412000, reward_env_0=923.8487548828125 +Rank-0: policy_step=412000, reward_env_1=950.1604614257812 +Rank-0: policy_step=412000, reward_env_2=888.6689453125 +Rank-0: policy_step=412000, reward_env_3=909.7418823242188 +Rank-0: policy_step=414000, reward_env_0=905.9585571289062 +Rank-0: policy_step=414000, reward_env_1=898.7376708984375 +Rank-0: policy_step=414000, reward_env_2=938.0211791992188 +Rank-0: policy_step=414000, reward_env_3=864.6925048828125 +Rank-0: policy_step=416000, reward_env_0=926.9373779296875 +Rank-0: policy_step=416000, reward_env_1=910.0982666015625 +Rank-0: policy_step=416000, reward_env_2=891.2000732421875 +Rank-0: policy_step=416000, reward_env_3=873.0259399414062 +Rank-0: policy_step=418000, reward_env_0=830.7296752929688 +Rank-0: policy_step=418000, reward_env_1=792.3489379882812 +Rank-0: policy_step=418000, reward_env_2=785.37109375 +Rank-0: policy_step=418000, reward_env_3=848.4445190429688 +Rank-0: policy_step=420000, reward_env_0=885.6739501953125 +Rank-0: policy_step=420000, reward_env_1=950.7418823242188 +Rank-0: policy_step=420000, reward_env_2=859.4856567382812 +Rank-0: policy_step=420000, reward_env_3=805.8286743164062 +Rank-0: policy_step=422000, reward_env_0=845.3460693359375 +Rank-0: policy_step=422000, reward_env_1=880.4802856445312 +Rank-0: policy_step=422000, reward_env_2=855.9398193359375 +Rank-0: policy_step=422000, reward_env_3=882.0545654296875 +Rank-0: policy_step=424000, reward_env_0=945.623779296875 +Rank-0: policy_step=424000, reward_env_1=916.1929321289062 +Rank-0: policy_step=424000, reward_env_2=887.4605712890625 +Rank-0: policy_step=424000, reward_env_3=904.80419921875 +Rank-0: policy_step=426000, reward_env_0=863.6259155273438 +Rank-0: policy_step=426000, reward_env_1=911.1572875976562 +Rank-0: policy_step=426000, reward_env_2=941.548828125 +Rank-0: policy_step=426000, reward_env_3=884.1109008789062 +Rank-0: policy_step=428000, reward_env_0=823.77099609375 +Rank-0: policy_step=428000, reward_env_1=882.7049560546875 +Rank-0: policy_step=428000, reward_env_2=857.8377075195312 +Rank-0: policy_step=428000, reward_env_3=831.3613891601562 +Rank-0: policy_step=430000, reward_env_0=870.7410278320312 +Rank-0: policy_step=430000, reward_env_1=776.04052734375 +Rank-0: policy_step=430000, reward_env_2=822.292236328125 +Rank-0: policy_step=430000, reward_env_3=845.9228515625 +Rank-0: policy_step=432000, reward_env_0=827.0743408203125 +Rank-0: policy_step=432000, reward_env_1=892.0718383789062 +Rank-0: policy_step=432000, reward_env_2=861.017578125 +Rank-0: policy_step=432000, reward_env_3=828.2916259765625 +Rank-0: policy_step=434000, reward_env_0=722.2677001953125 +Rank-0: policy_step=434000, reward_env_1=861.9256591796875 +Rank-0: policy_step=434000, reward_env_2=522.3941650390625 +Rank-0: policy_step=434000, reward_env_3=843.7252197265625 +Rank-0: policy_step=436000, reward_env_0=887.9268798828125 +Rank-0: policy_step=436000, reward_env_1=858.6796875 +Rank-0: policy_step=436000, reward_env_2=881.55322265625 +Rank-0: policy_step=436000, reward_env_3=874.1316528320312 +Rank-0: policy_step=438000, reward_env_0=853.7929077148438 +Rank-0: policy_step=438000, reward_env_1=913.2722778320312 +Rank-0: policy_step=438000, reward_env_2=862.9351196289062 +Rank-0: policy_step=438000, reward_env_3=862.1657104492188 +Rank-0: policy_step=440000, reward_env_0=801.8331298828125 +Rank-0: policy_step=440000, reward_env_1=854.8385009765625 +Rank-0: policy_step=440000, reward_env_2=798.686767578125 +Rank-0: policy_step=440000, reward_env_3=887.5355224609375 +Rank-0: policy_step=442000, reward_env_0=883.7042236328125 +Rank-0: policy_step=442000, reward_env_1=864.5542602539062 +Rank-0: policy_step=442000, reward_env_2=801.9967041015625 +Rank-0: policy_step=442000, reward_env_3=853.4691162109375 +Rank-0: policy_step=444000, reward_env_0=896.71484375 +Rank-0: policy_step=444000, reward_env_1=883.6332397460938 +Rank-0: policy_step=444000, reward_env_2=882.015380859375 +Rank-0: policy_step=444000, reward_env_3=923.2923583984375 +Rank-0: policy_step=446000, reward_env_0=856.3253784179688 +Rank-0: policy_step=446000, reward_env_1=815.7265625 +Rank-0: policy_step=446000, reward_env_2=864.0433349609375 +Rank-0: policy_step=446000, reward_env_3=816.8611450195312 +Rank-0: policy_step=448000, reward_env_0=833.3370971679688 +Rank-0: policy_step=448000, reward_env_1=889.8046875 +Rank-0: policy_step=448000, reward_env_2=881.6996459960938 +Rank-0: policy_step=448000, reward_env_3=888.370361328125 +Rank-0: policy_step=450000, reward_env_0=852.7061157226562 +Rank-0: policy_step=450000, reward_env_1=831.8417358398438 +Rank-0: policy_step=450000, reward_env_2=873.1185302734375 +Rank-0: policy_step=450000, reward_env_3=872.6946411132812 +Rank-0: policy_step=452000, reward_env_0=913.0731811523438 +Rank-0: policy_step=452000, reward_env_1=759.593994140625 +Rank-0: policy_step=452000, reward_env_2=822.0515747070312 +Rank-0: policy_step=452000, reward_env_3=870.2621459960938 +Rank-0: policy_step=454000, reward_env_0=910.6627197265625 +Rank-0: policy_step=454000, reward_env_1=871.4953002929688 +Rank-0: policy_step=454000, reward_env_2=901.0242309570312 +Rank-0: policy_step=454000, reward_env_3=857.83642578125 +Rank-0: policy_step=456000, reward_env_0=818.6113891601562 +Rank-0: policy_step=456000, reward_env_1=866.4872436523438 +Rank-0: policy_step=456000, reward_env_2=762.2593994140625 +Rank-0: policy_step=456000, reward_env_3=819.9625244140625 +Rank-0: policy_step=458000, reward_env_0=890.6290283203125 +Rank-0: policy_step=458000, reward_env_1=913.1181030273438 +Rank-0: policy_step=458000, reward_env_2=912.7213134765625 +Rank-0: policy_step=458000, reward_env_3=844.5999755859375 +Rank-0: policy_step=460000, reward_env_0=905.0780639648438 +Rank-0: policy_step=460000, reward_env_1=881.4569091796875 +Rank-0: policy_step=460000, reward_env_2=839.7293701171875 +Rank-0: policy_step=460000, reward_env_3=893.1539916992188 +Rank-0: policy_step=462000, reward_env_0=883.0076293945312 +Rank-0: policy_step=462000, reward_env_1=877.7626953125 +Rank-0: policy_step=462000, reward_env_2=863.9375 +Rank-0: policy_step=462000, reward_env_3=881.8802490234375 +Rank-0: policy_step=464000, reward_env_0=883.3395385742188 +Rank-0: policy_step=464000, reward_env_1=863.7293090820312 +Rank-0: policy_step=464000, reward_env_2=846.1231689453125 +Rank-0: policy_step=464000, reward_env_3=870.9586181640625 +Rank-0: policy_step=466000, reward_env_0=884.1751098632812 +Rank-0: policy_step=466000, reward_env_1=862.9114379882812 +Rank-0: policy_step=466000, reward_env_2=818.036376953125 +Rank-0: policy_step=466000, reward_env_3=860.5357666015625 +Rank-0: policy_step=468000, reward_env_0=821.2963256835938 +Rank-0: policy_step=468000, reward_env_1=798.1824951171875 +Rank-0: policy_step=468000, reward_env_2=821.6298828125 +Rank-0: policy_step=468000, reward_env_3=863.978515625 +Rank-0: policy_step=470000, reward_env_0=856.8347778320312 +Rank-0: policy_step=470000, reward_env_1=833.1890869140625 +Rank-0: policy_step=470000, reward_env_2=787.0861206054688 +Rank-0: policy_step=470000, reward_env_3=801.3120727539062 +Rank-0: policy_step=472000, reward_env_0=883.5061645507812 +Rank-0: policy_step=472000, reward_env_1=791.2484130859375 +Rank-0: policy_step=472000, reward_env_2=888.4317626953125 +Rank-0: policy_step=472000, reward_env_3=280.5549011230469 +Rank-0: policy_step=474000, reward_env_0=915.7325439453125 +Rank-0: policy_step=474000, reward_env_1=921.8428955078125 +Rank-0: policy_step=474000, reward_env_2=920.48388671875 +Rank-0: policy_step=474000, reward_env_3=858.83349609375 +Rank-0: policy_step=476000, reward_env_0=894.6270141601562 +Rank-0: policy_step=476000, reward_env_1=919.9764404296875 +Rank-0: policy_step=476000, reward_env_2=858.1167602539062 +Rank-0: policy_step=476000, reward_env_3=912.5479125976562 +Rank-0: policy_step=478000, reward_env_0=921.7483520507812 +Rank-0: policy_step=478000, reward_env_1=904.6066284179688 +Rank-0: policy_step=478000, reward_env_2=845.4376220703125 +Rank-0: policy_step=478000, reward_env_3=944.1141967773438 +Rank-0: policy_step=480000, reward_env_0=888.0297241210938 +Rank-0: policy_step=480000, reward_env_1=893.0316772460938 +Rank-0: policy_step=480000, reward_env_2=902.1946411132812 +Rank-0: policy_step=480000, reward_env_3=905.1110229492188 +Rank-0: policy_step=482000, reward_env_0=885.919677734375 +Rank-0: policy_step=482000, reward_env_1=809.8402099609375 +Rank-0: policy_step=482000, reward_env_2=901.9151611328125 +Rank-0: policy_step=482000, reward_env_3=884.3057861328125 +Rank-0: policy_step=484000, reward_env_0=904.5450439453125 +Rank-0: policy_step=484000, reward_env_1=893.8999633789062 +Rank-0: policy_step=484000, reward_env_2=914.0784301757812 +Rank-0: policy_step=484000, reward_env_3=909.2919311523438 +Rank-0: policy_step=486000, reward_env_0=865.3569946289062 +Rank-0: policy_step=486000, reward_env_1=852.3546142578125 +Rank-0: policy_step=486000, reward_env_2=699.8411254882812 +Rank-0: policy_step=486000, reward_env_3=897.8310546875 +Rank-0: policy_step=488000, reward_env_0=789.0899658203125 +Rank-0: policy_step=488000, reward_env_1=865.5814208984375 +Rank-0: policy_step=488000, reward_env_2=864.0103759765625 +Rank-0: policy_step=488000, reward_env_3=809.711181640625 +Rank-0: policy_step=490000, reward_env_0=891.7025146484375 +Rank-0: policy_step=490000, reward_env_1=884.9774780273438 +Rank-0: policy_step=490000, reward_env_2=890.9956665039062 +Rank-0: policy_step=490000, reward_env_3=846.63232421875 +Rank-0: policy_step=492000, reward_env_0=901.2996826171875 +Rank-0: policy_step=492000, reward_env_1=902.7505493164062 +Rank-0: policy_step=492000, reward_env_2=899.1532592773438 +Rank-0: policy_step=492000, reward_env_3=796.5845947265625 +Rank-0: policy_step=494000, reward_env_0=831.8873901367188 +Rank-0: policy_step=494000, reward_env_1=875.12548828125 +Rank-0: policy_step=494000, reward_env_2=848.447509765625 +Rank-0: policy_step=494000, reward_env_3=882.6404418945312 +Rank-0: policy_step=496000, reward_env_0=845.0203247070312 +Rank-0: policy_step=496000, reward_env_1=889.7410888671875 +Rank-0: policy_step=496000, reward_env_2=882.0408935546875 +Rank-0: policy_step=496000, reward_env_3=859.1314697265625 +Rank-0: policy_step=498000, reward_env_0=833.8394775390625 +Rank-0: policy_step=498000, reward_env_1=888.3397827148438 +Rank-0: policy_step=498000, reward_env_2=905.193359375 +Rank-0: policy_step=498000, reward_env_3=880.4007568359375 +Rank-0: policy_step=500000, reward_env_0=864.3685302734375 +Rank-0: policy_step=500000, reward_env_1=857.979248046875 +Rank-0: policy_step=500000, reward_env_2=901.314697265625 +Rank-0: policy_step=500000, reward_env_3=861.271728515625 \ No newline at end of file diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index de77c3c3..315baa7a 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -14,7 +14,7 @@ horizon: 15 # Training recipe replay_ratio: 1 -learning_starts: 65536 +learning_starts: 1024 per_rank_sequence_length: ??? # Encoder and decoder keys diff --git a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml index 3d67448a..395d47cb 100644 --- a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml @@ -8,7 +8,7 @@ defaults: algo: name: p2e_dv3_finetuning - learning_starts: 65536 + learning_starts: 16384 total_steps: 1000000 player: actor_type: exploration From 1f0c0ef7c80a2d29a4004bc95fd8b4e6027024d9 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Thu, 28 Mar 2024 11:19:30 +0100 Subject: [PATCH 10/51] fix: remove files --- hf.txt | 444 ------------------------ our.txt | 1004 ------------------------------------------------------- 2 files changed, 1448 deletions(-) delete mode 100644 hf.txt delete mode 100644 our.txt diff --git a/hf.txt b/hf.txt deleted file mode 100644 index 24b01d70..00000000 --- a/hf.txt +++ /dev/null @@ -1,444 +0,0 @@ -Episode has 500 steps and return 49.6. -Episode has 500 steps and return 28.1. -Episode has 500 steps and return 50.0. -Episode has 500 steps and return 47.3. -Episode has 500 steps and return 56.3. -Episode has 500 steps and return 73.6. -Episode has 500 steps and return 105.3. -Episode has 500 steps and return 53.3. -Episode has 500 steps and return 166.2. -Episode has 500 steps and return 100.8. -Episode has 500 steps and return 80.8. -Episode has 500 steps and return 144.9. -Episode has 500 steps and return 94.1. -Episode has 500 steps and return 146.6. -Episode has 500 steps and return 181.6. -Episode has 500 steps and return 146.2. -Episode has 500 steps and return 161.2. -Episode has 500 steps and return 178.6. -Episode has 500 steps and return 168.6. -Episode has 500 steps and return 159.5. -Episode has 500 steps and return 173.2. -Episode has 500 steps and return 173.4. -Episode has 500 steps and return 227.2. -Episode has 500 steps and return 266.1. -Episode has 500 steps and return 205.1. -Episode has 500 steps and return 164.7. -Episode has 500 steps and return 183.3. -Episode has 500 steps and return 208.4. -Episode has 500 steps and return 232.4. -Episode has 500 steps and return 213.2. -Episode has 500 steps and return 224.9. -Episode has 500 steps and return 236.5. -Episode has 500 steps and return 200.1. -Episode has 500 steps and return 198.5. -Episode has 500 steps and return 255.3. -Episode has 500 steps and return 227.0. -Episode has 500 steps and return 404.7. -Episode has 500 steps and return 388.6. -Episode has 500 steps and return 211.1. -Episode has 500 steps and return 229.2. -Episode has 500 steps and return 294.0. -Episode has 500 steps and return 171.7. -Episode has 500 steps and return 230.0. -Episode has 500 steps and return 265.8. -Episode has 500 steps and return 292.8. -Episode has 500 steps and return 228.9. -Episode has 500 steps and return 488.3. -Episode has 500 steps and return 92.0. -Episode has 500 steps and return 218.0. -Episode has 500 steps and return 170.2. -Episode has 500 steps and return 254.1. -Episode has 500 steps and return 75.2. -Episode has 500 steps and return 438.4. -Episode has 500 steps and return 239.5. -Episode has 500 steps and return 74.4. -Episode has 500 steps and return 268.6. -Episode has 500 steps and return 215.2. -Episode has 500 steps and return 191.5. -Episode has 500 steps and return 177.5. -Episode has 500 steps and return 331.3. -Episode has 500 steps and return 98.2. -Episode has 500 steps and return 35.0. -Episode has 500 steps and return 563.7. -Episode has 500 steps and return 52.1. -Episode has 500 steps and return 133.8. -Episode has 500 steps and return 96.4. -Episode has 500 steps and return 117.0. -Episode has 500 steps and return 120.2. -Episode has 500 steps and return 313.9. -Episode has 500 steps and return 375.6. -Episode has 500 steps and return 65.4. -Episode has 500 steps and return 374.2. -Episode has 500 steps and return 442.6. -Episode has 500 steps and return 286.9. -Episode has 500 steps and return 399.1. -Episode has 500 steps and return 434.4. -Episode has 500 steps and return 537.1. -Episode has 500 steps and return 548.6. -Episode has 500 steps and return 293.3. -Episode has 500 steps and return 555.3. -Episode has 500 steps and return 421.8. -Episode has 500 steps and return 170.0. -Episode has 500 steps and return 460.9. -Episode has 500 steps and return 368.1. -Episode has 500 steps and return 507.9. -Episode has 500 steps and return 404.2. -Episode has 500 steps and return 557.2. -Episode has 500 steps and return 472.3. -Episode has 500 steps and return 480.3. -Episode has 500 steps and return 472.7. -Episode has 500 steps and return 442.1. -Episode has 500 steps and return 304.9. -Episode has 500 steps and return 550.3. -Episode has 500 steps and return 458.1. -Episode has 500 steps and return 403.1. -Episode has 500 steps and return 422.5. -Episode has 500 steps and return 437.9. -Episode has 500 steps and return 319.1. -Episode has 500 steps and return 505.8. -Episode has 500 steps and return 582.0. -Episode has 500 steps and return 480.5. -Episode has 500 steps and return 466.2. -Episode has 500 steps and return 415.4. -Episode has 500 steps and return 570.2. -Episode has 500 steps and return 441.7. -Episode has 500 steps and return 611.2. -Episode has 500 steps and return 520.0. -Episode has 500 steps and return 527.2. -Episode has 500 steps and return 407.4. -Episode has 500 steps and return 232.7. -Episode has 500 steps and return 547.3. -Episode has 500 steps and return 360.3. -Episode has 500 steps and return 450.2. -Episode has 500 steps and return 670.0. -Episode has 500 steps and return 552.5. -Episode has 500 steps and return 528.5. -Episode has 500 steps and return 541.7. -Episode has 500 steps and return 611.0. -Episode has 500 steps and return 466.6. -Episode has 500 steps and return 608.8. -Episode has 500 steps and return 451.6. -Episode has 500 steps and return 524.8. -Episode has 500 steps and return 666.8. -Episode has 500 steps and return 419.7. -Episode has 500 steps and return 436.0. -Episode has 500 steps and return 478.2. -Episode has 500 steps and return 596.9. -Episode has 500 steps and return 587.7. -Episode has 500 steps and return 677.1. -Episode has 500 steps and return 416.8. -Episode has 500 steps and return 531.3. -Episode has 500 steps and return 609.7. -Episode has 500 steps and return 538.5. -Episode has 500 steps and return 619.3. -Episode has 500 steps and return 510.6. -Episode has 500 steps and return 453.9. -Episode has 500 steps and return 540.1. -Episode has 500 steps and return 601.7. -Episode has 500 steps and return 523.1. -Episode has 500 steps and return 626.5. -Episode has 500 steps and return 568.7. -Episode has 500 steps and return 606.8. -Episode has 500 steps and return 575.8. -Episode has 500 steps and return 648.1. -Episode has 500 steps and return 249.3. -Episode has 500 steps and return 431.9. -Episode has 500 steps and return 414.0. -Episode has 500 steps and return 562.2. -Episode has 500 steps and return 754.8. -Episode has 500 steps and return 732.9. -Episode has 500 steps and return 616.7. -Episode has 500 steps and return 637.3. -Episode has 500 steps and return 455.7. -Episode has 500 steps and return 736.2. -Episode has 500 steps and return 718.0. -Episode has 500 steps and return 620.7. -Episode has 500 steps and return 683.3. -Episode has 500 steps and return 512.4. -Episode has 500 steps and return 654.2. -Episode has 500 steps and return 555.8. -Episode has 500 steps and return 708.6. -Episode has 500 steps and return 711.1. -Episode has 500 steps and return 741.6. -Episode has 500 steps and return 639.3. -Episode has 500 steps and return 678.0. -Episode has 500 steps and return 634.7. -Episode has 500 steps and return 586.6. -Episode has 500 steps and return 582.5. -Episode has 500 steps and return 788.8. -Episode has 500 steps and return 722.7. -Episode has 500 steps and return 759.8. -Episode has 500 steps and return 595.7. -Episode has 500 steps and return 812.6. -Episode has 500 steps and return 793.8. -Episode has 500 steps and return 688.6. -Episode has 500 steps and return 736.5. -Episode has 500 steps and return 748.2. -Episode has 500 steps and return 799.9. -Episode has 500 steps and return 798.6. -Episode has 500 steps and return 328.5. -Episode has 500 steps and return 793.1. -Episode has 500 steps and return 686.7. -Episode has 500 steps and return 697.5. -Episode has 500 steps and return 573.3. -Episode has 500 steps and return 840.9. -Episode has 500 steps and return 844.6. -Episode has 500 steps and return 862.8. -Episode has 500 steps and return 853.2. -Episode has 500 steps and return 760.1. -Episode has 500 steps and return 792.1. -Episode has 500 steps and return 871.9. -Episode has 500 steps and return 727.3. -Episode has 500 steps and return 806.6. -Episode has 500 steps and return 915.1. -Episode has 500 steps and return 753.4. -Episode has 500 steps and return 865.3. -Episode has 500 steps and return 905.4. -Episode has 500 steps and return 769.4. -Episode has 500 steps and return 898.3. -Episode has 500 steps and return 872.0. -Episode has 500 steps and return 826.6. -Episode has 500 steps and return 819.1. -Episode has 500 steps and return 775.4. -Episode has 500 steps and return 749.3. -Episode has 500 steps and return 697.5. -Episode has 500 steps and return 923.3. -Episode has 500 steps and return 800.2. -Episode has 500 steps and return 809.2. -Episode has 500 steps and return 637.9. -Episode has 500 steps and return 965.4. -Episode has 500 steps and return 878.0. -Episode has 500 steps and return 462.7. -Episode has 500 steps and return 870.2. -Episode has 500 steps and return 890.6. -Episode has 500 steps and return 888.6. -Episode has 500 steps and return 708.7. -Episode has 500 steps and return 839.5. -Episode has 500 steps and return 861.0. -Episode has 500 steps and return 884.5. -Episode has 500 steps and return 907.9. -Episode has 500 steps and return 896.9. -Episode has 500 steps and return 821.8. -Episode has 500 steps and return 917.7. -Episode has 500 steps and return 858.4. -Episode has 500 steps and return 851.0. -Episode has 500 steps and return 847.9. -Episode has 500 steps and return 752.4. -Episode has 500 steps and return 929.9. -Episode has 500 steps and return 789.4. -Episode has 500 steps and return 854.0. -Episode has 500 steps and return 913.5. -Episode has 500 steps and return 806.3. -Episode has 500 steps and return 808.1. -Episode has 500 steps and return 951.2. -Episode has 500 steps and return 944.4. -Episode has 500 steps and return 891.4. -Episode has 500 steps and return 937.1. -Episode has 500 steps and return 791.8. -Episode has 500 steps and return 944.7. -Episode has 500 steps and return 804.1. -Episode has 500 steps and return 944.1. -Episode has 500 steps and return 893.4. -Episode has 500 steps and return 879.0. -Episode has 500 steps and return 856.9. -Episode has 500 steps and return 911.5. -Episode has 500 steps and return 869.6. -Episode has 500 steps and return 709.7. -Episode has 500 steps and return 911.6. -Episode has 500 steps and return 916.0. -Episode has 500 steps and return 906.5. -Episode has 500 steps and return 893.6. -Episode has 500 steps and return 918.6. -Episode has 500 steps and return 877.6. -Episode has 500 steps and return 905.8. -Episode has 500 steps and return 931.9. -Episode has 500 steps and return 914.2. -Episode has 500 steps and return 902.2. -Episode has 500 steps and return 939.2. -Episode has 500 steps and return 852.3. -Episode has 500 steps and return 877.9. -Episode has 500 steps and return 935.4. -Episode has 500 steps and return 881.7. -Episode has 500 steps and return 861.4. -Episode has 500 steps and return 891.9. -Episode has 500 steps and return 903.1. -Episode has 500 steps and return 931.0. -Episode has 500 steps and return 897.3. -Episode has 500 steps and return 980.8. -Episode has 500 steps and return 936.5. -Episode has 500 steps and return 944.7. -Episode has 500 steps and return 780.6. -Episode has 500 steps and return 869.5. -Episode has 500 steps and return 846.1. -Episode has 500 steps and return 963.2. -Episode has 500 steps and return 904.3. -Episode has 500 steps and return 951.7. -Episode has 500 steps and return 913.4. -Episode has 500 steps and return 945.5. -Episode has 500 steps and return 963.7. -Episode has 500 steps and return 700.7. -Episode has 500 steps and return 924.4. -Episode has 500 steps and return 576.8. -Episode has 500 steps and return 959.5. -Episode has 500 steps and return 886.6. -Episode has 500 steps and return 887.9. -Episode has 500 steps and return 956.5. -Episode has 500 steps and return 959.7. -Episode has 500 steps and return 971.7. -Episode has 500 steps and return 944.7. -Episode has 500 steps and return 972.0. -Episode has 500 steps and return 941.1. -Episode has 500 steps and return 940.6. -Episode has 500 steps and return 879.0. -Episode has 500 steps and return 938.4. -Episode has 500 steps and return 964.2. -Episode has 500 steps and return 959.9. -Episode has 500 steps and return 948.9. -Episode has 500 steps and return 847.3. -Episode has 500 steps and return 967.5. -Episode has 500 steps and return 939.8. -Episode has 500 steps and return 839.4. -Episode has 500 steps and return 890.1. -Episode has 500 steps and return 629.0. -Episode has 500 steps and return 808.7. -Episode has 500 steps and return 957.7. -Episode has 500 steps and return 951.3. -Episode has 500 steps and return 875.7. -Episode has 500 steps and return 915.9. -Episode has 500 steps and return 915.8. -Episode has 500 steps and return 899.5. -Episode has 500 steps and return 918.3. -Episode has 500 steps and return 884.6. -Episode has 500 steps and return 955.4. -Episode has 500 steps and return 923.3. -Episode has 500 steps and return 956.8. -Episode has 500 steps and return 941.6. -Episode has 500 steps and return 940.0. -Episode has 500 steps and return 897.3. -Episode has 500 steps and return 957.6. -Episode has 500 steps and return 880.2. -Episode has 500 steps and return 880.7. -Episode has 500 steps and return 947.3. -Episode has 500 steps and return 781.5. -Episode has 500 steps and return 977.5. -Episode has 500 steps and return 944.2. -Episode has 500 steps and return 933.7. -Episode has 500 steps and return 866.3. -Episode has 500 steps and return 986.8. -Episode has 500 steps and return 930.0. -Episode has 500 steps and return 944.7. -Episode has 500 steps and return 814.5. -Episode has 500 steps and return 927.2. -Episode has 500 steps and return 968.0. -Episode has 500 steps and return 862.1. -Episode has 500 steps and return 941.6. -Episode has 500 steps and return 944.9. -Episode has 500 steps and return 917.0. -Episode has 500 steps and return 954.5. -Episode has 500 steps and return 942.9. -Episode has 500 steps and return 957.7. -Episode has 500 steps and return 924.8. -Episode has 500 steps and return 933.4. -Episode has 500 steps and return 983.7. -Episode has 500 steps and return 963.1. -Episode has 500 steps and return 888.6. -Episode has 500 steps and return 950.3. -Episode has 500 steps and return 946.6. -Episode has 500 steps and return 913.1. -Episode has 500 steps and return 940.6. -Episode has 500 steps and return 946.6. -Episode has 500 steps and return 947.4. -Episode has 500 steps and return 896.9. -Episode has 500 steps and return 936.2. -Episode has 500 steps and return 954.8. -Episode has 500 steps and return 908.0. -Episode has 500 steps and return 899.1. -Episode has 500 steps and return 859.9. -Episode has 500 steps and return 905.3. -Episode has 500 steps and return 918.8. -Episode has 500 steps and return 876.7. -Episode has 500 steps and return 948.1. -Episode has 500 steps and return 950.8. -Episode has 500 steps and return 829.3. -Episode has 500 steps and return 985.3. -Episode has 500 steps and return 923.1. -Episode has 500 steps and return 958.6. -Episode has 500 steps and return 985.6. -Episode has 500 steps and return 910.7. -Episode has 500 steps and return 911.0. -Episode has 500 steps and return 711.1. -Episode has 500 steps and return 957.2. -Episode has 500 steps and return 828.6. -Episode has 500 steps and return 960.4. -Episode has 500 steps and return 949.7. -Episode has 500 steps and return 973.5. -Episode has 500 steps and return 888.0. -Episode has 500 steps and return 955.3. -Episode has 500 steps and return 962.0. -Episode has 500 steps and return 900.8. -Episode has 500 steps and return 980.7. -Episode has 500 steps and return 900.0. -Episode has 500 steps and return 919.3. -Episode has 500 steps and return 888.7. -Episode has 500 steps and return 933.6. -Episode has 500 steps and return 957.7. -Episode has 500 steps and return 915.5. -Episode has 500 steps and return 858.3. -Episode has 500 steps and return 948.3. -Episode has 500 steps and return 985.8. -Episode has 500 steps and return 970.6. -Episode has 500 steps and return 887.0. -Episode has 500 steps and return 971.1. -Episode has 500 steps and return 885.5. -Episode has 500 steps and return 935.9. -Episode has 500 steps and return 949.7. -Episode has 500 steps and return 940.5. -Episode has 500 steps and return 943.3. -Episode has 500 steps and return 879.6. -Episode has 500 steps and return 959.8. -Episode has 500 steps and return 972.0. -Episode has 500 steps and return 947.0. -Episode has 500 steps and return 868.3. -Episode has 500 steps and return 961.0. -Episode has 500 steps and return 970.9. -Episode has 500 steps and return 842.1. -Episode has 500 steps and return 982.6. -Episode has 500 steps and return 939.8. -Episode has 500 steps and return 964.5. -Episode has 500 steps and return 920.9. -Episode has 500 steps and return 917.2. -Episode has 500 steps and return 959.7. -Episode has 500 steps and return 933.0. -Episode has 500 steps and return 922.8. -Episode has 500 steps and return 919.7. -Episode has 500 steps and return 921.1. -Episode has 500 steps and return 945.8. -Episode has 500 steps and return 901.1. -Episode has 500 steps and return 838.8. -Episode has 500 steps and return 901.7. -Episode has 500 steps and return 950.5. -Episode has 500 steps and return 946.3. -Episode has 500 steps and return 862.8. -Episode has 500 steps and return 819.3. -Episode has 500 steps and return 929.2. -Episode has 500 steps and return 960.2. -Episode has 500 steps and return 915.0. -Episode has 500 steps and return 907.7. -Episode has 500 steps and return 884.8. -Episode has 500 steps and return 953.6. -Episode has 500 steps and return 939.3. -Episode has 500 steps and return 885.7. -Episode has 500 steps and return 906.4. -Episode has 500 steps and return 892.6. -Episode has 500 steps and return 882.8. -Episode has 500 steps and return 944.3. -Episode has 500 steps and return 948.4. -Episode has 500 steps and return 935.3. -Episode has 500 steps and return 946.1. -Episode has 500 steps and return 974.2. -Episode has 500 steps and return 948.4. -Episode has 500 steps and return 979.5. -Episode has 500 steps and return 906.0. -Episode has 500 steps and return 924.5. -Episode has 500 steps and return 930.4. \ No newline at end of file diff --git a/our.txt b/our.txt deleted file mode 100644 index 26bd4cc6..00000000 --- a/our.txt +++ /dev/null @@ -1,1004 +0,0 @@ -Rank-0: policy_step=2000, reward_env_0=34.73862075805664 -Rank-0: policy_step=2000, reward_env_1=42.344600677490234 -Rank-0: policy_step=2000, reward_env_2=47.81229019165039 -Rank-0: policy_step=2000, reward_env_3=41.856040954589844 -Rank-0: policy_step=4000, reward_env_0=41.78125 -Rank-0: policy_step=4000, reward_env_1=38.49329376220703 -Rank-0: policy_step=4000, reward_env_2=49.31068801879883 -Rank-0: policy_step=4000, reward_env_3=45.358585357666016 -Rank-0: policy_step=6000, reward_env_0=81.88333129882812 -Rank-0: policy_step=6000, reward_env_1=89.1480712890625 -Rank-0: policy_step=6000, reward_env_2=48.11588668823242 -Rank-0: policy_step=6000, reward_env_3=98.4811019897461 -Rank-0: policy_step=8000, reward_env_0=68.72354888916016 -Rank-0: policy_step=8000, reward_env_1=66.5965576171875 -Rank-0: policy_step=8000, reward_env_2=82.55899810791016 -Rank-0: policy_step=8000, reward_env_3=132.6808624267578 -Rank-0: policy_step=10000, reward_env_0=118.16141510009766 -Rank-0: policy_step=10000, reward_env_1=74.5601806640625 -Rank-0: policy_step=10000, reward_env_2=142.609130859375 -Rank-0: policy_step=10000, reward_env_3=139.7652130126953 -Rank-0: policy_step=12000, reward_env_0=188.80462646484375 -Rank-0: policy_step=12000, reward_env_1=184.95005798339844 -Rank-0: policy_step=12000, reward_env_2=86.33712768554688 -Rank-0: policy_step=12000, reward_env_3=155.44322204589844 -Rank-0: policy_step=14000, reward_env_0=76.76958465576172 -Rank-0: policy_step=14000, reward_env_1=175.488525390625 -Rank-0: policy_step=14000, reward_env_2=62.65166091918945 -Rank-0: policy_step=14000, reward_env_3=142.7516632080078 -Rank-0: policy_step=16000, reward_env_0=185.68272399902344 -Rank-0: policy_step=16000, reward_env_1=216.90252685546875 -Rank-0: policy_step=16000, reward_env_2=276.9674987792969 -Rank-0: policy_step=16000, reward_env_3=202.0769805908203 -Rank-0: policy_step=18000, reward_env_0=188.1887664794922 -Rank-0: policy_step=18000, reward_env_1=263.4117431640625 -Rank-0: policy_step=18000, reward_env_2=289.2801208496094 -Rank-0: policy_step=18000, reward_env_3=184.7601318359375 -Rank-0: policy_step=20000, reward_env_0=311.6178283691406 -Rank-0: policy_step=20000, reward_env_1=346.6574401855469 -Rank-0: policy_step=20000, reward_env_2=228.565185546875 -Rank-0: policy_step=20000, reward_env_3=277.8114013671875 -Rank-0: policy_step=22000, reward_env_0=263.11260986328125 -Rank-0: policy_step=22000, reward_env_1=298.3888854980469 -Rank-0: policy_step=22000, reward_env_2=324.6815490722656 -Rank-0: policy_step=22000, reward_env_3=382.9852294921875 -Rank-0: policy_step=24000, reward_env_0=375.1632995605469 -Rank-0: policy_step=24000, reward_env_1=369.400390625 -Rank-0: policy_step=24000, reward_env_2=381.8180847167969 -Rank-0: policy_step=24000, reward_env_3=398.2370910644531 -Rank-0: policy_step=26000, reward_env_0=404.8024597167969 -Rank-0: policy_step=26000, reward_env_1=354.73333740234375 -Rank-0: policy_step=26000, reward_env_2=390.251220703125 -Rank-0: policy_step=26000, reward_env_3=383.8092041015625 -Rank-0: policy_step=28000, reward_env_0=414.41278076171875 -Rank-0: policy_step=28000, reward_env_1=291.2098388671875 -Rank-0: policy_step=28000, reward_env_2=386.4712829589844 -Rank-0: policy_step=28000, reward_env_3=391.07366943359375 -Rank-0: policy_step=30000, reward_env_0=375.96124267578125 -Rank-0: policy_step=30000, reward_env_1=430.10546875 -Rank-0: policy_step=30000, reward_env_2=378.39630126953125 -Rank-0: policy_step=30000, reward_env_3=396.7026062011719 -Rank-0: policy_step=32000, reward_env_0=373.9325866699219 -Rank-0: policy_step=32000, reward_env_1=485.3587951660156 -Rank-0: policy_step=32000, reward_env_2=377.4389953613281 -Rank-0: policy_step=32000, reward_env_3=347.692626953125 -Rank-0: policy_step=34000, reward_env_0=416.3648681640625 -Rank-0: policy_step=34000, reward_env_1=438.3783264160156 -Rank-0: policy_step=34000, reward_env_2=444.15673828125 -Rank-0: policy_step=34000, reward_env_3=445.5474853515625 -Rank-0: policy_step=36000, reward_env_0=453.7706298828125 -Rank-0: policy_step=36000, reward_env_1=424.9276123046875 -Rank-0: policy_step=36000, reward_env_2=506.4404602050781 -Rank-0: policy_step=36000, reward_env_3=447.3998718261719 -Rank-0: policy_step=38000, reward_env_0=400.382080078125 -Rank-0: policy_step=38000, reward_env_1=419.72625732421875 -Rank-0: policy_step=38000, reward_env_2=332.4638977050781 -Rank-0: policy_step=38000, reward_env_3=211.37547302246094 -Rank-0: policy_step=40000, reward_env_0=353.8775634765625 -Rank-0: policy_step=40000, reward_env_1=389.2950134277344 -Rank-0: policy_step=40000, reward_env_2=343.6236267089844 -Rank-0: policy_step=40000, reward_env_3=371.2104187011719 -Rank-0: policy_step=42000, reward_env_0=484.0337829589844 -Rank-0: policy_step=42000, reward_env_1=401.7615661621094 -Rank-0: policy_step=42000, reward_env_2=409.39385986328125 -Rank-0: policy_step=42000, reward_env_3=358.9210205078125 -Rank-0: policy_step=44000, reward_env_0=387.3330078125 -Rank-0: policy_step=44000, reward_env_1=412.122802734375 -Rank-0: policy_step=44000, reward_env_2=500.48443603515625 -Rank-0: policy_step=44000, reward_env_3=447.0583190917969 -Rank-0: policy_step=46000, reward_env_0=408.05352783203125 -Rank-0: policy_step=46000, reward_env_1=398.016845703125 -Rank-0: policy_step=46000, reward_env_2=332.7139587402344 -Rank-0: policy_step=46000, reward_env_3=301.12091064453125 -Rank-0: policy_step=48000, reward_env_0=414.05938720703125 -Rank-0: policy_step=48000, reward_env_1=378.2053527832031 -Rank-0: policy_step=48000, reward_env_2=314.1753234863281 -Rank-0: policy_step=48000, reward_env_3=433.94488525390625 -Rank-0: policy_step=50000, reward_env_0=444.0424499511719 -Rank-0: policy_step=50000, reward_env_1=325.40447998046875 -Rank-0: policy_step=50000, reward_env_2=515.9829711914062 -Rank-0: policy_step=50000, reward_env_3=330.1351623535156 -Rank-0: policy_step=52000, reward_env_0=418.6585693359375 -Rank-0: policy_step=52000, reward_env_1=287.8473205566406 -Rank-0: policy_step=52000, reward_env_2=332.3724670410156 -Rank-0: policy_step=52000, reward_env_3=308.03717041015625 -Rank-0: policy_step=54000, reward_env_0=471.9745178222656 -Rank-0: policy_step=54000, reward_env_1=470.923583984375 -Rank-0: policy_step=54000, reward_env_2=516.7538452148438 -Rank-0: policy_step=54000, reward_env_3=457.2450256347656 -Rank-0: policy_step=56000, reward_env_0=415.8127136230469 -Rank-0: policy_step=56000, reward_env_1=486.84405517578125 -Rank-0: policy_step=56000, reward_env_2=386.1386413574219 -Rank-0: policy_step=56000, reward_env_3=463.2752990722656 -Rank-0: policy_step=58000, reward_env_0=574.1663208007812 -Rank-0: policy_step=58000, reward_env_1=505.2137756347656 -Rank-0: policy_step=58000, reward_env_2=540.8296508789062 -Rank-0: policy_step=58000, reward_env_3=486.4355773925781 -Rank-0: policy_step=60000, reward_env_0=570.690673828125 -Rank-0: policy_step=60000, reward_env_1=511.0129699707031 -Rank-0: policy_step=60000, reward_env_2=415.1099853515625 -Rank-0: policy_step=60000, reward_env_3=468.572021484375 -Rank-0: policy_step=62000, reward_env_0=425.178466796875 -Rank-0: policy_step=62000, reward_env_1=387.4505615234375 -Rank-0: policy_step=62000, reward_env_2=413.6191101074219 -Rank-0: policy_step=62000, reward_env_3=400.85174560546875 -Rank-0: policy_step=64000, reward_env_0=568.7259521484375 -Rank-0: policy_step=64000, reward_env_1=533.4554443359375 -Rank-0: policy_step=64000, reward_env_2=600.3287353515625 -Rank-0: policy_step=64000, reward_env_3=535.531982421875 -Rank-0: policy_step=66000, reward_env_0=422.3890380859375 -Rank-0: policy_step=66000, reward_env_1=516.184814453125 -Rank-0: policy_step=66000, reward_env_2=470.21258544921875 -Rank-0: policy_step=66000, reward_env_3=445.8867492675781 -Rank-0: policy_step=68000, reward_env_0=468.1947937011719 -Rank-0: policy_step=68000, reward_env_1=545.9535522460938 -Rank-0: policy_step=68000, reward_env_2=526.6798706054688 -Rank-0: policy_step=68000, reward_env_3=442.2272644042969 -Rank-0: policy_step=70000, reward_env_0=505.8017578125 -Rank-0: policy_step=70000, reward_env_1=578.27392578125 -Rank-0: policy_step=70000, reward_env_2=588.2696533203125 -Rank-0: policy_step=70000, reward_env_3=546.4624633789062 -Rank-0: policy_step=72000, reward_env_0=530.622802734375 -Rank-0: policy_step=72000, reward_env_1=466.9184875488281 -Rank-0: policy_step=72000, reward_env_2=519.3150024414062 -Rank-0: policy_step=72000, reward_env_3=494.05035400390625 -Rank-0: policy_step=74000, reward_env_0=528.94287109375 -Rank-0: policy_step=74000, reward_env_1=570.19091796875 -Rank-0: policy_step=74000, reward_env_2=460.4098815917969 -Rank-0: policy_step=74000, reward_env_3=570.50927734375 -Rank-0: policy_step=76000, reward_env_0=556.430908203125 -Rank-0: policy_step=76000, reward_env_1=482.764892578125 -Rank-0: policy_step=76000, reward_env_2=594.02490234375 -Rank-0: policy_step=76000, reward_env_3=573.9700927734375 -Rank-0: policy_step=78000, reward_env_0=443.87994384765625 -Rank-0: policy_step=78000, reward_env_1=563.2550659179688 -Rank-0: policy_step=78000, reward_env_2=521.17919921875 -Rank-0: policy_step=78000, reward_env_3=352.4790954589844 -Rank-0: policy_step=80000, reward_env_0=536.4426879882812 -Rank-0: policy_step=80000, reward_env_1=409.7697448730469 -Rank-0: policy_step=80000, reward_env_2=517.1969604492188 -Rank-0: policy_step=80000, reward_env_3=519.5016479492188 -Rank-0: policy_step=82000, reward_env_0=473.44415283203125 -Rank-0: policy_step=82000, reward_env_1=554.3283081054688 -Rank-0: policy_step=82000, reward_env_2=471.43060302734375 -Rank-0: policy_step=82000, reward_env_3=486.87945556640625 -Rank-0: policy_step=84000, reward_env_0=444.3627014160156 -Rank-0: policy_step=84000, reward_env_1=623.5541381835938 -Rank-0: policy_step=84000, reward_env_2=561.1341552734375 -Rank-0: policy_step=84000, reward_env_3=632.6451416015625 -Rank-0: policy_step=86000, reward_env_0=452.0357360839844 -Rank-0: policy_step=86000, reward_env_1=457.6752624511719 -Rank-0: policy_step=86000, reward_env_2=462.83270263671875 -Rank-0: policy_step=86000, reward_env_3=633.3515625 -Rank-0: policy_step=88000, reward_env_0=496.31475830078125 -Rank-0: policy_step=88000, reward_env_1=524.0308227539062 -Rank-0: policy_step=88000, reward_env_2=446.4565124511719 -Rank-0: policy_step=88000, reward_env_3=528.0741577148438 -Rank-0: policy_step=90000, reward_env_0=526.8228759765625 -Rank-0: policy_step=90000, reward_env_1=692.5054931640625 -Rank-0: policy_step=90000, reward_env_2=558.9354248046875 -Rank-0: policy_step=90000, reward_env_3=668.9599609375 -Rank-0: policy_step=92000, reward_env_0=680.6822509765625 -Rank-0: policy_step=92000, reward_env_1=600.8048095703125 -Rank-0: policy_step=92000, reward_env_2=509.6063537597656 -Rank-0: policy_step=92000, reward_env_3=573.4466552734375 -Rank-0: policy_step=94000, reward_env_0=712.0780639648438 -Rank-0: policy_step=94000, reward_env_1=632.1633911132812 -Rank-0: policy_step=94000, reward_env_2=664.1851196289062 -Rank-0: policy_step=94000, reward_env_3=767.8641967773438 -Rank-0: policy_step=96000, reward_env_0=716.1005249023438 -Rank-0: policy_step=96000, reward_env_1=689.6419677734375 -Rank-0: policy_step=96000, reward_env_2=694.5114135742188 -Rank-0: policy_step=96000, reward_env_3=623.5415649414062 -Rank-0: policy_step=98000, reward_env_0=717.7392578125 -Rank-0: policy_step=98000, reward_env_1=693.0969848632812 -Rank-0: policy_step=98000, reward_env_2=720.88671875 -Rank-0: policy_step=98000, reward_env_3=564.4533081054688 -Rank-0: policy_step=100000, reward_env_0=652.6953735351562 -Rank-0: policy_step=100000, reward_env_1=606.155517578125 -Rank-0: policy_step=100000, reward_env_2=650.3914184570312 -Rank-0: policy_step=100000, reward_env_3=644.356689453125 -Rank-0: policy_step=102000, reward_env_0=696.5504760742188 -Rank-0: policy_step=102000, reward_env_1=830.15966796875 -Rank-0: policy_step=102000, reward_env_2=702.1847534179688 -Rank-0: policy_step=102000, reward_env_3=695.056396484375 -Rank-0: policy_step=104000, reward_env_0=705.6522827148438 -Rank-0: policy_step=104000, reward_env_1=721.8042602539062 -Rank-0: policy_step=104000, reward_env_2=661.4934692382812 -Rank-0: policy_step=104000, reward_env_3=630.6600341796875 -Rank-0: policy_step=106000, reward_env_0=733.650634765625 -Rank-0: policy_step=106000, reward_env_1=684.9617919921875 -Rank-0: policy_step=106000, reward_env_2=773.5457763671875 -Rank-0: policy_step=106000, reward_env_3=767.1033325195312 -Rank-0: policy_step=108000, reward_env_0=762.7892456054688 -Rank-0: policy_step=108000, reward_env_1=659.6124267578125 -Rank-0: policy_step=108000, reward_env_2=719.6046142578125 -Rank-0: policy_step=108000, reward_env_3=829.75390625 -Rank-0: policy_step=110000, reward_env_0=775.33740234375 -Rank-0: policy_step=110000, reward_env_1=748.4049682617188 -Rank-0: policy_step=110000, reward_env_2=775.7978515625 -Rank-0: policy_step=110000, reward_env_3=667.773681640625 -Rank-0: policy_step=112000, reward_env_0=795.8703002929688 -Rank-0: policy_step=112000, reward_env_1=807.1406860351562 -Rank-0: policy_step=112000, reward_env_2=891.5454711914062 -Rank-0: policy_step=112000, reward_env_3=716.4409790039062 -Rank-0: policy_step=114000, reward_env_0=800.0789184570312 -Rank-0: policy_step=114000, reward_env_1=748.317138671875 -Rank-0: policy_step=114000, reward_env_2=712.0599975585938 -Rank-0: policy_step=114000, reward_env_3=809.2642211914062 -Rank-0: policy_step=116000, reward_env_0=836.1480102539062 -Rank-0: policy_step=116000, reward_env_1=788.550048828125 -Rank-0: policy_step=116000, reward_env_2=710.2114868164062 -Rank-0: policy_step=116000, reward_env_3=678.5193481445312 -Rank-0: policy_step=118000, reward_env_0=733.9635009765625 -Rank-0: policy_step=118000, reward_env_1=750.5971069335938 -Rank-0: policy_step=118000, reward_env_2=811.8917846679688 -Rank-0: policy_step=118000, reward_env_3=832.9111938476562 -Rank-0: policy_step=120000, reward_env_0=736.3533325195312 -Rank-0: policy_step=120000, reward_env_1=894.9639892578125 -Rank-0: policy_step=120000, reward_env_2=894.9337768554688 -Rank-0: policy_step=120000, reward_env_3=847.1104125976562 -Rank-0: policy_step=122000, reward_env_0=920.4165649414062 -Rank-0: policy_step=122000, reward_env_1=823.3157958984375 -Rank-0: policy_step=122000, reward_env_2=905.22021484375 -Rank-0: policy_step=122000, reward_env_3=850.3617553710938 -Rank-0: policy_step=124000, reward_env_0=912.7060546875 -Rank-0: policy_step=124000, reward_env_1=935.6702880859375 -Rank-0: policy_step=124000, reward_env_2=855.1871337890625 -Rank-0: policy_step=124000, reward_env_3=867.9970703125 -Rank-0: policy_step=126000, reward_env_0=769.3657836914062 -Rank-0: policy_step=126000, reward_env_1=851.39404296875 -Rank-0: policy_step=126000, reward_env_2=675.6405029296875 -Rank-0: policy_step=126000, reward_env_3=833.4070434570312 -Rank-0: policy_step=128000, reward_env_0=894.7110595703125 -Rank-0: policy_step=128000, reward_env_1=907.6494750976562 -Rank-0: policy_step=128000, reward_env_2=886.9708862304688 -Rank-0: policy_step=128000, reward_env_3=913.380432128906 -Rank-0: policy_step=130000, reward_env_0=769.1718139648438 -Rank-0: policy_step=130000, reward_env_1=697.7454223632812 -Rank-0: policy_step=130000, reward_env_2=855.5421752929688 -Rank-0: policy_step=130000, reward_env_3=822.5703735351562 -Rank-0: policy_step=132000, reward_env_0=891.750732421875 -Rank-0: policy_step=132000, reward_env_1=858.8231811523438 -Rank-0: policy_step=132000, reward_env_2=878.7779541015625 -Rank-0: policy_step=132000, reward_env_3=791.3135375976562 -Rank-0: policy_step=134000, reward_env_0=594.6578979492188 -Rank-0: policy_step=134000, reward_env_1=566.0285034179688 -Rank-0: policy_step=134000, reward_env_2=708.566162109375 -Rank-0: policy_step=134000, reward_env_3=651.6737060546875 -Rank-0: policy_step=136000, reward_env_0=577.1491088867188 -Rank-0: policy_step=136000, reward_env_1=684.2374877929688 -Rank-0: policy_step=136000, reward_env_2=644.9037475585938 -Rank-0: policy_step=136000, reward_env_3=661.53271484375 -Rank-0: policy_step=138000, reward_env_0=681.3390502929688 -Rank-0: policy_step=138000, reward_env_1=240.81495666503906 -Rank-0: policy_step=138000, reward_env_2=682.992919921875 -Rank-0: policy_step=138000, reward_env_3=645.79443359375 -Rank-0: policy_step=140000, reward_env_0=722.959228515625 -Rank-0: policy_step=140000, reward_env_1=740.4248046875 -Rank-0: policy_step=140000, reward_env_2=664.8697509765625 -Rank-0: policy_step=140000, reward_env_3=747.2042236328125 -Rank-0: policy_step=142000, reward_env_0=726.0316162109375 -Rank-0: policy_step=142000, reward_env_1=729.0147705078125 -Rank-0: policy_step=142000, reward_env_2=667.4451293945312 -Rank-0: policy_step=142000, reward_env_3=748.6612548828125 -Rank-0: policy_step=144000, reward_env_0=814.0946044921875 -Rank-0: policy_step=144000, reward_env_1=846.5692138671875 -Rank-0: policy_step=144000, reward_env_2=729.8314208984375 -Rank-0: policy_step=144000, reward_env_3=748.5468139648438 -Rank-0: policy_step=146000, reward_env_0=750.8712768554688 -Rank-0: policy_step=146000, reward_env_1=792.1831665039062 -Rank-0: policy_step=146000, reward_env_2=805.902587890625 -Rank-0: policy_step=146000, reward_env_3=712.8002319335938 -Rank-0: policy_step=148000, reward_env_0=848.4915161132812 -Rank-0: policy_step=148000, reward_env_1=909.7400512695312 -Rank-0: policy_step=148000, reward_env_2=832.5953369140625 -Rank-0: policy_step=148000, reward_env_3=868.4920043945312 -Rank-0: policy_step=150000, reward_env_0=584.99951171875 -Rank-0: policy_step=150000, reward_env_1=634.89111328125 -Rank-0: policy_step=150000, reward_env_2=636.4849243164062 -Rank-0: policy_step=150000, reward_env_3=657.2733764648438 -Rank-0: policy_step=152000, reward_env_0=710.2503662109375 -Rank-0: policy_step=152000, reward_env_1=636.9563598632812 -Rank-0: policy_step=152000, reward_env_2=643.7001342773438 -Rank-0: policy_step=152000, reward_env_3=684.23681640625 -Rank-0: policy_step=154000, reward_env_0=805.3668823242188 -Rank-0: policy_step=154000, reward_env_1=861.1378784179688 -Rank-0: policy_step=154000, reward_env_2=850.2848510742188 -Rank-0: policy_step=154000, reward_env_3=815.2654418945312 -Rank-0: policy_step=156000, reward_env_0=920.6705322265625 -Rank-0: policy_step=156000, reward_env_1=872.3659057617188 -Rank-0: policy_step=156000, reward_env_2=816.8571166992188 -Rank-0: policy_step=156000, reward_env_3=937.50390625 -Rank-0: policy_step=158000, reward_env_0=799.9392700195312 -Rank-0: policy_step=158000, reward_env_1=905.5791625976562 -Rank-0: policy_step=158000, reward_env_2=857.4993896484375 -Rank-0: policy_step=158000, reward_env_3=879.02197265625 -Rank-0: policy_step=160000, reward_env_0=849.5126342773438 -Rank-0: policy_step=160000, reward_env_1=818.5578002929688 -Rank-0: policy_step=160000, reward_env_2=888.0670166015625 -Rank-0: policy_step=160000, reward_env_3=814.2349853515625 -Rank-0: policy_step=162000, reward_env_0=691.1488037109375 -Rank-0: policy_step=162000, reward_env_1=760.5980834960938 -Rank-0: policy_step=162000, reward_env_2=852.7131958007812 -Rank-0: policy_step=162000, reward_env_3=768.8295288085938 -Rank-0: policy_step=164000, reward_env_0=906.2494506835938 -Rank-0: policy_step=164000, reward_env_1=802.9567260742188 -Rank-0: policy_step=164000, reward_env_2=809.2301025390625 -Rank-0: policy_step=164000, reward_env_3=823.1631469726562 -Rank-0: policy_step=166000, reward_env_0=825.5352172851562 -Rank-0: policy_step=166000, reward_env_1=852.8405151367188 -Rank-0: policy_step=166000, reward_env_2=769.0669555664062 -Rank-0: policy_step=166000, reward_env_3=895.240966796875 -Rank-0: policy_step=168000, reward_env_0=764.7465209960938 -Rank-0: policy_step=168000, reward_env_1=727.8375244140625 -Rank-0: policy_step=168000, reward_env_2=673.0181274414062 -Rank-0: policy_step=168000, reward_env_3=816.3668823242188 -Rank-0: policy_step=170000, reward_env_0=234.01513671875 -Rank-0: policy_step=170000, reward_env_1=882.3270263671875 -Rank-0: policy_step=170000, reward_env_2=862.6891479492188 -Rank-0: policy_step=170000, reward_env_3=888.3853759765625 -Rank-0: policy_step=172000, reward_env_0=778.1363525390625 -Rank-0: policy_step=172000, reward_env_1=758.3740234375 -Rank-0: policy_step=172000, reward_env_2=784.8538818359375 -Rank-0: policy_step=172000, reward_env_3=775.268310546875 -Rank-0: policy_step=174000, reward_env_0=832.8033447265625 -Rank-0: policy_step=174000, reward_env_1=810.05224609375 -Rank-0: policy_step=174000, reward_env_2=754.3297119140625 -Rank-0: policy_step=174000, reward_env_3=496.98004150390625 -Rank-0: policy_step=176000, reward_env_0=803.17041015625 -Rank-0: policy_step=176000, reward_env_1=839.056884765625 -Rank-0: policy_step=176000, reward_env_2=817.6718139648438 -Rank-0: policy_step=176000, reward_env_3=865.02099609375 -Rank-0: policy_step=178000, reward_env_0=685.6907348632812 -Rank-0: policy_step=178000, reward_env_1=717.9905395507812 -Rank-0: policy_step=178000, reward_env_2=684.7826538085938 -Rank-0: policy_step=178000, reward_env_3=757.7161865234375 -Rank-0: policy_step=180000, reward_env_0=863.4733276367188 -Rank-0: policy_step=180000, reward_env_1=836.3515625 -Rank-0: policy_step=180000, reward_env_2=843.3726806640625 -Rank-0: policy_step=180000, reward_env_3=844.8733520507812 -Rank-0: policy_step=180000, reward_env_0=863.4733276367188 -Rank-0: policy_step=180000, reward_env_1=836.3515625 -Rank-0: policy_step=180000, reward_env_2=843.3726806640625 -Rank-0: policy_step=180000, reward_env_3=844.8733520507812 -Rank-0: policy_step=182000, reward_env_0=845.8255004882812 -Rank-0: policy_step=182000, reward_env_1=883.538818359375 -Rank-0: policy_step=182000, reward_env_2=791.5325317382812 -Rank-0: policy_step=182000, reward_env_3=862.1351318359375 -Rank-0: policy_step=184000, reward_env_0=779.1425170898438 -Rank-0: policy_step=184000, reward_env_1=762.2304077148438 -Rank-0: policy_step=184000, reward_env_2=742.30419921875 -Rank-0: policy_step=184000, reward_env_3=831.0992431640625 -Rank-0: policy_step=186000, reward_env_0=582.53076171875 -Rank-0: policy_step=186000, reward_env_1=822.92919921875 -Rank-0: policy_step=186000, reward_env_2=784.7510986328125 -Rank-0: policy_step=186000, reward_env_3=749.7599487304688 -Rank-0: policy_step=188000, reward_env_0=907.643310546875 -Rank-0: policy_step=188000, reward_env_1=888.6090698242188 -Rank-0: policy_step=188000, reward_env_2=829.2177734375 -Rank-0: policy_step=188000, reward_env_3=905.5299072265625 -Rank-0: policy_step=190000, reward_env_0=890.6513671875 -Rank-0: policy_step=190000, reward_env_1=883.7294921875 -Rank-0: policy_step=190000, reward_env_2=919.9202880859375 -Rank-0: policy_step=190000, reward_env_3=856.9088745117188 -Rank-0: policy_step=192000, reward_env_0=910.6689453125 -Rank-0: policy_step=192000, reward_env_1=887.120361328125 -Rank-0: policy_step=192000, reward_env_2=862.862060546875 -Rank-0: policy_step=192000, reward_env_3=883.4767456054688 -Rank-0: policy_step=194000, reward_env_0=871.2962036132812 -Rank-0: policy_step=194000, reward_env_1=841.5816040039062 -Rank-0: policy_step=194000, reward_env_2=828.805908203125 -Rank-0: policy_step=194000, reward_env_3=871.89697265625 -Rank-0: policy_step=196000, reward_env_0=854.8218994140625 -Rank-0: policy_step=196000, reward_env_1=897.5107421875 -Rank-0: policy_step=196000, reward_env_2=918.3775024414062 -Rank-0: policy_step=196000, reward_env_3=868.5860595703125 -Rank-0: policy_step=198000, reward_env_0=857.48779296875 -Rank-0: policy_step=198000, reward_env_1=878.7049560546875 -Rank-0: policy_step=198000, reward_env_2=831.6140747070312 -Rank-0: policy_step=198000, reward_env_3=828.8794555664062 -Rank-0: policy_step=200000, reward_env_0=700.7992553710938 -Rank-0: policy_step=200000, reward_env_1=755.5785522460938 -Rank-0: policy_step=200000, reward_env_2=797.6727294921875 -Rank-0: policy_step=200000, reward_env_3=698.5155029296875 -Rank-0: policy_step=202000, reward_env_0=846.9471435546875 -Rank-0: policy_step=202000, reward_env_1=857.8955078125 -Rank-0: policy_step=202000, reward_env_2=919.7608032226562 -Rank-0: policy_step=202000, reward_env_3=778.7256469726562 -Rank-0: policy_step=204000, reward_env_0=814.895263671875 -Rank-0: policy_step=204000, reward_env_1=771.8240966796875 -Rank-0: policy_step=204000, reward_env_2=838.2137451171875 -Rank-0: policy_step=204000, reward_env_3=880.1572265625 -Rank-0: policy_step=206000, reward_env_0=695.6881103515625 -Rank-0: policy_step=206000, reward_env_1=700.8348999023438 -Rank-0: policy_step=206000, reward_env_2=778.7178955078125 -Rank-0: policy_step=206000, reward_env_3=707.680908203125 -Rank-0: policy_step=208000, reward_env_0=876.1331176757812 -Rank-0: policy_step=208000, reward_env_1=811.8592529296875 -Rank-0: policy_step=208000, reward_env_2=623.7986450195312 -Rank-0: policy_step=208000, reward_env_3=762.5757446289062 -Rank-0: policy_step=210000, reward_env_0=791.9050903320312 -Rank-0: policy_step=210000, reward_env_1=884.3527221679688 -Rank-0: policy_step=210000, reward_env_2=846.5733642578125 -Rank-0: policy_step=210000, reward_env_3=893.745361328125 -Rank-0: policy_step=212000, reward_env_0=869.674072265625 -Rank-0: policy_step=212000, reward_env_1=773.3558349609375 -Rank-0: policy_step=212000, reward_env_2=869.585693359375 -Rank-0: policy_step=212000, reward_env_3=857.3773803710938 -Rank-0: policy_step=214000, reward_env_0=923.8992919921875 -Rank-0: policy_step=214000, reward_env_1=925.374267578125 -Rank-0: policy_step=214000, reward_env_2=918.5689086914062 -Rank-0: policy_step=214000, reward_env_3=889.7066650390625 -Rank-0: policy_step=216000, reward_env_0=886.3828735351562 -Rank-0: policy_step=216000, reward_env_1=813.0368041992188 -Rank-0: policy_step=216000, reward_env_2=878.5695190429688 -Rank-0: policy_step=216000, reward_env_3=791.1290893554688 -Rank-0: policy_step=218000, reward_env_0=865.8330078125 -Rank-0: policy_step=218000, reward_env_1=905.0578002929688 -Rank-0: policy_step=218000, reward_env_2=889.3573608398438 -Rank-0: policy_step=218000, reward_env_3=829.8511962890625 -Rank-0: policy_step=220000, reward_env_0=885.14306640625 -Rank-0: policy_step=220000, reward_env_1=779.61572265625 -Rank-0: policy_step=220000, reward_env_2=845.6345825195312 -Rank-0: policy_step=220000, reward_env_3=798.9155883789062 -Rank-0: policy_step=222000, reward_env_0=849.7171020507812 -Rank-0: policy_step=222000, reward_env_1=868.869384765625 -Rank-0: policy_step=222000, reward_env_2=834.7500610351562 -Rank-0: policy_step=222000, reward_env_3=842.9816284179688 -Rank-0: policy_step=224000, reward_env_0=802.7107543945312 -Rank-0: policy_step=224000, reward_env_1=901.338623046875 -Rank-0: policy_step=224000, reward_env_2=811.5595092773438 -Rank-0: policy_step=224000, reward_env_3=888.8012084960938 -Rank-0: policy_step=226000, reward_env_0=933.8987426757812 -Rank-0: policy_step=226000, reward_env_1=849.692138671875 -Rank-0: policy_step=226000, reward_env_2=922.0040893554688 -Rank-0: policy_step=226000, reward_env_3=876.5579833984375 -Rank-0: policy_step=228000, reward_env_0=886.2613525390625 -Rank-0: policy_step=228000, reward_env_1=845.2538452148438 -Rank-0: policy_step=228000, reward_env_2=867.0548095703125 -Rank-0: policy_step=228000, reward_env_3=819.923828125 -Rank-0: policy_step=230000, reward_env_0=854.1001586914062 -Rank-0: policy_step=230000, reward_env_1=919.8405151367188 -Rank-0: policy_step=230000, reward_env_2=824.2915649414062 -Rank-0: policy_step=230000, reward_env_3=878.7169799804688 -Rank-0: policy_step=232000, reward_env_0=926.7627563476562 -Rank-0: policy_step=232000, reward_env_1=872.3309326171875 -Rank-0: policy_step=232000, reward_env_2=867.2348022460938 -Rank-0: policy_step=232000, reward_env_3=784.5524291992188 -Rank-0: policy_step=234000, reward_env_0=915.9845581054688 -Rank-0: policy_step=234000, reward_env_1=827.5287475585938 -Rank-0: policy_step=234000, reward_env_2=849.2322387695312 -Rank-0: policy_step=234000, reward_env_3=780.6039428710938 -Rank-0: policy_step=236000, reward_env_0=827.017822265625 -Rank-0: policy_step=236000, reward_env_1=904.0676879882812 -Rank-0: policy_step=236000, reward_env_2=903.7184448242188 -Rank-0: policy_step=236000, reward_env_3=909.9741821289062 -Rank-0: policy_step=238000, reward_env_0=851.303955078125 -Rank-0: policy_step=238000, reward_env_1=883.248046875 -Rank-0: policy_step=238000, reward_env_2=919.256591796875 -Rank-0: policy_step=238000, reward_env_3=838.8497314453125 -Rank-0: policy_step=240000, reward_env_0=897.4510498046875 -Rank-0: policy_step=240000, reward_env_1=836.6401977539062 -Rank-0: policy_step=240000, reward_env_2=879.9327392578125 -Rank-0: policy_step=240000, reward_env_3=927.852294921875 -Rank-0: policy_step=242000, reward_env_0=911.8667602539062 -Rank-0: policy_step=242000, reward_env_1=896.4654541015625 -Rank-0: policy_step=242000, reward_env_2=918.3038940429688 -Rank-0: policy_step=242000, reward_env_3=864.33544921875 -Rank-0: policy_step=244000, reward_env_0=845.3732299804688 -Rank-0: policy_step=244000, reward_env_1=909.016845703125 -Rank-0: policy_step=244000, reward_env_2=930.9536743164062 -Rank-0: policy_step=244000, reward_env_3=861.5154418945312 -Rank-0: policy_step=246000, reward_env_0=856.290771484375 -Rank-0: policy_step=246000, reward_env_1=917.77978515625 -Rank-0: policy_step=246000, reward_env_2=915.5521240234375 -Rank-0: policy_step=246000, reward_env_3=917.7088012695312 -Rank-0: policy_step=248000, reward_env_0=885.3877563476562 -Rank-0: policy_step=248000, reward_env_1=916.2377319335938 -Rank-0: policy_step=248000, reward_env_2=923.3497314453125 -Rank-0: policy_step=248000, reward_env_3=939.8336181640625 -Rank-0: policy_step=250000, reward_env_0=818.3768920898438 -Rank-0: policy_step=250000, reward_env_1=902.4088745117188 -Rank-0: policy_step=250000, reward_env_2=867.3248291015625 -Rank-0: policy_step=250000, reward_env_3=815.4102172851562 -Rank-0: policy_step=252000, reward_env_0=745.5842895507812 -Rank-0: policy_step=252000, reward_env_1=884.4244384765625 -Rank-0: policy_step=252000, reward_env_2=916.9805908203125 -Rank-0: policy_step=252000, reward_env_3=835.7476806640625 -Rank-0: policy_step=254000, reward_env_0=868.4482421875 -Rank-0: policy_step=254000, reward_env_1=820.7265625 -Rank-0: policy_step=254000, reward_env_2=815.6775512695312 -Rank-0: policy_step=254000, reward_env_3=879.8387451171875 -Rank-0: policy_step=256000, reward_env_0=883.0449829101562 -Rank-0: policy_step=256000, reward_env_1=902.7261962890625 -Rank-0: policy_step=256000, reward_env_2=878.898193359375 -Rank-0: policy_step=256000, reward_env_3=879.6201782226562 -Rank-0: policy_step=258000, reward_env_0=917.4191284179688 -Rank-0: policy_step=258000, reward_env_1=850.9492797851562 -Rank-0: policy_step=258000, reward_env_2=894.2022094726562 -Rank-0: policy_step=258000, reward_env_3=906.3563842773438 -Rank-0: policy_step=260000, reward_env_0=835.6770629882812 -Rank-0: policy_step=260000, reward_env_1=880.777587890625 -Rank-0: policy_step=260000, reward_env_2=871.7696533203125 -Rank-0: policy_step=260000, reward_env_3=763.43896484375 -Rank-0: policy_step=262000, reward_env_0=846.7047119140625 -Rank-0: policy_step=262000, reward_env_1=910.076171875 -Rank-0: policy_step=262000, reward_env_2=874.0935668945312 -Rank-0: policy_step=262000, reward_env_3=895.3096923828125 -Rank-0: policy_step=264000, reward_env_0=920.7540283203125 -Rank-0: policy_step=264000, reward_env_1=913.9781494140625 -Rank-0: policy_step=264000, reward_env_2=897.6683349609375 -Rank-0: policy_step=264000, reward_env_3=870.248779296875 -Rank-0: policy_step=266000, reward_env_0=930.111083984375 -Rank-0: policy_step=266000, reward_env_1=875.6416625976562 -Rank-0: policy_step=266000, reward_env_2=876.0927124023438 -Rank-0: policy_step=266000, reward_env_3=926.799560546875 -Rank-0: policy_step=268000, reward_env_0=799.584716796875 -Rank-0: policy_step=268000, reward_env_1=907.9052124023438 -Rank-0: policy_step=268000, reward_env_2=935.1431274414062 -Rank-0: policy_step=268000, reward_env_3=866.684814453125 -Rank-0: policy_step=270000, reward_env_0=880.8111572265625 -Rank-0: policy_step=270000, reward_env_1=880.47314453125 -Rank-0: policy_step=270000, reward_env_2=894.7467041015625 -Rank-0: policy_step=270000, reward_env_3=916.9739990234375 -Rank-0: policy_step=272000, reward_env_0=912.5625610351562 -Rank-0: policy_step=272000, reward_env_1=902.5364990234375 -Rank-0: policy_step=272000, reward_env_2=899.695068359375 -Rank-0: policy_step=272000, reward_env_3=881.3681030273438 -Rank-0: policy_step=274000, reward_env_0=893.7555541992188 -Rank-0: policy_step=274000, reward_env_1=917.3219604492188 -Rank-0: policy_step=274000, reward_env_2=952.9459228515625 -Rank-0: policy_step=274000, reward_env_3=941.3435668945312 -Rank-0: policy_step=276000, reward_env_0=920.78515625 -Rank-0: policy_step=276000, reward_env_1=876.2333374023438 -Rank-0: policy_step=276000, reward_env_2=897.0881958007812 -Rank-0: policy_step=276000, reward_env_3=878.8807373046875 -Rank-0: policy_step=278000, reward_env_0=880.7891235351562 -Rank-0: policy_step=278000, reward_env_1=887.3251953125 -Rank-0: policy_step=278000, reward_env_2=904.9336547851562 -Rank-0: policy_step=278000, reward_env_3=922.2870483398438 -Rank-0: policy_step=280000, reward_env_0=938.641357421875 -Rank-0: policy_step=280000, reward_env_1=895.8674926757812 -Rank-0: policy_step=280000, reward_env_2=925.9614868164062 -Rank-0: policy_step=280000, reward_env_3=899.3460083007812 -Rank-0: policy_step=282000, reward_env_0=874.5557250976562 -Rank-0: policy_step=282000, reward_env_1=829.7514038085938 -Rank-0: policy_step=282000, reward_env_2=817.3855590820312 -Rank-0: policy_step=282000, reward_env_3=829.8850708007812 -Rank-0: policy_step=284000, reward_env_0=861.7803344726562 -Rank-0: policy_step=284000, reward_env_1=864.5391845703125 -Rank-0: policy_step=284000, reward_env_2=883.0468139648438 -Rank-0: policy_step=284000, reward_env_3=847.9132080078125 -Rank-0: policy_step=286000, reward_env_0=880.6534423828125 -Rank-0: policy_step=286000, reward_env_1=918.8771362304688 -Rank-0: policy_step=286000, reward_env_2=945.2252197265625 -Rank-0: policy_step=286000, reward_env_3=941.3966064453125 -Rank-0: policy_step=288000, reward_env_0=903.0276489257812 -Rank-0: policy_step=288000, reward_env_1=896.660888671875 -Rank-0: policy_step=288000, reward_env_2=959.4469604492188 -Rank-0: policy_step=288000, reward_env_3=937.8251342773438 -Rank-0: policy_step=290000, reward_env_0=907.5549926757812 -Rank-0: policy_step=290000, reward_env_1=966.6063842773438 -Rank-0: policy_step=290000, reward_env_2=945.430908203125 -Rank-0: policy_step=290000, reward_env_3=907.5317993164062 -Rank-0: policy_step=292000, reward_env_0=934.0708618164062 -Rank-0: policy_step=292000, reward_env_1=908.2861328125 -Rank-0: policy_step=292000, reward_env_2=911.2447509765625 -Rank-0: policy_step=292000, reward_env_3=899.0462646484375 -Rank-0: policy_step=294000, reward_env_0=928.2252197265625 -Rank-0: policy_step=294000, reward_env_1=869.8588256835938 -Rank-0: policy_step=294000, reward_env_2=938.7529907226562 -Rank-0: policy_step=294000, reward_env_3=904.830078125 -Rank-0: policy_step=296000, reward_env_0=894.9407958984375 -Rank-0: policy_step=296000, reward_env_1=914.6753540039062 -Rank-0: policy_step=296000, reward_env_2=887.0993041992188 -Rank-0: policy_step=296000, reward_env_3=921.4598388671875 -Rank-0: policy_step=298000, reward_env_0=849.0245361328125 -Rank-0: policy_step=298000, reward_env_1=896.1520385742188 -Rank-0: policy_step=298000, reward_env_2=934.1355590820312 -Rank-0: policy_step=298000, reward_env_3=919.2213134765625 -Rank-0: policy_step=300000, reward_env_0=890.2568969726562 -Rank-0: policy_step=300000, reward_env_1=901.0765380859375 -Rank-0: policy_step=300000, reward_env_2=928.361328125 -Rank-0: policy_step=300000, reward_env_3=887.704345703125 -Rank-0: policy_step=302000, reward_env_0=870.42138671875 -Rank-0: policy_step=302000, reward_env_1=889.7203369140625 -Rank-0: policy_step=302000, reward_env_2=866.378173828125 -Rank-0: policy_step=302000, reward_env_3=899.8527221679688 -Rank-0: policy_step=304000, reward_env_0=873.331298828125 -Rank-0: policy_step=304000, reward_env_1=904.1979370117188 -Rank-0: policy_step=304000, reward_env_2=908.2112426757812 -Rank-0: policy_step=304000, reward_env_3=881.7921752929688 -Rank-0: policy_step=306000, reward_env_0=933.1976928710938 -Rank-0: policy_step=306000, reward_env_1=921.9896850585938 -Rank-0: policy_step=306000, reward_env_2=910.9968872070312 -Rank-0: policy_step=306000, reward_env_3=853.7877807617188 -Rank-0: policy_step=308000, reward_env_0=934.4517211914062 -Rank-0: policy_step=308000, reward_env_1=931.2173461914062 -Rank-0: policy_step=308000, reward_env_2=905.6231079101562 -Rank-0: policy_step=308000, reward_env_3=900.7759399414062 -Rank-0: policy_step=310000, reward_env_0=902.0452270507812 -Rank-0: policy_step=310000, reward_env_1=910.8877563476562 -Rank-0: policy_step=310000, reward_env_2=931.95068359375 -Rank-0: policy_step=310000, reward_env_3=911.6986083984375 -Rank-0: policy_step=312000, reward_env_0=884.3322143554688 -Rank-0: policy_step=312000, reward_env_1=901.164794921875 -Rank-0: policy_step=312000, reward_env_2=878.9514770507812 -Rank-0: policy_step=312000, reward_env_3=889.063232421875 -Rank-0: policy_step=314000, reward_env_0=921.4210205078125 -Rank-0: policy_step=314000, reward_env_1=921.128662109375 -Rank-0: policy_step=314000, reward_env_2=878.6793823242188 -Rank-0: policy_step=314000, reward_env_3=885.2361450195312 -Rank-0: policy_step=316000, reward_env_0=880.9342041015625 -Rank-0: policy_step=316000, reward_env_1=917.7034912109375 -Rank-0: policy_step=316000, reward_env_2=904.115966796875 -Rank-0: policy_step=316000, reward_env_3=905.0476684570312 -Rank-0: policy_step=318000, reward_env_0=848.7282104492188 -Rank-0: policy_step=318000, reward_env_1=804.3541259765625 -Rank-0: policy_step=318000, reward_env_2=878.5125732421875 -Rank-0: policy_step=318000, reward_env_3=829.8920288085938 -Rank-0: policy_step=320000, reward_env_0=734.6503295898438 -Rank-0: policy_step=320000, reward_env_1=835.25244140625 -Rank-0: policy_step=320000, reward_env_2=885.0934448242188 -Rank-0: policy_step=320000, reward_env_3=855.514892578125 -Rank-0: policy_step=322000, reward_env_0=776.9710083007812 -Rank-0: policy_step=322000, reward_env_1=844.5307006835938 -Rank-0: policy_step=322000, reward_env_2=802.7974853515625 -Rank-0: policy_step=322000, reward_env_3=837.8748779296875 -Rank-0: policy_step=324000, reward_env_0=872.334228515625 -Rank-0: policy_step=324000, reward_env_1=909.104248046875 -Rank-0: policy_step=324000, reward_env_2=848.8099975585938 -Rank-0: policy_step=324000, reward_env_3=796.0451049804688 -Rank-0: policy_step=326000, reward_env_0=919.63818359375 -Rank-0: policy_step=326000, reward_env_1=927.7220458984375 -Rank-0: policy_step=326000, reward_env_2=920.3660278320312 -Rank-0: policy_step=326000, reward_env_3=912.32421875 -Rank-0: policy_step=328000, reward_env_0=909.9505004882812 -Rank-0: policy_step=328000, reward_env_1=902.4457397460938 -Rank-0: policy_step=328000, reward_env_2=916.6644287109375 -Rank-0: policy_step=328000, reward_env_3=916.298583984375 -Rank-0: policy_step=330000, reward_env_0=873.3424682617188 -Rank-0: policy_step=330000, reward_env_1=869.6365966796875 -Rank-0: policy_step=330000, reward_env_2=884.9525756835938 -Rank-0: policy_step=330000, reward_env_3=851.2592163085938 -Rank-0: policy_step=332000, reward_env_0=919.3666381835938 -Rank-0: policy_step=332000, reward_env_1=891.314697265625 -Rank-0: policy_step=332000, reward_env_2=900.9784545898438 -Rank-0: policy_step=332000, reward_env_3=826.5481567382812 -Rank-0: policy_step=334000, reward_env_0=932.8573608398438 -Rank-0: policy_step=334000, reward_env_1=900.768798828125 -Rank-0: policy_step=334000, reward_env_2=851.2880249023438 -Rank-0: policy_step=334000, reward_env_3=937.3782958984375 -Rank-0: policy_step=336000, reward_env_0=906.57763671875 -Rank-0: policy_step=336000, reward_env_1=907.7730102539062 -Rank-0: policy_step=336000, reward_env_2=785.6316528320312 -Rank-0: policy_step=336000, reward_env_3=901.448486328125 -Rank-0: policy_step=338000, reward_env_0=914.640625 -Rank-0: policy_step=338000, reward_env_1=896.1820678710938 -Rank-0: policy_step=338000, reward_env_2=916.7769775390625 -Rank-0: policy_step=338000, reward_env_3=913.3696899414062 -Rank-0: policy_step=340000, reward_env_0=896.8836669921875 -Rank-0: policy_step=340000, reward_env_1=837.9439086914062 -Rank-0: policy_step=340000, reward_env_2=876.250732421875 -Rank-0: policy_step=340000, reward_env_3=876.9715576171875 -Rank-0: policy_step=342000, reward_env_0=930.7742919921875 -Rank-0: policy_step=342000, reward_env_1=963.5834350585938 -Rank-0: policy_step=342000, reward_env_2=910.2763671875 -Rank-0: policy_step=342000, reward_env_3=839.4712524414062 -Rank-0: policy_step=344000, reward_env_0=956.925048828125 -Rank-0: policy_step=344000, reward_env_1=928.4138793945312 -Rank-0: policy_step=344000, reward_env_2=960.3240356445312 -Rank-0: policy_step=344000, reward_env_3=907.0892333984375 -Rank-0: policy_step=346000, reward_env_0=894.4429321289062 -Rank-0: policy_step=346000, reward_env_1=886.9314575195312 -Rank-0: policy_step=346000, reward_env_2=930.2914428710938 -Rank-0: policy_step=346000, reward_env_3=942.2235717773438 -Rank-0: policy_step=348000, reward_env_0=866.168701171875 -Rank-0: policy_step=348000, reward_env_1=721.375732421875 -Rank-0: policy_step=348000, reward_env_2=884.9852905273438 -Rank-0: policy_step=348000, reward_env_3=785.744873046875 -Rank-0: policy_step=350000, reward_env_0=813.8825073242188 -Rank-0: policy_step=350000, reward_env_1=908.8901977539062 -Rank-0: policy_step=350000, reward_env_2=914.7741088867188 -Rank-0: policy_step=350000, reward_env_3=881.0470581054688 -Rank-0: policy_step=352000, reward_env_0=942.6838989257812 -Rank-0: policy_step=352000, reward_env_1=939.2142333984375 -Rank-0: policy_step=352000, reward_env_2=835.6941528320312 -Rank-0: policy_step=352000, reward_env_3=925.718994140625 -Rank-0: policy_step=354000, reward_env_0=884.5985107421875 -Rank-0: policy_step=354000, reward_env_1=927.3810424804688 -Rank-0: policy_step=354000, reward_env_2=930.7720336914062 -Rank-0: policy_step=354000, reward_env_3=930.6976318359375 -Rank-0: policy_step=356000, reward_env_0=935.0960083007812 -Rank-0: policy_step=356000, reward_env_1=891.68212890625 -Rank-0: policy_step=356000, reward_env_2=922.30908203125 -Rank-0: policy_step=356000, reward_env_3=924.6773681640625 -Rank-0: policy_step=358000, reward_env_0=861.39990234375 -Rank-0: policy_step=358000, reward_env_1=887.5126953125 -Rank-0: policy_step=358000, reward_env_2=912.4690551757812 -Rank-0: policy_step=358000, reward_env_3=872.5657958984375 -Rank-0: policy_step=360000, reward_env_0=871.79443359375 -Rank-0: policy_step=360000, reward_env_1=766.851806640625 -Rank-0: policy_step=360000, reward_env_2=784.69580078125 -Rank-0: policy_step=360000, reward_env_3=890.5595092773438 -Rank-0: policy_step=362000, reward_env_0=641.2073364257812 -Rank-0: policy_step=362000, reward_env_1=688.7249145507812 -Rank-0: policy_step=362000, reward_env_2=722.231201171875 -Rank-0: policy_step=362000, reward_env_3=769.2327270507812 -Rank-0: policy_step=364000, reward_env_0=835.4423217773438 -Rank-0: policy_step=364000, reward_env_1=776.77587890625 -Rank-0: policy_step=364000, reward_env_2=883.8353271484375 -Rank-0: policy_step=364000, reward_env_3=720.4803466796875 -Rank-0: policy_step=366000, reward_env_0=627.900146484375 -Rank-0: policy_step=366000, reward_env_1=708.9801025390625 -Rank-0: policy_step=366000, reward_env_2=696.0964965820312 -Rank-0: policy_step=366000, reward_env_3=700.3309936523438 -Rank-0: policy_step=368000, reward_env_0=745.8839721679688 -Rank-0: policy_step=368000, reward_env_1=733.6041870117188 -Rank-0: policy_step=368000, reward_env_2=765.8311157226562 -Rank-0: policy_step=368000, reward_env_3=646.5365600585938 -Rank-0: policy_step=370000, reward_env_0=894.57958984375 -Rank-0: policy_step=370000, reward_env_1=929.0679931640625 -Rank-0: policy_step=370000, reward_env_2=924.1478271484375 -Rank-0: policy_step=370000, reward_env_3=944.9603881835938 -Rank-0: policy_step=372000, reward_env_0=851.5604248046875 -Rank-0: policy_step=372000, reward_env_1=867.4108276367188 -Rank-0: policy_step=372000, reward_env_2=861.8362426757812 -Rank-0: policy_step=372000, reward_env_3=884.82763671875 -Rank-0: policy_step=374000, reward_env_0=895.9190673828125 -Rank-0: policy_step=374000, reward_env_1=933.0680541992188 -Rank-0: policy_step=374000, reward_env_2=878.9688720703125 -Rank-0: policy_step=374000, reward_env_3=912.167236328125 -Rank-0: policy_step=376000, reward_env_0=897.4933471679688 -Rank-0: policy_step=376000, reward_env_1=853.8679809570312 -Rank-0: policy_step=376000, reward_env_2=900.552734375 -Rank-0: policy_step=376000, reward_env_3=875.390380859375 -Rank-0: policy_step=378000, reward_env_0=911.134765625 -Rank-0: policy_step=378000, reward_env_1=758.4716186523438 -Rank-0: policy_step=378000, reward_env_2=849.4696044921875 -Rank-0: policy_step=378000, reward_env_3=935.5361938476562 -Rank-0: policy_step=380000, reward_env_0=909.4636840820312 -Rank-0: policy_step=380000, reward_env_1=898.4797973632812 -Rank-0: policy_step=380000, reward_env_2=890.9351196289062 -Rank-0: policy_step=380000, reward_env_3=824.7240600585938 -Rank-0: policy_step=382000, reward_env_0=889.1553344726562 -Rank-0: policy_step=382000, reward_env_1=959.9402465820312 -Rank-0: policy_step=382000, reward_env_2=929.06396484375 -Rank-0: policy_step=382000, reward_env_3=920.4011840820312 -Rank-0: policy_step=384000, reward_env_0=937.1408081054688 -Rank-0: policy_step=384000, reward_env_1=945.7388305664062 -Rank-0: policy_step=384000, reward_env_2=889.3087768554688 -Rank-0: policy_step=384000, reward_env_3=887.58642578125 -Rank-0: policy_step=386000, reward_env_0=954.7630615234375 -Rank-0: policy_step=386000, reward_env_1=929.4154052734375 -Rank-0: policy_step=386000, reward_env_2=948.8687133789062 -Rank-0: policy_step=386000, reward_env_3=851.774169921875 -Rank-0: policy_step=388000, reward_env_0=850.1627807617188 -Rank-0: policy_step=388000, reward_env_1=841.5476684570312 -Rank-0: policy_step=388000, reward_env_2=912.009521484375 -Rank-0: policy_step=388000, reward_env_3=875.6387939453125 -Rank-0: policy_step=390000, reward_env_0=906.4107666015625 -Rank-0: policy_step=390000, reward_env_1=857.533935546875 -Rank-0: policy_step=390000, reward_env_2=948.1854858398438 -Rank-0: policy_step=390000, reward_env_3=804.4624633789062 -Rank-0: policy_step=392000, reward_env_0=902.027099609375 -Rank-0: policy_step=392000, reward_env_1=891.5255126953125 -Rank-0: policy_step=392000, reward_env_2=817.2357788085938 -Rank-0: policy_step=392000, reward_env_3=877.6072998046875 -Rank-0: policy_step=394000, reward_env_0=683.1485595703125 -Rank-0: policy_step=394000, reward_env_1=834.3070678710938 -Rank-0: policy_step=394000, reward_env_2=924.4867553710938 -Rank-0: policy_step=394000, reward_env_3=413.3022766113281 -Rank-0: policy_step=396000, reward_env_0=635.8063354492188 -Rank-0: policy_step=396000, reward_env_1=955.8380126953125 -Rank-0: policy_step=396000, reward_env_2=926.856201171875 -Rank-0: policy_step=396000, reward_env_3=900.9290161132812 -Rank-0: policy_step=398000, reward_env_0=911.2310180664062 -Rank-0: policy_step=398000, reward_env_1=877.0919189453125 -Rank-0: policy_step=398000, reward_env_2=951.266845703125 -Rank-0: policy_step=398000, reward_env_3=931.3839721679688 -Rank-0: policy_step=400000, reward_env_0=884.9244995117188 -Rank-0: policy_step=400000, reward_env_1=816.6129150390625 -Rank-0: policy_step=400000, reward_env_2=927.2639770507812 -Rank-0: policy_step=400000, reward_env_3=887.2872314453125 -Rank-0: policy_step=402000, reward_env_0=854.2955932617188 -Rank-0: policy_step=402000, reward_env_1=871.5416870117188 -Rank-0: policy_step=402000, reward_env_2=847.7739868164062 -Rank-0: policy_step=402000, reward_env_3=802.7327880859375 -Rank-0: policy_step=404000, reward_env_0=932.5904541015625 -Rank-0: policy_step=404000, reward_env_1=856.1954956054688 -Rank-0: policy_step=404000, reward_env_2=775.369873046875 -Rank-0: policy_step=404000, reward_env_3=723.0234375 -Rank-0: policy_step=406000, reward_env_0=847.7246704101562 -Rank-0: policy_step=406000, reward_env_1=838.1256713867188 -Rank-0: policy_step=406000, reward_env_2=823.4154663085938 -Rank-0: policy_step=406000, reward_env_3=855.191650390625 -Rank-0: policy_step=408000, reward_env_0=877.60400390625 -Rank-0: policy_step=408000, reward_env_1=862.7110595703125 -Rank-0: policy_step=408000, reward_env_2=876.7628173828125 -Rank-0: policy_step=408000, reward_env_3=817.1781616210938 -Rank-0: policy_step=410000, reward_env_0=880.756103515625 -Rank-0: policy_step=410000, reward_env_1=647.1429443359375 -Rank-0: policy_step=410000, reward_env_2=850.9156494140625 -Rank-0: policy_step=410000, reward_env_3=942.177978515625 -Rank-0: policy_step=412000, reward_env_0=923.8487548828125 -Rank-0: policy_step=412000, reward_env_1=950.1604614257812 -Rank-0: policy_step=412000, reward_env_2=888.6689453125 -Rank-0: policy_step=412000, reward_env_3=909.7418823242188 -Rank-0: policy_step=414000, reward_env_0=905.9585571289062 -Rank-0: policy_step=414000, reward_env_1=898.7376708984375 -Rank-0: policy_step=414000, reward_env_2=938.0211791992188 -Rank-0: policy_step=414000, reward_env_3=864.6925048828125 -Rank-0: policy_step=416000, reward_env_0=926.9373779296875 -Rank-0: policy_step=416000, reward_env_1=910.0982666015625 -Rank-0: policy_step=416000, reward_env_2=891.2000732421875 -Rank-0: policy_step=416000, reward_env_3=873.0259399414062 -Rank-0: policy_step=418000, reward_env_0=830.7296752929688 -Rank-0: policy_step=418000, reward_env_1=792.3489379882812 -Rank-0: policy_step=418000, reward_env_2=785.37109375 -Rank-0: policy_step=418000, reward_env_3=848.4445190429688 -Rank-0: policy_step=420000, reward_env_0=885.6739501953125 -Rank-0: policy_step=420000, reward_env_1=950.7418823242188 -Rank-0: policy_step=420000, reward_env_2=859.4856567382812 -Rank-0: policy_step=420000, reward_env_3=805.8286743164062 -Rank-0: policy_step=422000, reward_env_0=845.3460693359375 -Rank-0: policy_step=422000, reward_env_1=880.4802856445312 -Rank-0: policy_step=422000, reward_env_2=855.9398193359375 -Rank-0: policy_step=422000, reward_env_3=882.0545654296875 -Rank-0: policy_step=424000, reward_env_0=945.623779296875 -Rank-0: policy_step=424000, reward_env_1=916.1929321289062 -Rank-0: policy_step=424000, reward_env_2=887.4605712890625 -Rank-0: policy_step=424000, reward_env_3=904.80419921875 -Rank-0: policy_step=426000, reward_env_0=863.6259155273438 -Rank-0: policy_step=426000, reward_env_1=911.1572875976562 -Rank-0: policy_step=426000, reward_env_2=941.548828125 -Rank-0: policy_step=426000, reward_env_3=884.1109008789062 -Rank-0: policy_step=428000, reward_env_0=823.77099609375 -Rank-0: policy_step=428000, reward_env_1=882.7049560546875 -Rank-0: policy_step=428000, reward_env_2=857.8377075195312 -Rank-0: policy_step=428000, reward_env_3=831.3613891601562 -Rank-0: policy_step=430000, reward_env_0=870.7410278320312 -Rank-0: policy_step=430000, reward_env_1=776.04052734375 -Rank-0: policy_step=430000, reward_env_2=822.292236328125 -Rank-0: policy_step=430000, reward_env_3=845.9228515625 -Rank-0: policy_step=432000, reward_env_0=827.0743408203125 -Rank-0: policy_step=432000, reward_env_1=892.0718383789062 -Rank-0: policy_step=432000, reward_env_2=861.017578125 -Rank-0: policy_step=432000, reward_env_3=828.2916259765625 -Rank-0: policy_step=434000, reward_env_0=722.2677001953125 -Rank-0: policy_step=434000, reward_env_1=861.9256591796875 -Rank-0: policy_step=434000, reward_env_2=522.3941650390625 -Rank-0: policy_step=434000, reward_env_3=843.7252197265625 -Rank-0: policy_step=436000, reward_env_0=887.9268798828125 -Rank-0: policy_step=436000, reward_env_1=858.6796875 -Rank-0: policy_step=436000, reward_env_2=881.55322265625 -Rank-0: policy_step=436000, reward_env_3=874.1316528320312 -Rank-0: policy_step=438000, reward_env_0=853.7929077148438 -Rank-0: policy_step=438000, reward_env_1=913.2722778320312 -Rank-0: policy_step=438000, reward_env_2=862.9351196289062 -Rank-0: policy_step=438000, reward_env_3=862.1657104492188 -Rank-0: policy_step=440000, reward_env_0=801.8331298828125 -Rank-0: policy_step=440000, reward_env_1=854.8385009765625 -Rank-0: policy_step=440000, reward_env_2=798.686767578125 -Rank-0: policy_step=440000, reward_env_3=887.5355224609375 -Rank-0: policy_step=442000, reward_env_0=883.7042236328125 -Rank-0: policy_step=442000, reward_env_1=864.5542602539062 -Rank-0: policy_step=442000, reward_env_2=801.9967041015625 -Rank-0: policy_step=442000, reward_env_3=853.4691162109375 -Rank-0: policy_step=444000, reward_env_0=896.71484375 -Rank-0: policy_step=444000, reward_env_1=883.6332397460938 -Rank-0: policy_step=444000, reward_env_2=882.015380859375 -Rank-0: policy_step=444000, reward_env_3=923.2923583984375 -Rank-0: policy_step=446000, reward_env_0=856.3253784179688 -Rank-0: policy_step=446000, reward_env_1=815.7265625 -Rank-0: policy_step=446000, reward_env_2=864.0433349609375 -Rank-0: policy_step=446000, reward_env_3=816.8611450195312 -Rank-0: policy_step=448000, reward_env_0=833.3370971679688 -Rank-0: policy_step=448000, reward_env_1=889.8046875 -Rank-0: policy_step=448000, reward_env_2=881.6996459960938 -Rank-0: policy_step=448000, reward_env_3=888.370361328125 -Rank-0: policy_step=450000, reward_env_0=852.7061157226562 -Rank-0: policy_step=450000, reward_env_1=831.8417358398438 -Rank-0: policy_step=450000, reward_env_2=873.1185302734375 -Rank-0: policy_step=450000, reward_env_3=872.6946411132812 -Rank-0: policy_step=452000, reward_env_0=913.0731811523438 -Rank-0: policy_step=452000, reward_env_1=759.593994140625 -Rank-0: policy_step=452000, reward_env_2=822.0515747070312 -Rank-0: policy_step=452000, reward_env_3=870.2621459960938 -Rank-0: policy_step=454000, reward_env_0=910.6627197265625 -Rank-0: policy_step=454000, reward_env_1=871.4953002929688 -Rank-0: policy_step=454000, reward_env_2=901.0242309570312 -Rank-0: policy_step=454000, reward_env_3=857.83642578125 -Rank-0: policy_step=456000, reward_env_0=818.6113891601562 -Rank-0: policy_step=456000, reward_env_1=866.4872436523438 -Rank-0: policy_step=456000, reward_env_2=762.2593994140625 -Rank-0: policy_step=456000, reward_env_3=819.9625244140625 -Rank-0: policy_step=458000, reward_env_0=890.6290283203125 -Rank-0: policy_step=458000, reward_env_1=913.1181030273438 -Rank-0: policy_step=458000, reward_env_2=912.7213134765625 -Rank-0: policy_step=458000, reward_env_3=844.5999755859375 -Rank-0: policy_step=460000, reward_env_0=905.0780639648438 -Rank-0: policy_step=460000, reward_env_1=881.4569091796875 -Rank-0: policy_step=460000, reward_env_2=839.7293701171875 -Rank-0: policy_step=460000, reward_env_3=893.1539916992188 -Rank-0: policy_step=462000, reward_env_0=883.0076293945312 -Rank-0: policy_step=462000, reward_env_1=877.7626953125 -Rank-0: policy_step=462000, reward_env_2=863.9375 -Rank-0: policy_step=462000, reward_env_3=881.8802490234375 -Rank-0: policy_step=464000, reward_env_0=883.3395385742188 -Rank-0: policy_step=464000, reward_env_1=863.7293090820312 -Rank-0: policy_step=464000, reward_env_2=846.1231689453125 -Rank-0: policy_step=464000, reward_env_3=870.9586181640625 -Rank-0: policy_step=466000, reward_env_0=884.1751098632812 -Rank-0: policy_step=466000, reward_env_1=862.9114379882812 -Rank-0: policy_step=466000, reward_env_2=818.036376953125 -Rank-0: policy_step=466000, reward_env_3=860.5357666015625 -Rank-0: policy_step=468000, reward_env_0=821.2963256835938 -Rank-0: policy_step=468000, reward_env_1=798.1824951171875 -Rank-0: policy_step=468000, reward_env_2=821.6298828125 -Rank-0: policy_step=468000, reward_env_3=863.978515625 -Rank-0: policy_step=470000, reward_env_0=856.8347778320312 -Rank-0: policy_step=470000, reward_env_1=833.1890869140625 -Rank-0: policy_step=470000, reward_env_2=787.0861206054688 -Rank-0: policy_step=470000, reward_env_3=801.3120727539062 -Rank-0: policy_step=472000, reward_env_0=883.5061645507812 -Rank-0: policy_step=472000, reward_env_1=791.2484130859375 -Rank-0: policy_step=472000, reward_env_2=888.4317626953125 -Rank-0: policy_step=472000, reward_env_3=280.5549011230469 -Rank-0: policy_step=474000, reward_env_0=915.7325439453125 -Rank-0: policy_step=474000, reward_env_1=921.8428955078125 -Rank-0: policy_step=474000, reward_env_2=920.48388671875 -Rank-0: policy_step=474000, reward_env_3=858.83349609375 -Rank-0: policy_step=476000, reward_env_0=894.6270141601562 -Rank-0: policy_step=476000, reward_env_1=919.9764404296875 -Rank-0: policy_step=476000, reward_env_2=858.1167602539062 -Rank-0: policy_step=476000, reward_env_3=912.5479125976562 -Rank-0: policy_step=478000, reward_env_0=921.7483520507812 -Rank-0: policy_step=478000, reward_env_1=904.6066284179688 -Rank-0: policy_step=478000, reward_env_2=845.4376220703125 -Rank-0: policy_step=478000, reward_env_3=944.1141967773438 -Rank-0: policy_step=480000, reward_env_0=888.0297241210938 -Rank-0: policy_step=480000, reward_env_1=893.0316772460938 -Rank-0: policy_step=480000, reward_env_2=902.1946411132812 -Rank-0: policy_step=480000, reward_env_3=905.1110229492188 -Rank-0: policy_step=482000, reward_env_0=885.919677734375 -Rank-0: policy_step=482000, reward_env_1=809.8402099609375 -Rank-0: policy_step=482000, reward_env_2=901.9151611328125 -Rank-0: policy_step=482000, reward_env_3=884.3057861328125 -Rank-0: policy_step=484000, reward_env_0=904.5450439453125 -Rank-0: policy_step=484000, reward_env_1=893.8999633789062 -Rank-0: policy_step=484000, reward_env_2=914.0784301757812 -Rank-0: policy_step=484000, reward_env_3=909.2919311523438 -Rank-0: policy_step=486000, reward_env_0=865.3569946289062 -Rank-0: policy_step=486000, reward_env_1=852.3546142578125 -Rank-0: policy_step=486000, reward_env_2=699.8411254882812 -Rank-0: policy_step=486000, reward_env_3=897.8310546875 -Rank-0: policy_step=488000, reward_env_0=789.0899658203125 -Rank-0: policy_step=488000, reward_env_1=865.5814208984375 -Rank-0: policy_step=488000, reward_env_2=864.0103759765625 -Rank-0: policy_step=488000, reward_env_3=809.711181640625 -Rank-0: policy_step=490000, reward_env_0=891.7025146484375 -Rank-0: policy_step=490000, reward_env_1=884.9774780273438 -Rank-0: policy_step=490000, reward_env_2=890.9956665039062 -Rank-0: policy_step=490000, reward_env_3=846.63232421875 -Rank-0: policy_step=492000, reward_env_0=901.2996826171875 -Rank-0: policy_step=492000, reward_env_1=902.7505493164062 -Rank-0: policy_step=492000, reward_env_2=899.1532592773438 -Rank-0: policy_step=492000, reward_env_3=796.5845947265625 -Rank-0: policy_step=494000, reward_env_0=831.8873901367188 -Rank-0: policy_step=494000, reward_env_1=875.12548828125 -Rank-0: policy_step=494000, reward_env_2=848.447509765625 -Rank-0: policy_step=494000, reward_env_3=882.6404418945312 -Rank-0: policy_step=496000, reward_env_0=845.0203247070312 -Rank-0: policy_step=496000, reward_env_1=889.7410888671875 -Rank-0: policy_step=496000, reward_env_2=882.0408935546875 -Rank-0: policy_step=496000, reward_env_3=859.1314697265625 -Rank-0: policy_step=498000, reward_env_0=833.8394775390625 -Rank-0: policy_step=498000, reward_env_1=888.3397827148438 -Rank-0: policy_step=498000, reward_env_2=905.193359375 -Rank-0: policy_step=498000, reward_env_3=880.4007568359375 -Rank-0: policy_step=500000, reward_env_0=864.3685302734375 -Rank-0: policy_step=500000, reward_env_1=857.979248046875 -Rank-0: policy_step=500000, reward_env_2=901.314697265625 -Rank-0: policy_step=500000, reward_env_3=861.271728515625 \ No newline at end of file From cd4a4c4cc5aa8ee8d357690f0a0fcf7bf67fa922 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Thu, 28 Mar 2024 11:21:24 +0100 Subject: [PATCH 11/51] feat: update repeats --- sheeprl/algos/dreamer_v3/dreamer_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index b3b83335..0c76583b 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -685,7 +685,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - repeats = ratio(policy_step) + repeats = ratio(policy_step / world_size) if update >= learning_starts and repeats > 0: local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, From b17d451ca5e896485d2fa383d26833e92cfbb519 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 28 Mar 2024 12:29:55 +0100 Subject: [PATCH 12/51] Let Dv3 compute bootstrap correctly --- sheeprl/algos/dreamer_v3/dreamer_v3.py | 35 ++++++++++++++++---------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 7984e85c..984b0bf0 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -171,7 +171,7 @@ def train( 1, validate_args=validate_args, ) - continue_targets = 1 - data["dones"] + continue_targets = 1 - data["terminated"] # Reshape posterior and prior logits to shape [B, T, 32, 32] priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size) @@ -255,7 +255,7 @@ def train( 1, validate_args=validate_args, ).mode - true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) + true_done = (1 - data["terminated"]).flatten().reshape(1, -1, 1) continues = torch.cat((true_done, continues[1:])) # Estimate lambda-values @@ -577,9 +577,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["is_first"] = np.ones_like(step_data["terminated"]) player.init_states() per_rank_gradient_steps = 0 @@ -626,16 +627,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) - step_data["is_first"] = np.zeros_like(step_data["dones"]) + step_data["is_first"] = np.zeros_like(step_data["terminated"]) if "restart_on_exception" in infos: for i, agent_roe in enumerate(infos["restart_on_exception"]): if agent_roe and not dones[i]: last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like( - rb.buffer[i]["dones"][last_inserted_idx] + rb.buffer[i]["terminated"][last_inserted_idx] = np.zeros_like( + rb.buffer[i]["terminated"][last_inserted_idx] + ) + rb.buffer[i]["truncated"][last_inserted_idx] = np.ones_like( + rb.buffer[i]["truncated"][last_inserted_idx] ) rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( rb.buffer[i]["is_first"][last_inserted_idx] @@ -667,7 +673,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs rewards = rewards.reshape((1, cfg.env.num_envs, -1)) - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards) dones_idxes = dones.nonzero()[0].tolist() @@ -676,15 +683,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["terminated"] = step_data["terminated"][:, dones_idxes] + reset_data["truncated"] = step_data["truncated"][:, dones_idxes] reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = step_data["rewards"][:, dones_idxes] - reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + reset_data["is_first"] = np.zeros_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset already inserted step data step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) - step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["terminated"][:, dones_idxes] = np.zeros_like(step_data["terminated"][:, dones_idxes]) + step_data["truncated"][:, dones_idxes] = np.zeros_like(step_data["truncated"][:, dones_idxes]) step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) From e8c9049531c487ff425003c95961f3c4b11ff4dd Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Thu, 28 Mar 2024 14:41:32 +0100 Subject: [PATCH 13/51] feat: added replay ratio and update exploration --- sheeprl/algos/dreamer_v1/agent.py | 23 +++-- sheeprl/algos/dreamer_v1/dreamer_v1.py | 39 +++----- sheeprl/algos/dreamer_v2/agent.py | 77 +++++++-------- sheeprl/algos/dreamer_v2/dreamer_v2.py | 45 +++------ sheeprl/algos/dreamer_v2/utils.py | 7 +- sheeprl/algos/dreamer_v3/agent.py | 131 ++++--------------------- sheeprl/algos/dreamer_v3/dreamer_v3.py | 4 +- sheeprl/algos/dreamer_v3/utils.py | 3 +- sheeprl/configs/algo/dreamer_v1.yaml | 3 +- sheeprl/configs/algo/dreamer_v2.yaml | 7 +- sheeprl/configs/exp/dreamer_v2.yaml | 3 - sheeprl/configs/exp/dreamer_v3.yaml | 3 - 12 files changed, 103 insertions(+), 242 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index d9b69d54..a8c8e28c 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -288,32 +288,38 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs]) self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) - def get_exploration_action(self, obs: Tensor, mask: Optional[Dict[str, Tensor]] = None) -> Sequence[Tensor]: + def get_exploration_action( + self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, step: int = 0 + ) -> Sequence[Tensor]: """Return the actions with a certain amount of noise for exploration. Args: obs (Tensor): the current observations. + sample_actions (bool): whether or not to sample the actions. + Default to True. mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed). Defaults to None. + step (int): the step of the training, used for the exploration amount. + Default to 0. Returns: The actions the agent has to perform (Sequence[Tensor]). """ - actions = self.get_greedy_action(obs, mask=mask) + actions = self.get_actions(obs, sample_actions=sample_actions, mask=mask) expl_actions = None if self.actor.expl_amount > 0: - expl_actions = self.actor.add_exploration_noise(actions, mask=mask) + expl_actions = self.actor.add_exploration_noise(actions, step=step, mask=mask) self.actions = torch.cat(expl_actions, dim=-1) return expl_actions or actions - def get_greedy_action( - self, obs: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + def get_actions( + self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Sequence[Tensor]: """Return the greedy actions. Args: obs (Tensor): the current observations. - is_training (bool): whether it is training. + sample_actions (bool): whether or not to sample the actions. Default to True. mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed). Defaults to None. @@ -329,7 +335,7 @@ def get_greedy_action( self.representation_model(torch.cat((self.recurrent_state, embedded_obs), -1)), validate_args=self.validate_args, ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), is_training, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) self.actions = torch.cat(actions, -1) return actions @@ -488,6 +494,9 @@ def build_agent( activation=eval(actor_cfg.dense_act), distribution_cfg=cfg.distribution, layer_norm=False, + expl_amount=actor_cfg.expl_amount, + expl_decay=actor_cfg.expl_decay, + expl_min=actor_cfg.expl_min, ) critic = MLP( input_dims=latent_state_size, diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 0326a142..da0810f4 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -27,7 +27,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -547,22 +547,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -624,7 +620,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_exploration_action(normalized_obs, mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -684,13 +680,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Reset internal agent states player.init_states(reset_envs=dones_idxes) - updates_before_training -= 1 - - # Train the agent - if update > learning_starts and updates_before_training <= 0: + repeats = ratio(policy_step / world_size) + if update >= learning_starts and repeats > 0: # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(cfg.algo.per_rank_gradient_steps): + for i in range(repeats): sample = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, @@ -713,17 +707,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg, ) train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if aggregator: - aggregator.update("Params/exploration_amount", actor.expl_amount) + aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step)) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -767,7 +752,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "last_log": last_log, diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index 6f546f43..a369349a 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -435,6 +435,10 @@ class Actor(nn.Module): Default to False. expl_amount (float): the exploration amount to use during training. Default to 0.0. + expl_decay (float): the exploration decay to use during training. + Default to 0.0. + expl_min (float): the exploration amount minimum to use during training. + Default to 0.0. """ def __init__( @@ -450,6 +454,8 @@ def __init__( mlp_layers: int = 4, layer_norm: bool = False, expl_amount: float = 0.0, + expl_decay: float = 0.0, + expl_min: float = 0.0, ) -> None: super().__init__() self.distribution_cfg = distribution_cfg @@ -485,17 +491,17 @@ def __init__( self.min_std = min_std self.distribution_cfg = distribution_cfg self._expl_amount = expl_amount + self._expl_decay = expl_decay + self._expl_min = expl_min - @property - def expl_amount(self) -> float: - return self._expl_amount - - @expl_amount.setter - def expl_amount(self, amount: float): - self._expl_amount = amount + def _get_expl_amount(self, step: int) -> Tensor: + amount = self._expl_amount + if self._expl_decay: + amount *= 0.5 ** float(step) / self._expl_decay + return max(amount, self._expl_min) def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -503,7 +509,7 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - is_training (bool): whether it is in the training phase. + sample_actions (bool): whether or not to sample the actions. Default to True. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -534,7 +540,7 @@ def forward( std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std dist = TruncatedNormal(torch.tanh(mean), std, -1, 1, validate_args=self.distribution_cfg.validate_args) actions_dist = Independent(dist, 1, validate_args=self.distribution_cfg.validate_args) - if is_training: + if sample_actions: actions = actions_dist.rsample() else: sample = actions_dist.sample((100,)) @@ -551,19 +557,20 @@ def forward( logits=logits, validate_args=self.distribution_cfg.validate_args ) ) - if is_training: + if sample_actions: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) return tuple(actions), tuple(actions_dist) def add_exploration_noise( - self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None + self, actions: Sequence[Tensor], step: int = 0, mask: Optional[Dict[str, Tensor]] = None ) -> Sequence[Tensor]: if self.is_continuous: actions = torch.cat(actions, -1) - if self._expl_amount > 0.0: - actions = torch.clip(Normal(actions, self._expl_amount).sample(), -1, 1) + expl_amount = self._get_expl_amount(step) + if expl_amount > 0.0: + actions = torch.clip(Normal(actions, expl_amount).sample(), -1, 1) expl_actions = [actions] else: expl_actions = [] @@ -593,6 +600,8 @@ def __init__( mlp_layers: int = 4, layer_norm: bool = False, expl_amount: float = 0.0, + expl_decay: float = 0.0, + expl_min: float = 0.0, ) -> None: super().__init__( latent_state_size=latent_state_size, @@ -606,10 +615,12 @@ def __init__( mlp_layers=mlp_layers, layer_norm=layer_norm, expl_amount=expl_amount, + expl_decay=expl_decay, + expl_min=expl_min, ) def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -617,7 +628,7 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - is_training (bool): whether it is in the training phase. + sample_actions (bool): whether or not to sample the actions. Default to True. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -657,7 +668,7 @@ def forward( logits=logits, validate_args=self.distribution_cfg.validate_args ) ) - if is_training: + if sample_actions: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -666,7 +677,7 @@ def forward( return tuple(actions), tuple(actions_dist) def add_exploration_noise( - self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None + self, actions: Sequence[Tensor], step: int = 0, mask: Optional[Dict[str, Tensor]] = None ) -> Sequence[Tensor]: expl_actions = [] functional_action = actions[0].argmax(dim=-1) @@ -696,7 +707,7 @@ def add_exploration_noise( sample = ( OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False).sample().to(act.device) ) - expl_amount = self.expl_amount + expl_amount = self._get_expl_amount(step) # If the action[0] was changed, and now it is critical, then we force to change also the other 2 actions # to satisfy the constraints of the environment if ( @@ -816,30 +827,10 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs]) self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) - def get_exploration_action(self, obs: Dict[str, Tensor], mask: Optional[Dict[str, Tensor]] = None) -> Tensor: - """ - Return the actions with a certain amount of noise for exploration. - - Args: - obs (Dict[str, Tensor]): the current observations. - is_continuous (bool): whether or not the actions are continuous. - mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). - Default to None. - - Returns: - The actions the agent has to perform. - """ - actions = self.get_greedy_action(obs, mask=mask) - expl_actions = None - if self.actor.expl_amount > 0: - expl_actions = self.actor.add_exploration_noise(actions, mask=mask) - self.actions = torch.cat(expl_actions, dim=-1) - return expl_actions or actions - - def get_greedy_action( + def get_actions( self, obs: Dict[str, Tensor], - is_training: bool = True, + sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, ) -> Sequence[Tensor]: """ @@ -847,7 +838,7 @@ def get_greedy_action( Args: obs (Dict[str, Tensor]): the current observations. - is_training (bool): whether it is training. + sample_actions (bool): whether or not to sample the actions. Default to True. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -866,7 +857,7 @@ def get_greedy_action( self.stochastic_state = stochastic_state.view( *stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), is_training, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) self.actions = torch.cat(actions, -1) return actions diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 613ce927..ff296172 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -32,7 +32,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -579,22 +579,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -658,7 +654,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_actions(normalized_obs, mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -720,13 +716,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Reset internal agent states player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - n_samples = ( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ) + repeats = ratio(policy_step / world_size) + if update >= learning_starts and repeats > 0: + n_samples = cfg.algo.per_rank_pretrain_steps if update == learning_starts else ratio local_data = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, @@ -736,7 +729,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): from_numpy=cfg.buffer.from_numpy, ) with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): + for i in range(n_samples): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()): tcp.data.copy_(cp.data) @@ -757,17 +750,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) per_rank_gradient_steps += 1 train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount", actor.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -777,6 +759,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log("Params/replay_ratio", per_rank_gradient_steps * world_size / policy_step, policy_step) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -812,7 +797,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "last_log": last_log, diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index 799674fc..d64134e9 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence import gymnasium as gym import torch @@ -34,7 +34,6 @@ "State/post_entropy", "State/prior_entropy", "State/kl", - "Params/exploration_amount", "Grads/world_model", "Grads/actor", "Grads/critic", @@ -107,7 +106,7 @@ def compute_lambda_values( @torch.no_grad() def test( - player: Union["PlayerDV2", "PlayerDV1"], + player: "PlayerDV2" | "PlayerDV1", fabric: Fabric, cfg: Dict[str, Any], log_dir: str, @@ -143,7 +142,7 @@ def test( preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 elif k in cfg.algo.mlp_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) - real_actions = player.get_greedy_action( + real_actions = player.get_actions( preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index f29546dc..b7fa0065 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -18,11 +18,7 @@ from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder -from sheeprl.utils.distribution import ( - OneHotCategoricalStraightThroughValidateArgs, - OneHotCategoricalValidateArgs, - TruncatedNormal, -) +from sheeprl.utils.distribution import OneHotCategoricalStraightThroughValidateArgs, TruncatedNormal from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward from sheeprl.utils.utils import symlog @@ -641,29 +637,10 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs], sample_state=False )[1].reshape(1, len(reset_envs), -1) - def get_exploration_action(self, obs: Dict[str, Tensor], mask: Optional[Dict[str, Tensor]] = None) -> Tensor: - """ - Return the actions with a certain amount of noise for exploration. - - Args: - obs (Dict[str, Tensor]): the current observations. - mask (Dict[str, Tensor], optional): the mask of the actions. - Default to None. - - Returns: - The actions the agent has to perform. - """ - actions = self.get_greedy_action(obs, mask=mask) - expl_actions = None - if self.actor.expl_amount > 0: - expl_actions = self.actor.add_exploration_noise(actions, mask=mask) - self.actions = torch.cat(expl_actions, dim=-1) - return expl_actions or actions - - def get_greedy_action( + def get_actions( self, obs: Dict[str, Tensor], - is_training: bool = True, + sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, ) -> Sequence[Tensor]: """ @@ -671,7 +648,7 @@ def get_greedy_action( Args: obs (Dict[str, Tensor]): the current observations. - is_training (bool): whether it is training. + sample_actions (bool): whether or not to sample the actions. Default to True. Returns: @@ -688,7 +665,7 @@ def get_greedy_action( self.stochastic_state = self.stochastic_state.view( *self.stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), is_training, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) self.actions = torch.cat(actions, -1) return actions @@ -720,8 +697,6 @@ class Actor(nn.Module): then `p = (1 - self.unimix) * p + self.unimix * unif`, where `unif = `1 / self.discrete`. Defaults to 0.01. - expl_amount (float): the exploration amount to use during training. - Default to 0.0. """ def __init__( @@ -737,7 +712,6 @@ def __init__( mlp_layers: int = 5, layer_norm: bool = True, unimix: float = 0.01, - expl_amount: float = 0.0, ) -> None: super().__init__() self.distribution_cfg = distribution_cfg @@ -775,18 +749,9 @@ def __init__( self.init_std = torch.tensor(init_std) self.min_std = min_std self._unimix = unimix - self._expl_amount = expl_amount - - @property - def expl_amount(self) -> float: - return self._expl_amount - - @expl_amount.setter - def expl_amount(self, amount: float): - self._expl_amount = amount def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -794,6 +759,10 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). + sample_actions (bool): whether or not to sample the actions. + Default to True. + mask (Dict[str, Tensor], optional): the mask to use on the actions. + Default to None. Returns: The tensor of the actions taken by the agent with shape (batch_size, *, num_actions). @@ -821,7 +790,7 @@ def forward( std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std dist = TruncatedNormal(torch.tanh(mean), std, -1, 1, validate_args=self.distribution_cfg.validate_args) actions_dist = Independent(dist, 1, validate_args=self.distribution_cfg.validate_args) - if is_training: + if sample_actions: actions = actions_dist.rsample() else: sample = actions_dist.sample((100,)) @@ -838,7 +807,7 @@ def forward( logits=self._uniform_mix(logits), validate_args=self.distribution_cfg.validate_args ) ) - if is_training: + if sample_actions: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -852,27 +821,6 @@ def _uniform_mix(self, logits: Tensor) -> Tensor: logits = probs_to_logits(probs) return logits - def add_exploration_noise( - self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None - ) -> Sequence[Tensor]: - if self.is_continuous: - actions = torch.cat(actions, -1) - if self._expl_amount > 0.0: - actions = torch.clip(Normal(actions, self._expl_amount).sample(), -1, 1) - expl_actions = [actions] - else: - expl_actions = [] - for act in actions: - sample = ( - OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False) - .sample() - .to(act.device) - ) - expl_actions.append( - torch.where(torch.rand(act.shape[:1], device=act.device) < self._expl_amount, sample, act) - ) - return tuple(expl_actions) - class MinedojoActor(Actor): def __init__( @@ -888,7 +836,6 @@ def __init__( mlp_layers: int = 5, layer_norm: bool = True, unimix: float = 0.01, - expl_amount: float = 0.0, ) -> None: super().__init__( latent_state_size=latent_state_size, @@ -902,11 +849,10 @@ def __init__( mlp_layers=mlp_layers, layer_norm=layer_norm, unimix=unimix, - expl_amount=expl_amount, ) def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -914,6 +860,10 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). + sample_actions (bool): whether or not to sample the actions. + Default to True. + mask (Dict[str, Tensor], optional): the mask to apply to the actions. + Default to None. Returns: The tensor of the actions taken by the agent with shape (batch_size, *, num_actions). @@ -950,7 +900,7 @@ def forward( logits=logits, validate_args=self.distribution_cfg.validate_args ) ) - if is_training: + if sample_actions: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -958,51 +908,6 @@ def forward( functional_action = actions[0].argmax(dim=-1) # [T, B] return tuple(actions), tuple(actions_dist) - def add_exploration_noise( - self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None - ) -> Sequence[Tensor]: - expl_actions = [] - functional_action = actions[0].argmax(dim=-1) - for i, act in enumerate(actions): - logits = torch.zeros_like(act) - # Exploratory action must respect the constraints of the environment - if mask is not None: - if i == 0: - logits[torch.logical_not(mask["mask_action_type"].expand_as(logits))] = -torch.inf - elif i == 1: - mask["mask_craft_smelt"] = mask["mask_craft_smelt"].expand_as(logits) - for t in range(functional_action.shape[0]): - for b in range(functional_action.shape[1]): - sampled_action = functional_action[t, b].item() - if sampled_action == 15: # Craft action - logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf - elif i == 2: - mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits) - mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits) - for t in range(functional_action.shape[0]): - for b in range(functional_action.shape[1]): - sampled_action = functional_action[t, b].item() - if sampled_action in {16, 17}: # Equip/Place action - logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf - elif sampled_action == 18: # Destroy action - logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf - sample = ( - OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False).sample().to(act.device) - ) - expl_amount = self.expl_amount - # If the action[0] was changed, and now it is critical, then we force to change also the other 2 actions - # to satisfy the constraints of the environment - if ( - i in {1, 2} - and actions[0].argmax() != expl_actions[0].argmax() - and expl_actions[0].argmax().item() in {15, 16, 17, 18} - ): - expl_amount = 2 - expl_actions.append(torch.where(torch.rand(act.shape[:1], device=self.device) < expl_amount, sample, act)) - if mask is not None and i == 0: - functional_action = expl_actions[0].argmax(dim=-1) - return tuple(expl_actions) - def build_agent( fabric: Fabric, diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 0c76583b..551dac57 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -610,7 +610,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) + real_actions = actions = player.get_actions(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() @@ -696,7 +696,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): from_numpy=cfg.buffer.from_numpy, ) with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): + for i in range(repeats): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index d18d1fb0..efd10ac3 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -30,7 +30,6 @@ "State/kl", "State/post_entropy", "State/prior_entropy", - "Params/exploration_amount", "Grads/world_model", "Grads/actor", "Grads/critic", @@ -116,7 +115,7 @@ def test( preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 elif k in cfg.algo.mlp_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) - real_actions = player.get_greedy_action( + real_actions = player.get_actions( preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: diff --git a/sheeprl/configs/algo/dreamer_v1.yaml b/sheeprl/configs/algo/dreamer_v1.yaml index 212668f5..e76536fc 100644 --- a/sheeprl/configs/algo/dreamer_v1.yaml +++ b/sheeprl/configs/algo/dreamer_v1.yaml @@ -100,8 +100,7 @@ actor: clip_gradients: 100.0 expl_amount: 0.3 expl_min: 0.0 - expl_decay: False - max_step_expl_decay: 200000 + expl_decay: 0.0 # Actor optimizer optimizer: diff --git a/sheeprl/configs/algo/dreamer_v2.yaml b/sheeprl/configs/algo/dreamer_v2.yaml index 719d91ec..a252728c 100644 --- a/sheeprl/configs/algo/dreamer_v2.yaml +++ b/sheeprl/configs/algo/dreamer_v2.yaml @@ -11,9 +11,8 @@ lmbda: 0.95 horizon: 15 # Training recipe -train_every: 5 +replay_ratio: 2 learning_starts: 1000 -per_rank_gradient_steps: 1 per_rank_pretrain_steps: 100 per_rank_sequence_length: ??? @@ -111,10 +110,6 @@ actor: layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} clip_gradients: 100.0 - expl_amount: 0.0 - expl_min: 0.0 - expl_decay: False - max_step_expl_decay: 0 # Actor optimizer optimizer: diff --git a/sheeprl/configs/exp/dreamer_v2.yaml b/sheeprl/configs/exp/dreamer_v2.yaml index 66faf0c9..5565d62d 100644 --- a/sheeprl/configs/exp/dreamer_v2.yaml +++ b/sheeprl/configs/exp/dreamer_v2.yaml @@ -63,9 +63,6 @@ metric: State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} Grads/world_model: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/dreamer_v3.yaml b/sheeprl/configs/exp/dreamer_v3.yaml index 1c0f1419..fc51c1d2 100644 --- a/sheeprl/configs/exp/dreamer_v3.yaml +++ b/sheeprl/configs/exp/dreamer_v3.yaml @@ -62,9 +62,6 @@ metric: State/prior_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} Grads/world_model: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} From 88c69689cfcc14d1cec41c9fb7607789cdd73035 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 28 Mar 2024 14:53:33 +0100 Subject: [PATCH 14/51] Fix exploration actions computation on DV1 --- sheeprl/algos/dreamer_v1/agent.py | 2 +- sheeprl/algos/dreamer_v1/dreamer_v1.py | 18 +++++++++--------- sheeprl/algos/dreamer_v2/agent.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index a8c8e28c..9de6bc4a 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -241,7 +241,7 @@ def __init__( encoder: nn.Module | _FabricModule, recurrent_model: nn.Module | _FabricModule, representation_model: nn.Module | _FabricModule, - actor: nn.Module | _FabricModule, + actor: DV2Actor | _FabricModule, actions_dim: Sequence[int], num_envs: int, stochastic_size: int, diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index da0810f4..ba7f15d4 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -684,16 +684,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if update >= learning_starts and repeats > 0: # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + sample = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=repeats, + dtype=None, + device=device, + from_numpy=cfg.buffer.from_numpy, + ) # [N_samples, Seq_len, Batch_size, ...] for i in range(repeats): - sample = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=1, - dtype=None, - device=device, - from_numpy=cfg.buffer.from_numpy, - ) # [N_samples, Seq_len, Batch_size, ...] - batch = {k: v[0].float() for k, v in sample.items()} + batch = {k: v[i].float() for k, v in sample.items()} train( fabric, world_model, diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index a369349a..63f766f3 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -566,9 +566,9 @@ def forward( def add_exploration_noise( self, actions: Sequence[Tensor], step: int = 0, mask: Optional[Dict[str, Tensor]] = None ) -> Sequence[Tensor]: + expl_amount = self._get_expl_amount(step) if self.is_continuous: actions = torch.cat(actions, -1) - expl_amount = self._get_expl_amount(step) if expl_amount > 0.0: actions = torch.clip(Normal(actions, expl_amount).sample(), -1, 1) expl_actions = [actions] @@ -581,7 +581,7 @@ def add_exploration_noise( .to(act.device) ) expl_actions.append( - torch.where(torch.rand(act.shape[:1], device=act.device) < self._expl_amount, sample, act) + torch.where(torch.rand(act.shape[:1], device=act.device) < expl_amount, sample, act) ) return tuple(expl_actions) From a5c957c22fdbae078c2a177644b358248e7e3ec8 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 28 Mar 2024 15:02:53 +0100 Subject: [PATCH 15/51] Fix naming --- sheeprl/algos/dreamer_v1/dreamer_v1.py | 16 ++++++++++++---- sheeprl/algos/dreamer_v2/dreamer_v2.py | 15 +++++++++------ sheeprl/algos/dreamer_v3/dreamer_v3.py | 17 ++++++++++------- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index ba7f15d4..33934a4b 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -588,6 +588,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() + per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -680,19 +682,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Reset internal agent states player.init_states(reset_envs=dones_idxes) - repeats = ratio(policy_step / world_size) - if update >= learning_starts and repeats > 0: + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): sample = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=repeats, + n_samples=per_rank_gradient_steps, dtype=None, device=device, from_numpy=cfg.buffer.from_numpy, ) # [N_samples, Seq_len, Batch_size, ...] - for i in range(repeats): + for i in range(per_rank_gradient_steps): batch = {k: v[i].float() for k, v in sample.items()} train( fabric, @@ -706,6 +708,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): aggregator, cfg, ) + cumulative_per_rank_gradient_steps += 1 train_step += world_size if aggregator: aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step)) @@ -718,6 +721,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index ff296172..74238a8d 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -622,6 +622,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states() per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -717,9 +718,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - repeats = ratio(policy_step / world_size) - if update >= learning_starts and repeats > 0: - n_samples = cfg.algo.per_rank_pretrain_steps if update == learning_starts else ratio + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: + n_samples = cfg.algo.per_rank_pretrain_steps if update == learning_starts else per_rank_gradient_steps local_data = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, @@ -730,7 +731,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(n_samples): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()): tcp.data.copy_(cp.data) batch = {k: v[i].float() for k, v in local_data.items()} @@ -748,7 +749,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg, actions_dim, ) - per_rank_gradient_steps += 1 + cumulative_per_rank_gradient_steps += 1 train_step += world_size # Log metrics @@ -760,7 +761,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): aggregator.reset() # Log replay ratio - fabric.log("Params/replay_ratio", per_rank_gradient_steps * world_size / policy_step, policy_step) + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) # Sync distributed timers if not timer.disabled: diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 551dac57..be125a73 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -579,6 +579,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states() per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -685,19 +686,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - repeats = ratio(policy_step / world_size) - if update >= learning_starts and repeats > 0: + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=repeats, + n_samples=per_rank_gradient_steps, dtype=None, device=fabric.device, from_numpy=cfg.buffer.from_numpy, ) with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(repeats): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + for i in range(per_rank_gradient_steps): + if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) @@ -718,7 +719,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions_dim, moments, ) - per_rank_gradient_steps += 1 + cumulative_per_rank_gradient_steps += 1 train_step += world_size # Log metrics @@ -730,7 +731,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): aggregator.reset() # Log replay ratio - fabric.log("Params/replay_ratio", per_rank_gradient_steps * world_size / policy_step, policy_step) + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) # Sync distributed timers if not timer.disabled: From c36577ddce5225d21d6049a73bd318832e40516f Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 28 Mar 2024 15:57:06 +0100 Subject: [PATCH 16/51] Add replay-ratio to SAC --- sheeprl/algos/sac/sac.py | 23 +++++++++++++++++----- sheeprl/algos/sac/sac_decoupled.py | 31 ++++++++++++++++++++---------- sheeprl/configs/algo/sac.yaml | 2 +- 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 73f1ca43..2c9e56ae 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -27,7 +27,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import save_configs +from sheeprl.utils.utils import Ratio, save_configs def train( @@ -211,6 +211,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -232,6 +237,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] obs = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) + per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -282,12 +289,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs # Train the agent - if update >= learning_starts: - training_steps = learning_starts if update == learning_starts else 1 - + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: # We sample one time to reduce the communications between processes sample = rb.sample_tensors( - batch_size=training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, + batch_size=per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs, dtype=None, device=device, @@ -334,6 +340,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg, policy_steps_per_update, ) + cumulative_per_rank_gradient_steps += 1 train_step += world_size # Log metrics @@ -344,6 +351,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -376,6 +388,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "qf_optimizer": qf_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), + "ratio": ratio.state_dict(), "update": update * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 44a23569..dae1a7fa 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -26,7 +26,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import save_configs +from sheeprl.utils.utils import Ratio, save_configs @torch.inference_mode() @@ -149,6 +149,11 @@ def player( if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -170,6 +175,8 @@ def player( obs = envs.reset(seed=cfg.seed)[0] obs = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) + per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs @@ -219,7 +226,9 @@ def player( obs = next_obs # Send data to the training agents - if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / (fabric.world_size - 1)) + cumulative_per_rank_gradient_steps += per_rank_gradient_steps + if update >= learning_starts and per_rank_gradient_steps > 0: # Send local info to the trainers if not first_info_sent: world_collective.broadcast_object_list( @@ -228,12 +237,8 @@ def player( first_info_sent = True # Sample data to be sent to the trainers - training_steps = learning_starts if update == learning_starts else 1 sample = rb.sample_tensors( - batch_size=training_steps - * cfg.algo.per_rank_gradient_steps - * cfg.algo.per_rank_batch_size - * (fabric.world_size - 1), + batch_size=per_rank_gradient_steps * cfg.algo.per_rank_batch_size * (fabric.world_size - 1), sample_next_obs=cfg.buffer.sample_next_obs, dtype=None, device=device, @@ -241,8 +246,7 @@ def player( ) # chunks = {k1: [k1_chunk_1, k1_chunk_2, ...], k2: [k2_chunk_1, k2_chunk_2, ...]} chunks = { - k: v.float().split(training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size) - for k, v in sample.items() + k: v.float().split(per_rank_gradient_steps * cfg.algo.per_rank_batch_size) for k, v in sample.items() } # chunks = [{k1: k1_chunk_1, k2: k2_chunk_1}, {k1: k1_chunk_2, k2: k2_chunk_2}, ...] chunks = [{k: v[i] for k, v in chunks.items()} for i in range(len(chunks[next(iter(chunks.keys()))]))] @@ -269,6 +273,13 @@ def player( fabric.log_dict(aggregator.compute(), policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", + cumulative_per_rank_gradient_steps * (fabric.world_size - 1) / policy_step, + policy_step, + ) + # Sync timers if not timer.disabled: timer_metrics = timer.compute() @@ -401,7 +412,7 @@ def trainer( if not MetricAggregator.disabled: aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device) - # Receive data from player reagrding the: + # Receive data from player regarding the: # * update # * last_log # * last_checkpoint diff --git a/sheeprl/configs/algo/sac.yaml b/sheeprl/configs/algo/sac.yaml index 452f447e..fcc428c0 100644 --- a/sheeprl/configs/algo/sac.yaml +++ b/sheeprl/configs/algo/sac.yaml @@ -11,8 +11,8 @@ gamma: 0.99 hidden_size: 256 # Training recipe +replay_ratio: 1.0 learning_starts: 100 -per_rank_gradient_steps: 1 # Model related parameters # Actor From 0bc9f07ae7e892d9553829ed9eb54f195801e85e Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Thu, 28 Mar 2024 17:26:56 +0100 Subject: [PATCH 17/51] feat: added replay ratio to p2e algos --- notebooks/dreamer_v3_imagination.ipynb | 2 +- sheeprl/algos/dreamer_v1/agent.py | 2 +- sheeprl/algos/dreamer_v1/dreamer_v1.py | 3 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 1 - sheeprl/algos/dreamer_v3/dreamer_v3.py | 3 +- sheeprl/algos/p2e_dv1/agent.py | 3 + sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 80 ++++++++----------- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 82 ++++++++------------ sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 66 +++++----------- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 67 +++++----------- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 67 +++++----------- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 69 ++++++---------- 12 files changed, 154 insertions(+), 291 deletions(-) diff --git a/notebooks/dreamer_v3_imagination.ipynb b/notebooks/dreamer_v3_imagination.ipynb index 1a7f2a44..c58531f2 100644 --- a/notebooks/dreamer_v3_imagination.ipynb +++ b/notebooks/dreamer_v3_imagination.ipynb @@ -230,7 +230,7 @@ " mask = {k: v for k, v in preprocessed_obs.items() if k.startswith(\"mask\")}\n", " if len(mask) == 0:\n", " mask = None\n", - " real_actions = actions = player.get_exploration_action(preprocessed_obs, mask)\n", + " real_actions = actions = player.get_actions(preprocessed_obs, mask)\n", " actions = torch.cat(actions, -1).cpu().numpy()\n", " if is_continuous:\n", " real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()\n", diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index 9de6bc4a..4885a26d 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -288,7 +288,7 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs]) self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) - def get_exploration_action( + def get_exploration_actions( self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, step: int = 0 ) -> Sequence[Tensor]: """Return the actions with a certain amount of noise for exploration. diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 33934a4b..fdbe21bf 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -588,7 +588,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() - per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -622,7 +621,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask, step=policy_step) + real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 74238a8d..e169450a 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -621,7 +621,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() - per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index be125a73..047b69c2 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -578,7 +578,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["is_first"] = np.ones_like(step_data["dones"]) player.init_states() - per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -699,7 +698,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(per_rank_gradient_steps): if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau + tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) batch = {k: v[i].float() for k, v in local_data.items()} diff --git a/sheeprl/algos/p2e_dv1/agent.py b/sheeprl/algos/p2e_dv1/agent.py index 7d13b21d..32a6269d 100644 --- a/sheeprl/algos/p2e_dv1/agent.py +++ b/sheeprl/algos/p2e_dv1/agent.py @@ -95,6 +95,9 @@ def build_agent( activation=eval(actor_cfg.dense_act), distribution_cfg=cfg.distribution, layer_norm=False, + expl_amount=actor_cfg.expl_amount, + expl_decay=actor_cfg.expl_decay, + expl_min=actor_cfg.expl_min, ) critic_task = MLP( input_dims=latent_state_size, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 9ad04faa..059eeed4 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -29,7 +29,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -557,7 +557,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = {} - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -574,27 +573,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -624,6 +614,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -656,7 +647,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -716,22 +707,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Reset internal agent states player.init_states(reset_envs=dones_idxes) - updates_before_training -= 1 - - # Train the agent - if update >= learning_starts and updates_before_training <= 0: + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(cfg.algo.per_rank_gradient_steps): - sample = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=1, - dtype=None, - device=device, - from_numpy=cfg.buffer.from_numpy, - ) # [N_samples, Seq_len, Batch_size, ...] - batch = {k: v[0].float() for k, v in sample.items()} + sample = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=device, + from_numpy=cfg.buffer.from_numpy, + ) # [N_samples, Seq_len, Batch_size, ...] + for i in range(per_rank_gradient_steps): + batch = {k: v[i].float() for k, v in sample.items()} train( fabric, world_model, @@ -750,25 +739,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actor_exploration_optimizer=actor_exploration_optimizer, critic_exploration_optimizer=critic_exploration_optimizer, ) + cumulative_per_rank_gradient_steps += 1 train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) + aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step)) + aggregator.update( + "Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step) + ) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -778,6 +755,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -814,7 +796,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 6da54a91..7037b46e 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -24,7 +24,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -201,7 +201,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if resume_from_checkpoint else 0 # Global variables train_step = 0 @@ -218,27 +217,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): last_log = state["last_log"] if resume_from_checkpoint else 0 last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if resume_from_checkpoint: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if resume_from_checkpoint and not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -268,6 +258,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -284,7 +275,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -344,24 +335,22 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Reset internal agent states player.init_states(reset_envs=dones_idxes) - updates_before_training -= 1 - - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - if player.actor_type == "exploration": + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: + if player.actor_type != "task": player.actor_type = "task" player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(cfg.algo.per_rank_gradient_steps): - sample = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=1, - dtype=None, - device=device, - from_numpy=cfg.buffer.from_numpy, - ) # [N_samples, Seq_len, Batch_size, ...] - batch = {k: v[0].float() for k, v in sample.items()} + sample = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=device, + from_numpy=cfg.buffer.from_numpy, + ) # [N_samples, Seq_len, Batch_size, ...] + for i in range(per_rank_gradient_steps): + batch = {k: v[i].float() for k, v in sample.items()} train( fabric, world_model, @@ -374,25 +363,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): aggregator, cfg, ) + cumulative_per_rank_gradient_steps += 1 train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) + aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step)) + aggregator.update( + "Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step) + ) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -402,6 +379,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -436,7 +418,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 54fad041..f66881f9 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -29,7 +29,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -103,7 +103,6 @@ def train( critic_exploration_optimizer (_FabricOptimizer): the optimizer of the critic for exploration. is_continuous (bool): whether or not are continuous actions. actions_dim (Sequence[int]): the actions dimension. - is_exploring (bool): whether the agent is exploring. """ batch_size = cfg.algo.per_rank_batch_size sequence_length = cfg.algo.per_rank_sequence_length @@ -702,7 +701,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -719,27 +717,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -770,7 +759,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -803,7 +792,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_actions(normalized_obs, mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -865,13 +854,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Reset internal agent states player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - n_samples = ( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ) + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: + n_samples = cfg.algo.per_rank_pretrain_steps if update == learning_starts else per_rank_gradient_steps local_data = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, @@ -882,8 +868,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + for i in range(n_samples): + if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(cp.data) for cp, tcp in zip( @@ -913,25 +899,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous=is_continuous, actions_dim=actions_dim, ) + cumulative_per_rank_gradient_steps += 1 train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -941,6 +910,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -978,7 +952,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index d00f8c69..2b33a960 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -24,7 +24,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -220,7 +220,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if resume_from_checkpoint else 0 # Global variables train_step = 0 @@ -237,27 +236,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): last_log = state["last_log"] if resume_from_checkpoint else 0 last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if resume_from_checkpoint: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -288,7 +278,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -305,7 +295,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_actions(normalized_obs, mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -367,16 +357,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Reset internal agent states player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - if player.actor_type == "exploration": + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: + if player.actor_type != "task": player.actor_type = "task" player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) - n_samples = ( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ) + n_samples = cfg.algo.per_rank_pretrain_steps if update == learning_starts else per_rank_gradient_steps local_data = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, @@ -387,8 +374,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + for i in range(n_samples): + if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(cp.data) batch = {k: v[i].float() for k, v in local_data.items()} @@ -406,25 +393,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): cfg, actions_dim=actions_dim, ) + cumulative_per_rank_gradient_steps += 1 train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -434,6 +404,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -469,7 +444,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 106cb655..8ae603e1 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -33,7 +33,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -768,7 +768,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -785,27 +784,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -832,7 +822,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["is_first"] = np.ones_like(step_data["dones"]) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -864,7 +854,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) + real_actions = actions = player.get_actions(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() @@ -938,25 +928,22 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ), + n_samples=per_rank_gradient_steps, dtype=None, device=fabric.device, from_numpy=cfg.buffer.from_numpy, ) # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau + for i in range(per_rank_gradient_steps): + if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) for k in critics_exploration.keys(): @@ -988,25 +975,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): moments_exploration=moments_exploration, moments_task=moments_task, ) + cumulative_per_rank_gradient_steps += 1 train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -1016,6 +986,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -1061,7 +1036,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index b5b482f8..3ac28ed4 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -22,7 +22,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs @register_algorithm() @@ -215,7 +215,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if resume_from_checkpoint else 0 # Global variables train_step = 0 @@ -232,27 +231,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): last_log = state["last_log"] if resume_from_checkpoint else 0 last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if resume_from_checkpoint: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -279,7 +269,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): step_data["is_first"] = np.ones_like(step_data["dones"]) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -295,7 +285,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) + real_actions = actions = player.get_actions(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() @@ -369,28 +359,25 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - if player.actor_type == "exploration": + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: + if player.actor_type != "task": player.actor_type = "task" player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ), + n_samples=per_rank_gradient_steps, dtype=None, device=fabric.device, from_numpy=cfg.buffer.from_numpy, ) # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + for i in range(per_rank_gradient_steps): + if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) batch = {k: v[i].float() for k, v in local_data.items()} @@ -410,25 +397,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): actions_dim=actions_dim, moments=moments_task, ) + cumulative_per_rank_gradient_steps += 1 train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -438,6 +408,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -473,7 +448,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), From b5fbe5dd8faf54a44a280930d121f633e241c0e3 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Thu, 28 Mar 2024 17:38:58 +0100 Subject: [PATCH 18/51] feat: update configs and utils of p2e algos --- sheeprl/algos/p2e_dv2/utils.py | 2 -- sheeprl/algos/p2e_dv3/utils.py | 2 -- sheeprl/configs/exp/p2e_dv2_exploration.yaml | 6 ------ sheeprl/configs/exp/p2e_dv2_finetuning.yaml | 6 ------ sheeprl/configs/exp/p2e_dv3_exploration.yaml | 6 ------ sheeprl/configs/exp/p2e_dv3_finetuning.yaml | 6 ------ 6 files changed, 28 deletions(-) diff --git a/sheeprl/algos/p2e_dv2/utils.py b/sheeprl/algos/p2e_dv2/utils.py index c717ce2d..91847673 100644 --- a/sheeprl/algos/p2e_dv2/utils.py +++ b/sheeprl/algos/p2e_dv2/utils.py @@ -29,8 +29,6 @@ "State/kl", "State/post_entropy", "State/prior_entropy", - "Params/exploration_amount_task", - "Params/exploration_amount_exploration", "Rewards/intrinsic", "Values_exploration/predicted_values", "Values_exploration/lambda_values", diff --git a/sheeprl/algos/p2e_dv3/utils.py b/sheeprl/algos/p2e_dv3/utils.py index c126e6c2..c2563336 100644 --- a/sheeprl/algos/p2e_dv3/utils.py +++ b/sheeprl/algos/p2e_dv3/utils.py @@ -28,8 +28,6 @@ "Loss/continue_loss", "Loss/ensemble_loss", "State/kl", - "Params/exploration_amount_task", - "Params/exploration_amount_exploration", "State/post_entropy", "State/prior_entropy", "Grads/world_model", diff --git a/sheeprl/configs/exp/p2e_dv2_exploration.yaml b/sheeprl/configs/exp/p2e_dv2_exploration.yaml index bae53323..3c33a758 100644 --- a/sheeprl/configs/exp/p2e_dv2_exploration.yaml +++ b/sheeprl/configs/exp/p2e_dv2_exploration.yaml @@ -51,12 +51,6 @@ metric: State/prior_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_task: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_exploration: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} Rewards/intrinsic: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/p2e_dv2_finetuning.yaml b/sheeprl/configs/exp/p2e_dv2_finetuning.yaml index 1d315969..e55ca8b2 100644 --- a/sheeprl/configs/exp/p2e_dv2_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv2_finetuning.yaml @@ -52,12 +52,6 @@ metric: State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_task: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_exploration: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} Grads/world_model: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/p2e_dv3_exploration.yaml b/sheeprl/configs/exp/p2e_dv3_exploration.yaml index 009c2d99..66b475a3 100644 --- a/sheeprl/configs/exp/p2e_dv3_exploration.yaml +++ b/sheeprl/configs/exp/p2e_dv3_exploration.yaml @@ -48,12 +48,6 @@ metric: State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_task: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_exploration: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} State/post_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml index 395d47cb..d03b1495 100644 --- a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml @@ -46,12 +46,6 @@ metric: State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_task: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_exploration: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} State/post_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} From 24c9352a67c76a5a041f489b55ef5851ddb6f96f Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 28 Mar 2024 22:46:20 +0100 Subject: [PATCH 19/51] Add replay-ratio to SAC-AE --- sheeprl/algos/sac_ae/sac_ae.py | 49 ++++++++++++++++++-------------- sheeprl/configs/algo/sac_ae.yaml | 5 ++-- sheeprl/data/buffers.py | 5 +++- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 09c47b56..a40c5ef3 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -31,7 +31,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import save_configs +from sheeprl.utils.utils import Ratio, save_configs def train( @@ -46,16 +46,12 @@ def train( decoder_optimizer: Optimizer, data: Dict[str, Tensor], aggregator: MetricAggregator | None, - update: int, + cumulative_per_rank_gradient_steps: int, cfg: Dict[str, Any], - policy_steps_per_update: int, group: Optional[CollectibleGroup] = None, ): - critic_target_network_frequency = cfg.algo.critic.target_network_frequency // policy_steps_per_update + 1 - actor_network_frequency = cfg.algo.actor.network_frequency // policy_steps_per_update + 1 - decoder_update_freq = cfg.algo.decoder.update_freq // policy_steps_per_update + 1 - normalized_obs = {} normalized_next_obs = {} + normalized_obs = {} for k in cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder: if k in cfg.algo.cnn_keys.encoder: normalized_obs[k] = data[k] / 255.0 @@ -77,12 +73,12 @@ def train( aggregator.update("Loss/value_loss", qf_loss) # Update the target networks with EMA - if update % critic_target_network_frequency == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: agent.critic_target_ema() agent.critic_encoder_target_ema() # Update the actor - if update % actor_network_frequency == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.actor.update_freq == 0: actions, logprobs = agent.get_actions_and_log_probs(normalized_obs, detach_encoder_features=True) qf_values = agent.get_q_values(normalized_obs, actions, detach_encoder_features=True) min_qf_values = torch.min(qf_values, dim=-1, keepdim=True)[0] @@ -103,7 +99,7 @@ def train( aggregator.update("Loss/alpha_loss", alpha_loss) # Update the decoder - if update % decoder_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.decoder.update_freq == 0: hidden = encoder(normalized_obs) reconstruction = decoder(hidden) reconstruction_loss = 0 @@ -284,6 +280,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -307,13 +308,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in cfg.algo.cnn_keys.encoder: obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) + per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * fabric.world_size # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - if update < learning_starts: + if update <= learning_starts: actions = envs.action_space.sample() else: with torch.inference_mode(): @@ -363,18 +366,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs # Train the agent - if update >= learning_starts - 1: - training_steps = learning_starts if update == learning_starts - 1 else 1 - + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: # We sample one time to reduce the communications between processes sample = rb.sample_tensors( - training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, + per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs, from_numpy=cfg.buffer.from_numpy, - ) # [G*B, 1] - gathered_data = fabric.all_gather(sample) # [G*B, World, 1] - flatten_dim = 3 if fabric.world_size > 1 else 2 - gathered_data = {k: v.view(-1, *v.shape[flatten_dim:]) for k, v in gathered_data.items()} # [G*B*World] + ) # [1, G*B] + gathered_data: Dict[str, torch.Tensor] = fabric.all_gather(sample) # [World, 1, G*B] + for k, v in gathered_data.items(): + gathered_data[k] = v.flatten(start_dim=0, end_dim=2).float() # [G*B*World] len_data = len(gathered_data[next(iter(gathered_data.keys()))]) if fabric.world_size > 1: dist_sampler: DistributedSampler = DistributedSampler( @@ -408,10 +410,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): decoder_optimizer, {k: v[batch_idxes] for k, v in gathered_data.items()}, aggregator, - update, + cumulative_per_rank_gradient_steps, cfg, - policy_steps_per_update, ) + cumulative_per_rank_gradient_steps += 1 train_step += world_size # Log metrics @@ -422,6 +424,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() diff --git a/sheeprl/configs/algo/sac_ae.yaml b/sheeprl/configs/algo/sac_ae.yaml index e7dfd94b..2fc24f33 100644 --- a/sheeprl/configs/algo/sac_ae.yaml +++ b/sheeprl/configs/algo/sac_ae.yaml @@ -7,6 +7,7 @@ defaults: name: sac_ae # Training recipe +replay_ratio: 1.0 learning_starts: 1000 # Model related parameters @@ -53,12 +54,12 @@ decoder: tau: 0.01 hidden_size: 1024 actor: - network_frequency: 2 + update_freq: 2 optimizer: lr: 1e-3 eps: 1e-08 critic: - target_network_frequency: 2 + target_network_update_freq: 2 optimizer: lr: 1e-3 eps: 1e-08 diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 73672e39..bbf10d5a 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -317,7 +317,10 @@ def sample_tensors( Dict[str, Tensor]: the sampled dictionary, containing the sampled array, one for every key, with a shape of [n_samples, batch_size, ...] """ - samples = self.sample(batch_size=batch_size, sample_next_obs=sample_next_obs, clone=clone, **kwargs) + n_samples = kwargs.pop("n_samples", 1) + samples = self.sample( + batch_size=batch_size, sample_next_obs=sample_next_obs, clone=clone, n_samples=n_samples, **kwargs + ) return { k: get_tensor(v, dtype=dtype, clone=clone, device=device, from_numpy=from_numpy) for k, v in samples.items() } From 32b89b446afa53244c2fe24bee81781d24bb2f4f Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 29 Mar 2024 08:40:28 +0100 Subject: [PATCH 20/51] Add DrOQ replay ratio --- sheeprl/algos/droq/droq.py | 62 +++++++++++++++++++++++++++------- sheeprl/algos/sac_ae/sac_ae.py | 1 + sheeprl/configs/algo/droq.yaml | 2 +- 3 files changed, 52 insertions(+), 13 deletions(-) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 2d7f8c1a..26503249 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -26,7 +26,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import save_configs +from sheeprl.utils.utils import Ratio, save_configs def train( @@ -38,17 +38,22 @@ def train( rb: ReplayBuffer, aggregator: MetricAggregator | None, cfg: Dict[str, Any], + per_rank_gradient_steps: int, ): # Sample a minibatch in a distributed way: Line 5 - Algorithm 2 # We sample one time to reduce the communications between processes sample = rb.sample_tensors( - cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, + per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs, from_numpy=cfg.buffer.from_numpy, ) - critic_data = fabric.all_gather(sample) - flatten_dim = 3 if fabric.world_size > 1 else 2 - critic_data = {k: v.view(-1, *v.shape[flatten_dim:]) for k, v in critic_data.items()} + critic_data: Dict[str, torch.Tensor] = fabric.all_gather(sample) # [World, G*B] + for k, v in critic_data.items(): + critic_data[k] = v.float() # [G*B*World] + if fabric.world_size > 1: + critic_data[k] = critic_data[k].flatten(start_dim=0, end_dim=2) + else: + critic_data[k] = critic_data[k].flatten(start_dim=0, end_dim=1) critic_idxes = range(len(critic_data[next(iter(critic_data.keys()))])) if fabric.world_size > 1: dist_sampler: DistributedSampler = DistributedSampler( @@ -68,7 +73,12 @@ def train( # Sample a different minibatch in a distributed way to update actor and alpha parameter sample = rb.sample_tensors(cfg.algo.per_rank_batch_size, from_numpy=cfg.buffer.from_numpy) actor_data = fabric.all_gather(sample) - actor_data = {k: v.view(-1, *v.shape[flatten_dim:]) for k, v in actor_data.items()} + for k, v in actor_data.items(): + actor_data[k] = v.float() # [G*B*World] + if fabric.world_size > 1: + actor_data[k] = actor_data[k].flatten(start_dim=0, end_dim=2) + else: + actor_data[k] = actor_data[k].flatten(start_dim=0, end_dim=1) if fabric.world_size > 1: actor_sampler: DistributedSampler = DistributedSampler( range(len(actor_data[next(iter(actor_data.keys()))])), @@ -259,6 +269,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -280,16 +295,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): o = envs.reset(seed=cfg.seed)[0] obs = np.concatenate([o[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype(np.float32) + per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * fabric.world_size # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - with torch.no_grad(): - # Sample an action given the observation received by the environment - actions, _ = actor(torch.from_numpy(obs).to(device)) - actions = actions.cpu().numpy() + if update <= learning_starts: + actions = envs.action_space.sample() + else: + with torch.inference_mode(): + # Sample an action given the observation received by the environment + actions, _ = actor(torch.from_numpy(obs).to(device)) + actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) dones = np.logical_or(dones, truncated) @@ -328,9 +348,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs # Train the agent - if update > learning_starts: - train(fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, rb, aggregator, cfg) + per_rank_gradient_steps = ratio(policy_step / world_size) + if update >= learning_starts and per_rank_gradient_steps > 0: + train( + fabric, + agent, + actor_optimizer, + qf_optimizer, + alpha_optimizer, + rb, + aggregator, + cfg, + per_rank_gradient_steps, + ) train_step += world_size + cumulative_per_rank_gradient_steps += per_rank_gradient_steps # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -340,6 +372,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -372,6 +409,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "qf_optimizer": qf_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), + "ratio": ratio.state_dict(), "update": update * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index a40c5ef3..d8f15500 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -465,6 +465,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "alpha_optimizer": alpha_optimizer.state_dict(), "encoder_optimizer": encoder_optimizer.state_dict(), "decoder_optimizer": decoder_optimizer.state_dict(), + "ratio": ratio.state_dict(), "update": update * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, diff --git a/sheeprl/configs/algo/droq.yaml b/sheeprl/configs/algo/droq.yaml index 82a88b2a..29d0aff1 100644 --- a/sheeprl/configs/algo/droq.yaml +++ b/sheeprl/configs/algo/droq.yaml @@ -5,7 +5,7 @@ defaults: name: droq # Training recipe -per_rank_gradient_steps: 20 +replay_ratio: 20.0 # Override from `sac` config critic: From d05788661be6445f3cd58924e09e1e020feb5e0a Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 29 Mar 2024 08:40:34 +0100 Subject: [PATCH 21/51] Fix tests --- tests/test_algos/test_algos.py | 26 +++++++++++++------------- tests/test_algos/test_cli.py | 8 ++++---- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 3d6fc985..fed41a6e 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -69,7 +69,7 @@ def test_droq(standard_args, start_time): "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", f"root_dir={root_dir}", f"run_name={run_name}", ] @@ -87,7 +87,7 @@ def test_sac(standard_args, start_time): "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", f"root_dir={root_dir}", f"run_name={run_name}", ] @@ -105,7 +105,7 @@ def test_sac_ae(standard_args, start_time): "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", f"root_dir={root_dir}", f"run_name={run_name}", "algo.mlp_keys.encoder=[state]", @@ -130,7 +130,7 @@ def test_sac_decoupled(standard_args, start_time): "exp=sac_decoupled", "algo.per_rank_batch_size=1", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", f"fabric.devices={os.environ['LT_DEVICES']}", f"root_dir={root_dir}", f"run_name={run_name}", @@ -239,7 +239,7 @@ def test_dreamer_v1(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", "algo.horizon=2", f"env.id={env_id}", f"root_dir={root_dir}", @@ -270,7 +270,7 @@ def test_p2e_dv1(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", "algo.horizon=4", "env.id=" + env_id, f"root_dir={root_dir}", @@ -311,7 +311,7 @@ def test_p2e_dv1(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", "algo.horizon=4", "env=dummy", "env.id=" + env_id, @@ -341,7 +341,7 @@ def test_dreamer_v2(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -377,7 +377,7 @@ def test_p2e_dv2(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", "algo.horizon=4", "env.id=" + env_id, f"root_dir={root_dir}", @@ -418,7 +418,7 @@ def test_p2e_dv2(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", "algo.horizon=4", "env=dummy", "env.id=" + env_id, @@ -448,7 +448,7 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -483,7 +483,7 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -526,7 +526,7 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "alog.replay_ratio=1", "algo.horizon=8", "env=dummy", "env.id=" + env_id, diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index 701a5e7a..08938055 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -125,7 +125,7 @@ def test_resume_from_checkpoint(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 alog.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " @@ -168,7 +168,7 @@ def test_resume_from_checkpoint_env_error(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 alog.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " @@ -221,7 +221,7 @@ def test_resume_from_checkpoint_algo_error(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 alog.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " @@ -276,7 +276,7 @@ def test_evaluate(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 alog.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " From b9044a3b88be80dcfc5dd313863aad1bd7fbb67d Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 29 Mar 2024 08:42:19 +0100 Subject: [PATCH 22/51] Fix mispelled --- tests/test_algos/test_algos.py | 26 +++++++++++++------------- tests/test_algos/test_cli.py | 8 ++++---- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index fed41a6e..f506c88f 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -69,7 +69,7 @@ def test_droq(standard_args, start_time): "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", f"root_dir={root_dir}", f"run_name={run_name}", ] @@ -87,7 +87,7 @@ def test_sac(standard_args, start_time): "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", f"root_dir={root_dir}", f"run_name={run_name}", ] @@ -105,7 +105,7 @@ def test_sac_ae(standard_args, start_time): "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", f"root_dir={root_dir}", f"run_name={run_name}", "algo.mlp_keys.encoder=[state]", @@ -130,7 +130,7 @@ def test_sac_decoupled(standard_args, start_time): "exp=sac_decoupled", "algo.per_rank_batch_size=1", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", f"fabric.devices={os.environ['LT_DEVICES']}", f"root_dir={root_dir}", f"run_name={run_name}", @@ -239,7 +239,7 @@ def test_dreamer_v1(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", "algo.horizon=2", f"env.id={env_id}", f"root_dir={root_dir}", @@ -270,7 +270,7 @@ def test_p2e_dv1(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", "algo.horizon=4", "env.id=" + env_id, f"root_dir={root_dir}", @@ -311,7 +311,7 @@ def test_p2e_dv1(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", "algo.horizon=4", "env=dummy", "env.id=" + env_id, @@ -341,7 +341,7 @@ def test_dreamer_v2(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -377,7 +377,7 @@ def test_p2e_dv2(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", "algo.horizon=4", "env.id=" + env_id, f"root_dir={root_dir}", @@ -418,7 +418,7 @@ def test_p2e_dv2(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", "algo.horizon=4", "env=dummy", "env.id=" + env_id, @@ -448,7 +448,7 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -483,7 +483,7 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -526,7 +526,7 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "alog.replay_ratio=1", + "algo.replay_ratio=1", "algo.horizon=8", "env=dummy", "env.id=" + env_id, diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index 08938055..31772981 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -125,7 +125,7 @@ def test_resume_from_checkpoint(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 alog.replay_ratio=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " @@ -168,7 +168,7 @@ def test_resume_from_checkpoint_env_error(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 alog.replay_ratio=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " @@ -221,7 +221,7 @@ def test_resume_from_checkpoint_algo_error(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 alog.replay_ratio=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " @@ -276,7 +276,7 @@ def test_evaluate(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 alog.replay_ratio=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " From 5bd7d75c217fccaa6f9e6b66697f9a2c4bc2050c Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 29 Mar 2024 09:09:19 +0100 Subject: [PATCH 23/51] Fix wrong attribute accesing --- sheeprl/algos/dreamer_v1/agent.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index 4885a26d..24274e09 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -23,8 +23,8 @@ # In order to use the hydra.utils.get_class method, in this way the user can # specify in the configs the name of the class without having to know where # to go to retrieve the class -Actor = DV2Actor -MinedojoActor = DV2MinedojoActor +DV1Actor = DV2Actor +DV1MinedojoActor = DV2MinedojoActor class RecurrentModel(nn.Module): @@ -241,7 +241,7 @@ def __init__( encoder: nn.Module | _FabricModule, recurrent_model: nn.Module | _FabricModule, representation_model: nn.Module | _FabricModule, - actor: DV2Actor | _FabricModule, + actor: DV1Actor | _FabricModule, actions_dim: Sequence[int], num_envs: int, stochastic_size: int, @@ -307,7 +307,7 @@ def get_exploration_actions( """ actions = self.get_actions(obs, sample_actions=sample_actions, mask=mask) expl_actions = None - if self.actor.expl_amount > 0: + if self.actor._expl_amount > 0: expl_actions = self.actor.add_exploration_noise(actions, step=step, mask=mask) self.actions = torch.cat(expl_actions, dim=-1) return expl_actions or actions @@ -483,7 +483,7 @@ def build_agent( continue_model.apply(init_weights) if world_model_cfg.use_continues else None, ) actor_cls = hydra.utils.get_class(cfg.algo.actor.cls) - actor: Union[Actor, MinedojoActor] = actor_cls( + actor: Union[DV1Actor, DV1MinedojoActor] = actor_cls( latent_state_size=latent_state_size, actions_dim=actions_dim, is_continuous=is_continuous, From 8d94f68431240892386182621aa4709f4890cbca Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 29 Mar 2024 09:09:57 +0100 Subject: [PATCH 24/51] FIx naming and configs --- howto/configs.md | 6 ++---- sheeprl/algos/dreamer_v2/dreamer_v2.py | 2 +- sheeprl/algos/dreamer_v3/dreamer_v3.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 2 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 2 +- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 2 +- sheeprl/algos/sac_ae/sac_ae.py | 6 +++--- sheeprl/configs/algo/dreamer_v1.yaml | 5 ++--- sheeprl/configs/algo/dreamer_v2.yaml | 2 +- sheeprl/configs/algo/dreamer_v3.yaml | 2 +- sheeprl/configs/algo/sac_ae.yaml | 6 +++--- sheeprl/configs/exp/dreamer_v1_benchmarks.yaml | 4 ++-- sheeprl/configs/exp/dreamer_v2_benchmarks.yaml | 2 +- sheeprl/configs/exp/dreamer_v2_crafter.yaml | 2 +- sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml | 4 ++-- sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml | 2 +- .../exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml | 2 +- ...L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml | 2 +- ..._dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml | 2 +- sheeprl/configs/exp/sac_benchmarks.yaml | 2 +- tests/test_algos/test_algos.py | 7 ++----- tests/test_algos/test_cli.py | 8 ++++---- 23 files changed, 35 insertions(+), 41 deletions(-) diff --git a/howto/configs.md b/howto/configs.md index ea49072b..4d9aa09a 100644 --- a/howto/configs.md +++ b/howto/configs.md @@ -139,10 +139,8 @@ lmbda: 0.95 horizon: 15 # Training recipe -train_every: 16 learning_starts: 65536 per_rank_pretrain_steps: 1 -per_rank_gradient_steps: 1 per_rank_sequence_length: ??? # Encoder and decoder keys @@ -266,7 +264,7 @@ critic: mlp_layers: ${algo.mlp_layers} layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} - target_network_update_freq: 1 + per_rank_target_network_update_freq: 1 tau: 0.02 bins: 255 clip_gradients: 100.0 @@ -410,7 +408,7 @@ buffer: algo: learning_starts: 1024 total_steps: 100000 - train_every: 1 + dense_units: 512 mlp_layers: 2 world_model: diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index e169450a..c7f2cd09 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -730,7 +730,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(n_samples): - if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq == 0: for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()): tcp.data.copy_(cp.data) batch = {k: v[i].float() for k, v in local_data.items()} diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 047b69c2..4b889ac1 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -697,7 +697,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(per_rank_gradient_steps): - if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq == 0: tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index f66881f9..b8335ba2 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -869,7 +869,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(n_samples): - if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(cp.data) for cp, tcp in zip( diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 2b33a960..ff9e6328 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -375,7 +375,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(n_samples): - if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(cp.data) batch = {k: v[i].float() for k, v in local_data.items()} diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 8ae603e1..9a609478 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -942,7 +942,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(per_rank_gradient_steps): - if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq == 0: tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 3ac28ed4..04cb0d23 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -376,7 +376,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Start training with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(per_rank_gradient_steps): - if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq == 0: tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index d8f15500..d5d6783f 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -73,12 +73,12 @@ def train( aggregator.update("Loss/value_loss", qf_loss) # Update the target networks with EMA - if cumulative_per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq == 0: agent.critic_target_ema() agent.critic_encoder_target_ema() # Update the actor - if cumulative_per_rank_gradient_steps % cfg.algo.actor.update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.actor.per_rank_update_freq == 0: actions, logprobs = agent.get_actions_and_log_probs(normalized_obs, detach_encoder_features=True) qf_values = agent.get_q_values(normalized_obs, actions, detach_encoder_features=True) min_qf_values = torch.min(qf_values, dim=-1, keepdim=True)[0] @@ -99,7 +99,7 @@ def train( aggregator.update("Loss/alpha_loss", alpha_loss) # Update the decoder - if cumulative_per_rank_gradient_steps % cfg.algo.decoder.update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.decoder.per_rank_update_freq == 0: hidden = encoder(normalized_obs) reconstruction = decoder(hidden) reconstruction_loss = 0 diff --git a/sheeprl/configs/algo/dreamer_v1.yaml b/sheeprl/configs/algo/dreamer_v1.yaml index e76536fc..53402948 100644 --- a/sheeprl/configs/algo/dreamer_v1.yaml +++ b/sheeprl/configs/algo/dreamer_v1.yaml @@ -11,9 +11,8 @@ horizon: 15 name: dreamer_v1 # Training recipe -train_every: 1000 +replay_ratio: 0.1 learning_starts: 5000 -per_rank_gradient_steps: 100 per_rank_sequence_length: ??? # Encoder and decoder keys @@ -91,7 +90,7 @@ world_model: # Actor actor: - cls: sheeprl.algos.dreamer_v1.agent.Actor + cls: sheeprl.algos.dreamer_v1.agent.DV1Actor min_std: 0.1 init_std: 5.0 dense_act: ${algo.dense_act} diff --git a/sheeprl/configs/algo/dreamer_v2.yaml b/sheeprl/configs/algo/dreamer_v2.yaml index a252728c..19830fe4 100644 --- a/sheeprl/configs/algo/dreamer_v2.yaml +++ b/sheeprl/configs/algo/dreamer_v2.yaml @@ -123,7 +123,7 @@ critic: mlp_layers: ${algo.mlp_layers} layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} - target_network_update_freq: 100 + per_rank_target_network_update_freq: 100 clip_gradients: 100.0 # Critic optimizer diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index 315baa7a..f227a45b 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -135,7 +135,7 @@ critic: mlp_layers: ${algo.mlp_layers} layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} - target_network_update_freq: 1 + per_rank_target_network_update_freq: 1 tau: 0.02 bins: 255 clip_gradients: 100.0 diff --git a/sheeprl/configs/algo/sac_ae.yaml b/sheeprl/configs/algo/sac_ae.yaml index 2fc24f33..dde8668f 100644 --- a/sheeprl/configs/algo/sac_ae.yaml +++ b/sheeprl/configs/algo/sac_ae.yaml @@ -39,7 +39,7 @@ encoder: # Decoder decoder: l2_lambda: 1e-6 - update_freq: 1 + per_rank_update_freq: 1 cnn_channels_multiplier: ${algo.cnn_channels_multiplier} dense_units: ${algo.dense_units} mlp_layers: ${algo.mlp_layers} @@ -54,12 +54,12 @@ decoder: tau: 0.01 hidden_size: 1024 actor: - update_freq: 2 + per_rank_update_freq: 2 optimizer: lr: 1e-3 eps: 1e-08 critic: - target_network_update_freq: 2 + per_rank_target_network_update_freq: 2 optimizer: lr: 1e-3 eps: 1e-08 diff --git a/sheeprl/configs/exp/dreamer_v1_benchmarks.yaml b/sheeprl/configs/exp/dreamer_v1_benchmarks.yaml index efa04f19..12f29b1f 100644 --- a/sheeprl/configs/exp/dreamer_v1_benchmarks.yaml +++ b/sheeprl/configs/exp/dreamer_v1_benchmarks.yaml @@ -26,10 +26,10 @@ buffer: # Algorithm algo: learning_starts: 1024 - train_every: 16 + dense_units: 8 mlp_layers: 1 - per_rank_gradient_steps: 1 + world_model: stochastic_size: 4 encoder: diff --git a/sheeprl/configs/exp/dreamer_v2_benchmarks.yaml b/sheeprl/configs/exp/dreamer_v2_benchmarks.yaml index 27bcb515..cfa2977a 100644 --- a/sheeprl/configs/exp/dreamer_v2_benchmarks.yaml +++ b/sheeprl/configs/exp/dreamer_v2_benchmarks.yaml @@ -27,7 +27,7 @@ buffer: algo: learning_starts: 1024 per_rank_pretrain_steps: 1 - train_every: 16 + dense_units: 8 mlp_layers: world_model: diff --git a/sheeprl/configs/exp/dreamer_v2_crafter.yaml b/sheeprl/configs/exp/dreamer_v2_crafter.yaml index db7e5249..caff4d8b 100644 --- a/sheeprl/configs/exp/dreamer_v2_crafter.yaml +++ b/sheeprl/configs/exp/dreamer_v2_crafter.yaml @@ -40,7 +40,7 @@ mlp_keys: # Algorithm algo: gamma: 0.999 - train_every: 5 + layer_norm: True learning_starts: 10000 per_rank_pretrain_steps: 1 diff --git a/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml b/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml index dc8b146b..4c731b77 100644 --- a/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml +++ b/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml @@ -26,12 +26,12 @@ buffer: # Algorithm algo: gamma: 0.995 - train_every: 16 + total_steps: 200000000 learning_starts: 200000 per_rank_batch_size: 32 per_rank_pretrain_steps: 1 - per_rank_gradient_steps: 1 + world_model: use_continues: True kl_free_nats: 0.0 diff --git a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml index 35f8fb2f..666f93dc 100644 --- a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml +++ b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml @@ -25,7 +25,7 @@ buffer: # Algorithm algo: - train_every: 2 + replay_ratio: 0.5 cnn_keys: encoder: diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml index 7003be3a..704a07c5 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml @@ -40,7 +40,7 @@ algo: mlp_keys: encoder: [] learning_starts: 1024 - train_every: 2 + # Metric metric: diff --git a/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml b/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml index df1f356b..6eda91af 100644 --- a/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml +++ b/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml @@ -35,7 +35,7 @@ buffer: # Algorithm algo: learning_starts: 131072 - train_every: 1 + dense_units: 768 mlp_layers: 4 world_model: diff --git a/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml b/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml index 8dcad491..94d4d95a 100644 --- a/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml +++ b/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml @@ -37,7 +37,7 @@ buffer: # Algorithm algo: learning_starts: 65536 - train_every: 1 + dense_units: 768 mlp_layers: 4 world_model: diff --git a/sheeprl/configs/exp/sac_benchmarks.yaml b/sheeprl/configs/exp/sac_benchmarks.yaml index 63dc2086..b3ce9a7d 100644 --- a/sheeprl/configs/exp/sac_benchmarks.yaml +++ b/sheeprl/configs/exp/sac_benchmarks.yaml @@ -15,7 +15,7 @@ env: algo: name: sac learning_starts: 100 - per_rank_gradient_steps: 1 + per_rank_batch_size: 512 # # If you want to run this benchmark with older versions, # you need to comment the test function in the `./sheeprl/algos/ppo/ppo.py` file. diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index f506c88f..132d7a9d 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -114,8 +114,8 @@ def test_sac_ae(standard_args, start_time): "algo.hidden_size=4", "algo.dense_units=4", "algo.cnn_channels_multiplier=2", - "algo.actor.network_frequency=1", - "algo.decoder.update_freq=1", + "algo.actor.per_rank_update_freq=1", + "algo.decoder.per_rank_update_freq=1", ] with mock.patch.object(sys, "argv", args): @@ -459,7 +459,6 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", "algo.layer_norm=True", - "algo.train_every=1", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", ] @@ -494,7 +493,6 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", "algo.layer_norm=True", - "algo.train_every=1", "buffer.checkpoint=True", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", @@ -538,7 +536,6 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", "algo.layer_norm=True", - "algo.train_every=1", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", ] diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index 31772981..0e95871f 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -130,7 +130,7 @@ def test_resume_from_checkpoint(): "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " - f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " + f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, check=True, @@ -173,7 +173,7 @@ def test_resume_from_checkpoint_env_error(): "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " - f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " + f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, check=True, @@ -226,7 +226,7 @@ def test_resume_from_checkpoint_algo_error(): "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " - f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " + f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, check=True, @@ -281,7 +281,7 @@ def test_evaluate(): "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " - f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " + f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, check=True, From e5dd8fd9a5eb49c6e8a35c51b7fe7a84159c0f3f Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 29 Mar 2024 15:04:33 +0100 Subject: [PATCH 25/51] feat: add terminated and truncated to dreamer, p2e and ppo algos --- sheeprl/algos/a2c/a2c.py | 30 ++++++++++--- sheeprl/algos/dreamer_v1/dreamer_v1.py | 25 ++++++----- sheeprl/algos/dreamer_v2/dreamer_v2.py | 37 +++++++++------- sheeprl/algos/dreamer_v3/dreamer_v3.py | 8 ++-- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 25 ++++++----- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 19 +++++--- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 41 ++++++++++-------- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 27 +++++++----- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 43 +++++++++++-------- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 31 ++++++++----- sheeprl/algos/ppo/ppo.py | 8 ++-- sheeprl/algos/ppo/ppo_decoupled.py | 8 ++-- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 6 ++- sheeprl/configs/algo/dreamer_v3_S.yaml | 12 +++--- .../exp/dreamer_v3_dmc_walker_walk.yaml | 2 + 15 files changed, 202 insertions(+), 120 deletions(-) diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index bca58a05..8d45d0ed 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -243,11 +243,31 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions = torch.cat(actions, -1).cpu().numpy() # Single environment step - obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) - - dones = np.logical_or(done, truncated) - dones = dones.reshape(cfg.env.num_envs, -1) - rewards = rewards.reshape(cfg.env.num_envs, -1) + obs, rewards, terminated, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + truncated_envs = np.nonzero(truncated)[0] + if len(truncated_envs) > 0: + real_next_obs = { + k: torch.empty( + len(truncated_envs), + *observation_space[k].shape, + dtype=torch.float32, + device=device, + ) + for k in obs_keys + } + for i, truncated_env in enumerate(truncated_envs): + for k, v in info["final_observation"][truncated_env].items(): + torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) + if k in cfg.algo.cnn_keys.encoder: + torch_v = torch_v.view(-1, *v.shape[-2:]) + torch_v = torch_v / 255.0 - 0.5 + real_next_obs[k][i] = torch_v + _, _, vals = player(real_next_obs) + rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape( + rewards[truncated_envs].shape + ) + dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data step_data["dones"] = dones[np.newaxis] diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index fdbe21bf..52434e39 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -192,9 +192,9 @@ def train( 1, validate_args=validate_args, ) - continue_targets = (1 - data["dones"]) * cfg.algo.gamma + continues_targets = (1 - data["terminated"]) * cfg.algo.gamma else: - qc = continue_targets = None + qc = continues_targets = None # compute the distributions of the states (posteriors and priors) # it is necessary an Independent distribution because @@ -224,7 +224,7 @@ def train( cfg.algo.world_model.kl_free_nats, cfg.algo.world_model.kl_regularizer, qc, - continue_targets, + continues_targets, cfg.algo.world_model.continue_scale_factor, ) fabric.backward(rec_loss) @@ -582,7 +582,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in cfg.algo.cnn_keys.encoder: obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -629,8 +630,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -659,7 +662,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated[np.newaxis] + step_data["truncated"] = truncated[np.newaxis] step_data["actions"] = actions[np.newaxis] step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -671,13 +675,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["dones"][0, d]) # Reset internal agent states player.init_states(reset_envs=dones_idxes) diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index c7f2cd09..c18f6e65 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -185,9 +185,9 @@ def train( 1, validate_args=validate_args, ) - continue_targets = (1 - data["dones"]) * cfg.algo.gamma + continues_targets = (1 - data["terminated"]) * cfg.algo.gamma else: - pc = continue_targets = None + pc = continues_targets = None # Reshape posterior and prior logits to shape [T, B, 32, 32] priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size) @@ -207,7 +207,7 @@ def train( cfg.algo.world_model.kl_free_avg, cfg.algo.world_model.kl_regularizer, pc, - continue_targets, + continues_targets, cfg.algo.world_model.discount_scale_factor, validate_args=validate_args, ) @@ -282,8 +282,8 @@ def train( predicted_rewards = world_model.reward_model(imagined_trajectories) if cfg.algo.world_model.use_continues and world_model.continue_model: continues = logits_to_probs(world_model.continue_model(imagined_trajectories), is_binary=True) - true_done = (1 - data["dones"]).reshape(1, -1, 1) * cfg.algo.gamma - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).reshape(1, -1, 1) * cfg.algo.gamma + continues = torch.cat((true_continue, continues[1:])) else: continues = torch.ones_like(predicted_rewards.detach()) * cfg.algo.gamma @@ -612,12 +612,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) if cfg.dry_run: - step_data["dones"] = step_data["dones"] + 1 + step_data["truncated"] = step_data["truncated"] + 1 + step_data["terminated"] = step_data["terminated"] + 1 step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() @@ -663,9 +665,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"])) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.dry_run and buffer_type == "episode": dones = np.ones_like(dones) @@ -693,7 +697,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -705,14 +710,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - reset_data["is_first"] = np.ones_like(reset_data["dones"]) + reset_data["is_first"] = np.ones_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d]) # Reset internal agent states player.init_states(dones_idxes) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 96389821..9b1bfee7 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -171,7 +171,7 @@ def train( 1, validate_args=validate_args, ) - continue_targets = 1 - data["terminated"] + continues_targets = 1 - data["terminated"] # Reshape posterior and prior logits to shape [B, T, 32, 32] priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size) @@ -191,7 +191,7 @@ def train( cfg.algo.world_model.kl_free_nats, cfg.algo.world_model.kl_regularizer, pc, - continue_targets, + continues_targets, cfg.algo.world_model.continue_scale_factor, validate_args=validate_args, ) @@ -255,8 +255,8 @@ def train( 1, validate_args=validate_args, ).mode - true_done = (1 - data["terminated"]).flatten().reshape(1, -1, 1) - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).flatten().reshape(1, -1, 1) + continues = torch.cat((true_continue, continues[1:])) # Estimate lambda-values lambda_values = compute_lambda_values( diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 059eeed4..ee9c8a68 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -148,9 +148,9 @@ def train( 1, validate_args=validate_args, ) - continue_targets = (1 - data["dones"]) * cfg.algo.gamma + continues_targets = (1 - data["terminated"]) * cfg.algo.gamma else: - qc = continue_targets = None + qc = continues_targets = None posteriors_dist = Independent( Normal(posteriors_mean, posteriors_std, validate_args=validate_args), 1, validate_args=validate_args ) @@ -169,7 +169,7 @@ def train( cfg.algo.world_model.kl_free_nats, cfg.algo.world_model.kl_regularizer, qc, - continue_targets, + continues_targets, cfg.algo.world_model.continue_scale_factor, ) fabric.backward(rec_loss) @@ -608,7 +608,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in cfg.algo.cnn_keys.encoder: obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -655,8 +656,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -685,7 +688,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated[np.newaxis] + step_data["truncated"] = truncated[np.newaxis] step_data["actions"] = actions[np.newaxis] step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -697,13 +701,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d]) # Reset internal agent states player.init_states(reset_envs=dones_idxes) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 7037b46e..1bbfc2ee 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -252,7 +252,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if k in cfg.algo.cnn_keys.encoder: obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -283,8 +284,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -313,7 +316,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated[np.newaxis] + step_data["truncated"] = truncated[np.newaxis] step_data["actions"] = actions[np.newaxis] step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -325,13 +329,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d]) # Reset internal agent states player.init_states(reset_envs=dones_idxes) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index b8335ba2..1e593100 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -167,9 +167,9 @@ def train( 1, validate_args=validate_args, ) - continue_targets = (1 - data["dones"]) * cfg.algo.gamma + continues_targets = (1 - data["terminated"]) * cfg.algo.gamma else: - pc = continue_targets = None + pc = continues_targets = None # Reshape posterior and prior logits to shape [B, T, 32, 32] priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size) @@ -189,7 +189,7 @@ def train( cfg.algo.world_model.kl_free_avg, cfg.algo.world_model.kl_regularizer, pc, - continue_targets, + continues_targets, cfg.algo.world_model.discount_scale_factor, validate_args=validate_args, ) @@ -278,8 +278,8 @@ def train( if cfg.algo.world_model.use_continues and world_model.continue_model: continues = logits_to_probs(logits=world_model.continue_model(imagined_trajectories), is_binary=True) - true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) * cfg.algo.gamma - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).flatten().reshape(1, -1, 1) * cfg.algo.gamma + continues = torch.cat((true_continue, continues[1:])) else: continues = torch.ones_like(intrinsic_reward.detach()) * cfg.algo.gamma @@ -381,8 +381,8 @@ def train( predicted_rewards = world_model.reward_model(imagined_trajectories) if cfg.algo.world_model.use_continues and world_model.continue_model: continues = logits_to_probs(logits=world_model.continue_model(imagined_trajectories), is_binary=True) - true_done = (1 - data["dones"]).reshape(1, -1, 1) * cfg.algo.gamma - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).reshape(1, -1, 1) * cfg.algo.gamma + continues = torch.cat((true_continue, continues[1:])) else: continues = torch.ones_like(predicted_rewards.detach()) * cfg.algo.gamma @@ -750,12 +750,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) if cfg.dry_run: - step_data["dones"] = step_data["dones"] + 1 + step_data["terminated"] = step_data["terminated"] + 1 + step_data["truncated"] = step_data["truncated"] + 1 step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() @@ -801,9 +803,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"])) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.dry_run and buffer_type == "episode": dones = np.ones_like(dones) @@ -831,7 +835,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -843,14 +848,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - reset_data["is_first"] = np.ones_like(reset_data["dones"]) + reset_data["is_first"] = np.ones_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d]) # Reset internal agent states player.init_states(dones_idxes) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index ff9e6328..50e3114c 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -269,12 +269,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) if cfg.dry_run: - step_data["dones"] = step_data["dones"] + 1 + step_data["terminated"] = step_data["terminated"] + 1 + step_data["truncated"] = step_data["truncated"] + 1 step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() @@ -304,9 +306,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"])) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.dry_run and buffer_type == "episode": dones = np.ones_like(dones) @@ -334,7 +338,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -346,14 +351,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - reset_data["is_first"] = np.ones_like(reset_data["dones"]) + reset_data["is_first"] = np.ones_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) # Reset internal agent states player.init_states(dones_idxes) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 9a609478..5e171383 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -167,7 +167,7 @@ def train( 1, validate_args=validate_args, ) - continue_targets = 1 - data["dones"] + continues_targets = 1 - data["terminated"] # Reshape posterior and prior logits to shape [B, T, 32, 32] priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size) @@ -187,7 +187,7 @@ def train( cfg.algo.world_model.kl_free_nats, cfg.algo.world_model.kl_regularizer, pc, - continue_targets, + continues_targets, cfg.algo.world_model.continue_scale_factor, ) fabric.backward(rec_loss) @@ -274,8 +274,8 @@ def train( 1, validate_args=validate_args, ).mode - true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).flatten().reshape(1, -1, 1) + continues = torch.cat((true_continue, continues[1:])) if critic["reward_type"] == "intrinsic": # Predict intrinsic reward @@ -418,8 +418,8 @@ def train( 1, validate_args=validate_args, ).mode - true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).flatten().reshape(1, -1, 1) + continues = torch.cat((true_continue, continues[1:])) lambda_values = compute_lambda_values( predicted_rewards[1:], @@ -817,9 +817,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) player.init_states() cumulative_per_rank_gradient_steps = 0 @@ -866,16 +867,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) - step_data["is_first"] = np.zeros_like(step_data["dones"]) + step_data["is_first"] = np.zeros_like(step_data["terminated"]) if "restart_on_exception" in infos: for i, agent_roe in enumerate(infos["restart_on_exception"]): if agent_roe and not dones[i]: last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like( - rb.buffer[i]["dones"][last_inserted_idx] + rb.buffer[i]["terminated"][last_inserted_idx] = np.zeros_like( + rb.buffer[i]["terminated"][last_inserted_idx] + ) + rb.buffer[i]["truncated"][last_inserted_idx] = np.ones_like( + rb.buffer[i]["truncated"][last_inserted_idx] ) rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( rb.buffer[i]["is_first"][last_inserted_idx] @@ -907,7 +913,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs rewards = rewards.reshape((1, cfg.env.num_envs, -1)) - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards) dones_idxes = dones.nonzero()[0].tolist() @@ -916,15 +923,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["terminated"] = step_data["terminated"][:, dones_idxes] + reset_data["truncated"] = step_data["truncated"][:, dones_idxes] reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = step_data["rewards"][:, dones_idxes] - reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + reset_data["is_first"] = np.zeros_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset already inserted step data step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) - step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["terminated"][:, dones_idxes] = np.zeros_like(step_data["terminated"][:, dones_idxes]) + step_data["truncated"][:, dones_idxes] = np.zeros_like(step_data["truncated"][:, dones_idxes]) step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 04cb0d23..e1ede0a0 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -264,9 +264,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) player.init_states() cumulative_per_rank_gradient_steps = 0 @@ -297,16 +298,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) - step_data["is_first"] = np.zeros_like(step_data["dones"]) + step_data["is_first"] = np.zeros_like(step_data["terminated"]) if "restart_on_exception" in infos: for i, agent_roe in enumerate(infos["restart_on_exception"]): if agent_roe and not dones[i]: last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like( - rb.buffer[i]["dones"][last_inserted_idx] + rb.buffer[i]["terminated"][last_inserted_idx] = np.zeros_like( + rb.buffer[i]["terminated"][last_inserted_idx] + ) + rb.buffer[i]["truncated"][last_inserted_idx] = np.ones_like( + rb.buffer[i]["truncated"][last_inserted_idx] ) rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( rb.buffer[i]["is_first"][last_inserted_idx] @@ -338,7 +344,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): obs = next_obs rewards = rewards.reshape((1, cfg.env.num_envs, -1)) - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards) dones_idxes = dones.nonzero()[0].tolist() @@ -347,15 +354,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["terminated"] = step_data["terminated"][:, dones_idxes] + reset_data["truncated"] = step_data["truncated"][:, dones_idxes] reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = step_data["rewards"][:, dones_idxes] - reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + reset_data["is_first"] = np.zeros_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset already inserted step data step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) - step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["terminated"][:, dones_idxes] = np.zeros_like(step_data["terminated"][:, dones_idxes]) + step_data["truncated"][:, dones_idxes] = np.zeros_like(step_data["truncated"][:, dones_idxes]) step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 9d73f41d..ba2aa447 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -283,7 +283,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions = torch.cat(actions, -1).cpu().numpy() # Single environment step - obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + obs, rewards, terminated, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) truncated_envs = np.nonzero(truncated)[0] if len(truncated_envs) > 0: real_next_obs = { @@ -303,8 +303,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v _, _, _, vals = player(real_next_obs) - rewards[truncated_envs] += vals.cpu().numpy().reshape(rewards[truncated_envs].shape) - dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape( + rewards[truncated_envs].shape + ) + dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index c506f66e..866228d7 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -206,7 +206,7 @@ def player( actions = torch.cat(actions, -1).cpu().numpy() # Single environment step - obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + obs, rewards, terminated, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) truncated_envs = np.nonzero(truncated)[0] if len(truncated_envs) > 0: real_next_obs = { @@ -226,8 +226,10 @@ def player( torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v _, _, _, vals = agent(real_next_obs) - rewards[truncated_envs] += vals.cpu().numpy().reshape(rewards[truncated_envs].shape) - dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape( + rewards[truncated_envs].shape + ) + dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index b16a5459..94399b6f 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -308,7 +308,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions = torch_actions.cpu().numpy() # Single environment step - next_obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + next_obs, rewards, terminated, truncated, info = envs.step( + real_actions.reshape(envs.action_space.shape) + ) truncated_envs = np.nonzero(truncated)[0] if len(truncated_envs) > 0: real_next_obs = { @@ -334,7 +336,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) vals = player.get_values(rnn_out).view(rewards[truncated_envs].shape).cpu().numpy() rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) - dones = np.logical_or(dones, truncated).reshape(1, cfg.env.num_envs, -1).astype(np.float32) + dones = np.logical_or(terminated, truncated).reshape(1, cfg.env.num_envs, -1).astype(np.float32) rewards = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1) diff --git a/sheeprl/configs/algo/dreamer_v3_S.yaml b/sheeprl/configs/algo/dreamer_v3_S.yaml index c8f36153..ff067597 100644 --- a/sheeprl/configs/algo/dreamer_v3_S.yaml +++ b/sheeprl/configs/algo/dreamer_v3_S.yaml @@ -2,14 +2,14 @@ defaults: - dreamer_v3_XL - _self_ -dense_units: 512 -mlp_layers: 2 +dense_units: 2 +mlp_layers: 1 world_model: encoder: - cnn_channels_multiplier: 32 + cnn_channels_multiplier: 1 recurrent_model: - recurrent_state_size: 512 + recurrent_state_size: 2 transition_model: - hidden_size: 512 + hidden_size: 2 representation_model: - hidden_size: 512 \ No newline at end of file + hidden_size: 2 \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml index 9964c3c3..c3657f83 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml @@ -39,6 +39,8 @@ algo: encoder: [] learning_starts: 1024 replay_ratio: 0.5 + per_rank_batch_size: 2 + per_rank_sequence_lenght: 4 # Metric metric: From fdd4579631b6756cbf6fa86f6fec6575a64a5b65 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 29 Mar 2024 15:05:00 +0100 Subject: [PATCH 26/51] fix: dmc wrapper --- sheeprl/envs/dmc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sheeprl/envs/dmc.py b/sheeprl/envs/dmc.py index c2643993..3ea0825d 100644 --- a/sheeprl/envs/dmc.py +++ b/sheeprl/envs/dmc.py @@ -220,13 +220,14 @@ def step( action = self._convert_action(action) time_step = self.env.step(action) reward = time_step.reward or 0.0 - done = time_step.last() obs = self._get_obs(time_step) self.current_state = _flatten_obs(time_step.observation) extra = {} extra["discount"] = time_step.discount extra["internal_state"] = self.env.physics.get_state().copy() - return obs, reward, done, False, extra + truncated = time_step.last() and time_step.discount == 1 + terminated = False if time_step.first() else time_step.last() and time_step.discount == 0 + return obs, reward, terminated, truncated, extra def reset( self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None From a2a26909ed3073dc19437cd2b53852b9a1358379 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 29 Mar 2024 16:18:00 +0100 Subject: [PATCH 27/51] feat: update algos to split terminated from truncated --- sheeprl/algos/droq/droq.py | 8 ++++---- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 2 +- sheeprl/algos/sac/sac.py | 8 ++++---- sheeprl/algos/sac/sac_decoupled.py | 6 +++--- sheeprl/algos/sac_ae/sac_ae.py | 8 ++++---- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 26503249..b78468b3 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -97,7 +97,7 @@ def train( next_target_qf_value = agent.get_next_target_q_values( critic_batch_data["next_observations"], critic_batch_data["rewards"], - critic_batch_data["dones"], + critic_batch_data["terminated"], cfg.algo.gamma, ) for qf_value_idx in range(agent.num_critics): @@ -310,8 +310,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sample an action given the observation received by the environment actions, _ = actor(torch.from_numpy(obs).to(device)) actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -340,7 +339,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not cfg.buffer.sample_next_obs: step_data["next_observations"] = real_next_obs[np.newaxis] step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) - step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.float32) step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) rb.add(step_data, validate_args=cfg.buffer.validate_args) diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 94399b6f..4054261d 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -335,7 +335,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): tuple(s[:, truncated_envs, ...] for s in states), ) vals = player.get_values(rnn_out).view(rewards[truncated_envs].shape).cpu().numpy() - rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) + rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(terminated, truncated).reshape(1, cfg.env.num_envs, -1).astype(np.float32) rewards = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 2c9e56ae..6e856931 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -45,7 +45,7 @@ def train( ): # Update the soft-critic next_target_qf_value = agent.get_next_target_q_values( - data["next_observations"], data["rewards"], data["dones"], cfg.algo.gamma + data["next_observations"], data["rewards"], data["terminated"], cfg.algo.gamma ) qf_values = agent.get_q_values(data["observations"], data["actions"]) qf_loss = critic_loss(qf_values, next_target_qf_value, agent.num_critics) @@ -253,9 +253,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) actions, _ = actor(torch_obs) actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) + next_obs, rewards, terminated, truncated, infos = envs.step(actions) next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) - dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) if cfg.metric.log_level > 0 and "final_info" in infos: @@ -277,7 +276,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): [v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1 ) - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) + step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) step_data["actions"] = actions[np.newaxis] step_data["observations"] = obs[np.newaxis] if not cfg.buffer.sample_next_obs: diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index dae1a7fa..c0b4541a 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -190,9 +190,8 @@ def player( torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) actions, _ = actor(torch_obs) actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) + next_obs, rewards, terminated, truncated, infos = envs.step(actions) next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) - dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) if cfg.metric.log_level > 0 and "final_info" in infos: @@ -214,7 +213,8 @@ def player( [v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1 ) - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) + step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) step_data["actions"] = actions[np.newaxis] step_data["observations"] = obs[np.newaxis] if not cfg.buffer.sample_next_obs: diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index d5d6783f..714e72bf 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -62,7 +62,7 @@ def train( # Update the soft-critic next_target_qf_value = agent.get_next_target_q_values( - normalized_next_obs, data["rewards"], data["dones"], cfg.algo.gamma + normalized_next_obs, data["rewards"], data["terminated"], cfg.algo.gamma ) qf_values = agent.get_q_values(normalized_obs, data["actions"]) qf_loss = critic_loss(qf_values, next_target_qf_value, agent.num_critics) @@ -324,8 +324,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = {k: torch.from_numpy(v).to(device).float() for k, v in normalized_obs.items()} actions, _ = actor(torch_obs) actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -357,7 +356,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): 1, cfg.env.num_envs, -1, *step_data[f"next_{k}"].shape[-2:] ) - step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.float32) step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1).astype(np.float32) step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) rb.add(step_data, validate_args=cfg.buffer.validate_args) From 74bfb6bbf82a59d23fe69ca30a678648e5f926c3 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 29 Mar 2024 17:59:54 +0100 Subject: [PATCH 28/51] fix: crafter and diambra wrappers --- sheeprl/envs/crafter.py | 2 +- sheeprl/envs/diambra.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sheeprl/envs/crafter.py b/sheeprl/envs/crafter.py index ae5a94cc..f0c6f71d 100644 --- a/sheeprl/envs/crafter.py +++ b/sheeprl/envs/crafter.py @@ -50,7 +50,7 @@ def _convert_obs(self, obs: np.ndarray) -> Dict[str, np.ndarray]: def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: obs, reward, done, info = self.env.step(action) - return self._convert_obs(obs), reward, done, False, info + return self._convert_obs(obs), reward, done and info["discount"] == 0, done and info["discount"] != 0, info def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None diff --git a/sheeprl/envs/diambra.py b/sheeprl/envs/diambra.py index 2e773ac1..002ed0ba 100644 --- a/sheeprl/envs/diambra.py +++ b/sheeprl/envs/diambra.py @@ -123,9 +123,9 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A if self._action_type == "discrete" and isinstance(action, np.ndarray): action = action.squeeze() action = action.item() - obs, reward, done, truncated, infos = self.env.step(action) + obs, reward, terminated, truncated, infos = self.env.step(action) infos["env_domain"] = "DIAMBRA" - return self._convert_obs(obs), reward, done or infos.get("env_done", False), truncated, infos + return self._convert_obs(obs), reward, terminated or infos.get("env_done", False), truncated, infos def render(self, mode: str = "rgb_array", **kwargs) -> Optional[Union[RenderFrame, List[RenderFrame]]]: return self.env.render() From 05e43705a27bbe8f8e9cbe57c2bcfccae6665016 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Sat, 30 Mar 2024 10:41:15 +0100 Subject: [PATCH 29/51] feat: replace done with truncated key in when the buffer is added to the checkpoint --- sheeprl/utils/callback.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sheeprl/utils/callback.py b/sheeprl/utils/callback.py index eaa099c6..19e588a7 100644 --- a/sheeprl/utils/callback.py +++ b/sheeprl/utils/callback.py @@ -105,14 +105,14 @@ def _ckpt_rb( """ if isinstance(rb, ReplayBuffer): # clone the true done - state = rb["dones"][(rb._pos - 1) % rb.buffer_size, :].copy() + state = rb["truncated"][(rb._pos - 1) % rb.buffer_size, :].copy() # substitute the last done with all True values (all the environment are truncated) - rb["dones"][(rb._pos - 1) % rb.buffer_size, :] = True + rb["truncated"][(rb._pos - 1) % rb.buffer_size, :] = 1 elif isinstance(rb, EnvIndependentReplayBuffer): state = [] for b in rb.buffer: - state.append(b["dones"][(b._pos - 1) % b.buffer_size, :].copy()) - b["dones"][(b._pos - 1) % b.buffer_size, :] = True + state.append(b["truncated"][(b._pos - 1) % b.buffer_size, :].copy()) + b["truncated"][(b._pos - 1) % b.buffer_size, :] = 1 elif isinstance(rb, EpisodeBuffer): # remove open episodes from the buffer because the state of the environment is not saved state = rb._open_episodes @@ -133,10 +133,10 @@ def _experiment_consistent_rb( """ if isinstance(rb, ReplayBuffer): # reinsert the true dones in the buffer - rb["dones"][(rb._pos - 1) % rb.buffer_size, :] = state + rb["truncated"][(rb._pos - 1) % rb.buffer_size, :] = state elif isinstance(rb, EnvIndependentReplayBuffer): for i, b in enumerate(rb.buffer): - b["dones"][(b._pos - 1) % b.buffer_size, :] = state[i] + b["truncated"][(b._pos - 1) % b.buffer_size, :] = state[i] elif isinstance(rb, EpisodeBuffer): # reinsert the open episodes to continue the training rb._open_episodes = state From 87c9098b3d0f02bbd20afbc9768f78db93c0ab37 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Sat, 30 Mar 2024 17:20:51 +0100 Subject: [PATCH 30/51] feat: added truncated/terminated to minedojo environment --- sheeprl/envs/minedojo.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index 00c0837f..4098864e 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -245,6 +245,8 @@ def step(self, action: np.ndarray) -> Tuple[Any, SupportsFloat, bool, bool, Dict action[3] = 12 obs, reward, done, info = self.env.step(action) + terminated = done and not info.get("TimeLimit.truncated", False) + truncated = done and info.get("TimeLimit.truncated", False) self._pos = { "x": float(obs["location_stats"]["pos"][0]), "y": float(obs["location_stats"]["pos"][1]), @@ -252,17 +254,19 @@ def step(self, action: np.ndarray) -> Tuple[Any, SupportsFloat, bool, bool, Dict "pitch": float(obs["location_stats"]["pitch"].item()), "yaw": float(obs["location_stats"]["yaw"].item()), } - info = { - "life_stats": { - "life": float(obs["life_stats"]["life"].item()), - "oxygen": float(obs["life_stats"]["oxygen"].item()), - "food": float(obs["life_stats"]["food"].item()), - }, - "location_stats": copy.deepcopy(self._pos), - "action": a.tolist(), - "biomeid": float(obs["location_stats"]["biome_id"].item()), - } - return self._convert_obs(obs), reward, done, False, info + info.update( + { + "life_stats": { + "life": float(obs["life_stats"]["life"].item()), + "oxygen": float(obs["life_stats"]["oxygen"].item()), + "food": float(obs["life_stats"]["food"].item()), + }, + "location_stats": copy.deepcopy(self._pos), + "action": a.tolist(), + "biomeid": float(obs["location_stats"]["biome_id"].item()), + } + ) + return self._convert_obs(obs), reward, terminated, truncated, info def reset( self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None From e137a38a1bb376255019fbbea25aa16b01fe7297 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 13:53:18 +0200 Subject: [PATCH 31/51] feat: added terminated/truncated to minerl and super mario bros envs --- sheeprl/configs/env/minerl.yaml | 5 ++++- sheeprl/configs/env/minerl_obtain_diamond.yaml | 12 ++++++++++++ .../configs/env/minerl_obtain_iron_pickaxe.yaml | 12 ++++++++++++ sheeprl/envs/minerl.py | 3 +-- sheeprl/envs/minerl_envs/navigate.py | 9 ++++++--- sheeprl/envs/minerl_envs/obtain.py | 17 ++++++++++------- sheeprl/envs/super_mario_bros.py | 3 ++- 7 files changed, 47 insertions(+), 14 deletions(-) create mode 100644 sheeprl/configs/env/minerl_obtain_diamond.yaml create mode 100644 sheeprl/configs/env/minerl_obtain_iron_pickaxe.yaml diff --git a/sheeprl/configs/env/minerl.yaml b/sheeprl/configs/env/minerl.yaml index 74e564a6..c3849c4f 100644 --- a/sheeprl/configs/env/minerl.yaml +++ b/sheeprl/configs/env/minerl.yaml @@ -5,6 +5,9 @@ defaults: # Override from `minecraft` config id: custom_navigate action_repeat: 1 +max_episode_steps: 12000 +reward_as_observation: True +num_envs: 4 # Wrapper to be instantiated wrapper: @@ -17,7 +20,7 @@ wrapper: - ${env.max_pitch} seed: null break_speed_multiplier: ${env.break_speed_multiplier} - multihot_inventory: True + multihot_inventory: False sticky_attack: ${env.sticky_attack} sticky_jump: ${env.sticky_jump} dense: True diff --git a/sheeprl/configs/env/minerl_obtain_diamond.yaml b/sheeprl/configs/env/minerl_obtain_diamond.yaml new file mode 100644 index 00000000..55699873 --- /dev/null +++ b/sheeprl/configs/env/minerl_obtain_diamond.yaml @@ -0,0 +1,12 @@ +defaults: + - minerl + - _self_ + +id: custom_obtain_diamond +action_repeat: 1 +max_episode_steps: 36000 +num_envs: 16 + +wrapper: + multihot_inventory: True + dense: False \ No newline at end of file diff --git a/sheeprl/configs/env/minerl_obtain_iron_pickaxe.yaml b/sheeprl/configs/env/minerl_obtain_iron_pickaxe.yaml new file mode 100644 index 00000000..78125a32 --- /dev/null +++ b/sheeprl/configs/env/minerl_obtain_iron_pickaxe.yaml @@ -0,0 +1,12 @@ +defaults: + - minerl + - _self_ + +id: custom_obtain_iron_pickaxe +action_repeat: 1 +max_episode_steps: 36000 +num_envs: 16 + +wrapper: + multihot_inventory: True + dense: False \ No newline at end of file diff --git a/sheeprl/envs/minerl.py b/sheeprl/envs/minerl.py index 35047bdf..338d384f 100644 --- a/sheeprl/envs/minerl.py +++ b/sheeprl/envs/minerl.py @@ -85,7 +85,7 @@ def __init__( self._height = height self._width = width self._pitch_limits = pitch_limits - self._sticky_attack = sticky_attack + self._sticky_attack = 0 if break_speed_multiplier > 1 else sticky_attack self._sticky_jump = sticky_jump self._sticky_attack_counter = 0 self._sticky_jump_counter = 0 @@ -303,7 +303,6 @@ def step(self, actions: np.ndarray) -> Tuple[Dict[str, Any], SupportsFloat, bool "pitch": next_pitch, "yaw": next_yaw, } - info = {} return self._convert_obs(obs), reward, done, False, info def reset( diff --git a/sheeprl/envs/minerl_envs/navigate.py b/sheeprl/envs/minerl_envs/navigate.py index f02723d5..619d17e6 100644 --- a/sheeprl/envs/minerl_envs/navigate.py +++ b/sheeprl/envs/minerl_envs/navigate.py @@ -9,7 +9,6 @@ import minerl.herobraine.hero.handlers as handlers from minerl.herobraine.hero.handler import Handler -from minerl.herobraine.hero.mc import MS_PER_STEP from sheeprl.envs.minerl_envs.backend import CustomSimpleEmbodimentEnvSpec @@ -22,7 +21,11 @@ def __init__(self, dense, extreme, *args, **kwargs): suffix += "Dense" if dense else "" name = "CustomMineRLNavigate{}-v0".format(suffix) self.dense, self.extreme = dense, extreme - super().__init__(name, *args, max_episode_steps=6000, **kwargs) + + # The time limit is handled outside because MineRL does not provide a way + # to distinguish between terminated and truncated + kwargs.pop("max_episode_steps", None) + super().__init__(name, *args, max_episode_steps=None, **kwargs) def is_from_folder(self, folder: str) -> bool: return folder == "navigateextreme" if self.extreme else folder == "navigate" @@ -60,7 +63,7 @@ def create_server_world_generators(self) -> List[Handler]: return [handlers.DefaultWorldGenerator(force_reset=True)] def create_server_quit_producers(self) -> List[Handler]: - return [handlers.ServerQuitFromTimeUp(NAVIGATE_STEPS * MS_PER_STEP), handlers.ServerQuitWhenAnyAgentFinishes()] + return [handlers.ServerQuitWhenAnyAgentFinishes()] def create_server_decorators(self) -> List[Handler]: return [ diff --git a/sheeprl/envs/minerl_envs/obtain.py b/sheeprl/envs/minerl_envs/obtain.py index ce5f9d1c..3302024e 100644 --- a/sheeprl/envs/minerl_envs/obtain.py +++ b/sheeprl/envs/minerl_envs/obtain.py @@ -9,7 +9,6 @@ from minerl.herobraine.hero import handlers from minerl.herobraine.hero.handler import Handler -from minerl.herobraine.hero.mc import MS_PER_STEP from sheeprl.envs.minerl_envs.backend import CustomSimpleEmbodimentEnvSpec @@ -28,7 +27,7 @@ def __init__( dense, reward_schedule: List[Dict[str, Union[str, int, float]]], *args, - max_episode_steps=6000, + max_episode_steps=None, **kwargs, ): # 6000 for obtain iron (5 mins) @@ -138,10 +137,7 @@ def create_server_world_generators(self) -> List[Handler]: return [handlers.DefaultWorldGenerator(force_reset=True)] def create_server_quit_producers(self) -> List[Handler]: - return [ - handlers.ServerQuitFromTimeUp(time_limit_ms=self.max_episode_steps * MS_PER_STEP), - handlers.ServerQuitWhenAnyAgentFinishes(), - ] + return [handlers.ServerQuitWhenAnyAgentFinishes()] def create_server_decorators(self) -> List[Handler]: return [] @@ -175,6 +171,9 @@ def determine_success_from_rewards(self, rewards: list) -> bool: class CustomObtainDiamond(CustomObtain): def __init__(self, dense, *args, **kwargs): + # The time limit is handled outside because MineRL does not provide a way + # to distinguish between terminated and truncated + kwargs.pop("max_episode_steps", None) super(CustomObtainDiamond, self).__init__( *args, target_item="diamond", @@ -193,7 +192,7 @@ def __init__(self, dense, *args, **kwargs): dict(type="iron_pickaxe", amount=1, reward=256), dict(type="diamond", amount=1, reward=1024), ], - max_episode_steps=18000, + max_episode_steps=None, **kwargs, ) @@ -251,6 +250,9 @@ def get_docstring(self): class CustomObtainIronPickaxe(CustomObtain): def __init__(self, dense, *args, **kwargs): + # The time limit is handled outside because MineRL does not provide a way + # to distinguish between terminated and truncated + kwargs.pop("max_episode_steps", None) super(CustomObtainIronPickaxe, self).__init__( *args, target_item="iron_pickaxe", @@ -268,6 +270,7 @@ def __init__(self, dense, *args, **kwargs): dict(type="iron_ingot", amount=1, reward=128), dict(type="iron_pickaxe", amount=1, reward=256), ], + max_episode_steps=None, **kwargs, ) diff --git a/sheeprl/envs/super_mario_bros.py b/sheeprl/envs/super_mario_bros.py index 7fe41a13..0c1043fa 100644 --- a/sheeprl/envs/super_mario_bros.py +++ b/sheeprl/envs/super_mario_bros.py @@ -55,7 +55,8 @@ def step(self, action: np.ndarray | int) -> Tuple[Any, SupportsFloat, bool, bool action = action.squeeze().item() obs, reward, done, info = self.env.step(action) converted_obs = {"rgb": obs.copy()} - return converted_obs, reward, done, False, info + is_timelimit = info.get("time", False) + return converted_obs, reward, done and not is_timelimit, done and is_timelimit, info def render(self) -> RenderFrame | list[RenderFrame] | None: rendered_frame: np.ndarray | None = self.env.render(mode=self.render_mode) From 64d3c81092dbc8ab16086ffe8b2a9c1b5f73db6e Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 14:37:05 +0200 Subject: [PATCH 32/51] docs: update howto --- howto/learn_in_minedojo.md | 4 ++++ howto/learn_in_minerl.md | 3 +++ howto/logs_and_checkpoints.md | 1 - howto/select_observations.md | 1 + howto/work_with_steps.md | 10 ++++++---- 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/howto/learn_in_minedojo.md b/howto/learn_in_minedojo.md index b9fabbe1..c9c3ee8a 100644 --- a/howto/learn_in_minedojo.md +++ b/howto/learn_in_minedojo.md @@ -62,6 +62,10 @@ Moreover, we restrict the look-up/down actions between `min_pitch` and `max_pitc In addition, we added the forward action when the agent selects one of the following actions: `jump`, `sprint`, and `sneak`. Finally, we added sticky actions for the `jump` and `attack` actions. You can set the values of the `sticky_jump` and `sticky_attack` parameters through the `env.sticky_jump` and `env.sticky_attack` cli arguments, respectively. The sticky actions, if set, force the agent to repeat the selected actions for a certain number of steps. +> [!NOTE] +> +> The `env.sticky_attack` parameter is set to `0` if the `env.break_speed_multiplier > 1`. + For more information about the MineDojo action space, check [here](https://docs.minedojo.org/sections/core_api/action_space.html). > [!NOTE] diff --git a/howto/learn_in_minerl.md b/howto/learn_in_minerl.md index 80edff01..28389df6 100644 --- a/howto/learn_in_minerl.md +++ b/howto/learn_in_minerl.md @@ -51,6 +51,9 @@ Finally, we added sticky actions for the `jump` and `attack` actions. You can se > > The action repeat in the Minecraft environments is set to 1, indeed, It makes no sense to force the agent to repeat an action such as crafting (it may not have enough material for the second action). +> [!NOTE] +> The `env.sticky_attack` parameter is set to `0` if the `env.break_speed_multiplier > 1`. + ## Headless machines If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: diff --git a/howto/logs_and_checkpoints.md b/howto/logs_and_checkpoints.md index 57a2d8e9..7c72d4c3 100644 --- a/howto/logs_and_checkpoints.md +++ b/howto/logs_and_checkpoints.md @@ -122,7 +122,6 @@ AGGREGATOR_KEYS = { "State/post_entropy", "State/prior_entropy", "State/kl", - "Params/exploration_amount", "Grads/world_model", "Grads/actor", "Grads/critic", diff --git a/howto/select_observations.md b/howto/select_observations.md index e984ee9f..61a6188c 100644 --- a/howto/select_observations.md +++ b/howto/select_observations.md @@ -8,6 +8,7 @@ In the first case, the observations are returned in the form of python dictionar ### Both observations The algorithms that can work with both image and vector observations are specified in [Table 1](../README.md) in the README, and are reported here: +* A2C * PPO * PPO Recurrent * SAC-AE diff --git a/howto/work_with_steps.md b/howto/work_with_steps.md index ec5a7147..5bc62a8c 100644 --- a/howto/work_with_steps.md +++ b/howto/work_with_steps.md @@ -22,11 +22,13 @@ The hyper-parameters that refer to the *policy steps* are: * `exploration_steps`: the number of policy steps in which the agent explores the environment in the P2E algorithms. * `max_episode_steps`: the maximum number of policy steps an episode can last (`max_steps`); when this number is reached a `terminated=True` is returned by the environment. This means that if you decide to have an action repeat greater than one (`action_repeat > 1`), then the environment performs a maximum number of steps equal to: `env_steps = max_steps * action_repeat`$. * `learning_starts`: how many policy steps the agent has to perform before starting the training. -* `train_every`: how many policy steps the agent has to perform between one training and the next. ## Gradient steps -A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, `n * gradient_steps` calls to the *train* method will be executed. +A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, `n * per_rank_gradient_steps` calls to the *train* method will be executed. The hyper-parameters which refer to the *gradient steps* are: -* `algo.per_rank_gradient_steps`: the number of gradient steps per rank to perform in a single iteration. -* `algo.per_rank_pretrain_steps`: the number of gradient steps per rank to perform in the first iteration. \ No newline at end of file +* `algo.per_rank_pretrain_steps`: the number of gradient steps per rank to perform in the first iteration. + +> [!NOTE] +> +> The `replay_ratio` is the ratio between the gradient steps and the policy steps played by the agente. \ No newline at end of file From 2e156f3fe119f945028d98f5c1f0ff405571a268 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 14:37:15 +0200 Subject: [PATCH 33/51] fix: minedojo wrapper --- sheeprl/envs/minedojo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index 4098864e..673f9207 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -69,7 +69,7 @@ def __init__( self._pos = kwargs.get("start_position", None) self._break_speed_multiplier = kwargs.get("break_speed_multiplier", 100) self._start_pos = copy.deepcopy(self._pos) - self._sticky_attack = sticky_attack + self._sticky_attack = 0 if self._break_speed_multiplier > 1 else sticky_attack self._sticky_jump = sticky_jump self._sticky_attack_counter = 0 self._sticky_jump_counter = 0 From 0167fd5c6a3d848989a07d76084d0289dada08aa Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 14:38:23 +0200 Subject: [PATCH 34/51] docs: update --- howto/learn_in_minedojo.md | 1 + howto/learn_in_minerl.md | 2 ++ 2 files changed, 3 insertions(+) diff --git a/howto/learn_in_minedojo.md b/howto/learn_in_minedojo.md index c9c3ee8a..1eea5a3f 100644 --- a/howto/learn_in_minedojo.md +++ b/howto/learn_in_minedojo.md @@ -69,6 +69,7 @@ Finally, we added sticky actions for the `jump` and `attack` actions. You can se For more information about the MineDojo action space, check [here](https://docs.minedojo.org/sections/core_api/action_space.html). > [!NOTE] +> > Since the MineDojo environments have a multi-discrete action space, the sticky actions can be easily implemented. The agent will perform the selected action and the sticky actions simultaneously. > > The action repeat in the Minecraft environments is set to 1, indeed, It makes no sense to force the agent to repeat an action such as crafting (it may not have enough material for the second action). diff --git a/howto/learn_in_minerl.md b/howto/learn_in_minerl.md index 28389df6..50d61a33 100644 --- a/howto/learn_in_minerl.md +++ b/howto/learn_in_minerl.md @@ -47,11 +47,13 @@ In addition, we added the forward action when the agent selects one of the follo Finally, we added sticky actions for the `jump` and `attack` actions. You can set the values of the `sticky_jump` and `sticky_attack` parameters through the `env.sticky_jump` and `env.sticky_attack` arguments, respectively. The sticky actions, if set, force the agent to repeat the selected actions for a certain number of steps. > [!NOTE] +> > Since the MineRL environments have a multi-discrete action space, the sticky actions can be easily implemented. The agent will perform the selected action and the sticky actions simultaneously. > > The action repeat in the Minecraft environments is set to 1, indeed, It makes no sense to force the agent to repeat an action such as crafting (it may not have enough material for the second action). > [!NOTE] +> > The `env.sticky_attack` parameter is set to `0` if the `env.break_speed_multiplier > 1`. ## Headless machines From 09e051eb99a2b05fb20c1a6e1e1bdb42a5f4a717 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 14:43:10 +0200 Subject: [PATCH 35/51] fix: minedojo --- sheeprl/envs/minedojo.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index 673f9207..44746fca 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -245,8 +245,9 @@ def step(self, action: np.ndarray) -> Tuple[Any, SupportsFloat, bool, bool, Dict action[3] = 12 obs, reward, done, info = self.env.step(action) - terminated = done and not info.get("TimeLimit.truncated", False) - truncated = done and info.get("TimeLimit.truncated", False) + is_timelimit = info.get("TimeLimit.truncated", False) + terminated = done and not is_timelimit + truncated = done and is_timelimit self._pos = { "x": float(obs["location_stats"]["pos"][0]), "y": float(obs["location_stats"]["pos"][1]), From dacd42577e2cd9f8b9238ad6562a181168c4b7f4 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 14:44:05 +0200 Subject: [PATCH 36/51] update dependencies --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b4b7362b..6577c495 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,8 +81,8 @@ atari = [ "gymnasium[accept-rom-license]==0.29.*", "gymnasium[other]==0.29.*", ] -minedojo = ["minedojo==0.1", "importlib_resources==5.12.0"] -minerl = ["setuptools==66.0.0", "minerl==0.4.4"] +minedojo = ["minedojo==0.1", "importlib_resources==5.12.0", "gym==0.21.0"] +minerl = ["setuptools==66.0.0", "minerl==0.4.4", "gym==0.19.0"] diambra = ["diambra==0.0.17", "diambra-arena==2.2.6"] crafter = ["crafter==1.8.3"] mlflow = ["mlflow==2.11.1"] From f2557a3c6db40cea1b9ca0267442f625213d13d1 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 14:49:08 +0200 Subject: [PATCH 37/51] fix: minedojo --- sheeprl/envs/minedojo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index 44746fca..17cf7b0d 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -83,7 +83,6 @@ def __init__( task_id=id, image_size=(height, width), world_seed=seed, - generate_world_type="default", fast_reset=True, **kwargs, ) From 5bf50dd1e5184a4b06131c2b74b301c671519033 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 14:58:54 +0200 Subject: [PATCH 38/51] fix: dv3 small configs --- sheeprl/configs/algo/dreamer_v3_S.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sheeprl/configs/algo/dreamer_v3_S.yaml b/sheeprl/configs/algo/dreamer_v3_S.yaml index ff067597..c8f36153 100644 --- a/sheeprl/configs/algo/dreamer_v3_S.yaml +++ b/sheeprl/configs/algo/dreamer_v3_S.yaml @@ -2,14 +2,14 @@ defaults: - dreamer_v3_XL - _self_ -dense_units: 2 -mlp_layers: 1 +dense_units: 512 +mlp_layers: 2 world_model: encoder: - cnn_channels_multiplier: 1 + cnn_channels_multiplier: 32 recurrent_model: - recurrent_state_size: 2 + recurrent_state_size: 512 transition_model: - hidden_size: 2 + hidden_size: 512 representation_model: - hidden_size: 2 \ No newline at end of file + hidden_size: 512 \ No newline at end of file From f58a3c21a08ac9593850c5e6cd793d1991574ea0 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 15:50:22 +0200 Subject: [PATCH 39/51] fix: episode buffer and tests --- sheeprl/data/buffers.py | 24 ++- tests/test_data/test_episode_buffer.py | 234 +++++++++++++++---------- 2 files changed, 160 insertions(+), 98 deletions(-) diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index bbf10d5a..8c51c9b1 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -923,8 +923,10 @@ def add( last_key = current_key last_batch_shape = current_batch_shape - if "dones" not in data: - raise RuntimeError(f"The episode must contain the `dones` key, got: {data.keys()}") + if "terminated" not in data and "truncated" not in data: + raise RuntimeError( + f"The episode must contain the `terminated` and the `truncated` keys, got: {data.keys()}" + ) if env_idxes is not None and (np.array(env_idxes) >= self._n_envs).any(): raise ValueError( @@ -937,7 +939,7 @@ def add( for i, env in enumerate(env_idxes): # Take the data from a single environment env_data = {k: v[:, i] for k, v in data.items()} - done = env_data["dones"] + done = np.logical_or(env_data["terminated"], env_data["truncated"]) # Take episode ends episode_ends = done.nonzero()[0].tolist() # If there is not any done, then add the data to the respective open episode @@ -954,12 +956,15 @@ def add( episode = {k: env_data[k][start : stop + 1] for k in env_data.keys()} # If the episode length is greater than zero, then add it to the open episode # of the corresponding environment. - if len(episode["dones"]) > 0: + if len(np.logical_or(episode["terminated"], episode["truncated"])) > 0: self._open_episodes[env].append(episode) start = stop + 1 # If the open episode is not empty and the last element is a done, then save the episode # in the buffer and clear the open episode - if len(self._open_episodes[env]) > 0 and self._open_episodes[env][-1]["dones"][-1] == 1: + should_save = len(self._open_episodes[env]) > 0 and np.logical_or( + self._open_episodes[env][-1]["terminated"][-1], self._open_episodes[env][-1]["truncated"][-1] + ) + if should_save: self._save_episode(self._open_episodes[env]) self._open_episodes[env] = [] @@ -974,9 +979,10 @@ def _save_episode(self, episode_chunks: Sequence[Dict[str, np.ndarray | MemmapAr episode = {k: np.concatenate(v, axis=0) for k, v in episode.items()} # Control the validity of the episode - ep_len = episode["dones"].shape[0] - if len(episode["dones"].nonzero()[0]) != 1 or episode["dones"][-1] != 1: - raise RuntimeError(f"The episode must contain exactly one done, got: {len(np.nonzero(episode['dones']))}") + ends = np.logical_or(episode["terminated"], episode["truncated"]) + ep_len = ends.shape[0] + if len(ends.nonzero()[0]) != 1 or ends[-1] != 1: + raise RuntimeError(f"The episode must contain exactly one done, got: {len(np.nonzero(ends))}") if ep_len < self._minimum_episode_length: raise RuntimeError( f"Episode too short (at least {self._minimum_episode_length} steps), got: {ep_len} steps" @@ -1076,7 +1082,7 @@ def sample( samples_per_eps.update({f"next_{k}": [] for k in self._obs_keys}) for i, n in enumerate(nsample_per_eps): if n > 0: - ep_len = valid_episodes[i]["dones"].shape[0] + ep_len = np.logical_or(valid_episodes[i]["terminated"], valid_episodes[i]["truncated"]).shape[0] if sample_next_obs: ep_len -= 1 # Define the maximum index that can be sampled in the episodes diff --git a/tests/test_data/test_episode_buffer.py b/tests/test_data/test_episode_buffer.py index 967abfbf..12fdba37 100644 --- a/tests/test_data/test_episode_buffer.py +++ b/tests/test_data/test_episode_buffer.py @@ -35,51 +35,51 @@ def test_episode_buffer_add_episodes(): buf_size = 30 sl = 5 n_envs = 1 - obs_keys = ("dones",) + obs_keys = ("terminated",) rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((sl, n_envs, 1))} - ep2 = {"dones": np.zeros((sl + 5, n_envs, 1))} - ep3 = {"dones": np.zeros((sl + 10, n_envs, 1))} - ep4 = {"dones": np.zeros((sl, n_envs, 1))} - ep1["dones"][-1] = 1 - ep2["dones"][-1] = 1 - ep3["dones"][-1] = 1 - ep4["dones"][-1] = 1 + ep1 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep2 = {"terminated": np.zeros((sl + 5, n_envs, 1)), "truncated": np.zeros((sl + 5, n_envs, 1))} + ep3 = {"terminated": np.zeros((sl + 10, n_envs, 1)), "truncated": np.zeros((sl + 10, n_envs, 1))} + ep4 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep1["terminated"][-1] = 1 + ep2["truncated"][-1] = 1 + ep3["terminated"][-1] = 1 + ep4["truncated"][-1] = 1 rb.add(ep1) rb.add(ep2) rb.add(ep3) rb.add(ep4) assert rb.full - assert (rb._buf[-1]["dones"] == ep4["dones"][:, 0]).all() - assert (rb._buf[0]["dones"] == ep2["dones"][:, 0]).all() + assert (rb._buf[-1]["terminated"] == ep4["terminated"][:, 0]).all() + assert (rb._buf[0]["terminated"] == ep2["terminated"][:, 0]).all() def test_episode_buffer_add_single_dict(): buf_size = 5 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((sl, n_envs, 1))} - ep1["dones"][-1] = 1 + ep1 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep1["truncated"][-1] = 1 rb.add(ep1) assert rb.full for env in range(n_envs): - assert (rb._buf[0]["dones"] == ep1["dones"][:, env]).all() + assert (rb._buf[0]["terminated"] == ep1["terminated"][:, env]).all() def test_episode_buffer_error_add(): buf_size = 10 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated",) rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) ep1 = torch.zeros(sl, n_envs, 1) with pytest.raises(ValueError, match="`data` must be a dictionary containing Numpy arrays, but `data` is of type"): rb.add(ep1, validate_args=True) - ep2 = {"dones": torch.zeros((sl, n_envs, 1))} + ep2 = {"terminated": torch.zeros((sl, n_envs, 1)), "truncated": torch.zeros((sl, n_envs, 1))} with pytest.raises(ValueError, match="`data` must be a dictionary containing Numpy arrays. Found key"): rb.add(ep2, validate_args=True) @@ -87,22 +87,22 @@ def test_episode_buffer_error_add(): with pytest.raises(ValueError, match="The `data` replay buffer must be not None"): rb.add(ep3, validate_args=True) - ep4 = {"dones": np.zeros((1,))} + ep4 = {"terminated": np.zeros((1,)), "truncated": np.zeros((1,))} with pytest.raises(RuntimeError, match=r"`data` must have at least 2: \[sequence_length, n_envs"): rb.add(ep4, validate_args=True) - obs_keys = ("dones", "obs") + obs_keys = ("terminated", "obs") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep5 = {"dones": np.zeros((sl, n_envs, 1)), "obs": np.zeros((sl, 1, 6))} + ep5 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1)), "obs": np.zeros((sl, 1, 6))} with pytest.raises(RuntimeError, match="Every array in `data` must be congruent in the first 2 dimensions"): rb.add(ep5, validate_args=True) ep6 = {"obs": np.zeros((sl, 1, 6))} - with pytest.raises(RuntimeError, match="The episode must contain the `dones` key"): + with pytest.raises(RuntimeError, match="The episode must contain the `terminated`"): rb.add(ep6, validate_args=True) - ep7 = {"dones": np.zeros((sl, 1, 1))} - ep7["dones"][-1] = 1 + ep7 = {"terminated": np.zeros((sl, 1, 1)), "truncated": np.zeros((sl, 1, 1))} + ep7["terminated"][-1] = 1 with pytest.raises(ValueError, match="The indices of the environment must be integers in"): rb.add(ep7, validate_args=True, env_idxes=[10]) @@ -111,9 +111,9 @@ def test_add_only_for_some_envs(): buf_size = 10 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((sl, n_envs - 2, 1))} + ep1 = {"terminated": np.zeros((sl, n_envs - 2, 1)), "truncated": np.zeros((sl, n_envs - 2, 1))} rb.add(ep1, env_idxes=[0, 3]) assert len(rb._open_episodes[0]) > 0 assert len(rb._open_episodes[1]) == 0 @@ -125,16 +125,24 @@ def test_save_episode(): buf_size = 100 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] - ep_chunks[-1]["dones"][-1] = 1 + ep_chunks = [] + for _ in range(8): + chunk_dim = (np.random.randint(1, 8, (1,)).item(), 1) + ep_chunks.append( + { + "terminated": np.zeros(chunk_dim), + "truncated": np.zeros(chunk_dim), + } + ) + ep_chunks[-1]["terminated"][-1] = 1 rb._save_episode(ep_chunks) assert len(rb._buf) == 1 assert ( - np.concatenate([e["dones"] for e in rb.buffer], axis=0) - == np.concatenate([c["dones"] for c in ep_chunks], axis=0) + np.concatenate([e["terminated"] for e in rb.buffer], axis=0) + == np.concatenate([c["terminated"] for c in ep_chunks], axis=0) ).all() @@ -142,29 +150,35 @@ def test_save_episode_errors(): buf_size = 100 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) with pytest.raises(RuntimeError, match="Invalid episode, an empty sequence is given"): rb._save_episode([]) - ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] - ep_chunks[-1]["dones"][-1] = 1 - ep_chunks[0]["dones"][-1] = 1 + ep_chunks = [] + for _ in range(8): + chunk_dim = (np.random.randint(1, 8, (1,)).item(), 1) + ep_chunks.append({"terminated": np.zeros(chunk_dim), "truncated": np.zeros(chunk_dim)}) + ep_chunks[-1]["terminated"][-1] = 1 + ep_chunks[0]["truncated"][-1] = 1 with pytest.raises(RuntimeError, match="The episode must contain exactly one done"): rb._save_episode(ep_chunks) - ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] - ep_chunks[0]["dones"][-1] = 1 + ep_chunks = [] + for _ in range(8): + chunk_dim = (np.random.randint(1, 8, (1,)).item(), 1) + ep_chunks.append({"terminated": np.zeros(chunk_dim), "truncated": np.zeros(chunk_dim)}) + ep_chunks[0]["terminated"][-1] = 1 with pytest.raises(RuntimeError, match="The episode must contain exactly one done"): rb._save_episode(ep_chunks) - ep_chunks = [{"dones": np.ones((1, 1))}] + ep_chunks = [{"terminated": np.ones((1, 1)), "truncated": np.zeros((1, 1))}] with pytest.raises(RuntimeError, match="Episode too short"): rb._save_episode(ep_chunks) - ep_chunks = [{"dones": np.zeros((110, 1))} for _ in range(8)] - ep_chunks[-1]["dones"][-1] = 1 + ep_chunks = [{"terminated": np.zeros((110, 1)), "truncated": np.zeros((110, 1))} for _ in range(8)] + ep_chunks[-1]["truncated"][-1] = 1 with pytest.raises(RuntimeError, match="Episode too long"): rb._save_episode(ep_chunks) @@ -173,14 +187,19 @@ def test_episode_buffer_sample_one_element(): buf_size = 5 sl = 5 n_envs = 1 - obs_keys = ("dones", "a") + obs_keys = ("terminated", "truncated", "a") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep = {"dones": np.zeros((sl, n_envs, 1)), "a": np.random.rand(sl, n_envs, 1)} - ep["dones"][-1] = 1 + ep = { + "terminated": np.zeros((sl, n_envs, 1)), + "truncated": np.zeros((sl, n_envs, 1)), + "a": np.random.rand(sl, n_envs, 1), + } + ep["terminated"][-1] = 1 rb.add(ep) sample = rb.sample(1, n_samples=1, sequence_length=sl) assert rb.full - assert (sample["dones"][0, :, 0] == ep["dones"][:, 0]).all() + assert (sample["terminated"][0, :, 0] == ep["terminated"][:, 0]).all() + assert (sample["truncated"][0, :, 0] == ep["truncated"][:, 0]).all() assert (sample["a"][0, :, 0] == ep["a"][:, 0]).all() @@ -188,42 +207,58 @@ def test_episode_buffer_sample_shapes(): buf_size = 30 sl = 2 n_envs = 1 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep = {"dones": np.zeros((sl, n_envs, 1))} - ep["dones"][-1] = 1 + ep = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep["truncated"][-1] = 1 rb.add(ep) sample = rb.sample(3, n_samples=2, sequence_length=sl) - assert sample["dones"].shape[:-1] == tuple([2, sl, 3]) + assert sample["terminated"].shape[:-1] == tuple([2, sl, 3]) + assert sample["truncated"].shape[:-1] == tuple([2, sl, 3]) sample = rb.sample(3, n_samples=2, sequence_length=sl, clone=True) - assert sample["dones"].shape[:-1] == tuple([2, sl, 3]) + assert sample["terminated"].shape[:-1] == tuple([2, sl, 3]) + assert sample["truncated"].shape[:-1] == tuple([2, sl, 3]) def test_episode_buffer_sample_more_episodes(): buf_size = 100 sl = 15 n_envs = 1 - obs_keys = ("dones", "a") + obs_keys = ("terminated", "truncated", "a") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((40, n_envs, 1)), "a": np.ones((40, n_envs, 1)) * -1} - ep2 = {"dones": np.zeros((45, n_envs, 1)), "a": np.ones((45, n_envs, 1)) * -2} - ep3 = {"dones": np.zeros((50, n_envs, 1)), "a": np.ones((50, n_envs, 1)) * -3} - ep1["dones"][-1] = 1 - ep2["dones"][-1] = 1 - ep3["dones"][-1] = 1 + ep1 = { + "terminated": np.zeros((40, n_envs, 1)), + "a": np.ones((40, n_envs, 1)) * -1, + "truncated": np.zeros((40, n_envs, 1)), + } + ep2 = { + "terminated": np.zeros((45, n_envs, 1)), + "a": np.ones((45, n_envs, 1)) * -2, + "truncated": np.zeros((45, n_envs, 1)), + } + ep3 = { + "terminated": np.zeros((50, n_envs, 1)), + "a": np.ones((50, n_envs, 1)) * -3, + "truncated": np.zeros((50, n_envs, 1)), + } + ep1["terminated"][-1] = 1 + ep2["truncated"][-1] = 1 + ep3["terminated"][-1] = 1 rb.add(ep1) rb.add(ep2) rb.add(ep3) samples = rb.sample(50, n_samples=5, sequence_length=sl) - assert samples["dones"].shape[:-1] == tuple([5, sl, 50]) + assert samples["terminated"].shape[:-1] == tuple([5, sl, 50]) + assert samples["truncated"].shape[:-1] == tuple([5, sl, 50]) samples = {k: np.moveaxis(samples[k], 2, 1).reshape(-1, sl, 1) for k in obs_keys} - for i in range(len(samples["dones"])): + for i in range(len(samples["terminated"])): assert ( np.isin(samples["a"][i], -1).all() or np.isin(samples["a"][i], -2).all() or np.isin(samples["a"][i], -3).all() ) - assert len(samples["dones"][i].nonzero()[0]) == 0 or samples["dones"][i][-1] == 1 + assert len(samples["terminated"][i].nonzero()[0]) == 0 or samples["terminated"][i][-1] == 1 + assert len(samples["truncated"][i].nonzero()[0]) == 0 or samples["truncated"][i][-1] == 1 def test_episode_buffer_error_sample(): @@ -236,7 +271,7 @@ def test_episode_buffer_error_sample(): rb.sample(-1, n_samples=2) with pytest.raises(ValueError, match="The number of samples must be greater than 0"): rb.sample(2, n_samples=-1) - ep1 = {"dones": np.zeros((15, 1, 1))} + ep1 = {"terminated": np.zeros((15, 1, 1)), "truncated": np.zeros((15, 1, 1))} rb.add(ep1) with pytest.raises(RuntimeError, match="No valid episodes has been added to the buffer"): rb.sample(2, n_samples=2, sequence_length=20) @@ -247,34 +282,38 @@ def test_episode_buffer_prioritize_ends(): buf_size = 100 sl = 15 n_envs = 1 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, prioritize_ends=True) - ep1 = {"dones": np.zeros((15, n_envs, 1))} - ep2 = {"dones": np.zeros((25, n_envs, 1))} - ep3 = {"dones": np.zeros((30, n_envs, 1))} - ep1["dones"][-1] = 1 - ep2["dones"][-1] = 1 - ep3["dones"][-1] = 1 + ep1 = {"terminated": np.zeros((15, n_envs, 1)), "truncated": np.zeros((15, n_envs, 1))} + ep2 = {"terminated": np.zeros((25, n_envs, 1)), "truncated": np.zeros((25, n_envs, 1))} + ep3 = {"terminated": np.zeros((30, n_envs, 1)), "truncated": np.zeros((30, n_envs, 1))} + ep1["truncated"][-1] = 1 + ep2["terminated"][-1] = 1 + ep3["truncated"][-1] = 1 rb.add(ep1) rb.add(ep2) rb.add(ep3) samples = rb.sample(50, n_samples=5, sequence_length=sl) - assert samples["dones"].shape[:-1] == tuple([5, sl, 50]) - assert np.isin(samples["dones"], 1).any() > 0 + assert samples["terminated"].shape[:-1] == tuple([5, sl, 50]) + assert samples["truncated"].shape[:-1] == tuple([5, sl, 50]) + assert np.isin(samples["terminated"], 1).any() > 0 + assert np.isin(samples["truncated"], 1).any() > 0 def test_sample_next_obs(): buf_size = 10 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((sl, n_envs, 1))} - ep1["dones"][-1] = 1 + ep1 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep1["terminated"][-1] = 1 rb.add(ep1) sample = rb.sample(10, True, n_samples=5, sequence_length=sl - 1) - assert "next_dones" in sample - assert (sample["next_dones"][:, -1] == 1).all() + assert "next_terminated" in sample + assert "next_truncated" in sample + assert (sample["next_terminated"][:, -1] == 1).all() + assert not (sample["next_truncated"][:, -1] == 1).any() def test_memmap_episode_buffer(): @@ -282,7 +321,7 @@ def test_memmap_episode_buffer(): bs = 4 sl = 4 n_envs = 1 - obs_keys = ("dones", "observations") + obs_keys = ("terminated", "truncated", "observations") with pytest.raises( ValueError, match="The buffer is set to be memory-mapped but the `memmap_dir` attribute is None", @@ -293,11 +332,13 @@ def test_memmap_episode_buffer(): for _ in range(buf_size // bs): ep = { "observations": np.random.randint(0, 256, (bs, n_envs, 3, 64, 64), dtype=np.uint8), - "dones": np.zeros((bs, n_envs, 1)), + "terminated": np.zeros((bs, n_envs, 1)), + "truncated": np.zeros((bs, n_envs, 1)), } - ep["dones"][-1] = 1 + ep["truncated"][-1] = 1 rb.add(ep) - assert isinstance(rb._buf[-1]["dones"], MemmapArray) + assert isinstance(rb._buf[-1]["terminated"], MemmapArray) + assert isinstance(rb._buf[-1]["truncated"], MemmapArray) assert isinstance(rb._buf[-1]["observations"], MemmapArray) assert rb.is_memmap del rb @@ -309,7 +350,7 @@ def test_memmap_to_file_episode_buffer(): bs = 5 sl = 4 n_envs = 1 - obs_keys = ("dones", "observations") + obs_keys = ("terminated", "truncated", "observations") memmap_dir = "test_episode_buffer" rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, memmap=True, memmap_dir=memmap_dir) for i in range(4): @@ -319,15 +360,19 @@ def test_memmap_to_file_episode_buffer(): bs = 5 ep = { "observations": np.random.randint(0, 256, (bs, n_envs, 3, 64, 64), dtype=np.uint8), - "dones": np.zeros((bs, n_envs, 1)), + "terminated": np.zeros((bs, n_envs, 1)), + "truncated": np.zeros((bs, n_envs, 1)), } - ep["dones"][-1] = 1 + ep["terminated"][-1] = 1 rb.add(ep) del ep - assert isinstance(rb._buf[-1]["dones"], MemmapArray) + assert isinstance(rb._buf[-1]["terminated"], MemmapArray) + assert isinstance(rb._buf[-1]["truncated"], MemmapArray) assert isinstance(rb._buf[-1]["observations"], MemmapArray) - memmap_dir = os.path.dirname(rb._buf[-1]["dones"].filename) - assert os.path.exists(os.path.join(memmap_dir, "dones.memmap")) + memmap_dir = os.path.dirname(rb._buf[-1]["terminated"].filename) + memmap_dir = os.path.dirname(rb._buf[-1]["truncated"].filename) + assert os.path.exists(os.path.join(memmap_dir, "terminated.memmap")) + assert os.path.exists(os.path.join(memmap_dir, "truncated.memmap")) assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) assert rb.is_memmap for ep in rb.buffer: @@ -342,8 +387,12 @@ def test_sample_tensors(): buf_size = 10 n_envs = 1 rb = EpisodeBuffer(buf_size, n_envs) - td = {"observations": np.arange(8).reshape(-1, 1, 1), "dones": np.zeros((8, 1, 1))} - td["dones"][-1] = 1 + td = { + "observations": np.arange(8).reshape(-1, 1, 1), + "terminated": np.zeros((8, 1, 1)), + "truncated": np.zeros((8, 1, 1)), + } + td["truncated"][-1] = 1 rb.add(td) s = rb.sample_tensors(10, sample_next_obs=True, n_samples=3, sequence_length=5) assert isinstance(s["observations"], torch.Tensor) @@ -363,9 +412,10 @@ def test_sample_tensor_memmap(): rb = EpisodeBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir, obs_keys=("observations")) td = { "observations": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), - "dones": np.zeros((buf_size, n_envs, 1)), + "terminated": np.zeros((buf_size, n_envs, 1)), + "truncated": np.zeros((buf_size, n_envs, 1)), } - td["dones"][-1] = 1 + td["terminated"][-1] = 1 rb.add(td) sample = rb.sample_tensors(10, False, n_samples=3, sequence_length=5) assert isinstance(sample["observations"], torch.Tensor) @@ -378,8 +428,14 @@ def test_add_rb(): buf_size = 10 n_envs = 3 rb = ReplayBuffer(buf_size, n_envs) - rb.add({"dones": np.zeros((buf_size, n_envs, 1)), "a": np.random.rand(buf_size, n_envs, 5)}) - rb["dones"][-1] = 1 + rb.add( + { + "terminated": np.zeros((buf_size, n_envs, 1)), + "truncated": np.zeros((buf_size, n_envs, 1)), + "a": np.random.rand(buf_size, n_envs, 5), + } + ) + rb["truncated"][-1] = 1 epb = EpisodeBuffer(buf_size * n_envs, minimum_episode_length=2, n_envs=n_envs) epb.add(rb) assert (rb["a"][:, 0] == epb._buf[0]["a"]).all() From d19a8ba1ac4aa58cf497e9db9fe07b4c14684bdd Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 17:16:53 +0200 Subject: [PATCH 40/51] feat: added possibility to choose layernorm and kwargs --- sheeprl/algos/dreamer_v3/agent.py | 237 ++++++++++++++------------- sheeprl/algos/dreamer_v3/utils.py | 2 +- sheeprl/configs/algo/dreamer_v3.yaml | 29 ++-- sheeprl/models/models.py | 21 ++- sheeprl/utils/model.py | 17 ++ 5 files changed, 173 insertions(+), 133 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index a39e6d84..0faff0be 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -1,7 +1,7 @@ from __future__ import annotations import copy -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import gymnasium import hydra @@ -26,7 +26,7 @@ from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder from sheeprl.utils.fabric import get_single_device_fabric -from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward +from sheeprl.utils.model import LayerNormChannelLastFP32, LayerNormFP32, ModuleType, cnn_forward from sheeprl.utils.utils import symlog @@ -43,8 +43,10 @@ class CNNEncoder(nn.Module): image_size (Tuple[int, int]): the image size as (Height,Width). channels_multiplier (int): the multiplier for the output channels. Given the 4 stages, the 4 output channels will be [1, 2, 4, 8] * `channels_multiplier`. - layer_norm (bool, optional): whether to apply the layer normalization. - Defaults to True. + layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. + Defaults to LayerNormChannelLastFP32. + layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. + Default to {"eps": 1e-3}. activation (ModuleType, optional): the activation function. Defaults to nn.SiLU. stages (int, optional): how many stages for the CNN. @@ -56,7 +58,8 @@ def __init__( input_channels: Sequence[int], image_size: Tuple[int, int], channels_multiplier: int, - layer_norm: bool = True, + layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32, + layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, activation: ModuleType = nn.SiLU, stages: int = 4, ) -> None: @@ -68,14 +71,12 @@ def __init__( input_channels=self.input_dim[0], hidden_channels=(torch.tensor([2**i for i in range(stages)]) * channels_multiplier).tolist(), cnn_layer=nn.Conv2d, - layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, + layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity}, activation=activation, - norm_layer=[LayerNormChannelLast for _ in range(stages)] if layer_norm else None, - norm_args=( - [{"normalized_shape": (2**i) * channels_multiplier, "eps": 1e-3} for i in range(stages)] - if layer_norm - else None - ), + norm_layer=[layer_norm_cls] * stages, + norm_args=[ + {"normalized_shape": (2**i) * channels_multiplier, **layer_norm_kw} for i in range(stages) + ], ), nn.Flatten(-3, -1), ) @@ -100,8 +101,10 @@ class MLPEncoder(nn.Module): Defaults to 4. dense_units (int, optional): the dimension of every mlp. Defaults to 512. - layer_norm (bool, optional): whether to apply the layer normalization. - Defaults to True. + layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. + Defaults to LayerNormFP32. + layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. + Default to {"eps": 1e-3}. activation (ModuleType, optional): the activation function after every layer. Defaults to nn.SiLU. symlog_inputs (bool, optional): whether to squash the input with the symlog function. @@ -114,7 +117,8 @@ def __init__( input_dims: Sequence[int], mlp_layers: int = 4, dense_units: int = 512, - layer_norm: bool = True, + layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, activation: ModuleType = nn.SiLU, symlog_inputs: bool = True, ) -> None: @@ -126,11 +130,9 @@ def __init__( None, [dense_units] * mlp_layers, activation=activation, - layer_args={"bias": not layer_norm}, - norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, - norm_args=( - [{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] if layer_norm else None - ), + layer_args={"bias": layer_norm_cls == nn.Identity}, + norm_layer=layer_norm_cls, + norm_args={"normalized_shape": dense_units, **layer_norm_kw}, ) self.output_dim = dense_units self.symlog_inputs = symlog_inputs @@ -159,8 +161,10 @@ class CNNDecoder(nn.Module): image_size (Tuple[int, int]): the final image size. activation (nn.Module, optional): the activation function. Defaults to nn.SiLU. - layer_norm (bool, optional): whether to apply the layer normalization. - Defaults to True. + layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. + Defaults to LayerNormChannelLastFP32. + layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. + Default to {"eps": 1e-3}. stages (int): how many stages in the CNN decoder. """ @@ -173,7 +177,8 @@ def __init__( cnn_encoder_output_dim: int, image_size: Tuple[int, int], activation: nn.Module = nn.SiLU, - layer_norm: bool = True, + layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32, + layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, stages: int = 4, ) -> None: super().__init__() @@ -193,20 +198,17 @@ def __init__( + [self.output_dim[0]], cnn_layer=nn.ConvTranspose2d, layer_args=[ - {"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm} for _ in range(stages - 1) + {"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity} + for _ in range(stages - 1) ] + [{"kernel_size": 4, "stride": 2, "padding": 1}], activation=[activation for _ in range(stages - 1)] + [None], - norm_layer=[LayerNormChannelLast for _ in range(stages - 1)] + [None] if layer_norm else None, - norm_args=( - [ - {"normalized_shape": (2 ** (stages - i - 2)) * channels_multiplier, "eps": 1e-3} - for i in range(stages - 1) - ] - + [None] - if layer_norm - else None - ), + norm_layer=[layer_norm_cls for _ in range(stages - 1)] + [None], + norm_args=[ + {"normalized_shape": (2 ** (stages - i - 2)) * channels_multiplier, **layer_norm_kw} + for i in range(stages - 1) + ] + + [None], ), ) @@ -229,8 +231,10 @@ class MLPDecoder(nn.Module): Defaults to 4. dense_units (int, optional): the dimension of every mlp. Defaults to 512. - layer_norm (bool, optional): whether to apply the layer normalization. - Defaults to True. + layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. + Defaults to LayerNormFP32. + layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. + Default to {"eps": 1e-3}. activation (ModuleType, optional): the activation function after every layer. Defaults to nn.SiLU. """ @@ -243,7 +247,8 @@ def __init__( mlp_layers: int = 4, dense_units: int = 512, activation: ModuleType = nn.SiLU, - layer_norm: bool = True, + layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, ) -> None: super().__init__() self.output_dims = output_dims @@ -253,11 +258,9 @@ def __init__( None, [dense_units] * mlp_layers, activation=activation, - layer_args={"bias": not layer_norm}, - norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, - norm_args=( - [{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] if layer_norm else None - ), + layer_args={"bias": layer_norm_cls == nn.Identity}, + norm_layer=layer_norm_cls, + norm_args={"normalized_shape": dense_units, **layer_norm_kw}, ) self.heads = nn.ModuleList([nn.Linear(dense_units, mlp_dim) for mlp_dim in self.output_dims]) @@ -278,8 +281,10 @@ class RecurrentModel(nn.Module): recurrent_state_size (int): the size of the recurrent state. activation_fn (nn.Module): the activation function. Default to SiLU. - layer_norm (bool): whether to use the LayerNorm inside the GRU. - Defaults to True. + layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. + Defaults to LayerNormFP32. + layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. + Default to {"eps": 1e-3}. """ def __init__( @@ -288,7 +293,8 @@ def __init__( recurrent_state_size: int, dense_units: int, activation_fn: nn.Module = nn.SiLU, - layer_norm: bool = True, + layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, ) -> None: super().__init__() self.mlp = MLP( @@ -296,11 +302,18 @@ def __init__( output_dim=None, hidden_sizes=[dense_units], activation=activation_fn, - layer_args={"bias": not layer_norm}, - norm_layer=[nn.LayerNorm] if layer_norm else None, - norm_args=[{"normalized_shape": dense_units, "eps": 1e-3}] if layer_norm else None, + layer_args={"bias": layer_norm_cls == nn.Identity}, + norm_layer=[layer_norm_cls], + norm_args=[{"normalized_shape": dense_units, **layer_norm_kw}], + ) + self.rnn = LayerNormGRUCell( + dense_units, + recurrent_state_size, + bias=False, + batch_first=False, + layer_norm_cls=layer_norm_cls, + layer_norm_kw=layer_norm_kw, ) - self.rnn = LayerNormGRUCell(dense_units, recurrent_state_size, bias=False, batch_first=False, layer_norm=True) def forward(self, input: Tensor, recurrent_state: Tensor) -> Tensor: """ @@ -691,8 +704,10 @@ class Actor(nn.Module): Default to nn.SiLU. mlp_layers (int): the number of dense layers. Default to 5. - layer_norm (bool, optional): whether to apply the layer normalization. - Defaults to True. + layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. + Defaults to LayerNormFP32. + layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. + Default to {"eps": 1e-3}. unimix: (float, optional): the percentage of uniform distribution to inject into the categorical distribution over actions, i.e. given some logits `l` and probabilities `p = softmax(l)`, then `p = (1 - self.unimix) * p + self.unimix * unif`, @@ -714,7 +729,8 @@ def __init__( dense_units: int = 1024, activation: nn.Module = nn.SiLU, mlp_layers: int = 5, - layer_norm: bool = True, + layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, unimix: float = 0.01, action_clip: float = 1.0, ) -> None: @@ -739,11 +755,9 @@ def __init__( hidden_sizes=[dense_units] * mlp_layers, activation=activation, flatten_dim=None, - layer_args={"bias": not layer_norm}, - norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, - norm_args=( - [{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] if layer_norm else None - ), + layer_args={"bias": layer_norm_cls == nn.Identity}, + norm_layer=layer_norm_cls, + norm_args={"normalized_shape": dense_units, **layer_norm_kw}, ) if is_continuous: self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, np.sum(actions_dim) * 2)]) @@ -834,7 +848,8 @@ def __init__( dense_units: int = 1024, activation: nn.Module = nn.SiLU, mlp_layers: int = 5, - layer_norm: bool = True, + layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, unimix: float = 0.01, action_clip: float = 1.0, ) -> None: @@ -848,7 +863,8 @@ def __init__( dense_units=dense_units, activation=activation, mlp_layers=mlp_layers, - layer_norm=layer_norm, + layer_norm_cls=layer_norm_cls, + layer_norm_kw=layer_norm_kw, unimix=unimix, action_clip=action_clip, ) @@ -959,7 +975,8 @@ def build_agent( input_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.algo.cnn_keys.encoder], image_size=obs_space[cfg.algo.cnn_keys.encoder[0]].shape[-2:], channels_multiplier=world_model_cfg.encoder.cnn_channels_multiplier, - layer_norm=world_model_cfg.encoder.layer_norm, + layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder.cnn_layer_norm.cls), + layer_norm_kw=world_model_cfg.encoder.cnn_layer_norm.kw, activation=eval(world_model_cfg.encoder.cnn_act), stages=cnn_stages, ) @@ -973,15 +990,19 @@ def build_agent( mlp_layers=world_model_cfg.encoder.mlp_layers, dense_units=world_model_cfg.encoder.dense_units, activation=eval(world_model_cfg.encoder.dense_act), - layer_norm=world_model_cfg.encoder.layer_norm, + layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder._mlp_layer_norm.cls), + layer_norm_kw=world_model_cfg.encoder.mlp_layer_norm.kw, ) if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 else None ) encoder = MultiEncoder(cnn_encoder, mlp_encoder) recurrent_model = RecurrentModel( - **world_model_cfg.recurrent_model, input_size=int(sum(actions_dim) + stochastic_size), + recurrent_state_size=world_model_cfg.recurrent_model.recurrent_state_size, + dense_units=world_model_cfg.recurrent_model.dense_units, + layer_norm_cls=hydra.utils.get_class(world_model_cfg.recurrent_model.layer_norm.cls), + layer_norm_kw=world_model_cfg.recurrent_model.layer_norm.kw, ) represention_model_input_size = encoder.output_dim if not cfg.algo.decoupled_rssm: @@ -991,28 +1012,30 @@ def build_agent( output_dim=stochastic_size, hidden_sizes=[world_model_cfg.representation_model.hidden_size], activation=eval(world_model_cfg.representation_model.dense_act), - layer_args={"bias": not world_model_cfg.representation_model.layer_norm}, + layer_args={"bias": world_model_cfg.representation_model.layer_norm.cls == "torch.nn.Identity"}, flatten_dim=None, - norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None, - norm_args=( - [{"normalized_shape": world_model_cfg.representation_model.hidden_size}] - if world_model_cfg.representation_model.layer_norm - else None - ), + norm_layer=[hydra.utils.get_class(world_model_cfg.representation_model.layer_norm.cls)], + norm_args=[ + { + "normalized_shape": world_model_cfg.representation_model.hidden_size, + **world_model_cfg.representation_model.layer_norm.kw, + } + ], ) transition_model = MLP( input_dims=recurrent_state_size, output_dim=stochastic_size, hidden_sizes=[world_model_cfg.transition_model.hidden_size], activation=eval(world_model_cfg.transition_model.dense_act), - layer_args={"bias": not world_model_cfg.transition_model.layer_norm}, + layer_args={"bias": world_model_cfg.transition_model.layer_norm.cls == "torch.nn.Identity"}, flatten_dim=None, - norm_layer=[nn.LayerNorm] if world_model_cfg.transition_model.layer_norm else None, - norm_args=( - [{"normalized_shape": world_model_cfg.transition_model.hidden_size}] - if world_model_cfg.transition_model.layer_norm - else None - ), + norm_layer=[hydra.utils.get_class(world_model_cfg.transition_model.layer_norm.cls)], + norm_args=[ + { + "normalized_shape": world_model_cfg.transition_model.hidden_size, + **world_model_cfg.transition_model.layer_norm.kw, + } + ], ) if cfg.algo.decoupled_rssm: rssm_cls = DecoupledRSSM @@ -1035,7 +1058,8 @@ def build_agent( cnn_encoder_output_dim=cnn_encoder.output_dim, image_size=obs_space[cfg.algo.cnn_keys.decoder[0]].shape[-2:], activation=eval(world_model_cfg.observation_model.cnn_act), - layer_norm=world_model_cfg.observation_model.layer_norm, + layer_norm_cls=hydra.utils.get_class(world_model_cfg.observation_model.cnn_layer_norm.cls), + layer_norm_kw=world_model_cfg.observation_model.mlp_layer_norm.kw, stages=cnn_stages, ) if cfg.algo.cnn_keys.decoder is not None and len(cfg.algo.cnn_keys.decoder) > 0 @@ -1049,7 +1073,8 @@ def build_agent( mlp_layers=world_model_cfg.observation_model.mlp_layers, dense_units=world_model_cfg.observation_model.dense_units, activation=eval(world_model_cfg.observation_model.dense_act), - layer_norm=world_model_cfg.observation_model.layer_norm, + layer_norm_cls=hydra.utils.get_class(world_model_cfg.observation_model.mlp_layer_norm.cls), + layer_norm_kw=world_model_cfg.observation_model.mlp_layer_norm.kw, ) if cfg.algo.mlp_keys.decoder is not None and len(cfg.algo.mlp_keys.decoder) > 0 else None @@ -1060,42 +1085,26 @@ def build_agent( output_dim=world_model_cfg.reward_model.bins, hidden_sizes=[world_model_cfg.reward_model.dense_units] * world_model_cfg.reward_model.mlp_layers, activation=eval(world_model_cfg.reward_model.dense_act), - layer_args={"bias": not world_model_cfg.reward_model.layer_norm}, + layer_args={"bias": world_model_cfg.reward_model.layer_norm.cls == "torch.nn.Identity"}, flatten_dim=None, - norm_layer=( - [nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)] - if world_model_cfg.reward_model.layer_norm - else None - ), - norm_args=( - [ - {"normalized_shape": world_model_cfg.reward_model.dense_units} - for _ in range(world_model_cfg.reward_model.mlp_layers) - ] - if world_model_cfg.reward_model.layer_norm - else None - ), + norm_layer=hydra.utils.get_class(world_model_cfg.reward_model.layer_norm.cls), + norm_args={ + "normalized_shape": world_model_cfg.reward_model.dense_units, + **world_model_cfg.reward_model.layer_norm.kw, + }, ) continue_model = MLP( input_dims=latent_state_size, output_dim=1, hidden_sizes=[world_model_cfg.discount_model.dense_units] * world_model_cfg.discount_model.mlp_layers, activation=eval(world_model_cfg.discount_model.dense_act), - layer_args={"bias": not world_model_cfg.discount_model.layer_norm}, + layer_args={"bias": world_model_cfg.discount_model.layer_norm.cls == "torch.nn.Identity"}, flatten_dim=None, - norm_layer=( - [nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)] - if world_model_cfg.discount_model.layer_norm - else None - ), - norm_args=( - [ - {"normalized_shape": world_model_cfg.discount_model.dense_units} - for _ in range(world_model_cfg.discount_model.mlp_layers) - ] - if world_model_cfg.discount_model.layer_norm - else None - ), + norm_layer=hydra.utils.get_class(world_model_cfg.discount_model.layer_norm.cls), + norm_args={ + "normalized_shape": world_model_cfg.discount_model.dense_units, + **world_model_cfg.discount_model.layer_norm.kw, + }, ) world_model = WorldModel( encoder.apply(init_weights), @@ -1115,7 +1124,8 @@ def build_agent( activation=eval(actor_cfg.dense_act), mlp_layers=actor_cfg.mlp_layers, distribution_cfg=cfg.distribution, - layer_norm=actor_cfg.layer_norm, + layer_norm_cls=hydra.utils.get_class(actor_cfg.layer_norm.cls), + layer_norm_kw=actor_cfg.layer_norm.kw, unimix=cfg.algo.unimix, action_clip=actor_cfg.action_clip, ) @@ -1124,14 +1134,13 @@ def build_agent( output_dim=critic_cfg.bins, hidden_sizes=[critic_cfg.dense_units] * critic_cfg.mlp_layers, activation=eval(critic_cfg.dense_act), - layer_args={"bias": not critic_cfg.layer_norm}, + layer_args={"bias": critic_cfg.layer_norm.cls == "torch.nn.Identity"}, flatten_dim=None, - norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, - norm_args=( - [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] - if critic_cfg.layer_norm - else None - ), + norm_layer=hydra.utils.get_class(critic_cfg.layer_norm.cls), + norm_args={ + "normalized_shape": critic_cfg.dense_units, + **critic_cfg.layer_norm.kw, + }, ) actor.apply(init_weights) critic.apply(init_weights) diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index efd10ac3..bb8bf297 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -54,7 +54,7 @@ def __init__( self.register_buffer("high", torch.zeros((), dtype=torch.float32)) def forward(self, x: Tensor, fabric: Fabric) -> Any: - gathered_x = fabric.all_gather(x).detach() + gathered_x = fabric.all_gather(x).float().detach() low = torch.quantile(gathered_x, self._percentile_low) high = torch.quantile(gathered_x, self._percentile_high) self.low = self._decay * self.low + (1 - self._decay) * low diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index e0b60985..9b6e85fd 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -25,7 +25,14 @@ mlp_keys: decoder: ${algo.mlp_keys.encoder} # Model related parameters -layer_norm: True +cnn_layer_norm: + cls: sheeprl.utils.model.LayerNormChannelLastFP32 + kw: + eps: 1e-3 +mlp_layer_norm: + cls: sheeprl.utils.model.LayerNormFP32 + kw: + eps: 1e-3 dense_units: 1024 mlp_layers: 5 dense_act: torch.nn.SiLU @@ -51,26 +58,27 @@ world_model: cnn_act: ${algo.cnn_act} dense_act: ${algo.dense_act} mlp_layers: ${algo.mlp_layers} - layer_norm: ${algo.layer_norm} + cnn_layer_norm: ${algo.cnn_layer_norm} + mlp_layer_norm: ${algo.mlp_layer_norm} dense_units: ${algo.dense_units} # Recurrent model recurrent_model: recurrent_state_size: 4096 - layer_norm: True + layer_norm: ${algo.mlp_layer_norm} dense_units: ${algo.dense_units} # Prior transition_model: hidden_size: 1024 dense_act: ${algo.dense_act} - layer_norm: ${algo.layer_norm} + layer_norm: ${algo.mlp_layer_norm} # Posterior representation_model: hidden_size: 1024 dense_act: ${algo.dense_act} - layer_norm: ${algo.layer_norm} + layer_norm: ${algo.mlp_layer_norm} # Decoder observation_model: @@ -78,14 +86,15 @@ world_model: cnn_act: ${algo.cnn_act} dense_act: ${algo.dense_act} mlp_layers: ${algo.mlp_layers} - layer_norm: ${algo.layer_norm} + cnn_layer_norm: ${algo.cnn_layer_norm} + mlp_layer_norm: ${algo.mlp_layer_norm} dense_units: ${algo.dense_units} # Reward model reward_model: dense_act: ${algo.dense_act} mlp_layers: ${algo.mlp_layers} - layer_norm: ${algo.layer_norm} + layer_norm: ${algo.mlp_layer_norm} dense_units: ${algo.dense_units} bins: 255 @@ -94,7 +103,7 @@ world_model: learnable: True dense_act: ${algo.dense_act} mlp_layers: ${algo.mlp_layers} - layer_norm: ${algo.layer_norm} + layer_norm: ${algo.mlp_layer_norm} dense_units: ${algo.dense_units} # World model optimizer @@ -112,7 +121,7 @@ actor: init_std: 2.0 dense_act: ${algo.dense_act} mlp_layers: ${algo.mlp_layers} - layer_norm: ${algo.layer_norm} + layer_norm: ${algo.mlp_layer_norm} dense_units: ${algo.dense_units} clip_gradients: 100.0 unimix: ${algo.unimix} @@ -136,7 +145,7 @@ actor: critic: dense_act: ${algo.dense_act} mlp_layers: ${algo.mlp_layers} - layer_norm: ${algo.layer_norm} + layer_norm: ${algo.mlp_layer_norm} dense_units: ${algo.dense_units} per_rank_target_network_update_freq: 1 tau: 0.02 diff --git a/sheeprl/models/models.py b/sheeprl/models/models.py index 089c45ce..646395b5 100644 --- a/sheeprl/models/models.py +++ b/sheeprl/models/models.py @@ -4,7 +4,7 @@ import warnings from math import prod -from typing import Dict, Optional, Sequence, Union, no_type_check +from typing import Any, Callable, Dict, Optional, Sequence, Union, no_type_check import torch import torch.nn.functional as F @@ -342,12 +342,20 @@ class LayerNormGRUCell(nn.Module): Defaults to True. batch_first (bool, optional): whether the first dimension represent the batch dimension or not. Defaults to False. - layer_norm (bool, optional): whether to apply a LayerNorm after the input projection. - Defaults to False. + layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. + Defaults to nn.Identiy. + layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. + Default to {}. """ def __init__( - self, input_size: int, hidden_size: int, bias: bool = True, batch_first: bool = False, layer_norm: bool = False + self, + input_size: int, + hidden_size: int, + bias: bool = True, + batch_first: bool = False, + layer_norm_cls: Callable[..., nn.Module] = nn.Identity, + layer_norm_kw: Dict[str, Any] = {}, ) -> None: super().__init__() self.input_size = input_size @@ -355,10 +363,7 @@ def __init__( self.bias = bias self.batch_first = batch_first self.linear = nn.Linear(input_size + hidden_size, 3 * hidden_size, bias=self.bias) - if layer_norm: - self.layer_norm = torch.nn.LayerNorm(3 * hidden_size) - else: - self.layer_norm = nn.Identity() + self.layer_norm = layer_norm_cls(3 * hidden_size, **layer_norm_kw) def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: is_3d = input.dim() == 3 diff --git a/sheeprl/utils/model.py b/sheeprl/utils/model.py index f74ba626..1552020e 100644 --- a/sheeprl/utils/model.py +++ b/sheeprl/utils/model.py @@ -1,6 +1,7 @@ """ Adapted from: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py """ + from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch @@ -233,3 +234,19 @@ def forward(self, x: Tensor) -> Tensor: x = super().forward(x) x = x.permute(0, 3, 1, 2) return x + + +class LayerNormChannelLastFP32(LayerNormChannelLast): + def forward(self, x: Tensor) -> Tensor: + input_dtype = x.dtype + x = x.to(torch.float32) + out = super().forward(x) + return out.to(input_dtype) + + +class LayerNormFP32(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + input_dtype = x.dtype + x = x.to(torch.float32) + out = super().forward(x) + return out.to(input_dtype) From 4f15952bcd9939d9be946648e10a325c6dd5a393 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 17:49:46 +0200 Subject: [PATCH 41/51] feat: added first recurrent state learnable in dv3 --- sheeprl/algos/dreamer_v3/agent.py | 64 +++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 0faff0be..77efd216 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -354,7 +354,7 @@ class RSSM(nn.Module): def __init__( self, - recurrent_model: nn.Module | _FabricModule, + recurrent_model: RecurrentModel | _FabricModule, representation_model: nn.Module | _FabricModule, transition_model: nn.Module | _FabricModule, distribution_cfg: Dict[str, Any], @@ -368,6 +368,14 @@ def __init__( self.discrete = discrete self.unimix = unimix self.distribution_cfg = distribution_cfg + self.initial_recurrent_state = nn.Parameter( + torch.zeros(recurrent_model.recurrent_state_size, dtype=torch.float32) + ) + + def get_initial_states(self, batch_shape: Sequence[int] | torch.Size) -> Tuple[Tensor, Tensor]: + initial_recurrent_state = torch.tanh(self.initial_recurrent_state).expand(*batch_shape, -1) + initial_posterior = self._transition(initial_recurrent_state, sample_state=False)[1] + return initial_recurrent_state, initial_posterior def dynamic( self, posterior: Tensor, recurrent_state: Tensor, action: Tensor, embedded_obs: Tensor, is_first: Tensor @@ -399,11 +407,12 @@ def dynamic( from the recurrent state and the embbedded observation. """ action = (1 - is_first) * action - recurrent_state = (1 - is_first) * recurrent_state + is_first * torch.tanh(torch.zeros_like(recurrent_state)) + + initial_recurrent_state, initial_posterior = self.get_initial_states(recurrent_state.shape[:2]) + recurrent_state = (1 - is_first) * recurrent_state + is_first * initial_recurrent_state posterior = posterior.view(*posterior.shape[:-2], -1) - posterior = (1 - is_first) * posterior + is_first * self._transition(recurrent_state, sample_state=False)[ - 1 - ].view_as(posterior) + posterior = (1 - is_first) * posterior + is_first * initial_posterior.view_as(posterior) + recurrent_state = self.recurrent_model(torch.cat((posterior, action), -1), recurrent_state) prior_logits, prior = self._transition(recurrent_state) posterior_logits, posterior = self._representation(recurrent_state, embedded_obs) @@ -535,11 +544,12 @@ def dynamic( from the recurrent state and the embbedded observation. """ action = (1 - is_first) * action - recurrent_state = (1 - is_first) * recurrent_state + is_first * torch.tanh(torch.zeros_like(recurrent_state)) + + initial_recurrent_state, initial_posterior = self.get_initial_states(recurrent_state.shape[:2]) + recurrent_state = (1 - is_first) * recurrent_state + is_first * initial_recurrent_state posterior = posterior.view(*posterior.shape[:-2], -1) - posterior = (1 - is_first) * posterior + is_first * self._transition(recurrent_state, sample_state=False)[ - 1 - ].view_as(posterior) + posterior = (1 - is_first) * posterior + is_first * initial_posterior.view_as(posterior) + recurrent_state = self.recurrent_model(torch.cat((posterior, action), -1), recurrent_state) prior_logits, prior = self._transition(recurrent_state) return recurrent_state, prior, prior_logits @@ -614,7 +624,7 @@ def __init__( distribution_cfg=actor.distribution_cfg, discrete=rssm.discrete, unimix=rssm.unimix, - ) + ).to(single_device_fabric.device) self.actor = single_device_fabric.setup_module(getattr(actor, "module", actor)) self.device = single_device_fabric.device self.actions_dim = actions_dim @@ -997,6 +1007,7 @@ def build_agent( else None ) encoder = MultiEncoder(cnn_encoder, mlp_encoder) + recurrent_model = RecurrentModel( input_size=int(sum(actions_dim) + stochastic_size), recurrent_state_size=world_model_cfg.recurrent_model.recurrent_state_size, @@ -1007,14 +1018,15 @@ def build_agent( represention_model_input_size = encoder.output_dim if not cfg.algo.decoupled_rssm: represention_model_input_size += recurrent_state_size + representation_ln_cls = hydra.utils.get_class(world_model_cfg.representation_model.layer_norm.cls) representation_model = MLP( input_dims=represention_model_input_size, output_dim=stochastic_size, hidden_sizes=[world_model_cfg.representation_model.hidden_size], activation=eval(world_model_cfg.representation_model.dense_act), - layer_args={"bias": world_model_cfg.representation_model.layer_norm.cls == "torch.nn.Identity"}, + layer_args={"bias": representation_ln_cls == nn.Identity}, flatten_dim=None, - norm_layer=[hydra.utils.get_class(world_model_cfg.representation_model.layer_norm.cls)], + norm_layer=[representation_ln_cls], norm_args=[ { "normalized_shape": world_model_cfg.representation_model.hidden_size, @@ -1022,14 +1034,15 @@ def build_agent( } ], ) + transition_ln_cls = hydra.utils.get_class(world_model_cfg.transition_model.layer_norm.cls) transition_model = MLP( input_dims=recurrent_state_size, output_dim=stochastic_size, hidden_sizes=[world_model_cfg.transition_model.hidden_size], activation=eval(world_model_cfg.transition_model.dense_act), - layer_args={"bias": world_model_cfg.transition_model.layer_norm.cls == "torch.nn.Identity"}, + layer_args={"bias": transition_ln_cls == nn.Identity}, flatten_dim=None, - norm_layer=[hydra.utils.get_class(world_model_cfg.transition_model.layer_norm.cls)], + norm_layer=[transition_ln_cls], norm_args=[ { "normalized_shape": world_model_cfg.transition_model.hidden_size, @@ -1037,6 +1050,7 @@ def build_agent( } ], ) + if cfg.algo.decoupled_rssm: rssm_cls = DecoupledRSSM else: @@ -1048,7 +1062,8 @@ def build_agent( distribution_cfg=cfg.distribution, discrete=world_model_cfg.discrete_size, unimix=cfg.algo.unimix, - ) + ).to(fabric.device) + cnn_decoder = ( CNNDecoder( keys=cfg.algo.cnn_keys.decoder, @@ -1080,27 +1095,31 @@ def build_agent( else None ) observation_model = MultiDecoder(cnn_decoder, mlp_decoder) + + reward_ln_cls = hydra.utils.get_class(world_model_cfg.reward_model.layer_norm.cls) reward_model = MLP( input_dims=latent_state_size, output_dim=world_model_cfg.reward_model.bins, hidden_sizes=[world_model_cfg.reward_model.dense_units] * world_model_cfg.reward_model.mlp_layers, activation=eval(world_model_cfg.reward_model.dense_act), - layer_args={"bias": world_model_cfg.reward_model.layer_norm.cls == "torch.nn.Identity"}, + layer_args={"bias": reward_ln_cls == nn.Identity}, flatten_dim=None, - norm_layer=hydra.utils.get_class(world_model_cfg.reward_model.layer_norm.cls), + norm_layer=reward_ln_cls, norm_args={ "normalized_shape": world_model_cfg.reward_model.dense_units, **world_model_cfg.reward_model.layer_norm.kw, }, ) + + discount_ln_cls = hydra.utils.get_class(world_model_cfg.discount_model.layer_norm.cls) continue_model = MLP( input_dims=latent_state_size, output_dim=1, hidden_sizes=[world_model_cfg.discount_model.dense_units] * world_model_cfg.discount_model.mlp_layers, activation=eval(world_model_cfg.discount_model.dense_act), - layer_args={"bias": world_model_cfg.discount_model.layer_norm.cls == "torch.nn.Identity"}, + layer_args={"bias": discount_ln_cls == nn.Identity}, flatten_dim=None, - norm_layer=hydra.utils.get_class(world_model_cfg.discount_model.layer_norm.cls), + norm_layer=discount_ln_cls, norm_args={ "normalized_shape": world_model_cfg.discount_model.dense_units, **world_model_cfg.discount_model.layer_norm.kw, @@ -1113,6 +1132,7 @@ def build_agent( reward_model.apply(init_weights), continue_model.apply(init_weights), ) + actor_cls = hydra.utils.get_class(cfg.algo.actor.cls) actor: Actor | MinedojoActor = actor_cls( latent_state_size=latent_state_size, @@ -1129,14 +1149,16 @@ def build_agent( unimix=cfg.algo.unimix, action_clip=actor_cfg.action_clip, ) + + critic_ln_cls = hydra.utils.get_class(critic_cfg.layer_norm.cls) critic = MLP( input_dims=latent_state_size, output_dim=critic_cfg.bins, hidden_sizes=[critic_cfg.dense_units] * critic_cfg.mlp_layers, activation=eval(critic_cfg.dense_act), - layer_args={"bias": critic_cfg.layer_norm.cls == "torch.nn.Identity"}, + layer_args={"bias": critic_ln_cls == nn.Identity}, flatten_dim=None, - norm_layer=hydra.utils.get_class(critic_cfg.layer_norm.cls), + norm_layer=critic_ln_cls, norm_args={ "normalized_shape": critic_cfg.dense_units, **critic_cfg.layer_norm.kw, From 66d4d92e127664bdd3ff6640501538032cf7d22f Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 17:54:37 +0200 Subject: [PATCH 42/51] feat: update dv3 ww configs --- .../configs/exp/dreamer_v3_dmc_walker_walk.yaml | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml index af38aee4..18719b55 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml @@ -38,9 +38,22 @@ algo: - rgb mlp_keys: encoder: [] - learning_starts: 1024 + learning_starts: 1300 replay_ratio: 0.5 # Metric metric: log_every: 5000 + +fabric: + accelerator: cuda + precision: bf16-mixed + # precision: None + # plugins: + # - _target_: lightning.fabric.plugins.precision.MixedPrecision + # precision: 16-mixed + # device: cuda + # scaler: + # _target_: torch.cuda.amp.GradScaler + # init_scale: 1e4 + # growth_interval: 1000 \ No newline at end of file From d536264fd1673640a57f93805e4301adf1456211 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 18:00:21 +0200 Subject: [PATCH 43/51] feat: learned initial recurrent state when resetting the player states) --- sheeprl/algos/dreamer_v3/agent.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 77efd216..c2b17063 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -646,18 +646,12 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: """ if reset_envs is None or len(reset_envs) == 0: self.actions = torch.zeros(1, self.num_envs, np.sum(self.actions_dim), device=self.device) - self.recurrent_state = torch.tanh( - torch.zeros(1, self.num_envs, self.recurrent_state_size, device=self.device) - ) - self.stochastic_state = self.rssm._transition(self.recurrent_state, sample_state=False)[1].reshape( - 1, self.num_envs, -1 - ) + self.recurrent_state, stochastic_state = self.rssm.get_initial_states((1, self.num_envs)) + self.stochastic_state = stochastic_state.reshape(1, self.num_envs, -1) else: self.actions[:, reset_envs] = torch.zeros_like(self.actions[:, reset_envs]) - self.recurrent_state[:, reset_envs] = torch.tanh(torch.zeros_like(self.recurrent_state[:, reset_envs])) - self.stochastic_state[:, reset_envs] = self.rssm._transition( - self.recurrent_state[:, reset_envs], sample_state=False - )[1].reshape(1, len(reset_envs), -1) + self.recurrent_state[:, reset_envs], stochastic_state = self.rssm.get_initial_states((1, len(reset_envs))) + self.stochastic_state[:, reset_envs] = stochastic_state.reshape(1, len(reset_envs), -1) def get_actions( self, From 6108afb7f994c50c2af2999014fb74b791f59428 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 2 Apr 2024 18:53:57 +0200 Subject: [PATCH 44/51] fix: env interaction --- sheeprl/algos/dreamer_v1/dreamer_v1.py | 2 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 2 +- sheeprl/algos/dreamer_v3/agent.py | 1 + sheeprl/algos/dreamer_v3/dreamer_v3.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 2 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 2 +- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 2 +- 10 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 079a8065..381cc2b6 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -594,7 +594,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step) + real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 97dccccf..05ab0a86 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -620,7 +620,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(normalized_obs, mask) + real_actions = actions = player.get_actions(normalized_obs, mask=mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index c2b17063..325523c3 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -314,6 +314,7 @@ def __init__( layer_norm_cls=layer_norm_cls, layer_norm_kw=layer_norm_kw, ) + self.recurrent_state_size = recurrent_state_size def forward(self, input: Tensor, recurrent_state: Tensor) -> Tensor: """ diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index d0e76a81..7a7d2c17 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -586,7 +586,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(preprocessed_obs, mask) + real_actions = actions = player.get_actions(preprocessed_obs, mask=mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 632633b9..081a2a79 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -620,7 +620,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step) + real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index f7dfc34d..7779274a 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -276,7 +276,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step) + real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 4b7cc09b..a474f2c2 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -757,7 +757,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(normalized_obs, mask) + real_actions = actions = player.get_actions(normalized_obs, mask=mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 365722b1..4e0dcc6b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -297,7 +297,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(normalized_obs, mask) + real_actions = actions = player.get_actions(normalized_obs, mask=mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index bb622993..e74d10db 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -827,7 +827,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(preprocessed_obs, mask) + real_actions = actions = player.get_actions(preprocessed_obs, mask=mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index e220fea8..d1110472 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -286,7 +286,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(preprocessed_obs, mask) + real_actions = actions = player.get_actions(preprocessed_obs, mask=mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() From 01484e99b01d333e27976eb7d60e60db027ff048 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 3 Apr 2024 09:21:59 +0200 Subject: [PATCH 45/51] fix: avoid to rewrite with layer_norm kwargs --- sheeprl/algos/dreamer_v3/agent.py | 12 ++++++------ sheeprl/models/models.py | 2 ++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 325523c3..51a89738 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -75,7 +75,7 @@ def __init__( activation=activation, norm_layer=[layer_norm_cls] * stages, norm_args=[ - {"normalized_shape": (2**i) * channels_multiplier, **layer_norm_kw} for i in range(stages) + {**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages) ], ), nn.Flatten(-3, -1), @@ -132,7 +132,7 @@ def __init__( activation=activation, layer_args={"bias": layer_norm_cls == nn.Identity}, norm_layer=layer_norm_cls, - norm_args={"normalized_shape": dense_units, **layer_norm_kw}, + norm_args={**layer_norm_kw, "normalized_shape": dense_units}, ) self.output_dim = dense_units self.symlog_inputs = symlog_inputs @@ -205,7 +205,7 @@ def __init__( activation=[activation for _ in range(stages - 1)] + [None], norm_layer=[layer_norm_cls for _ in range(stages - 1)] + [None], norm_args=[ - {"normalized_shape": (2 ** (stages - i - 2)) * channels_multiplier, **layer_norm_kw} + {**layer_norm_kw, "normalized_shape": (2 ** (stages - i - 2)) * channels_multiplier} for i in range(stages - 1) ] + [None], @@ -260,7 +260,7 @@ def __init__( activation=activation, layer_args={"bias": layer_norm_cls == nn.Identity}, norm_layer=layer_norm_cls, - norm_args={"normalized_shape": dense_units, **layer_norm_kw}, + norm_args={**layer_norm_kw, "normalized_shape": dense_units}, ) self.heads = nn.ModuleList([nn.Linear(dense_units, mlp_dim) for mlp_dim in self.output_dims]) @@ -304,7 +304,7 @@ def __init__( activation=activation_fn, layer_args={"bias": layer_norm_cls == nn.Identity}, norm_layer=[layer_norm_cls], - norm_args=[{"normalized_shape": dense_units, **layer_norm_kw}], + norm_args=[{**layer_norm_kw, "normalized_shape": dense_units}], ) self.rnn = LayerNormGRUCell( dense_units, @@ -762,7 +762,7 @@ def __init__( flatten_dim=None, layer_args={"bias": layer_norm_cls == nn.Identity}, norm_layer=layer_norm_cls, - norm_args={"normalized_shape": dense_units, **layer_norm_kw}, + norm_args={**layer_norm_kw, "normalized_shape": dense_units}, ) if is_continuous: self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, np.sum(actions_dim) * 2)]) diff --git a/sheeprl/models/models.py b/sheeprl/models/models.py index 646395b5..dbc810ad 100644 --- a/sheeprl/models/models.py +++ b/sheeprl/models/models.py @@ -363,6 +363,8 @@ def __init__( self.bias = bias self.batch_first = batch_first self.linear = nn.Linear(input_size + hidden_size, 3 * hidden_size, bias=self.bias) + # Avoid multiple values for the `normalized_shape` argument + layer_norm_kw.pop("normalized_shape", None) self.layer_norm = layer_norm_cls(3 * hidden_size, **layer_norm_kw) def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: From fd62f69737fe1366e5ee311e3c30442344ccfd3a Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 3 Apr 2024 09:28:12 +0200 Subject: [PATCH 46/51] fix: avoid to rewrite with layer_norm kwargs --- sheeprl/algos/dreamer_v3/agent.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 51a89738..03733eea 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -1024,8 +1024,8 @@ def build_agent( norm_layer=[representation_ln_cls], norm_args=[ { - "normalized_shape": world_model_cfg.representation_model.hidden_size, **world_model_cfg.representation_model.layer_norm.kw, + "normalized_shape": world_model_cfg.representation_model.hidden_size, } ], ) @@ -1040,8 +1040,8 @@ def build_agent( norm_layer=[transition_ln_cls], norm_args=[ { - "normalized_shape": world_model_cfg.transition_model.hidden_size, **world_model_cfg.transition_model.layer_norm.kw, + "normalized_shape": world_model_cfg.transition_model.hidden_size, } ], ) @@ -1101,8 +1101,8 @@ def build_agent( flatten_dim=None, norm_layer=reward_ln_cls, norm_args={ - "normalized_shape": world_model_cfg.reward_model.dense_units, **world_model_cfg.reward_model.layer_norm.kw, + "normalized_shape": world_model_cfg.reward_model.dense_units, }, ) @@ -1116,8 +1116,8 @@ def build_agent( flatten_dim=None, norm_layer=discount_ln_cls, norm_args={ - "normalized_shape": world_model_cfg.discount_model.dense_units, **world_model_cfg.discount_model.layer_norm.kw, + "normalized_shape": world_model_cfg.discount_model.dense_units, }, ) world_model = WorldModel( @@ -1155,8 +1155,8 @@ def build_agent( flatten_dim=None, norm_layer=critic_ln_cls, norm_args={ - "normalized_shape": critic_cfg.dense_units, **critic_cfg.layer_norm.kw, + "normalized_shape": critic_cfg.dense_units, }, ) actor.apply(init_weights) From 2df458a6a99fba1b4015c461b77e05d29615e312 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 3 Apr 2024 09:38:42 +0200 Subject: [PATCH 47/51] fix: dv2 LayerNormGruCell creation --- sheeprl/algos/dreamer_v2/agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index a2eafe87..de4707d6 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -278,7 +278,9 @@ def __init__( norm_layer=[nn.LayerNorm] if layer_norm else None, norm_args=[{"normalized_shape": dense_units}] if layer_norm else None, ) - self.rnn = LayerNormGRUCell(dense_units, recurrent_state_size, bias=True, batch_first=False, layer_norm=True) + self.rnn = LayerNormGRUCell( + dense_units, recurrent_state_size, bias=True, batch_first=False, layer_norm=nn.LayerNorm + ) def forward(self, input: Tensor, recurrent_state: Tensor) -> Tensor: """ From fe4807a8cba92e0422bb393966c511020af13db2 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 3 Apr 2024 09:41:21 +0200 Subject: [PATCH 48/51] fix: dv2 LayerNormGruCell creation --- sheeprl/algos/dreamer_v2/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index de4707d6..c0f7a369 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -279,7 +279,7 @@ def __init__( norm_args=[{"normalized_shape": dense_units}] if layer_norm else None, ) self.rnn = LayerNormGRUCell( - dense_units, recurrent_state_size, bias=True, batch_first=False, layer_norm=nn.LayerNorm + dense_units, recurrent_state_size, bias=True, batch_first=False, layer_norm_cls=nn.LayerNorm ) def forward(self, input: Tensor, recurrent_state: Tensor) -> Tensor: From 6a8c6cb6aab2dddd0a86999a45353d70a1eb5cd7 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 3 Apr 2024 10:15:01 +0200 Subject: [PATCH 49/51] fix: update p2e dv3 + fix tests --- sheeprl/algos/p2e_dv3/agent.py | 33 +++++++++++++++---------------- sheeprl/configs/algo/p2e_dv3.yaml | 2 +- tests/test_algos/test_algos.py | 9 ++++++--- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/sheeprl/algos/p2e_dv3/agent.py b/sheeprl/algos/p2e_dv3/agent.py index d19b329e..2f35674d 100644 --- a/sheeprl/algos/p2e_dv3/agent.py +++ b/sheeprl/algos/p2e_dv3/agent.py @@ -106,13 +106,15 @@ def build_agent( activation=eval(actor_cfg.dense_act), mlp_layers=actor_cfg.mlp_layers, distribution_cfg=cfg.distribution, - layer_norm=actor_cfg.layer_norm, + layer_norm_cls=hydra.utils.get_class(actor_cfg.layer_norm.cls), + layer_norm_kw=actor_cfg.layer_norm.kw, unimix=cfg.algo.unimix, ) single_device_fabric = get_single_device_fabric(fabric) critics_exploration = {} intrinsic_critics = 0 + critic_ln_cls = hydra.utils.get_class(critic_cfg.layer_norm.cls) for k, v in cfg.algo.critics_exploration.items(): if v.weight > 0: if v.reward_type == "intrinsic": @@ -126,13 +128,12 @@ def build_agent( hidden_sizes=[critic_cfg.dense_units] * critic_cfg.mlp_layers, activation=eval(critic_cfg.dense_act), flatten_dim=None, - layer_args={"bias": not critic_cfg.layer_norm}, - norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, - norm_args=( - [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] - if critic_cfg.layer_norm - else None - ), + layer_args={"bias": critic_ln_cls == nn.Identity}, + norm_layer=critic_ln_cls, + norm_args={ + **critic_cfg.layer_norm.kw, + "normalized_shape": critic_cfg.dense_units, + }, ), } critics_exploration[k]["module"].apply(init_weights) @@ -170,6 +171,7 @@ def build_agent( # initialize the ensembles with different seeds to be sure they have different weights ens_list = [] cfg_ensembles = cfg.algo.ensembles + ensembles_ln_cls = hydra.utils.get_class(cfg_ensembles.layer_norm.cls) with isolate_rng(): for i in range(cfg_ensembles.n): fabric.seed_everything(cfg.seed + i) @@ -184,15 +186,12 @@ def build_agent( hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers, activation=eval(cfg_ensembles.dense_act), flatten_dim=None, - layer_args={"bias": not cfg.algo.ensembles.layer_norm}, - norm_layer=( - [nn.LayerNorm for _ in range(cfg_ensembles.mlp_layers)] if cfg_ensembles.layer_norm else None - ), - norm_args=( - [{"normalized_shape": cfg_ensembles.dense_units} for _ in range(cfg_ensembles.mlp_layers)] - if cfg_ensembles.layer_norm - else None - ), + layer_args={"bias": ensembles_ln_cls == nn.Identity}, + norm_layer=ensembles_ln_cls, + norm_args={ + **cfg_ensembles.layer_norm.kw, + "normalized_shape": cfg_ensembles.dense_units, + }, ).apply(init_weights) ) ensembles = nn.ModuleList(ens_list) diff --git a/sheeprl/configs/algo/p2e_dv3.yaml b/sheeprl/configs/algo/p2e_dv3.yaml index 292d00b4..49ce538c 100644 --- a/sheeprl/configs/algo/p2e_dv3.yaml +++ b/sheeprl/configs/algo/p2e_dv3.yaml @@ -22,7 +22,7 @@ ensembles: dense_act: ${algo.dense_act} mlp_layers: ${algo.mlp_layers} dense_units: ${algo.dense_units} - layer_norm: ${algo.layer_norm} + layer_norm: ${algo.mlp_layer_norm} clip_gradients: 100 optimizer: lr: 1e-4 diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 132d7a9d..db6f22a5 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -458,9 +458,10 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.world_model.recurrent_model.recurrent_state_size=8", "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", - "algo.layer_norm=True", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_layer_norm.cls=torch.nn.LayerNorm", + "algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast", ] with mock.patch.object(sys, "argv", args): @@ -492,11 +493,12 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.world_model.recurrent_model.recurrent_state_size=8", "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", - "algo.layer_norm=True", "buffer.checkpoint=True", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", "checkpoint.save_last=True", + "algo.mlp_layer_norm.cls=torch.nn.LayerNorm", + "algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast", ] with mock.patch.object(sys, "argv", args): @@ -535,9 +537,10 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.world_model.recurrent_model.recurrent_state_size=8", "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", - "algo.layer_norm=True", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_layer_norm.cls=torch.nn.LayerNorm", + "algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast", ] with mock.patch.object(sys, "argv", args): run() From 1e6f30978cfaa079630042f970937d678218dcc6 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 3 Apr 2024 10:22:31 +0200 Subject: [PATCH 50/51] fix: tests --- tests/test_algos/test_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index 0e95871f..d1ad2839 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -172,7 +172,7 @@ def test_resume_from_checkpoint_env_error(): "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " - "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " + "algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, From e77c52dc72ca2f4d448fffea1910e704144e524a Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 3 Apr 2024 10:30:58 +0200 Subject: [PATCH 51/51] fix: tests --- tests/test_algos/test_cli.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index d1ad2839..0e465d24 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -129,7 +129,7 @@ def test_resume_from_checkpoint(): "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " - "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " + "algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, @@ -225,7 +225,7 @@ def test_resume_from_checkpoint_algo_error(): "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " - "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " + "algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, @@ -280,7 +280,7 @@ def test_evaluate(): "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " - "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " + "algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True,