Thoad
Lightweight performat Python 3.12+ automatic differentiation system that leverages PyTorch’s computational graph to compute arbitrary-order partial derivatives.
Install / Use
/learn @mntsx/ThoadREADME
<br>[!NOTE] This package is still in an experimental stage. It may exhibit unstable behavior or produce unexpected results, and is subject to possible future structural modifications.
About
thoad is a lightweight reverse-mode automatic differentiation engine written entirely in Python that works over PyTorch’s computational graph to compute high order partial derivatives. Unlike PyTorch’s native autograd - which is limited to first-order native partial derivatives - thoad is able to performantly propagate arbitray-order derivatives throughout the graph, enabling more advanced derivative-based workflows.
Core Features
- Python 3.12+: thoad is implemented in Python 3.12, and its compatible with any higher Python version.
- Built on PyTorch: thoad uses PyTorch as its only dependency. It is compatible with +70 PyTorch operator backward functions.
- Arbitrary-Order Differentiation: thoad can compute arbitrary-order partial derivatives - including cross node derivatives.
- Adoption of the PyTorch Computational Graph: thoad integrates with PyTorch tensors by adopting their internally traced subgraphs.
- High Performance: thoad hessian comp time scales asymptotically better than torch.autograd's, remaining closer to jax.jet performance.
- Non-Sequential Graph Support: Unlike jax.jet, thoad supports differentiation on arbitrary graph topologies, not only sequentials.
- Non-Scalar Differentiation: Unlike
torch.Tensor.backward, thoad allows launching differentiation from non-scalar tensors. - Support for Backward Hooks: thoad allows registering backward hooks for dynamic tuning of propagated high-order derivatives.
- Diagonal Optimization: thoad detects and avoids duplication of cross diagonal dimensions during back-propagation.
- Symmetry Optimization: Leveraging Schwarz’s theorem, thoad removes redundant derivative block computations.
Installation
thoad can be installed either from PyPI or directly from the GitHub repository.
-
From PyPI
pip install thoad -
From GitHub Install directly with
pip(fetches the latest from themainbranch):pip install git+https://github.com/mntsx/thoad.gitOr, if you prefer to clone and install in editable mode:
git clone https://github.com/mntsx/thoad.git cd thoad pip install -e .
Using the Package
thoad exposes two primary interfaces for computing high-order derivatives:
thoad.backward: a function-based interface that closely resemblestorch.Tensor.backward. It provides a quick way to compute high-order pertial derivatives without needing to manage an explicit controller object, but it offers only the core functionality (derivative computation and storage).thoad.Controller: a class-based interface that wraps the output tensor’s subgraph in a controller object. In addition to performing the same high-order backward pass, it gives access to advanced features such as fetching specific cross partials, inspecting batch-dimension optimizations, overriding backward-function implementations, retaining intermediate partials, and registering custom hooks.
thoad.backward
The thoad.backward function computes high-order partial derivatives of a given output tensor and stores them in each leaf tensor’s .hgrad attribute.
Arguments:
-
tensor: A PyTorch tensor from which to start the backward pass. This tensor must haverequire_grad=Trueand be part of a differentiable graph. -
order: A positive integer specifying the maximum order of derivatives to compute. -
gradient: A tensor with the same shape astensorto seed the vector-Jacobian product (i.e., custom upstream gradient). If omitted, the primal vector space is not reduced. -
crossings: A boolean flag (default=False). If set toTrue, cross partial derivatives (i.e., derivatives that involve more than one distinct leaf tensor) will be computed. -
groups: An iterable of disjoint groups of leaf tensors. Whencrossings=False, only those cross partials whose participating leaf tensors all lie within a single group will be calculated. Ifcrossings=Trueandgroupsis provided, a ValueError will be raised (they are mutually exclusive). -
keep_batch: A boolean flag (default=False) that controls how output dimensions are organized in the computed derivatives.-
When
keep_batch=False:
The derivative preserves one first flattened "primal" axis, followed by each original partial shape, sorted in differentiation order. Concretelly:- A single "primal" axis that contains every element of the graph output tensor (flattened into one dimension).
- A group of axes per derivative order, each matching the shape of the respective differentially targeted tensor.
For an N-th order derivative of a leaf tensor with
input_numelelements and an output withoutput_numelelements, the deerivative shape is:- Axis 1: indexes all
output_numeloutputs - Axes 2…(sum(Nj)+1): each indexes all
input_numelinputs
-
When
keep_batch=True:
The derivative shape follows the same ordering as in the previous case, but includes a series of "independent dimensions" immediately after the "primal" axis.- Axis 1 flattens all elements of the output tensor (size =
output_numel). - Axes 2...(k+i) correspond to dimensions shared by multiple input tensors and treated independently throughout the graph. These are dimensions that are only operated on element-wise (e.g. batch dimensions).
- Axes (k+i+1)...(k+i+sum(Nj)+1) each flatten all
input_numelelements of the leaf tensor, one axis per derivative order.
- Axis 1 flattens all elements of the output tensor (size =
-
-
keep_schwarz: A boolean flag (default=False). IfTrue, symmetric (Schwarz) permutations are retained explicitly instead of being canonicalized/reduced, useful for debugging or inspecting non-reduced layouts.
Returns:
- An instance of
thoad.Controllerwrapping the same tensor and graph.
Executing Autodifferentiation via thoad.backward
import torch
import thoad
from torch.nn import functional as F
### Normal PyTorch workflow
X = torch.rand(size=(10,15), requires_grad=True)
Y = torch.rand(size=(15,20), requires_grad=True)
Z = F.scaled_dot_product_attention(query=X, key=Y.T, value=Y.T)
### Call thoad backward
order = 2
thoad.backward(tensor=Z, order=order)
### Checks
# check derivative shapes through the attribute aggregated to torch.Tensor: hgrad
for o in range(1, 1 + order):
assert X.hgrad[o - 1].shape == (Z.numel(), *(o * tuple(X.shape)))
assert Y.hgrad[o - 1].shape == (Z.numel(), *(o * tuple(Y.shape)))
<br>
thoad.Controller
The Controller class wraps a tensor’s backward subgraph in a controller object, performing the same core high-order backward pass as thoad.backward while exposing advanced customization, inspection, and override capabilities.
Instantiation
Use the constructor to create a controller for any tensor requiring gradients:
controller = thoad.Controller(tensor=GO) # takes graph output tensor
tensor: A PyTorchTensorwithrequires_grad=Trueand a non-Nonegrad_fn.
Properties
-
.tensor → TensorThe output tensor underlying this controller. Setter: Replaces the tensor (after validation), rebuilds the internal computation graph, and invalidates any previously computed derivatives. -
.compatible → boolIndicates whether every backward function in the tensor’s subgraph has a supported high-order implementation. IfFalse, some derivatives may fall back or be unavailable. -
.index → Dict[Type[torch.autograd.Function], Type[ExtendedAutogradFunction]]A mapping from base PyTorchautograd.Functionclasses to thoad’sExtendedAutogradFunctionimplementations. Setter: Validates and injects your custom high-order extensions.
Core Methods
.backward(order, gradient=None, crossings=False, groups=None, keep_batch=False, keep_schwarz=False) → None
Performs the high-order backward pass up to the specified derivative order, storing all computed partials in each leaf tensor’s .hgrad attribute.
order(int > 0): maximum derivative order.gradient(Optional[Tensor]): custom upstream gradient with the same shape ascontroller.tensor.crossings(bool, defaultFalse): IfTrue, cross partial derivatives across different leaf tensors will be computed.groups(Optional[Iterable[Iterable[Tensor]]], defaultNone): Whencrossings=False, restricts cross partials to those whose leaf tensors all lie within a single group. Ifcrossings=Trueandgroupsis provided, a ValueError is raised.keep_batch(bool, defaultFalse): controls whether independent output axes are kept separate (batched) or merged (flattened) in stored/retrieved derivatives.
Related Skills
node-connect
339.1kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
83.8kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
339.1kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
83.8kCommit, push, and open a PR
