Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QR-DQN #13

Merged
merged 34 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ad4f445
Add QR-DQN(WIP)
toshikwa Dec 8, 2020
0a84573
Update docstring
toshikwa Dec 8, 2020
5a72eba
Add quantile_huber_loss
toshikwa Dec 8, 2020
671d328
Fix typo
toshikwa Dec 8, 2020
04d0612
Merge branch 'master' into feat/qrdqn
araffin Dec 8, 2020
51a52cf
Remove unnecessary lines
toshikwa Dec 8, 2020
d94d583
Update variable names and comments in quantile_huber_loss
toshikwa Dec 8, 2020
8b36a21
Fix mutable arguments
toshikwa Dec 8, 2020
50e7e8d
Update variable names
toshikwa Dec 8, 2020
d456bc0
Merge branch 'master' into feat/qrdqn
araffin Dec 8, 2020
f55b8ad
Ignore import not used warnings
toshikwa Dec 9, 2020
f4ece75
Fix default parameter of optimizer in QR-DQN
toshikwa Dec 9, 2020
d67d5e8
Update quantile_huber_loss to have more reasonable interface
toshikwa Dec 9, 2020
92c8d10
Merge branch 'feat/qrdqn' of https://github.com/ku2482/stable-baselin…
toshikwa Dec 9, 2020
39d5bc7
update tests
toshikwa Dec 9, 2020
b335b37
Add assertion to quantile_huber_loss
toshikwa Dec 9, 2020
62c336a
Update variable names of quantile regression
toshikwa Dec 9, 2020
2f350e5
Update comments
toshikwa Dec 10, 2020
bedbc80
Reduce the number of quantiles during test
toshikwa Dec 10, 2020
11ae6b0
Update comment
toshikwa Dec 10, 2020
faeda56
Merge branch 'master' into feat/qrdqn
araffin Dec 13, 2020
d2b1ab7
Update quantile_huber_loss
toshikwa Dec 13, 2020
cd419da
Fix isort
toshikwa Dec 13, 2020
5449171
Add document of QR-DQN without results
toshikwa Dec 13, 2020
e0de065
Update docs
toshikwa Dec 13, 2020
147d3e8
Fix bugs
toshikwa Dec 13, 2020
4f31b17
Update doc
araffin Dec 19, 2020
b54b5d6
Add comments about shape
araffin Dec 19, 2020
29d1912
Minor edits
araffin Dec 19, 2020
eac6080
Update comments
toshikwa Dec 19, 2020
b27ed43
Add benchmark
araffin Dec 20, 2020
fe9f015
Doc fixes
araffin Dec 20, 2020
f213f5e
Update doc
araffin Dec 21, 2020
3a53ee1
Bug fix in saving/loading + update tests
araffin Dec 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ See documentation for the full list of included features.

**RL Algorithms**:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)

**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
Expand Down
1 change: 1 addition & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ along with some useful characteristics: support for discrete/continuous actions,
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
TQC ✔️ ❌ ❌ ❌ ❌
QR-DQN ️❌ ️✔️ ❌ ❌ ❌
============ =========== ============ ================= =============== ================


Expand Down
15 changes: 15 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
model.learn(total_timesteps=10000, log_interval=4)
model.save("tqc_pendulum")

QR-DQN
------

Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.

.. code-block:: python

from sb3_contrib import QRDQN

policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("qrdqn_cartpole")


.. PyBullet: Normalizing input features
.. ------------------------------------
..
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:caption: RL Algorithms

modules/tqc
modules/qrdqn

.. toctree::
:maxdepth: 1
Expand Down
13 changes: 11 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``TimeFeatureWrapper`` to the wrappers
- Added ``QR-DQN`` algorithm (`@ku2482`_)

Bug Fixes:
^^^^^^^^^^
- Fixed bug in ``TQC`` when saving/loading the policy only with non-default number of quantiles

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -24,6 +26,7 @@ Others:
^^^^^^^
- Updated ``TQC`` to match new SB3 version
- Updated SB3 min version
- Moved ``quantile_huber_loss`` to ``common/utils.py`` (@ku2482)

Documentation:
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -62,13 +65,19 @@ Maintainers
-----------

Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
`Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).

.. _Ashley Hill: https://github.com/hill-a
.. _Antonin Raffin: https://araffin.github.io/
.. _Maximilian Ernestus: https://github.com/erniejunior
.. _Maximilian Ernestus: https://github.com/ernestum
.. _Adam Gleave: https://gleave.me/
.. _@araffin: https://github.com/araffin
.. _@AdamGleave: https://github.com/adamgleave
.. _Anssi Kanervisto: https://github.com/Miffyli
.. _@Miffyli: https://github.com/Miffyli
.. _@ku2482: https://github.com/ku2482

Contributors:
-------------

@ku2482
150 changes: 150 additions & 0 deletions docs/modules/qrdqn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
.. _qrdqn:

.. automodule:: sb3_contrib.qrdqn


QR-DQN
======

`Quantile Regression DQN (QR-DQN) <https://arxiv.org/abs/1710.10044>`_ builds on `Deep Q-Network (DQN) <https://arxiv.org/abs/1312.5602>`_
and make use of quantile regression to explicitly model the `distribution over returns <https://arxiv.org/abs/1707.06887>`_,
instead of predicting the mean return (DQN).


.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
CnnPolicy


Notes
-----

- Original paper: https://arxiv.org/abs/1710.100442
- Distributional RL (C51): https://arxiv.org/abs/1707.06887


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ❌
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔ ✔
Box ❌ ✔
MultiDiscrete ❌ ✔
MultiBinary ❌ ✔
============= ====== ===========


Example
-------

.. code-block:: python

import gym

from sb3_contrib import QRDQN

env = gym.make("CartPole-v1")

policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("qrdqn_cartpole")

del model # remove to demonstrate saving and loading

model = QRDQN.load("qrdqn_cartpole")

obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()


Results
-------

Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds.

The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/13>`_.


.. note::

QR-DQN implementation was validated against `Intel Coach <https://github.com/IntelLabs/coach/tree/master/benchmarks/qr_dqn>`_ one
which roughly compare to the original paper results (we trained the agent with a smaller budget).


============ ========== ===========
Environments QR-DQN DQN
============ ========== ===========
Breakout 413 +/- 21 ~300
Pong 20 +/- 0 ~20
CartPole 386 +/- 64 500 +/- 0
MountainCar -111 +/- 4 -107 +/- 4
LunarLander 168 +/- 39 195 +/- 28
Acrobot -73 +/- 2 -74 +/- 2
============ ========== ===========

How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Clone RL-Zoo fork and checkout the branch ``feat/qrdqn``:

.. code-block:: bash

git clone https://github.com/ku2482/rl-baselines3-zoo/
cd rl-baselines3-zoo/
git checkout feat/qrdqn

Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):

.. code-block:: bash

python train.py --algo qrdqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000


Plot the results:

.. code-block:: bash

python scripts/all_plots.py -a qrdqn -e Breakout Pong -f logs/ -o logs/qrdqn_results
python scripts/plot_from_file.py -i logs/qrdqn_results.pkl -latex -l QR-DQN



Parameters
----------

.. autoclass:: QRDQN
:members:
:inherited-members:

.. _qrdqn_policies:

QR-DQN Policies
---------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: sb3_contrib.qrdqn.policies.QRDQNPolicy
:members:
:noindex:

.. autoclass:: CnnPolicy
:members:
2 changes: 1 addition & 1 deletion sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

# from sb3_contrib.cmaes import CMAES
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC

# Read version from file
Expand Down
69 changes: 69 additions & 0 deletions sb3_contrib/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Optional

import torch as th


def quantile_huber_loss(
current_quantiles: th.Tensor,
target_quantiles: th.Tensor,
cum_prob: Optional[th.Tensor] = None,
sum_over_quantiles: bool = True,
) -> th.Tensor:
"""
The quantile-regression loss, as described in the QR-DQN and TQC papers.
Partially taken from https://github.com/bayesgroup/tqc_pytorch.

:param current_quantiles: current estimate of quantiles, must be either
(batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)
:param target_quantiles: target of quantiles, must be either (batch_size, n_target_quantiles),
(batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)
:param cum_prob: cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper),
must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles).
(if None, calculating unit quantiles)
:param sum_over_quantiles: if summing over the quantile dimension or not
:return: the loss
"""
if current_quantiles.ndim != target_quantiles.ndim:
raise ValueError(
f"Error: The dimension of curremt_quantile ({current_quantiles.ndim}) needs to match "
f"the dimension of target_quantiles ({target_quantiles.ndim})."
)
if current_quantiles.shape[0] != target_quantiles.shape[0]:
raise ValueError(
f"Error: The batch size of curremt_quantile ({current_quantiles.shape[0]}) needs to match "
f"the batch size of target_quantiles ({target_quantiles.shape[0]})."
)
if current_quantiles.ndim not in (2, 3):
raise ValueError(f"Error: The dimension of current_quantiles ({current_quantiles.ndim}) needs to be either 2 or 3.")

if cum_prob is None:
n_quantiles = current_quantiles.shape[-1]
# Cumulative probabilities to calculate quantiles.
cum_prob = (th.arange(n_quantiles, device=current_quantiles.device, dtype=th.float) + 0.5) / n_quantiles
if current_quantiles.ndim == 2:
# For QR-DQN, current_quantiles have a shape (batch_size, n_quantiles), and make cum_prob
# broadcastable to (batch_size, n_quantiles, n_target_quantiles)
cum_prob = cum_prob.view(1, -1, 1)
elif current_quantiles.ndim == 3:
# For TQC, current_quantiles have a shape (batch_size, n_critics, n_quantiles), and make cum_prob
# broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles)
cum_prob = cum_prob.view(1, 1, -1, 1)

# QR-DQN
# target_quantiles: (batch_size, n_target_quantiles) -> (batch_size, 1, n_target_quantiles)
# current_quantiles: (batch_size, n_quantiles) -> (batch_size, n_quantiles, 1)
# pairwise_delta: (batch_size, n_target_quantiles, n_quantiles)
# TQC
# target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles)
# current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1)
# pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles)
# Note: in both cases, the loss has the same shape as pairwise_delta
pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1)
abs_pairwise_delta = th.abs(pairwise_delta)
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss
if sum_over_quantiles:
loss = loss.sum(dim=-2).mean()
else:
loss = loss.mean()
return loss
2 changes: 2 additions & 0 deletions sb3_contrib/qrdqn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy
from sb3_contrib.qrdqn.qrdqn import QRDQN
Loading