Torchcde
Differentiable controlled differential equation solvers for PyTorch with GPU support and memory-efficient adjoint backpropagation.
Install / Use
/learn @patrick-kidger/TorchcdeREADME
Update: for any new projects, I would now recommend using Diffrax instead. This is much faster, and producion-quality. torchcde was its prototype as a research project!
This library provides differentiable GPU-capable solvers for controlled differential equations (CDEs). Backpropagation through the solver or via the adjoint method is supported; the latter allows for improved memory efficiency.
In particular this allows for building Neural Controlled Differential Equation models, which are state-of-the-art models for (arbitrarily irregular!) time series. Neural CDEs can be thought of as a "continuous time RNN".
<p align="center"> <img align="middle" src="./imgs/main.png" width="666" /> </p>
Installation
pip install torchcde
Requires PyTorch >=1.7.
Example
import torch
import torchcde
# Create some data
batch, length, input_channels = 1, 10, 2
hidden_channels = 3
t = torch.linspace(0, 1, length)
t_ = t.unsqueeze(0).unsqueeze(-1).expand(batch, length, 1)
x_ = torch.rand(batch, length, input_channels - 1)
x = torch.cat([t_, x_], dim=2) # include time as a channel
# Interpolate it
coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(x)
X = torchcde.CubicSpline(coeffs)
# Create the Neural CDE system
class F(torch.nn.Module):
def __init__(self):
super(F, self).__init__()
self.linear = torch.nn.Linear(hidden_channels,
hidden_channels * input_channels)
def forward(self, t, z):
return self.linear(z).view(batch, hidden_channels, input_channels)
func = F()
z0 = torch.rand(batch, hidden_channels)
# Integrate it
torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval)
See time_series_classification.py, which demonstrates how to use the library to train a Neural CDE model to predict the chirality of a spiral.
Also see irregular_data.py, for demonstrations on how to handle variable-length inputs, irregular sampling, or missing data, all of which can be handled easily, without changing the model.
Citation
If you found use this library useful, please consider citing
@article{kidger2020neuralcde,
title={{N}eural {C}ontrolled {D}ifferential {E}quations for {I}rregular {T}ime {S}eries},
author={Kidger, Patrick and Morrill, James and Foster, James and Lyons, Terry},
journal={Advances in Neural Information Processing Systems},
year={2020}
}
Documentation
The library consists of two main components: (1) integrators for solving controlled differential equations, and (2) ways of constructing controls from data.
Integrators
The library provides the cdeint function, which solves the system of controlled differential equations:
dz(t) = f(t, z(t))dX(t) z(t_0) = z0
The goal is to find the response z driven by the control X. This can be re-written as the following differential equation:
dz/dt(t) = f(t, z)dX/dt(t) z(t_0) = z0
where the right hand side describes a matrix-vector product between f(t, z) and dX/dt(t).
This is solved by
cdeint(X, func, z0, t, adjoint, backend, **kwargs)
where letting ... denote an arbitrary number of batch dimensions:
Xis atorch.nn.Modulewith methodderivative, such thatX.derivative(t)is a Tensor of shape(..., input_channels),funcis atorch.nn.Module, such thatfunc(t, z)returns a Tensor of shape(..., hidden_channels, input_channels),z0is a Tensor of shape(..., hidden_channels),tis a one-dimensional Tensor of times to outputzat.adjointis a boolean (defaulting toTrue).backendis a string (defaulting to"torchdiffeq").
Adjoint backpropagation (which is slower but more memory efficient) can be toggled with adjoint=True/False.
The backend should be either "torchdiffeq" or "torchsde", corresponding to which underlying library to use for the solvers. If using torchsde then the stochastic term is zero -- so the CDE is still reduced to an ODE. This is useful if one library supports a feature that the other doesn't. (For example torchsde supports a reversible solver, the reversible Heun method; at time of writing torchdiffeq does not support any reversible solvers.)
Any additional **kwargs are passed on to torchdiffeq.odeint[_adjoint] or torchsde.sdeint[_adjoint], for example to specify the solver.
Constructing controls
A very common scenario is to construct the continuous controlX from discrete data (which may be irregularly sampled with missing values). To support this, we provide three main interpolation schemes:
- Hermite cubic splines with backwards differences
- Linear interpolation
- Rectilinear interpolation
Note that if for some reason you already have a continuous control X then you won't need an interpolation scheme at all!
Hermite cubic splines are usually the best choice, if possible. Linear and rectilinear interpolations are particularly useful in causal settings -- when at inference time the data is arriving over time. We go into further details in the Further Documentation below.
Just demonstrating Hermite cubic splines for now:
coeffs = hermite_cubic_coefficients_with_backward_differences(x)
# coeffs is a torch.Tensor you can save, load,
# pass through Datasets and DataLoaders etc.
X = CubicSpline(coeffs)
where:
xis a Tensor of shape(..., length, input_channels), where...is some number of batch dimensions. Missing data should be represented as aNaN.
The interface provided by CubicSpline is:
.interval, which gives the time interval the spline is defined over. (Often used as thetargument incdeint.) This is determined implicitly from the length of the data, and so does not in general correspond to the time your data was actually observed at. (See the Further Documentation note on reparameterisation invariance.).grid_pointsis all of the knots in the spline, so that for exampleX.evaluate(X.grid_points)will recover the original data..evaluate(t), wheretis an any-dimensional Tensor, to evaluate the spline at any (collection of) time(s)..derivative(t), wheretis an any-dimensional Tensor, to evaluate the derivative of the spline at any (collection of) time(s).
Usually hermite_cubic_coefficients_with_backward_differences should be computed as a preprocessing step, whilst CubicSpline should be called inside the forward pass of your model. See time_series_classification.py for a worked example.
Then call:
cdeint(X=X, func=... z0=..., t=X.interval)
Further documentation
The earlier documentation section should give everything you need to get up and running.
Here we discuss a few more advanced bits of functionality:
- The reparameterisation invariance property of CDEs.
- Other interpolation methods, and the differences between them.
- The use of fixed solvers. (They just work.)
- Stacking CDEs (i.e. controlling one by the output of another).
- Computing logsignatures for the log-ODE method.
Reparameterisation invariance
This is a classical fact about CDEs.
Let <img src="https://render.githubusercontent.com/render/math?math=%5Cpsi%20%5Ccolon%20%5Ba%2C%20b%5D%20%5Cto%20%5Bc%2C%20d%5D"> be differentiable and increasing, with <img src="https://render.githubusercontent.com/render/math?math=%5Cpsi(a)%20%3D%20c"> and <img src="https://render.githubusercontent.com/render/math?math=%5Cpsi(b)%20%3D%20d">. Let <img src="https://render.githubusercontent.com/render/math?math=T%20%5Cin%20%5Bc%2C%20d%5D">, let <img src="https://render.githubusercontent.com/render/math?math=%5Cwidetilde%7Bz%7D%20%3D%20z%20%5Ccirc%20%5Cpsi">, let <img src="https://render.githubusercontent.com/render/math?math=%5Cwidetilde%7BX%7D%20%3D%20X%20%5Ccirc%20%5Cpsi">, and let <img src="https://render.githubusercontent.com/render/math?math=%5Cmathcal%7BT%7D%20%3D%20%5Cpsi(T)">. Then substituting <img src="https://render.githubusercontent.com/render/math?math=t%20%3D%20%5Cpsi(%5Ctau)"> into a CDE (and just using the standard change of variables formula):
<img src="https://render.githubusercontent.com/render/math?math=%5Cbegin%7Balign*%7D%0A%5Cwidetilde%7Bz%7D(%5Cmathcal%7BT%7D)%20%26%3D%20z(T)%5C%5C%0A%20%26%3D%20z(c)%20%2B%20%5Cint_c%5ET%20f(z(t))%5C%2C%5Cmathrm%7Bd%7DX(t)%5C%5C%0A%26%3D%20z(c)%20%2B%20%5Cint_c%5ET%20f(z(t))%5C%2C%5Cfrac%7B%5Cmathrm%7Bd%7DX%7D%7B%5Cmathrm%7Bd%7Dt%7D(t)%5C%2C%5Cmathrm%7Bd%7Dt%5C%5C%0A%20%26%3D%20z(%5Cpsi(a))%20%2B%20%5Cint_a%5E%7B%5Cpsi%5E%7B-1%7D(T)%7D%20f(z(%5Cpsi(%5Ctau)))%5C%2C%5Cfrac%7B%5Cmathrm%7Bd%7DX%7D%7B%5Cmathrm%7Bd%7Dt%7D(%5Cpsi(%5Ctau))%5C%2C%5Cfrac%7B%5Cmathrm%7Bd%7D%5Cpsi%7D%7B%5Cmathrm%7Bd%7D%5Ctau%7D(%5Ctau)%5C%2C%5Cmathrm%7Bd%7D%5Ctau%5C%5C%0A%20%26%3D%20(z%5Ccirc%5Cpsi)(a)%20%2B%20%5Cint_a%5E%7B%5Cpsi%5E%7B-1%7D(T)%7D%20f((z%5Ccirc%5Cpsi)(%5Ctau))%5C%2C%5Cfrac%7B%5Cmathrm%7Bd%7D(X%5Ccirc%20%5Cpsi)%7D%7B%5Cmathrm%7Bd%7D%5Ctau%7D(%5Ctau)%5C%2C%5Cmathrm%7Bd%7D%5Ctau%5C%5C%0A%20%26%3D%20(z%5Ccirc%5Cpsi)(a)%20%2B%20%5Cint_a%5E%7B%5Cpsi%5E%7B-1%7D(T)%7D%20f((z%5Ccirc%20%5Cpsi)(%5Ctau))%5C%2C%5Cmathrm%7Bd%7D(X%5Ccirc%20%5Cpsi)(%5Ctau)%5C%5C%0A%26%3D%20%5Cwidetilde%7Bz%7D(c)%20%2B%20%5Cint_c%5E%5Cmathcal%7BT%7D%20f(%5Cwidetilde%7Bz%7D(%5Ctau))%20%5C%2C%5Cmathrm%7Bd%7D%5Cwidetilde%7BX%7D(%5Ctau)%0A%5Cend%7Balign*%7D">We see that <img src="https://render.githubusercontent.com/render/math?math=%5Cwidetilde%7Bz%7D"> also satisfies the neural CDE equation, just with <img src="https://render.githubusercontent.com/render/math?math=%5Cwidetilde%7BX%7D"> as input instead of <img src="https
Related Skills
openhue
344.4kControl Philips Hue lights and scenes via the OpenHue CLI.
sag
344.4kElevenLabs text-to-speech with mac-style say UX.
weather
344.4kGet current weather and forecasts via wttr.in or Open-Meteo
tweakcc
1.5kCustomize Claude Code's system prompts, create custom toolsets, input pattern highlighters, themes/thinking verbs/spinners, customize input box & user message styling, support AGENTS.md, unlock private/unreleased features, and much more. Supports both native/npm installs on all platforms.
