FLEX
Knowledge-Guided Adaptation of Pathology Foundation Models Improves Cross-domain Generalization and Demographic Fairness
Install / Use
/learn @HKU-MedAI/FLEXREADME
Overview
The advent of foundation models has ushered in a transformative era in computational pathology, enabling the extraction of rich, transferable image features for a broad range of downstream pathology tasks. However, site-specific signatures and demographic biases persist in these features, leading to short-cut learning and unfair predictions, ultimately compromising model generalizability and fairness across diverse clinical sites and demographic groups.
This repository implements FLEX, a novel framework that enhances cross-domain generalization and demographic fairness of pathology foundation models, thus facilitating accurate diagnosis across diverse pathology tasks. FLEX employs a task-specific information bottleneck, informed by visual and textual domain knowledge, to promote:
- Generalizability across clinical settings
- Fairness across demographic groups
- Adaptability to specific pathology tasks

Features
- Cross-domain generalization: Significantly improves diagnostic performance on data from unseen sites
- Demographic fairness: Reduces performance gaps between demographic groups
- Versatility: Compatible with various vision-language models
- Scalability: Adaptable to varying training data sizes
- Seamless integration: Works with multiple instance learning frameworks
Installation
Setup
-
Clone the repository:
git clone https://github.com/HKU-MedAI/FLEX cd FLEX -
Create and activate a virtual environment, and install the dependencies:
conda env create -f environment.yml conda activate flex
Instructions for Use
Data Preparation
Prepare your data in the following structure:
Dataset/
├── TCGA-BRCA/
│ ├── features/
│ │ ├── ...
│ ├── tcga-brca_label.csv
│ ├── tcga-brca_label_her2.csv
│ └── ...
├── TCGA-NSCLC/
└── ...
Visual Prompts
Organize visual prompts in the following structure:
prompts/
├── BRCA/
│ ├── 0/
│ │ ├── image1.png
│ │ └── ...
│ └── 1/
│ ├── image1.png
│ └── ...
├── BRCA_HER2/
└── ...
Running on Your Data
Generate Data Splits
Generate site-preserved Monte Carlo Cross-Validation (SP-MCCV) splits for your dataset:
python generate_sitepreserved_splits.py
python generate_sp_mccv_splits.py
Extract Features
Due to the large size of WSIs, patch-level features must be extracted first. We recommend using established pipelines like CLAM or TRIDENT. This is a computationally intensive step. Extracted features (e.g., in .h5 format) should be placed in the features/ subdirectory for each dataset.
Prepare Visual and Textual Concepts
- Visual Prompts: As described in our paper, visual prompts are representative patches for each class. We provide the visual prompts used in our experiments in the
prompts/directory. For custom tasks, you will need to generate your own. - Textual Prompts: Textual concepts are defined within the code/configuration files. These are crucial for guiding the information bottleneck. Please refer to
config.py(or similar file) to see how task-specific prompts like "invasive ductal carcinoma" are defined.
Train the FLEX model
Train the FLEX model and evaluate the performance:
bash ./scripts/train_flex.sh
<!-- ### Reproduction Instructions
To reproduce the results in our paper:
1. Download the datasets mentioned in the paper (TCGA-BRCA, TCGA-NSCLC, TCGA-STAD, TCGA-CRC)
2. Extract features using [CLAM](https://github.com/mahmoodlab/CLAM) or [TRIDENT](https://github.com/mahmoodlab/TRIDENT).
3. Run the following commands:
```bash
# Generate splits
python generate_sitepreserved_splits.py
python generate_sp_mccv_splits.py
# For each task, modify the task parameter in train_flex.sh and run the script to train the model
bash ./scripts/train_flex.sh
```
4. For specific tasks or customizations, refer to the key parameters section below. -->
Key Parameters
--task: Task name (e.g., BRCA, NSCLC, STAD_LAUREN)--data_root_dir: Path to the data directory--split_suffix: Split suffix (e.g., sitepre5_fold3)--exp_code: Experiment code for logging and saving results--model_type: Model type (default: flex)--base_mil: Base MIL framework (default:abmil)--slide_align: Whether to align in slide level (default: 1)--w_infonce: Weight for InfoNCE loss (default: 14)--w_kl: Weight for KL loss (default: 14)--len_prompt: Number of learnable textual prompt tokens
Evaluation Results
FLEX has been evaluated on 16 clinically relevant tasks and demonstrates:
- Improved performance on unseen clinical sites
- Reduced performance gap between seen and unseen sites
- Enhanced fairness across demographic groups
For detailed results, refer to our paper.
License
This project is licensed under the Apache-2.0 license.
Acknowledgments
This project was built on the top of amazing works, including CLAM, CONCH, QuiltNet, PathGen-CLIP, and PreservedSiteCV. We thank the authors for their great works.
Related Skills
node-connect
352.9kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
111.5kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
352.9kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
352.9kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
