Documentation | Examples | GitHub
PixyzRL is a reinforcement learning (RL) framework based on probabilistic generative models and Bayesian theory. Built on top of the Pixyz library, it provides a modular and flexible design to enable uncertainty-aware decision-making and improve sample efficiency. PixyzRL supports:
- Probabilistic Policy Optimization (e.g., PPO, A2C)
- On-policy and Off-policy Learning
- Memory Management for RL (Replay Buffer, Rollout Buffer)
- Advantage calculations are supported by MC / GAE / GRPO
- Integration with Gymnasium environments
- Logging and Model Training Utilities
CartPole-v1 | CarRacing-v3 |
---|---|
CartPole-v1.mp4 |
CarRacing-v3.mp4 |
examples/cartpole_v1_ppo_discrete_trainer.py | examples/car_racing_v3_ppo_continual.py |
Bipedal-Walker-v3 | Lunar-Lander-v3 |
bipedal-walker-v3.mp4 |
luna-lunder.mp4 |
examples/bipedal_walker_v3_ppo_continual.py | examples/lunar_lander_v3_ppo_continue_trainer.py |
CartPole-v1 ( GRPO ) TEST | |
test_760.mp4 |
- Python 3.10+
- PyTorch 2.5.1+
- Gymnasium (for environment interaction)
pip install pixyzrl
git clone https://github.com/ItoMasaki/PixyzRL.git
cd PixyzRL
pip install -e .
import torch
from pixyz.distributions import Categorical, Deterministic
from torch import nn
from pixyzrl.environments import Env
from pixyzrl.logger import Logger
from pixyzrl.memory import RolloutBuffer
from pixyzrl.models import PPO
from pixyzrl.trainer import OnPolicyTrainer
from pixyzrl.utils import print_latex
env = Env("CartPole-v1", 2, render_mode="rgb_array")
action_dim = env.action_space
obs_dim = env.observation_space
class Actor(Categorical):
def __init__(self):
super().__init__(var=["a"], cond_var=["s"], name="p")
self._prob = nn.Sequential(
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, action_dim),
nn.Softmax(dim=-1),
)
def forward(self, s: torch.Tensor):
probs = self._prob(s)
return {"probs": probs}
class Critic(Deterministic):
def __init__(self):
super().__init__(var=["v"], cond_var=["o"], name="f")
self.net = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, o: torch.Tensor):
return {"v": self.net(o)}
actor = Actor()
critic = Critic()
>>> pixyzrl.utils.print_latex(actor)
p(a|o)
>>> pixyzrl.utils.print_latex(critic)
f(v|o)
ppo = PPO(
actor,
critic,
entropy_coef=0.01,
mse_coef=0.5,
lr_actor=1e-4,
lr_critic=3e-4,
device="mps",
)
buffer = RolloutBuffer(
1024,
{
"obs": {
"shape": (*obs_dim,),
"map": "o",
},
"value": {
"shape": (1,),
"map": "v",
},
"action": {
"shape": (action_dim,),
"map": "a",
},
"reward": {
"shape": (1,),
},
"done": {
"shape": (1,),
},
"returns": {
"shape": (1,),
"map": "r",
},
"advantages": {
"shape": (1,),
"map": "A",
},
},
2,
advantage_normalization=True,
lam=0.95,
gamma=0.99,
)
>>> print_latex(agent)
mean \left(1.0 MSE(f(v|o), r) - min \left(A clip(\frac{p(a|o)}{old(a|o)}, 0.8, 1.2), A \frac{p(a|o)}{old(a|o)}\right) \right)

logger = Logger("cartpole_v1_ppo_discrete_trainer", log_types=["print"])
trainer = OnPolicyTrainer(env, buffer, ppo, "gae", "mps", logger=logger)
trainer.train(1000000, 32, 10, save_interval=50, test_interval=20)
PixyzRL
├── docs
│ └── pixyz
│ └── README.pixyz.md
├── examples # Example scripts
├── pixyzrl
│ ├── environments # Environment wrappers
│ ├── models
│ │ ├── on_policy # On-policy models (e.g., PPO, A2C)
│ │ └── off_policy # Off-policy models (e.g., DQN)
│ ├── memory # Experience replay & rollout buffer
│ ├── trainer # Training utilities
│ ├── losses # Loss function definitions
│ ├── logger # Logging utilities
│ └── utils.py
└── pyproject.toml
- Implement Deep Q-Network (DQN)
- Implement Dreamer (model-based RL)
- Integrate with ChatGPT for automatic architecture generation
- Integrate with Genesis
PixyzRL is released under the MIT License.
For questions and discussions, please visit: