Nufftax
Pure JAX implementation of the Non-Uniform Fast Fourier Transform (NUFFT)
Install / Use
/learn @GragasLab/NufftaxREADME
<p align="center"> <img src="docs/_static/mri_example.png" alt="MRI reconstruction example" width="100%"> </p>
Why nufftax?
A JAX package for NUFFT already exists: jax-finufft. However, it wraps the C++ FINUFFT library via Foreign Function Interface (FFI), exposing it through custom XLA calls. This approach can lead to:
- Kernel fusion issues on GPU — custom XLA calls act as optimization barriers, preventing XLA from fusing operations
- CUDA version matching — GPU support requires matching CUDA versions between JAX and the library
nufftax takes a different approach — pure JAX implementation:
- Fully differentiable — gradients w.r.t. both values and sample locations
- Pure JAX — works with
jit,grad,vmap,jvp,vjpwith no FFI barriers - GPU ready — runs on CPU/GPU without code changes, benefits from XLA fusion
- Pallas GPU kernels — fused Triton spreading kernels with 5-75x speedups on A100/H100
- All NUFFT types — Type 1, 2, 3 in 1D, 2D, 3D
JAX Transformation Support
| Transform | jit | grad/vjp | jvp | vmap |
|-----------|:-----:|:------------:|:-----:|:------:|
| Type 1 (1D/2D/3D) | ✅ | ✅ | ✅ | ✅ |
| Type 2 (1D/2D/3D) | ✅ | ✅ | ✅ | ✅ |
| Type 3 (1D/2D/3D) | ✅ | ✅ | ✅ | ✅ |
Differentiable inputs:
- Type 1:
gradw.r.t.c(strengths) andx, y, z(coordinates) - Type 2:
gradw.r.t.f(Fourier modes) andx, y, z(coordinates) - Type 3:
gradw.r.t.c(strengths),x, y, z(source coordinates), ands, t, u(target frequencies)
GPU Acceleration
On GPU, nufftax automatically dispatches spreading and interpolation to fused Pallas (Triton) kernels when the problem is large enough. This avoids materializing O(M × nspread^d) intermediate tensors and uses atomic scatter-add for spreading.
| Operation | Backend | Speedup vs pure JAX | |-----------|---------|---------------------| | 1D spread | A100 | 5–67x (M ≥ 100K) | | 1D spread | H100 | 4–75x (M ≥ 100K) | | 2D spread | A100/H100 | 2–3x (M ≥ 100K) |
The dispatch is transparent — no code changes required. On CPU or for small problems, the pure JAX path is used.
Installation
CPU only:
uv pip install nufftax
With CUDA 12 GPU support:
uv pip install "nufftax[cuda12]"
Development install (from source):
git clone https://github.com/GragasLab/nufftax.git
cd nufftax
uv pip install -e ".[dev]"
This installs test dependencies (pytest, ruff, finufft for comparison testing, pre-commit).
Development install with CUDA 12:
uv pip install -e ".[dev,cuda12]"
With docs dependencies:
uv pip install -e ".[docs]"
Quick Example
import jax
import jax.numpy as jnp
from nufftax import nufft1d1
# Irregular sample locations in [-pi, pi)
x = jnp.array([0.1, 0.7, 1.3, 2.1, -0.5])
c = jnp.array([1.0+0.5j, 0.3-0.2j, 0.8+0.1j, 0.2+0.4j, 0.5-0.3j])
# Compute Fourier modes
f = nufft1d1(x, c, n_modes=32, eps=1e-6)
# Differentiate through the transform
grad_c = jax.grad(lambda c: jnp.sum(jnp.abs(nufft1d1(x, c, n_modes=32)) ** 2))(c)
Documentation
- Quickstart — get running in 5 minutes
- Concepts — understand the mathematics
- Tutorials — MRI reconstruction, spectral analysis, optimization
- API Reference — complete function reference
License
MIT. Algorithm based on FINUFFT by the Flatiron Institute.
Citation
If you use nufftax in your research, please cite:
@software{nufftax,
author = {Gragas and Oudoumanessah, Geoffroy and Iollo, Jacopo},
title = {nufftax: Pure JAX implementation of the Non-Uniform Fast Fourier Transform},
url = {https://github.com/GragasLab/nufftax},
year = {2026}
}
@article{finufft,
author = {Barnett, Alexander H. and Magland, Jeremy F. and af Klinteberg, Ludvig},
title = {A parallel non-uniform fast Fourier transform library based on an ``exponential of semicircle'' kernel},
journal = {SIAM J. Sci. Comput.},
volume = {41},
number = {5},
pages = {C479--C504},
year = {2019}
}
Related Skills
node-connect
333.7kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
82.0kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
333.7kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
82.0kCommit, push, and open a PR
