SkillAgentSearch skills...

Probdiffeq

Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem.

Install / Use

/learn @pnkraemer/Probdiffeq
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

Probabilistic solvers in JAX

CI PyPI version License Python versions

Probdiffeq implements adaptive probabilistic numerical solvers for differential equations (ODEs). It builds on JAX, thus inheriting automatic differentiation, vectorisation, and GPU acceleration.

⚠️ Probdiffeq is an active research project. Expect rough edges and sudden API changes.

Features:

  • ⚡ Automatic calibration and step-size adaptation
  • ⚡ Stable implementations of filtering, smoothing, and other estimation strategies
  • ⚡ Custom information operators, dense output, posterior sampling, and prior distributions.
  • ⚡ Efficient handling of high-dimensional problems through state-space model factorisations
  • ⚡ Parameter estimation
  • ⚡ Taylor-series estimation with and without automatic differentiation
  • ⚡ Seamless interoperability with Optax, BlackJAX, and other JAX-based libraries
  • ⚡ Numerous examples (basic and advanced) -- see the documentation

Quickstart: See here for a minimal example to get you started.

Contributing: Contributions are very welcome!

  • Browse open issues (look for “good first issue”).
  • Check the developer documentation.
  • Open an issue for feature requests or ideas.

Related projects:

The docs include guidance on migrating from these packages. Missing something? Open an issue or pull request!

You might also like:

  • diffeqzoo: reference implementations of differential equations in NumPy and JAX
  • probfindiff: probabilistic finite-difference methods in JAX

Installation

Install the latest release from PyPI:

pip install probdiffeq

This assumes JAX is already installed.

To install the library with JAX (using the CPU backend):

pip install probdiffeq[cpu]

Compatibility note: For GPU support, install JAX with CUDA following JAX installation instructions.

Versioning: Probdiffeq follows semantic versioning via 0.MINOR.PATCH:

  • PATCH: increase with bugfixes & new features
  • MINOR: increase with breaking changes

See semantic versioning. Notably, Probdiffeq's API is not guaranteed to be stable, but we do our best to follow the versioning scheme so that downstream projects remain reproducible.

Coming from other ODE solver libraries?

This guide helps you get started with Probdiffeq for solving ordinary differential equations (ODEs), especially if you are familiar with other probabilistic or non-probabilistic ODE solvers in Python or Julia.

Probdiffeq is a JAX library that focuses on state-space-model-based formulations of probabilistic IVP solvers. For what this means, have a look at this thesis.

Probabilistic ODE solvers in a nutshell: Unlike traditional solvers that return a single point estimate of the solution, probabilistic solvers return a posterior distribution. This built-in uncertainty quantification reflects the numerical error (and other modelling choices), and helps you make better decisions during the simulation and in downstream tasks, for example, during adaptive time-stepping, parameter estimation, or in physics-informed machine learning applications.

From traditional (non-probabilistic) ODE solvers

If you're coming from traditional ODE solvers like SciPy's integrate.solve_ivp, JAX's jax.experimental.odeint, or Diffrax, you'll notice some fundamental differences:

Key differences:

  • Solutions as distributions: Probdiffeq returns posterior distributions instead of point estimates. You automatically get uncertainty quantification, which you can use for sensitivity analysis, model selection, or downstream decision-making.
  • Fine-grained control: Probdiffeq lets you customise the probabilistic model (prior distribution, calibration method, linearization order), giving you more control over solver behaviour. Since the modelling matters, everyone has to build their own custom solvers, and default behaviour is rare.
  • Explicit solver modes: Instead of a single solve() function, Probdiffeq offers specialised functions for targeting terminal values, checkpoints, or fixed grids. This is not just easier to maintain, but also enables better performance by easier code optimisation and specialised default parameters (e.g. whether or not timesteps should be clipped before checkpoints).

Mapping from Diffrax methods: If you're switching from Diffrax, here's how to achieve similar accuracy levels by adjusting Taylor coefficients and linearization order:

| Diffrax method | ProbDiffEq approach | |--|--| | Heun(), Midpoint() | Use 2 Taylor coefficients with zeroth-order linearization | | Tsit5(), Dopri5() | Use 5 Taylor coefficients with zeroth-order linearization | | Dopri8() | Use 8 Taylor coefficients with zeroth-order linearization | | Kvaerno3(), Kvaerno5() | Use 2 to 5 Taylor coefficients with first-order linearization |

Tidbit: Probabilistic solvers based on the once-integrated Wiener/OU processes are closely related to (different versions of) the trapezoidal rule (Schober et al., 2019; Bosch et al., 2023). Higher-order methods connect to more general linear multistep methods (Schober et al., 2019).

  • Michael Schober, Simo Särkkä & Philipp Hennig (2019). A probabilistic model for the numerical solution of initial value problems. Statistics and Computing, 29(1), 99–122.

  • Bosch, Nathanael, Philipp Hennig, and Filip Tronarp. "Probabilistic exponential integrators." Advances in Neural Information Processing Systems 36 (2023): 40450-40467.

Note: Probdiffeq is not a drop-in replacement for these solvers; the probabilistic approach is fundamentally different. However, you can match performance and accuracy levels by tuning the solver configuration (see the examples in the documentation).

From other probabilistic ODE solvers

If you're familiar with other probabilistic solver libraries, here are the comparisons:

From ProbNum (Python, Numpy): ProbNum is a general-purpose probabilistic numerics library, while Probdiffeq specialises in ODE solving with pure JAX. Advantages of Probdiffeq:

  • Greater efficiency due to JAX's JIT compilation and autodiff
  • More mature ODE algorithms (state-space factorisations, improved adaptive time-stepping)
  • Richer outputs (sampling, marginal likelihoods, marginal-likelihood losses, etc.)

From ProbNumDiffEq.jl (Julia): ProbNumDiffEq.jl is a Julia equivalent of Probdiffeq (though the libraries are unrelated), with similar features but slightly different APIs. Here's how to translate:

| ProbNumDiffEq.jl concept | ProbDiffEq concept | |----------------------------- | --------------------| | EK0 / EK1 | constraint_ode_ts0() / constraint_ode_ts1() | | DynamicDiffusion / FixedDiffusion | solver_dynamic() / solver_mle() | | IWP(diffusion=x^2) | prior_wiener_integrated(output_scale=x) | | smooth=true/false | strategy_filter() / strategy_smoother_fixedpoint() / strategy_smoother_fixedinterval() |

Both libraries are actively evolving; consult their latest API documentation if you're unsure about equivalences.

Choose the right solver

Good solvers are problem-dependent. However, some guidelines exist:

Problem characteristics

Choosing the right approach matters because problem size and behaviour directly impact solver efficiency, stability, and the accuracy of the uncertainty quantification.

Dimensionality: For low-dimensional problems, use dense covariances, which track full correlations between state variables and offer the best stability and uncertainty quantification. For larger problems, use blockdiagonal or isotropic state-space models, which are more efficient by tracking only partial uncertainty correlations. However, their uncertainty quantification is typically worse. The general trade-off is between accuracy and speed: dense models scale cubically in the dimension but provide the best accuracy; the other two models scale linearly in the dimension.

Stiffness: Stiff problems have rapid changes or very different timescales. For these, use dense state-space models with first-order linearization. See also the prior recommendations below. Avoid zeroth-order methods and isotropic state-space models for stiff problems. Block-diagonal state-space models with first-order linearization may suffice for moderately stiff cases, but expect that all solvers except first-order linearisation in dense state-space models have worse stability than, for example, implicit Runge-Kutta methods.

Filter

Related Skills

View on GitHub
GitHub Stars56
CategoryDevelopment
Updated16d ago
Forks3

Languages

Python

Security Score

100/100

Audited on Mar 21, 2026

No findings