PixyzRL
A Bayesian RL Framework with Probabilistic Generative Models
Install / Use
/learn @ItoMasaki/PixyzRLREADME
PixyzRL: A Reinforcement Learning Framework with Probabilistic Generative Models
Documentation | Examples | GitHub
What is PixyzRL?
PixyzRL is a reinforcement learning (RL) framework based on probabilistic generative models and Bayesian theory. Built on top of the Pixyz library, it provides a modular and flexible design to enable uncertainty-aware decision-making and improve sample efficiency. PixyzRL supports:
- Probabilistic Policy Optimization (e.g., PPO, A2C)
- Soft Actor-Critic (SAC) for continuous-control off-policy learning
- On-policy and Off-policy Learning
- Memory Management for RL (Replay Buffer, Rollout Buffer)
- Advantage calculations are supported by MC / GAE / GRPO
- Integration with Gymnasium environments
- Logging and Model Training Utilities
|CartPole-v1|CarRacing-v3| |:-:|:-:| |<video src="https://github.com/user-attachments/assets/9775e2ef-90dd-4dde-aaff-054af2674fbb"/>|<video src="https://github.com/user-attachments/assets/098b8da4-ce9f-4cff-ac5f-a8ff731589d7"/>| |examples/cartpole_v1_ppo_discrete_trainer.py|examples/car_racing_v3_ppo_continual.py| |Bipedal-Walker-v3|Lunar-Lander-v3| <video src="https://github.com/user-attachments/assets/72045515-8c56-4308-acd8-6aca48915024"/>|<video src="https://github.com/user-attachments/assets/aba4af68-1ded-4578-b463-529f0f6d0cde"/>| |examples/bipedal_walker_v3_ppo_continual.py|examples/lunar_lander_v3_ppo_continue_trainer.py| |CartPole-v1 ( GRPO ) TEST|| |<video src=https://github.com/user-attachments/assets/c4889b64-0936-4f22-8778-466b15e7a4cc/>||
SAC example: examples/pendulum_v1_sac.py
Installation
Requirements
- Python 3.10+
- PyTorch 2.5.1+
- Gymnasium (for environment interaction)
Install PixyzRL
Using pip
pip install pixyzrl
Install from Source
git clone https://github.com/ItoMasaki/PixyzRL.git
cd PixyzRL
pip install -e .
Quick Start
1. Set Up Environment
import torch
from pixyz.distributions import Categorical, Deterministic
from torch import nn
from pixyzrl.environments import Env
from pixyzrl.logger import Logger
from pixyzrl.memory import RolloutBuffer
from pixyzrl.models import PPO
from pixyzrl.trainer import OnPolicyTrainer
from pixyzrl.utils import print_latex
env = Env("CartPole-v1", 2, render_mode="rgb_array")
action_dim = env.action_space
obs_dim = env.observation_space
2. Define Actor and Critic Networks
class Actor(Categorical):
def __init__(self):
super().__init__(var=["a"], cond_var=["s"], name="p")
self._prob = nn.Sequential(
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, action_dim),
nn.Softmax(dim=-1),
)
def forward(self, s: torch.Tensor):
probs = self._prob(s)
return {"probs": probs}
class Critic(Deterministic):
def __init__(self):
super().__init__(var=["v"], cond_var=["o"], name="f")
self.net = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, o: torch.Tensor):
return {"v": self.net(o)}
actor = Actor()
critic = Critic()
2.1 Display distributions as latex
>>> pixyzrl.utils.print_latex(actor)
p(a|o)
>>> pixyzrl.utils.print_latex(critic)
f(v|o)
3. Prepare PPO and Buffer
ppo = PPO(
actor,
critic,
entropy_coef=0.01,
mse_coef=0.5,
lr_actor=1e-4,
lr_critic=3e-4,
device="mps",
)
buffer = RolloutBuffer(
1024,
{
"obs": {
"shape": (*obs_dim,),
"map": "o",
},
"value": {
"shape": (1,),
"map": "v",
},
"action": {
"shape": (action_dim,),
"map": "a",
},
"reward": {
"shape": (1,),
},
"done": {
"shape": (1,),
},
"returns": {
"shape": (1,),
"map": "r",
},
"advantages": {
"shape": (1,),
"map": "A",
},
},
2,
advantage_normalization=True,
lam=0.95,
gamma=0.99,
)
3.1 Display model as latex
>>> print_latex(agent)
mean \left(1.0 MSE(f(v|o), r) - min \left(A clip(\frac{p(a|o)}{old(a|o)}, 0.8, 1.2), A \frac{p(a|o)}{old(a|o)}\right) \right)
<img width="1272" alt="latex" src="https://github.com/user-attachments/assets/317f1f12-bf29-4015-87ee-1aa53ed6b26f" />
4. Training with Trainer
logger = Logger("cartpole_v1_ppo_discrete_trainer", log_types=["print"])
trainer = OnPolicyTrainer(env, buffer, ppo, "gae", "mps", logger=logger)
trainer.train(1000000, 32, 10, save_interval=50, test_interval=20)
Directory Structure
PixyzRL
├── docs
│ └── pixyz
│ └── README.pixyz.md
├── examples # Example scripts
├── pixyzrl
│ ├── environments # Environment wrappers
│ ├── models
│ │ ├── on_policy # On-policy models (e.g., PPO, A2C)
│ │ └── off_policy # Off-policy models (e.g., DQN)
│ ├── memory # Experience replay & rollout buffer
│ ├── trainer # Training utilities
│ ├── losses # Loss function definitions
│ ├── logger # Logging utilities
│ └── utils.py
└── pyproject.toml
Future Work
- [x] Implement Soft Actor-Critic (SAC)
- [x] Implement Deep Q-Network (DQN)
- [ ] Implement Dreamer (model-based RL)
- [ ] Integrate with ChatGPT for automatic architecture generation
- [ ] Integrate with Genesis
License
PixyzRL is released under the MIT License.
Community & Support
For questions and discussions, please visit:
Related Skills
claude-opus-4-5-migration
111.7kMigrate prompts and code from Claude Sonnet 4.0, Sonnet 4.5, or Opus 4.1 to Opus 4.5
model-usage
353.3kUse CodexBar CLI local cost usage to summarize per-model usage for Codex or Claude, including the current (most recent) model or a full model breakdown. Trigger when asked for model-level usage/cost data from codexbar, or when you need a scriptable per-model summary from codexbar cost JSON.
TrendRadar
51.3k⭐AI-driven public opinion & trend monitor with multi-platform aggregation, RSS, and smart alerts.🎯 告别信息过载,你的 AI 舆情监控助手与热点筛选工具!聚合多平台热点 + RSS 订阅,支持关键词精准筛选。AI 智能筛选新闻 + AI 翻译 + AI 分析简报直推手机,也支持接入 MCP 架构,赋能 AI 自然语言对话分析、情感洞察与趋势预测等。支持 Docker ,数据本地/云端自持。集成微信/飞书/钉钉/Telegram/邮件/ntfy/bark/slack 等渠道智能推送。
mcp-for-beginners
15.8kThis open-source curriculum introduces the fundamentals of Model Context Protocol (MCP) through real-world, cross-language examples in .NET, Java, TypeScript, JavaScript, Rust and Python. Designed for developers, it focuses on practical techniques for building modular, scalable, and secure AI workflows from session setup to service orchestration.
