Torchcurves
Parametric differentiable curves with PyTorch for continuous embeddings, shape-restricted models, or KANs
Install / Use
/learn @alexshtf/TorchcurvesREADME
A PyTorch module for vectorized and differentiable parametric curves with learnable coefficients, such as a B-Spline curve with learnable control points, for KANs, continuous embeddings, and shape constraints.
<div align="center"> <p><b>Use cases</b></p> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/alexshtf/torchcurves/master/assets/usecases_dark.png"> <source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/alexshtf/torchcurves/master/assets/usecases_light.png"> <img width="100%" alt="Torchcurves Usecases" src="https://raw.githubusercontent.com/alexshtf/torchcurves/master/assets/usecases_light.png"> </picture> </div>Turns out all the above use cases have one thing in common - they can all be expressed using learnable parametric curves, and this the tool this library provides.
Learn
A simple "hello world" example - evaluate three two-dimensional b-spline curves at four points:
import torch
import torchcurves as tc
u = torch.rand(4, 3) # (B, C)
curve = tc.BSplineCurve(
num_curves=3, # C
dim=2, # D
)
y = curve(u) # (B, C, D)
print(u.shape, "->", y.shape) # torch.Size([4, 3]) -> torch.Size([4, 3, 2])
For more information:
- Documentation site.
- Example notebooks for you to try out.
Features
- Differentiable: Custom autograd function ensures gradients flow properly through the curve evaluation.
- Vectorized: Vectorized operations for efficient batch and multi-curve evaluation.
- Efficient numerics: Clenshaw recursion for polynomials, Cox-DeBoor for splines.
Installation
With pip:
pip install torchcurves
With uv:
uv add torchcurves
Use cases
There are examples in the doc/examples directory showing how to build models using
this library. Here we show some simple code snippets to appreciate the library.
Use case 1 - continuous embeddings
import torchcurves as tc
from torch import nn
import torch
class Net(nn.Module):
def __init__(self, num_categorical, num_numerical, dim, num_knots=10):
super().__init__()
self.cat_emb = nn.Embedding(num_categorical, dim)
self.num_emb = tc.BSplineCurve(num_numerical, dim, knots_config=num_knots)
self.embedding_based_model = MySuperDuperModel() # <-- put your encoder model here
def forward(self, x_categorical, x_numerical):
embeddings = torch.cat([
self.cat_emb(x_categorical),
self.num_emb(x_numerical)
], dim=-2)
return self.embedding_based_model(embeddings)
Use case 2 - monotone functions
Working on online advertising, and want to model the probability of winning an ad auction given the bid? We know higher bids must result in a higher win probability - we need a monotone function. Turns out B-Splines are monotone if their coefficient vectors are monotone. Want an increasing function? Just make sure the increasing - so let's use it.
Below is an example with an auction encoder that encodes the auction into a vector, we then transform it to an increasing vector, and use it as the coefficient vector for a B-Spline curve.
import torch
from torch import nn
import torchcurves.functional as tcf
class AuctionWinModel(nn.Module):
def __init__(self, num_auction_features, num_bid_coefficients):
self.auction_encoder = make_auction_encoder( # example - an MLP, a transformer, etc.
input_features=num_auction_features,
output_features=num_bid_coefficients,
)
self.spline_knots = nn.Buffer(tcf.uniform_augmented_knots(
n_control_points=num_bid_coefficients,
degree=3,
k_min=0,
k_max=1
))
def forward(self, auction_features, bids):
# map auction features to increasing spline coefficients
spline_coeffs = self._make_increasing(self.auction_encoder(auction_features))
# map bids to [0, 1] using the arctan (or any other) normalization
mapped_bid = tcf.arctan(bids)
# evaluate the spline at the mapped bids, treating each
# mini-batch sample as a separate curve
return tcf.bspline_curves(
mapped_bid.unsqueeze(0), # 1 x B (B curves in 1 dimension)
spline_coeffs.unsqueeze(-1), # B x C x 1 (B curves with C coefs in 1 dimension)
self.spline_knots,
degree=3
)
def _make_increasing(self, x):
# transform a mini-batch of vectors to a mini-batch of increasing vectors
initial = x[..., :1]
increments = nn.functional.softplus(x[..., 1:])
concatenated = torch.concat((initial, increments), dim=-1)
return torch.cumsum(concatenated, dim=-1)
Now we can train the model to predict the probability of winning auctions given auction features and bid:
import torch.nn.functional as F
for auction_features, bids, win_labels in train_loader:
win_logits = model(auction_features, bids)
loss = F.binary_cross_entropy_with_logits( # or any loss we desire
win_logits,
win_labels
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Use case 3 - Kolmogorov-Arnold networks
A KAN [1] based on the B-Spline basis, along the lines of the original paper:
import torchcurves as tc
from torch import nn
input_dim = 2
intermediate_dim = 5
num_control_points = 10
kan = nn.Sequential(
# layer 1
tc.BSplineCurve(input_dim, intermediate_dim, knots_config=num_control_points),
tc.Sum(dim=-2),
# layer 2
tc.BSplineCurve(intermediate_dim, intermediate_dim, knots_config=num_control_points),
tc.Sum(dim=-2),
# layer 3
tc.BSplineCurve(intermediate_dim, 1, knots_config=num_control_points),
tc.Sum(dim=-2),
)
Yes, we know the original KAN paper used a different curve parametrization, B-Spline + arcsinh, but the whole point of this repo is showing that KAN activations can be parametrized in arbitrary ways.
For example, here is a KAN based on Legendre polynomials of degree 5:
import torchcurves as tc
from torch import nn
input_dim = 2
intermediate_dim = 5
degree = 5
kan = nn.Sequential(
# layer 1
tc.LegendreCurve(input_dim, intermediate_dim, degree=degree),
tc.Sum(dim=-2),
# layer 2
tc.LegendreCurve(intermediate_dim, intermediate_dim, degree=degree),
tc.Sum(dim=-2),
# layer 3
tc.LegendreCurve(intermediate_dim, 1, degree=degree),
tc.Sum(dim=-2),
)
Since KANs are the primary use case for the tc.Sum() layer, we can omit the dim=-2 argument, but it is provided
here for clarity.
Advanced features
The curves we provide here typically rely on their inputs to lie in a compact interval, typically [-1, 1]. Arbitrary inputs need to be normalized to this interval. We provide two simple out-of-the-box normalization strategies described below.
Rational scaling
This is the default strategy — this strategy computes
x \to \frac{x}{\sqrt{s^2 + x^2}},
and is based on the paper
Wang, Z.Q. and Guo, B.Y., 2004. Modified Legendre rational spectral method for the whole line. Journal of Computational Mathematics, pp.457-474.
In Python it looks like this:
tc.BSplineCurve(num_curves, curve_dim, normalize_fn='rational', normalization_scale=s)
Arctan scaling
This strategy computes
x \to \frac{2}{\pi} \arctan(x / s).
This kind of scaling function, up to constants, is the CDF of the Cauchy distribution. It is useful when our inputs are assumed to be heavy tailed.
In Python it looks like this:
tc.BSplineCurve(num_curves, curve_dim, normalize_fn='arctan', normalization_scale=s)
Clamping
The inputs are simply clipped to [-1, 1] after scaling, i.e.
x \to \max(\min(1, x / s), -1)
In Python it looks like this:
tc.BSplineCurve(num_curves, curve_dim, normalize_fn='clamp', normalization_scale=s)
Custom normalization
Provide a custom function that maps its input to the designated range after scaling. Example:
def erf_clamp(x: Tensor, scale: float = 1, out_min: float = -1, out_max: float = 1) -> torch.Tensor:
mapped = torch.special.erf(x / scale)
return ((mapped + 1) * (out_max - out_min)) / 2 + out_min
tc.BSplineCurve(num_curves, curve_dim, normalize_fn=erf_clamp, normalization_scale=s)
Gradient checkpointing for Legendre curves
For large degrees, the backward pass can be memory-intensive. Use
checkpoint_segments to trade compute for memory. Larger values create more
segments (lower memory, higher compute). Set to None to disable. Checkpointing
is applied only when gradients are enabled.
# Functional API
tc.functional.legendre_curves(x, coeffs, checkpoi
Related Skills
claude-opus-4-5-migration
82.7kMigrate prompts and code from Claude Sonnet 4.0, Sonnet 4.5, or Opus 4.1 to Opus 4.5
model-usage
335.9kUse CodexBar CLI local cost usage to summarize per-model usage for Codex or Claude, including the current (most recent) model or a full model breakdown. Trigger when asked for model-level usage/cost data from codexbar, or when you need a scriptable per-model summary from codexbar cost JSON.
mcp-for-beginners
15.6kThis open-source curriculum introduces the fundamentals of Model Context Protocol (MCP) through real-world, cross-language examples in .NET, Java, TypeScript, JavaScript, Rust and Python. Designed for developers, it focuses on practical techniques for building modular, scalable, and secure AI workflows from session setup to service orchestration.
TrendRadar
49.8k⭐AI-driven public opinion & trend monitor with multi-platform aggregation, RSS, and smart alerts.🎯 告别信息过载,你的 AI 舆情监控助手与热点筛选工具!聚合多平台热点 + RSS 订阅,支持关键词精准筛选。AI 智能筛选新闻 + AI 翻译 + AI 分析简报直推手机,也支持接入 MCP 架构,赋能 AI 自然语言对话分析、情感洞察与趋势预测等。支持 Docker ,数据本地/云端自持。集成微信/飞书/钉钉/Telegram/邮件/ntfy/bark/slack 等渠道智能推送。
