MedViTV2
MedViTV2: Medical Image Classification with KAN-Integrated Transformers and Dilated Neighborhood Attention (Applied Soft Computing 2025)
Install / Use
/learn @Omid-Nejati/MedViTV2README
🔥 News
- [2025.08.27] We have released the pre-trained weights.
- [2025.10.06] Our paper accepted for publication in Applied Soft Computing.
Train & Test --- Prepare data
To train or evaluate MedViT models on 17 medical datasets, follow this "Evaluation".
Important: This code also supports training all TIMM models.
Introduction
<div align="justify"> Convolutional networks, transformers, hybrid models, and Mamba-based architectures have shown strong performance in medical image classification but are typically designed for clean, labeled data. Real-world clinical datasets, however, often contain corruptions arising from multi-center studies and variations in imaging equipment. To address this, we introduce the Medical Vision Transformer (MedViTV2), the first architecture to integrate Kolmogorov–Arnold Network (KAN) layers into a transformer for generalized medical image classification. We design an efficient KAN block to lower computational cost while improving accuracy over the original MedViT. To overcome scaling fragility, we propose Dilated Neighborhood Attention (DiNA), an adaptation of fused dot-product attention that expands receptive fields and mitigates feature collapse. Additionally, a hierarchical hybrid strategy balances local and global feature perception through efficient stacking of Local and Global Feature Perception blocks. Evaluated on 17 classification and 12 corrupted datasets, MedViTV2 achieved state-of-the-art performance in 27 of 29 benchmarks, improving efficiency by 44% and boosting accuracy by 4.6% on MedMNIST, 5.8% on NonMNIST, and 13.4% on MedMNIST-C. </div> <div style="text-align: center"> <img src="https://github.com/Omid-Nejati/MedViT-V2/blob/main/Fig/ACC.png" title="MedViT-S" height="60%" width="60%"> </div> Figure 1. Comparison between MedViTs (V1 and V2), MedMamba, and the baseline ResNets, in terms of Average Accuracy vs. FLOPs trade-off over all MedMNIST datasets. MedViTV2-T/S/L significantly improves average accu- racy by 2.6%, 2.5%, and 4.6%, respectively, compared to MedViTV1-T/S/L.</center>Overview
<div style="text-align: center"> <img src="https://github.com/Omid-Nejati/MedViT-V2/blob/main/Fig/structure.png" title="MedViT-S" height="75%" width="75%"> </div> Figure 2. Overall architecture of the proposed Medical Vision Transformer (MedViTV2).</center>Visual Examples
You can find a tutorial for visualizing the Grad-CAM heatmap of MedViT in this repository "visualize".
<br><br>

Usage
First, clone the repository locally:
git clone https://github.com/whai362/PVT.git](https://github.com/Omid-Nejati/MedViTV2.git
cd MedViTV2
Install PyTorch 2.5
pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124
Then, install natten 0.17.3
pip install natten==0.17.3+torch250cu124 -f https://shi-labs.com/natten/wheels/
Also, install requirements
pip install -r requirements.txt
Training
To train MedViT-small on breastMNIST on a single gpu for 100 epochs run:
python main.py --model_name 'MedViT_small' --dataset 'breastmnist' --pretrained False
📊 Performance Overview
Below is the performance summary of MedViT on various medical imaging datasets.
🔹 Model weights are available now.
| Dataset | Task | MedViTV2-tiny (%) |MedViTV2-small (%) |MedViTV2-base (%) |MedViTV2-large (%) | |:-----------:|:--------:|:-----------------------:|:------------------:|:---------------------:|:-----------------------:| | ChestMNIST | Multi-Class (14) | 96.3 (model)| 96.4 (model)| 96.4 (model)| 96.7 (model)| | PathMNIST | Multi-Class (9) | 95.9 (model)| 96.5 (model)| 97.0 (model)| 97.7 (model)| | DermaMNIST | Multi-Class (7) | 78.1 (model)| 79.2 (model)| 80.8 (model)| 81.7 (model)| | OCTMNIST | Multi-Class (4) | 92.7 (model)| 94.2 (model)| 94.4 (model)| 95.2 (model)| | PneumoniaMNIST | Multi-Class (2) | 95.1 (model)| 96.5 (model)| 96.9 (model)| 97.3 (model)| | RetinaMNIST | Multi-Class (5) | 54.7 (model)| 56.2 (model)| 57.5 (model)| 57.8 (model)| | BreastMNIST | Multi-Class (2) | 88.2 (model)| 89.5 (model)| 90.4 (model)| 91.0 (model)| | BloodMNIST | Multi-Class (8) | 97.9 (model)| 98.5 (model)| 98.5 (model)| 98.7 (model)| | TissueMNIST | Multi-Class (8) | 69.9 (model)| 70.5 (model)| 71.1 (model)| 71.6 (model)| | OrganAMNIST | Multi-Class (11) | 95.8 (model)| 96.6 (model)| 96.9 (model)| 97.3 (model)| | OrganCMNIST | Multi-Class (11) | 93.5 (model)| 95.0 (model)| 95.3 (model)| 96.1 ([model](https://drive.google.com/file/d/1jpPTbcy0ztZxo9XshfU_J0TiRV
