MMedAgent
Learning to Use Medical Tools with Multi-modal Agent
Install / Use
/learn @Wangyixinxin/MMedAgentREADME
MMedAgent: Learning to Use Medical Tools with Multi-modal Agent
The first multimodal medical AI Agent incorporating a wide spectrum of tools to handle various medical tasks across different modalities seamlessly.
[Paper, EMNLP 2024] [Demo (*NOTE: This is a temporary link. Please follow [Build Web UI and Server] to build your own server. *)]
Binxu Li, Tiankai Yan, Yuanting Pan, Jie Luo, Ruiyang Ji, Jiayuan Ding, Zhe Xu, Shilong Liu, Haoyu Dong*, Zihao Lin*, Yixin Wang*
<div style="text-align: center;"> <img src="imgs/mmedagent.jpg" alt="MMedAgent" style="width: 50%;"/> <img src="imgs/instruction-tuning-data.jpg" alt="Instruction Tuning Data" style="width: 50%;"/> </div>Current Tool lists
| Task | Tool | Data Source | Imaging Modality | |----------------|------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------| | VQA | LLaVA-Med | PMC article<br>60K-IM | MRI, CT, X-ray, Histology, Gross | | Classification | BiomedCLIP | PMC article<br>60K-IM | MRI, CT, X-ray, Histology, Gross | | Grounding | Grounding DINO | WORD, etc.<br> | MRI, CT, X-ray, Histology | | Segmentation with bounding-box prompts (Segmentation) | MedSAM | WORD, etc. | MRI, CT, X-ray, Histology, Gross | | Segmentation with text prompts (G-Seg) | Grounding DINO + MedSAM | WORD, etc.* | MRI, CT, X-ray, Histology | | Medical report generation (MRG) | ChatCAD | MIMIC-CXR | X-ray | | Retrieval augmented generation (RAG) | ChatCAD+ | Merck Manual | -- |
Note: -- means that the RAG task only focuses on natural language without handling images. WORD, etc.* indicates various data sources including WORD, FLARE2021, BRATS, Montgomery County X-ray Set (MC), VinDr-CXR, and Cellseg.
Usage
- Clone this repo
git clone https://github.com/Wangyixinxin/MMedAgent.git
- Create environment
cd MMedAgent
conda create -n mmedagent python=3.10 -y
conda activate mmedagent
pip install --upgrade pip # enable PEP 660 support
pip install -e .
- Additional packages required for training
pip install -e ".[train]"
pip install flash-attn --no-build-isolation
Model Download
MMedAgent Checkpoint
Model checkpoints (lora) and instruction-tuning data can be downloaded here
Download the model and data by following:
git lfs install
git clone https://huggingface.co/andy0207/mmedagent
Base Model Download
The model weights below are delta weights. The usage of LLaVA-Med checkpoints should comply with the base LLM's model license: LLaMA.
The delta weights for LLaVA-Med are provided. Please download following the below instructions or see details in (LLaVA-Med)[https://github.com/microsoft/LLaVA-Med/tree/v1.0.0]
Model Descriptions | Model Delta Weights | Size | | --- | --- | ---: | | LLaVA-Med | llava_med_in_text_60k_ckpt2_delta.zip | 11.06 GB |
Instructions:
- Download the delta weights above and unzip the files.
- Download the original LLaMA weights (llama-7b in our model) in the huggingface format by following the instructions here.
- Use the following scripts to get original LLaVA-Med (LLaVA-Med 7b in our model) weights by applying the above delta weights. In the script below, set the --delta argument to the path of the unzipped delta weights directory from step 1 and --target as the output folder.
python3 -m llava.model.apply_delta \
--base /path/to/llama-7b \
--target ./base_model \
--delta /path/to/llava_med_delta_weights
Train
Training with lora:
deepspeed llava/train/train_mem.py \
--lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
--deepspeed ./scripts/zero2.json \
--model_name_or_path ./base_model \
--version v1\
--data_path ./train_data_json/example.jsonl \
--image_folder ./train_images \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length False \
--bf16 True \
--output_dir ./checkpoints/final_model_lora \
--num_train_epochs 30 \
--per_device_train_batch_size 12 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 3000 \
--save_total_limit 2 \
--learning_rate 2e-4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
or use tuning.sh
Evaluation
Apply lora (if you enable lora during training)
Download the MMedAgent checkpoints (lora) here and set --model-path as this folder.
CUDA_VISIBLE_DEVICES=0 python scripts/merge_lora_weights.py \
--model-path ./checkpoints/final_model_lora \
--model-base ./base_model \
--save-model-path ./llava_med_agent
or use merge.sh
Inference
CUDA_VISIBLE_DEVICES=0 python llava/eval/model_vqa.py \
--model-path ./llava_med_agent \
--question-file ./eval_data_json/eval_example.jsonl \
--image-folder ./eval_images \
--answers-file ./eval_data_json/output_agent_eval_example.jsonl \
--temperature 0.2
or use eval.sh
GPT-4o inference
python llava/eval/eval_gpt4o.py \
--api-key "your-api-key" \
--question ./eval_data_json/eval_example.jsonl \
--output ./eval_data_json/output_gpt4o_eval_example.jsonl \
--max-tokens 1024
or run eval_gpt4o.sh
GPT-4 evaluation
All the outputs will be assessed by GPT-4 and rated on a scale from 1 to 10 based on their helpfulness, relevance, accuracy, and level of details. Check our paper for detailed evaluation.
python ./llava/eval/eval_gpt4.py \
--question_input_path ./eval_data_json/eval_example.jsonl \
--input_path ./eval_data_json/output_gpt4o_eval_example.jsonl
--output_path ./eval_data_json/compare_gpt4o_medagent_reivew.jsonl
or run eval_gpt4.sh
Data Download
Instruction-tuning Dataset
We build the first open-source instruction tuning dataset for multi-modal medical agents.
| Data | size | | --- | --- | | instruction_all.json | 97.03 MiB |
Download the data by:
git lfs install
git clone https://huggingface.co/andy0207/mmedagent
Note: The images themselves are not owned by us and therefore not included in our instruction data. If you use any images in the dataset, please follow their original licenses. If you would like to process the data and construct the full dataset, you can follow the instructions provided in the next section.
Tool dataset (Selected)
Grounding task dataset
Please download the following segmentation dataset and refer to the following codes to process the data into required data format for grounding task.
Related Skills
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
best-practices-researcher
The most comprehensive Claude Code skills registry | Web Search: https://skills-registry-web.vercel.app
groundhog
398Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
isf-agent
a repo for an agent that helps researchers apply for isf funding
