SkillAgentSearch skills...

RSPrompter

This is the pytorch implement of our paper "RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model"

Install / Use

/learn @KyanChen/RSPrompter
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

<div align="center"> <h2> RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model </h2> </div> <br> <div align="center"> <img src="resources/RSPrompter.png" width="800"/> </div> <br> <div align="center"> <a href="https://kychen.me/RSPrompter"> <span style="font-size: 20px; ">Project Page</span> </a> &nbsp;&nbsp;&nbsp;&nbsp; <a href="https://arxiv.org/abs/2306.16269"> <span style="font-size: 20px; ">arXiv</span> </a> &nbsp;&nbsp;&nbsp;&nbsp; <a href="https://huggingface.co/spaces/KyanChen/RSPrompter"> <span style="font-size: 20px; ">HFSpace</span> </a> </div> <br> <br>

GitHub stars license arXiv Hugging Face Spaces

<br> <br> <div align="center">

English | 简体中文

</div>

Introduction

This repository is the code implementation of the paper RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model, which is based on the MMDetection project.

The current branch has been tested under PyTorch 2.x and CUDA 12.1, supports Python 3.7+, and is compatible with most CUDA versions.

If you find this project helpful, please give us a star ⭐️, your support is our greatest motivation.

<details open> <summary>Main Features</summary>
  • A highly consistent API interface and usage method with MMDetection
  • Open source SAM-seg, SAM-det, RSPrompter and other models in the paper
  • Tested with AMP, DeepSpeed and other training methods
  • Support training and testing of multiple datasets
</details>

Update Log

🌟 2023.06.29 Released the RSPrompter project, which implements the SAM-seg, SAM-det, RSPrompter and other models in the paper based on Lightning and MMDetection.

🌟 2023.11.25 Updated the code of RSPrompter, which is completely consistent with the API interface and usage method of MMDetection.

🌟 2023.11.26 Added the LoRA efficient fine-tuning method, and made the input image size variable, reducing the memory usage of the model.

🌟 2023.11.26 Provided a reference for the memory usage of each model, see Common Problems for details.

🌟 2023.11.30 Updated the paper content, see Arxiv for details.

TODO

  • [X] Consistent API interface and usage method with MMDetection
  • [X] Reduce the memory usage of the model while ensuring performance by reducing the image input and combining with the large model fine-tuning technology
  • [X] Dynamically variable image size input
  • [X] Efficient fine-tuning method in the model
  • [ ] Add SAM-cls model

Table of Contents

Installation

Dependencies

  • Linux or Windows
  • Python 3.7+, recommended 3.10
  • PyTorch 2.0 or higher, recommended 2.1
  • CUDA 11.7 or higher, recommended 12.1
  • MMCV 2.0 or higher, recommended 2.1

Environment Installation

We recommend using Miniconda for installation. The following command will create a virtual environment named rsprompter and install PyTorch and MMCV.

Note: If you have experience with PyTorch and have already installed it, you can skip to the next section. Otherwise, you can follow these steps to prepare.

<details open>

Step 0: Install Miniconda.

Step 1: Create a virtual environment named rsprompter and activate it.

conda create -n rsprompter python=3.10 -y
conda activate rsprompter

Step 2: Install PyTorch2.1.x.

Linux/Windows:

pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121

Or

conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia

Step 3: Install MMCV2.1.x.

pip install -U openmim
mim install mmcv==2.1.0

Step 4: Install other dependencies.

pip install -U transformers==4.38.1 wandb==0.16.3 einops pycocotools shapely scipy terminaltables importlib peft==0.8.2 mat4py==0.6.0 mpi4py

Step 5: [Optional] Install DeepSpeed.

If you want to use DeepSpeed to train the model, you need to install DeepSpeed. The installation method of DeepSpeed can refer to the DeepSpeed official document.

pip install deepspeed==0.13.4

Note: The support for DeepSpeed under the Windows system is not perfect yet, we recommend that you use DeepSpeed under the Linux system.

</details>

Install RSPrompter

Download or clone the RSPrompter repository.

git clone git@github.com:KyanChen/RSPrompter.git
cd RSPrompter

Dataset Preparation

<details open>

Basic Instance Segmentation Dataset

We provide the instance segmentation dataset preparation method used in the paper.

WHU Building Dataset

  • Image download address: WHU Building Dataset

  • Semantic label to instance label: We provide the corresponding conversion script to convert the semantic label of WHU building dataset to instance label.

NWPU VHR-10 Dataset

SSDD Dataset

Note: In the data folder of this project, we provide the instance labels of the above datasets, which you can use directly.

Organization Method

You can also choose other sources to download the data, but you need to organize the dataset in the following format:

${DATASET_ROOT} # Dataset root directory, for example: /home/username/data/NWPU
├── annotations
│   ├── train.json
│   ├── val.json
│   └── test.json
└── images
    ├── train
    ├── val
    └── test

Note: In the project folder, we provide a folder named data, which contains examples of the organization method of the above datasets.

Other Datasets

If you want to use other datasets, you can refer to MMDetection documentation to prepare the datasets.

</details>

Model Training

SAM-based Model

Config File and Main Parameter Parsing

We provide the configuration files of the SAM-based models used in the paper, which can be found in the configs/rsprompter folder. The Config file is completely consistent with the API interface and usage method of MMDetection. Below we provide an analysis of some of the main parameters. If you want to know more about the meaning of the parameters, you can refer to the MMDetection documentation.

<details open>

Parameter Parsing:

  • work_dir: The output path of model training, which generally does not need to be modified.
  • default_hooks-CheckpointHook: Checkpoint saving configuration during model training, which generally does not need to be modified.
  • default_hooks-visualization: Visualization configuration during model training, comment out during training and uncomment during testing.
  • vis_backends-WandbVisBackend: Configuration of network-side visualization tools, after opening the comment, you need to register an account on the wandb official website, and you can view the visualization results during training in the web browser.
  • num_classes: The number of categories in the dataset, which needs to be modified according to the number of categories in the dataset.
  • prompt_shape: The shape of the Prompt, the first parameter represents $N_p$, and the second parameter represents $K_p$, which generally does not need to be modified.
  • hf_sam_pretrain_name: The name of the SAM model on HuggingFace Spaces, which needs to be modified to your own path, you can use the download script to download.
  • hf_sam_pretrain_ckpt_path: The checkpoint path of the SAM model on HuggingFace Spaces, which needs to be modified to your own path, you can use the download script to download.
  • model-decoder_freeze: Whether to freeze the parameters of the SAM decoder, which generally does not need to be modified.
  • model-neck-feature_aggregator-hidden_channels: The hidden channel number of the feature aggregator, which generally does not need to be modified.
  • model-neck-feature_aggregator-select_layers: The number of layers selected by the feature aggregator, which needs to be modified according to the selected SAM backbone type.

Related Skills

View on GitHub
GitHub Stars655
CategoryEducation
Updated17h ago
Forks44

Languages

Python

Security Score

95/100

Audited on Mar 26, 2026

No findings