Skip to content

Commit

Permalink
Fix/bugs (#181)
Browse files Browse the repository at this point in the history
* fix: minedojo _convert_action() function

* fix: ppo decoupled and ppo recurrent

* fix: ppo and sac
  • Loading branch information
michele-milesi authored Jan 4, 2024
1 parent 2201903 commit 038facd
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 16 deletions.
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_v = torch_v.view(cfg.env.num_envs, -1, *v.shape[-2:])
torch_v = torch_v.view(-1, *v.shape[-2:])
torch_v = torch_v / 255.0 - 0.5
real_next_obs[k][i] = torch_v
with torch.no_grad():
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def player(
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_v = torch_v.view(cfg.env.num_envs, -1, *v.shape[-2:])
torch_v = torch_v.view(-1, *v.shape[-2:])
torch_v = torch_v / 255.0 - 0.5
real_next_obs[k][i] = torch_v
with torch.no_grad():
Expand Down Expand Up @@ -480,7 +480,7 @@ def trainer(

# Start training
with timer(
"Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg)
"Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg
):
# The Join context is needed because there can be the possibility
# that some ranks receive less data
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_v = torch_v.view(1, len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5
torch_v = torch_v.view(1, 1, -1, *torch_v.shape[-2:]) / 255.0 - 0.5
real_next_obs[k][0, i] = torch_v
with torch.no_grad():
feat = agent.module.feature_extractor(real_next_obs)
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def trainer(

# Start training
with timer(
"Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg)
"Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg
):
for batch_idxes in sampler:
train(
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/configs/algo/ppo_recurrent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ per_rank_sequence_length: ???
# Model related parameters
mlp_layers: 1
layer_norm: True
dense_units: 128
dense_units: 256
dense_act: torch.nn.ReLU
rnn:
lstm:
hidden_size: 64
hidden_size: 128
pre_rnn_mlp:
bias: True
apply: False
Expand All @@ -36,7 +36,7 @@ rnn:
layer_norm: ${algo.layer_norm}
dense_units: ${algo.rnn.lstm.hidden_size}
encoder:
dense_units: 64
dense_units: 128

# Optimizer related parameters
optimizer:
Expand Down
12 changes: 5 additions & 7 deletions sheeprl/configs/exp/ppo_recurrent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ defaults:
- _self_

algo:
per_rank_num_batches: 4
per_rank_sequence_length: 8
total_steps: 409600
rollout_steps: 256
update_epochs: 4
per_rank_num_batches: 8
per_rank_sequence_length: 16
total_steps: 409000
rollout_steps: 512
update_epochs: 8

# Environment
env:
id: CartPole-v1
num_envs: 16
mask_velocities: True

buffer:
memmap: False
2 changes: 1 addition & 1 deletion sheeprl/envs/minedojo.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _convert_action(self, action: np.ndarray) -> np.ndarray:
self._sticky_attack_counter -= 1
# it the selected action is not attack, then the agent stops the sticky attack
elif converted_action[5] != 3:
self._sticky_attack = 0
self._sticky_attack_counter = 0
if self._sticky_jump:
# 2 is the index of the jump/sneak/sprint actions, 1 is the value for the jump action
if converted_action[2] == 1:
Expand Down

0 comments on commit 038facd

Please sign in to comment.