CUTS
[MICCAI 2024] CUTS: A Deep Learning and Topological Framework for Multigranular Unsupervised Medical Image Segmentation
Install / Use
/learn @ChenLiu-1996/CUTSREADME
Krishnaswamy Lab, Yale University
This is the authors' PyTorch implementation of CUTS, MICCAI 2024.
The official version is maintained in the Lab GitHub repo.
A Glimpse into the Methods
<img src = "assets/architecture.png" width=800>Citation
@inproceedings{Liu_CUTS_MICCAI2024,
title = { { CUTS: A Deep Learning and Topological Framework for Multigranular Unsupervised Medical Image Segmentation } },
author = { Liu, Chen and Amodio, Matthew and Shen, Liangbo L. and Gao, Feng and Avesta, Arman and Aneja, Sanjay and Wang, Jay C. and Del Priore, Lucian V. and Krishnaswamy, Smita},
booktitle = {proceedings of Medical Image Computing and Computer Assisted Intervention -- MICCAI 2024},
publisher = {Springer Nature Switzerland},
volume = {LNCS 15008},
page = {155–165},
year = {2024},
month = {October},
}
Repository Hierarchy
UnsupervisedMedicalSeg (CUTS)
├── (*) comparison: other SOTA unsupervised methods for comparison.
|
├── checkpoints: model weights are saved here.
├── config: configuration yaml files.
├── data: folders containing data files.
├── logs: training log files.
├── results: generated results (images, labels, segmentations, figures, etc.).
|
└── src
├── (*) scripts_analysis: scripts for analysis and plotting.
| ├── `generate_baselines.py`
| ├── `generate_kmeans.py`
| ├── `generate_diffusion.py`
| ├── `plot_paper_figure_main.py`
| └── `run_metrics.py`
|
├── (*) `main.py`: unsupervised training of the CUTS encoder.
├── (*) `main_supervised.py`: supervised training of UNet/nnUNet for comparison.
|
├── datasets: defines how to access and process the data in `CUTS/data/`.
├── data_utils
├── model
└── utils
Relatively core files or folders are marked with (*).
Data Provided
As some background info, I inherited the datasets from a graduated member of the lab when I worked on this project. These datasets are already preprocessed by the time I had them. For reproducibility, I have included the berkeley_natural_images, brain_tumor and retina datasets in zip format in this directory. The brain_ventricles dataset exceeds the GitHub size limits, and can be found on Google Drive.
Please be mindful that these datasets are relatively small in sample size. If big sample size is a requirement, you can look into bigger datasets such as the BraTS challenge.
The entire dataset can also be downloaded from Huggingface.
To reproduce the results in the paper.
The following commands are using retina_seed2 as an example (retina dataset, random seed set to 2022).
cd ./data/
unzip retina.zip
</details>
<details>
<summary>Activate environment</summary>
conda activate cuts
</details>
<details>
<summary><b>Stage 1.</b> Training the convolutional encoder</summary>
To train a model.
## Under `src`
python main.py --mode train --config ../config/retina_seed2.yaml
To test a model (automatically done during train mode).
## Under `src`
python main.py --mode test --config ../config/retina_seed2.yaml
</details>
<details>
<summary>(Optional) [Comparison] Training a supervised model</summary>
## Under `src/`
python main_supervised.py --mode train --config ../retina_seed2.yaml
</details>
<details>
<summary>(Optional) [Comparison] Training other models</summary>
To train STEGO.
## Under `comparison/STEGO/CUTS_scripts/`
python step01_prepare_data.py --config ../../../config/retina_seed2.yaml
python step02_precompute_knns.py --train-config ./train_config/train_config_retina_seed2.yaml
python step03_train_segmentation.py --train-config ./train_config/train_config_retina_seed2.yaml
python step04_produce_results.py --config ../../../config/retina_seed2.yaml --eval-config ./eval_config/eval_config_retina_seed2.yaml
To train Differentiable Feature Clustering (DFC).
## Under `comparison/DFC/CUTS_scripts/`
python step01_produce_results.py --config ../../../config/retina_seed2.yaml
To use Segment Anything Model (SAM).
## Under `comparison/SAM/`
mkdir SAM_checkpoint && cd SAM_checkpoint
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
## Under `comparison/SAM/CUTS_scripts/`
python step01_produce_results.py --config ../../../config/retina_seed2.yaml
To use MedSAM.
## Under `comparison/MedSAM/`
mkdir MedSAM_checkpoint && cd MedSAM_checkpoint
download from https://drive.google.com/file/d/1ARiB5RkSsWmAB_8mqWnwDF8ZKTtFwsjl/view
## Under `comparison/SAM_Med2D/CUTS_scripts/`
python step01_produce_results.py --config ../../../config/retina_seed2.yaml
To use SAM-Med2D.
## Under `comparison/SAM_Med2D/`
mkdir SAM_Med2D_checkpoint && cd SAM_Med2D_checkpoint
download from https://drive.google.com/file/d/1ARiB5RkSsWmAB_8mqWnwDF8ZKTtFwsjl/view
## Under `comparison/SAM_Med2D/CUTS_scripts/`
python step01_produce_results.py --config ../../../config/retina_seed2.yaml
</details>
<details>
<summary><b>Stage 2.</b> Results Generation</summary>
To generate and save the segmentation using spectral k-means.
## Under `src/scripts_analysis`
python generate_kmeans.py --config ../../config/retina_seed2.yaml
To generate and save the segmentation using diffusion condensation.
## Under `src/scripts_analysis`
python generate_diffusion.py --config ../../config/retina_seed2.yaml
To generate and save the segmentation using baseline methods.
## Under `src/scripts_analysis`
python generate_baselines.py --config ../../config/retina_seed2.yaml
</details>
<details>
<summary>Results Plotting</summary>
To reproduce the figures in the paper.
There is one single script for this purpose (previously two but we recently merged them): plot_paper_figure_main.py.
The image-idx argument shall be followed by space-separated index/indices of the images to be plotted.
Without the --comparison flag, the CUTS-only results will be plotted.
With the --comparison flag, the side-by-side comparison against other methods will be plotted.
With the --grayscale flag, the input images and reconstructed images will be plotted in grayscale.
With the --binary flag, the labels will be binarized using a consistent method described in the paper.
With the --separate flag, the labels will be displayed as separate masks. Otherwise they will be overlaid. This flag is altomatically turned on (and cannot be turned off) for multi-class segmentation cases.
## Under `src/scripts_analysis`
## For natural images (berkeley), multi-class segmentation.
### Diffusion condensation trajectory.
python plot_paper_figure_main.py --config ../../config/berkeley_seed2.yaml --image-idx 8 22 89
### Segmentation comparison.
python plot_paper_figure_main.py --config ../../config/berkeley_seed2.yaml --image-idx 8 22 89 --comparison --separate
## For medical images with color (retina), binary segmentation.
### Diffusion condensation trajectory.
python plot_paper_figure_main.py --config ../../config/retina_seed2.yaml --image-idx 4 7 18
### Segmentation comparison (overlay).
python plot_paper_figure_main.py --config ../../config/retina_seed2.yaml --image-idx 4 7 18 --comparison --binary
### Segmentation comparison (non-overlay).
python plot_paper_figure_main.py --config ../../config/retina_seed2.yaml --image-idx 4 7 18 --comparison --binary --separate
## For medical images without color (brain ventricles, brain tumor), binary segmentation.
### Diffusion condensation trajectory.
python plot_paper_figure_main.py --config ../../config/brain_ventricles_seed2.yaml --image-idx 35 41 88 --grayscale
### Segmentation comparison (overlay).
python plot_paper_figure_main.py --config ../../config/brain_ventricles_seed2.yaml --image-idx 35 41 88 --grayscale --comparison --binary
### Segmentation comparison (non-overlay).
python plot_paper_figure_main.py --config ../../config/brain_ventricles_seed2.yaml --image-idx 35 41 88 --grayscale --comparison --binary --separate
### Diffusion condensation trajectory.
python plot_paper_figure_main.py --config ../../config/brain_tumor_seed2.yaml --image-idx 1 25 31 --grayscale
### Segmentation comparison (overlay).
python plot_paper_
