BrushNet
[ECCV 2024] The official implementation of paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"
Install / Use
/learn @TencentARC/BrushNetREADME
BrushNet
This repository contains the implementation of the ECCV2024 paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"
Keywords: Image Inpainting, Diffusion Models, Image Generation
<p align="center"> <a href="https://tencentarc.github.io/BrushNet/">🌐Project Page</a> | <a href="https://arxiv.org/abs/2403.06976">📜Arxiv</a> | <a href="https://forms.gle/9TgMZ8tm49UYsZ9s5">🗄️Data</a> | <a href="https://drive.google.com/file/d/1IkEBWcd2Fui2WHcckap4QFPcCI0gkHBh/view">📹Video</a> | <a href="https://huggingface.co/spaces/TencentARC/BrushNet">🤗Hugging Face Demo</a> | </p>Xuan Ju<sup>12</sup>, Xian Liu<sup>12</sup>, Xintao Wang<sup>1*</sup>, Yuxuan Bian<sup>2</sup>, Ying Shan<sup>1</sup>, Qiang Xu<sup>2*</sup><br> <sup>1</sup>ARC Lab, Tencent PCG <sup>2</sup>The Chinese University of Hong Kong <sup>*</sup>Corresponding Author
📖 Table of Contents
🔥 Update Log
- [2024/12/17] 📢 📢 BrushEdit are released, an efficient, white-box, free-form image editing tool powered by LLM-agents and an all-in-one inpainting model.
- [2024/12/17] 📢 📢 BrushNetX (Stronger BrushNet) models are released.
TODO
- [x] Release trainig and inference code
- [x] Release checkpoint (sdv1.5)
- [x] Release checkpoint (sdxl). Sadly, I only have V100 for training this checkpoint, which can only train with a batch size of 1 with a slow speed. The current ckpt is only trained for a small step number thus perform not well. But fortunately, yuanhang volunteer to help training a better version. Please stay tuned! Thank yuanhang for his effort!
- [x] Release evluation code
- [x] Release gradio demo
- [x] Release comfyui demo. Thank nullquant (ConfyUI-BrushNet) and kijai (ComfyUI-BrushNet-Wrapper) for helping!
- [x] Release trainig data. Thank random123123 for helping!
- [x] We use BrushNet to participate in CVPR2024 GenAI Media Generation Challenge Workshop and get top prize! The solution is provided in InstructionGuidedEditing
- [x] Release a new version of checkpoint (sdxl).
🛠️ Method Overview
BrushNet is a diffusion-based text-guided image inpainting model that can be plug-and-play into any pre-trained diffusion model. Our architectural design incorporates two key insights: (1) dividing the masked image features and noisy latent reduces the model's learning load, and (2) leveraging dense per-pixel control over the entire pre-trained model enhances its suitability for image inpainting tasks. More analysis can be found in the main paper.

🚀 Getting Started
Environment Requirement 🌍
BrushNet has been implemented and tested on Pytorch 1.12.1 with python 3.9.
Clone the repo:
git clone https://github.com/TencentARC/BrushNet.git
We recommend you first use conda to create virtual environment, and install pytorch following official instructions. For example:
conda create -n diffusers python=3.9 -y
conda activate diffusers
python -m pip install --upgrade pip
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
Then, you can install diffusers (implemented in this repo) with:
pip install -e .
After that, you can install required packages thourgh:
cd examples/brushnet/
pip install -r requirements.txt
Data Download ⬇️
Dataset
You can download the BrushData and BrushBench here (as well as the EditBench we re-processed), which are used for training and testing the BrushNet. By downloading the data, you are agreeing to the terms and conditions of the license. The data structure should be like:
|-- data
|-- BrushData
|-- 00200.tar
|-- 00201.tar
|-- ...
|-- BrushDench
|-- images
|-- mapping_file.json
|-- EditBench
|-- images
|-- mapping_file.json
Noted: We only provide a part of the BrushData in google drive due to the space limit. random123123 has helped upload a full dataset on hugging face here. Thank for his help!
Checkpoints
Checkpoints of BrushNet can be downloaded from here. The ckpt folder contains
- BrushNet pretrained checkpoints for Stable Diffusion v1.5 (
segmentation_mask_brushnet_ckptandrandom_mask_brushnet_ckpt) - pretrinaed Stable Diffusion v1.5 checkpoint (e.g., realisticVisionV60B1_v51VAE from Civitai). You can use
scripts/convert_original_stable_diffusion_to_diffusers.pyto process other models downloaded from Civitai. - BrushNet pretrained checkpoints for Stable Diffusion XL (
segmentation_mask_brushnet_ckpt_sdxl_v1andrandom_mask_brushnet_ckpt_sdxl_v0). A better version will be shortly released by yuanhang. Please stay tuned! - pretrinaed Stable Diffusion XL checkpoint (e.g., juggernautXL_juggernautX from Civitai). You can use
StableDiffusionXLPipeline.from_single_file("path of safetensors").save_pretrained("path to save",safe_serialization=False)to process other models downloaded from Civitai.
The data structure should be like:
|-- data
|-- BrushData
|-- BrushDench
|-- EditBench
|-- ckpt
|-- realisticVisionV60B1_v51VAE
|-- model_index.json
|-- vae
|-- ...
|-- segmentation_mask_brushnet_ckpt
|-- segmentation_mask_brushnet_ckpt_sdxl_v0
|-- random_mask_brushnet_ckpt
|-- random_mask_brushnet_ckpt_sdxl_v0
|-- ...
The checkpoint in segmentation_mask_brushnet_ckpt and segmentation_mask_brushnet_ckpt_sdxl_v0 provide checkpoints trained on BrushData, which has segmentation prior (mask are with the same shape of objects). The random_mask_brushnet_ckpt and random_mask_brushnet_ckpt_sdxl provide a more general ckpt for random mask shape.
🏃🏼 Running Scripts
Training 🤯
You can train with segmentation mask using the script:
# sd v1.5
accelerate launch examples/brushnet/train_brushnet.py \
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
--output_dir runs/logs/brushnet_segmentationmask \
--train_data_dir data/BrushData \
--resolution 512 \
--learning_rate 1e-5 \
--train_batch_size 2 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300
--checkpointing_steps 10000
# sdxl
accelerate launch examples/brushnet/train_brushnet_sdxl.py \
--pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \
--output_dir runs/logs/brushnetsdxl_segmentationmask \
--train_data_dir data/BrushData \
--resolution 1024 \
--learning_rate 1e-5 \
--train_batch_size 1 \
--gradient_accumulation_steps 4 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300 \
--checkpointing_steps 10000
To use custom dataset, you can process your own data to the format of BrushData and revise --train_data_dir.
You can train with random mask using the script (by adding --random_mask):
# sd v1.5
accelerate launch examples/brushnet/train_brushnet.py \
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
--output_dir runs/logs/brushnet_randommask \
--train_data_dir data/BrushData \
--resolution 512 \
--learning_rate 1e-5 \
--train_batch_size 2 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300 \
--random_mask
# sdxl
accelerate launch examples/brushnet/train_brushnet_sdxl.py \
--pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \
--output_dir runs/logs/brushnetsdxl_randommask \
--train_data_dir data/BrushData \
--resolution 1024 \
--learning_rate 1e-5 \
--train_batch_size 1 \
--gradient_accumulation_steps 4 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300 \
--checkpointing_steps 10000 \
--random_mask
Inference 📜
You can inference with the script:
# sd v1.5
python examples/brushnet/test_brushnet.py
# sdxl
python examples/brushnet/test_brushnet_sdxl.py
Since BrushNet is trained on Laion, it can only guarantee the performance on general scenarios. We recommend you train on your own data (e.g., product exhibition, virtual try-on) if you have high-quality industrial application requirements. We would also be appreciate if you would like to contribute your trained model!
You can also inference through gradio demo:
# sd v1.
Related Skills
node-connect
346.4kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
107.2kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
346.4kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
346.4kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
