Fwht
WIP: Triton implementation of a Fast Walsh-Hadamard Transform
Install / Use
/learn @arthurfeeney/FwhtREADME
torch-fwht
This is an implementation of the Fast Walsh-Hadamard Transform (FWHT) for pytorch. The GPU code is implemented with triton. The CPU code is a simple C++ extension.
This is a work-in-progress!
Installation
The only dependencies are triton and pytorch.
git clone https://github.com/arthurfeeney/fwht && cd fwht
pip install --editable .
To run the tests,
pip install pytest
python -m pytest test/
Example
from fwht import (
fast_hadamard_transform,
# can also use as a module
# FastHadamardTransform
)
data = torch.ones(32, device='cuda')
fast_hadamard_transform(data, inplace=True)
assert data[0] = 32
assert torch.all(data[1:] == 0)
Caveats
- If the input is not a power of 2, the kernel has to explicitly zero pad it to the next power of two. This happens inside the kernel, so it does not need to allocate in global memory, just doing extra compute. Triton requires block sizes to be powers of 2, so I'm not sure if there's a way to work around this.
- This assumes the input can be factorized as a power of 16 times a power of 2. (Meaning the input size needs to be $16^m * 2^n$.) Otherwise, it will get padded.
- The maximum supported size is currently $16^3 * 2^3$, which allows things to fit in GPU shared memory. Going beyond this size is likely uncommon in machine learning.
Notes on Implementation
The implementation relies on a few things:
- Hadamard transforms $H_{pq}$ can be expressed as the Kronecker product of two smaller Hadamard transforms $H_{pq} = \text{kron}(H_p, H_q)$.
- The matrix-vector product of a kronecker product satisfies $\text{kron}(B^T, A) \text{colvec}(m)$ = $\text{colvec}(AmB)$, where colvec unravels the columns of a matrix to form a vector. Since the Hadamard matrices are symmetric, it is even simpler and doesn't require transposes.
- A way to construct a Hadamard matrix is the identity $(H_n)_{i,j} = -1^{i \cdot j}$, where $i\cdot j$ is the bit-wise dot product of integers $i$ and $j$. This is useful for building the small Hadamard matrices used in the base case, instead of loading them from gmem. Triton should optimize most of this most of this away during compilation.
Related Skills
node-connect
349.0kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
109.4kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
349.0kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
349.0kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
