Jaxadi
Transforms your CasADi functions into batchable JAX-compatible functions. By combining the power of CasADi with the flexibility of JAX, JAXADi enables the creation of efficient code that runs smoothly on CPUs, GPUs, and TPUs.
Install / Use
/learn @based-robotics/JaxadiREADME
JaxADi is a Python library designed to bridge the gap between casadi.Function and JAX-compatible functions. By leveraging the strengths of both CasADi and JAX, JAXADI opens up exciting opportunities for building highly efficient, batchable code that can be executed seamlessly across CPUs, GPUs, and TPUs.
JAXADI can be particularly useful in scenarios involving:
- Robotics simulations
- Optimal control problems
- Machine learning models with complex dynamics
Please dive into quick tutorial to get you up to speed in no time <a href="https://colab.research.google.com/github/based-robotics/jaxadi/blob/master/examples/_demo.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" width="120" align="center"/></a>
Installation
You can install JAXADI using pip:
<!-- Change once it will be realeased -->pip install jaxadi
For a complete environment setup for examples, we recommend using Conda/Mamba:
mamba env create -f environment.yml
Usage
JAXADI provides a simple and intuitive API:
import casadi as cs
import numpy as np
from jaxadi import translate, convert
from jax import numpy as jnp
x = cs.SX.sym("x", 2, 2)
y = cs.SX.sym("y", 2, 2)
# Define a complex nonlinear function
z = x @ y # Matrix multiplication
z_squared = z * z # Element-wise squaring
z_sin = cs.sin(z) # Element-wise sine
result = z_squared + z_sin # Element-wise addition
# Create the CasADi function
casadi_fn = cs.Function("complex_nonlinear_func", [x, y], [result])
# Get JAX-compatible function string representation
jax_fn_string = translate(casadi_fn)
print(jax_fn_string)
# Define JAX function from CasADi one
jax_fn = convert(casadi_fn, compile=True)
# Run compiled function
input_x = jnp.array(np.random.rand(2, 2))
input_y = jnp.array(np.random.rand(2, 2))
output = jax_fn(input_x, input_y)
<strong>Note:</strong> For now translation does not support functions with very large number of operations, due to the translation implementation. Secret component of translation is work-tree expansion, which might lead to large overhead in number of symbols. We are working on finding the compromise in both speed and extensive functions support.
Examples
JAXADI comes with several examples to help you get started:
-
Basic Translation: Learn how to translate CasADi functions to JAX.
-
Lowering Operations: Understand the lowering process in JaxADi.
-
Function Conversion: See how to fully convert CasADi functions to JAX.
-
Pendulum Rollout: Batched rollout of the nonlinear passive nonlinear pendulum
-
Pinocchio Integration: Explore how to convert Pinocchio-based CasADi functions to JAX.
-
MJX Comparison: Compare the transformed Pinnocchio forward kinematics with one provided by Mujoco MJX
Note: To run the Pinocchio and MJX examples, ensure you have them properly installed in your environment.
Performance Benchmarks

The process of benchmarking and evaluating the performance of Jaxadi is described in the benchmarks directory.
Contributing
We welcome contributions! Please see our Contributing Guide for more details.
Citation
If you use JaxADi in your research, please cite it as follows:
@misc{jaxadi2024,
title = {JaxADi: Bridging CasADi and JAX for Efficient Numerical Computing},
author = {Alentev, Igor and Kozlov, Lev and Nedelchev, Simeon},
year = {2024},
url = {https://github.com/based-robotics/jaxadi},
note = {Accessed: [Insert Access Date]}
}
Acknowledgements
This project draws inspiration from cusadi, with a focus on simplicity and JAX integration.
Contact
For questions, issues, or suggestions, please open an issue on our GitHub repository.
We hope JAXADI empowers your numerical computing and optimization tasks! Happy coding!
