Hydrax
Sampling-based model predictive control on GPU with JAX/MJX
Install / Use
/learn @vincekurtz/HydraxREADME
Hydrax
Sampling-based model predictive control on GPU with JAX and MuJoCo MJX.

About
Hydrax implements various sampling-based MPC algorithms on GPU. It is heavily inspired by MJPC, but focuses exclusively on sampling-based algorithms, runs on hardware accelerators via JAX and MJX, and includes support for online domain randomization.
Available methods:
| Algorithm | Description | Import |
| --- | --- | --- |
| Predictive sampling | Take the lowest-cost rollout at each iteration. | hydrax.algs.PredictiveSampling |
| MPPI | Take an exponentially weighted average of the rollouts. | hydrax.algs.MPPI |
| Cross Entropy Method | Fit a Gaussian distribution to the n best "elite" rollouts. | hydrax.algs.CEM |
| DIAL-MPC | MPPI with dual-loop, annealed sampling covariance. | hydrax.algs.DIAL |
| Evosax | Any of the 30+ evolution strategies implemented in evosax. Includes CMA-ES, differential evolution, and many more. | hydrax.algs.Evosax |
News
- February 15, 2026. Our preferred package manager is now
uv, which is lighter weight and offers improved
reproducibility via
uv.lock. Conda use is still possible, but we recommend switching to uv for the best experience. Note thathydraxnow requires CUDA 13. - April 13, 2024. Large changes to the core
hydraxfunctionality + some breaking changes.- Splines (and their knots) are now the default parameterization of the control signals and decision variables! Before, it was always assumed that every control step applied a zero-order hold. This is now a special case of the new spline parameterization.
- All "time-based" variables are now specified in the controller.
Previously, variables like the planning horizon and number of sim steps
per control step were specified in the task. Now, the main variables to
specify are
plan_horizon(the length of the planning horizon in seconds),num_knots(the number of spline knots to plan with), anddt(the planning time step (in the model XML)). This is a breaking change!
Setup (uv)
Clone this repository:
git clone https://github.com/vincekurtz/hydrax.git
cd hydrax
Install the package and dependencies:
uv sync
You can use uv to run examples and tests directly:
uv run python examples/pendulum.py mppi # pendulum swing up with MPPI
uv run pytest # run unit tests
Or you can activate the virtual environment (created by uv sync) and run
things directly:
source .venv/bin/activate
python examples/pendulum.py mppi # pendulum swing up with MPPI
pytest # run unit tests
Setup (conda)
Set up a conda env with cuda support (first time only):
conda env create -f environment.yml
Enter the conda env:
conda activate hydrax
Install the package and dependencies:
pip install -e .
(Optional) Set up pre-commit hooks:
pre-commit autoupdate
pre-commit install
(Optional) Run unit tests:
pytest
Examples
Launch an interactive pendulum swingup simulation with predictive sampling:
python examples/pendulum.py ps
Launch an interactive humanoid standup simulation (shown above) with MPPI and online domain randomization:
python examples/humanoid_standup.py
Other demos can be found in the examples folder.
Design your own task
Hydrax considers optimal control problems of the form
\begin{align}
\min_{u_t} & \sum_{t=0}^{T} \ell(x_t, u_t) + \phi(x_{T+1}), \\
\mathrm{s.t.}~& x_{t+1} = f(x_t, u_t),
\end{align}
where $x_t$ is the system state and $u_t$ is the control input at time $t$, and the system dynamics $f(x_t, u_t)$ are defined by a mujoco MJX model.
To design a new task, you'll need to specify the cost ($\ell$, $\phi$) and the
dynamics ($f$). You can do this by creating a new class that inherits from
hydrax.task_base.Task:
class MyNewTask(Task):
def __init__(self, ...):
# Create or load a mujoco model defining the dynamics (f)
mj_model = ...
super().__init__(mj_model, ...)
def running_cost(self, x: mjx.Data, u: jax.Array) -> float:
# Implement the running cost (l) here
return ...
def terminal_cost(self, x: jax.Array) -> float:
# Implement the terminal cost (phi) here
return ...
The dynamics ($f$) are specified by a mujoco.MjModel that is passed to the
constructor. Other constructor arguments specify the planning horizon $T$ and
other details.
For the cost, simply implement the running_cost ($\ell$) and terminal_cost
($\phi$) methods.
See hydrax.tasks for some example task implementations.
Implement your own control algorithm
Hydrax considers sampling-based MPC algorithms that follow the following generic structure:

The meaning of the parameters $\theta$ differ depending on the algorithm. In predictive sampling, for example, $\theta$ is the mean of a Gaussian distribution that the controls $U = [u_0, u_1, ...]$ are sampled from.
To implement a new planning algorithm, you'll need to inherit from
hydrax.alg_base.SamplingBasedController and implement
the three methods shown below:
class MyControlAlgorithm(SamplingBasedController):
def init_params(self) -> Any:
# Initialize the policy parameters (theta).
...
return params
def sample_knots(self, params: Any) -> Tuple[jax.Array, Any]:
# Sample the spline knots U from the policy. Return the samples
# and the (updated) parameters.
...
return controls, params
def update_params(self, params: Any, rollouts: Trajectory) -> Any:
# Update the policy parameters (theta) based on the trajectory data
# (costs, controls, observations, etc) stored in the rollouts.
...
return new_params
These three methods define a unique sampling-based MPC algorithm. Hydrax takes
care of the rest, including parallelizing rollouts on GPU and collecting the
rollout data in a Trajectory object.
Note: because of
the way JAX handles randomness,
we assume the PRNG key is stored as one of the parameters $\theta$. This is why
sample_knots returns updated parameters along with the control samples
$U^{(1:N)}$.
For some examples, take a look at hydrax.algs.
Domain Randomization
One benefit of GPU-based simulation is the ability to roll out trajectories with different model parameters in parallel. Such domain randomization can improve robustness and help reduce the sim-to-real gap.
Hydrax provides tools to make online domain randomization easy. In particular,
you can add domain randomization to any task by overriding the
domain_randomize_model and domain_randomize_data methods of a given
Task. For example:
class MyDomainRandomizedTask(Task):
...
def domain_randomize_model(self, rng: jax.Array) -> Dict[str, jax.Array]:
"""Randomize the friction coefficients."""
n_geoms = self.model.geom_friction.shape[0]
multiplier = jax.random.uniform(rng, (n_geoms,), minval=0.5, maxval=2.0)
new_frictions = self.model.geom_friction.at[:, 0].set(
self.model.geom_friction[:, 0] * multiplier
)
return {"geom_friction": new_frictions}
def domain_randomize_data(self, data: mjx.Data, rng: jax.Array) -> Dict[str, jax.Array]:
"""Randomly shift the measured configurations."""
shift = 0.005 * jax.random.normal(rng, (self.model.nq,))
return {"qpos": data.qpos + shift}
These methods return a dictionary of randomized parameters, given a particular
random seed (rng). Hydrax takes care of the details of applying these
parameters to the model and data, and performing rollouts in parallel.
To use a domain randomized task, you'll need to tell the planner how many random
models to use with the num_randomizations flag. For example,
task = MyDomainRandomizedTask(...)
ctrl = PredictiveSampling(
task,
num_samples=32,
noise_level=0.1,
num_randomizations=16,
)
sets up a predictive sampling controller that rolls out 32 control sequences across 16 domain randomized models.
The resulting Trajectory rollouts will have
dimensions (num_randomizations, num_samples, num_time_steps, ...).
Risk Strategies
With domain randomization, we need to somehow aggregate costs across the
different domains. By default, we take the average cost over the randomizations,
similar to domain randomization in reinforcement learning. Other strategies are
available via the RiskStrategy interface.
For example, to plan using the worst-case maximum cost across randomizations:
from hydrax.risk import WorstCase
...
task = MyDomainRandomizedTask(...)
ctrl = PredictiveSampling(
task,
num_samples=32,
noise_level=0.1,
num_randomizations=16,
risk_strategy=WorstCase(),
)
Available risk strategies:
| Strategy | Description | Import | | --- | --- | --- | | Average (default) | Take t
Related Skills
node-connect
339.5kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
83.9kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
339.5kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
83.9kCommit, push, and open a PR
