Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
f08d17a
0
Parent(s):
Initial commit
Browse files- .gitattributes +35 -0
- LICENSE +201 -0
- README.md +14 -0
- app.py +112 -0
- instructions.json +4 -0
- main.py +95 -0
- main_json.py +92 -0
- requirements.txt +15 -0
- src/flux/block.py +339 -0
- src/flux/condition.py +133 -0
- src/flux/generate.py +322 -0
- src/flux/lora_controller.py +77 -0
- src/flux/pipeline_tools.py +52 -0
- src/flux/transformer.py +252 -0
- src/gradio/gradio_app.py +118 -0
- src/train/callbacks.py +268 -0
- src/train/data.py +401 -0
- src/train/model.py +185 -0
- src/train/train.py +214 -0
- train/README.md +138 -0
- train/config/canny_512.yaml +48 -0
- train/config/cartoon_512.yaml +44 -0
- train/config/fill_1024.yaml +47 -0
- train/config/scene_512.yaml +45 -0
- train/config/sr_512.yaml +48 -0
- train/config/subject_512.yaml +44 -0
- train/requirements.txt +15 -0
- train/script/data_download/data_download1.sh +1 -0
- train/script/data_download/data_download2.sh +3 -0
- train/script/train_canny.sh +13 -0
- train/script/train_cartoon.sh +15 -0
- train/script/train_scene.sh +13 -0
- train/script/train_subject.sh +13 -0
- utils.py +591 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: IEAP
|
3 |
+
emoji: 👀
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.32.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: A demo for IEAP
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
from utils import encode_image_to_datauri, cot_with_gpt, extract_instructions, infer_with_DiT, roi_localization, fusion
|
4 |
+
import openai
|
5 |
+
import os
|
6 |
+
import uuid
|
7 |
+
from src.flux.generate import generate, seed_everything
|
8 |
+
|
9 |
+
|
10 |
+
def process_image(api_key, seed, image, prompt):
|
11 |
+
if not api_key:
|
12 |
+
raise gr.Error("❌ Please enter a valid OpenAI API key.")
|
13 |
+
|
14 |
+
openai.api_key = api_key
|
15 |
+
|
16 |
+
# Generate a unique image ID to avoid file name conflict
|
17 |
+
image_id = str(uuid.uuid4())
|
18 |
+
seed_everything(seed)
|
19 |
+
input_path = f"input_{image_id}.png"
|
20 |
+
image.save(input_path)
|
21 |
+
|
22 |
+
try:
|
23 |
+
uri = encode_image_to_datauri(input_path)
|
24 |
+
categories, instructions = cot_with_gpt(uri, prompt)
|
25 |
+
# categories = ['Tone Transfer', 'Style Change']
|
26 |
+
# instructions = ['Change the time to night', 'Change the style to watercolor']
|
27 |
+
|
28 |
+
if not categories or not instructions:
|
29 |
+
raise gr.Error("No editing steps returned by GPT. Try a more specific instruction.")
|
30 |
+
|
31 |
+
intermediate_images = []
|
32 |
+
current_image_path = input_path
|
33 |
+
|
34 |
+
for i, (category, instruction) in enumerate(zip(categories, instructions)):
|
35 |
+
print(f"[Step {i}] Category: {category} | Instruction: {instruction}")
|
36 |
+
step_prefix = f"{image_id}_{i}"
|
37 |
+
|
38 |
+
if category in ('Add', 'Remove', 'Replace'):
|
39 |
+
if category == 'Add':
|
40 |
+
edited_image = infer_with_DiT('RoI Editing', current_image_path, instruction, category)
|
41 |
+
else:
|
42 |
+
mask_image = roi_localization(current_image_path, instruction, category)
|
43 |
+
edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, category)
|
44 |
+
|
45 |
+
elif category == 'Action Change':
|
46 |
+
mask_image = roi_localization(current_image_path, instruction, category)
|
47 |
+
inpainted = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove')
|
48 |
+
changed_instance, x0, y1, scale = infer_with_DiT('RoI Editing', current_image_path, instruction, category)
|
49 |
+
fusion_image = fusion(inpainted, changed_instance, x0, y1, scale)
|
50 |
+
edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
|
51 |
+
|
52 |
+
elif category in ('Move', 'Resize'):
|
53 |
+
mask_image, changed_instance, x0, y1, scale = roi_localization(current_image_path, instruction, category)
|
54 |
+
inpainted = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove')
|
55 |
+
fusion_image = fusion(inpainted, changed_instance, x0, y1, scale)
|
56 |
+
edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
|
57 |
+
|
58 |
+
elif category in ('Appearance Change', 'Background Change', 'Color Change', 'Material Change', 'Expression Change'):
|
59 |
+
edited_image = infer_with_DiT('RoI Editing', current_image_path, instruction, category)
|
60 |
+
|
61 |
+
elif category in ('Tone Transfer', 'Style Change'):
|
62 |
+
edited_image = infer_with_DiT('Global Transformation', current_image_path, instruction, category)
|
63 |
+
|
64 |
+
else:
|
65 |
+
raise gr.Error(f"Invalid category returned: '{category}'")
|
66 |
+
|
67 |
+
current_image_path = f"{step_prefix}.png"
|
68 |
+
edited_image.save(current_image_path)
|
69 |
+
intermediate_images.append(edited_image.copy())
|
70 |
+
|
71 |
+
final_result = intermediate_images[-1] if intermediate_images else image
|
72 |
+
return intermediate_images, final_result
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
raise gr.Error(f"Processing failed: {str(e)}")
|
76 |
+
|
77 |
+
|
78 |
+
# Gradio UI
|
79 |
+
with gr.Blocks() as demo:
|
80 |
+
gr.Markdown("## 🖼️ IEAP: Image Editing As Programs")
|
81 |
+
|
82 |
+
with gr.Row():
|
83 |
+
api_key_input = gr.Textbox(label="🔑 OpenAI API Key", type="password", placeholder="sk-...")
|
84 |
+
|
85 |
+
with gr.Row():
|
86 |
+
seed_slider = gr.Slider(
|
87 |
+
label="🎲 Random Seed",
|
88 |
+
minimum=0,
|
89 |
+
maximum=1000000,
|
90 |
+
value=3407,
|
91 |
+
step=1,
|
92 |
+
info="Drag to set the random seed for reproducibility"
|
93 |
+
)
|
94 |
+
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column():
|
97 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
98 |
+
prompt_input = gr.Textbox(label="Instruction", placeholder="e.g., Move the dog to the left and change its color to blue")
|
99 |
+
submit_button = gr.Button("Submit")
|
100 |
+
with gr.Column():
|
101 |
+
result_gallery = gr.Gallery(label="Intermediate Steps", columns=2, height="auto")
|
102 |
+
final_output = gr.Image(label="✅ Final Result")
|
103 |
+
|
104 |
+
submit_button.click(
|
105 |
+
fn=process_image,
|
106 |
+
inputs=[api_key_input, seed_slider, image_input, prompt_input],
|
107 |
+
outputs=[result_gallery, final_output]
|
108 |
+
)
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
demo.launch(
|
112 |
+
)
|
instructions.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"categories": ["Move", "Resize"],
|
3 |
+
"instructions": ["Move the woman to the right", "Minify the woman"]
|
4 |
+
}
|
main.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from PIL import Image
|
4 |
+
import openai
|
5 |
+
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
|
6 |
+
from utils import encode_image_to_datauri, cot_with_gpt, extract_instructions, infer_with_DiT, roi_localization, fusion
|
7 |
+
from src.flux.generate import generate, seed_everything
|
8 |
+
|
9 |
+
def main():
|
10 |
+
parser = argparse.ArgumentParser(description="Evaluate single image + instruction using GPT-4o")
|
11 |
+
parser.add_argument("image_path", help="Path to input image")
|
12 |
+
parser.add_argument("prompt", help="Original instruction")
|
13 |
+
parser.add_argument("--seed", type=int, default=3407, help="Random seed for reproducibility")
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
seed_everything(args.seed)
|
17 |
+
|
18 |
+
openai.api_key = "YOUR_API_KEY"
|
19 |
+
|
20 |
+
if not openai.api_key:
|
21 |
+
raise ValueError("OPENAI_API_KEY environment variable not set.")
|
22 |
+
|
23 |
+
os.makedirs("results", exist_ok=True)
|
24 |
+
|
25 |
+
|
26 |
+
###########################################
|
27 |
+
### CoT -> instructions ###
|
28 |
+
###########################################
|
29 |
+
|
30 |
+
uri = encode_image_to_datauri(args.image_path)
|
31 |
+
categories, instructions = cot_with_gpt(uri, args.prompt)
|
32 |
+
print(categories)
|
33 |
+
print(instructions)
|
34 |
+
|
35 |
+
# categories = ['Move', 'Resize']
|
36 |
+
# instructions = ['Move the woman to the right', 'Minify the woman']
|
37 |
+
|
38 |
+
###########################################
|
39 |
+
### Neural Program Interpreter ###
|
40 |
+
###########################################
|
41 |
+
for i in range(len(categories)):
|
42 |
+
if i == 0:
|
43 |
+
image = args.image_path
|
44 |
+
else:
|
45 |
+
image = f"results/{i-1}.png"
|
46 |
+
category = categories[i]
|
47 |
+
instruction = instructions[i]
|
48 |
+
if category in ('Add', 'Remove', 'Replace', 'Action Change', 'Move', 'Resize'):
|
49 |
+
if category in ('Add', 'Remove', 'Replace'):
|
50 |
+
if category == 'Add':
|
51 |
+
edited_image = infer_with_DiT('RoI Editing', image, instruction, category)
|
52 |
+
else:
|
53 |
+
### RoI Localization
|
54 |
+
mask_image = roi_localization(image, instruction, category)
|
55 |
+
# mask_image.save("mask.png")
|
56 |
+
### RoI Inpainting
|
57 |
+
edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, category)
|
58 |
+
elif category == 'Action Change':
|
59 |
+
### RoI Localization
|
60 |
+
mask_image = roi_localization(image, instruction, category)
|
61 |
+
### RoI Inpainting
|
62 |
+
edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove') # inpainted bg
|
63 |
+
### RoI Editing
|
64 |
+
changed_instance, x0, y1, scale = infer_with_DiT('RoI Editing', image, instruction, category) # action change
|
65 |
+
fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
|
66 |
+
### RoI Compositioning
|
67 |
+
edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
|
68 |
+
elif category in ('Move', 'Resize'):
|
69 |
+
### RoI Localization
|
70 |
+
mask_image, changed_instance, x0, y1, scale = roi_localization(image, instruction, category)
|
71 |
+
### RoI Inpainting
|
72 |
+
edited_image= infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove') # inpainted bg
|
73 |
+
# changed_instance, bottom_left, scale = layout_change(image, instruction) # move/resize
|
74 |
+
fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
|
75 |
+
fusion_image.save("fusion.png")
|
76 |
+
### RoI Compositioning
|
77 |
+
edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
|
78 |
+
|
79 |
+
elif category in ('Appearance Change', 'Background Change', 'Color Change', 'Material Change', 'Expression Change'):
|
80 |
+
### RoI Editing
|
81 |
+
edited_image = infer_with_DiT('RoI Editing', image, instruction, category)
|
82 |
+
|
83 |
+
elif category in ('Tone Transfer', 'Style Change'):
|
84 |
+
### Global Transformation
|
85 |
+
edited_image = infer_with_DiT('Global Transformation', image, instruction, category)
|
86 |
+
|
87 |
+
else:
|
88 |
+
raise ValueError(f"Invalid category: '{category}'")
|
89 |
+
|
90 |
+
image = edited_image
|
91 |
+
image.save(f"results/{i}.png")
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
main()
|
main_json.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
from PIL import Image
|
5 |
+
import openai
|
6 |
+
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
|
7 |
+
from utils import encode_image_to_datauri, cot_with_gpt, extract_instructions, infer_with_DiT, roi_localization, fusion
|
8 |
+
|
9 |
+
|
10 |
+
def main():
|
11 |
+
parser = argparse.ArgumentParser(description="Evaluate single image + instruction using GPT-4o")
|
12 |
+
parser.add_argument("image_path", help="Path to input image")
|
13 |
+
parser.add_argument("json_path", help="Path to JSON file containing categories and instructions")
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
openai.api_key = "YOUR_API_KEY"
|
17 |
+
|
18 |
+
if not openai.api_key:
|
19 |
+
raise ValueError("OPENAI_API_KEY environment variable not set.")
|
20 |
+
|
21 |
+
os.makedirs("results", exist_ok=True)
|
22 |
+
|
23 |
+
|
24 |
+
#######################################################
|
25 |
+
### Load instructions from JSON ###
|
26 |
+
#######################################################
|
27 |
+
try:
|
28 |
+
with open(args.json_path, 'r') as f:
|
29 |
+
data = json.load(f)
|
30 |
+
categories = data.get('categories', [])
|
31 |
+
instructions = data.get('instructions', [])
|
32 |
+
|
33 |
+
if not categories or not instructions:
|
34 |
+
raise ValueError("JSON file must contain 'categories' and 'instructions' arrays.")
|
35 |
+
|
36 |
+
if len(categories) != len(instructions):
|
37 |
+
raise ValueError("Length of 'categories' and 'instructions' must match.")
|
38 |
+
|
39 |
+
print("Loaded instructions from JSON:")
|
40 |
+
for i, (cat, instr) in enumerate(zip(categories, instructions)):
|
41 |
+
print(f"Step {i+1}: [{cat}] {instr}")
|
42 |
+
|
43 |
+
except Exception as e:
|
44 |
+
raise ValueError(f"Failed to load JSON file: {str(e)}")
|
45 |
+
|
46 |
+
###################################################
|
47 |
+
### Neural Program Interpreter ###
|
48 |
+
###################################################
|
49 |
+
for i in range(len(categories)):
|
50 |
+
if i == 0:
|
51 |
+
image = args.image_path
|
52 |
+
else:
|
53 |
+
image = f"results/{i-1}.png"
|
54 |
+
category = categories[i]
|
55 |
+
instruction = instructions[i]
|
56 |
+
|
57 |
+
if category in ('Add', 'Remove', 'Replace', 'Action Change', 'Move', 'Resize'):
|
58 |
+
if category in ('Add', 'Remove', 'Replace'):
|
59 |
+
if category == 'Add':
|
60 |
+
edited_image = infer_with_DiT('RoI Editing', image, instruction, category)
|
61 |
+
else:
|
62 |
+
mask_image = roi_localization(image, instruction, category)
|
63 |
+
edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, category)
|
64 |
+
elif category == 'Action Change':
|
65 |
+
mask_image = roi_localization(image, instruction, category)
|
66 |
+
edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove')
|
67 |
+
changed_instance, x0, y1, scale = infer_with_DiT('RoI Editing', image, instruction, category)
|
68 |
+
fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
|
69 |
+
edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
|
70 |
+
elif category in ('Move', 'Resize'):
|
71 |
+
mask_image, changed_instance, x0, y1, scale = roi_localization(image, instruction, category)
|
72 |
+
edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove')
|
73 |
+
fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
|
74 |
+
fusion_image.save("fusion.png")
|
75 |
+
edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
|
76 |
+
|
77 |
+
elif category in ('Appearance Change', 'Background Change', 'Color Change', 'Material Change', 'Expression Change'):
|
78 |
+
edited_image = infer_with_DiT('RoI Editing', image, instruction, category)
|
79 |
+
|
80 |
+
elif category in ('Tone Transfer', 'Style Change'):
|
81 |
+
edited_image = infer_with_DiT('Global Transformation', image, instruction, category)
|
82 |
+
|
83 |
+
else:
|
84 |
+
raise ValueError(f"Invalid category: '{category}'")
|
85 |
+
|
86 |
+
image = edited_image
|
87 |
+
image.save(f"results/{i}.png")
|
88 |
+
print(f"Step {i+1} completed: {category} - {instruction}")
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.32.0
|
2 |
+
transformers==4.42.3
|
3 |
+
xtuner[deepspeed]==0.1.23
|
4 |
+
timm==1.0.9
|
5 |
+
mmdet==3.3.0
|
6 |
+
hydra-core==1.3.2
|
7 |
+
ninja==1.11.1
|
8 |
+
decord==0.6.0
|
9 |
+
peft==0.11.1
|
10 |
+
protobuf==5.29.4
|
11 |
+
sentencepiece==0.2.0
|
12 |
+
tornado==6.4.2
|
13 |
+
openai==0.28.0
|
14 |
+
gradio==5.32.0
|
15 |
+
opencv-python
|
src/flux/block.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List, Union, Optional, Dict, Any, Callable
|
3 |
+
from diffusers.models.attention_processor import Attention, F
|
4 |
+
from .lora_controller import enable_lora
|
5 |
+
|
6 |
+
|
7 |
+
def attn_forward(
|
8 |
+
attn: Attention,
|
9 |
+
hidden_states: torch.FloatTensor,
|
10 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
11 |
+
condition_latents: torch.FloatTensor = None,
|
12 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
13 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
14 |
+
cond_rotary_emb: Optional[torch.Tensor] = None,
|
15 |
+
model_config: Optional[Dict[str, Any]] = {},
|
16 |
+
) -> torch.FloatTensor:
|
17 |
+
batch_size, _, _ = (
|
18 |
+
hidden_states.shape
|
19 |
+
if encoder_hidden_states is None
|
20 |
+
else encoder_hidden_states.shape
|
21 |
+
)
|
22 |
+
|
23 |
+
with enable_lora(
|
24 |
+
(attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
|
25 |
+
):
|
26 |
+
# `sample` projections.
|
27 |
+
query = attn.to_q(hidden_states)
|
28 |
+
key = attn.to_k(hidden_states)
|
29 |
+
value = attn.to_v(hidden_states)
|
30 |
+
|
31 |
+
inner_dim = key.shape[-1]
|
32 |
+
head_dim = inner_dim // attn.heads
|
33 |
+
|
34 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
35 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
36 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
37 |
+
|
38 |
+
if attn.norm_q is not None:
|
39 |
+
query = attn.norm_q(query)
|
40 |
+
if attn.norm_k is not None:
|
41 |
+
key = attn.norm_k(key)
|
42 |
+
|
43 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
44 |
+
if encoder_hidden_states is not None:
|
45 |
+
# `context` projections.
|
46 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
47 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
48 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
49 |
+
|
50 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
51 |
+
batch_size, -1, attn.heads, head_dim
|
52 |
+
).transpose(1, 2)
|
53 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
54 |
+
batch_size, -1, attn.heads, head_dim
|
55 |
+
).transpose(1, 2)
|
56 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
57 |
+
batch_size, -1, attn.heads, head_dim
|
58 |
+
).transpose(1, 2)
|
59 |
+
|
60 |
+
if attn.norm_added_q is not None:
|
61 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(
|
62 |
+
encoder_hidden_states_query_proj
|
63 |
+
)
|
64 |
+
if attn.norm_added_k is not None:
|
65 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(
|
66 |
+
encoder_hidden_states_key_proj
|
67 |
+
)
|
68 |
+
|
69 |
+
# attention
|
70 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
71 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
72 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
73 |
+
|
74 |
+
if image_rotary_emb is not None:
|
75 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
76 |
+
|
77 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
78 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
79 |
+
|
80 |
+
if condition_latents is not None:
|
81 |
+
cond_query = attn.to_q(condition_latents)
|
82 |
+
cond_key = attn.to_k(condition_latents)
|
83 |
+
cond_value = attn.to_v(condition_latents)
|
84 |
+
|
85 |
+
cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
|
86 |
+
1, 2
|
87 |
+
)
|
88 |
+
cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
89 |
+
cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
|
90 |
+
1, 2
|
91 |
+
)
|
92 |
+
if attn.norm_q is not None:
|
93 |
+
cond_query = attn.norm_q(cond_query)
|
94 |
+
if attn.norm_k is not None:
|
95 |
+
cond_key = attn.norm_k(cond_key)
|
96 |
+
|
97 |
+
if cond_rotary_emb is not None:
|
98 |
+
cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
|
99 |
+
cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
|
100 |
+
|
101 |
+
if condition_latents is not None:
|
102 |
+
query = torch.cat([query, cond_query], dim=2)
|
103 |
+
key = torch.cat([key, cond_key], dim=2)
|
104 |
+
value = torch.cat([value, cond_value], dim=2)
|
105 |
+
|
106 |
+
if not model_config.get("union_cond_attn", True):
|
107 |
+
# If we don't want to use the union condition attention, we need to mask the attention
|
108 |
+
# between the hidden states and the condition latents
|
109 |
+
attention_mask = torch.ones(
|
110 |
+
query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
|
111 |
+
)
|
112 |
+
condition_n = cond_query.shape[2]
|
113 |
+
attention_mask[-condition_n:, :-condition_n] = False
|
114 |
+
attention_mask[:-condition_n, -condition_n:] = False
|
115 |
+
elif model_config.get("independent_condition", False):
|
116 |
+
attention_mask = torch.ones(
|
117 |
+
query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
|
118 |
+
)
|
119 |
+
condition_n = cond_query.shape[2]
|
120 |
+
attention_mask[-condition_n:, :-condition_n] = False
|
121 |
+
if hasattr(attn, "c_factor"):
|
122 |
+
attention_mask = torch.zeros(
|
123 |
+
query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
|
124 |
+
)
|
125 |
+
condition_n = cond_query.shape[2]
|
126 |
+
bias = torch.log(attn.c_factor[0])
|
127 |
+
attention_mask[-condition_n:, :-condition_n] = bias
|
128 |
+
attention_mask[:-condition_n, -condition_n:] = bias
|
129 |
+
hidden_states = F.scaled_dot_product_attention(
|
130 |
+
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
|
131 |
+
)
|
132 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
133 |
+
batch_size, -1, attn.heads * head_dim
|
134 |
+
)
|
135 |
+
hidden_states = hidden_states.to(query.dtype)
|
136 |
+
|
137 |
+
if encoder_hidden_states is not None:
|
138 |
+
if condition_latents is not None:
|
139 |
+
encoder_hidden_states, hidden_states, condition_latents = (
|
140 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
141 |
+
hidden_states[
|
142 |
+
:, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
|
143 |
+
],
|
144 |
+
hidden_states[:, -condition_latents.shape[1] :],
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
encoder_hidden_states, hidden_states = (
|
148 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
149 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
150 |
+
)
|
151 |
+
|
152 |
+
with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
|
153 |
+
# linear proj
|
154 |
+
hidden_states = attn.to_out[0](hidden_states)
|
155 |
+
# dropout
|
156 |
+
hidden_states = attn.to_out[1](hidden_states)
|
157 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
158 |
+
|
159 |
+
if condition_latents is not None:
|
160 |
+
condition_latents = attn.to_out[0](condition_latents)
|
161 |
+
condition_latents = attn.to_out[1](condition_latents)
|
162 |
+
|
163 |
+
return (
|
164 |
+
(hidden_states, encoder_hidden_states, condition_latents)
|
165 |
+
if condition_latents is not None
|
166 |
+
else (hidden_states, encoder_hidden_states)
|
167 |
+
)
|
168 |
+
elif condition_latents is not None:
|
169 |
+
# if there are condition_latents, we need to separate the hidden_states and the condition_latents
|
170 |
+
hidden_states, condition_latents = (
|
171 |
+
hidden_states[:, : -condition_latents.shape[1]],
|
172 |
+
hidden_states[:, -condition_latents.shape[1] :],
|
173 |
+
)
|
174 |
+
return hidden_states, condition_latents
|
175 |
+
else:
|
176 |
+
return hidden_states
|
177 |
+
|
178 |
+
|
179 |
+
def block_forward(
|
180 |
+
self,
|
181 |
+
hidden_states: torch.FloatTensor,
|
182 |
+
encoder_hidden_states: torch.FloatTensor,
|
183 |
+
condition_latents: torch.FloatTensor,
|
184 |
+
temb: torch.FloatTensor,
|
185 |
+
cond_temb: torch.FloatTensor,
|
186 |
+
cond_rotary_emb=None,
|
187 |
+
image_rotary_emb=None,
|
188 |
+
model_config: Optional[Dict[str, Any]] = {},
|
189 |
+
):
|
190 |
+
use_cond = condition_latents is not None
|
191 |
+
with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
|
192 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
193 |
+
hidden_states, emb=temb
|
194 |
+
)
|
195 |
+
|
196 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
197 |
+
self.norm1_context(encoder_hidden_states, emb=temb)
|
198 |
+
)
|
199 |
+
|
200 |
+
if use_cond:
|
201 |
+
(
|
202 |
+
norm_condition_latents,
|
203 |
+
cond_gate_msa,
|
204 |
+
cond_shift_mlp,
|
205 |
+
cond_scale_mlp,
|
206 |
+
cond_gate_mlp,
|
207 |
+
) = self.norm1(condition_latents, emb=cond_temb)
|
208 |
+
|
209 |
+
# Attention.
|
210 |
+
result = attn_forward(
|
211 |
+
self.attn,
|
212 |
+
model_config=model_config,
|
213 |
+
hidden_states=norm_hidden_states,
|
214 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
215 |
+
condition_latents=norm_condition_latents if use_cond else None,
|
216 |
+
image_rotary_emb=image_rotary_emb,
|
217 |
+
cond_rotary_emb=cond_rotary_emb if use_cond else None,
|
218 |
+
)
|
219 |
+
attn_output, context_attn_output = result[:2]
|
220 |
+
cond_attn_output = result[2] if use_cond else None
|
221 |
+
|
222 |
+
# Process attention outputs for the `hidden_states`.
|
223 |
+
# 1. hidden_states
|
224 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
225 |
+
hidden_states = hidden_states + attn_output
|
226 |
+
# 2. encoder_hidden_states
|
227 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
228 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
229 |
+
# 3. condition_latents
|
230 |
+
if use_cond:
|
231 |
+
cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
|
232 |
+
condition_latents = condition_latents + cond_attn_output
|
233 |
+
if model_config.get("add_cond_attn", False):
|
234 |
+
hidden_states += cond_attn_output
|
235 |
+
|
236 |
+
# LayerNorm + MLP.
|
237 |
+
# 1. hidden_states
|
238 |
+
norm_hidden_states = self.norm2(hidden_states)
|
239 |
+
norm_hidden_states = (
|
240 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
241 |
+
)
|
242 |
+
# 2. encoder_hidden_states
|
243 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
244 |
+
norm_encoder_hidden_states = (
|
245 |
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
246 |
+
)
|
247 |
+
# 3. condition_latents
|
248 |
+
if use_cond:
|
249 |
+
norm_condition_latents = self.norm2(condition_latents)
|
250 |
+
norm_condition_latents = (
|
251 |
+
norm_condition_latents * (1 + cond_scale_mlp[:, None])
|
252 |
+
+ cond_shift_mlp[:, None]
|
253 |
+
)
|
254 |
+
|
255 |
+
# Feed-forward.
|
256 |
+
with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
|
257 |
+
# 1. hidden_states
|
258 |
+
ff_output = self.ff(norm_hidden_states)
|
259 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
260 |
+
# 2. encoder_hidden_states
|
261 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
262 |
+
context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
|
263 |
+
# 3. condition_latents
|
264 |
+
if use_cond:
|
265 |
+
cond_ff_output = self.ff(norm_condition_latents)
|
266 |
+
cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
|
267 |
+
|
268 |
+
# Process feed-forward outputs.
|
269 |
+
hidden_states = hidden_states + ff_output
|
270 |
+
encoder_hidden_states = encoder_hidden_states + context_ff_output
|
271 |
+
if use_cond:
|
272 |
+
condition_latents = condition_latents + cond_ff_output
|
273 |
+
|
274 |
+
# Clip to avoid overflow.
|
275 |
+
if encoder_hidden_states.dtype == torch.float16:
|
276 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
277 |
+
|
278 |
+
return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
|
279 |
+
|
280 |
+
|
281 |
+
def single_block_forward(
|
282 |
+
self,
|
283 |
+
hidden_states: torch.FloatTensor,
|
284 |
+
temb: torch.FloatTensor,
|
285 |
+
image_rotary_emb=None,
|
286 |
+
condition_latents: torch.FloatTensor = None,
|
287 |
+
cond_temb: torch.FloatTensor = None,
|
288 |
+
cond_rotary_emb=None,
|
289 |
+
model_config: Optional[Dict[str, Any]] = {},
|
290 |
+
):
|
291 |
+
|
292 |
+
using_cond = condition_latents is not None
|
293 |
+
residual = hidden_states
|
294 |
+
with enable_lora(
|
295 |
+
(
|
296 |
+
self.norm.linear,
|
297 |
+
self.proj_mlp,
|
298 |
+
),
|
299 |
+
model_config.get("latent_lora", False),
|
300 |
+
):
|
301 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
302 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
303 |
+
if using_cond:
|
304 |
+
residual_cond = condition_latents
|
305 |
+
norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
|
306 |
+
mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
|
307 |
+
|
308 |
+
attn_output = attn_forward(
|
309 |
+
self.attn,
|
310 |
+
model_config=model_config,
|
311 |
+
hidden_states=norm_hidden_states,
|
312 |
+
image_rotary_emb=image_rotary_emb,
|
313 |
+
**(
|
314 |
+
{
|
315 |
+
"condition_latents": norm_condition_latents,
|
316 |
+
"cond_rotary_emb": cond_rotary_emb if using_cond else None,
|
317 |
+
}
|
318 |
+
if using_cond
|
319 |
+
else {}
|
320 |
+
),
|
321 |
+
)
|
322 |
+
if using_cond:
|
323 |
+
attn_output, cond_attn_output = attn_output
|
324 |
+
|
325 |
+
with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
|
326 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
327 |
+
gate = gate.unsqueeze(1)
|
328 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
329 |
+
hidden_states = residual + hidden_states
|
330 |
+
if using_cond:
|
331 |
+
condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
|
332 |
+
cond_gate = cond_gate.unsqueeze(1)
|
333 |
+
condition_latents = cond_gate * self.proj_out(condition_latents)
|
334 |
+
condition_latents = residual_cond + condition_latents
|
335 |
+
|
336 |
+
if hidden_states.dtype == torch.float16:
|
337 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
338 |
+
|
339 |
+
return hidden_states if not using_cond else (hidden_states, condition_latents)
|
src/flux/condition.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional, Union, List, Tuple
|
3 |
+
from diffusers.pipelines import FluxPipeline
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from .pipeline_tools import encode_images
|
9 |
+
|
10 |
+
condition_dict = {
|
11 |
+
"depth": 0,
|
12 |
+
"canny": 1,
|
13 |
+
"subject": 4,
|
14 |
+
"coloring": 6,
|
15 |
+
"deblurring": 7,
|
16 |
+
"depth_pred": 8,
|
17 |
+
"fill": 9,
|
18 |
+
"sr": 10,
|
19 |
+
"cartoon": 11,
|
20 |
+
"scene": 12
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
class Condition(object):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
condition_type: str,
|
28 |
+
raw_img: Union[Image.Image, torch.Tensor] = None,
|
29 |
+
condition: Union[Image.Image, torch.Tensor] = None,
|
30 |
+
mask=None,
|
31 |
+
position_delta=None,
|
32 |
+
position_scale=1.0,
|
33 |
+
) -> None:
|
34 |
+
self.condition_type = condition_type
|
35 |
+
assert raw_img is not None or condition is not None
|
36 |
+
if raw_img is not None:
|
37 |
+
self.condition = self.get_condition(condition_type, raw_img)
|
38 |
+
else:
|
39 |
+
self.condition = condition
|
40 |
+
self.position_delta = position_delta
|
41 |
+
self.position_scale = position_scale
|
42 |
+
# TODO: Add mask support
|
43 |
+
assert mask is None, "Mask not supported yet"
|
44 |
+
|
45 |
+
def get_condition(
|
46 |
+
self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
|
47 |
+
) -> Union[Image.Image, torch.Tensor]:
|
48 |
+
"""
|
49 |
+
Returns the condition image.
|
50 |
+
"""
|
51 |
+
if condition_type == "depth":
|
52 |
+
from transformers import pipeline
|
53 |
+
|
54 |
+
depth_pipe = pipeline(
|
55 |
+
task="depth-estimation",
|
56 |
+
model="LiheYoung/depth-anything-small-hf",
|
57 |
+
device="cuda",
|
58 |
+
)
|
59 |
+
source_image = raw_img.convert("RGB")
|
60 |
+
condition_img = depth_pipe(source_image)["depth"].convert("RGB")
|
61 |
+
return condition_img
|
62 |
+
elif condition_type == "canny":
|
63 |
+
img = np.array(raw_img)
|
64 |
+
edges = cv2.Canny(img, 100, 200)
|
65 |
+
edges = Image.fromarray(edges).convert("RGB")
|
66 |
+
return edges
|
67 |
+
elif condition_type == "subject":
|
68 |
+
return raw_img
|
69 |
+
elif condition_type == "coloring":
|
70 |
+
return raw_img.convert("L").convert("RGB")
|
71 |
+
elif condition_type == "deblurring":
|
72 |
+
condition_image = (
|
73 |
+
raw_img.convert("RGB")
|
74 |
+
.filter(ImageFilter.GaussianBlur(10))
|
75 |
+
.convert("RGB")
|
76 |
+
)
|
77 |
+
return condition_image
|
78 |
+
elif condition_type == "fill":
|
79 |
+
return raw_img.convert("RGB")
|
80 |
+
elif condition_type == "cartoon":
|
81 |
+
return raw_img.convert("RGB")
|
82 |
+
elif condition_type == "scene":
|
83 |
+
return raw_img.convert("RGB")
|
84 |
+
return self.condition
|
85 |
+
|
86 |
+
@property
|
87 |
+
def type_id(self) -> int:
|
88 |
+
"""
|
89 |
+
Returns the type id of the condition.
|
90 |
+
"""
|
91 |
+
return condition_dict[self.condition_type]
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def get_type_id(cls, condition_type: str) -> int:
|
95 |
+
"""
|
96 |
+
Returns the type id of the condition.
|
97 |
+
"""
|
98 |
+
return condition_dict[condition_type]
|
99 |
+
|
100 |
+
def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
101 |
+
"""
|
102 |
+
Encodes the condition into tokens, ids and type_id.
|
103 |
+
"""
|
104 |
+
if self.condition_type in [
|
105 |
+
"depth",
|
106 |
+
"canny",
|
107 |
+
"subject",
|
108 |
+
"coloring",
|
109 |
+
"deblurring",
|
110 |
+
"depth_pred",
|
111 |
+
"fill",
|
112 |
+
"sr",
|
113 |
+
"cartoon",
|
114 |
+
"scene"
|
115 |
+
]:
|
116 |
+
tokens, ids = encode_images(pipe, self.condition)
|
117 |
+
else:
|
118 |
+
raise NotImplementedError(
|
119 |
+
f"Condition type {self.condition_type} not implemented"
|
120 |
+
)
|
121 |
+
if self.position_delta is None and self.condition_type == "subject":
|
122 |
+
self.position_delta = [0, -self.condition.size[0] // 16]
|
123 |
+
if self.position_delta is not None:
|
124 |
+
ids[:, 1] += self.position_delta[0]
|
125 |
+
ids[:, 2] += self.position_delta[1]
|
126 |
+
if self.position_scale != 1.0:
|
127 |
+
scale_bias = (self.position_scale - 1.0) / 2
|
128 |
+
ids[:, 1] *= self.position_scale
|
129 |
+
ids[:, 2] *= self.position_scale
|
130 |
+
ids[:, 1] += scale_bias
|
131 |
+
ids[:, 2] += scale_bias
|
132 |
+
type_id = torch.ones_like(ids[:, :1]) * self.type_id
|
133 |
+
return tokens, ids, type_id
|
src/flux/generate.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import yaml, os
|
3 |
+
from diffusers.pipelines import FluxPipeline
|
4 |
+
from typing import List, Union, Optional, Dict, Any, Callable
|
5 |
+
from .transformer import tranformer_forward
|
6 |
+
from .condition import Condition
|
7 |
+
|
8 |
+
from diffusers.pipelines.flux.pipeline_flux import (
|
9 |
+
FluxPipelineOutput,
|
10 |
+
calculate_shift,
|
11 |
+
retrieve_timesteps,
|
12 |
+
np,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def get_config(config_path: str = None):
|
17 |
+
config_path = config_path or os.environ.get("XFL_CONFIG")
|
18 |
+
if not config_path:
|
19 |
+
return {}
|
20 |
+
with open(config_path, "r") as f:
|
21 |
+
config = yaml.safe_load(f)
|
22 |
+
return config
|
23 |
+
|
24 |
+
|
25 |
+
def prepare_params(
|
26 |
+
prompt: Union[str, List[str]] = None,
|
27 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
28 |
+
height: Optional[int] = 512,
|
29 |
+
width: Optional[int] = 512,
|
30 |
+
num_inference_steps: int = 28,
|
31 |
+
timesteps: List[int] = None,
|
32 |
+
guidance_scale: float = 3.5,
|
33 |
+
num_images_per_prompt: Optional[int] = 1,
|
34 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
35 |
+
latents: Optional[torch.FloatTensor] = None,
|
36 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
37 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
38 |
+
output_type: Optional[str] = "pil",
|
39 |
+
return_dict: bool = True,
|
40 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
41 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
42 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
43 |
+
max_sequence_length: int = 512,
|
44 |
+
**kwargs: dict,
|
45 |
+
):
|
46 |
+
return (
|
47 |
+
prompt,
|
48 |
+
prompt_2,
|
49 |
+
height,
|
50 |
+
width,
|
51 |
+
num_inference_steps,
|
52 |
+
timesteps,
|
53 |
+
guidance_scale,
|
54 |
+
num_images_per_prompt,
|
55 |
+
generator,
|
56 |
+
latents,
|
57 |
+
prompt_embeds,
|
58 |
+
pooled_prompt_embeds,
|
59 |
+
output_type,
|
60 |
+
return_dict,
|
61 |
+
joint_attention_kwargs,
|
62 |
+
callback_on_step_end,
|
63 |
+
callback_on_step_end_tensor_inputs,
|
64 |
+
max_sequence_length,
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
def seed_everything(seed: int = 42):
|
69 |
+
torch.backends.cudnn.deterministic = True
|
70 |
+
torch.manual_seed(seed)
|
71 |
+
np.random.seed(seed)
|
72 |
+
|
73 |
+
|
74 |
+
@torch.no_grad()
|
75 |
+
def generate(
|
76 |
+
pipeline: FluxPipeline,
|
77 |
+
conditions: List[Condition] = None,
|
78 |
+
config_path: str = None,
|
79 |
+
model_config: Optional[Dict[str, Any]] = {},
|
80 |
+
condition_scale: float = 1.0,
|
81 |
+
default_lora: bool = False,
|
82 |
+
image_guidance_scale: float = 1.0,
|
83 |
+
**params: dict,
|
84 |
+
):
|
85 |
+
model_config = model_config or get_config(config_path).get("model", {})
|
86 |
+
# print(model_config)
|
87 |
+
if condition_scale != 1:
|
88 |
+
for name, module in pipeline.transformer.named_modules():
|
89 |
+
if not name.endswith(".attn"):
|
90 |
+
continue
|
91 |
+
module.c_factor = torch.ones(1, 1) * condition_scale
|
92 |
+
|
93 |
+
self = pipeline
|
94 |
+
(
|
95 |
+
prompt,
|
96 |
+
prompt_2,
|
97 |
+
height,
|
98 |
+
width,
|
99 |
+
num_inference_steps,
|
100 |
+
timesteps,
|
101 |
+
guidance_scale,
|
102 |
+
num_images_per_prompt,
|
103 |
+
generator,
|
104 |
+
latents,
|
105 |
+
prompt_embeds,
|
106 |
+
pooled_prompt_embeds,
|
107 |
+
output_type,
|
108 |
+
return_dict,
|
109 |
+
joint_attention_kwargs,
|
110 |
+
callback_on_step_end,
|
111 |
+
callback_on_step_end_tensor_inputs,
|
112 |
+
max_sequence_length,
|
113 |
+
) = prepare_params(**params)
|
114 |
+
|
115 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
116 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
117 |
+
|
118 |
+
# 1. Check inputs. Raise error if not correct
|
119 |
+
self.check_inputs(
|
120 |
+
prompt,
|
121 |
+
prompt_2,
|
122 |
+
height,
|
123 |
+
width,
|
124 |
+
prompt_embeds=prompt_embeds,
|
125 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
126 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
127 |
+
max_sequence_length=max_sequence_length,
|
128 |
+
)
|
129 |
+
|
130 |
+
self._guidance_scale = guidance_scale
|
131 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
132 |
+
self._interrupt = False
|
133 |
+
|
134 |
+
# 2. Define call parameters
|
135 |
+
if prompt is not None and isinstance(prompt, str):
|
136 |
+
batch_size = 1
|
137 |
+
elif prompt is not None and isinstance(prompt, list):
|
138 |
+
batch_size = len(prompt)
|
139 |
+
else:
|
140 |
+
batch_size = prompt_embeds.shape[0]
|
141 |
+
|
142 |
+
device = self._execution_device
|
143 |
+
|
144 |
+
lora_scale = (
|
145 |
+
self.joint_attention_kwargs.get("scale", None)
|
146 |
+
if self.joint_attention_kwargs is not None
|
147 |
+
else None
|
148 |
+
)
|
149 |
+
(
|
150 |
+
prompt_embeds,
|
151 |
+
pooled_prompt_embeds,
|
152 |
+
text_ids,
|
153 |
+
) = self.encode_prompt(
|
154 |
+
prompt=prompt,
|
155 |
+
prompt_2=prompt_2,
|
156 |
+
prompt_embeds=prompt_embeds,
|
157 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
158 |
+
device=device,
|
159 |
+
num_images_per_prompt=num_images_per_prompt,
|
160 |
+
max_sequence_length=max_sequence_length,
|
161 |
+
lora_scale=lora_scale,
|
162 |
+
)
|
163 |
+
|
164 |
+
# 4. Prepare latent variables
|
165 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
166 |
+
latents, latent_image_ids = self.prepare_latents(
|
167 |
+
batch_size * num_images_per_prompt,
|
168 |
+
num_channels_latents,
|
169 |
+
height,
|
170 |
+
width,
|
171 |
+
prompt_embeds.dtype,
|
172 |
+
device,
|
173 |
+
generator,
|
174 |
+
latents,
|
175 |
+
)
|
176 |
+
|
177 |
+
# 4.1. Prepare conditions
|
178 |
+
condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
|
179 |
+
use_condition = conditions is not None or []
|
180 |
+
if use_condition:
|
181 |
+
assert len(conditions) <= 1, "Only one condition is supported for now."
|
182 |
+
if not default_lora:
|
183 |
+
pipeline.set_adapters(conditions[0].condition_type)
|
184 |
+
for condition in conditions:
|
185 |
+
tokens, ids, type_id = condition.encode(self)
|
186 |
+
condition_latents.append(tokens) # [batch_size, token_n, token_dim]
|
187 |
+
condition_ids.append(ids) # [token_n, id_dim(3)]
|
188 |
+
condition_type_ids.append(type_id) # [token_n, 1]
|
189 |
+
condition_latents = torch.cat(condition_latents, dim=1)
|
190 |
+
condition_ids = torch.cat(condition_ids, dim=0)
|
191 |
+
condition_type_ids = torch.cat(condition_type_ids, dim=0)
|
192 |
+
|
193 |
+
# 5. Prepare timesteps
|
194 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
195 |
+
image_seq_len = latents.shape[1]
|
196 |
+
mu = calculate_shift(
|
197 |
+
image_seq_len,
|
198 |
+
self.scheduler.config.base_image_seq_len,
|
199 |
+
self.scheduler.config.max_image_seq_len,
|
200 |
+
self.scheduler.config.base_shift,
|
201 |
+
self.scheduler.config.max_shift,
|
202 |
+
)
|
203 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
204 |
+
self.scheduler,
|
205 |
+
num_inference_steps,
|
206 |
+
device,
|
207 |
+
timesteps,
|
208 |
+
sigmas,
|
209 |
+
mu=mu,
|
210 |
+
)
|
211 |
+
num_warmup_steps = max(
|
212 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
213 |
+
)
|
214 |
+
self._num_timesteps = len(timesteps)
|
215 |
+
|
216 |
+
# 6. Denoising loop
|
217 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
218 |
+
for i, t in enumerate(timesteps):
|
219 |
+
if self.interrupt:
|
220 |
+
continue
|
221 |
+
|
222 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
223 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
224 |
+
|
225 |
+
# handle guidance
|
226 |
+
if self.transformer.config.guidance_embeds:
|
227 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
228 |
+
guidance = guidance.expand(latents.shape[0])
|
229 |
+
else:
|
230 |
+
guidance = None
|
231 |
+
noise_pred = tranformer_forward(
|
232 |
+
self.transformer,
|
233 |
+
model_config=model_config,
|
234 |
+
# Inputs of the condition (new feature)
|
235 |
+
condition_latents=condition_latents if use_condition else None,
|
236 |
+
condition_ids=condition_ids if use_condition else None,
|
237 |
+
condition_type_ids=condition_type_ids if use_condition else None,
|
238 |
+
# Inputs to the original transformer
|
239 |
+
hidden_states=latents,
|
240 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
241 |
+
timestep=timestep / 1000,
|
242 |
+
guidance=guidance,
|
243 |
+
pooled_projections=pooled_prompt_embeds,
|
244 |
+
encoder_hidden_states=prompt_embeds,
|
245 |
+
txt_ids=text_ids,
|
246 |
+
img_ids=latent_image_ids,
|
247 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
248 |
+
return_dict=False,
|
249 |
+
)[0]
|
250 |
+
|
251 |
+
if image_guidance_scale != 1.0:
|
252 |
+
uncondition_latents = condition.encode(self, empty=True)[0]
|
253 |
+
unc_pred = tranformer_forward(
|
254 |
+
self.transformer,
|
255 |
+
model_config=model_config,
|
256 |
+
# Inputs of the condition (new feature)
|
257 |
+
condition_latents=uncondition_latents if use_condition else None,
|
258 |
+
condition_ids=condition_ids if use_condition else None,
|
259 |
+
condition_type_ids=condition_type_ids if use_condition else None,
|
260 |
+
# Inputs to the original transformer
|
261 |
+
hidden_states=latents,
|
262 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
263 |
+
timestep=timestep / 1000,
|
264 |
+
guidance=torch.ones_like(guidance),
|
265 |
+
pooled_projections=pooled_prompt_embeds,
|
266 |
+
encoder_hidden_states=prompt_embeds,
|
267 |
+
txt_ids=text_ids,
|
268 |
+
img_ids=latent_image_ids,
|
269 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
270 |
+
return_dict=False,
|
271 |
+
)[0]
|
272 |
+
|
273 |
+
noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
|
274 |
+
|
275 |
+
# compute the previous noisy sample x_t -> x_t-1
|
276 |
+
latents_dtype = latents.dtype
|
277 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
278 |
+
|
279 |
+
if latents.dtype != latents_dtype:
|
280 |
+
if torch.backends.mps.is_available():
|
281 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
282 |
+
latents = latents.to(latents_dtype)
|
283 |
+
|
284 |
+
if callback_on_step_end is not None:
|
285 |
+
callback_kwargs = {}
|
286 |
+
for k in callback_on_step_end_tensor_inputs:
|
287 |
+
callback_kwargs[k] = locals()[k]
|
288 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
289 |
+
|
290 |
+
latents = callback_outputs.pop("latents", latents)
|
291 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
292 |
+
|
293 |
+
# call the callback, if provided
|
294 |
+
if i == len(timesteps) - 1 or (
|
295 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
296 |
+
):
|
297 |
+
progress_bar.update()
|
298 |
+
|
299 |
+
if output_type == "latent":
|
300 |
+
image = latents
|
301 |
+
|
302 |
+
else:
|
303 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
304 |
+
latents = (
|
305 |
+
latents / self.vae.config.scaling_factor
|
306 |
+
) + self.vae.config.shift_factor
|
307 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
308 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
309 |
+
|
310 |
+
# Offload all models
|
311 |
+
self.maybe_free_model_hooks()
|
312 |
+
|
313 |
+
if condition_scale != 1:
|
314 |
+
for name, module in pipeline.transformer.named_modules():
|
315 |
+
if not name.endswith(".attn"):
|
316 |
+
continue
|
317 |
+
del module.c_factor
|
318 |
+
|
319 |
+
if not return_dict:
|
320 |
+
return (image,)
|
321 |
+
|
322 |
+
return FluxPipelineOutput(images=image)
|
src/flux/lora_controller.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
2 |
+
from typing import List, Any, Optional, Type
|
3 |
+
from .condition import condition_dict
|
4 |
+
|
5 |
+
class enable_lora:
|
6 |
+
def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
|
7 |
+
self.activated: bool = activated
|
8 |
+
if activated:
|
9 |
+
return
|
10 |
+
self.lora_modules: List[BaseTunerLayer] = [
|
11 |
+
each for each in lora_modules if isinstance(each, BaseTunerLayer)
|
12 |
+
]
|
13 |
+
self.scales = [
|
14 |
+
{
|
15 |
+
active_adapter: lora_module.scaling[active_adapter]
|
16 |
+
for active_adapter in lora_module.active_adapters
|
17 |
+
}
|
18 |
+
for lora_module in self.lora_modules
|
19 |
+
]
|
20 |
+
|
21 |
+
def __enter__(self) -> None:
|
22 |
+
if self.activated:
|
23 |
+
return
|
24 |
+
|
25 |
+
for lora_module in self.lora_modules:
|
26 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
27 |
+
continue
|
28 |
+
for active_adapter in lora_module.active_adapters:
|
29 |
+
if active_adapter in condition_dict.keys():
|
30 |
+
lora_module.scaling[active_adapter] = 0.0
|
31 |
+
|
32 |
+
def __exit__(
|
33 |
+
self,
|
34 |
+
exc_type: Optional[Type[BaseException]],
|
35 |
+
exc_val: Optional[BaseException],
|
36 |
+
exc_tb: Optional[Any],
|
37 |
+
) -> None:
|
38 |
+
if self.activated:
|
39 |
+
return
|
40 |
+
for i, lora_module in enumerate(self.lora_modules):
|
41 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
42 |
+
continue
|
43 |
+
for active_adapter in lora_module.active_adapters:
|
44 |
+
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
|
45 |
+
|
46 |
+
|
47 |
+
class set_lora_scale:
|
48 |
+
def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
|
49 |
+
self.lora_modules: List[BaseTunerLayer] = [
|
50 |
+
each for each in lora_modules if isinstance(each, BaseTunerLayer)
|
51 |
+
]
|
52 |
+
self.scales = [
|
53 |
+
{
|
54 |
+
active_adapter: lora_module.scaling[active_adapter]
|
55 |
+
for active_adapter in lora_module.active_adapters
|
56 |
+
}
|
57 |
+
for lora_module in self.lora_modules
|
58 |
+
]
|
59 |
+
self.scale = scale
|
60 |
+
|
61 |
+
def __enter__(self) -> None:
|
62 |
+
for lora_module in self.lora_modules:
|
63 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
64 |
+
continue
|
65 |
+
lora_module.scale_layer(self.scale)
|
66 |
+
|
67 |
+
def __exit__(
|
68 |
+
self,
|
69 |
+
exc_type: Optional[Type[BaseException]],
|
70 |
+
exc_val: Optional[BaseException],
|
71 |
+
exc_tb: Optional[Any],
|
72 |
+
) -> None:
|
73 |
+
for i, lora_module in enumerate(self.lora_modules):
|
74 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
75 |
+
continue
|
76 |
+
for active_adapter in lora_module.active_adapters:
|
77 |
+
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
|
src/flux/pipeline_tools.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.pipelines import FluxPipeline
|
2 |
+
from diffusers.utils import logging
|
3 |
+
from diffusers.pipelines.flux.pipeline_flux import logger
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
|
7 |
+
def encode_images(pipeline: FluxPipeline, images: Tensor):
|
8 |
+
images = pipeline.image_processor.preprocess(images)
|
9 |
+
images = images.to(pipeline.device).to(pipeline.dtype)
|
10 |
+
images = pipeline.vae.encode(images).latent_dist.sample()
|
11 |
+
images = (
|
12 |
+
images - pipeline.vae.config.shift_factor
|
13 |
+
) * pipeline.vae.config.scaling_factor
|
14 |
+
images_tokens = pipeline._pack_latents(images, *images.shape)
|
15 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
16 |
+
images.shape[0],
|
17 |
+
images.shape[2],
|
18 |
+
images.shape[3],
|
19 |
+
pipeline.device,
|
20 |
+
pipeline.dtype,
|
21 |
+
)
|
22 |
+
if images_tokens.shape[1] != images_ids.shape[0]:
|
23 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
24 |
+
images.shape[0],
|
25 |
+
images.shape[2] // 2,
|
26 |
+
images.shape[3] // 2,
|
27 |
+
pipeline.device,
|
28 |
+
pipeline.dtype,
|
29 |
+
)
|
30 |
+
return images_tokens, images_ids
|
31 |
+
|
32 |
+
|
33 |
+
def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
|
34 |
+
# Turn off warnings (CLIP overflow)
|
35 |
+
logger.setLevel(logging.ERROR)
|
36 |
+
(
|
37 |
+
prompt_embeds,
|
38 |
+
pooled_prompt_embeds,
|
39 |
+
text_ids,
|
40 |
+
) = pipeline.encode_prompt(
|
41 |
+
prompt=prompts,
|
42 |
+
prompt_2=None,
|
43 |
+
prompt_embeds=None,
|
44 |
+
pooled_prompt_embeds=None,
|
45 |
+
device=pipeline.device,
|
46 |
+
num_images_per_prompt=1,
|
47 |
+
max_sequence_length=max_sequence_length,
|
48 |
+
lora_scale=None,
|
49 |
+
)
|
50 |
+
# Turn on warnings
|
51 |
+
logger.setLevel(logging.WARNING)
|
52 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
src/flux/transformer.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.pipelines import FluxPipeline
|
3 |
+
from typing import List, Union, Optional, Dict, Any, Callable
|
4 |
+
from .block import block_forward, single_block_forward
|
5 |
+
from .lora_controller import enable_lora
|
6 |
+
from accelerate.utils import is_torch_version
|
7 |
+
from diffusers.models.transformers.transformer_flux import (
|
8 |
+
FluxTransformer2DModel,
|
9 |
+
Transformer2DModelOutput,
|
10 |
+
USE_PEFT_BACKEND,
|
11 |
+
scale_lora_layers,
|
12 |
+
unscale_lora_layers,
|
13 |
+
logger,
|
14 |
+
)
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
def prepare_params(
|
19 |
+
hidden_states: torch.Tensor,
|
20 |
+
encoder_hidden_states: torch.Tensor = None,
|
21 |
+
pooled_projections: torch.Tensor = None,
|
22 |
+
timestep: torch.LongTensor = None,
|
23 |
+
img_ids: torch.Tensor = None,
|
24 |
+
txt_ids: torch.Tensor = None,
|
25 |
+
guidance: torch.Tensor = None,
|
26 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
27 |
+
controlnet_block_samples=None,
|
28 |
+
controlnet_single_block_samples=None,
|
29 |
+
return_dict: bool = True,
|
30 |
+
**kwargs: dict,
|
31 |
+
):
|
32 |
+
return (
|
33 |
+
hidden_states,
|
34 |
+
encoder_hidden_states,
|
35 |
+
pooled_projections,
|
36 |
+
timestep,
|
37 |
+
img_ids,
|
38 |
+
txt_ids,
|
39 |
+
guidance,
|
40 |
+
joint_attention_kwargs,
|
41 |
+
controlnet_block_samples,
|
42 |
+
controlnet_single_block_samples,
|
43 |
+
return_dict,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def tranformer_forward(
|
48 |
+
transformer: FluxTransformer2DModel,
|
49 |
+
condition_latents: torch.Tensor,
|
50 |
+
condition_ids: torch.Tensor,
|
51 |
+
condition_type_ids: torch.Tensor,
|
52 |
+
model_config: Optional[Dict[str, Any]] = {},
|
53 |
+
c_t=0,
|
54 |
+
**params: dict,
|
55 |
+
):
|
56 |
+
self = transformer
|
57 |
+
use_condition = condition_latents is not None
|
58 |
+
|
59 |
+
(
|
60 |
+
hidden_states,
|
61 |
+
encoder_hidden_states,
|
62 |
+
pooled_projections,
|
63 |
+
timestep,
|
64 |
+
img_ids,
|
65 |
+
txt_ids,
|
66 |
+
guidance,
|
67 |
+
joint_attention_kwargs,
|
68 |
+
controlnet_block_samples,
|
69 |
+
controlnet_single_block_samples,
|
70 |
+
return_dict,
|
71 |
+
) = prepare_params(**params)
|
72 |
+
|
73 |
+
if joint_attention_kwargs is not None:
|
74 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
75 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
76 |
+
else:
|
77 |
+
lora_scale = 1.0
|
78 |
+
|
79 |
+
if USE_PEFT_BACKEND:
|
80 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
81 |
+
scale_lora_layers(self, lora_scale)
|
82 |
+
else:
|
83 |
+
if (
|
84 |
+
joint_attention_kwargs is not None
|
85 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
86 |
+
):
|
87 |
+
logger.warning(
|
88 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
89 |
+
)
|
90 |
+
|
91 |
+
with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
|
92 |
+
hidden_states = self.x_embedder(hidden_states)
|
93 |
+
condition_latents = self.x_embedder(condition_latents) if use_condition else None
|
94 |
+
|
95 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
96 |
+
|
97 |
+
if guidance is not None:
|
98 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
99 |
+
else:
|
100 |
+
guidance = None
|
101 |
+
|
102 |
+
temb = (
|
103 |
+
self.time_text_embed(timestep, pooled_projections)
|
104 |
+
if guidance is None
|
105 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
106 |
+
)
|
107 |
+
|
108 |
+
cond_temb = (
|
109 |
+
self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
|
110 |
+
if guidance is None
|
111 |
+
else self.time_text_embed(
|
112 |
+
torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
|
113 |
+
)
|
114 |
+
)
|
115 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
116 |
+
|
117 |
+
if txt_ids.ndim == 3:
|
118 |
+
logger.warning(
|
119 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
120 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
121 |
+
)
|
122 |
+
txt_ids = txt_ids[0]
|
123 |
+
if img_ids.ndim == 3:
|
124 |
+
logger.warning(
|
125 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
126 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
127 |
+
)
|
128 |
+
img_ids = img_ids[0]
|
129 |
+
|
130 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
131 |
+
image_rotary_emb = self.pos_embed(ids)
|
132 |
+
if use_condition:
|
133 |
+
# condition_ids[:, :1] = condition_type_ids
|
134 |
+
cond_rotary_emb = self.pos_embed(condition_ids)
|
135 |
+
|
136 |
+
# hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
|
137 |
+
|
138 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
139 |
+
if self.training and self.gradient_checkpointing:
|
140 |
+
ckpt_kwargs: Dict[str, Any] = (
|
141 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
142 |
+
)
|
143 |
+
encoder_hidden_states, hidden_states, condition_latents = (
|
144 |
+
torch.utils.checkpoint.checkpoint(
|
145 |
+
block_forward,
|
146 |
+
self=block,
|
147 |
+
model_config=model_config,
|
148 |
+
hidden_states=hidden_states,
|
149 |
+
encoder_hidden_states=encoder_hidden_states,
|
150 |
+
condition_latents=condition_latents if use_condition else None,
|
151 |
+
temb=temb,
|
152 |
+
cond_temb=cond_temb if use_condition else None,
|
153 |
+
cond_rotary_emb=cond_rotary_emb if use_condition else None,
|
154 |
+
image_rotary_emb=image_rotary_emb,
|
155 |
+
**ckpt_kwargs,
|
156 |
+
)
|
157 |
+
)
|
158 |
+
|
159 |
+
else:
|
160 |
+
encoder_hidden_states, hidden_states, condition_latents = block_forward(
|
161 |
+
block,
|
162 |
+
model_config=model_config,
|
163 |
+
hidden_states=hidden_states,
|
164 |
+
encoder_hidden_states=encoder_hidden_states,
|
165 |
+
condition_latents=condition_latents if use_condition else None,
|
166 |
+
temb=temb,
|
167 |
+
cond_temb=cond_temb if use_condition else None,
|
168 |
+
cond_rotary_emb=cond_rotary_emb if use_condition else None,
|
169 |
+
image_rotary_emb=image_rotary_emb,
|
170 |
+
)
|
171 |
+
|
172 |
+
# controlnet residual
|
173 |
+
if controlnet_block_samples is not None:
|
174 |
+
interval_control = len(self.transformer_blocks) / len(
|
175 |
+
controlnet_block_samples
|
176 |
+
)
|
177 |
+
interval_control = int(np.ceil(interval_control))
|
178 |
+
hidden_states = (
|
179 |
+
hidden_states
|
180 |
+
+ controlnet_block_samples[index_block // interval_control]
|
181 |
+
)
|
182 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
183 |
+
|
184 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
185 |
+
if self.training and self.gradient_checkpointing:
|
186 |
+
ckpt_kwargs: Dict[str, Any] = (
|
187 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
188 |
+
)
|
189 |
+
result = torch.utils.checkpoint.checkpoint(
|
190 |
+
single_block_forward,
|
191 |
+
self=block,
|
192 |
+
model_config=model_config,
|
193 |
+
hidden_states=hidden_states,
|
194 |
+
temb=temb,
|
195 |
+
image_rotary_emb=image_rotary_emb,
|
196 |
+
**(
|
197 |
+
{
|
198 |
+
"condition_latents": condition_latents,
|
199 |
+
"cond_temb": cond_temb,
|
200 |
+
"cond_rotary_emb": cond_rotary_emb,
|
201 |
+
}
|
202 |
+
if use_condition
|
203 |
+
else {}
|
204 |
+
),
|
205 |
+
**ckpt_kwargs,
|
206 |
+
)
|
207 |
+
|
208 |
+
else:
|
209 |
+
result = single_block_forward(
|
210 |
+
block,
|
211 |
+
model_config=model_config,
|
212 |
+
hidden_states=hidden_states,
|
213 |
+
temb=temb,
|
214 |
+
image_rotary_emb=image_rotary_emb,
|
215 |
+
**(
|
216 |
+
{
|
217 |
+
"condition_latents": condition_latents,
|
218 |
+
"cond_temb": cond_temb,
|
219 |
+
"cond_rotary_emb": cond_rotary_emb,
|
220 |
+
}
|
221 |
+
if use_condition
|
222 |
+
else {}
|
223 |
+
),
|
224 |
+
)
|
225 |
+
if use_condition:
|
226 |
+
hidden_states, condition_latents = result
|
227 |
+
else:
|
228 |
+
hidden_states = result
|
229 |
+
|
230 |
+
# controlnet residual
|
231 |
+
if controlnet_single_block_samples is not None:
|
232 |
+
interval_control = len(self.single_transformer_blocks) / len(
|
233 |
+
controlnet_single_block_samples
|
234 |
+
)
|
235 |
+
interval_control = int(np.ceil(interval_control))
|
236 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
237 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
238 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
239 |
+
)
|
240 |
+
|
241 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
242 |
+
|
243 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
244 |
+
output = self.proj_out(hidden_states)
|
245 |
+
|
246 |
+
if USE_PEFT_BACKEND:
|
247 |
+
# remove `lora_scale` from each PEFT layer
|
248 |
+
unscale_lora_layers(self, lora_scale)
|
249 |
+
|
250 |
+
if not return_dict:
|
251 |
+
return (output,)
|
252 |
+
return Transformer2DModelOutput(sample=output)
|
src/gradio/gradio_app.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image, ImageDraw, ImageFont
|
4 |
+
from diffusers.pipelines import FluxPipeline
|
5 |
+
from diffusers import FluxTransformer2DModel
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from ..flux.condition import Condition
|
9 |
+
from ..flux.generate import seed_everything, generate
|
10 |
+
|
11 |
+
pipe = None
|
12 |
+
use_int8 = False
|
13 |
+
|
14 |
+
|
15 |
+
def get_gpu_memory():
|
16 |
+
return torch.cuda.get_device_properties(0).total_memory / 1024**3
|
17 |
+
|
18 |
+
|
19 |
+
def init_pipeline():
|
20 |
+
global pipe
|
21 |
+
if use_int8 or get_gpu_memory() < 33:
|
22 |
+
transformer_model = FluxTransformer2DModel.from_pretrained(
|
23 |
+
"sayakpaul/flux.1-schell-int8wo-improved",
|
24 |
+
torch_dtype=torch.bfloat16,
|
25 |
+
use_safetensors=False,
|
26 |
+
)
|
27 |
+
pipe = FluxPipeline.from_pretrained(
|
28 |
+
"black-forest-labs/FLUX.1-schnell",
|
29 |
+
transformer=transformer_model,
|
30 |
+
torch_dtype=torch.bfloat16,
|
31 |
+
)
|
32 |
+
else:
|
33 |
+
pipe = FluxPipeline.from_pretrained(
|
34 |
+
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
35 |
+
)
|
36 |
+
pipe = pipe.to("cuda")
|
37 |
+
pipe.load_lora_weights(
|
38 |
+
"Yuanshi/OminiControl",
|
39 |
+
weight_name="omini/subject_512.safetensors",
|
40 |
+
adapter_name="subject",
|
41 |
+
)
|
42 |
+
|
43 |
+
# Optional: Load additional LoRA weights
|
44 |
+
#pipe.load_lora_weights("XLabs-AI/flux-RealismLora", adapter_name="realism")
|
45 |
+
|
46 |
+
|
47 |
+
def process_image_and_text(image, text):
|
48 |
+
# center crop image
|
49 |
+
w, h, min_size = image.size[0], image.size[1], min(image.size)
|
50 |
+
image = image.crop(
|
51 |
+
(
|
52 |
+
(w - min_size) // 2,
|
53 |
+
(h - min_size) // 2,
|
54 |
+
(w + min_size) // 2,
|
55 |
+
(h + min_size) // 2,
|
56 |
+
)
|
57 |
+
)
|
58 |
+
image = image.resize((512, 512))
|
59 |
+
|
60 |
+
condition = Condition("subject", image, position_delta=(0, 32))
|
61 |
+
|
62 |
+
if pipe is None:
|
63 |
+
init_pipeline()
|
64 |
+
|
65 |
+
result_img = generate(
|
66 |
+
pipe,
|
67 |
+
prompt=text.strip(),
|
68 |
+
conditions=[condition],
|
69 |
+
num_inference_steps=8,
|
70 |
+
height=512,
|
71 |
+
width=512,
|
72 |
+
).images[0]
|
73 |
+
|
74 |
+
return result_img
|
75 |
+
|
76 |
+
|
77 |
+
def get_samples():
|
78 |
+
sample_list = [
|
79 |
+
{
|
80 |
+
"image": "assets/oranges.jpg",
|
81 |
+
"text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"image": "assets/penguin.jpg",
|
85 |
+
"text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"image": "assets/rc_car.jpg",
|
89 |
+
"text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"image": "assets/clock.jpg",
|
93 |
+
"text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"image": "assets/tshirt.jpg",
|
97 |
+
"text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.",
|
98 |
+
},
|
99 |
+
]
|
100 |
+
return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
|
101 |
+
|
102 |
+
|
103 |
+
demo = gr.Interface(
|
104 |
+
fn=process_image_and_text,
|
105 |
+
inputs=[
|
106 |
+
gr.Image(type="pil"),
|
107 |
+
gr.Textbox(lines=2),
|
108 |
+
],
|
109 |
+
outputs=gr.Image(type="pil"),
|
110 |
+
title="OminiControl / Subject driven generation",
|
111 |
+
examples=get_samples(),
|
112 |
+
)
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
init_pipeline()
|
116 |
+
demo.launch(
|
117 |
+
debug=True,
|
118 |
+
)
|
src/train/callbacks.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning as L
|
2 |
+
from PIL import Image, ImageFilter, ImageDraw
|
3 |
+
import numpy as np
|
4 |
+
from transformers import pipeline
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
|
9 |
+
try:
|
10 |
+
import wandb
|
11 |
+
except ImportError:
|
12 |
+
wandb = None
|
13 |
+
|
14 |
+
from ..flux.condition import Condition
|
15 |
+
from ..flux.generate import generate
|
16 |
+
|
17 |
+
|
18 |
+
class TrainingCallback(L.Callback):
|
19 |
+
def __init__(self, run_name, training_config: dict = {}):
|
20 |
+
self.run_name, self.training_config = run_name, training_config
|
21 |
+
|
22 |
+
self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
|
23 |
+
self.save_interval = training_config.get("save_interval", 1000)
|
24 |
+
self.sample_interval = training_config.get("sample_interval", 1000)
|
25 |
+
self.save_path = training_config.get("save_path", "./output")
|
26 |
+
|
27 |
+
self.wandb_config = training_config.get("wandb", None)
|
28 |
+
self.use_wandb = (
|
29 |
+
wandb is not None and os.environ.get("WANDB_API_KEY") is not None
|
30 |
+
)
|
31 |
+
|
32 |
+
self.total_steps = 0
|
33 |
+
|
34 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
35 |
+
gradient_size = 0
|
36 |
+
max_gradient_size = 0
|
37 |
+
count = 0
|
38 |
+
for _, param in pl_module.named_parameters():
|
39 |
+
if param.grad is not None:
|
40 |
+
gradient_size += param.grad.norm(2).item()
|
41 |
+
max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
|
42 |
+
count += 1
|
43 |
+
if count > 0:
|
44 |
+
gradient_size /= count
|
45 |
+
|
46 |
+
self.total_steps += 1
|
47 |
+
|
48 |
+
# Print training progress every n steps
|
49 |
+
if self.use_wandb:
|
50 |
+
report_dict = {
|
51 |
+
"steps": batch_idx,
|
52 |
+
"steps": self.total_steps,
|
53 |
+
"epoch": trainer.current_epoch,
|
54 |
+
"gradient_size": gradient_size,
|
55 |
+
}
|
56 |
+
loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
|
57 |
+
report_dict["loss"] = loss_value
|
58 |
+
report_dict["t"] = pl_module.last_t
|
59 |
+
wandb.log(report_dict)
|
60 |
+
|
61 |
+
if self.total_steps % self.print_every_n_steps == 0:
|
62 |
+
print(
|
63 |
+
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
|
64 |
+
)
|
65 |
+
|
66 |
+
# Save LoRA weights at specified intervals
|
67 |
+
if self.total_steps % self.save_interval == 0:
|
68 |
+
print(
|
69 |
+
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
|
70 |
+
)
|
71 |
+
pl_module.save_lora(
|
72 |
+
f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
|
73 |
+
)
|
74 |
+
|
75 |
+
# Generate and save a sample image at specified intervals
|
76 |
+
if self.total_steps % self.sample_interval == 0:
|
77 |
+
print(
|
78 |
+
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
|
79 |
+
)
|
80 |
+
self.generate_a_sample(
|
81 |
+
trainer,
|
82 |
+
pl_module,
|
83 |
+
f"{self.save_path}/{self.run_name}/output",
|
84 |
+
f"lora_{self.total_steps}",
|
85 |
+
batch["condition_type"][
|
86 |
+
0
|
87 |
+
], # Use the condition type from the current batch
|
88 |
+
)
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def generate_a_sample(
|
92 |
+
self,
|
93 |
+
trainer,
|
94 |
+
pl_module,
|
95 |
+
save_path,
|
96 |
+
file_name,
|
97 |
+
condition_type="super_resolution",
|
98 |
+
):
|
99 |
+
# TODO: change this two variables to parameters
|
100 |
+
condition_size = trainer.training_config["dataset"]["condition_size"]
|
101 |
+
target_size = trainer.training_config["dataset"]["target_size"]
|
102 |
+
position_scale = trainer.training_config["dataset"].get("position_scale", 1.0)
|
103 |
+
|
104 |
+
generator = torch.Generator(device=pl_module.device)
|
105 |
+
generator.manual_seed(42)
|
106 |
+
|
107 |
+
test_list = []
|
108 |
+
|
109 |
+
if condition_type == "subject":
|
110 |
+
test_list.extend(
|
111 |
+
[
|
112 |
+
(
|
113 |
+
Image.open("assets/test_in.jpg"),
|
114 |
+
[0, -32],
|
115 |
+
"Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.",
|
116 |
+
),
|
117 |
+
(
|
118 |
+
Image.open("assets/test_out.jpg"),
|
119 |
+
[0, -32],
|
120 |
+
"In a bright room. It is placed on a table.",
|
121 |
+
),
|
122 |
+
]
|
123 |
+
)
|
124 |
+
elif condition_type == "scene":
|
125 |
+
test_list.extend(
|
126 |
+
[
|
127 |
+
(
|
128 |
+
Image.open("assets/a2759.jpg"),
|
129 |
+
[0, -32],
|
130 |
+
"change the color of the plane to red",
|
131 |
+
),
|
132 |
+
(
|
133 |
+
Image.open("assets/clock.jpg"),
|
134 |
+
[0, -32],
|
135 |
+
"turn the color of the clock to blue",
|
136 |
+
),
|
137 |
+
]
|
138 |
+
)
|
139 |
+
elif condition_type == "canny":
|
140 |
+
condition_img = Image.open("assets/vase_hq.jpg").resize(
|
141 |
+
(condition_size, condition_size)
|
142 |
+
)
|
143 |
+
condition_img = np.array(condition_img)
|
144 |
+
condition_img = cv2.Canny(condition_img, 100, 200)
|
145 |
+
condition_img = Image.fromarray(condition_img).convert("RGB")
|
146 |
+
test_list.append(
|
147 |
+
(
|
148 |
+
condition_img,
|
149 |
+
[0, 0],
|
150 |
+
"A beautiful vase on a table.",
|
151 |
+
{"position_scale": position_scale} if position_scale != 1.0 else {},
|
152 |
+
)
|
153 |
+
)
|
154 |
+
elif condition_type == "coloring":
|
155 |
+
condition_img = (
|
156 |
+
Image.open("assets/vase_hq.jpg")
|
157 |
+
.resize((condition_size, condition_size))
|
158 |
+
.convert("L")
|
159 |
+
.convert("RGB")
|
160 |
+
)
|
161 |
+
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
|
162 |
+
elif condition_type == "depth":
|
163 |
+
if not hasattr(self, "deepth_pipe"):
|
164 |
+
self.deepth_pipe = pipeline(
|
165 |
+
task="depth-estimation",
|
166 |
+
model="LiheYoung/depth-anything-small-hf",
|
167 |
+
device="cpu",
|
168 |
+
)
|
169 |
+
condition_img = (
|
170 |
+
Image.open("assets/vase_hq.jpg")
|
171 |
+
.resize((condition_size, condition_size))
|
172 |
+
.convert("RGB")
|
173 |
+
)
|
174 |
+
condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB")
|
175 |
+
test_list.append(
|
176 |
+
(
|
177 |
+
condition_img,
|
178 |
+
[0, 0],
|
179 |
+
"A beautiful vase on a table.",
|
180 |
+
{"position_scale": position_scale} if position_scale != 1.0 else {},
|
181 |
+
)
|
182 |
+
)
|
183 |
+
elif condition_type == "depth_pred":
|
184 |
+
condition_img = (
|
185 |
+
Image.open("assets/vase_hq.jpg")
|
186 |
+
.resize((condition_size, condition_size))
|
187 |
+
.convert("RGB")
|
188 |
+
)
|
189 |
+
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
|
190 |
+
elif condition_type == "deblurring":
|
191 |
+
blur_radius = 5
|
192 |
+
image = Image.open("./assets/vase_hq.jpg")
|
193 |
+
condition_img = (
|
194 |
+
image.convert("RGB")
|
195 |
+
.resize((condition_size, condition_size))
|
196 |
+
.filter(ImageFilter.GaussianBlur(blur_radius))
|
197 |
+
.convert("RGB")
|
198 |
+
)
|
199 |
+
test_list.append(
|
200 |
+
(
|
201 |
+
condition_img,
|
202 |
+
[0, 0],
|
203 |
+
"A beautiful vase on a table.",
|
204 |
+
{"position_scale": position_scale} if position_scale != 1.0 else {},
|
205 |
+
)
|
206 |
+
)
|
207 |
+
elif condition_type == "fill":
|
208 |
+
condition_img = (
|
209 |
+
Image.open("./assets/vase_hq.jpg")
|
210 |
+
.resize((condition_size, condition_size))
|
211 |
+
.convert("RGB")
|
212 |
+
)
|
213 |
+
mask = Image.new("L", condition_img.size, 0)
|
214 |
+
draw = ImageDraw.Draw(mask)
|
215 |
+
a = condition_img.size[0] // 4
|
216 |
+
b = a * 3
|
217 |
+
draw.rectangle([a, a, b, b], fill=255)
|
218 |
+
condition_img = Image.composite(
|
219 |
+
condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask
|
220 |
+
)
|
221 |
+
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
|
222 |
+
elif condition_type == "sr":
|
223 |
+
condition_img = (
|
224 |
+
Image.open("assets/vase_hq.jpg")
|
225 |
+
.resize((condition_size, condition_size))
|
226 |
+
.convert("RGB")
|
227 |
+
)
|
228 |
+
test_list.append((condition_img, [0, -16], "A beautiful vase on a table."))
|
229 |
+
elif condition_type == "cartoon":
|
230 |
+
condition_img = (
|
231 |
+
Image.open("assets/cartoon_boy.png")
|
232 |
+
.resize((condition_size, condition_size))
|
233 |
+
.convert("RGB")
|
234 |
+
)
|
235 |
+
test_list.append(
|
236 |
+
(
|
237 |
+
condition_img,
|
238 |
+
[0, -16],
|
239 |
+
"A cartoon character in a white background. He is looking right, and running.",
|
240 |
+
)
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
raise NotImplementedError
|
244 |
+
|
245 |
+
if not os.path.exists(save_path):
|
246 |
+
os.makedirs(save_path)
|
247 |
+
for i, (condition_img, position_delta, prompt, *others) in enumerate(test_list):
|
248 |
+
condition = Condition(
|
249 |
+
condition_type=condition_type,
|
250 |
+
condition=condition_img.resize(
|
251 |
+
(condition_size, condition_size)
|
252 |
+
).convert("RGB"),
|
253 |
+
position_delta=position_delta,
|
254 |
+
**(others[0] if others else {}),
|
255 |
+
)
|
256 |
+
res = generate(
|
257 |
+
pl_module.flux_pipe,
|
258 |
+
prompt=prompt,
|
259 |
+
conditions=[condition],
|
260 |
+
height=target_size,
|
261 |
+
width=target_size,
|
262 |
+
generator=generator,
|
263 |
+
model_config=pl_module.model_config,
|
264 |
+
default_lora=True,
|
265 |
+
)
|
266 |
+
res.images[0].save(
|
267 |
+
os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
|
268 |
+
)
|
src/train/data.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageFilter, ImageDraw
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torchvision.transforms as T
|
6 |
+
import random
|
7 |
+
|
8 |
+
|
9 |
+
class Subject200KDataset(Dataset):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
base_dataset,
|
13 |
+
condition_size: int = 512,
|
14 |
+
target_size: int = 512,
|
15 |
+
image_size: int = 512,
|
16 |
+
padding: int = 0,
|
17 |
+
condition_type: str = "subject",
|
18 |
+
drop_text_prob: float = 0.1,
|
19 |
+
drop_image_prob: float = 0.1,
|
20 |
+
return_pil_image: bool = False,
|
21 |
+
):
|
22 |
+
self.base_dataset = base_dataset
|
23 |
+
self.condition_size = condition_size
|
24 |
+
self.target_size = target_size
|
25 |
+
self.image_size = image_size
|
26 |
+
self.padding = padding
|
27 |
+
self.condition_type = condition_type
|
28 |
+
self.drop_text_prob = drop_text_prob
|
29 |
+
self.drop_image_prob = drop_image_prob
|
30 |
+
self.return_pil_image = return_pil_image
|
31 |
+
|
32 |
+
self.to_tensor = T.ToTensor()
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.base_dataset) * 2
|
36 |
+
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
# If target is 0, left image is target, right image is condition
|
39 |
+
target = idx % 2
|
40 |
+
item = self.base_dataset[idx // 2]
|
41 |
+
|
42 |
+
# Crop the image to target and condition
|
43 |
+
image = item["image"]
|
44 |
+
left_img = image.crop(
|
45 |
+
(
|
46 |
+
self.padding,
|
47 |
+
self.padding,
|
48 |
+
self.image_size + self.padding,
|
49 |
+
self.image_size + self.padding,
|
50 |
+
)
|
51 |
+
)
|
52 |
+
right_img = image.crop(
|
53 |
+
(
|
54 |
+
self.image_size + self.padding * 2,
|
55 |
+
self.padding,
|
56 |
+
self.image_size * 2 + self.padding * 2,
|
57 |
+
self.image_size + self.padding,
|
58 |
+
)
|
59 |
+
)
|
60 |
+
|
61 |
+
# Get the target and condition image
|
62 |
+
target_image, condition_img = (
|
63 |
+
(left_img, right_img) if target == 0 else (right_img, left_img)
|
64 |
+
)
|
65 |
+
|
66 |
+
# Resize the image
|
67 |
+
condition_img = condition_img.resize(
|
68 |
+
(self.condition_size, self.condition_size)
|
69 |
+
).convert("RGB")
|
70 |
+
target_image = target_image.resize(
|
71 |
+
(self.target_size, self.target_size)
|
72 |
+
).convert("RGB")
|
73 |
+
|
74 |
+
# Get the description
|
75 |
+
description = item["description"][
|
76 |
+
"description_0" if target == 0 else "description_1"
|
77 |
+
]
|
78 |
+
|
79 |
+
# Randomly drop text or image
|
80 |
+
drop_text = random.random() < self.drop_text_prob
|
81 |
+
drop_image = random.random() < self.drop_image_prob
|
82 |
+
if drop_text:
|
83 |
+
description = ""
|
84 |
+
if drop_image:
|
85 |
+
condition_img = Image.new(
|
86 |
+
"RGB", (self.condition_size, self.condition_size), (0, 0, 0)
|
87 |
+
)
|
88 |
+
|
89 |
+
return {
|
90 |
+
"image": self.to_tensor(target_image),
|
91 |
+
"condition": self.to_tensor(condition_img),
|
92 |
+
"condition_type": self.condition_type,
|
93 |
+
"description": description,
|
94 |
+
# 16 is the downscale factor of the image
|
95 |
+
"position_delta": np.array([0, -self.condition_size // 16]),
|
96 |
+
**({"pil_image": image} if self.return_pil_image else {}),
|
97 |
+
}
|
98 |
+
|
99 |
+
class SceneDataset(Dataset):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
base_dataset,
|
103 |
+
condition_size: int = 512,
|
104 |
+
target_size: int = 512,
|
105 |
+
image_size: int = 512,
|
106 |
+
padding: int = 0,
|
107 |
+
condition_type: str = "scene",
|
108 |
+
drop_text_prob: float = 0.1,
|
109 |
+
drop_image_prob: float = 0.1,
|
110 |
+
return_pil_image: bool = False,
|
111 |
+
):
|
112 |
+
self.base_dataset = base_dataset
|
113 |
+
self.condition_size = condition_size
|
114 |
+
self.target_size = target_size
|
115 |
+
self.image_size = image_size
|
116 |
+
self.padding = padding
|
117 |
+
self.condition_type = condition_type
|
118 |
+
self.drop_text_prob = drop_text_prob
|
119 |
+
self.drop_image_prob = drop_image_prob
|
120 |
+
self.return_pil_image = return_pil_image
|
121 |
+
|
122 |
+
self.to_tensor = T.ToTensor()
|
123 |
+
|
124 |
+
def __len__(self):
|
125 |
+
return len(self.base_dataset)
|
126 |
+
|
127 |
+
def __getitem__(self, idx):
|
128 |
+
# If target is 0, left image is target, right image is condition
|
129 |
+
# target = idx % 2
|
130 |
+
target = 1
|
131 |
+
item = self.base_dataset[idx // 2]
|
132 |
+
|
133 |
+
# Crop the image to target and condition
|
134 |
+
imageA = item["imageA"]
|
135 |
+
imageB = item["imageB"]
|
136 |
+
|
137 |
+
left_img = imageA
|
138 |
+
right_img = imageB
|
139 |
+
|
140 |
+
# Get the target and condition image
|
141 |
+
target_image, condition_img = (
|
142 |
+
(left_img, right_img) if target == 0 else (right_img, left_img)
|
143 |
+
)
|
144 |
+
|
145 |
+
# Resize the image
|
146 |
+
condition_img = condition_img.resize(
|
147 |
+
(self.condition_size, self.condition_size)
|
148 |
+
).convert("RGB")
|
149 |
+
target_image = target_image.resize(
|
150 |
+
(self.target_size, self.target_size)
|
151 |
+
).convert("RGB")
|
152 |
+
|
153 |
+
# Get the description
|
154 |
+
description = item["prompt"]
|
155 |
+
|
156 |
+
# Randomly drop text or image
|
157 |
+
drop_text = random.random() < self.drop_text_prob
|
158 |
+
drop_image = random.random() < self.drop_image_prob
|
159 |
+
if drop_text:
|
160 |
+
description = ""
|
161 |
+
if drop_image:
|
162 |
+
condition_img = Image.new(
|
163 |
+
"RGB", (self.condition_size, self.condition_size), (0, 0, 0)
|
164 |
+
)
|
165 |
+
|
166 |
+
return {
|
167 |
+
"image": self.to_tensor(target_image),
|
168 |
+
"condition": self.to_tensor(condition_img),
|
169 |
+
"condition_type": self.condition_type,
|
170 |
+
"description": description,
|
171 |
+
"position_delta": np.array([0, -self.condition_size // 16]),
|
172 |
+
**({"pil_image": [target_image, condition_img]} if self.return_pil_image else {}),
|
173 |
+
}
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
class ImageConditionDataset(Dataset):
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
base_dataset,
|
182 |
+
condition_size: int = 512,
|
183 |
+
target_size: int = 512,
|
184 |
+
condition_type: str = "canny",
|
185 |
+
drop_text_prob: float = 0.1,
|
186 |
+
drop_image_prob: float = 0.1,
|
187 |
+
return_pil_image: bool = False,
|
188 |
+
position_scale=1.0,
|
189 |
+
):
|
190 |
+
self.base_dataset = base_dataset
|
191 |
+
self.condition_size = condition_size
|
192 |
+
self.target_size = target_size
|
193 |
+
self.condition_type = condition_type
|
194 |
+
self.drop_text_prob = drop_text_prob
|
195 |
+
self.drop_image_prob = drop_image_prob
|
196 |
+
self.return_pil_image = return_pil_image
|
197 |
+
self.position_scale = position_scale
|
198 |
+
|
199 |
+
self.to_tensor = T.ToTensor()
|
200 |
+
|
201 |
+
def __len__(self):
|
202 |
+
return len(self.base_dataset)
|
203 |
+
|
204 |
+
@property
|
205 |
+
def depth_pipe(self):
|
206 |
+
if not hasattr(self, "_depth_pipe"):
|
207 |
+
from transformers import pipeline
|
208 |
+
|
209 |
+
self._depth_pipe = pipeline(
|
210 |
+
task="depth-estimation",
|
211 |
+
model="LiheYoung/depth-anything-small-hf",
|
212 |
+
device="cpu",
|
213 |
+
)
|
214 |
+
return self._depth_pipe
|
215 |
+
|
216 |
+
def _get_canny_edge(self, img):
|
217 |
+
resize_ratio = self.condition_size / max(img.size)
|
218 |
+
img = img.resize(
|
219 |
+
(int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
|
220 |
+
)
|
221 |
+
img_np = np.array(img)
|
222 |
+
img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
223 |
+
edges = cv2.Canny(img_gray, 100, 200)
|
224 |
+
return Image.fromarray(edges).convert("RGB")
|
225 |
+
|
226 |
+
def __getitem__(self, idx):
|
227 |
+
image = self.base_dataset[idx]["jpg"]
|
228 |
+
image = image.resize((self.target_size, self.target_size)).convert("RGB")
|
229 |
+
description = self.base_dataset[idx]["json"]["prompt"]
|
230 |
+
|
231 |
+
enable_scale = random.random() < 1
|
232 |
+
if not enable_scale:
|
233 |
+
condition_size = int(self.condition_size * self.position_scale)
|
234 |
+
position_scale = 1.0
|
235 |
+
else:
|
236 |
+
condition_size = self.condition_size
|
237 |
+
position_scale = self.position_scale
|
238 |
+
|
239 |
+
# Get the condition image
|
240 |
+
position_delta = np.array([0, 0])
|
241 |
+
if self.condition_type == "canny":
|
242 |
+
condition_img = self._get_canny_edge(image)
|
243 |
+
elif self.condition_type == "coloring":
|
244 |
+
condition_img = (
|
245 |
+
image.resize((condition_size, condition_size))
|
246 |
+
.convert("L")
|
247 |
+
.convert("RGB")
|
248 |
+
)
|
249 |
+
elif self.condition_type == "deblurring":
|
250 |
+
blur_radius = random.randint(1, 10)
|
251 |
+
condition_img = (
|
252 |
+
image.convert("RGB")
|
253 |
+
.filter(ImageFilter.GaussianBlur(blur_radius))
|
254 |
+
.resize((condition_size, condition_size))
|
255 |
+
.convert("RGB")
|
256 |
+
)
|
257 |
+
elif self.condition_type == "depth":
|
258 |
+
condition_img = self.depth_pipe(image)["depth"].convert("RGB")
|
259 |
+
condition_img = condition_img.resize((condition_size, condition_size))
|
260 |
+
elif self.condition_type == "depth_pred":
|
261 |
+
condition_img = image
|
262 |
+
image = self.depth_pipe(condition_img)["depth"].convert("RGB")
|
263 |
+
description = f"[depth] {description}"
|
264 |
+
elif self.condition_type == "fill":
|
265 |
+
condition_img = image.resize((condition_size, condition_size)).convert(
|
266 |
+
"RGB"
|
267 |
+
)
|
268 |
+
w, h = image.size
|
269 |
+
x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
|
270 |
+
y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
|
271 |
+
mask = Image.new("L", image.size, 0)
|
272 |
+
draw = ImageDraw.Draw(mask)
|
273 |
+
draw.rectangle([x1, y1, x2, y2], fill=255)
|
274 |
+
if random.random() > 0.5:
|
275 |
+
mask = Image.eval(mask, lambda a: 255 - a)
|
276 |
+
condition_img = Image.composite(
|
277 |
+
image, Image.new("RGB", image.size, (0, 0, 0)), mask
|
278 |
+
)
|
279 |
+
elif self.condition_type == "sr":
|
280 |
+
condition_img = image.resize((condition_size, condition_size)).convert(
|
281 |
+
"RGB"
|
282 |
+
)
|
283 |
+
position_delta = np.array([0, -condition_size // 16])
|
284 |
+
|
285 |
+
else:
|
286 |
+
raise ValueError(f"Condition type {self.condition_type} not implemented")
|
287 |
+
|
288 |
+
# Randomly drop text or image
|
289 |
+
drop_text = random.random() < self.drop_text_prob
|
290 |
+
drop_image = random.random() < self.drop_image_prob
|
291 |
+
if drop_text:
|
292 |
+
description = ""
|
293 |
+
if drop_image:
|
294 |
+
condition_img = Image.new(
|
295 |
+
"RGB", (condition_size, condition_size), (0, 0, 0)
|
296 |
+
)
|
297 |
+
|
298 |
+
return {
|
299 |
+
"image": self.to_tensor(image),
|
300 |
+
"condition": self.to_tensor(condition_img),
|
301 |
+
"condition_type": self.condition_type,
|
302 |
+
"description": description,
|
303 |
+
"position_delta": position_delta,
|
304 |
+
**({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
|
305 |
+
**({"position_scale": position_scale} if position_scale != 1.0 else {}),
|
306 |
+
}
|
307 |
+
|
308 |
+
|
309 |
+
class CartoonDataset(Dataset):
|
310 |
+
def __init__(
|
311 |
+
self,
|
312 |
+
base_dataset,
|
313 |
+
condition_size: int = 1024,
|
314 |
+
target_size: int = 1024,
|
315 |
+
image_size: int = 1024,
|
316 |
+
padding: int = 0,
|
317 |
+
condition_type: str = "cartoon",
|
318 |
+
drop_text_prob: float = 0.1,
|
319 |
+
drop_image_prob: float = 0.1,
|
320 |
+
return_pil_image: bool = False,
|
321 |
+
):
|
322 |
+
self.base_dataset = base_dataset
|
323 |
+
self.condition_size = condition_size
|
324 |
+
self.target_size = target_size
|
325 |
+
self.image_size = image_size
|
326 |
+
self.padding = padding
|
327 |
+
self.condition_type = condition_type
|
328 |
+
self.drop_text_prob = drop_text_prob
|
329 |
+
self.drop_image_prob = drop_image_prob
|
330 |
+
self.return_pil_image = return_pil_image
|
331 |
+
|
332 |
+
self.to_tensor = T.ToTensor()
|
333 |
+
|
334 |
+
def __len__(self):
|
335 |
+
return len(self.base_dataset)
|
336 |
+
|
337 |
+
def __getitem__(self, idx):
|
338 |
+
data = self.base_dataset[idx]
|
339 |
+
condition_img = data["condition"]
|
340 |
+
target_image = data["target"]
|
341 |
+
|
342 |
+
# Tag
|
343 |
+
tag = data["tags"][0]
|
344 |
+
|
345 |
+
target_description = data["target_description"]
|
346 |
+
|
347 |
+
description = {
|
348 |
+
"lion": "lion like animal",
|
349 |
+
"bear": "bear like animal",
|
350 |
+
"gorilla": "gorilla like animal",
|
351 |
+
"dog": "dog like animal",
|
352 |
+
"elephant": "elephant like animal",
|
353 |
+
"eagle": "eagle like bird",
|
354 |
+
"tiger": "tiger like animal",
|
355 |
+
"owl": "owl like bird",
|
356 |
+
"woman": "woman",
|
357 |
+
"parrot": "parrot like bird",
|
358 |
+
"mouse": "mouse like animal",
|
359 |
+
"man": "man",
|
360 |
+
"pigeon": "pigeon like bird",
|
361 |
+
"girl": "girl",
|
362 |
+
"panda": "panda like animal",
|
363 |
+
"crocodile": "crocodile like animal",
|
364 |
+
"rabbit": "rabbit like animal",
|
365 |
+
"boy": "boy",
|
366 |
+
"monkey": "monkey like animal",
|
367 |
+
"cat": "cat like animal",
|
368 |
+
}
|
369 |
+
|
370 |
+
# Resize the image
|
371 |
+
condition_img = condition_img.resize(
|
372 |
+
(self.condition_size, self.condition_size)
|
373 |
+
).convert("RGB")
|
374 |
+
target_image = target_image.resize(
|
375 |
+
(self.target_size, self.target_size)
|
376 |
+
).convert("RGB")
|
377 |
+
|
378 |
+
# Process datum to create description
|
379 |
+
description = data.get(
|
380 |
+
"description",
|
381 |
+
f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.",
|
382 |
+
)
|
383 |
+
|
384 |
+
# Randomly drop text or image
|
385 |
+
drop_text = random.random() < self.drop_text_prob
|
386 |
+
drop_image = random.random() < self.drop_image_prob
|
387 |
+
if drop_text:
|
388 |
+
description = ""
|
389 |
+
if drop_image:
|
390 |
+
condition_img = Image.new(
|
391 |
+
"RGB", (self.condition_size, self.condition_size), (0, 0, 0)
|
392 |
+
)
|
393 |
+
|
394 |
+
return {
|
395 |
+
"image": self.to_tensor(target_image),
|
396 |
+
"condition": self.to_tensor(condition_img),
|
397 |
+
"condition_type": self.condition_type,
|
398 |
+
"description": description,
|
399 |
+
# 16 is the downscale factor of the image
|
400 |
+
"position_delta": np.array([0, -16]),
|
401 |
+
}
|
src/train/model.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning as L
|
2 |
+
from diffusers.pipelines import FluxPipeline
|
3 |
+
import torch
|
4 |
+
from peft import LoraConfig, get_peft_model_state_dict
|
5 |
+
|
6 |
+
import prodigyopt
|
7 |
+
|
8 |
+
from ..flux.transformer import tranformer_forward
|
9 |
+
from ..flux.condition import Condition
|
10 |
+
from ..flux.pipeline_tools import encode_images, prepare_text_input
|
11 |
+
|
12 |
+
|
13 |
+
class OminiModel(L.LightningModule):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
flux_pipe_id: str,
|
17 |
+
lora_path: str = None,
|
18 |
+
lora_config: dict = None,
|
19 |
+
device: str = "cuda",
|
20 |
+
dtype: torch.dtype = torch.bfloat16,
|
21 |
+
model_config: dict = {},
|
22 |
+
optimizer_config: dict = None,
|
23 |
+
gradient_checkpointing: bool = False,
|
24 |
+
):
|
25 |
+
# Initialize the LightningModule
|
26 |
+
super().__init__()
|
27 |
+
self.model_config = model_config
|
28 |
+
self.optimizer_config = optimizer_config
|
29 |
+
|
30 |
+
# Load the Flux pipeline
|
31 |
+
self.flux_pipe: FluxPipeline = (
|
32 |
+
FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
|
33 |
+
)
|
34 |
+
self.transformer = self.flux_pipe.transformer
|
35 |
+
self.transformer.gradient_checkpointing = gradient_checkpointing
|
36 |
+
self.transformer.train()
|
37 |
+
|
38 |
+
# Freeze the Flux pipeline
|
39 |
+
self.flux_pipe.text_encoder.requires_grad_(False).eval()
|
40 |
+
self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
|
41 |
+
self.flux_pipe.vae.requires_grad_(False).eval()
|
42 |
+
|
43 |
+
# Initialize LoRA layers
|
44 |
+
self.lora_layers = self.init_lora(lora_path, lora_config)
|
45 |
+
|
46 |
+
self.to(device).to(dtype)
|
47 |
+
|
48 |
+
def init_lora(self, lora_path: str, lora_config: dict):
|
49 |
+
assert lora_path or lora_config
|
50 |
+
if lora_path:
|
51 |
+
# TODO: Implement this
|
52 |
+
raise NotImplementedError
|
53 |
+
else:
|
54 |
+
self.transformer.add_adapter(LoraConfig(**lora_config))
|
55 |
+
# TODO: Check if this is correct (p.requires_grad)
|
56 |
+
lora_layers = filter(
|
57 |
+
lambda p: p.requires_grad, self.transformer.parameters()
|
58 |
+
)
|
59 |
+
return list(lora_layers)
|
60 |
+
|
61 |
+
def save_lora(self, path: str):
|
62 |
+
FluxPipeline.save_lora_weights(
|
63 |
+
save_directory=path,
|
64 |
+
transformer_lora_layers=get_peft_model_state_dict(self.transformer),
|
65 |
+
safe_serialization=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
def configure_optimizers(self):
|
69 |
+
# Freeze the transformer
|
70 |
+
self.transformer.requires_grad_(False)
|
71 |
+
opt_config = self.optimizer_config
|
72 |
+
|
73 |
+
# Set the trainable parameters
|
74 |
+
self.trainable_params = self.lora_layers
|
75 |
+
|
76 |
+
# Unfreeze trainable parameters
|
77 |
+
for p in self.trainable_params:
|
78 |
+
p.requires_grad_(True)
|
79 |
+
|
80 |
+
# Initialize the optimizer
|
81 |
+
if opt_config["type"] == "AdamW":
|
82 |
+
optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
|
83 |
+
elif opt_config["type"] == "Prodigy":
|
84 |
+
optimizer = prodigyopt.Prodigy(
|
85 |
+
self.trainable_params,
|
86 |
+
**opt_config["params"],
|
87 |
+
)
|
88 |
+
elif opt_config["type"] == "SGD":
|
89 |
+
optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
|
90 |
+
else:
|
91 |
+
raise NotImplementedError
|
92 |
+
|
93 |
+
return optimizer
|
94 |
+
|
95 |
+
def training_step(self, batch, batch_idx):
|
96 |
+
step_loss = self.step(batch)
|
97 |
+
self.log_loss = (
|
98 |
+
step_loss.item()
|
99 |
+
if not hasattr(self, "log_loss")
|
100 |
+
else self.log_loss * 0.95 + step_loss.item() * 0.05
|
101 |
+
)
|
102 |
+
return step_loss
|
103 |
+
|
104 |
+
def step(self, batch):
|
105 |
+
imgs = batch["image"]
|
106 |
+
conditions = batch["condition"]
|
107 |
+
condition_types = batch["condition_type"]
|
108 |
+
prompts = batch["description"]
|
109 |
+
position_delta = batch["position_delta"][0]
|
110 |
+
position_scale = float(batch.get("position_scale", [1.0])[0])
|
111 |
+
|
112 |
+
# Prepare inputs
|
113 |
+
with torch.no_grad():
|
114 |
+
# Prepare image input
|
115 |
+
x_0, img_ids = encode_images(self.flux_pipe, imgs)
|
116 |
+
|
117 |
+
# Prepare text input
|
118 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
|
119 |
+
self.flux_pipe, prompts
|
120 |
+
)
|
121 |
+
|
122 |
+
# Prepare t and x_t
|
123 |
+
t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
|
124 |
+
x_1 = torch.randn_like(x_0).to(self.device)
|
125 |
+
t_ = t.unsqueeze(1).unsqueeze(1)
|
126 |
+
x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
|
127 |
+
|
128 |
+
# Prepare conditions
|
129 |
+
condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
|
130 |
+
|
131 |
+
# Add position delta
|
132 |
+
condition_ids[:, 1] += position_delta[0]
|
133 |
+
condition_ids[:, 2] += position_delta[1]
|
134 |
+
|
135 |
+
if position_scale != 1.0:
|
136 |
+
scale_bias = (position_scale - 1.0) / 2
|
137 |
+
condition_ids[:, 1] *= position_scale
|
138 |
+
condition_ids[:, 2] *= position_scale
|
139 |
+
condition_ids[:, 1] += scale_bias
|
140 |
+
condition_ids[:, 2] += scale_bias
|
141 |
+
|
142 |
+
# Prepare condition type
|
143 |
+
condition_type_ids = torch.tensor(
|
144 |
+
[
|
145 |
+
Condition.get_type_id(condition_type)
|
146 |
+
for condition_type in condition_types
|
147 |
+
]
|
148 |
+
).to(self.device)
|
149 |
+
condition_type_ids = (
|
150 |
+
torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
|
151 |
+
).unsqueeze(1)
|
152 |
+
|
153 |
+
# Prepare guidance
|
154 |
+
guidance = (
|
155 |
+
torch.ones_like(t).to(self.device)
|
156 |
+
if self.transformer.config.guidance_embeds
|
157 |
+
else None
|
158 |
+
)
|
159 |
+
|
160 |
+
# Forward pass
|
161 |
+
transformer_out = tranformer_forward(
|
162 |
+
self.transformer,
|
163 |
+
# Model config
|
164 |
+
model_config=self.model_config,
|
165 |
+
# Inputs of the condition (new feature)
|
166 |
+
condition_latents=condition_latents,
|
167 |
+
condition_ids=condition_ids,
|
168 |
+
condition_type_ids=condition_type_ids,
|
169 |
+
# Inputs to the original transformer
|
170 |
+
hidden_states=x_t,
|
171 |
+
timestep=t,
|
172 |
+
guidance=guidance,
|
173 |
+
pooled_projections=pooled_prompt_embeds,
|
174 |
+
encoder_hidden_states=prompt_embeds,
|
175 |
+
txt_ids=text_ids,
|
176 |
+
img_ids=img_ids,
|
177 |
+
joint_attention_kwargs=None,
|
178 |
+
return_dict=False,
|
179 |
+
)
|
180 |
+
pred = transformer_out[0]
|
181 |
+
|
182 |
+
# Compute loss
|
183 |
+
loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
|
184 |
+
self.last_t = t.mean().item()
|
185 |
+
return loss
|
src/train/train.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
import torch
|
3 |
+
import lightning as L
|
4 |
+
import yaml
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import re
|
8 |
+
|
9 |
+
from datasets import load_dataset
|
10 |
+
|
11 |
+
from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset, SceneDataset
|
12 |
+
from .model import OminiModel
|
13 |
+
from .callbacks import TrainingCallback
|
14 |
+
import safetensors.torch
|
15 |
+
from peft import PeftModel
|
16 |
+
|
17 |
+
import os
|
18 |
+
from PIL import Image
|
19 |
+
import pandas as pd
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
|
22 |
+
from torchvision import transforms
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
class LocalSubjectsDataset(Dataset):
|
26 |
+
def __init__(self, csv_file, image_dir, transform=None):
|
27 |
+
self.data = pd.read_csv(csv_file)
|
28 |
+
self.image_dir = image_dir
|
29 |
+
self.transform = transform
|
30 |
+
self.features = {
|
31 |
+
'imageA': 'PIL.Image',
|
32 |
+
'prompt': 'str',
|
33 |
+
'imageB': 'PIL.Image'
|
34 |
+
}
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.data)
|
38 |
+
|
39 |
+
def __getitem__(self, idx):
|
40 |
+
# 获取图片A、描述和图片B的文件名
|
41 |
+
imgA_value = self.data.iloc[idx]['imageA']
|
42 |
+
if isinstance(imgA_value, pd.Series):
|
43 |
+
imgA_value = imgA_value.values[0]
|
44 |
+
imgA_name = os.path.join(self.image_dir, str(imgA_value))
|
45 |
+
|
46 |
+
prompt = self.data.iloc[idx]['prompt']
|
47 |
+
imgB_value = self.data.iloc[idx]['imageB']
|
48 |
+
if isinstance(imgB_value, pd.Series):
|
49 |
+
imgB_value = imgB_value.values[0]
|
50 |
+
imgB_name = os.path.join(self.image_dir, str(imgB_value))
|
51 |
+
|
52 |
+
imageA = Image.open(imgA_name).convert("RGB")
|
53 |
+
imageB = Image.open(imgB_name).convert("RGB")
|
54 |
+
|
55 |
+
if self.transform:
|
56 |
+
imageA = self.transform(imageA)
|
57 |
+
imageB = self.transform(imageB)
|
58 |
+
|
59 |
+
sample = {'imageA': imageA, 'prompt': prompt, 'imageB': imageB}
|
60 |
+
return sample
|
61 |
+
|
62 |
+
transform = transforms.Compose([
|
63 |
+
transforms.Resize((600, 600)),
|
64 |
+
# transforms.ToTensor(),
|
65 |
+
])
|
66 |
+
|
67 |
+
|
68 |
+
def get_rank():
|
69 |
+
try:
|
70 |
+
rank = int(os.environ.get("LOCAL_RANK"))
|
71 |
+
except:
|
72 |
+
rank = 0
|
73 |
+
return rank
|
74 |
+
|
75 |
+
|
76 |
+
def get_config():
|
77 |
+
config_path = os.environ.get("XFL_CONFIG")
|
78 |
+
assert config_path is not None, "Please set the XFL_CONFIG environment variable"
|
79 |
+
with open(config_path, "r") as f:
|
80 |
+
config = yaml.safe_load(f)
|
81 |
+
return config
|
82 |
+
|
83 |
+
|
84 |
+
def init_wandb(wandb_config, run_name):
|
85 |
+
import wandb
|
86 |
+
wandb.init(
|
87 |
+
project=wandb_config["project"],
|
88 |
+
name=run_name,
|
89 |
+
config={},
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
def main():
|
94 |
+
# Initialize
|
95 |
+
is_main_process, rank = get_rank() == 0, get_rank()
|
96 |
+
torch.cuda.set_device(rank)
|
97 |
+
config = get_config()
|
98 |
+
training_config = config["train"]
|
99 |
+
run_name = time.strftime("%Y%m%d-%H%M%S")
|
100 |
+
|
101 |
+
# Initialize WanDB
|
102 |
+
wandb_config = training_config.get("wandb", None)
|
103 |
+
if wandb_config is not None and is_main_process:
|
104 |
+
init_wandb(wandb_config, run_name)
|
105 |
+
|
106 |
+
print("Rank:", rank)
|
107 |
+
if is_main_process:
|
108 |
+
print("Config:", config)
|
109 |
+
|
110 |
+
# Initialize dataset and dataloader
|
111 |
+
if training_config["dataset"]["type"] == "scene":
|
112 |
+
dataset = LocalSubjectsDataset(csv_file='csv_path', image_dir='images_path', transform=transform)
|
113 |
+
data_valid = dataset
|
114 |
+
print(data_valid.features)
|
115 |
+
print(len(data_valid))
|
116 |
+
print(training_config["dataset"])
|
117 |
+
dataset = SceneDataset(
|
118 |
+
data_valid,
|
119 |
+
condition_size=training_config["dataset"]["condition_size"],
|
120 |
+
target_size=training_config["dataset"]["target_size"],
|
121 |
+
image_size=training_config["dataset"]["image_size"],
|
122 |
+
padding=training_config["dataset"]["padding"],
|
123 |
+
condition_type=training_config["condition_type"],
|
124 |
+
drop_text_prob=training_config["dataset"]["drop_text_prob"],
|
125 |
+
drop_image_prob=training_config["dataset"]["drop_image_prob"],
|
126 |
+
)
|
127 |
+
elif training_config["dataset"]["type"] == "img":
|
128 |
+
# Load dataset text-to-image-2M
|
129 |
+
dataset = load_dataset(
|
130 |
+
"webdataset",
|
131 |
+
data_files={"train": training_config["dataset"]["urls"]},
|
132 |
+
split="train",
|
133 |
+
cache_dir="cache/t2i2m",
|
134 |
+
num_proc=32,
|
135 |
+
)
|
136 |
+
dataset = ImageConditionDataset(
|
137 |
+
dataset,
|
138 |
+
condition_size=training_config["dataset"]["condition_size"],
|
139 |
+
target_size=training_config["dataset"]["target_size"],
|
140 |
+
condition_type=training_config["condition_type"],
|
141 |
+
drop_text_prob=training_config["dataset"]["drop_text_prob"],
|
142 |
+
drop_image_prob=training_config["dataset"]["drop_image_prob"],
|
143 |
+
position_scale=training_config["dataset"].get("position_scale", 1.0),
|
144 |
+
)
|
145 |
+
elif training_config["dataset"]["type"] == "cartoon":
|
146 |
+
dataset = load_dataset("saquiboye/oye-cartoon", split="train")
|
147 |
+
dataset = CartoonDataset(
|
148 |
+
dataset,
|
149 |
+
condition_size=training_config["dataset"]["condition_size"],
|
150 |
+
target_size=training_config["dataset"]["target_size"],
|
151 |
+
image_size=training_config["dataset"]["image_size"],
|
152 |
+
padding=training_config["dataset"]["padding"],
|
153 |
+
condition_type=training_config["condition_type"],
|
154 |
+
drop_text_prob=training_config["dataset"]["drop_text_prob"],
|
155 |
+
drop_image_prob=training_config["dataset"]["drop_image_prob"],
|
156 |
+
)
|
157 |
+
elif training_config["dataset"]["type"] == "scene":
|
158 |
+
dataset = dataset
|
159 |
+
else:
|
160 |
+
raise NotImplementedError
|
161 |
+
|
162 |
+
print("Dataset length:", len(dataset))
|
163 |
+
train_loader = DataLoader(
|
164 |
+
dataset,
|
165 |
+
batch_size=training_config["batch_size"],
|
166 |
+
shuffle=True,
|
167 |
+
num_workers=training_config["dataloader_workers"],
|
168 |
+
)
|
169 |
+
print("Trainloader generated.")
|
170 |
+
|
171 |
+
# Initialize model
|
172 |
+
trainable_model = OminiModel(
|
173 |
+
flux_pipe_id=config["flux_path"],
|
174 |
+
lora_config=training_config["lora_config"],
|
175 |
+
device=f"cuda",
|
176 |
+
dtype=getattr(torch, config["dtype"]),
|
177 |
+
optimizer_config=training_config["optimizer"],
|
178 |
+
model_config=config.get("model", {}),
|
179 |
+
gradient_checkpointing=training_config.get("gradient_checkpointing", False),
|
180 |
+
)
|
181 |
+
|
182 |
+
training_callbacks = (
|
183 |
+
[TrainingCallback(run_name, training_config=training_config)]
|
184 |
+
if is_main_process
|
185 |
+
else []
|
186 |
+
)
|
187 |
+
|
188 |
+
# Initialize trainer
|
189 |
+
trainer = L.Trainer(
|
190 |
+
accumulate_grad_batches=training_config["accumulate_grad_batches"],
|
191 |
+
callbacks=training_callbacks,
|
192 |
+
enable_checkpointing=False,
|
193 |
+
enable_progress_bar=False,
|
194 |
+
logger=False,
|
195 |
+
max_steps=training_config.get("max_steps", -1),
|
196 |
+
max_epochs=training_config.get("max_epochs", -1),
|
197 |
+
gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
|
198 |
+
)
|
199 |
+
|
200 |
+
setattr(trainer, "training_config", training_config)
|
201 |
+
|
202 |
+
# Save config
|
203 |
+
save_path = training_config.get("save_path", "./output")
|
204 |
+
if is_main_process:
|
205 |
+
os.makedirs(f"{save_path}/{run_name}")
|
206 |
+
with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
|
207 |
+
yaml.dump(config, f)
|
208 |
+
|
209 |
+
# Start training
|
210 |
+
trainer.fit(trainable_model, train_loader)
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
main()
|
train/README.md
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OminiControl Training 🛠️
|
2 |
+
|
3 |
+
## Preparation
|
4 |
+
|
5 |
+
### Setup
|
6 |
+
1. **Environment**
|
7 |
+
```bash
|
8 |
+
conda create -n omini python=3.10
|
9 |
+
conda activate omini
|
10 |
+
```
|
11 |
+
2. **Requirements**
|
12 |
+
```bash
|
13 |
+
pip install -r train/requirements.txt
|
14 |
+
```
|
15 |
+
|
16 |
+
### Dataset
|
17 |
+
1. Download dataset [Subject200K](https://huggingface.co/datasets/Yuanshi/Subjects200K). (**subject-driven generation**)
|
18 |
+
```
|
19 |
+
bash train/script/data_download/data_download1.sh
|
20 |
+
```
|
21 |
+
2. Download dataset [text-to-image-2M](https://huggingface.co/datasets/jackyhate/text-to-image-2M). (**spatial control task**)
|
22 |
+
```
|
23 |
+
bash train/script/data_download/data_download2.sh
|
24 |
+
```
|
25 |
+
**Note:** By default, only a few files are downloaded. You can modify `data_download2.sh` to download additional datasets. Remember to update the config file to specify the training data accordingly.
|
26 |
+
|
27 |
+
## Training
|
28 |
+
|
29 |
+
### Start training training
|
30 |
+
**Config file path**: `./train/config`
|
31 |
+
|
32 |
+
**Scripts path**: `./train/script`
|
33 |
+
|
34 |
+
1. Subject-driven generation
|
35 |
+
```bash
|
36 |
+
bash train/script/train_subject.sh
|
37 |
+
```
|
38 |
+
2. Spatial control task
|
39 |
+
```bash
|
40 |
+
bash train/script/train_canny.sh
|
41 |
+
```
|
42 |
+
|
43 |
+
**Note**: Detailed WanDB settings and GPU settings can be found in the script files and the config files.
|
44 |
+
|
45 |
+
### Other spatial control tasks
|
46 |
+
This repository supports 5 spatial control tasks:
|
47 |
+
1. Canny edge to image (`canny`)
|
48 |
+
2. Image colorization (`coloring`)
|
49 |
+
3. Image deblurring (`deblurring`)
|
50 |
+
4. Depth map to image (`depth`)
|
51 |
+
5. Image to depth map (`depth_pred`)
|
52 |
+
6. Image inpainting (`fill`)
|
53 |
+
7. Super resolution (`sr`)
|
54 |
+
|
55 |
+
You can modify the `condition_type` parameter in config file `config/canny_512.yaml` to switch between different tasks.
|
56 |
+
|
57 |
+
### Customize your own task
|
58 |
+
You can customize your own task by constructing a new dataset and modifying the training code.
|
59 |
+
|
60 |
+
<details>
|
61 |
+
<summary>Instructions</summary>
|
62 |
+
|
63 |
+
1. **Dataset** :
|
64 |
+
|
65 |
+
Construct a new dataset with the following format: (`src/train/data.py`)
|
66 |
+
```python
|
67 |
+
class MyDataset(Dataset):
|
68 |
+
def __init__(self, ...):
|
69 |
+
...
|
70 |
+
def __len__(self):
|
71 |
+
...
|
72 |
+
def __getitem__(self, idx):
|
73 |
+
...
|
74 |
+
return {
|
75 |
+
"image": image,
|
76 |
+
"condition": condition_img,
|
77 |
+
"condition_type": "your_condition_type",
|
78 |
+
"description": description,
|
79 |
+
"position_delta": position_delta
|
80 |
+
}
|
81 |
+
```
|
82 |
+
**Note:** For spatial control tasks, set the `position_delta` to be `[0, 0]`. For non-spatial control tasks, set `position_delta` to be `[0, -condition_width // 16]`.
|
83 |
+
2. **Condition**:
|
84 |
+
|
85 |
+
Add a new condition type in the `Condition` class. (`src/flux/condition.py`)
|
86 |
+
```python
|
87 |
+
condition_dict = {
|
88 |
+
...
|
89 |
+
"your_condition_type": your_condition_id_number, # Add your condition type here
|
90 |
+
}
|
91 |
+
...
|
92 |
+
if condition_type in [
|
93 |
+
...
|
94 |
+
"your_condition_type", # Add your condition type here
|
95 |
+
]:
|
96 |
+
...
|
97 |
+
```
|
98 |
+
3. **Test**:
|
99 |
+
|
100 |
+
Add a new test function for your task. (`src/train/callbacks.py`)
|
101 |
+
```python
|
102 |
+
if self.condition_type == "your_condition_type":
|
103 |
+
condition_img = (
|
104 |
+
Image.open("images/vase.jpg")
|
105 |
+
.resize((condition_size, condition_size))
|
106 |
+
.convert("RGB")
|
107 |
+
)
|
108 |
+
...
|
109 |
+
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
|
110 |
+
```
|
111 |
+
|
112 |
+
4. **Import relevant dataset in the training script**
|
113 |
+
Update the file in the following section. (`src/train/train.py`)
|
114 |
+
```python
|
115 |
+
from .data import (
|
116 |
+
ImageConditionDataset,
|
117 |
+
Subject200KDateset,
|
118 |
+
MyDataset
|
119 |
+
)
|
120 |
+
...
|
121 |
+
|
122 |
+
# Initialize dataset and dataloader
|
123 |
+
if training_config["dataset"]["type"] == "your_condition_type":
|
124 |
+
...
|
125 |
+
```
|
126 |
+
|
127 |
+
</details>
|
128 |
+
|
129 |
+
## Hardware requirement
|
130 |
+
**Note**: Memory optimization (like dynamic T5 model loading) is pending implementation.
|
131 |
+
|
132 |
+
**Recommanded**
|
133 |
+
- Hardware: 2x NVIDIA H100 GPUs
|
134 |
+
- Memory: ~80GB GPU memory
|
135 |
+
|
136 |
+
**Minimal**
|
137 |
+
- Hardware: 1x NVIDIA L20 GPU
|
138 |
+
- Memory: ~48GB GPU memory
|
train/config/canny_512.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flux_path: "black-forest-labs/FLUX.1-dev"
|
2 |
+
dtype: "bfloat16"
|
3 |
+
|
4 |
+
model:
|
5 |
+
union_cond_attn: true
|
6 |
+
add_cond_attn: false
|
7 |
+
latent_lora: false
|
8 |
+
|
9 |
+
train:
|
10 |
+
batch_size: 1
|
11 |
+
accumulate_grad_batches: 1
|
12 |
+
dataloader_workers: 5
|
13 |
+
save_interval: 1000
|
14 |
+
sample_interval: 100
|
15 |
+
max_steps: -1
|
16 |
+
gradient_checkpointing: true
|
17 |
+
save_path: "runs"
|
18 |
+
|
19 |
+
# Specify the type of condition to use.
|
20 |
+
# Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill"]
|
21 |
+
condition_type: "canny"
|
22 |
+
dataset:
|
23 |
+
type: "img"
|
24 |
+
urls:
|
25 |
+
- "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000046.tar"
|
26 |
+
- "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000045.tar"
|
27 |
+
cache_name: "data_512_2M"
|
28 |
+
condition_size: 512
|
29 |
+
target_size: 512
|
30 |
+
drop_text_prob: 0.1
|
31 |
+
drop_image_prob: 0.1
|
32 |
+
|
33 |
+
wandb:
|
34 |
+
project: "OminiControl"
|
35 |
+
|
36 |
+
lora_config:
|
37 |
+
r: 4
|
38 |
+
lora_alpha: 4
|
39 |
+
init_lora_weights: "gaussian"
|
40 |
+
target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
|
41 |
+
|
42 |
+
optimizer:
|
43 |
+
type: "Prodigy"
|
44 |
+
params:
|
45 |
+
lr: 1
|
46 |
+
use_bias_correction: true
|
47 |
+
safeguard_warmup: true
|
48 |
+
weight_decay: 0.01
|
train/config/cartoon_512.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flux_path: "black-forest-labs/FLUX.1-dev"
|
2 |
+
dtype: "bfloat16"
|
3 |
+
|
4 |
+
model:
|
5 |
+
union_cond_attn: true
|
6 |
+
add_cond_attn: false
|
7 |
+
latent_lora: false
|
8 |
+
|
9 |
+
train:
|
10 |
+
batch_size: 1
|
11 |
+
accumulate_grad_batches: 1
|
12 |
+
dataloader_workers: 8
|
13 |
+
save_interval: 1000
|
14 |
+
sample_interval: 100
|
15 |
+
max_steps: 15000
|
16 |
+
gradient_checkpointing: false
|
17 |
+
save_path: "runs"
|
18 |
+
|
19 |
+
condition_type: "cartoon"
|
20 |
+
dataset:
|
21 |
+
type: "cartoon"
|
22 |
+
condition_size: 512
|
23 |
+
target_size: 512
|
24 |
+
image_size: 512
|
25 |
+
padding: 0
|
26 |
+
drop_text_prob: 0.1
|
27 |
+
drop_image_prob: 0.0
|
28 |
+
|
29 |
+
wandb:
|
30 |
+
project: "OminiControl"
|
31 |
+
|
32 |
+
lora_config:
|
33 |
+
r: 4
|
34 |
+
lora_alpha: 4
|
35 |
+
init_lora_weights: "gaussian"
|
36 |
+
target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
|
37 |
+
|
38 |
+
optimizer:
|
39 |
+
type: "Prodigy"
|
40 |
+
params:
|
41 |
+
lr: 2
|
42 |
+
use_bias_correction: true
|
43 |
+
safeguard_warmup: true
|
44 |
+
weight_decay: 0.01
|
train/config/fill_1024.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flux_path: "black-forest-labs/FLUX.1-dev"
|
2 |
+
dtype: "bfloat16"
|
3 |
+
|
4 |
+
model:
|
5 |
+
union_cond_attn: true
|
6 |
+
add_cond_attn: false
|
7 |
+
latent_lora: false
|
8 |
+
|
9 |
+
train:
|
10 |
+
batch_size: 1
|
11 |
+
accumulate_grad_batches: 1
|
12 |
+
dataloader_workers: 5
|
13 |
+
save_interval: 1000
|
14 |
+
sample_interval: 100
|
15 |
+
max_steps: -1
|
16 |
+
gradient_checkpointing: true
|
17 |
+
save_path: "runs"
|
18 |
+
|
19 |
+
# Specify the type of condition to use.
|
20 |
+
# Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill"]
|
21 |
+
condition_type: "fill"
|
22 |
+
dataset:
|
23 |
+
type: "img"
|
24 |
+
urls:
|
25 |
+
- "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_1024_10K/data_000000.tar"
|
26 |
+
cache_name: "data_1024_10K"
|
27 |
+
condition_size: 1024
|
28 |
+
target_size: 1024
|
29 |
+
drop_text_prob: 0.1
|
30 |
+
drop_image_prob: 0.1
|
31 |
+
|
32 |
+
wandb:
|
33 |
+
project: "OminiControl"
|
34 |
+
|
35 |
+
lora_config:
|
36 |
+
r: 4
|
37 |
+
lora_alpha: 4
|
38 |
+
init_lora_weights: "gaussian"
|
39 |
+
target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
|
40 |
+
|
41 |
+
optimizer:
|
42 |
+
type: "Prodigy"
|
43 |
+
params:
|
44 |
+
lr: 1
|
45 |
+
use_bias_correction: true
|
46 |
+
safeguard_warmup: true
|
47 |
+
weight_decay: 0.01
|
train/config/scene_512.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flux_path: "black-forest-labs/FLUX.1-dev"
|
2 |
+
dtype: "bfloat16"
|
3 |
+
|
4 |
+
model:
|
5 |
+
union_cond_attn: true
|
6 |
+
add_cond_attn: false
|
7 |
+
latent_lora: true
|
8 |
+
|
9 |
+
train:
|
10 |
+
batch_size: 1
|
11 |
+
accumulate_grad_batches: 1
|
12 |
+
dataloader_workers: 5
|
13 |
+
save_interval: 2000
|
14 |
+
sample_interval: 100
|
15 |
+
max_steps: -1
|
16 |
+
gradient_checkpointing: false
|
17 |
+
save_path: "save_path"
|
18 |
+
|
19 |
+
condition_type: "scene"
|
20 |
+
dataset:
|
21 |
+
type: "scene"
|
22 |
+
condition_size: 512
|
23 |
+
target_size: 512
|
24 |
+
image_size: 512
|
25 |
+
padding: 8
|
26 |
+
drop_text_prob: 0.1
|
27 |
+
drop_image_prob: 0.1
|
28 |
+
|
29 |
+
wandb:
|
30 |
+
project: "OminiControl"
|
31 |
+
|
32 |
+
lora_config:
|
33 |
+
r: 128
|
34 |
+
lora_alpha: 128
|
35 |
+
init_lora_weights: "gaussian"
|
36 |
+
target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
|
37 |
+
|
38 |
+
|
39 |
+
optimizer:
|
40 |
+
type: "Prodigy"
|
41 |
+
params:
|
42 |
+
lr: 1
|
43 |
+
use_bias_correction: true
|
44 |
+
safeguard_warmup: true
|
45 |
+
weight_decay: 0.01
|
train/config/sr_512.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flux_path: "black-forest-labs/FLUX.1-dev"
|
2 |
+
dtype: "bfloat16"
|
3 |
+
|
4 |
+
model:
|
5 |
+
union_cond_attn: true
|
6 |
+
add_cond_attn: false
|
7 |
+
latent_lora: false
|
8 |
+
|
9 |
+
train:
|
10 |
+
batch_size: 1
|
11 |
+
accumulate_grad_batches: 1
|
12 |
+
dataloader_workers: 5
|
13 |
+
save_interval: 1000
|
14 |
+
sample_interval: 100
|
15 |
+
max_steps: -1
|
16 |
+
gradient_checkpointing: true
|
17 |
+
save_path: "runs"
|
18 |
+
|
19 |
+
# Specify the type of condition to use.
|
20 |
+
# Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill", "sr"]
|
21 |
+
condition_type: "sr"
|
22 |
+
dataset:
|
23 |
+
type: "img"
|
24 |
+
urls:
|
25 |
+
- "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000046.tar"
|
26 |
+
- "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000045.tar"
|
27 |
+
cache_name: "data_512_2M"
|
28 |
+
condition_size: 256
|
29 |
+
target_size: 512
|
30 |
+
drop_text_prob: 0.1
|
31 |
+
drop_image_prob: 0.1
|
32 |
+
|
33 |
+
wandb:
|
34 |
+
project: "OminiControl"
|
35 |
+
|
36 |
+
lora_config:
|
37 |
+
r: 4
|
38 |
+
lora_alpha: 4
|
39 |
+
init_lora_weights: "gaussian"
|
40 |
+
target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
|
41 |
+
|
42 |
+
optimizer:
|
43 |
+
type: "Prodigy"
|
44 |
+
params:
|
45 |
+
lr: 1
|
46 |
+
use_bias_correction: true
|
47 |
+
safeguard_warmup: true
|
48 |
+
weight_decay: 0.01
|
train/config/subject_512.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flux_path: "black-forest-labs/FLUX.1-dev"
|
2 |
+
dtype: "bfloat16"
|
3 |
+
|
4 |
+
model:
|
5 |
+
union_cond_attn: true
|
6 |
+
add_cond_attn: false
|
7 |
+
latent_lora: true
|
8 |
+
|
9 |
+
train:
|
10 |
+
batch_size: 1
|
11 |
+
accumulate_grad_batches: 1
|
12 |
+
dataloader_workers: 5
|
13 |
+
save_interval: 1000
|
14 |
+
sample_interval: 100
|
15 |
+
max_steps: -1
|
16 |
+
gradient_checkpointing: true
|
17 |
+
save_path: "runs"
|
18 |
+
|
19 |
+
condition_type: "subject"
|
20 |
+
dataset:
|
21 |
+
type: "subject"
|
22 |
+
condition_size: 512
|
23 |
+
target_size: 512
|
24 |
+
image_size: 512
|
25 |
+
padding: 8
|
26 |
+
drop_text_prob: 0.1
|
27 |
+
drop_image_prob: 0.1
|
28 |
+
|
29 |
+
wandb:
|
30 |
+
project: "OminiControl"
|
31 |
+
|
32 |
+
lora_config:
|
33 |
+
r: 4
|
34 |
+
lora_alpha: 4
|
35 |
+
init_lora_weights: "gaussian"
|
36 |
+
target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
|
37 |
+
|
38 |
+
optimizer:
|
39 |
+
type: "Prodigy"
|
40 |
+
params:
|
41 |
+
lr: 1
|
42 |
+
use_bias_correction: true
|
43 |
+
safeguard_warmup: true
|
44 |
+
weight_decay: 0.01
|
train/requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.31.0
|
2 |
+
transformers
|
3 |
+
peft
|
4 |
+
opencv-python
|
5 |
+
protobuf
|
6 |
+
sentencepiece
|
7 |
+
gradio
|
8 |
+
jupyter
|
9 |
+
torchao
|
10 |
+
|
11 |
+
lightning
|
12 |
+
datasets
|
13 |
+
torchvision
|
14 |
+
prodigyopt
|
15 |
+
wandb
|
train/script/data_download/data_download1.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
huggingface-cli download --repo-type dataset Yuanshi/Subjects200K
|
train/script/data_download/data_download2.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
huggingface-cli download --repo-type dataset jackyhate/text-to-image-2M data_512_2M/data_000045.tar
|
2 |
+
huggingface-cli download --repo-type dataset jackyhate/text-to-image-2M data_512_2M/data_000046.tar
|
3 |
+
huggingface-cli download --repo-type dataset jackyhate/text-to-image-2M data_1024_10K/data_000000.tar
|
train/script/train_canny.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Specify the config file path and the GPU devices to use
|
2 |
+
# export CUDA_VISIBLE_DEVICES=0,1
|
3 |
+
|
4 |
+
# Specify the config file path
|
5 |
+
export XFL_CONFIG=./train/config/canny_512.yaml
|
6 |
+
|
7 |
+
# Specify the WANDB API key
|
8 |
+
# export WANDB_API_KEY='YOUR_WANDB_API_KEY'
|
9 |
+
|
10 |
+
echo $XFL_CONFIG
|
11 |
+
export TOKENIZERS_PARALLELISM=true
|
12 |
+
|
13 |
+
accelerate launch --main_process_port 41353 -m src.train.train
|
train/script/train_cartoon.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Specify the config file path and the GPU devices to use
|
2 |
+
# export CUDA_VISIBLE_DEVICES=0,1
|
3 |
+
|
4 |
+
# Specify the config file path
|
5 |
+
export XFL_CONFIG=./train/config/cartoon_512.yaml
|
6 |
+
|
7 |
+
export HF_HUB_CACHE=./cache
|
8 |
+
|
9 |
+
# Specify the WANDB API key
|
10 |
+
# export WANDB_API_KEY='YOUR_WANDB_API_KEY'
|
11 |
+
|
12 |
+
echo $XFL_CONFIG
|
13 |
+
export TOKENIZERS_PARALLELISM=true
|
14 |
+
|
15 |
+
accelerate launch --main_process_port 41353 -m src.train.train
|
train/script/train_scene.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Specify the config file path and the GPU devices to use
|
2 |
+
# export CUDA_VISIBLE_DEVICES=0,1
|
3 |
+
|
4 |
+
# Specify the config file path
|
5 |
+
export XFL_CONFIG=./train/config/scene_512.yaml
|
6 |
+
|
7 |
+
# Specify the WANDB API key
|
8 |
+
# export WANDB_API_KEY='YOUR_WANDB_API_KEY'
|
9 |
+
|
10 |
+
echo $XFL_CONFIG
|
11 |
+
export TOKENIZERS_PARALLELISM=true
|
12 |
+
|
13 |
+
accelerate launch --main_process_port 41353 -m src.train.train
|
train/script/train_subject.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Specify the config file path and the GPU devices to use
|
2 |
+
# export CUDA_VISIBLE_DEVICES=0,1
|
3 |
+
|
4 |
+
# Specify the config file path
|
5 |
+
export XFL_CONFIG=./train/config/subject_512.yaml
|
6 |
+
|
7 |
+
# Specify the WANDB API key
|
8 |
+
# export WANDB_API_KEY='YOUR_WANDB_API_KEY'
|
9 |
+
|
10 |
+
echo $XFL_CONFIG
|
11 |
+
export TOKENIZERS_PARALLELISM=true
|
12 |
+
|
13 |
+
accelerate launch --main_process_port 41353 -m src.train.train
|
utils.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from diffusers.pipelines import FluxPipeline
|
4 |
+
from src.flux.condition import Condition
|
5 |
+
from PIL import Image
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import base64
|
10 |
+
import io
|
11 |
+
import re
|
12 |
+
from PIL import Image, ImageFilter
|
13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
14 |
+
from scipy.ndimage import binary_dilation
|
15 |
+
import cv2
|
16 |
+
import openai
|
17 |
+
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
|
18 |
+
|
19 |
+
|
20 |
+
from src.flux.generate import generate, seed_everything
|
21 |
+
|
22 |
+
try:
|
23 |
+
from mmengine.visualization import Visualizer
|
24 |
+
except ImportError:
|
25 |
+
Visualizer = None
|
26 |
+
print("Warning: mmengine is not installed, visualization is disabled.")
|
27 |
+
|
28 |
+
import re
|
29 |
+
|
30 |
+
def encode_image_to_datauri(path, size=(512, 512)):
|
31 |
+
with Image.open(path).convert('RGB') as img:
|
32 |
+
img = img.resize(size, Image.LANCZOS)
|
33 |
+
buffer = io.BytesIO()
|
34 |
+
img.save(buffer, format='PNG')
|
35 |
+
b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
36 |
+
return b64
|
37 |
+
# return f"data:image/png;base64,{b64}"
|
38 |
+
|
39 |
+
|
40 |
+
@retry(
|
41 |
+
reraise=True,
|
42 |
+
wait=wait_exponential(min=1, max=60),
|
43 |
+
stop=stop_after_attempt(6),
|
44 |
+
retry=retry_if_exception_type((openai.error.RateLimitError, openai.error.APIError))
|
45 |
+
)
|
46 |
+
def cot_with_gpt(image_uri, instruction):
|
47 |
+
response = openai.ChatCompletion.create(
|
48 |
+
model="gpt-4o",
|
49 |
+
messages=[
|
50 |
+
{
|
51 |
+
"role": "user",
|
52 |
+
"content": [
|
53 |
+
{"type": "text", "text": f'''
|
54 |
+
Now you are an expert in image editing. Based on the given single image, what atomic image editing instructions should be if the user wants to {instruction}? Let's think step by step.
|
55 |
+
Atomic instructions include 13 categories as follows:
|
56 |
+
- Add: e.g.: add a car on the road
|
57 |
+
- Remove: e.g.: remove the sofa in the image
|
58 |
+
- Color Change: e.g.: change the color of the shoes to blue
|
59 |
+
- Material Change: e.g.: change the material of the sign like stone
|
60 |
+
- Action Change: e.g.: change the action of the boy to raising hands
|
61 |
+
- Expression Change: e.g.: change the expression to smile
|
62 |
+
- Replace: e.g.: replace the coffee with an apple
|
63 |
+
- Background Change: e.g.: change the background into forest
|
64 |
+
- Appearance Change: e.g.: make the cup have a floral pattern
|
65 |
+
- Move: e.g.: move the plane to the left
|
66 |
+
- Resize: e.g.: enlarge the clock
|
67 |
+
- Tone Transfer: e.g.: change the weather to foggy
|
68 |
+
- Style Change: e.g.: make the style of the image to cartoon
|
69 |
+
Respond *only* with a numbered list.
|
70 |
+
Each line must begin with the category in square brackets, then the instruction. Please strictly follow the atomic categories.
|
71 |
+
The operation (what) and the target (to what) are crystal clear.
|
72 |
+
Do not split replace to add and remove.
|
73 |
+
For example:
|
74 |
+
“1. [Add] add a car on the road\n
|
75 |
+
2. [Color Change] change the color of the shoes to blue\n
|
76 |
+
3. [Move] move the lamp to the left\n"
|
77 |
+
Do not include any extra text, explanations, JSON or markdown—just the list.
|
78 |
+
'''},
|
79 |
+
{
|
80 |
+
"type": "image_url",
|
81 |
+
"image_url": {
|
82 |
+
"url": f"data:image/jpeg;base64,{image_uri}"
|
83 |
+
}
|
84 |
+
},
|
85 |
+
],
|
86 |
+
}
|
87 |
+
],
|
88 |
+
max_tokens=300,
|
89 |
+
)
|
90 |
+
text = response.choices[0].message.content.strip()
|
91 |
+
print(text)
|
92 |
+
|
93 |
+
categories, instructions = extract_instructions(text)
|
94 |
+
return categories, instructions
|
95 |
+
|
96 |
+
|
97 |
+
def extract_instructions(text):
|
98 |
+
categories = []
|
99 |
+
instructions = []
|
100 |
+
|
101 |
+
pattern = r'^\s*\d+\.\s*\[(.*?)\]\s*(.*?)$'
|
102 |
+
|
103 |
+
for line in text.split('\n'):
|
104 |
+
line = line.strip()
|
105 |
+
if not line:
|
106 |
+
continue
|
107 |
+
|
108 |
+
match = re.match(pattern, line)
|
109 |
+
if match:
|
110 |
+
category = match.group(1).strip()
|
111 |
+
instruction = match.group(2).strip()
|
112 |
+
|
113 |
+
if category and instruction:
|
114 |
+
categories.append(category)
|
115 |
+
instructions.append(instruction)
|
116 |
+
|
117 |
+
return categories, instructions
|
118 |
+
|
119 |
+
def extract_last_bbox(result):
|
120 |
+
pattern = r'\[?<span data-type="inline-math" data-value="XCcoW15cJ10rKVwnLFxzKlxbXHMqKFxkKylccyosXHMqKFxkKylccyosXHMqKFxkKylccyosXHMqKFxkKylccypcXQ=="></span>\]?'
|
121 |
+
matches = re.findall(pattern, result)
|
122 |
+
|
123 |
+
if not matches:
|
124 |
+
simple_pattern = r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]'
|
125 |
+
simple_matches = re.findall(simple_pattern, result)
|
126 |
+
if simple_matches:
|
127 |
+
x0, y0, x1, y1 = map(int, simple_matches[-1])
|
128 |
+
return [x0, y0, x1, y1]
|
129 |
+
else:
|
130 |
+
print(f"No bounding boxes found, please try again: {result}")
|
131 |
+
return None
|
132 |
+
|
133 |
+
last_match = matches[-1]
|
134 |
+
x0, y0, x1, y1 = map(int, last_match[1:])
|
135 |
+
return x0, y0, x1, y1
|
136 |
+
|
137 |
+
|
138 |
+
def infer_with_DiT(task, image, instruction, category):
|
139 |
+
# seed_everything(3407)
|
140 |
+
|
141 |
+
if task == 'RoI Inpainting':
|
142 |
+
if category == 'Add' or category == 'Replace':
|
143 |
+
lora_path = "weights/add.safetensors"
|
144 |
+
added = extract_object_with_gpt(instruction)
|
145 |
+
instruction_dit = f"add {added} on the black region"
|
146 |
+
elif category == 'Remove' or category == 'Action Change':
|
147 |
+
lora_path = "weights/remove.safetensors"
|
148 |
+
instruction_dit = f"Fill the hole of the image"
|
149 |
+
|
150 |
+
condition = Condition("scene", image, position_delta=(0, 0))
|
151 |
+
elif task == 'RoI Editing':
|
152 |
+
image = Image.open(image).convert('RGB').resize((512, 512))
|
153 |
+
condition = Condition("scene", image, position_delta=(0, -32))
|
154 |
+
instruction_dit = instruction
|
155 |
+
if category == 'Action Change':
|
156 |
+
lora_path = "weights/action.safetensors"
|
157 |
+
elif category == 'Expression Change':
|
158 |
+
lora_path = "weights/expression.safetensors"
|
159 |
+
elif category == 'Add':
|
160 |
+
lora_path = "weights/addition.safetensors"
|
161 |
+
elif category == 'Material Change':
|
162 |
+
lora_path = "weights/material.safetensors"
|
163 |
+
elif category == 'Color Change':
|
164 |
+
lora_path = "weights/color.safetensors"
|
165 |
+
elif category == 'Background Change':
|
166 |
+
lora_path = "weights/bg.safetensors"
|
167 |
+
elif category == 'Appearance Change':
|
168 |
+
lora_path = "weights/appearance.safetensors"
|
169 |
+
|
170 |
+
elif task == 'RoI Compositioning':
|
171 |
+
lora_path = "weights/fusion.safetensors"
|
172 |
+
condition = Condition("scene", image, position_delta=(0, 0))
|
173 |
+
instruction_dit = "inpaint the black-bordered region so that the object's edges blend smoothly with the background"
|
174 |
+
|
175 |
+
elif task == 'Global Transformation':
|
176 |
+
image = Image.open(image).convert('RGB').resize((512, 512))
|
177 |
+
instruction_dit = instruction
|
178 |
+
lora_path = "weights/overall.safetensors"
|
179 |
+
|
180 |
+
condition = Condition("scene", image, position_delta=(0, -32))
|
181 |
+
else:
|
182 |
+
raise ValueError(f"Invalid task: '{task}'")
|
183 |
+
pipe = FluxPipeline.from_pretrained(
|
184 |
+
"black-forest-labs/FLUX.1-dev",
|
185 |
+
torch_dtype=torch.bfloat16
|
186 |
+
)
|
187 |
+
|
188 |
+
pipe = pipe.to("cuda")
|
189 |
+
|
190 |
+
pipe.load_lora_weights(
|
191 |
+
"Cicici1109/IEAP",
|
192 |
+
weight_name=lora_path,
|
193 |
+
adapter_name="scene",
|
194 |
+
)
|
195 |
+
result_img = generate(
|
196 |
+
pipe,
|
197 |
+
prompt=instruction_dit,
|
198 |
+
conditions=[condition],
|
199 |
+
config_path = "train/config/scene_512.yaml",
|
200 |
+
num_inference_steps=28,
|
201 |
+
height=512,
|
202 |
+
width=512,
|
203 |
+
).images[0]
|
204 |
+
# result_img
|
205 |
+
if task == 'RoI Editing' and category == 'Action Change':
|
206 |
+
text_roi = extract_object_with_gpt(instruction)
|
207 |
+
instruction_loc = f"<image>Please segment {text_roi}."
|
208 |
+
# (model, tokenizer, image_path, instruction, work_dir, dilate):
|
209 |
+
img = result_img
|
210 |
+
# print(f"Instruction: {instruction_loc}")
|
211 |
+
|
212 |
+
model, tokenizer = load_model("ByteDance/Sa2VA-8B")
|
213 |
+
|
214 |
+
result = model.predict_forward(
|
215 |
+
image=img,
|
216 |
+
text=instruction_loc,
|
217 |
+
tokenizer=tokenizer,
|
218 |
+
)
|
219 |
+
|
220 |
+
prediction = result['prediction']
|
221 |
+
# print(f"Model Output: {prediction}")
|
222 |
+
|
223 |
+
if '[SEG]' in prediction and 'prediction_masks' in result:
|
224 |
+
pred_mask = result['prediction_masks'][0]
|
225 |
+
pred_mask_np = np.squeeze(np.array(pred_mask))
|
226 |
+
|
227 |
+
## obtain region bbox
|
228 |
+
rows = np.any(pred_mask_np, axis=1)
|
229 |
+
cols = np.any(pred_mask_np, axis=0)
|
230 |
+
if not np.any(rows) or not np.any(cols):
|
231 |
+
print("Warning: Mask is empty, cannot compute bounding box")
|
232 |
+
return img
|
233 |
+
|
234 |
+
y0, y1 = np.where(rows)[0][[0, -1]]
|
235 |
+
x0, x1 = np.where(cols)[0][[0, -1]]
|
236 |
+
|
237 |
+
changed_instance = crop_masked_region(result_img, pred_mask_np)
|
238 |
+
|
239 |
+
return changed_instance, x0, y1, 1
|
240 |
+
|
241 |
+
|
242 |
+
return result_img
|
243 |
+
|
244 |
+
def load_model(model_path):
|
245 |
+
model = AutoModelForCausalLM.from_pretrained(
|
246 |
+
model_path,
|
247 |
+
torch_dtype="auto",
|
248 |
+
device_map="auto",
|
249 |
+
trust_remote_code=True
|
250 |
+
).eval()
|
251 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
252 |
+
return model, tokenizer
|
253 |
+
|
254 |
+
def extract_object_with_gpt(instruction):
|
255 |
+
system_prompt = (
|
256 |
+
"You are a helpful assistant that extracts the object or target being edited in an image editing instruction. "
|
257 |
+
"Only return a concise noun phrase describing the object. "
|
258 |
+
"Examples:\n"
|
259 |
+
"- Input: 'Remove the dog' → Output: 'the dog'\n"
|
260 |
+
"- Input: 'Add a hat on the dog' → Output: 'a hat'\n"
|
261 |
+
"- Input: 'Replace the biggest bear with a tiger' → Output: 'the biggest bear'\n"
|
262 |
+
"- Input: 'Change the action of the girl to riding' → Output: 'the girl'\n"
|
263 |
+
"- Input: 'Move the red car on the lake' → Output: 'the red car'\n"
|
264 |
+
"- Input: 'Minify the carrot on the rabbit's hand' → Output: 'the carrot on the rabbit's hand'\n"
|
265 |
+
"- Input: 'Swap the location of the dog and the cat' → Output: 'the dog and the cat'\n"
|
266 |
+
"Now extract the object for this instruction:"
|
267 |
+
)
|
268 |
+
|
269 |
+
try:
|
270 |
+
response = openai.ChatCompletion.create(
|
271 |
+
model="gpt-3.5-turbo",
|
272 |
+
messages=[
|
273 |
+
{"role": "system", "content": system_prompt},
|
274 |
+
{"role": "user", "content": instruction}
|
275 |
+
],
|
276 |
+
temperature=0.2,
|
277 |
+
max_tokens=20,
|
278 |
+
)
|
279 |
+
object_phrase = response.choices[0].message['content'].strip().strip('"')
|
280 |
+
print(f"Identified object: {object_phrase}")
|
281 |
+
return object_phrase
|
282 |
+
except Exception as e:
|
283 |
+
print(f"GPT extraction failed: {e}")
|
284 |
+
return instruction
|
285 |
+
|
286 |
+
def extract_region_with_gpt(instruction):
|
287 |
+
system_prompt = (
|
288 |
+
"You are a helpful assistant that extracts target region being edited in an image editing instruction. "
|
289 |
+
"Only return a concise noun phrase describing the target region. "
|
290 |
+
"Examples:\n"
|
291 |
+
"- Input: 'Add a red hat to the man on the left' → Output: 'the man on the left'\n"
|
292 |
+
"- Input: 'Add a cat beside the dog' → Output: 'the dog'\n"
|
293 |
+
"Now extract the target region for this instruction:"
|
294 |
+
)
|
295 |
+
|
296 |
+
try:
|
297 |
+
response = openai.ChatCompletion.create(
|
298 |
+
model="gpt-3.5-turbo",
|
299 |
+
messages=[
|
300 |
+
{"role": "system", "content": system_prompt},
|
301 |
+
{"role": "user", "content": instruction}
|
302 |
+
],
|
303 |
+
temperature=0.2,
|
304 |
+
max_tokens=20,
|
305 |
+
)
|
306 |
+
object_phrase = response.choices[0].message['content'].strip().strip('"')
|
307 |
+
# print(f"Identified object: {object_phrase}")
|
308 |
+
return object_phrase
|
309 |
+
except Exception as e:
|
310 |
+
print(f"GPT extraction failed: {e}")
|
311 |
+
return instruction
|
312 |
+
|
313 |
+
def get_masked(mask, image):
|
314 |
+
if mask.shape[:2] != image.size[::-1]:
|
315 |
+
raise ValueError(f"Mask size {mask.shape[:2]} does not match image size {image.size}")
|
316 |
+
|
317 |
+
image_array = np.array(image)
|
318 |
+
image_array[mask] = [0, 0, 0]
|
319 |
+
|
320 |
+
return Image.fromarray(image_array)
|
321 |
+
|
322 |
+
def bbox_to_mask(x0, y0, x1, y1, image_shape=(512, 512), fill_value=True):
|
323 |
+
height, width = image_shape
|
324 |
+
|
325 |
+
mask = np.zeros((height, width), dtype=bool)
|
326 |
+
|
327 |
+
x0 = max(0, int(x0))
|
328 |
+
y0 = max(0, int(y0))
|
329 |
+
x1 = min(width, int(x1))
|
330 |
+
y1 = min(height, int(y1))
|
331 |
+
|
332 |
+
if x0 >= x1 or y0 >= y1:
|
333 |
+
print("Warning: Invalid bounding box coordinates")
|
334 |
+
return mask
|
335 |
+
|
336 |
+
mask[y0:y1, x0:x1] = fill_value
|
337 |
+
|
338 |
+
return mask
|
339 |
+
|
340 |
+
def combine_bbox(text, x0, y0, x1, y1):
|
341 |
+
bbox = [x0, y0, x1, y1]
|
342 |
+
return [(text, bbox)]
|
343 |
+
|
344 |
+
def crop_masked_region(image, pred_mask_np):
|
345 |
+
if not isinstance(image, Image.Image):
|
346 |
+
raise ValueError("The input image is not a PIL Image object")
|
347 |
+
if not isinstance(pred_mask_np, np.ndarray) or pred_mask_np.dtype != bool:
|
348 |
+
raise ValueError("pred_mask_np must be a NumPy array of boolean type")
|
349 |
+
if pred_mask_np.shape[:2] != image.size[::-1]:
|
350 |
+
raise ValueError(f"Mask size {pred_mask_np.shape[:2]} does not match image size {image.size}")
|
351 |
+
|
352 |
+
image_rgba = image.convert("RGBA")
|
353 |
+
image_array = np.array(image_rgba)
|
354 |
+
|
355 |
+
rows = np.any(pred_mask_np, axis=1)
|
356 |
+
cols = np.any(pred_mask_np, axis=0)
|
357 |
+
|
358 |
+
if not np.any(rows) or not np.any(cols):
|
359 |
+
print("Warning: Mask is empty, cannot compute bounding box")
|
360 |
+
return image_rgba
|
361 |
+
|
362 |
+
y0, y1 = np.where(rows)[0][[0, -1]]
|
363 |
+
x0, x1 = np.where(cols)[0][[0, -1]]
|
364 |
+
|
365 |
+
cropped_image = image_array[y0:y1+1, x0:x1+1].copy()
|
366 |
+
cropped_mask = pred_mask_np[y0:y1+1, x0:x1+1]
|
367 |
+
|
368 |
+
alpha_channel = np.ones(cropped_mask.shape, dtype=np.uint8) * 255
|
369 |
+
alpha_channel[~cropped_mask] = 0
|
370 |
+
|
371 |
+
cropped_image[:, :, 3] = alpha_channel
|
372 |
+
|
373 |
+
return Image.fromarray(cropped_image, mode='RGBA')
|
374 |
+
|
375 |
+
def roi_localization(image, instruction, category): # add, remove, replace, action change, move, resize
|
376 |
+
model, tokenizer = load_model("ByteDance/Sa2VA-8B")
|
377 |
+
if category == 'Add':
|
378 |
+
text_roi = extract_region_with_gpt(instruction)
|
379 |
+
else:
|
380 |
+
text_roi = extract_object_with_gpt(instruction)
|
381 |
+
instruction_loc = f"<image>Please segment {text_roi}."
|
382 |
+
img = Image.open(image).convert('RGB').resize((512, 512))
|
383 |
+
print(f"Processing image: {os.path.basename(image)}, Instruction: {instruction_loc}")
|
384 |
+
|
385 |
+
result = model.predict_forward(
|
386 |
+
image=img,
|
387 |
+
text=instruction_loc,
|
388 |
+
tokenizer=tokenizer,
|
389 |
+
)
|
390 |
+
|
391 |
+
prediction = result['prediction']
|
392 |
+
# print(f"Model Output: {prediction}")
|
393 |
+
|
394 |
+
if '[SEG]' in prediction and 'prediction_masks' in result:
|
395 |
+
pred_mask = result['prediction_masks'][0]
|
396 |
+
pred_mask_np = np.squeeze(np.array(pred_mask))
|
397 |
+
if category == 'Add':
|
398 |
+
## obtain region bbox
|
399 |
+
rows = np.any(pred_mask_np, axis=1)
|
400 |
+
cols = np.any(pred_mask_np, axis=0)
|
401 |
+
if not np.any(rows) or not np.any(cols):
|
402 |
+
print("Warning: Mask is empty, cannot compute bounding box")
|
403 |
+
return img
|
404 |
+
|
405 |
+
y0, y1 = np.where(rows)[0][[0, -1]]
|
406 |
+
x0, x1 = np.where(cols)[0][[0, -1]]
|
407 |
+
|
408 |
+
## obtain inpainting bbox
|
409 |
+
bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
|
410 |
+
# print(bbox)
|
411 |
+
x0, y0, x1, y1 = layout_add(bbox, instruction)
|
412 |
+
mask = bbox_to_mask(x0, y0, x1, y1)
|
413 |
+
## make it black
|
414 |
+
masked_img = get_masked(mask, img)
|
415 |
+
elif category == 'Move' or category == 'Resize':
|
416 |
+
dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
|
417 |
+
masked_img = get_masked(dilated_original_mask, img)
|
418 |
+
## obtain region bbox
|
419 |
+
rows = np.any(pred_mask_np, axis=1)
|
420 |
+
cols = np.any(pred_mask_np, axis=0)
|
421 |
+
if not np.any(rows) or not np.any(cols):
|
422 |
+
print("Warning: Mask is empty, cannot compute bounding box")
|
423 |
+
return img
|
424 |
+
|
425 |
+
y0, y1 = np.where(rows)[0][[0, -1]]
|
426 |
+
x0, x1 = np.where(cols)[0][[0, -1]]
|
427 |
+
|
428 |
+
## obtain inpainting bbox
|
429 |
+
bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
|
430 |
+
# print(bbox)
|
431 |
+
x0_new, y0_new, x1_new, y1_new, = layout_change(bbox, instruction)
|
432 |
+
scale = (y1_new - y0_new) / (y1 - y0)
|
433 |
+
# print(scale)
|
434 |
+
changed_instance = crop_masked_region(img, pred_mask_np)
|
435 |
+
|
436 |
+
return masked_img, changed_instance, x0_new, y1_new, scale
|
437 |
+
else:
|
438 |
+
dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
|
439 |
+
masked_img = get_masked(dilated_original_mask, img)
|
440 |
+
|
441 |
+
return masked_img
|
442 |
+
|
443 |
+
else:
|
444 |
+
print("No valid mask found in the prediction.")
|
445 |
+
return None
|
446 |
+
|
447 |
+
def fusion(background, foreground, x, y, scale):
|
448 |
+
background = background.convert("RGBA")
|
449 |
+
bg_width, bg_height = background.size
|
450 |
+
|
451 |
+
fg_width, fg_height = foreground.size
|
452 |
+
new_size = (int(fg_width * scale), int(fg_height * scale))
|
453 |
+
foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
|
454 |
+
|
455 |
+
left = x
|
456 |
+
top = y - new_size[1]
|
457 |
+
|
458 |
+
canvas = Image.new('RGBA', (bg_width, bg_height), (0, 0, 0, 0))
|
459 |
+
canvas.paste(foreground_resized, (left, top), foreground_resized)
|
460 |
+
masked_foreground = process_edge(canvas, left, top, new_size)
|
461 |
+
result = Image.alpha_composite(background, masked_foreground)
|
462 |
+
|
463 |
+
return result
|
464 |
+
|
465 |
+
def process_edge(canvas, left, top, size):
|
466 |
+
width, height = size
|
467 |
+
|
468 |
+
region = canvas.crop((left, top, left + width, top + height))
|
469 |
+
alpha = region.getchannel('A')
|
470 |
+
|
471 |
+
dilated_alpha = alpha.filter(ImageFilter.MaxFilter(5))
|
472 |
+
eroded_alpha = alpha.filter(ImageFilter.MinFilter(3))
|
473 |
+
|
474 |
+
edge_mask = Image.new('L', (width, height), 0)
|
475 |
+
edge_pixels = edge_mask.load()
|
476 |
+
dilated_pixels = dilated_alpha.load()
|
477 |
+
eroded_pixels = eroded_alpha.load()
|
478 |
+
|
479 |
+
for y in range(height):
|
480 |
+
for x in range(width):
|
481 |
+
if dilated_pixels[x, y] > 0 and eroded_pixels[x, y] == 0:
|
482 |
+
edge_pixels[x, y] = 255
|
483 |
+
|
484 |
+
black_edge = Image.new('RGBA', (width, height), (0, 0, 0, 0))
|
485 |
+
black_edge.putalpha(edge_mask)
|
486 |
+
|
487 |
+
canvas.paste(black_edge, (left, top), black_edge)
|
488 |
+
|
489 |
+
return canvas
|
490 |
+
|
491 |
+
def combine_text_and_bbox(text_roi, x0, y0, x1, y1):
|
492 |
+
return [(text_roi, [x0, y0, x1, y1])]
|
493 |
+
|
494 |
+
@retry(
|
495 |
+
reraise=True,
|
496 |
+
wait=wait_exponential(min=1, max=60),
|
497 |
+
stop=stop_after_attempt(6),
|
498 |
+
retry=retry_if_exception_type((openai.error.RateLimitError, openai.error.APIError))
|
499 |
+
)
|
500 |
+
def layout_add(bbox, instruction):
|
501 |
+
response = openai.ChatCompletion.create(
|
502 |
+
model="gpt-4o",
|
503 |
+
messages=[
|
504 |
+
{
|
505 |
+
"role": "user",
|
506 |
+
"content": [
|
507 |
+
{"type": "text", "text": f'''
|
508 |
+
You are an intelligent bounding box editor. I will provide you with the current bounding boxes and an add editing instruction.
|
509 |
+
Your task is to determine the new bounding box of the added object. Let's think step by step.
|
510 |
+
The images are of size 512x512. The top-left corner has coordinate [0, 0]. The bottom-right corner has coordinnate [512, 512].
|
511 |
+
The bounding boxes should not go beyond the image boundaries. The new box must be large enough to reasonably encompass the added object in a visually appropriate way, allowing for partial overlap with existing objects when it comes to accessories like hat, necklace. etc.
|
512 |
+
Each bounding box should be in the format of (object name,[top-left x coordinate, top-left y coordinate, bottom-right x coordinate, bottom-right y coordinate]).
|
513 |
+
Only return the bounding box of the newly added object. Do not include the existing bounding boxes.
|
514 |
+
Please consider the semantic information of the layout, preserve semantic relations.
|
515 |
+
If needed, you can make reasonable guesses. Please refer to the examples below:
|
516 |
+
Input bounding boxes: [('a green car', [21, 281, 232, 440])]
|
517 |
+
Editing instruction: Add a bird on the green car.
|
518 |
+
Output bounding boxes: [('a bird', [80, 150, 180, 281])]
|
519 |
+
Input bounding boxes: [('stool', [300, 350, 380, 450])]
|
520 |
+
Editing instruction: Add a cat to the left of the stool.
|
521 |
+
Output bounding boxes: [('a cat', [180, 250, 300, 450])]
|
522 |
+
|
523 |
+
Here are some examples to illustrate appropriate overlapping for better visual effects:
|
524 |
+
Input bounding boxes: [('the white cat', [200, 300, 320, 420])]
|
525 |
+
Editing instruction: Add a hat on the white cat.
|
526 |
+
Output bounding boxes: [('a hat', [200, 150, 320, 330])]
|
527 |
+
Now, the current bounding boxes is {bbox}, the instruction is {instruction}.
|
528 |
+
'''},
|
529 |
+
],
|
530 |
+
}
|
531 |
+
],
|
532 |
+
max_tokens=1000,
|
533 |
+
)
|
534 |
+
|
535 |
+
result = response.choices[0].message.content.strip()
|
536 |
+
|
537 |
+
bbox = extract_last_bbox(result)
|
538 |
+
return bbox
|
539 |
+
|
540 |
+
@retry(
|
541 |
+
reraise=True,
|
542 |
+
wait=wait_exponential(min=1, max=60),
|
543 |
+
stop=stop_after_attempt(6),
|
544 |
+
retry=retry_if_exception_type((openai.error.RateLimitError, openai.error.APIError))
|
545 |
+
)
|
546 |
+
def layout_change(bbox, instruction):
|
547 |
+
response = openai.ChatCompletion.create(
|
548 |
+
model="gpt-4o",
|
549 |
+
messages=[
|
550 |
+
{
|
551 |
+
"role": "user",
|
552 |
+
"content": [
|
553 |
+
{"type": "text", "text": f'''
|
554 |
+
You are an intelligent bounding box editor. I will provide you with the current bounding boxes and the editing instruction.
|
555 |
+
Your task is to generate the new bounding boxes after editing.
|
556 |
+
The images are of size 512x512. The top-left corner has coordinate [0, 0]. The bottom-right corner has coordinnate [512, 512].
|
557 |
+
The bounding boxes should not overlap or go beyond the image boundaries.
|
558 |
+
Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, bottom-right x coordinate, bottom-right y coordinate]).
|
559 |
+
Do not add new objects or delete any object provided in the bounding boxes. Do not change the size or the shape of any object unless the instruction requires so.
|
560 |
+
Please consider the semantic information of the layout.
|
561 |
+
When resizing, keep the bottom-left corner fixed by default. When swaping locations, change according to the center point.
|
562 |
+
If needed, you can make reasonable guesses. Please refer to the examples below:
|
563 |
+
|
564 |
+
Input bounding boxes: [('a car', [21, 281, 232, 440])]
|
565 |
+
Editing instruction: Move the car to the right.
|
566 |
+
Output bounding boxes: [('a car', [121, 281, 332, 440])]
|
567 |
+
|
568 |
+
Input bounding boxes: [("bed", [50, 300, 450, 450]), ("pillow", [200, 200, 300, 230])]
|
569 |
+
Editing instruction: Move the pillow to the left side of the bed.
|
570 |
+
Output bounding boxes: [("bed", [50, 300, 450, 450]), ("pillow", [70, 270, 170, 300])]
|
571 |
+
|
572 |
+
Input bounding boxes: [("dog", [150, 250, 250, 300])]
|
573 |
+
Editing instruction: Enlarge the dog.
|
574 |
+
Output bounding boxes: [("dog", [150, 225, 300, 300])]
|
575 |
+
|
576 |
+
Input bounding boxes: [("chair", [100, 350, 200, 450]), ("lamp", [300, 200, 360, 300])]
|
577 |
+
Editing instruction: Swap the location of the chair and the lamp.
|
578 |
+
Output bounding boxes: [("chair", [280, 200, 380, 300]), ("lamp", [120, 350, 180, 450])]
|
579 |
+
|
580 |
+
|
581 |
+
Now, the current bounding boxes is {bbox}, the instruction is {instruction}. Let's think step by step, and output the edited layout.
|
582 |
+
'''},
|
583 |
+
],
|
584 |
+
}
|
585 |
+
],
|
586 |
+
max_tokens=1000,
|
587 |
+
)
|
588 |
+
result = response.choices[0].message.content.strip()
|
589 |
+
|
590 |
+
bbox = extract_last_bbox(result)
|
591 |
+
return bbox
|