Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A2C #33

Merged
merged 22 commits into from
Jan 4, 2024
Merged

A2C #33

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
54823e2
RMSProp from timm
belerico Jun 12, 2023
e09eb95
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Jun 13, 2023
d752701
Single actor for both discrete/continuous actions
belerico Jun 13, 2023
758fd09
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Dec 11, 2023
bbe1383
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2023
94adb41
Add A2C algo
belerico Dec 13, 2023
97ff0ae
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Dec 13, 2023
e2a92bc
Merge branch 'feature/a2c' of https://github.com/Eclectic-Sheep/sheep…
belerico Dec 13, 2023
f8638a2
Print log_dir
belerico Dec 13, 2023
fd1f334
Add docs
belerico Dec 14, 2023
f1327fc
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Dec 19, 2023
feb00aa
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Dec 19, 2023
375840d
Add mlflow support
belerico Dec 19, 2023
bec270d
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Dec 21, 2023
6a3a834
Update README
belerico Dec 21, 2023
f18b37d
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into feature…
belerico Jan 3, 2024
9626488
FIx comments
belerico Jan 3, 2024
5c7e7c8
Add A2C tests
belerico Jan 4, 2024
1fae0d0
Added model manager A2C config
belerico Jan 4, 2024
287c3d9
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into feature…
belerico Jan 4, 2024
bd4d9c6
Fix comments
belerico Jan 4, 2024
5577eed
Merge branch 'main' into feature/a2c
belerico Jan 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,5 @@ pytest_*
mlruns
mlartifacts
examples/models
session_*
session_*
bin/
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ The algorithms sheeped by sheeprl out-of-the-box are:

| Algorithm | Coupled | Decoupled | Recurrent | Vector obs | Pixel obs | Status |
| ------------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :construction: |
| A3C | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :construction: |
| A2C | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :x: | :heavy_check_mark: |
| A3C | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :x: | :construction: |
| PPO | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| PPO Recurrent | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :heavy_check_mark: |
Expand Down
1 change: 1 addition & 0 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
fabric._loggers = [logger]
fabric.logger.log_hyperparams(cfg)
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name)
fabric.print(f"Log dir: {log_dir}")

# Environment setup
vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv
Expand Down
1 change: 1 addition & 0 deletions sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

# fmt: off
from sheeprl.algos.a2c import a2c # noqa: F401
from sheeprl.algos.dreamer_v1 import dreamer_v1 # noqa: F401
from sheeprl.algos.dreamer_v2 import dreamer_v2 # noqa: F401
from sheeprl.algos.dreamer_v3 import dreamer_v3 # noqa: F401
Expand Down
35 changes: 35 additions & 0 deletions sheeprl/algos/a2c/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# A2C algorithm
Advantage-Actor-Critic (A2C) is an on-policy algorithm that uses the standard policy gradient algorithm, scaled by the advantages, to update the policy.

From the interaction with the environment, it collects trajectories of *observations*, *actions*, *rewards*, *values*, *logprobs* and *dones*. These trajectories are stored in a buffer, and used to train the policy and value networks.

Indeed, the training loop consists in sampling a batch of trajectories from the buffer, and computing the *policy loss*, *value loss* and *entropy loss*, while accumulating the gradients over the trajectories collected during the interaction with the environment. By deafult it will sum the gradients over multiple batches, averaging them across all replicas on the last batch seen.

From the rewards, *returns* and *advantages* are estimated. The *returns*, together with the values stored in the buffer and the values from the updated critic, are used to compute the *value loss*.

```python
def value_loss(
values: Tensor,
returns: Tensor,
clip_coef: float,
clip_vloss: bool,
) -> Tensor:
return mse_loss(values, returns)
```

Advantages and logprobs are used to compute the *policy loss*, using also the logprobs from the updated model.

```python
def policy_loss(logprobs: Tensor, advantages: Tensor) -> Tensor:
pg_loss = -logprobs * advantages.detach()
reduction = reduction.lower()
if reduction == "none":
return pg_loss
elif reduction == "mean":
return pg_loss.mean()
elif reduction == "sum":
return pg_loss.sum()
else:
raise ValueError(f"Unrecognized reduction: {reduction}")
```

Empty file added sheeprl/algos/a2c/__init__.py
Empty file.
Loading