SkillAgentSearch skills...

CrazyRL

JAX and PZ RL envs + algorithms for swarms of CrazyFlies

Install / Use

/learn @ffelten/CrazyRL
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

The CPU-based environments have been moved into MOMAland for long-term maintenance. I suggest using these if you do not need the Jax-based implementations.

<img src="surround.gif" alt="Swarm" align="right" width="50%"/>

pre-commit Code style: black Imports: isort Test: pytest YouTube

CrazyRL

A hardware-accelerated library for doing Multi-Agent Reinforcement Learning with Crazyflie drones. A video showing the results with real drones in our lab is available on YouTube.

It has:

  • ⚡️ A lightweight and fast simulator that is good enough to control Crazyflies in practice;
  • 🤝 A set of environments implemented in Python and Numpy, under the PettingZoo parallel API;
  • 🚀 The same environments implemented in Jax, that can be run fully on GPU;
  • 🧠 MARL algorithms implemented in Jax, both for PettingZoo and for full Jax environments;
  • 🚁 A set of utilities based on the cflib to control actual Crazyflies;
  • ✅ Good quality, tested and documented Python code;

The real-life example shown in the video is the result of executing the policies in real-life after learning in the lightweight simulator. Once the environment trained it can be displayed on simulation environment or in reality with the Crazyflies.

Environments

The red balls represent the position of the controlled drones.

Circle

The drones learn to perform a coordinated circle.

<img src="circle.gif" alt="Circle" width="500"/>

The yellow balls represent the target position of the drones.

Available in Numpy and JAX version.

Surround

The drones learn to surround a fixed target point.

<img src="surround.gif" alt="Surround" width="500"/>

The yellow ball represents the target the drones have to surround.

Available in Numpy and JAX version.

Escort

The drones learn to escort a target moving straight to one point to another.

<img src="escort.gif" alt="Escort" width="500"/>

The yellow ball represents the target the drones have to surround.

Available in Numpy and JAX version.

Catch

The drones learn to catch a target trying to escape.

<img src="catch.gif" alt="Catch" width="500"/>

The yellow ball represents the target the drones have to surround.

Available in Numpy and JAX version.

Learning

We provide implementations of MAPPO [1] both compatible with a CPU env (PettingZoo parallel API), and a GPU env (our JAX API). These implementations should be very close to each others in terms of sample efficiency but the GPU version is immensely faster in terms of time. We also have a multi-agent version of SAC, MASAC, which is compatible with the CPU envs.

<img src="results/Circle.png"> In the above image, we can see that sample efficiency of both MAPPO versions are very close, but the JAX version is much faster in terms of time. Notice that the Jax version can be improved further by relying on vectorized envs.

Multi-Objective Multi-Agent RL

When vmapping over a set of weight vectors to perform MOMARL learning, we achieve sublinear scaling w.r.t. the number of Pareto optimal policies we aim at learning:

<img src="results/mo/training_time.png" alt="Learning time"/>

API

There are examples of usage in the test files and main methods of the environments. Moreover, the learning folder contains examples of MARL algorithms.

Python/Numpy version

Basic version which can be used for training, simulation and the real drones. It follows the PettingZoo parallel API.

Execution :

from crazy_rl.multi_agent.numpy.circle.circle import Circle

env: ParallelEnv = Circle(
    drone_ids=np.array([0, 1]),
    render_mode="human",    # or real, or None
    init_flying_pos=np.array([[0, 0, 1], [2, 2, 1]]),
)

obs, info = env.reset()

done = False
while not done:
    # Execute policy for each agent
    actions: Dict[str, np.ndarray] = {}
    for agent_id in env.possible_agents:
        actions[agent_id] = actor.get_action(obs[agent_id], agent_id)

    obs, _, terminated, truncated, info = env.step(actions)
    done = terminated or truncated

You can have a look at the learning/ folder to see how we execute pre-trained policies.

JAX version

This version is specifically optimized for GPU usage and intended for agent training purposes. However, simulation and real-world functionalities are not available in this version.

Moreover, it is not compliant with the PettingZoo API as it heavily relies on functional programming. We sacrificed the API compatibility for huge performance gains.

Some functionalities are automatically done by wrappers, such as vmap, enabling parallelized training, allowing to leverage all the cores on the GPU. While it offers faster performance on GPUs, it may exhibit slower execution on CPUs.

You can find other wrappers you may need defined in jax_wrappers.

Execution:

from jax import random
from crazy_rl.multi_agent.jax.circle.circle import Circle

parallel_env = Circle(
        num_drones=5,
        init_flying_pos=jnp.array([[0.0, 0.0, 1.0], [2.0, 1.0, 1.0], [0.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 0.0, 1.0]]),
        num_intermediate_points=100,
    )

num_envs = 3  # number of envs in parallel
seed = 5  # PRNG seed
key = random.PRNGKey(seed)
key, subkeys = random.split(key)
subkeys = random.split(subkeys, num_envs)

# Wrappers
env = AutoReset(env)  # Auto reset the env when done, stores additional info in the dict
env = VecEnv(env)  # Vectorizes the env public methods

obs, info, state = env.reset(subkeys)

# Example of stepping through the 5 parallel environments
for i in range(301):
    actions = jnp.zeros((num_envs, parallel_env.num_drones, parallel_env.action_space(0).shape[0]))
    for env_id, obs in enumerate(obs):
        for agent_id in range(parallel_env.num_drones):
            key, subkey = random.split(key)
            actions[env_id, agent_id] = actor.get_action(obs, agent_id, subkey) # YOUR POLICY HERE

    key, *subkeys = random.split(key, num_envs + 1)
    obs, rewards, term, trunc, info, state = env.step(state, actions, jnp.stack(subkeys))

    # where you would learn or add to buffer

Install & run

Numpy version

poetry install
poetry run python crazy_rl/multi_agent/numpy/circle/circle.py

JAX on CPU

poetry install
poetry run python crazy_rl/multi_agent/jax/circle/circle.py

JAX on GPU

JAX GPU support is not included in the pyproject.toml file, as JAX CPU is the default option. Therefore, you need to manually install JAX GPU and disregard the poetry requirements for this purpose.

poetry install
poetry shell
pip install --upgrade pip

# Using CUDA 12
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Or using CUDA 11
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

python crazy_rl/learning/mappo.py

Please refer to the JAX installation GitHub page for the specific CUDA version requirements.

After installation, the JAX version automatically utilizes the GPU as the default device. However, if you prefer to switch to the CPU without reinstalling, you can manually set the device using the following command:

jax.config.update("jax_platform_name", "cpu")

Modes

Simulation

render_mode = "human"

The simulation is a simple particle representation on a 3D cartesian reference based on Crazyflie lighthouse reference frame. It is sufficient since the control of the CrazyFlies is high-level and precise enough.

Available in the Numpy version.

Real

render_mode = "real"

In our experiments, positioning was managed by Lighthouse positioning. It can probably be deployed with other positioning systems too.

Available in the Numpy version.

Guidelines

Firstly configuration

Related Skills

View on GitHub
GitHub Stars97
CategoryDevelopment
Updated1d ago
Forks11

Languages

Python

Security Score

95/100

Audited on Apr 5, 2026

No findings