SkillAgentSearch skills...

Difflogic

Train neural networks that distill into logic circuits, using JAX

Install / Use

/learn @slightknack/Difflogic
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

<img alt="difflogic" src="difflogic.svg">

"Where we're going, we won't need vectors!"

I've been reading about 1-bit quantization for NNs. The idea is pretty fun! The other day, I ran into some great research taking things a step further, using NNs to learn logic circuits. I replicated this research from scratch, and trained a neural network with logic gates in the place of activation functions to learn the 3x3 kernel function for Conway's Game of Life.

I wanted to see if I could speed up inference by extracting and compiling the learned logic circuit. So I wrote some code to extract and compile the trained network to pure C (including some simple optimizations like copy propogation and dead code elimination)! I benchmarked the original NN using for training (running on the GPU) against the extracted 300-line single-threaded C program (running on the CPU). ...

... and compiling the neural network to C resulted in a 1,744x speedup. Yeah, crazy right?

I had a lot of fun. Reproduction steps and development journal below! Enjoy.

resources

dependencies

  • jax
  • flax
  • optax
  • einops

to reproduce

[!TIP] With Nix installed on Linux or macOS, run nix build github:slightknack/difflogic

  • Clone this repo.
  • Create and source a venv.
  • Install dependencies listed above using pip
  • Run python3 main.py.
    • This will train for 3000 epochs with jit (< 2 minutes).
    • Record the s/epoch time. Each epoch is 512 samples:
      • On my machine, I get 0.000139 s/epoch.
        • (I modified main.py to not time weight update, otherwise 0.025 s/epoch is normal)
    • Verify test_loss_hard: 0 at the end.
    • After training, this will produce a file, gate.c.
  • Compile gate.c using your preferred c compiler:
    • gcc gate.c -O3 -o gate -Wall -Wextra
    • Run with ./gate
  • For benchmarking, comment out visualization
    • In gate.c run C-f to find comment out, three lines
  • Benchmark with time ./gate
    • This runs 100k steps of GOL on a random board
    • Record how long it takes -> bench_time:
      • On my machine, program finishes in 4.08s.
  • Compute the speedup:
    • (512 * s/epoch) / (bench_time / 100000)
    • As a lower bound, I got a 1,744x speedup.
      • (When I benched, I modified main.py to not record weight update time.)

Hardware: 2023 MacBook Pro M3, 18GB

journal

2025-05-27

  • well, I'm back again. Not planning to spend too much more time on this.
  • I implemented better renaming of variables for codegen.
  • I have a few fun ideas:
    • Try soft weights. I'll duplicate main.py to another file file, soft.py.
    • Reintegration Tracking is a technique for fluid simulation using CA. I've implemented it before. I might try to get it working, because then I could try to learn a logic circuit for fluid simulation, which would be crazy.
    • I've implemented single-threaded bit-parallel conway's game of life in C, but this is so embarrassingly parallel that a GPU compute shader implementation using e.g. WebGPU or slang might be in order.
  • I really want to write a blog post. So I'll start on that before I do anything else.
    • I'll write it for Anthony, will be fun.
    • Still working on it ... sigh
    • Okay I finished and it's tomorrow that took way too long. Published. Night.

2025-05-26

  • back!
  • where I left off:
    • got a big network converging
    • didn't get perfect convergence
      • probably due to wiring
  • what to try:
    • better wiring
    • soft dot weights
    • extracting discrete network at the end
  • planning better wiring
    • essentially, we need to generate unique pairs of wires, and shuffle those.
      • If input m = 2*n output, we can do:
        • (1, 2), (3, 4), ...
      • If input m = n output, we can do:
        • (1, 2), (3, 4), ...
        • (2, 3), (4, 5), ...
      • This is what they do in the paper
        • I think the extend it a little further
        • but anything in m : [n, 2*n] is possible
    • I will implement this approach:
      • First pair up: (1, 2), (3, 4), ...
      • Once that is exhausted, pair up: (2, 3), (4, 5), ...
      • Once that is exhausted, do random pairs
    • This is simple but should be better than what I have now.
      • Wait, I can just use itertools.combinations!
        • Nevermind, I want a uniform distribution of unique pairs
  • implementing better wiring
    • 25 ms/epoch
    • no tree: epoch 3000, loss 0.0245, hard 0.123
    • with tree: epoch 3000, loss 0.00342, hard 0.00781
    • comb with tree: epoch 3000, loss 0.0145, hard 0.0215
    • comb no tree: epoch 3000, loss 0.0337, hard 0.0801
  • okay, so the fancier wiring did not yield better results
    • maybe I should use the unique wiring used in the paper, exactly as described
      • YESYEYSYEYSYEYS!
      • epoch 3000, loss 9.91e-05, hard 0
        • HARD 0
          • we learned a perfect circuit for conway's!
  • mission complete.
    • Now all the fun stuff:
      • extracting the network
      • writing a little interactive demo and animations
      • writing a blog post
  • let me commit here
  • wait, I jumped the gun. I still want to do soft dot:
    • the layer generation code is good, I just need to i.e. multiply by 10
    • I also need to do the soft dot
    • and make the wires trainable
  • I'm going to duplicate main.py to create a alternate version with soft dot: soft.py:
  • before I actually implement that, I'm going to implement code for extracting the learned hard network:
    • I wrote code to generate some c code, but I really should prune the network before outputting it
    • I wrote code to prune, nice
  • let me try a training run, output hard to gate.c
    • Output! code has lots of copies, I should probably also implement copy prop sigh
  • going to take a break, commit here.
    • ANDs: 69
    • ORs: 68
    • XORs: 17
    • NOTs: 16
  • back from break, compiling to c:
    • implemented copy prop
    • implemented dead code elim
    • implemented variable renaming
    • implemented formatting with template
  • compiles and it works!
    • python3 main.py
    • gcc gate.c -O3 -o gate
    • ./gate
  • This is so frickin' cool. I just trained a neural network and compiled the weights to C, at 1 bit quantization.
    • also, note that cell is a uint64_t, so conway in gate.c actually performs parallel evaluation of 64 cells at once (e.g. an 8x8 tile)!
  • I'm going to try to benchmark this:
    • Well, I implemented a little visualizer TUI thing, which is cool!
    • On a 512x512 grid, battery saver on, it runs at a rate of
      • 100k steps in 4.08s
      • or 41 μs/step
      • or 24509 steps/second, fps, whatever
    • In comparison to the jax version, for a batch of 512, we were getting about 25 ms/epoch
      • that would mean 512x512 = 12.8 seconds/step
        • Might be unfair, jax does parallelize well
        • I mean, let's test 512x512 batch size...
          • Okay, it's like 6.62 s/epoch ~ 6.62 s/step
          • which is better!
    • We went from:
      • 6.62 s/step network
      • 41 μs/step c program
    • That's a 162,249.58x speedup.
  • Okay, that's obviously too good. After fixing up main.py:
    • python:0.000139 s/epoch
    • compiled: 41 μs/step
    • speedup: 1,744.31x.
    • A 1744x speedup is still very good!
  • I asked my brother draw a cool person unraveling for the logo.
  • Probably about done with this project, I might write a blog post though.
  • Cleaned up the README a bit. Night!

2025-05-24

  • some things to try:
    • use a real optimizer like adamw via optax
    • implement sparse static weights
    • implement "generalized" gates
      • instead of having a gate with 16 relaxations,
      • we have one gate with parameters
      • idea is to find a better basis that spans the same function space
  • where the model is at:
    • network layout [9, 48, 48, 1]
    • normalized weights and relu
    • batch size 20
    • step size 0.1
    • converges, ~2500 epochs, loss < 1e-7
    • ~3ms / epoch
  • if I try the full soft gate, everything else the same:
    • no dice, 15000 epochs, loss is ~.18-.29
    • ~2ms / epoch
  • if I try the mini soft gate, AND XOR OR NAND:
    • blows up, ~5100 epochs, loss is 0.0611 before
    • ~1-2ms / epoch
  • using a real optimizer
    • adamw via optax
      • that wasn't too hard to implement
      • I had to disable jit though, getting an error,
        • I should look into that so I can reenable jit
    • using hyperparams:
      • learning rate: 0.05
      • b1, b2: 0.9, 0.99
      • weight decay: 1e-2
    • training with soft mini gate, seems to be having a hard time converging, loss goes up crazy then comes back down
      • 18ms / epoch
      • maybe my learning rate is too high?
      • or batch size is too small?
      • it got down to 0.06 at around 11000
      • so obviously it's learning, it's just hard
    • I'm going to try again, with batch size: 512:
      • this seems to be converging
        • scratch that, "blowing up less"
      • epoch 5000, loss 0.0695
        • but it's going up and down
      • epoch 7000, loss 0.312
        • case in point
      • no dice, epoch 10000, loss 0.112
        • oscillating between ~0.05 and ~0.15
    • And another experiment, with:
      • batch size: 20
        • back to the old baseline, want to see just learning rate
      • learning rate: 0.01
      • it got down to 0.012 then blew up to 0.51
        • I wonder if the learning rate is too high?
      • new min, epoch ~7000, loss is 0.0038
      • epoch 14700, loss ~0.001
        • certainly is converging better
    • let's try
View on GitHub
GitHub Stars65
CategoryDevelopment
Updated6d ago
Forks12

Languages

Python

Security Score

80/100

Audited on Mar 22, 2026

No findings