CaloDiffusion
Diffusion models for calorimeter simulation
Install / Use
/learn @OzAmram/CaloDiffusionREADME
CaloDiffusion unofficial repository (WIP) - 2.0
Implemented with Pytorch 2.0. Dependencies listed in the pyproject.yaml
Install with
git clone https://github.com/[fork]/CaloDiffusion.git
pip install -e .
Installation can be tested with
pip install pytest pytest-dependency
python3 -m pytest tests/test_execution.py
Locations of data used for testing can be set with the --data-dir option during execution, and there are additional options for setting the location of the HGCalShowers and CaloChallenge directories with the options --hgcalshowers and --calochallenge options respectively.
HGCal Tests are run using a mocked dataset using random distributions, and can be run independently using python3 -m pytest tests/test_execution.py -m "hgcal", or excluded with -m "not hgcal"
Data
Results are presented using the Fast Calorimeter Data Challenge dataset and are available for download on zenodo:
Run the training scripts with
calodif-train
-d DATA-DIR \
-c CONFIG \
--checkpoint SAVE-DIR \
MODEL-TYPE
- Example configs in
[config_dataset1.json/config_dataset2.json/config_dataset3.json] - Additional options can be seen with
calodif-train --help
Sampling with the learned model
calodif-inference
--n-events N \
-c CONFIG \
sample MODEL-TYPE
- Additional options can be found with
calodif-inference --help
Creating the plots shown in the paper
calodif-inference
--n-events N \
-c CONFIG \
plot \
--generated RESULTS-H5F
Repository Structure and Contributing
This repository is broken into 4 main parts:
1. Scripts
calodiffusion/inference.py and calodiffusion/train.py allow for CLI based inference and training, consider them the client for the rest of the repository.
Functionality can be seen using --help menus.
2. Train
The base calodiffusion/train/train.py class is an abstract class that can load all the necessary functions and data for training a model.
It also contains saving methods.
It is a "driver" class, only providing minimal instructions on how to iterate through batches of data during training or inference.
Subclasses of Train have two necessary functions - init_model and training_loop.
Init model returns a specific initialized calodiffusion/model/diffusion object that will be used in training, and training_loop defines how a single batch of data is processed.
calodiffusion/train/evaluation.py contains extra metrics to quantify the success of training.
3. Models
Diffusion Models
The base model class in calodiffusion is calodiffusion/models/diffusion.py.
This abstract class contains methods for loading a specific sampler for inference, which loss is used, and how training or inference forward passes are performed.
It can also define specific ways .pt trained weights are loaded in the case of models with different moving pieces.
Diffusion is meant to have a (or several) pytorch.nn.module attributes to be used, not be a subclass of pytorch.nn.module itself.
A subclass of Diffusion has 4 required functions:
- init_model - provide a pytorch.nn.module object to be assigned to self.model
- forward - define how that model takes data and provides a prediction. Can be as simple as calling model.forward(). Called during training.
- __call__ - Define how denoising is done for a specific model. Called during inference.
- noise_generation - Generate noise for each inference step. Provides a generic "default_noise" option, but each subclass must confirm that they are using this default.
Samplers
Functions used in the denoising process to condition input for each step in the process.
Additional settings for each sampler can be set using the SAMPLER_OPTIONS of the configuration.
The selected sampler is using in diffusion.sample.
Loss
Loss is calculated in 2 stages - the loss metric, and then the loss calculation. Metrics can be mixed and matched with calculations. These are both set in the config.json file. The metric defines how the prediction is processed to be compared with ground truth values, and the calculation defines how they are numerically compared (using an L1 loss, MSE, etc).
4. Utils
A catch-all category for small utility functions used across training, inference, and evaluation.
Related Skills
node-connect
342.5kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
85.3kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
342.5kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
342.5kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
