CWMI
Complex Wavelet Mutual Information Loss: A Multi-Scale loss function for Semantic Segmentation
Install / Use
/learn @lurenhaothu/CWMIREADME
CWMI
The PyTorch implementation of the CWMI loss proposed in paper: Complex Wavelet Mutual Information Loss: A Multi-Scale loss function for Semantic Segmentation. <br>
Abstract
Recent advancements in deep neural networks have significantly enhanced the performance of semantic segmentation. However, class imbalance and instance imbalance remain persistent challenges, particularly in biomedical image analysis, where smaller instances and thin boundaries are often overshadowed by larger structures. To address the multiscale nature of segmented objects, various models have incorporated mechanisms such as spatial attention and feature pyramid networks. Despite these advancements, most loss functions are still primarily pixel-wise, while regional and boundary-focused loss functions often incur high computational costs or are restricted to small-scale regions. To address this limitation, we propose complex wavelet mutual information (CWMI) loss, a novel loss function that leverages mutual information from subband images decomposed by a complex steerable pyramid. Unlike discrete wavelet transforms, the complex steerable pyramid captures features across multiple orientations, remains robust to small translations, and preserves structural similarity across scales. Moreover, mutual information is well-suited for capturing high-dimensional directional features and exhibits greater noise robustness compared to prior wavelet-based loss functions that rely on distance or angle metrics. Extensive experiments on diverse segmentation datasets demonstrate that CWMI loss achieves significant improvements in both pixel-wise accuracy and topological metrics compared to state-of-the-art methods, while introducing minimal computational overhead.
<p align = "center"> <img src="figures/Figure 1.PNG"> </p>Environment
Ensure you have the following dependencies installed:
Python 3.12.8
PyTorch 2.5.1+cu121
Additional dependencies are listed in requirements.txt. Install them using:
pip install -r requirements.txt
Dataset Preparation
Download the datasets and place them in the following directories:
| Dataset | Download Link | Expected Directory |
|--------------|--------------|--------------------|
| SNEMI3D | Zenodo | ./data/snemi3d/ |
| GlaS | Kaggle | ./data/GlaS/Warwick_QU_Dataset/ |
| DRIVE | Kaggle | ./data/DRIVE/Drive_source/ |
| MASS ROAD | Kaggle | ./data/mass_road/mass_road_source/ |
Usage
1. Prepare Datasets
To preprocess datasets and calculate the mean and standard deviation for each dataset, run:
python ./data/dataprepare.py
2. Train the Model
Run the training script:
run train.ipynb
3. Evaluate the Model
Run the evaluation script:
run eval.ipynb
Additional Weight Map-Based Loss Functions
CWMI does not require additional data preparation. However, if you wish to test other weight map-based loss functions, run the corresponding scripts:
- U-Net weighted cross entropy (WCE) (arXiv:1505.04597)
python ./data/map_gen_unet.py - ABW loss (arXiv:1905.09226v2)
python ./data/map_gen_ABW.py - Skea_topo loss (arXiv:2404.18539)
python ./data/skeleton_aware_loss_gen.py python ./data/skeleton_gen.py
Results
<p align = "center"> <img src="figures/Table 1.png"> </p>SNEMI3D
<p align = "center"> <img src="figures/Figure 3.PNG"> </p>GlaS
<p align = "center"> <img src="figures/Figure 4.PNG"> </p>DRIVE
<p align = "center"> <img src="figures/Figure 5.PNG"> </p>MASS ROAD
<p align = "center"> <img src="figures/Figure 6.PNG"> </p>Computational cost
<p align = "center"> <img src="figures/Table 4.png"> </p>Related Skills
node-connect
351.8kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
110.9kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
351.8kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
351.8kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
