Skip to content

Commit

Permalink
Added CrossQ (#453)
Browse files Browse the repository at this point in the history
* Added CrossQ

* Add more hyperparameters

* Swicthed to uv and update requirements

* Update submodule

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
danielpalen and araffin authored Oct 24, 2024
1 parent 726e2f1 commit 6409c41
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 23 deletions.
18 changes: 10 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,21 @@ jobs:
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# Install Atari Roms
pip install autorom
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# cpu version of pytorch - faster to download
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install full requirements (for additional envs and test tools)
uv pip install --system -r requirements.txt
# Use headless version
pip install opencv-python-headless
pip install -e .[plots,tests]
uv pip install --system opencv-python-headless
uv pip install --system -e .[plots,tests]
- name: Lint with ruff
run: |
make lint
Expand All @@ -62,4 +64,4 @@ jobs:
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
make pytest
15 changes: 9 additions & 6 deletions .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@ jobs:
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# Install Atari Roms
pip install autorom
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# cpu version of pytorch - faster to download
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install full requirements (for additional envs and test tools)
uv pip install --system -r requirements.txt
# Use headless version
pip install opencv-python-headless
pip install -e .[plots,tests]
uv pip install --system opencv-python-headless
uv pip install --system -e .[plots,tests]
- name: Check trained agents
run: |
make check-trained-agents
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
## Release 2.4.0a4 (WIP)
## Release 2.4.0a10 (WIP)

**New algorithm: CrossQ, and better defaults for SAC/TQC on Swimmer-v4 env**

### Breaking Changes
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) (@JacobHA) [W&B report](https://wandb.ai/openrlbenchmark/sbx/reports/SAC-MuJoCo-Swimmer-v4--Vmlldzo3NzM5OTk2)
- Upgraded to SB3 >= 2.4.0

### New Features
- Added `CrossQ` hyperparameters for SB3-contrib (@danielpalen)

### Bug fixes
- Replaced deprecated `huggingface_hub.Repository` when pushing to Hugging Face Hub by the recommended `HfApi` (see https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http) (@cochaviz)

### Documentation

### Other
- Updated PyTorch version to 2.3.1 in the CI
- Updated PyTorch version to 2.4.1 in the CI
- Switched to uv to download packages faster on GitHub CI

## Release 2.3.0 (2024-03-31)

Expand Down
91 changes: 91 additions & 0 deletions hyperparams/crossq.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
MountainCarContinuous-v0:
n_timesteps: !!float 50000
policy: 'MlpPolicy'
learning_rate: !!float 7e-4
buffer_size: 50000
train_freq: 32
gradient_steps: 32
gamma: 0.9999
learning_starts: 100
use_sde: True
policy_delay: 2
policy_kwargs: "dict(use_expln=True, log_std_init=-1, net_arch=[64, 64])"

Pendulum-v1:
n_timesteps: 20000
policy: 'MlpPolicy'
policy_delay: 2
policy_kwargs: "dict(net_arch=[256, 256])"


LunarLanderContinuous-v2:
n_timesteps: !!float 2e5
policy: 'MlpPolicy'
buffer_size: 1000000
learning_starts: 10000


BipedalWalker-v3:
n_timesteps: !!float 2e5
policy: 'MlpPolicy'
buffer_size: 300000
gamma: 0.98
learning_starts: 10000
policy_kwargs: "dict(net_arch=dict(pi=[256, 256], qf=[1024, 1024]))"

# === Mujoco Envs ===

HalfCheetah-v4: &mujoco-defaults
buffer_size: 1_000_000
learning_rate: !!float 1e-3
learning_starts: 5000
n_timesteps: !!float 5e6
policy: 'MlpPolicy'
policy_delay: 3
policy_kwargs: "dict(net_arch=dict(pi=[256, 256], qf=[2048, 2048]))"

Ant-v4:
<<: *mujoco-defaults

Hopper-v4:
<<: *mujoco-defaults

Walker2d-v4:
<<: *mujoco-defaults

Humanoid-v4:
<<: *mujoco-defaults

HumanoidStandup-v4:
<<: *mujoco-defaults

Swimmer-v4:
<<: *mujoco-defaults
gamma: 0.999

# Tuned for SAC, need to check with CrossQ
HalfCheetahBulletEnv-v0: &pybullet-defaults
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
train_freq: 8
gradient_steps: 8
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(use_expln=True, log_std_init=-3)"

# Tuned
AntBulletEnv-v0:
<<: *pybullet-defaults

HopperBulletEnv-v0:
<<: *pybullet-defaults
learning_rate: lin_7.3e-4

Walker2DBulletEnv-v0:
<<: *pybullet-defaults
learning_rate: lin_7.3e-4
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
gym==0.26.2
stable-baselines3[extra_no_roms,tests,docs]>=2.4.0a0,<3.0
stable-baselines3[extra_no_roms,tests,docs]>=2.4.0a10,<3.0
box2d-py==2.3.8
pybullet_envs_gymnasium>=0.4.0
# minigrid
Expand Down
3 changes: 2 additions & 1 deletion rl_zoo3/push_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch as th
import yaml
from huggingface_hub import HfApi, Repository
from huggingface_hub import HfApi
from huggingface_hub.repocard import metadata_save
from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId
from huggingface_sb3.push_to_hub import _evaluate_agent, _generate_replay, generate_metadata
Expand Down Expand Up @@ -83,6 +83,7 @@ def generate_model_card(
RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo<br/>
SB3: https://github.com/DLR-RM/stable-baselines3<br/>
SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
SBX (SB3 + Jax): https://github.com/araffin/sbx
Install the RL Zoo (with SB3 and SB3-Contrib):
```bash
Expand Down
3 changes: 2 additions & 1 deletion rl_zoo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from gymnasium import spaces
from huggingface_hub import HfApi
from huggingface_sb3 import EnvironmentName, ModelName
from sb3_contrib import ARS, QRDQN, TQC, TRPO, RecurrentPPO
from sb3_contrib import ARS, QRDQN, TQC, TRPO, CrossQ, RecurrentPPO
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
Expand All @@ -32,6 +32,7 @@
"td3": TD3,
# SB3 Contrib,
"ars": ARS,
"crossq": CrossQ,
"qrdqn": QRDQN,
"tqc": TQC,
"trpo": TRPO,
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a4
2.4.0a10
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
See https://github.com/DLR-RM/rl-baselines3-zoo
"""
install_requires = [
"sb3_contrib>=2.4.0a4,<3.0",
"sb3_contrib>=2.4.0a10,<3.0",
"gymnasium~=0.29.1",
"huggingface_sb3>=3.0,<4.0",
"tqdm",
Expand All @@ -24,8 +24,7 @@
"pyyaml>=5.1",
"pytablewriter~=1.2",
]
# TODO(antonin): update to rliable>=1.1.0 once PR is merged and released
plots_requires = ["seaborn", "rliable @ git+https://github.com/araffin/rliable@patch-1", "scipy~=1.10"]
plots_requires = ["seaborn", "rliable~=1.2.0", "scipy~=1.10"]
test_requires = [
# for MuJoCo envs v4:
"mujoco~=2.3",
Expand Down

0 comments on commit 6409c41

Please sign in to comment.