RSPrompter
This is the pytorch implement of our paper "RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model"
Install / Use
/learn @KyanChen/RSPrompterREADME
English | 简体中文
</div>Introduction
This repository is the code implementation of the paper RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model, which is based on the MMDetection project.
The current branch has been tested under PyTorch 2.x and CUDA 12.1, supports Python 3.7+, and is compatible with most CUDA versions.
If you find this project helpful, please give us a star ⭐️, your support is our greatest motivation.
<details open> <summary>Main Features</summary>- A highly consistent API interface and usage method with MMDetection
- Open source SAM-seg, SAM-det, RSPrompter and other models in the paper
- Tested with AMP, DeepSpeed and other training methods
- Support training and testing of multiple datasets
Update Log
🌟 2023.06.29 Released the RSPrompter project, which implements the SAM-seg, SAM-det, RSPrompter and other models in the paper based on Lightning and MMDetection.
🌟 2023.11.25 Updated the code of RSPrompter, which is completely consistent with the API interface and usage method of MMDetection.
🌟 2023.11.26 Added the LoRA efficient fine-tuning method, and made the input image size variable, reducing the memory usage of the model.
🌟 2023.11.26 Provided a reference for the memory usage of each model, see Common Problems for details.
🌟 2023.11.30 Updated the paper content, see Arxiv for details.
TODO
- [X] Consistent API interface and usage method with MMDetection
- [X] Reduce the memory usage of the model while ensuring performance by reducing the image input and combining with the large model fine-tuning technology
- [X] Dynamically variable image size input
- [X] Efficient fine-tuning method in the model
- [ ] Add SAM-cls model
Table of Contents
- Introduction
- Update Log
- TODO
- Table of Contents
- Installation
- Dataset Preparation
- Model Training
- Model Testing
- Image Prediction
- Common Problems
- Acknowledgement
- Citation
- License
- Contact
Installation
Dependencies
- Linux or Windows
- Python 3.7+, recommended 3.10
- PyTorch 2.0 or higher, recommended 2.1
- CUDA 11.7 or higher, recommended 12.1
- MMCV 2.0 or higher, recommended 2.1
Environment Installation
We recommend using Miniconda for installation. The following command will create a virtual environment named rsprompter and install PyTorch and MMCV.
Note: If you have experience with PyTorch and have already installed it, you can skip to the next section. Otherwise, you can follow these steps to prepare.
<details open>Step 0: Install Miniconda.
Step 1: Create a virtual environment named rsprompter and activate it.
conda create -n rsprompter python=3.10 -y
conda activate rsprompter
Step 2: Install PyTorch2.1.x.
Linux/Windows:
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
Or
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
Step 3: Install MMCV2.1.x.
pip install -U openmim
mim install mmcv==2.1.0
Step 4: Install other dependencies.
pip install -U transformers==4.38.1 wandb==0.16.3 einops pycocotools shapely scipy terminaltables importlib peft==0.8.2 mat4py==0.6.0 mpi4py
Step 5: [Optional] Install DeepSpeed.
If you want to use DeepSpeed to train the model, you need to install DeepSpeed. The installation method of DeepSpeed can refer to the DeepSpeed official document.
pip install deepspeed==0.13.4
Note: The support for DeepSpeed under the Windows system is not perfect yet, we recommend that you use DeepSpeed under the Linux system.
</details>Install RSPrompter
Download or clone the RSPrompter repository.
git clone git@github.com:KyanChen/RSPrompter.git
cd RSPrompter
Dataset Preparation
<details open>Basic Instance Segmentation Dataset
We provide the instance segmentation dataset preparation method used in the paper.
WHU Building Dataset
-
Image download address: WHU Building Dataset。
-
Semantic label to instance label: We provide the corresponding conversion script to convert the semantic label of WHU building dataset to instance label.
NWPU VHR-10 Dataset
-
Image download address: NWPU VHR-10 Dataset.
-
Instance label download address: NWPU VHR-10 Instance Label.
SSDD Dataset
-
Image download address: SSDD Dataset.
-
Instance label download address: SSDD Instance Label.
Note: In the data folder of this project, we provide the instance labels of the above datasets, which you can use directly.
Organization Method
You can also choose other sources to download the data, but you need to organize the dataset in the following format:
${DATASET_ROOT} # Dataset root directory, for example: /home/username/data/NWPU
├── annotations
│ ├── train.json
│ ├── val.json
│ └── test.json
└── images
├── train
├── val
└── test
Note: In the project folder, we provide a folder named data, which contains examples of the organization method of the above datasets.
Other Datasets
If you want to use other datasets, you can refer to MMDetection documentation to prepare the datasets.
</details>Model Training
SAM-based Model
Config File and Main Parameter Parsing
We provide the configuration files of the SAM-based models used in the paper, which can be found in the configs/rsprompter folder. The Config file is completely consistent with the API interface and usage method of MMDetection. Below we provide an analysis of some of the main parameters. If you want to know more about the meaning of the parameters, you can refer to the MMDetection documentation.
Parameter Parsing:
work_dir: The output path of model training, which generally does not need to be modified.default_hooks-CheckpointHook: Checkpoint saving configuration during model training, which generally does not need to be modified.default_hooks-visualization: Visualization configuration during model training, comment out during training and uncomment during testing.vis_backends-WandbVisBackend: Configuration of network-side visualization tools, after opening the comment, you need to register an account on thewandbofficial website, and you can view the visualization results during training in the web browser.num_classes: The number of categories in the dataset, which needs to be modified according to the number of categories in the dataset.prompt_shape: The shape of the Prompt, the first parameter represents $N_p$, and the second parameter represents $K_p$, which generally does not need to be modified.hf_sam_pretrain_name: The name of the SAM model on HuggingFace Spaces, which needs to be modified to your own path, you can use the download script to download.hf_sam_pretrain_ckpt_path: The checkpoint path of the SAM model on HuggingFace Spaces, which needs to be modified to your own path, you can use the download script to download.model-decoder_freeze: Whether to freeze the parameters of the SAM decoder, which generally does not need to be modified.model-neck-feature_aggregator-hidden_channels: The hidden channel number of the feature aggregator, which generally does not need to be modified.model-neck-feature_aggregator-select_layers: The number of layers selected by the feature aggregator, which needs to be modified according to the selected SAM backbone type.
Related Skills
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
best-practices-researcher
The most comprehensive Claude Code skills registry | Web Search: https://skills-registry-web.vercel.app
groundhog
399Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
last30days-skill
10.3kAI agent skill that researches any topic across Reddit, X, YouTube, HN, Polymarket, and the web - then synthesizes a grounded summary
