Skip to content

Commit

Permalink
A2C (#33)
Browse files Browse the repository at this point in the history
* RMSProp from timm

* Single actor for both discrete/continuous actions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add A2C algo

* Print log_dir

* Add docs

* Add mlflow support

* Update README

* FIx comments

* Add A2C tests

* Added model manager A2C config

* Fix comments

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
belerico and pre-commit-ci[bot] authored Jan 4, 2024
1 parent 038facd commit 0633789
Show file tree
Hide file tree
Showing 43 changed files with 976 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,5 @@ pytest_*
mlruns
mlartifacts
examples/models
session_*
session_*
bin/
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ The algorithms sheeped by sheeprl out-of-the-box are:

| Algorithm | Coupled | Decoupled | Recurrent | Vector obs | Pixel obs | Status |
| ------------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :construction: |
| A3C | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :construction: |
| A2C | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :x: | :heavy_check_mark: |
| A3C | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :x: | :construction: |
| PPO | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| PPO Recurrent | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :heavy_check_mark: |
Expand Down
1 change: 1 addition & 0 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
fabric._loggers = [logger]
fabric.logger.log_hyperparams(cfg)
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name)
fabric.print(f"Log dir: {log_dir}")

# Environment setup
vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv
Expand Down
1 change: 1 addition & 0 deletions sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

# fmt: off
from sheeprl.algos.a2c import a2c # noqa: F401
from sheeprl.algos.dreamer_v1 import dreamer_v1 # noqa: F401
from sheeprl.algos.dreamer_v2 import dreamer_v2 # noqa: F401
from sheeprl.algos.dreamer_v3 import dreamer_v3 # noqa: F401
Expand Down
35 changes: 35 additions & 0 deletions sheeprl/algos/a2c/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# A2C algorithm
Advantage-Actor-Critic (A2C) is an on-policy algorithm that uses the standard policy gradient algorithm, scaled by the advantages, to update the policy.

From the interaction with the environment, it collects trajectories of *observations*, *actions*, *rewards*, *values*, *logprobs* and *dones*. These trajectories are stored in a buffer, and used to train the policy and value networks.

Indeed, the training loop consists in sampling a batch of trajectories from the buffer, and computing the *policy loss*, *value loss* and *entropy loss*, while accumulating the gradients over the trajectories collected during the interaction with the environment. By deafult it will sum the gradients over multiple batches, averaging them across all replicas on the last batch seen.

From the rewards, *returns* and *advantages* are estimated. The *returns*, together with the values stored in the buffer and the values from the updated critic, are used to compute the *value loss*.

```python
def value_loss(
values: Tensor,
returns: Tensor,
clip_coef: float,
clip_vloss: bool,
) -> Tensor:
return mse_loss(values, returns)
```

Advantages and logprobs are used to compute the *policy loss*, using also the logprobs from the updated model.

```python
def policy_loss(logprobs: Tensor, advantages: Tensor) -> Tensor:
pg_loss = -logprobs * advantages.detach()
reduction = reduction.lower()
if reduction == "none":
return pg_loss
elif reduction == "mean":
return pg_loss.mean()
elif reduction == "sum":
return pg_loss.sum()
else:
raise ValueError(f"Unrecognized reduction: {reduction}")
```

Empty file added sheeprl/algos/a2c/__init__.py
Empty file.
Loading

0 comments on commit 0633789

Please sign in to comment.