LDAE
Official PyTorch implementation of "Latent Diffusion Autoencoders: Toward Efficient and Meaningful Unsupervised Representation Learning in Medical Imaging". LDAE is a novel unsupervised framework for 3D medical imaging that combines a latent diffusion model with semantic controls.
Install / Use
/learn @GabrieleLozupone/LDAEREADME
🧠 Latent Diffusion Autoencoders (LDAE)
Official implementation of the paper: "Latent Diffusion Autoencoders: Toward Efficient and Meaningful Unsupervised Representation Learning in Medical Imaging (Medical Image Analysis 2026)".
<p align="center"> <img src="readme_files/LDAE.png" alt="LDAE Framework Overview"/> </p> <p align="center"> <em>A novel unsupervised diffusion-based foundation model for representation learning in 3D medical imaging</em> </p>📑 Table of Contents
- Overview
- Key Features
- Model Architecture
- Framework Capabilities
- Installation
- Data Preparation
- Usage
- Results
- Code Structure
- Use Cases
- Citation
- Acknowledgements
🔍 Overview
Latent Diffusion Autoencoders (LDAE) is a novel unsupervised framework for representation learning in 3D medical imaging. The method compresses 3D MRI scans using an AutoencoderKL, then applies a denoising diffusion process in the compressed latent space, guided by a semantic encoder. The framework is evaluated on Alzheimer's Disease (AD) brain scans from the ADNI dataset.
🌟 Key Features
-
First latent diffusion autoencoder framework for 3D medical images - LDAE learns semantic representations without supervision, suitable for classification, manipulation, and interpolation of brain scans.
-
20x faster inference than voxel-space DAEs while improving reconstruction fidelity.
-
Semantic control - Enables manipulation of clinically meaningful features (e.g., Alzheimer's atrophy patterns), smooth interpolation, and scan synthesis from pure noise.
-
Excellent generalizability - Representations are effective in downstream tasks such as AD classification (ROC-AUC: 89.48%) and age prediction (MAE: 4.16 years).
🏗 Model Architecture
LDAE consists of four main components:
-
AutoencoderKL - Compresses 3D MRI scans (1×128×160×128) to a low-dimensional latent space (3×16×20×16).
-
Latent Diffusion Model (LDM) - A DDPM trained on compressed latents, modeling p(z) with a 3D UNet backbone.
-
Semantic Encoder (
Encϕ) - A 2.5D network that processes slices with a 2D CNN and attention aggregation, producing a semantic vector y_sem ∈ R768. -
Gradient Estimator (
Gψ) - A conditional decoder approximating ∇z log p(y|z), guiding the reverse diffusion.
The framework operates by first compressing the high-dimensional MRI scans into a more manageable latent space, where the diffusion process can operate efficiently. The semantic encoder extracts meaningful clinical features from the MRI scans, which are then used to guide the generative process. This approach allows for both efficient representation learning and semantic control over the generated outputs.
🧠 Framework Capabilities
🔄 Semantic Manipulation
After training a linear classifier on the semantic embeddings y_sem, the weight vector w represents the direction of a clinical trait (e.g., Alzheimer's Disease vs Cognitively Normal).
- Extract the semantic vector from a brain scan using the semantic encoder:
y_sem = Encϕ(x)and the stochastic latent vector:z_T = DDIM_encode(z)in whichzis the compressed representation of the brain scan generated by the AutoencoderKL. - Modify a scan's semantics by:
y_manip = y_sem + α * w - Decode using
DDIM_sample(y_manip, z_T)to simulate progression or regression of clinical traits - Enables fine-grained control over clinically relevant features (e.g., ventricle shrinkage, hippocampal recovery)
This capability can be used for counterfactual explanations, showing how a brain would appear with or without disease-related features:
<p align="center"> <img src="readme_files/manipulation.png" alt="Semantic Manipulation for Counterfactual Explanations"/> </p>🎲 Stochastic Variation
With a fixed semantic code y_sem, we can sample multiple noise vectors z_T ~ N(0, I) to generate variations of a brain scan with a similar semantic content this helps to visualize the semantic space:
x_hat = DDIM_sample(y_sem, z_T)
Each sample maintains the same core anatomical structure but varies in low-level details:
<p align="center"> <img src="readme_files/stochastic_variations.png" alt="Stochastic Variations with Fixed Semantics"/> </p>🔄 Interpolation
Given two brain scans A and B, LDAE can interpolate between them in both semantic and stochastic spaces:
y_t = (1 - t) * y1 + t * y2 # Linear interpolation (LERP) in semantic space
z_t = SLERP(z1, z2, t) # Spherical interpolation (SLERP) in latent space
These are decoded via x_hat = DDIM_sample(y_t, z_t). This approach is especially useful for simulating disease progression or creating smooth transitions between time points:
Interpolation enables visualization of brain evolution over time and can help simulate missing timepoints in longitudinal studies.
🔧 Installation
Option 1: Using Docker
# Build the Docker image
docker build -t ldae .
# Run the container with appropriate volume mounts
docker run -it --gpus all -v /path/to/data:/data -v /path/to/project:/app/LDAE ldae
Option 2: Using Anaconda (Recommended)
# Create and activate a new conda environment
conda create -n ldae python=3.11
conda activate ldae
# Install dependencies
pip install -r requirements.txt
Pretrained Models
Download the following pretrained models and place them in the models/ directory to run the tutorial:
📊 Data Preparation
For data preparation, please refer to the AXIAL repository's data preparation section. While the AXIAL project focused on a specific ADNI dataset collection, this project uses the full ADNI dataset of T1-weighted MRI scans. However, the preparation steps remain the same:
- Download T1-weighted MRI scans from the ADNI database
- Convert the data to BIDS format using Clinica
- Preprocess the scans for bias field correction, brain extraction and registration
- Create a CSV file with paths to preprocessed images and corresponding labels
Note that access to the ADNI dataset requires proper credentials and approval.
CSV Format
The datamodule (brain_mr_dm.py) expects a CSV file with the following columns:
subject: Subject identifierpath: Path to the preprocessed MRI scandiagnosis: Diagnosis label (e.g., AD, CN, MCI)session: Session numberlatent_path: Path to the compressed latent representation (generated after training the AutoencoderKL)age: Subject age at the time of scan
Generating Compressed Latents
After training the AutoencoderKL (first stage of the pipeline), you need to generate compressed latent representations for all your MRI scans:
python scripts/save_compressed_latent.py --csv_path /path/to/dataset.csv --ae_model_path /path/to/autoencoderkl.pth
This script will:
- Load the trained AutoencoderKL model
- Process each MRI scan to generate its compressed latent representation
- Save the latents as NPZ files
- Create a new CSV file (
*_with_compressed_latent.csv) that includes paths to these latent files
Use this updated CSV for subsequent training stages.
🚀 Usage
Complete Tutorial
A comprehensive tutorial is available in notebooks/tutorial.ipynb, which demonstrates the full capabilities of LDAE:
- Loading and preprocessing brain MRI scans
- Compressing and reconstructing with AutoencoderKL
- Generating stochastic variations with LDAE
- Reconstructing with LDAE
- Classifying scans and estimating age with the semantic encoder
- Generating intermediate scans through interpolation (LERP in semantic space, SLERP in stochastic space)
- Manipulating brain scans in Alzheimer's Disease (AD) or Cognitively Normal (CN) direction
We recommend following this tutorial to understand the full workflow of LDAE.
Training Pipeline
- Train the AutoencoderKL:
python run_cli.py fit --config config/aekl.yaml
- Pre-train the Latent Diffusion Model:
python run_cli.py fit --config config/ldae_pretrain.yaml
- Train the Representation Learning Component:
python run_cli.py fit --config config/ldae_repr_learn.yaml
Downstream Tasks Evaluation
Once you have trained the representation learning component (which includes the semantic encoder trained to fill the posterior mean gap), you can evaluate the learned representations on downstream tasks:
