File size: 12,883 Bytes
8c9048a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
Metadata-Version: 2.4
Name: diffusers
Version: 0.27.0.dev0
Summary: State-of-the-art diffusion in PyTorch and JAX.
Home-page: https://github.com/huggingface/diffusers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)
Author-email: patrick@huggingface.co
License: Apache 2.0 License
Keywords: deep learning diffusion jax pytorch stable diffusion audioldm
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Requires-Python: >=3.8.0
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: importlib_metadata
Requires-Dist: filelock
Requires-Dist: huggingface-hub
Requires-Dist: numpy
Requires-Dist: regex!=2019.12.17
Requires-Dist: requests
Requires-Dist: safetensors>=0.3.1
Requires-Dist: Pillow
Provides-Extra: quality
Requires-Dist: urllib3<=2.0.0; extra == "quality"
Requires-Dist: isort>=5.5.4; extra == "quality"
Requires-Dist: ruff==0.1.5; extra == "quality"
Requires-Dist: hf-doc-builder>=0.3.0; extra == "quality"
Provides-Extra: docs
Requires-Dist: hf-doc-builder>=0.3.0; extra == "docs"
Provides-Extra: training
Requires-Dist: accelerate>=0.11.0; extra == "training"
Requires-Dist: datasets; extra == "training"
Requires-Dist: protobuf<4,>=3.20.3; extra == "training"
Requires-Dist: tensorboard; extra == "training"
Requires-Dist: Jinja2; extra == "training"
Requires-Dist: peft>=0.6.0; extra == "training"
Provides-Extra: test
Requires-Dist: compel==0.1.8; extra == "test"
Requires-Dist: GitPython<3.1.19; extra == "test"
Requires-Dist: datasets; extra == "test"
Requires-Dist: Jinja2; extra == "test"
Requires-Dist: invisible-watermark>=0.2.0; extra == "test"
Requires-Dist: k-diffusion>=0.0.12; extra == "test"
Requires-Dist: librosa; extra == "test"
Requires-Dist: parameterized; extra == "test"
Requires-Dist: pytest; extra == "test"
Requires-Dist: pytest-timeout; extra == "test"
Requires-Dist: pytest-xdist; extra == "test"
Requires-Dist: requests-mock==1.10.0; extra == "test"
Requires-Dist: safetensors>=0.3.1; extra == "test"
Requires-Dist: sentencepiece!=0.1.92,>=0.1.91; extra == "test"
Requires-Dist: scipy; extra == "test"
Requires-Dist: torchvision; extra == "test"
Requires-Dist: transformers>=4.25.1; extra == "test"
Provides-Extra: torch
Requires-Dist: torch>=1.4; extra == "torch"
Requires-Dist: accelerate>=0.11.0; extra == "torch"
Provides-Extra: flax
Requires-Dist: jax>=0.4.1; extra == "flax"
Requires-Dist: jaxlib>=0.4.1; extra == "flax"
Requires-Dist: flax>=0.4.1; extra == "flax"
Provides-Extra: dev
Requires-Dist: urllib3<=2.0.0; extra == "dev"
Requires-Dist: isort>=5.5.4; extra == "dev"
Requires-Dist: ruff==0.1.5; extra == "dev"
Requires-Dist: hf-doc-builder>=0.3.0; extra == "dev"
Requires-Dist: compel==0.1.8; extra == "dev"
Requires-Dist: GitPython<3.1.19; extra == "dev"
Requires-Dist: datasets; extra == "dev"
Requires-Dist: Jinja2; extra == "dev"
Requires-Dist: invisible-watermark>=0.2.0; extra == "dev"
Requires-Dist: k-diffusion>=0.0.12; extra == "dev"
Requires-Dist: librosa; extra == "dev"
Requires-Dist: parameterized; extra == "dev"
Requires-Dist: pytest; extra == "dev"
Requires-Dist: pytest-timeout; extra == "dev"
Requires-Dist: pytest-xdist; extra == "dev"
Requires-Dist: requests-mock==1.10.0; extra == "dev"
Requires-Dist: safetensors>=0.3.1; extra == "dev"
Requires-Dist: sentencepiece!=0.1.92,>=0.1.91; extra == "dev"
Requires-Dist: scipy; extra == "dev"
Requires-Dist: torchvision; extra == "dev"
Requires-Dist: transformers>=4.25.1; extra == "dev"
Requires-Dist: accelerate>=0.11.0; extra == "dev"
Requires-Dist: datasets; extra == "dev"
Requires-Dist: protobuf<4,>=3.20.3; extra == "dev"
Requires-Dist: tensorboard; extra == "dev"
Requires-Dist: Jinja2; extra == "dev"
Requires-Dist: peft>=0.6.0; extra == "dev"
Requires-Dist: hf-doc-builder>=0.3.0; extra == "dev"
Requires-Dist: torch>=1.4; extra == "dev"
Requires-Dist: accelerate>=0.11.0; extra == "dev"
Requires-Dist: jax>=0.4.1; extra == "dev"
Requires-Dist: jaxlib>=0.4.1; extra == "dev"
Requires-Dist: flax>=0.4.1; extra == "dev"
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: keywords
Dynamic: license
Dynamic: license-file
Dynamic: provides-extra
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

# BrushNet

This repository contains the implementation of the paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"

Keywords: Image Inpainting, Diffusion Models, Image Generation

> [Xuan Ju](https://github.com/juxuan27)<sup>12</sup>, [Xian Liu](https://alvinliu0.github.io/)<sup>12</sup>, [Xintao Wang](https://xinntao.github.io/)<sup>1*</sup>, [Yuxuan Bian](https://scholar.google.com.hk/citations?user=HzemVzoAAAAJ&hl=zh-CN&oi=ao)<sup>2</sup>, [Ying Shan](https://www.linkedin.com/in/YingShanProfile/)<sup>1</sup>, [Qiang Xu](https://cure-lab.github.io/)<sup>2*</sup><br>
> <sup>1</sup>ARC Lab, Tencent PCG <sup>2</sup>The Chinese University of Hong Kong <sup>*</sup>Corresponding Author


<p align="center">
  <a href="https://tencentarc.github.io/BrushNet/">๐ŸŒProject Page</a> |
  <a href="https://arxiv.org/abs/2403.06976">๐Ÿ“œArxiv</a> |
  <a href="https://forms.gle/9TgMZ8tm49UYsZ9s5">๐Ÿ—„๏ธData</a> |
  <a href="https://drive.google.com/file/d/1IkEBWcd2Fui2WHcckap4QFPcCI0gkHBh/view">๐Ÿ“นVideo</a> |
  <a href="https://huggingface.co/spaces/TencentARC/BrushNet">๐Ÿค—Hugging Face Demo</a> |
</p>



**๐Ÿ“– Table of Contents**


  - [๐Ÿ› ๏ธ Method Overview](#๏ธ-method-overview)
  - [๐Ÿš€ Getting Started](#-getting-started)
    - [Environment Requirement ๐ŸŒ](#environment-requirement-)
    - [Data Download โฌ‡๏ธ](#data-download-๏ธ)
  - [๐Ÿƒ๐Ÿผ Running Scripts](#-running-scripts)
    - [Training ๐Ÿคฏ](#training-)
    - [Inference ๐Ÿ“œ](#inference-)
    - [Evaluation ๐Ÿ“](#evaluation-)
  - [๐Ÿค๐Ÿผ Cite Us](#-cite-us)
  - [๐Ÿ’– Acknowledgement](#-acknowledgement)


## TODO


- [x] Release trainig and inference code
- [x] Release checkpoint (sdv1.5)
- [ ] Release checkpoint (sdxl)
- [x] Release evaluation code
- [x] Release gradio demo

## ๐Ÿ› ๏ธ Method Overview

BrushNet is a diffusion-based text-guided image inpainting model that can be plug-and-play into any pre-trained diffusion model. Our architectural design incorporates two key insights: (1) dividing the masked image features and noisy latent reduces the model's learning load, and (2) leveraging dense per-pixel control over the entire pre-trained model enhances its suitability for image inpainting tasks. More analysis can be found in the main paper.

![](examples/brushnet/src/model.png)



## ๐Ÿš€ Getting Started

### Environment Requirement ๐ŸŒ

BrushNet has been implemented and tested on Pytorch 1.12.1 with python 3.9.

Clone the repo:

```
git clone https://github.com/TencentARC/BrushNet.git
```

We recommend you first use `conda` to create virtual environment, and install `pytorch` following [official instructions](https://pytorch.org/). For example:


```
conda create -n diffusers python=3.9 -y
conda activate diffusers
python -m pip install --upgrade pip
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
```

Then, you can install diffusers (implemented in this repo) with:

```
pip install -e .
```

After that, you can install required packages thourgh:

```
cd examples/brushnet/
pip install -r requirements.txt
```

### Data Download โฌ‡๏ธ


**Dataset**

You can download the BrushData and BrushBench [here](https://forms.gle/9TgMZ8tm49UYsZ9s5) (as well as the EditBench we re-processed), which are used for training and testing the BrushNet. By downloading the data, you are agreeing to the terms and conditions of the license. The data structure should be like:

```
|-- data
    |-- BrushData
        |-- 00200.tar
        |-- 00201.tar
        |-- ...
    |-- BrushDench
        |-- images
        |-- mapping_file.json
    |-- EditBench
        |-- images
        |-- mapping_file.json
```


Noted: *We only provide a part of the BrushData due to the space limit. Please write an email to juxuan.27@gmail.com if you need the full dataset.*


**Checkpoints**

Checkpoints of BrushNet can be downloaded from [here](https://drive.google.com/drive/folders/1fqmS1CEOvXCxNWFrsSYd_jHYXxrydh1n?usp=drive_link). The ckpt folder contains our pretrained checkpoints and pretrinaed Stable Diffusion checkpoint (e.g., realisticVisionV60B1_v51VAE from [Civitai](https://civitai.com/)). You can use `scripts/convert_original_stable_diffusion_to_diffusers.py` to process other models downloaded from Civitai. The data structure should be like:



```
|-- data
    |-- BrushData
    |-- BrushDench
    |-- EditBench
    |-- ckpt
        |-- realisticVisionV60B1_v51VAE
            |-- model_index.json
            |-- vae
            |-- ...
        |-- segmentation_mask_brushnet_ckpt
        |-- random_mask_brushnet_ckpt
        |-- ...
```

The checkpoint in `segmentation_mask_brushnet_ckpt` provides checkpoints trained on BrushData, which has segmentation prior (mask are with the same shape of objects). The `random_mask_brushnet_ckpt` provides a more general ckpt for random mask shape.

## ๐Ÿƒ๐Ÿผ Running Scripts


### Training ๐Ÿคฏ

You can train with segmentation mask using the script:

```
accelerate launch examples/brushnet/train_brushnet.py \
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
--output_dir runs/logs/brushnet_segmentationmask \
--train_data_dir data/BrushData \
--resolution 512 \
--learning_rate 1e-5 \
--train_batch_size 2 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300
```

To use custom dataset, you can process your own data to the format of BrushData and revise `--train_data_dir`.

You can train with random mask using the script (by adding `--random_mask`):

```
accelerate launch examples/brushnet/train_brushnet.py \
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
--output_dir runs/logs/brushnet_randommask \
--train_data_dir data/BrushData \
--resolution 512 \
--learning_rate 1e-5 \
--train_batch_size 2 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300 \
--random_mask
```



### Inference ๐Ÿ“œ

You can inference with the script:

```
python examples/brushnet/test_brushnet.py
```

Since BrushNet is trained on Laion, it can only guarantee the performance on general scenarios. We recommend you train on your own data (e.g., product exhibition, virtual try-on) if you have high-quality industrial application requirements. We would also be appreciate if you would like to contribute your trained model!

You can also inference through gradio demo:

```
python examples/brushnet/app_brushnet.py
```


### Evaluation ๐Ÿ“

You can evaluate using the script:

```
python examples/brushnet/evaluate_brushnet.py \
--brushnet_ckpt_path data/ckpt/segmentation_mask_brushnet_ckpt \
--image_save_path runs/evaluation_result/BrushBench/brushnet_segmask/inside \
--mapping_file data/BrushBench/mapping_file.json \
--base_dir data/BrushBench \
--mask_key inpainting_mask
```

The `--mask_key` indicates which kind of mask to use, `inpainting_mask` for inside inpainting and `outpainting_mask` for outside inpainting. The evaluation results (images and metrics) will be saved in `--image_save_path`. 



*Noted that you need to ignore the nsfw detector in `src/diffusers/pipelines/brushnet/pipeline_brushnet.py#1261` to get the correct evaluation results. Moreover, we find different machine may generate different images, thus providing the results on our machine [here](https://drive.google.com/drive/folders/1dK3oIB2UvswlTtnIS1iHfx4s57MevWdZ?usp=sharing).*


## ๐Ÿค๐Ÿผ Cite Us

```
@misc{ju2024brushnet,
  title={BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion}, 
  author={Xuan Ju and Xian Liu and Xintao Wang and Yuxuan Bian and Ying Shan and Qiang Xu},
  year={2024},
  eprint={2403.06976},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```


## ๐Ÿ’– Acknowledgement
<span id="acknowledgement"></span>

Our code is modified based on [diffusers](https://github.com/huggingface/diffusers), thanks to all the contributors!