Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/bugs #181

Merged
merged 4 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_v = torch_v.view(cfg.env.num_envs, -1, *v.shape[-2:])
torch_v = torch_v.view(-1, *v.shape[-2:])
torch_v = torch_v / 255.0 - 0.5
real_next_obs[k][i] = torch_v
with torch.no_grad():
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def player(
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_v = torch_v.view(cfg.env.num_envs, -1, *v.shape[-2:])
torch_v = torch_v.view(-1, *v.shape[-2:])
torch_v = torch_v / 255.0 - 0.5
real_next_obs[k][i] = torch_v
with torch.no_grad():
Expand Down Expand Up @@ -480,7 +480,7 @@ def trainer(

# Start training
with timer(
"Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg)
"Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg
):
# The Join context is needed because there can be the possibility
# that some ranks receive less data
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_v = torch_v.view(1, len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5
torch_v = torch_v.view(1, 1, -1, *torch_v.shape[-2:]) / 255.0 - 0.5
real_next_obs[k][0, i] = torch_v
with torch.no_grad():
feat = agent.module.feature_extractor(real_next_obs)
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def trainer(

# Start training
with timer(
"Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg)
"Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg
):
for batch_idxes in sampler:
train(
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/configs/algo/ppo_recurrent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ per_rank_sequence_length: ???
# Model related parameters
mlp_layers: 1
layer_norm: True
dense_units: 128
dense_units: 256
dense_act: torch.nn.ReLU
rnn:
lstm:
hidden_size: 64
hidden_size: 128
pre_rnn_mlp:
bias: True
apply: False
Expand All @@ -36,7 +36,7 @@ rnn:
layer_norm: ${algo.layer_norm}
dense_units: ${algo.rnn.lstm.hidden_size}
encoder:
dense_units: 64
dense_units: 128

# Optimizer related parameters
optimizer:
Expand Down
12 changes: 5 additions & 7 deletions sheeprl/configs/exp/ppo_recurrent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ defaults:
- _self_

algo:
per_rank_num_batches: 4
per_rank_sequence_length: 8
total_steps: 409600
rollout_steps: 256
update_epochs: 4
per_rank_num_batches: 8
per_rank_sequence_length: 16
total_steps: 409000
rollout_steps: 512
update_epochs: 8

# Environment
env:
id: CartPole-v1
num_envs: 16
mask_velocities: True

buffer:
memmap: False
2 changes: 1 addition & 1 deletion sheeprl/envs/minedojo.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _convert_action(self, action: np.ndarray) -> np.ndarray:
self._sticky_attack_counter -= 1
# it the selected action is not attack, then the agent stops the sticky attack
elif converted_action[5] != 3:
self._sticky_attack = 0
self._sticky_attack_counter = 0
if self._sticky_jump:
# 2 is the index of the jump/sneak/sprint actions, 1 is the value for the jump action
if converted_action[2] == 1:
Expand Down