JAXAtari
No description available
Install / Use
/learn @k4ntz/JAXAtariREADME
🎮 JAXAtari: JAX-Based Object-Centric Atari Environments
Quentin Delfosse, Raban Emunds, Jannis Blüml, Paul Seitz, Sebastian Wette, Dominik Mandok AI/ML Lab – TU Darmstadt
A GPU-accelerated, object-centric Atari environment suite built with JAX for fast, scalable reinforcement learning research.
JAXAtari introduces a GPU-accelerated, object-centric Atari environment framework powered by JAX. Inspired by OCAtari, this framework enables up to 16,000x faster training speeds through just-in-time (JIT) compilation, vectorization, and massive parallelization on GPU. Similar to HackAtari, it implements a number of small game modifications , for simple testing of the generalization capabilities of agents.
<!-- --- -->Features
- Object-centric extraction of Atari game states with structured observations
- JAX-based vectorized execution with full GPU support and JIT compilation
- Comprehensive wrapper system for different observation types (pixel, object-centric, combined)
- Game modifications to test agent generalization across distribution shifts (+ simple implementation of custom modifications).
Getting Started
<!-- ### Prerequisites -->Install
python3 -m venv .venv
source .venv/bin/activate
python3 -m pip install -U pip
pip3 install -e .
Note: This will install JAX without GPU acceleration.
CUDA Users should run the following to add GPU support:
pip install -U "jax[cuda12]"
For other accelerator types, please follow the instructions here.
Note: Next, you need to download the original Atari 2600 sprites. Before downloading, you will be asked to confirm ownership of the original ROMs.
.venv/bin/install_sprites
Usage
Basic Environment Creation
The main entry point is the make() function:
import jax
import jaxatari
# Create an environment
env = jaxatari.make("pong") # or "seaquest", "kangaroo", "freeway", etc.
# Get available games
available_games = jaxatari.list_available_games()
print(f"Available games: {available_games}")
Using Modifications
JAXAtari provides some pre-implemented game modifications:
import jax
import jaxatari
# Create base environment
base_env = jaxatari.make("pong")
# Apply LazyEnemy modification
mod_env = jaxatari.modify(base_env, "pong", "LazyEnemyWrapper")
Using Wrappers
JAXAtari provides a comprehensive wrapper system for different use cases:
import jax
import jaxatari
from jaxatari.wrappers import (
AtariWrapper,
ObjectCentricWrapper,
PixelObsWrapper,
PixelAndObjectCentricWrapper,
FlattenObservationWrapper,
LogWrapper
)
# Create base environment
base_env = jaxatari.make("pong")
# Apply wrappers for different observation types
env = AtariWrapper(base_env, frame_stack_size=4, frame_skip=4)
env = ObjectCentricWrapper(env) # Returns flattened object features
# OR
env = PixelObsWrapper(AtariWrapper(base_env)) # Returns pixel observations
# OR
env = PixelAndObjectCentricWrapper(AtariWrapper(base_env)) # Returns both
# OR
env = FlattenObservationWrapper(ObjectCentricWrapper(AtariWrapper(base_env))) # Returns flattened observations
# Add logging wrapper for training
env = LogWrapper(env)
Vectorized Training Example
import jax
import jaxatari
from jaxatari.wrappers import AtariWrapper, ObjectCentricWrapper
# Create environment with wrappers
base_env = jaxatari.make("pong")
env = FlattenObservationWrapper(ObjectCentricWrapper(AtariWrapper(base_env)))
n_envs = 1024
rng = jax.random.PRNGKey(0)
reset_keys = jax.random.split(rng, n_envs)
# Initialize n_envs parallel environments
init_obs, env_state = jax.vmap(env.reset)(reset_keys)
# Take one random step in each env
action = jax.random.randint(rng, (n_envs,), 0, env.action_space().n)
new_obs, new_env_state, reward, done, info = jax.vmap(env.step)(env_state, action)
# Take 100 steps with scan
def step_fn(carry, unused):
_, env_state = carry
new_obs, new_env_state, reward, done, info = jax.vmap(env.step)(env_state, action)
return (new_obs, new_env_state), (reward, done, info)
carry = (init_obs, env_state)
_, (rewards, dones, infos) = jax.lax.scan(
step_fn, carry, None, length=100
)
Manual Game Play
Run a game manually with human input (e.g. on Pong):
pip install pygame
python3 scripts/play.py -g Pong
Supported Games
| Game | Supported | |----------|-----------| | Freeway | ✅ | | Kangaroo | ✅ | | Pong | ✅ | | Seaquest | ✅ |
More games can be added via the uniform wrapper system.
Wrapper System
JAXAtari provides several wrappers to customize environment behavior:
AtariWrapper: Base wrapper with frame stacking, frame skipping, and sticky actionsObjectCentricWrapper: Returns flattened object-centric features (2D array:[frame_stack, features])PixelObsWrapper: Returns pixel observations (4D array:[frame_stack, height, width, channels])PixelAndObjectCentricWrapper: Returns both pixel and object-centric observationsFlattenObservationWrapper: Flattens any observation structure to a single 1D arrayLogWrapper: Tracks episode returns and lengths for trainingMultiRewardLogWrapper: Tracks multiple reward components separately
Wrapper Usage Patterns
# For pure RL with object-centric features (recommended)
env = ObjectCentricWrapper(AtariWrapper(jaxatari.make("pong")))
# For computer vision approaches
env = PixelObsWrapper(AtariWrapper(jaxatari.make("pong")))
# For multi-modal approaches
env = PixelAndObjectCentricWrapper(AtariWrapper(jaxatari.make("pong")))
# For training with logging
env = LogWrapper(ObjectCentricWrapper(AtariWrapper(jaxatari.make("pong"))))
# All wrapper combinations can be flattened using the FlattenObservationWrapper
Contributing
Contributions are welcome!
- Fork this repository
- Create your feature branch:
git checkout -b feature/my-feature - Commit your changes:
git commit -m 'Add some feature' - Push to the branch:
git push origin feature/my-feature - Open a pull request
Cite us
@misc{jaxatari2026,
author = {Delfosse, Quentin and Emunds, Raban and Seitz, Paul and Wette, Sebastian and Bl{\"u}ml, Jannis and Kersting, Kristian},
title = {JAXAtari: A High-Performance Framework for Reasoning agents},
year = {2026},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {https://github.com/k4ntz/JAXAtari/},
}
License
This project is licensed under the MIT License.
See the LICENSE file for details.
Related Skills
node-connect
349.9kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
109.8kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
349.9kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
349.9kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
