Sbx
SBX: Stable Baselines Jax (SB3 + Jax) RL algorithms
Install / Use
/learn @araffin/SbxREADME
Stable Baselines Jax (SB3 + Jax = SBX)
Proof of concept version of Stable-Baselines3 in Jax.
Implemented algorithms:
- Soft Actor-Critic (SAC) and SAC-N
- Truncated Quantile Critics (TQC)
- Dropout Q-Functions for Doubly Efficient Reinforcement Learning (DroQ)
- Proximal Policy Optimization (PPO)
- Deep Q Network (DQN)
- Twin Delayed DDPG (TD3)
- Deep Deterministic Policy Gradient (DDPG)
- Batch Normalization in Deep Reinforcement Learning (CrossQ)
- Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)
Note: parameter resets for off-policy algorithms can be activated by passing a list of timesteps to the model constructor (ex: param_resets=[int(1e5), int(5e5)] to reset parameters and optimizers after 100_000 and 500_000 timesteps.
Install using pip
For the latest master version:
pip install git+https://github.com/araffin/sbx
or:
pip install sbx-rl
Example
import gymnasium as gym
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
env = gym.make("Pendulum-v1", render_mode="human")
model = TQC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)
vec_env = model.get_env()
obs = vec_env.reset()
for _ in range(1000):
vec_env.render()
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.close()
Using SBX with the RL Zoo
Since SBX shares the SB3 API, it is compatible with the RL Zoo, you just need to override the algorithm mapping:
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# See note below to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
if __name__ == "__main__":
train()
Then you can run this script as you would with the RL Zoo:
python train.py --algo sac --env HalfCheetah-v4 -params train_freq:4 gradient_steps:4 -P
The same goes for the enjoy script:
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# See note below to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
if __name__ == "__main__":
enjoy()
Note about DroQ
DroQ is a special configuration of SAC.
To have the algorithm with the hyperparameters from the paper, you should use (using RL Zoo config format):
HalfCheetah-v4:
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_starts: 10000
gradient_steps: 20
policy_delay: 20
policy_kwargs: "dict(dropout_rate=0.01, layer_norm=True)"
and then using the RL Zoo script defined above: python train.py --algo sac --env HalfCheetah-v4 -c droq.yml -P.
We recommend playing with the policy_delay and gradient_steps parameters for better speed/efficiency.
Having a higher learning rate for the q-value function is also helpful: qf_learning_rate: !!float 1e-3.
Note: when using the DroQ configuration with CrossQ, you should set layer_norm=False as there is already batch normalization.
Note about SimBa
SimBa is a special network architecture for off-policy algorithms (SAC, TQC, ...).
Some recommended hyperparameters (tested on MuJoCo and PyBullet environments):
import optax
default_hyperparams = dict(
n_envs=1,
n_timesteps=int(1e6),
policy="SimbaPolicy",
learning_rate=3e-4,
# qf_learning_rate=1e-3,
policy_kwargs={
"optimizer_class": optax.adamw,
# "optimizer_kwargs": {"weight_decay": 0.01},
# Note: here [128] represent a residual block, not just a single layer
"net_arch": {"pi": [128], "qf": [256, 256]},
"n_critics": 2,
},
learning_starts=10_000,
# Important: input normalization using VecNormalize
normalize={"norm_obs": True, "norm_reward": False},
)
hyperparams = {}
# You can also loop gym.registry
for env_id in [
"HalfCheetah-v4",
"HalfCheetahBulletEnv-v0",
"Ant-v4",
]:
hyperparams[env_id] = default_hyperparams
and then using the RL Zoo script defined above: python train.py --algo tqc --env HalfCheetah-v4 -c simba.py -P.
Benchmark
A partial benchmark can be found on OpenRL Benchmark where you can also find several reports.
Citing the Project
To cite this repository in publications, please cite the SB3 paper:
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
Maintainers
Stable-Baselines3 is currently maintained by Ashley Hill (aka @hill-a), Antonin Raffin (aka @araffin), Maximilian Ernestus (aka @ernestum), Adam Gleave (@AdamGleave), Anssi Kanervisto (@Miffyli) and Quentin Gallouédec (@qgallouedec).
Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.
How To Contribute
To any interested in making the baselines better, there is still some documentation that needs to be done. If you want to contribute, please read CONTRIBUTING.md guide first.
Contributors
We would like to thank our contributors: @jan1854.
Related Skills
node-connect
343.3kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
92.1kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
343.3kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
343.3kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
