Difflogic
Train neural networks that distill into logic circuits, using JAX
Install / Use
/learn @slightknack/DifflogicREADME
"Where we're going, we won't need vectors!"
I've been reading about 1-bit quantization for NNs. The idea is pretty fun! The other day, I ran into some great research taking things a step further, using NNs to learn logic circuits. I replicated this research from scratch, and trained a neural network with logic gates in the place of activation functions to learn the 3x3 kernel function for Conway's Game of Life.
I wanted to see if I could speed up inference by extracting and compiling the learned logic circuit. So I wrote some code to extract and compile the trained network to pure C (including some simple optimizations like copy propogation and dead code elimination)! I benchmarked the original NN using for training (running on the GPU) against the extracted 300-line single-threaded C program (running on the CPU). ...
... and compiling the neural network to C resulted in a 1,744x speedup. Yeah, crazy right?
I had a lot of fun. Reproduction steps and development journal below! Enjoy.
resources
- original paper/blog
- training nn in jax
- google colab
- conway gol shader
- flax linen nn docs
- optax optimizer docs
dependencies
- jax
- flax
- optax
- einops
to reproduce
[!TIP] With Nix installed on Linux or macOS, run
nix build github:slightknack/difflogic
- Clone this repo.
- Create and source a
venv. - Install dependencies listed above using
pip - Run
python3 main.py.- This will train for 3000 epochs with jit (< 2 minutes).
- Record the
s/epochtime. Each epoch is 512 samples:- On my machine, I get
0.000139 s/epoch.- (I modified
main.pyto not time weight update, otherwise0.025 s/epochis normal)
- (I modified
- On my machine, I get
- Verify
test_loss_hard: 0at the end. - After training, this will produce a file,
gate.c.
- Compile
gate.cusing your preferred c compiler:gcc gate.c -O3 -o gate -Wall -Wextra- Run with
./gate
- For benchmarking, comment out visualization
- In
gate.crunC-fto findcomment out, three lines
- In
- Benchmark with
time ./gate- This runs 100k steps of GOL on a random board
- Record how long it takes ->
bench_time:- On my machine, program finishes in
4.08s.
- On my machine, program finishes in
- Compute the speedup:
(512 * s/epoch) / (bench_time / 100000)- As a lower bound, I got a 1,744x speedup.
- (When I benched, I modified
main.pyto not record weight update time.)
- (When I benched, I modified
Hardware: 2023 MacBook Pro M3, 18GB
journal
2025-05-27
- well, I'm back again. Not planning to spend too much more time on this.
- I implemented better renaming of variables for codegen.
- I have a few fun ideas:
- Try soft weights. I'll duplicate
main.pyto another file file,soft.py. - Reintegration Tracking is a technique for fluid simulation using CA. I've implemented it before. I might try to get it working, because then I could try to learn a logic circuit for fluid simulation, which would be crazy.
- I've implemented single-threaded bit-parallel conway's game of life in C, but this is so embarrassingly parallel that a GPU compute shader implementation using e.g. WebGPU or slang might be in order.
- Try soft weights. I'll duplicate
- I really want to write a blog post. So I'll start on that before I do anything else.
- I'll write it for Anthony, will be fun.
- Still working on it ... sigh
- Okay I finished and it's tomorrow that took way too long. Published. Night.
2025-05-26
- back!
- where I left off:
- got a big network converging
- didn't get perfect convergence
- probably due to wiring
- what to try:
- better wiring
- soft dot weights
- extracting discrete network at the end
- planning better wiring
- essentially, we need to generate unique pairs of wires, and shuffle those.
- If input
m=2*noutput, we can do:- (1, 2), (3, 4), ...
- If input
m=noutput, we can do:- (1, 2), (3, 4), ...
- (2, 3), (4, 5), ...
- This is what they do in the paper
- I think the extend it a little further
- but anything in m : [n, 2*n] is possible
- If input
- I will implement this approach:
- First pair up: (1, 2), (3, 4), ...
- Once that is exhausted, pair up: (2, 3), (4, 5), ...
- Once that is exhausted, do random pairs
- This is simple but should be better than what I have now.
- Wait, I can just use itertools.combinations!
- Nevermind, I want a uniform distribution of unique pairs
- Wait, I can just use itertools.combinations!
- essentially, we need to generate unique pairs of wires, and shuffle those.
- implementing better wiring
- 25 ms/epoch
- no tree: epoch 3000, loss 0.0245, hard 0.123
- with tree: epoch 3000, loss 0.00342, hard 0.00781
- comb with tree: epoch 3000, loss 0.0145, hard 0.0215
- comb no tree: epoch 3000, loss 0.0337, hard 0.0801
- okay, so the fancier wiring did not yield better results
- maybe I should use the unique wiring used in the paper, exactly as described
- YESYEYSYEYSYEYS!
- epoch 3000, loss 9.91e-05, hard 0
- HARD 0
- we learned a perfect circuit for conway's!
- HARD 0
- maybe I should use the unique wiring used in the paper, exactly as described
- mission complete.
- Now all the fun stuff:
- extracting the network
- writing a little interactive demo and animations
- writing a blog post
- Now all the fun stuff:
- let me commit here
- wait, I jumped the gun. I still want to do soft dot:
- the layer generation code is good, I just need to i.e. multiply by 10
- I also need to do the soft dot
- and make the wires trainable
- I'm going to duplicate
main.pyto create a alternate version with soft dot:soft.py: - before I actually implement that, I'm going to implement code for extracting the learned hard network:
- I wrote code to generate some c code, but I really should prune the network before outputting it
- I wrote code to prune, nice
- let me try a training run, output hard to
gate.c- Output! code has lots of copies, I should probably also implement copy prop sigh
- going to take a break, commit here.
- ANDs: 69
- ORs: 68
- XORs: 17
- NOTs: 16
- back from break, compiling to c:
- implemented copy prop
- implemented dead code elim
- implemented variable renaming
- implemented formatting with template
- compiles and it works!
python3 main.pygcc gate.c -O3 -o gate./gate
- This is so frickin' cool. I just trained a neural network and compiled the weights to C, at 1 bit quantization.
- also, note that
cellis auint64_t, soconwayingate.cactually performs parallel evaluation of 64 cells at once (e.g. an 8x8 tile)!
- also, note that
- I'm going to try to benchmark this:
- Well, I implemented a little visualizer TUI thing, which is cool!
- On a 512x512 grid, battery saver on, it runs at a rate of
- 100k steps in 4.08s
- or 41 μs/step
- or 24509 steps/second, fps, whatever
- In comparison to the jax version, for a batch of 512, we were getting about 25 ms/epoch
- that would mean 512x512 = 12.8 seconds/step
- Might be unfair, jax does parallelize well
- I mean, let's test 512x512 batch size...
- Okay, it's like 6.62 s/epoch ~ 6.62 s/step
- which is better!
- that would mean 512x512 = 12.8 seconds/step
- We went from:
- 6.62 s/step network
- 41 μs/step c program
- That's a 162,249.58x speedup.
- Okay, that's obviously too good. After fixing up
main.py:- python:
0.000139 s/epoch - compiled:
41 μs/step - speedup:
1,744.31x. - A 1744x speedup is still very good!
- python:
- I asked my brother draw a cool person unraveling for the logo.
- Probably about done with this project, I might write a blog post though.
- Cleaned up the README a bit. Night!
2025-05-24
- some things to try:
- use a real optimizer like adamw via optax
- implement sparse static weights
- implement "generalized" gates
- instead of having a gate with 16 relaxations,
- we have one gate with parameters
- idea is to find a better basis that spans the same function space
- where the model is at:
- network layout [9, 48, 48, 1]
- normalized weights and relu
- batch size 20
- step size 0.1
- converges, ~2500 epochs, loss < 1e-7
- ~3ms / epoch
- if I try the full soft gate, everything else the same:
- no dice, 15000 epochs, loss is ~.18-.29
- ~2ms / epoch
- if I try the mini soft gate, AND XOR OR NAND:
- blows up, ~5100 epochs, loss is 0.0611 before
- ~1-2ms / epoch
- using a real optimizer
- adamw via optax
- that wasn't too hard to implement
- I had to disable jit though, getting an error,
- I should look into that so I can reenable jit
- using hyperparams:
- learning rate: 0.05
- b1, b2: 0.9, 0.99
- weight decay: 1e-2
- training with soft mini gate, seems to be having a hard time converging, loss goes up crazy then comes back down
- 18ms / epoch
- maybe my learning rate is too high?
- or batch size is too small?
- it got down to 0.06 at around 11000
- so obviously it's learning, it's just hard
- I'm going to try again, with batch size: 512:
- this seems to be converging
- scratch that, "blowing up less"
- epoch 5000, loss 0.0695
- but it's going up and down
- epoch 7000, loss 0.312
- case in point
- no dice, epoch 10000, loss 0.112
- oscillating between ~0.05 and ~0.15
- this seems to be converging
- And another experiment, with:
- batch size: 20
- back to the old baseline, want to see just learning rate
- learning rate: 0.01
- it got down to 0.012 then blew up to 0.51
- I wonder if the learning rate is too high?
- new min, epoch ~7000, loss is 0.0038
- epoch 14700, loss ~0.001
- certainly is converging better
- batch size: 20
- let's try
- adamw via optax
