Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/a2c benchmarks #266

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ The training times of our implementations compared to the ones of Stable Baselin
<th>SheepRL v0.4.0</th>
<th>SheepRL v0.4.9</th>
<th>SheepRL v0.5.2<br />(Numpy Buffers)</th>
<th>SheepRL v0.5.5<br />(Numpy Buffers)</th>
<th>StableBaselines3<sup>1</sup></th>
</tr>
</thead>
Expand All @@ -101,13 +102,32 @@ The training times of our implementations compared to the ones of Stable Baselin
<td>192.31s &plusmn; 1.11</td>
<td>138.3s &plusmn; 0.16</td>
<td>80.81s &plusmn; 0.68</td>
<td>81.27s &plusmn; 0.47</td>
<td>77.21s &plusmn; 0.36</td>
</tr>
<tr>
<td><i>2 devices</i></td>
<td>85.42s &plusmn; 2.27</td>
<td>59.53s &plusmn; 0.78</td>
<td>46.09s &plusmn; 0.59</td>
<td>36.88s &plusmn; 0.30</td>
<td>N.D.</td>
</tr>
<tr>
<td rowspan="2"><b>A2C</b></td>
<td><i>1 device</i></td>
<td>N.D.</td>
<td>N.D.</td>
<td>N.D.</td>
<td>84.76s &plusmn; 0.37</td>
<td>84.22s &plusmn; 0.99</td>
</tr>
<tr>
<td><i>2 devices</i></td>
<td>N.D.</td>
<td>N.D.</td>
<td>N.D.</td>
<td>28.95s &plusmn; 0.75</td>
<td>N.D.</td>
</tr>
<tr>
Expand All @@ -116,13 +136,15 @@ The training times of our implementations compared to the ones of Stable Baselin
<td>421.37s &plusmn; 5.27</td>
<td>363.74s &plusmn; 3.44</td>
<td>318.06s &plusmn; 4.46</td>
<td>320.21 &plusmn; 6.29</td>
<td>336.06s &plusmn; 12.26</td>
</tr>
<tr>
<td><i>2 devices</i></td>
<td>264.29s &plusmn; 1.81</td>
<td>238.88s &plusmn; 4.97</td>
<td>210.07s &plusmn; 27</td>
<td>225.95 &plusmn; 3.65</td>
<td>N.D.</td>
</tr>
<tr>
Expand All @@ -131,6 +153,7 @@ The training times of our implementations compared to the ones of Stable Baselin
<td>4201.23s</td>
<td>N.D.</td>
<td>2921.38s</td>
<td>2207.13s</td>
<td>N.D.</td>
</tr>
<tr>
Expand All @@ -139,6 +162,7 @@ The training times of our implementations compared to the ones of Stable Baselin
<td>1874.62s</td>
<td>N.D.</td>
<td>1148.1s</td>
<td>906.42s</td>
<td>N.D.</td>
</tr>
<tr>
Expand All @@ -147,6 +171,7 @@ The training times of our implementations compared to the ones of Stable Baselin
<td>2022.99s</td>
<td>N.D.</td>
<td>1378.01s</td>
<td>1589.30s</td>
<td>N.D.</td>
</tr>
</tbody>
Expand Down
11 changes: 11 additions & 0 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@
# "algo.per_rank_batch_size=128"
]

# A2C Arguments
# args = [
# os.path.join(ROOT_DIR, "__main__.py"),
# "exp=a2c_benchmarks",
# # Decomment below to run with 2 devices
# # "fabric.devices=2",
# # "env.num_envs=2",
# # "algo.per_rank_batch_size=10",
# # "algo.rollout_steps=20",
# ]

# SAC Arguments
# args = [
# os.path.join(ROOT_DIR, "__main__.py"),
Expand Down
16 changes: 14 additions & 2 deletions benchmarks/benchmark_sb3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gymnasium as gym
import stable_baselines3 as sb3
from stable_baselines3 import PPO, SAC # noqa: F401
from stable_baselines3 import A2C, PPO, SAC # noqa: F401
from torchmetrics import SumMetric

from sheeprl.utils.timer import timer
Expand All @@ -15,6 +15,18 @@
print(sb3.common.evaluation.evaluate_policy(model.policy, env))


# Stable Baselines3 - A2C - CartPolev1
# Decomment below to run A2C benchmarks

# if __name__ == "__main__":
# with timer("run_time", SumMetric, sync_on_compute=False):
# env = gym.make("CartPole-v1", render_mode="rgb_array")
# model = A2C("MlpPolicy", env, verbose=0, device="cpu", vf_coef=1.0)
# model.learn(total_timesteps=1024 * 64, log_interval=None)
# print(timer.compute())
# print(sb3.common.evaluation.evaluate_policy(model.policy, env))


# Stable Baselines3 SAC - LunarLanderContinuous-v2
# Decomment below to run SAC benchmarks

Expand All @@ -23,7 +35,7 @@
# env = sb3.common.vec_env.DummyVecEnv(
# [lambda: gym.make("LunarLanderContinuous-v2", render_mode="rgb_array") for _ in range(4)]
# )
# model = SAC("MlpPolicy", env, verbose=0, device="cpu", ent_coef=1.0)
# model = SAC("MlpPolicy", env, verbose=0, device="cpu")
# model.learn(total_timesteps=1024 * 64, log_interval=None)
# print(timer.compute())
# print(sb3.common.evaluation.evaluate_policy(model.policy, env.envs[0]))
52 changes: 3 additions & 49 deletions examples/ratio.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,9 @@
import warnings
from typing import Any, Dict, Mapping


class Ratio:
"""Directly taken from Hafner et al. (2023) implementation:
https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/embodied/core/when.py#L26
"""

def __init__(self, ratio: float, pretrain_steps: int = 0):
if pretrain_steps < 0:
raise ValueError(f"'pretrain_steps' must be non-negative, got {pretrain_steps}")
if ratio < 0:
raise ValueError(f"'ratio' must be non-negative, got {ratio}")
self._pretrain_steps = pretrain_steps
self._ratio = ratio
self._prev = None

def __call__(self, step: int) -> int:
if self._ratio == 0:
return 0
if self._prev is None:
self._prev = step
repeats = 1
if self._pretrain_steps > 0:
if step < self._pretrain_steps:
warnings.warn(
"The number of pretrain steps is greater than the number of current steps. This could lead to "
f"a higher ratio than the one specified ({self._ratio}). Setting the 'pretrain_steps' equal to "
"the number of current steps."
)
self._pretrain_steps = step
repeats = round(self._pretrain_steps * self._ratio)
return repeats
repeats = round((step - self._prev) * self._ratio)
self._prev += repeats / self._ratio
return repeats

def state_dict(self) -> Dict[str, Any]:
return {"_ratio": self._ratio, "_prev": self._prev, "_pretrain_steps": self._pretrain_steps}

def load_state_dict(self, state_dict: Mapping[str, Any]):
self._ratio = state_dict["_ratio"]
self._prev = state_dict["_prev"]
self._pretrain_steps = state_dict["_pretrain_steps"]
return self

from sheeprl.utils.utils import Ratio

if __name__ == "__main__":
num_envs = 1
world_size = 1
replay_ratio = 0.5
replay_ratio = 0.0625
per_rank_batch_size = 16
per_rank_sequence_length = 64
replayed_steps = world_size * per_rank_batch_size * per_rank_sequence_length
Expand All @@ -62,7 +16,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any]):
for i in range(0, total_policy_steps, policy_steps):
if i >= 128:
per_rank_repeats = r(i / world_size)
if per_rank_repeats > 0 and not printed:
if per_rank_repeats > 0: # and not printed:
print(
f"Training the agent with {per_rank_repeats} repeats on every rank "
f"({per_rank_repeats * world_size} global repeats) at global iteration {i}"
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
or cfg.dry_run
or update == num_updates
or (update == num_updates and cfg.checkpoint.save_last)
):
last_checkpoint = policy_step
state = {
Expand All @@ -370,7 +370,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
30 changes: 17 additions & 13 deletions sheeprl/algos/a2c/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
super().__init__()
self.keys = keys
self.input_dim = input_dim
self.output_dim = features_dim
self.output_dim = features_dim if features_dim else dense_units
self.model = MLP(
input_dim,
features_dim,
Expand Down Expand Up @@ -96,18 +96,22 @@ def __init__(
)

# Actor
actor_backbone = MLP(
input_dims=features_dim,
output_dim=None,
hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers,
activation=hydra.utils.get_class(actor_cfg.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm] * actor_cfg.mlp_layers if actor_cfg.layer_norm else None,
norm_args=(
[{"normalized_shape": actor_cfg.dense_units} for _ in range(actor_cfg.mlp_layers)]
if actor_cfg.layer_norm
else None
),
actor_backbone = (
MLP(
input_dims=features_dim,
output_dim=None,
hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers,
activation=hydra.utils.get_class(actor_cfg.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm] * actor_cfg.mlp_layers if actor_cfg.layer_norm else None,
norm_args=(
[{"normalized_shape": actor_cfg.dense_units} for _ in range(actor_cfg.mlp_layers)]
if actor_cfg.layer_norm
else None
),
)
if actor_cfg.mlp_layers > 0
else nn.Identity()
)
if is_continuous:
# Output is a tuple of two elements: mean and log_std, one for every action
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if update >= learning_starts:
per_rank_gradient_steps = ratio(policy_step / world_size)
per_rank_gradient_steps = ratio(policy_step / world_size) if not cfg.run_benchmarks else 1
if per_rank_gradient_steps > 0:
# We sample one time to reduce the communications between processes
sample = rb.sample_tensors(
Expand Down
59 changes: 59 additions & 0 deletions sheeprl/configs/exp/a2c_benchmarks.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# @package _global_

defaults:
- override /algo: a2c
- override /env: gym
- _self_

# Environment
env:
capture_video: False
num_envs: 1
sync_env: True

# Algorithm
algo:
name: a2c
rollout_steps: 5
loss_reduction: mean
normalize_advantages: False
max_grad_norm: 0.5
encoder:
mlp_layers: 2
mlp_features_dim: null
actor:
mlp_layers: 0
critic:
mlp_layers: 0
optimizer:
lr: 7e-4
eps: 1e-5
alpha: 0.99
per_rank_batch_size: 5
# # If you want to run this benchmark with older versions,
# you need to comment the test function in the `./sheeprl/algos/ppo/ppo.py` file.
run_test: False
# If you want to run this benchmark with older versions,
# you need to move the `total_steps` and the `mlp_keys` config from `algo` to the root.
total_steps: 65536
mlp_keys:
encoder: [state]

# Buffer
buffer:
share_data: False
size: ${algo.rollout_steps}
memmap: False

fabric:
devices: 1
accelerator: cpu

checkpoint:
every: 70000
save_last: False

metric:
log_every: 70000
log_level: 0
disable_timer: True
1 change: 1 addition & 0 deletions sheeprl/configs/exp/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
run_benchmarks: False
4 changes: 1 addition & 3 deletions sheeprl/configs/exp/dreamer_v1_benchmarks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ defaults:
- override /env: atari
- _self_

# Experiment
seed: 5

# Environment
env:
num_envs: 1
Expand All @@ -26,6 +23,7 @@ buffer:
# Algorithm
algo:
learning_starts: 1024
replay_ratio: 0.0625

dense_units: 8
mlp_layers: 1
Expand Down
8 changes: 3 additions & 5 deletions sheeprl/configs/exp/dreamer_v2_benchmarks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ defaults:
- override /env: atari
- _self_

# Experiment
seed: 5

# Environment
env:
num_envs: 1
Expand All @@ -26,10 +23,11 @@ buffer:
# Algorithm
algo:
learning_starts: 1024
per_rank_pretrain_steps: 1
per_rank_pretrain_steps: 0
replay_ratio: 0.0625

dense_units: 8
mlp_layers:
mlp_layers: 1
world_model:
discrete_size: 4
stochastic_size: 4
Expand Down
5 changes: 1 addition & 4 deletions sheeprl/configs/exp/dreamer_v3_benchmarks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ defaults:
- override /env: atari
- _self_

# Experiment
seed: 5

# Environment
env:
num_envs: 1
Expand All @@ -26,7 +23,7 @@ buffer:
# Algorithm
algo:
learning_starts: 1024
replay_ratio: 1
replay_ratio: 0.0625
dense_units: 8
mlp_layers: 1
world_model:
Expand Down
Loading
Loading