SkillAgentSearch skills...

PixyzRL

A Bayesian RL Framework with Probabilistic Generative Models

Install / Use

/learn @ItoMasaki/PixyzRL

README

PixyzRL: A Reinforcement Learning Framework with Probabilistic Generative Models

PixyzRL Logo

License: MIT PyTorch Version Python Version workflow codecov Open in Visual Studio Code PyPI Downloads

Documentation | Examples | GitHub

What is PixyzRL?

PixyzRL is a reinforcement learning (RL) framework based on probabilistic generative models and Bayesian theory. Built on top of the Pixyz library, it provides a modular and flexible design to enable uncertainty-aware decision-making and improve sample efficiency. PixyzRL supports:

  • Probabilistic Policy Optimization (e.g., PPO, A2C)
  • Soft Actor-Critic (SAC) for continuous-control off-policy learning
  • On-policy and Off-policy Learning
  • Memory Management for RL (Replay Buffer, Rollout Buffer)
  • Advantage calculations are supported by MC / GAE / GRPO
  • Integration with Gymnasium environments
  • Logging and Model Training Utilities

|CartPole-v1|CarRacing-v3| |:-:|:-:| |<video src="https://github.com/user-attachments/assets/9775e2ef-90dd-4dde-aaff-054af2674fbb"/>|<video src="https://github.com/user-attachments/assets/098b8da4-ce9f-4cff-ac5f-a8ff731589d7"/>| |examples/cartpole_v1_ppo_discrete_trainer.py|examples/car_racing_v3_ppo_continual.py| |Bipedal-Walker-v3|Lunar-Lander-v3| <video src="https://github.com/user-attachments/assets/72045515-8c56-4308-acd8-6aca48915024"/>|<video src="https://github.com/user-attachments/assets/aba4af68-1ded-4578-b463-529f0f6d0cde"/>| |examples/bipedal_walker_v3_ppo_continual.py|examples/lunar_lander_v3_ppo_continue_trainer.py| |CartPole-v1 ( GRPO ) TEST|| |<video src=https://github.com/user-attachments/assets/c4889b64-0936-4f22-8778-466b15e7a4cc/>||

SAC example: examples/pendulum_v1_sac.py

Installation

Requirements

  • Python 3.10+
  • PyTorch 2.5.1+
  • Gymnasium (for environment interaction)

Install PixyzRL

Using pip

pip install pixyzrl

Install from Source

git clone https://github.com/ItoMasaki/PixyzRL.git
cd PixyzRL
pip install -e .

Quick Start

1. Set Up Environment

import torch
from pixyz.distributions import Categorical, Deterministic
from torch import nn

from pixyzrl.environments import Env
from pixyzrl.logger import Logger
from pixyzrl.memory import RolloutBuffer
from pixyzrl.models import PPO
from pixyzrl.trainer import OnPolicyTrainer
from pixyzrl.utils import print_latex

env = Env("CartPole-v1", 2, render_mode="rgb_array")
action_dim = env.action_space
obs_dim = env.observation_space

2. Define Actor and Critic Networks

class Actor(Categorical):
    def __init__(self):
        super().__init__(var=["a"], cond_var=["s"], name="p")

        self._prob = nn.Sequential(
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1),
        )

    def forward(self, s: torch.Tensor):
        probs = self._prob(s)
        return {"probs": probs}

class Critic(Deterministic):
    def __init__(self):
        super().__init__(var=["v"], cond_var=["o"], name="f")
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, o: torch.Tensor):
        return {"v": self.net(o)}

actor = Actor()
critic = Critic()

2.1 Display distributions as latex

>>> pixyzrl.utils.print_latex(actor)
p(a|o)

>>> pixyzrl.utils.print_latex(critic)
f(v|o)

3. Prepare PPO and Buffer

ppo = PPO(
    actor,
    critic,
    entropy_coef=0.01,
    mse_coef=0.5,
    lr_actor=1e-4,
    lr_critic=3e-4,
    device="mps",
)

buffer = RolloutBuffer(
    1024,
    {
        "obs": {
            "shape": (*obs_dim,),
            "map": "o",
        },
        "value": {
            "shape": (1,),
            "map": "v",
        },
        "action": {
            "shape": (action_dim,),
            "map": "a",
        },
        "reward": {
            "shape": (1,),
        },
        "done": {
            "shape": (1,),
        },
        "returns": {
            "shape": (1,),
            "map": "r",
        },
        "advantages": {
            "shape": (1,),
            "map": "A",
        },
    },
    2,
    advantage_normalization=True,
    lam=0.95,
    gamma=0.99,
)

3.1 Display model as latex

>>> print_latex(agent)
mean \left(1.0 MSE(f(v|o), r) - min \left(A clip(\frac{p(a|o)}{old(a|o)}, 0.8, 1.2), A \frac{p(a|o)}{old(a|o)}\right) \right)
<img width="1272" alt="latex" src="https://github.com/user-attachments/assets/317f1f12-bf29-4015-87ee-1aa53ed6b26f" />

4. Training with Trainer

logger = Logger("cartpole_v1_ppo_discrete_trainer", log_types=["print"])
trainer = OnPolicyTrainer(env, buffer, ppo, "gae", "mps", logger=logger)
trainer.train(1000000, 32, 10, save_interval=50, test_interval=20)

Directory Structure

PixyzRL
├── docs
│   └── pixyz
│       └── README.pixyz.md
├── examples  # Example scripts
├── pixyzrl
│   ├── environments  # Environment wrappers
│   ├── models
│   │   ├── on_policy  # On-policy models (e.g., PPO, A2C)
│   │   └── off_policy  # Off-policy models (e.g., DQN)
│   ├── memory  # Experience replay & rollout buffer
│   ├── trainer  # Training utilities
│   ├── losses  # Loss function definitions
│   ├── logger  # Logging utilities
│   └── utils.py
└── pyproject.toml

Future Work

  • [x] Implement Soft Actor-Critic (SAC)
  • [x] Implement Deep Q-Network (DQN)
  • [ ] Implement Dreamer (model-based RL)
  • [ ] Integrate with ChatGPT for automatic architecture generation
  • [ ] Integrate with Genesis

License

PixyzRL is released under the MIT License.

Community & Support

For questions and discussions, please visit:

Related Skills

View on GitHub
GitHub Stars10
CategoryEducation
Updated1mo ago
Forks0

Languages

Python

Security Score

95/100

Audited on Mar 10, 2026

No findings