SkillAgentSearch skills...

MedViTV2

MedViTV2: Medical Image Classification with KAN-Integrated Transformers and Dilated Neighborhood Attention (Applied Soft Computing 2025)

Install / Use

/learn @Omid-Nejati/MedViTV2

README

Paper Paper Open In Colab

<div align="center"> <h1 style="font-family: Arial;">MedViT</h1> <h3>MedViTV2: Medical Image Classification with KAN-Integrated Transformers and Dilated Neighborhood Attention</h3> </div> <div align="center"> <img src="https://github.com/Omid-Nejati/MedViT-V2/blob/main/Fig/cover.jpg" alt="figure4" width="40%" /> </div>

🔥 News

  • [2025.08.27] We have released the pre-trained weights.
  • [2025.10.06] Our paper accepted for publication in Applied Soft Computing.

Train & Test --- Prepare data

To train or evaluate MedViT models on 17 medical datasets, follow this "Evaluation".

Important: This code also supports training all TIMM models.

Introduction

<div align="justify"> Convolutional networks, transformers, hybrid models, and Mamba-based architectures have shown strong performance in medical image classification but are typically designed for clean, labeled data. Real-world clinical datasets, however, often contain corruptions arising from multi-center studies and variations in imaging equipment. To address this, we introduce the Medical Vision Transformer (MedViTV2), the first architecture to integrate Kolmogorov–Arnold Network (KAN) layers into a transformer for generalized medical image classification. We design an efficient KAN block to lower computational cost while improving accuracy over the original MedViT. To overcome scaling fragility, we propose Dilated Neighborhood Attention (DiNA), an adaptation of fused dot-product attention that expands receptive fields and mitigates feature collapse. Additionally, a hierarchical hybrid strategy balances local and global feature perception through efficient stacking of Local and Global Feature Perception blocks. Evaluated on 17 classification and 12 corrupted datasets, MedViTV2 achieved state-of-the-art performance in 27 of 29 benchmarks, improving efficiency by 44% and boosting accuracy by 4.6% on MedMNIST, 5.8% on NonMNIST, and 13.4% on MedMNIST-C. </div> <div style="text-align: center"> <img src="https://github.com/Omid-Nejati/MedViT-V2/blob/main/Fig/ACC.png" title="MedViT-S" height="60%" width="60%"> </div> Figure 1. Comparison between MedViTs (V1 and V2), MedMamba, and the baseline ResNets, in terms of Average Accuracy vs. FLOPs trade-off over all MedMNIST datasets. MedViTV2-T/S/L significantly improves average accu- racy by 2.6%, 2.5%, and 4.6%, respectively, compared to MedViTV1-T/S/L.</center>

Overview

<div style="text-align: center"> <img src="https://github.com/Omid-Nejati/MedViT-V2/blob/main/Fig/structure.png" title="MedViT-S" height="75%" width="75%"> </div> Figure 2. Overall architecture of the proposed Medical Vision Transformer (MedViTV2).</center>

Visual Examples

You can find a tutorial for visualizing the Grad-CAM heatmap of MedViT in this repository "visualize". <br><br> MedViT-V

<center>Figure 3. Grad-Cam heatmap visualization. We present heatmaps generated from the last three layers of MedViTV1- T, MedViTV2-T, MedViTV1-L, and MedViTV2-L, respectively. Specifically, we utilize the final GFP, LGP, and normalization layers in these models to produce the heatmaps using Grad-CAM.</center>

Usage

First, clone the repository locally:

git clone https://github.com/whai362/PVT.git](https://github.com/Omid-Nejati/MedViTV2.git
cd MedViTV2

Install PyTorch 2.5

pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124

Then, install natten 0.17.3

pip install natten==0.17.3+torch250cu124 -f https://shi-labs.com/natten/wheels/

Also, install requirements

pip install -r requirements.txt

Training

To train MedViT-small on breastMNIST on a single gpu for 100 epochs run:

python main.py --model_name 'MedViT_small' --dataset 'breastmnist' --pretrained False

📊 Performance Overview

Below is the performance summary of MedViT on various medical imaging datasets.
🔹 Model weights are available now.

| Dataset | Task | MedViTV2-tiny (%) |MedViTV2-small (%) |MedViTV2-base (%) |MedViTV2-large (%) | |:-----------:|:--------:|:-----------------------:|:------------------:|:---------------------:|:-----------------------:| | ChestMNIST | Multi-Class (14) | 96.3 (model)| 96.4 (model)| 96.4 (model)| 96.7 (model)| | PathMNIST | Multi-Class (9) | 95.9 (model)| 96.5 (model)| 97.0 (model)| 97.7 (model)| | DermaMNIST | Multi-Class (7) | 78.1 (model)| 79.2 (model)| 80.8 (model)| 81.7 (model)| | OCTMNIST | Multi-Class (4) | 92.7 (model)| 94.2 (model)| 94.4 (model)| 95.2 (model)| | PneumoniaMNIST | Multi-Class (2) | 95.1 (model)| 96.5 (model)| 96.9 (model)| 97.3 (model)| | RetinaMNIST | Multi-Class (5) | 54.7 (model)| 56.2 (model)| 57.5 (model)| 57.8 (model)| | BreastMNIST | Multi-Class (2) | 88.2 (model)| 89.5 (model)| 90.4 (model)| 91.0 (model)| | BloodMNIST | Multi-Class (8) | 97.9 (model)| 98.5 (model)| 98.5 (model)| 98.7 (model)| | TissueMNIST | Multi-Class (8) | 69.9 (model)| 70.5 (model)| 71.1 (model)| 71.6 (model)| | OrganAMNIST | Multi-Class (11) | 95.8 (model)| 96.6 (model)| 96.9 (model)| 97.3 (model)| | OrganCMNIST | Multi-Class (11) | 93.5 (model)| 95.0 (model)| 95.3 (model)| 96.1 ([model](https://drive.google.com/file/d/1jpPTbcy0ztZxo9XshfU_J0TiRV

View on GitHub
GitHub Stars93
CategoryHealthcare
Updated3d ago
Forks22

Languages

Jupyter Notebook

Security Score

100/100

Audited on Mar 24, 2026

No findings