From df9f1fa42b3213558e61721c837051eed58c2c94 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 3 Jan 2024 09:41:50 +0100 Subject: [PATCH 1/3] fix: minedojo _convert_action() function --- sheeprl/algos/ppo/ppo.py | 2 +- sheeprl/configs/algo/ppo_recurrent.yaml | 6 +++--- sheeprl/configs/exp/ppo_recurrent.yaml | 10 +++++----- sheeprl/envs/minedojo.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 300a8244..3219940d 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -300,7 +300,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - torch_v = torch_v.view(cfg.env.num_envs, -1, *v.shape[-2:]) + torch_v = torch_v.view(-1, *v.shape[-2:]) torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v with torch.no_grad(): diff --git a/sheeprl/configs/algo/ppo_recurrent.yaml b/sheeprl/configs/algo/ppo_recurrent.yaml index d4bac2cd..24384415 100644 --- a/sheeprl/configs/algo/ppo_recurrent.yaml +++ b/sheeprl/configs/algo/ppo_recurrent.yaml @@ -18,11 +18,11 @@ per_rank_sequence_length: ??? # Model related parameters mlp_layers: 1 layer_norm: True -dense_units: 128 +dense_units: 256 dense_act: torch.nn.ReLU rnn: lstm: - hidden_size: 64 + hidden_size: 128 pre_rnn_mlp: bias: True apply: False @@ -36,7 +36,7 @@ rnn: layer_norm: ${algo.layer_norm} dense_units: ${algo.rnn.lstm.hidden_size} encoder: - dense_units: 64 + dense_units: 128 # Optimizer related parameters optimizer: diff --git a/sheeprl/configs/exp/ppo_recurrent.yaml b/sheeprl/configs/exp/ppo_recurrent.yaml index 4664cc39..317e9ba3 100644 --- a/sheeprl/configs/exp/ppo_recurrent.yaml +++ b/sheeprl/configs/exp/ppo_recurrent.yaml @@ -7,11 +7,11 @@ defaults: - _self_ algo: - per_rank_num_batches: 4 - per_rank_sequence_length: 8 - total_steps: 409600 - rollout_steps: 256 - update_epochs: 4 + per_rank_num_batches: 8 + per_rank_sequence_length: 16 + total_steps: 650000 + rollout_steps: 512 + update_epochs: 8 # Environment env: diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index 08f5c13a..00c0837f 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -192,7 +192,7 @@ def _convert_action(self, action: np.ndarray) -> np.ndarray: self._sticky_attack_counter -= 1 # it the selected action is not attack, then the agent stops the sticky attack elif converted_action[5] != 3: - self._sticky_attack = 0 + self._sticky_attack_counter = 0 if self._sticky_jump: # 2 is the index of the jump/sneak/sprint actions, 1 is the value for the jump action if converted_action[2] == 1: From d5018f3feb0efeede6dc705d00c72ae6d8d2d96d Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 3 Jan 2024 09:57:12 +0100 Subject: [PATCH 2/3] fix: ppo decoupled and ppo recurrent --- sheeprl/algos/ppo/ppo_decoupled.py | 2 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index bf019b52..7a10d18b 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -222,7 +222,7 @@ def player( for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - torch_v = torch_v.view(cfg.env.num_envs, -1, *v.shape[-2:]) + torch_v = torch_v.view(-1, *v.shape[-2:]) torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v with torch.no_grad(): diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 794c9fc8..fd974b4b 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -315,7 +315,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - torch_v = torch_v.view(1, len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 + torch_v = torch_v.view(1, 1, -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 real_next_obs[k][0, i] = torch_v with torch.no_grad(): feat = agent.module.feature_extractor(real_next_obs) From 33c0cd00bd4c6e7fb48e955a0f6d69b33b06d039 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Thu, 4 Jan 2024 16:20:12 +0000 Subject: [PATCH 3/3] fix: ppo and sac --- sheeprl/algos/ppo/ppo_decoupled.py | 2 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 2 +- sheeprl/algos/sac/sac_decoupled.py | 2 +- sheeprl/configs/exp/ppo_recurrent.yaml | 4 +--- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 7a10d18b..f8a81b5f 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -480,7 +480,7 @@ def trainer( # Start training with timer( - "Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg) + "Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg ): # The Join context is needed because there can be the possibility # that some ranks receive less data diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index fd974b4b..dddff56d 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -315,7 +315,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - torch_v = torch_v.view(1, 1, -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 + torch_v = torch_v.view(1, 1, -1, *torch_v.shape[-2:]) / 255.0 - 0.5 real_next_obs[k][0, i] = torch_v with torch.no_grad(): feat = agent.module.feature_extractor(real_next_obs) diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 705a53c0..daf98a47 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -453,7 +453,7 @@ def trainer( # Start training with timer( - "Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg) + "Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg ): for batch_idxes in sampler: train( diff --git a/sheeprl/configs/exp/ppo_recurrent.yaml b/sheeprl/configs/exp/ppo_recurrent.yaml index 317e9ba3..001afa6b 100644 --- a/sheeprl/configs/exp/ppo_recurrent.yaml +++ b/sheeprl/configs/exp/ppo_recurrent.yaml @@ -9,15 +9,13 @@ defaults: algo: per_rank_num_batches: 8 per_rank_sequence_length: 16 - total_steps: 650000 + total_steps: 409000 rollout_steps: 512 update_epochs: 8 # Environment env: - id: CartPole-v1 num_envs: 16 - mask_velocities: True buffer: memmap: False