Skip to content

A Bayesian RL Framework with Probabilistic Generative Models

License

Notifications You must be signed in to change notification settings

ItoMasaki/PixyzRL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PixyzRL: A Reinforcement Learning Framework with Probabilistic Generative Models

PixyzRL Logo

License: MIT PyTorch Version Python Version workflow codecov Open in Visual Studio Code PyPI Downloads

Documentation | Examples | GitHub

What is PixyzRL?

PixyzRL is a reinforcement learning (RL) framework based on probabilistic generative models and Bayesian theory. Built on top of the Pixyz library, it provides a modular and flexible design to enable uncertainty-aware decision-making and improve sample efficiency. PixyzRL supports:

  • Probabilistic Policy Optimization (e.g., PPO, A2C)
  • On-policy and Off-policy Learning
  • Memory Management for RL (Replay Buffer, Rollout Buffer)
  • Advantage calculations are supported by MC / GAE / GRPO
  • Integration with Gymnasium environments
  • Logging and Model Training Utilities
CartPole-v1 CarRacing-v3
CartPole-v1.mp4
CarRacing-v3.mp4
examples/cartpole_v1_ppo_discrete_trainer.py examples/car_racing_v3_ppo_continual.py
Bipedal-Walker-v3 Lunar-Lander-v3
bipedal-walker-v3.mp4
luna-lunder.mp4
examples/bipedal_walker_v3_ppo_continual.py examples/lunar_lander_v3_ppo_continue_trainer.py
CartPole-v1 ( GRPO ) TEST
test_760.mp4

Installation

Requirements

  • Python 3.10+
  • PyTorch 2.5.1+
  • Gymnasium (for environment interaction)

Install PixyzRL

Using pip

pip install pixyzrl

Install from Source

git clone https://github.com/ItoMasaki/PixyzRL.git
cd PixyzRL
pip install -e .

Quick Start

1. Set Up Environment

import torch
from pixyz.distributions import Categorical, Deterministic
from torch import nn

from pixyzrl.environments import Env
from pixyzrl.logger import Logger
from pixyzrl.memory import RolloutBuffer
from pixyzrl.models import PPO
from pixyzrl.trainer import OnPolicyTrainer
from pixyzrl.utils import print_latex

env = Env("CartPole-v1", 2, render_mode="rgb_array")
action_dim = env.action_space
obs_dim = env.observation_space

2. Define Actor and Critic Networks

class Actor(Categorical):
    def __init__(self):
        super().__init__(var=["a"], cond_var=["s"], name="p")

        self._prob = nn.Sequential(
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1),
        )

    def forward(self, s: torch.Tensor):
        probs = self._prob(s)
        return {"probs": probs}

class Critic(Deterministic):
    def __init__(self):
        super().__init__(var=["v"], cond_var=["o"], name="f")
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, o: torch.Tensor):
        return {"v": self.net(o)}

actor = Actor()
critic = Critic()

2.1 Display distributions as latex

>>> pixyzrl.utils.print_latex(actor)
p(a|o)

>>> pixyzrl.utils.print_latex(critic)
f(v|o)

3. Prepare PPO and Buffer

ppo = PPO(
    actor,
    critic,
    entropy_coef=0.01,
    mse_coef=0.5,
    lr_actor=1e-4,
    lr_critic=3e-4,
    device="mps",
)

buffer = RolloutBuffer(
    1024,
    {
        "obs": {
            "shape": (*obs_dim,),
            "map": "o",
        },
        "value": {
            "shape": (1,),
            "map": "v",
        },
        "action": {
            "shape": (action_dim,),
            "map": "a",
        },
        "reward": {
            "shape": (1,),
        },
        "done": {
            "shape": (1,),
        },
        "returns": {
            "shape": (1,),
            "map": "r",
        },
        "advantages": {
            "shape": (1,),
            "map": "A",
        },
    },
    2,
    advantage_normalization=True,
    lam=0.95,
    gamma=0.99,
)

3.1 Display model as latex

>>> print_latex(agent)
mean \left(1.0 MSE(f(v|o), r) - min \left(A clip(\frac{p(a|o)}{old(a|o)}, 0.8, 1.2), A \frac{p(a|o)}{old(a|o)}\right) \right)
latex

4. Training with Trainer

logger = Logger("cartpole_v1_ppo_discrete_trainer", log_types=["print"])
trainer = OnPolicyTrainer(env, buffer, ppo, "gae", "mps", logger=logger)
trainer.train(1000000, 32, 10, save_interval=50, test_interval=20)

Directory Structure

PixyzRL
├── docs
│   └── pixyz
│       └── README.pixyz.md
├── examples  # Example scripts
├── pixyzrl
│   ├── environments  # Environment wrappers
│   ├── models
│   │   ├── on_policy  # On-policy models (e.g., PPO, A2C)
│   │   └── off_policy  # Off-policy models (e.g., DQN)
│   ├── memory  # Experience replay & rollout buffer
│   ├── trainer  # Training utilities
│   ├── losses  # Loss function definitions
│   ├── logger  # Logging utilities
│   └── utils.py
└── pyproject.toml

Future Work

  • Implement Deep Q-Network (DQN)
  • Implement Dreamer (model-based RL)
  • Integrate with ChatGPT for automatic architecture generation
  • Integrate with Genesis

License

PixyzRL is released under the MIT License.

Community & Support

For questions and discussions, please visit: