FlaxDiff
A simple, easy-to-understand library for diffusion models using Flax and Jax. Includes detailed notebooks on DDPM, DDIM, and EDM with simplified mathematical explanations. Made as part of my journey for learning and experimenting with generative AI.
Install / Use
/learn @AshishKumar4/FlaxDiffREADME
This project is partially supported by Google TPU Research Cloud. I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.
A Versatile and simple Diffusion Library
In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
FlaxDiff is a library of tools (schedulers, samplers, models, etc.) designed and implemented in an easy-to-understand way. The focus is on understandability and readability over performance. I started this project as a hobby to familiarize myself with Flax and Jax and to learn about diffusion and the latest research in generative AI.
I initially started this project in Keras, being familiar with TensorFlow 2.0, but transitioned to Flax, powered by Jax, for its performance and ease of use. The old notebooks and models, including my first Flax models, are also provided.
The Diffusion_flax_linen.ipynb notebook is my main workspace for experiments. Several checkpoints are uploaded to the pretrained folder along with a copy of the working notebook associated with each checkpoint. You may need to copy the notebook to the working root for it to function properly.
Example Notebooks from scratch
In the example notebooks folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
Available Notebooks and Resources
-
Diffusion explained (nbviewer link) (local link)
- WORK IN PROGRESS An in-depth exploration of the concept of Diffusion based generative models, DDPM (Denoising Diffusion Probabilistic Models), DDIM (Denoising Diffusion Implicit Models), and the SDE/ODE generalizations of diffusion, with step-by-step explainations and code.
-
EDM (Elucidating the Design Space of Diffusion-based Generative Models)
- TODO A thorough guide to EDM, discussing the innovative approaches and techniques used in this advanced diffusion model.
These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
Other resources
-
Multi-host Data parallel training script in JAX
- Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
-
TPU utilities for making life easier
- A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
Disclaimer (and About Me)
I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
Also, few of the text may be generated with help of github copilot, so please excuse any mistakes in the text.
Index
- A Versatile and Easy-to-Understand Diffusion Library
- Disclaimer (and About Me)
- Features
- Installation of FlaxDiff
- Getting Started with FlaxDiff
- References and Acknowledgements
- Pending things to do list
- Gallery
- Contribution
- License
Features
Schedulers
Implemented in flaxdiff.schedulers:
- LinearNoiseSchedule (
flaxdiff.schedulers.LinearNoiseSchedule): A beta-parameterized discrete scheduler. - CosineNoiseScheduler (
flaxdiff.schedulers.CosineNoiseScheduler): A beta-parameterized discrete scheduler. - ExpNoiseSchedule (
flaxdiff.schedulers.ExpNoiseSchedule): A beta-parameterized discrete scheduler. - CosineContinuousNoiseScheduler (
flaxdiff.schedulers.CosineContinuousNoiseScheduler): A continuous scheduler. - CosineGeneralNoiseScheduler (
flaxdiff.schedulers.CosineGeneralNoiseScheduler): A continuous sigma parameterized cosine scheduler. - KarrasVENoiseScheduler (
flaxdiff.schedulers.KarrasVENoiseScheduler): A sigma-parameterized continuous scheduler proposed by Karras et al. 2022, best suited for inference. - EDMNoiseScheduler (
flaxdiff.schedulers.EDMNoiseScheduler): A sigma-parameterized continuous scheduler based on the Exponential Diffusion Model (EDM), best suited for training with the KarrasKarrasVENoiseScheduler.
Model Predictors
Implemented in flaxdiff.predictors:
- EpsilonPredictor (
flaxdiff.predictors.EpsilonPredictor): Predicts the noise in the data. - X0Predictor (
flaxdiff.predictors.X0Predictor): Predicts the original data from the noisy data. - VPredictor (
flaxdiff.predictors.VPredictor): Predicts a linear combination of the data and noise, commonly used in the EDM. - KarrasEDMPredictor (
flaxdiff.predictors.KarrasEDMPredictor): A generalized predictor for the EDM, integrating various parameterizations.
Samplers
Implemented in flaxdiff.samplers:
- DDPMSampler (
flaxdiff.samplers.DDPMSampler): Implements the Denoising Diffusion Probabilistic Model (DDPM) sampling process. - DDIMSampler (
flaxdiff.samplers.DDIMSampler): Implements the Denoising Diffusion Implicit Model (DDIM) sampling process. - EulerSampler (
flaxdiff.samplers.EulerSampler): An ODE solver sampler using Euler's method. - HeunSampler (
flaxdiff.samplers.HeunSampler): An ODE solver sampler using Heun's method. - RK4Sampler (
flaxdiff.samplers.RK4Sampler): An ODE solver sampler using the Runge-Kutta method. - MultiStepDPM (
flaxdiff.samplers.MultiStepDPM): Implements a multi-step sampling method inspired by the Multistep DPM solver as presented here: tonyduan/diffusion)
Training
Implemented in flaxdiff.trainer:
- DiffusionTrainer (
flaxdiff.trainer.DiffusionTrainer): A class designed to facilitate the training of diffusion models. It manages the training loop, loss calculation, and model updates.
Models
Implemented in flaxdiff.models:
- UNet (
flaxdiff.models.simple_unet.SimpleUNet): A sample UNET architecture for diffusion models. - Layers: A library of layers including upsampling (
flaxdiff.models.simple_unet.Upsample), downsampling (flaxdiff.models.simple_unet.Downsample), Time embeddings (flaxdiff.models.simple_unet.FouriedEmbedding), attention (flaxdiff.models.simple_unet.AttentionBlock), and residual blocks (flaxdiff.models.simple_unet.ResidualBlock).
Installation
To install FlaxDiff, you need to have Python 3.10 or higher. Install the required dependencies using:
pip install -r requirements.txt
The models were trained and tested with jax==0.4.28 and flax==0.8.4. However, when I updated to the latest jax==0.4.30 and flax==0.8.5, the models stopped training. There seems to have been some major change breaking the training dynamics and therefore I would recommend sticking to the versions mentioned in the requirements.txt
Getting Started
Training Example
Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer import DiffusionTrainer
from flaxdiff.data.datasets import get_dataset_grain
from flaxdiff.utils import defaultTextEn
Related Skills
qqbot-channel
345.4kQQ 频道管理技能。查询频道列表、子频道、成员、发帖、公告、日程等操作。使用 qqbot_channel_api 工具代理 QQ 开放平台 HTTP 接口,自动处理 Token 鉴权。当用户需要查看频道、管理子频道、查询成员、发布帖子/公告/日程时使用。
docs-writer
100.0k`docs-writer` skill instructions As an expert technical writer and editor for the Gemini CLI project, you produce accurate, clear, and consistent documentation. When asked to write, edit, or revie
model-usage
345.4kUse CodexBar CLI local cost usage to summarize per-model usage for Codex or Claude, including the current (most recent) model or a full model breakdown. Trigger when asked for model-level usage/cost data from codexbar, or when you need a scriptable per-model summary from codexbar cost JSON.
ddd
Guía de Principios DDD para el Proyecto > 📚 Documento Complementario : Este documento define los principios y reglas de DDD. Para ver templates de código, ejemplos detallados y guías paso

