RSMamba
This is the pytorch implement of the paper "RSMamba: Remote Sensing Image Classification with State Space Model"
Install / Use
/learn @KyanChen/RSMambaREADME
English | 简体中文
</div>Introduction
This repository is the code implementation of the paper RSMamba: Remote Sensing Image Classification with State Space Model, which is based on the MMPretrain project.
The current branch has been tested on Linux system, PyTorch 2.x and CUDA 12.1, supports Python 3.8+, and is compatible with most CUDA versions.
If you find this project helpful, please give us a star ⭐️, your support is our greatest motivation.
<details open> <summary>Main Features</summary>- Consistent API interface and usage with MMPretrain
- Open-sourced RSMamba models of different sizes in the paper
- Support for training and testing on multiple datasets
Updates
🌟 2024.03.28 Released the RSMamba project, which is fully consistent with the API interface and usage of MMPretrain.
🌟 2024.03.29 Open-sourced the weight files of RSMamba models of different sizes in the paper.
TODO
- [X] Open-source model training parameters
Table of Contents
- Introduction
- Updates
- TODO
- Table of Contents
- Installation
- Dataset Preparation
- Model Training
- Model Testing
- Image Prediction
- FAQ
- Acknowledgements
- Citation
- License
- Contact Us
Installation
Requirements
- Linux system, Windows is not tested, depending on whether
causal-conv1dandmamba-ssmcan be installed - Python 3.8+, recommended 3.11
- PyTorch 2.0 or higher, recommended 2.2
- CUDA 11.7 or higher, recommended 12.1
- MMCV 2.0 or higher, recommended 2.1
Environment Installation
It is recommended to use Miniconda for installation. The following commands will create a virtual environment named rsmamba and install PyTorch and MMCV. In the following installation steps, the default installed CUDA version is 12.1. If your CUDA version is not 12.1, please modify it according to the actual situation.
Note: If you are experienced with PyTorch and have already installed it, you can skip to the next section. Otherwise, you can follow the steps below.
<details open>Step 0: Install Miniconda.
Step 1: Create a virtual environment named rsmamba and activate it.
conda create -n rsmamba python=3.11 -y
conda activate rsmamba
Step 2: Install PyTorch2.2.x.
Linux/Windows:
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121 -y
Or
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia -y
Step 3: Install MMCV2.1.x.
pip install -U openmim
mim install mmcv==2.1.0
# or
pip install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html
Step 4: Install other dependencies.
pip install -U mat4py ipdb modelindex
pip install transformers==4.39.2
pip install causal-conv1d==1.2.0.post2
pip install mamba-ssm==1.2.0.post1
</details>
Install RSMamba
You can download or clone the RSMamba repository.
git clone git@github.com:KyanChen/RSMamba.git
cd RSMamba
Dataset Preparation
<details open>Remote Sensing Image Classification Dataset
We provide the method of preparing the remote sensing image classification dataset used in the paper.
UC Merced Dataset
- Image and annotation download link: UC Merced Dataset.
AID Dataset
- Image and annotation download link: AID Dataset。
NWPU RESISC45 Dataset
- Image and annotation download link: NWPU RESISC45 Dataset。
Note: The data folder of this project provides a small number of image annotation examples for the above datasets.
Organization Method
You can also choose other sources to download the data, but you need to organize the dataset in the following format:
${DATASET_ROOT} # Dataset root directory, for example: /home/username/data/UC
├── airplane
│ ├── airplane01.tif
│ ├── airplane02.tif
│ └── ...
├── ...
├── ...
├── ...
└── ...
Note: In the project folder datainfo, we provide the data set partition file. You can also use the Python script to divide the data set.
Other Datasets
If you want to use other datasets, you can refer to the MMPretrain documentation for dataset preparation.
</details>Model Training
RSMamba Model
Config File and Main Parameter Parsing
We provide the configuration files of RSMamba models with different parameter sizes in the paper, which can be found in the configuration files folder. The Config file is fully consistent with the API interface and usage of MMPretrain. Below we provide an analysis of some of the main parameters. If you want to know more about the parameters, you can refer to the MMPretrain documentation.
<details>Parameter Parsing:
work_dir:The output path of the model training, generally no need to modify.code_root:The root directory of the code, modify to the absolute path of the root directory of this project.data_root:The root directory of the dataset, modify to the absolute path of the dataset root directory.batch_size:The batch size of a single card, needs to be modified according to the memory size.max_epochs:The maximum number of training epochs, generally no need to modify.vis_backends/WandbVisBackend:Configuration of the network-side visualization tool, after opening the comment, you need to register an account on thewandbofficial website, and you can view the visualization results during the training process in the web browser.model/backbone/arch:The type of the model's backbone network, needs to be modified according to the selected model, includingb,l,h.model/backbone/path_type:The path type of the model, needs to be modified according to the selected model.default_hooks-CheckpointHook:Configuration of the checkpoint saving during the model training process, generally no need to modify.num_classes:The number of categories in the dataset, needs to be modified according to the number of categories in the dataset.dataset_type:The type of the dataset, needs to be modified according to the type of the dataset.resume: Whether to resume training, generally no need to modify.load_from:The path of the pre-trained checkpoint of the model, generally no need to modify.data_preprocessor/mean/std:The mean and standard deviation of data preprocessing, needs to be modified according to the mean and standard deviation of the dataset, generally no need to modify, refer to Python script.
Some parameters come from the inheritance value of _base_, you can find them in the basic configuration files folder.
Single Card Training
python tools/train.py configs/rsmamba/name_to_config.py # name_to_config.py is the configuration file you want to use
Multi-card Training
sh ./tools/dist_train.sh configs/rsmamba/name_to_config.py ${GPU_NUM} # name_to_config.py is the configuration file you want to use, GPU_NUM is the number of GPUs used
Other Image Classification Models
<details open>If you want to use other image classification models, you can refer to MMPretrain for model training, or you can put their Config files into the configs folder of this project, and then train them according to the above method.
Model Testing
Single Card Testing:
python tools/test.py configs/rsmamba/name_to_config.py ${CHECKPOINT_FILE} # name_to_config.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use
Multi-card Testing:
sh ./tools/dist_test.sh configs/rsmamba/name_to_config.py ${CHECKPOINT_FILE} ${GPU_NU
Related Skills
node-connect
339.1kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
83.8kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
339.1kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
83.8kCommit, push, and open a PR
