Yollm commited on
Commit
f08d17a
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: IEAP
3
+ emoji: 👀
4
+ colorFrom: gray
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.32.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: A demo for IEAP
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from utils import encode_image_to_datauri, cot_with_gpt, extract_instructions, infer_with_DiT, roi_localization, fusion
4
+ import openai
5
+ import os
6
+ import uuid
7
+ from src.flux.generate import generate, seed_everything
8
+
9
+
10
+ def process_image(api_key, seed, image, prompt):
11
+ if not api_key:
12
+ raise gr.Error("❌ Please enter a valid OpenAI API key.")
13
+
14
+ openai.api_key = api_key
15
+
16
+ # Generate a unique image ID to avoid file name conflict
17
+ image_id = str(uuid.uuid4())
18
+ seed_everything(seed)
19
+ input_path = f"input_{image_id}.png"
20
+ image.save(input_path)
21
+
22
+ try:
23
+ uri = encode_image_to_datauri(input_path)
24
+ categories, instructions = cot_with_gpt(uri, prompt)
25
+ # categories = ['Tone Transfer', 'Style Change']
26
+ # instructions = ['Change the time to night', 'Change the style to watercolor']
27
+
28
+ if not categories or not instructions:
29
+ raise gr.Error("No editing steps returned by GPT. Try a more specific instruction.")
30
+
31
+ intermediate_images = []
32
+ current_image_path = input_path
33
+
34
+ for i, (category, instruction) in enumerate(zip(categories, instructions)):
35
+ print(f"[Step {i}] Category: {category} | Instruction: {instruction}")
36
+ step_prefix = f"{image_id}_{i}"
37
+
38
+ if category in ('Add', 'Remove', 'Replace'):
39
+ if category == 'Add':
40
+ edited_image = infer_with_DiT('RoI Editing', current_image_path, instruction, category)
41
+ else:
42
+ mask_image = roi_localization(current_image_path, instruction, category)
43
+ edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, category)
44
+
45
+ elif category == 'Action Change':
46
+ mask_image = roi_localization(current_image_path, instruction, category)
47
+ inpainted = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove')
48
+ changed_instance, x0, y1, scale = infer_with_DiT('RoI Editing', current_image_path, instruction, category)
49
+ fusion_image = fusion(inpainted, changed_instance, x0, y1, scale)
50
+ edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
51
+
52
+ elif category in ('Move', 'Resize'):
53
+ mask_image, changed_instance, x0, y1, scale = roi_localization(current_image_path, instruction, category)
54
+ inpainted = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove')
55
+ fusion_image = fusion(inpainted, changed_instance, x0, y1, scale)
56
+ edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
57
+
58
+ elif category in ('Appearance Change', 'Background Change', 'Color Change', 'Material Change', 'Expression Change'):
59
+ edited_image = infer_with_DiT('RoI Editing', current_image_path, instruction, category)
60
+
61
+ elif category in ('Tone Transfer', 'Style Change'):
62
+ edited_image = infer_with_DiT('Global Transformation', current_image_path, instruction, category)
63
+
64
+ else:
65
+ raise gr.Error(f"Invalid category returned: '{category}'")
66
+
67
+ current_image_path = f"{step_prefix}.png"
68
+ edited_image.save(current_image_path)
69
+ intermediate_images.append(edited_image.copy())
70
+
71
+ final_result = intermediate_images[-1] if intermediate_images else image
72
+ return intermediate_images, final_result
73
+
74
+ except Exception as e:
75
+ raise gr.Error(f"Processing failed: {str(e)}")
76
+
77
+
78
+ # Gradio UI
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown("## 🖼️ IEAP: Image Editing As Programs")
81
+
82
+ with gr.Row():
83
+ api_key_input = gr.Textbox(label="🔑 OpenAI API Key", type="password", placeholder="sk-...")
84
+
85
+ with gr.Row():
86
+ seed_slider = gr.Slider(
87
+ label="🎲 Random Seed",
88
+ minimum=0,
89
+ maximum=1000000,
90
+ value=3407,
91
+ step=1,
92
+ info="Drag to set the random seed for reproducibility"
93
+ )
94
+
95
+ with gr.Row():
96
+ with gr.Column():
97
+ image_input = gr.Image(type="pil", label="Upload Image")
98
+ prompt_input = gr.Textbox(label="Instruction", placeholder="e.g., Move the dog to the left and change its color to blue")
99
+ submit_button = gr.Button("Submit")
100
+ with gr.Column():
101
+ result_gallery = gr.Gallery(label="Intermediate Steps", columns=2, height="auto")
102
+ final_output = gr.Image(label="✅ Final Result")
103
+
104
+ submit_button.click(
105
+ fn=process_image,
106
+ inputs=[api_key_input, seed_slider, image_input, prompt_input],
107
+ outputs=[result_gallery, final_output]
108
+ )
109
+
110
+ if __name__ == "__main__":
111
+ demo.launch(
112
+ )
instructions.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "categories": ["Move", "Resize"],
3
+ "instructions": ["Move the woman to the right", "Minify the woman"]
4
+ }
main.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from PIL import Image
4
+ import openai
5
+ from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
6
+ from utils import encode_image_to_datauri, cot_with_gpt, extract_instructions, infer_with_DiT, roi_localization, fusion
7
+ from src.flux.generate import generate, seed_everything
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser(description="Evaluate single image + instruction using GPT-4o")
11
+ parser.add_argument("image_path", help="Path to input image")
12
+ parser.add_argument("prompt", help="Original instruction")
13
+ parser.add_argument("--seed", type=int, default=3407, help="Random seed for reproducibility")
14
+ args = parser.parse_args()
15
+
16
+ seed_everything(args.seed)
17
+
18
+ openai.api_key = "YOUR_API_KEY"
19
+
20
+ if not openai.api_key:
21
+ raise ValueError("OPENAI_API_KEY environment variable not set.")
22
+
23
+ os.makedirs("results", exist_ok=True)
24
+
25
+
26
+ ###########################################
27
+ ### CoT -> instructions ###
28
+ ###########################################
29
+
30
+ uri = encode_image_to_datauri(args.image_path)
31
+ categories, instructions = cot_with_gpt(uri, args.prompt)
32
+ print(categories)
33
+ print(instructions)
34
+
35
+ # categories = ['Move', 'Resize']
36
+ # instructions = ['Move the woman to the right', 'Minify the woman']
37
+
38
+ ###########################################
39
+ ### Neural Program Interpreter ###
40
+ ###########################################
41
+ for i in range(len(categories)):
42
+ if i == 0:
43
+ image = args.image_path
44
+ else:
45
+ image = f"results/{i-1}.png"
46
+ category = categories[i]
47
+ instruction = instructions[i]
48
+ if category in ('Add', 'Remove', 'Replace', 'Action Change', 'Move', 'Resize'):
49
+ if category in ('Add', 'Remove', 'Replace'):
50
+ if category == 'Add':
51
+ edited_image = infer_with_DiT('RoI Editing', image, instruction, category)
52
+ else:
53
+ ### RoI Localization
54
+ mask_image = roi_localization(image, instruction, category)
55
+ # mask_image.save("mask.png")
56
+ ### RoI Inpainting
57
+ edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, category)
58
+ elif category == 'Action Change':
59
+ ### RoI Localization
60
+ mask_image = roi_localization(image, instruction, category)
61
+ ### RoI Inpainting
62
+ edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove') # inpainted bg
63
+ ### RoI Editing
64
+ changed_instance, x0, y1, scale = infer_with_DiT('RoI Editing', image, instruction, category) # action change
65
+ fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
66
+ ### RoI Compositioning
67
+ edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
68
+ elif category in ('Move', 'Resize'):
69
+ ### RoI Localization
70
+ mask_image, changed_instance, x0, y1, scale = roi_localization(image, instruction, category)
71
+ ### RoI Inpainting
72
+ edited_image= infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove') # inpainted bg
73
+ # changed_instance, bottom_left, scale = layout_change(image, instruction) # move/resize
74
+ fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
75
+ fusion_image.save("fusion.png")
76
+ ### RoI Compositioning
77
+ edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
78
+
79
+ elif category in ('Appearance Change', 'Background Change', 'Color Change', 'Material Change', 'Expression Change'):
80
+ ### RoI Editing
81
+ edited_image = infer_with_DiT('RoI Editing', image, instruction, category)
82
+
83
+ elif category in ('Tone Transfer', 'Style Change'):
84
+ ### Global Transformation
85
+ edited_image = infer_with_DiT('Global Transformation', image, instruction, category)
86
+
87
+ else:
88
+ raise ValueError(f"Invalid category: '{category}'")
89
+
90
+ image = edited_image
91
+ image.save(f"results/{i}.png")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()
main_json.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ from PIL import Image
5
+ import openai
6
+ from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
7
+ from utils import encode_image_to_datauri, cot_with_gpt, extract_instructions, infer_with_DiT, roi_localization, fusion
8
+
9
+
10
+ def main():
11
+ parser = argparse.ArgumentParser(description="Evaluate single image + instruction using GPT-4o")
12
+ parser.add_argument("image_path", help="Path to input image")
13
+ parser.add_argument("json_path", help="Path to JSON file containing categories and instructions")
14
+ args = parser.parse_args()
15
+
16
+ openai.api_key = "YOUR_API_KEY"
17
+
18
+ if not openai.api_key:
19
+ raise ValueError("OPENAI_API_KEY environment variable not set.")
20
+
21
+ os.makedirs("results", exist_ok=True)
22
+
23
+
24
+ #######################################################
25
+ ### Load instructions from JSON ###
26
+ #######################################################
27
+ try:
28
+ with open(args.json_path, 'r') as f:
29
+ data = json.load(f)
30
+ categories = data.get('categories', [])
31
+ instructions = data.get('instructions', [])
32
+
33
+ if not categories or not instructions:
34
+ raise ValueError("JSON file must contain 'categories' and 'instructions' arrays.")
35
+
36
+ if len(categories) != len(instructions):
37
+ raise ValueError("Length of 'categories' and 'instructions' must match.")
38
+
39
+ print("Loaded instructions from JSON:")
40
+ for i, (cat, instr) in enumerate(zip(categories, instructions)):
41
+ print(f"Step {i+1}: [{cat}] {instr}")
42
+
43
+ except Exception as e:
44
+ raise ValueError(f"Failed to load JSON file: {str(e)}")
45
+
46
+ ###################################################
47
+ ### Neural Program Interpreter ###
48
+ ###################################################
49
+ for i in range(len(categories)):
50
+ if i == 0:
51
+ image = args.image_path
52
+ else:
53
+ image = f"results/{i-1}.png"
54
+ category = categories[i]
55
+ instruction = instructions[i]
56
+
57
+ if category in ('Add', 'Remove', 'Replace', 'Action Change', 'Move', 'Resize'):
58
+ if category in ('Add', 'Remove', 'Replace'):
59
+ if category == 'Add':
60
+ edited_image = infer_with_DiT('RoI Editing', image, instruction, category)
61
+ else:
62
+ mask_image = roi_localization(image, instruction, category)
63
+ edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, category)
64
+ elif category == 'Action Change':
65
+ mask_image = roi_localization(image, instruction, category)
66
+ edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove')
67
+ changed_instance, x0, y1, scale = infer_with_DiT('RoI Editing', image, instruction, category)
68
+ fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
69
+ edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
70
+ elif category in ('Move', 'Resize'):
71
+ mask_image, changed_instance, x0, y1, scale = roi_localization(image, instruction, category)
72
+ edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove')
73
+ fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
74
+ fusion_image.save("fusion.png")
75
+ edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
76
+
77
+ elif category in ('Appearance Change', 'Background Change', 'Color Change', 'Material Change', 'Expression Change'):
78
+ edited_image = infer_with_DiT('RoI Editing', image, instruction, category)
79
+
80
+ elif category in ('Tone Transfer', 'Style Change'):
81
+ edited_image = infer_with_DiT('Global Transformation', image, instruction, category)
82
+
83
+ else:
84
+ raise ValueError(f"Invalid category: '{category}'")
85
+
86
+ image = edited_image
87
+ image.save(f"results/{i}.png")
88
+ print(f"Step {i+1} completed: {category} - {instruction}")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.32.0
2
+ transformers==4.42.3
3
+ xtuner[deepspeed]==0.1.23
4
+ timm==1.0.9
5
+ mmdet==3.3.0
6
+ hydra-core==1.3.2
7
+ ninja==1.11.1
8
+ decord==0.6.0
9
+ peft==0.11.1
10
+ protobuf==5.29.4
11
+ sentencepiece==0.2.0
12
+ tornado==6.4.2
13
+ openai==0.28.0
14
+ gradio==5.32.0
15
+ opencv-python
src/flux/block.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Optional, Dict, Any, Callable
3
+ from diffusers.models.attention_processor import Attention, F
4
+ from .lora_controller import enable_lora
5
+
6
+
7
+ def attn_forward(
8
+ attn: Attention,
9
+ hidden_states: torch.FloatTensor,
10
+ encoder_hidden_states: torch.FloatTensor = None,
11
+ condition_latents: torch.FloatTensor = None,
12
+ attention_mask: Optional[torch.FloatTensor] = None,
13
+ image_rotary_emb: Optional[torch.Tensor] = None,
14
+ cond_rotary_emb: Optional[torch.Tensor] = None,
15
+ model_config: Optional[Dict[str, Any]] = {},
16
+ ) -> torch.FloatTensor:
17
+ batch_size, _, _ = (
18
+ hidden_states.shape
19
+ if encoder_hidden_states is None
20
+ else encoder_hidden_states.shape
21
+ )
22
+
23
+ with enable_lora(
24
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
25
+ ):
26
+ # `sample` projections.
27
+ query = attn.to_q(hidden_states)
28
+ key = attn.to_k(hidden_states)
29
+ value = attn.to_v(hidden_states)
30
+
31
+ inner_dim = key.shape[-1]
32
+ head_dim = inner_dim // attn.heads
33
+
34
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
36
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
37
+
38
+ if attn.norm_q is not None:
39
+ query = attn.norm_q(query)
40
+ if attn.norm_k is not None:
41
+ key = attn.norm_k(key)
42
+
43
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
44
+ if encoder_hidden_states is not None:
45
+ # `context` projections.
46
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
47
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
48
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
49
+
50
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
51
+ batch_size, -1, attn.heads, head_dim
52
+ ).transpose(1, 2)
53
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
54
+ batch_size, -1, attn.heads, head_dim
55
+ ).transpose(1, 2)
56
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
57
+ batch_size, -1, attn.heads, head_dim
58
+ ).transpose(1, 2)
59
+
60
+ if attn.norm_added_q is not None:
61
+ encoder_hidden_states_query_proj = attn.norm_added_q(
62
+ encoder_hidden_states_query_proj
63
+ )
64
+ if attn.norm_added_k is not None:
65
+ encoder_hidden_states_key_proj = attn.norm_added_k(
66
+ encoder_hidden_states_key_proj
67
+ )
68
+
69
+ # attention
70
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
71
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
72
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
73
+
74
+ if image_rotary_emb is not None:
75
+ from diffusers.models.embeddings import apply_rotary_emb
76
+
77
+ query = apply_rotary_emb(query, image_rotary_emb)
78
+ key = apply_rotary_emb(key, image_rotary_emb)
79
+
80
+ if condition_latents is not None:
81
+ cond_query = attn.to_q(condition_latents)
82
+ cond_key = attn.to_k(condition_latents)
83
+ cond_value = attn.to_v(condition_latents)
84
+
85
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
86
+ 1, 2
87
+ )
88
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
90
+ 1, 2
91
+ )
92
+ if attn.norm_q is not None:
93
+ cond_query = attn.norm_q(cond_query)
94
+ if attn.norm_k is not None:
95
+ cond_key = attn.norm_k(cond_key)
96
+
97
+ if cond_rotary_emb is not None:
98
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
99
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
100
+
101
+ if condition_latents is not None:
102
+ query = torch.cat([query, cond_query], dim=2)
103
+ key = torch.cat([key, cond_key], dim=2)
104
+ value = torch.cat([value, cond_value], dim=2)
105
+
106
+ if not model_config.get("union_cond_attn", True):
107
+ # If we don't want to use the union condition attention, we need to mask the attention
108
+ # between the hidden states and the condition latents
109
+ attention_mask = torch.ones(
110
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
111
+ )
112
+ condition_n = cond_query.shape[2]
113
+ attention_mask[-condition_n:, :-condition_n] = False
114
+ attention_mask[:-condition_n, -condition_n:] = False
115
+ elif model_config.get("independent_condition", False):
116
+ attention_mask = torch.ones(
117
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
118
+ )
119
+ condition_n = cond_query.shape[2]
120
+ attention_mask[-condition_n:, :-condition_n] = False
121
+ if hasattr(attn, "c_factor"):
122
+ attention_mask = torch.zeros(
123
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
124
+ )
125
+ condition_n = cond_query.shape[2]
126
+ bias = torch.log(attn.c_factor[0])
127
+ attention_mask[-condition_n:, :-condition_n] = bias
128
+ attention_mask[:-condition_n, -condition_n:] = bias
129
+ hidden_states = F.scaled_dot_product_attention(
130
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
131
+ )
132
+ hidden_states = hidden_states.transpose(1, 2).reshape(
133
+ batch_size, -1, attn.heads * head_dim
134
+ )
135
+ hidden_states = hidden_states.to(query.dtype)
136
+
137
+ if encoder_hidden_states is not None:
138
+ if condition_latents is not None:
139
+ encoder_hidden_states, hidden_states, condition_latents = (
140
+ hidden_states[:, : encoder_hidden_states.shape[1]],
141
+ hidden_states[
142
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
143
+ ],
144
+ hidden_states[:, -condition_latents.shape[1] :],
145
+ )
146
+ else:
147
+ encoder_hidden_states, hidden_states = (
148
+ hidden_states[:, : encoder_hidden_states.shape[1]],
149
+ hidden_states[:, encoder_hidden_states.shape[1] :],
150
+ )
151
+
152
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
153
+ # linear proj
154
+ hidden_states = attn.to_out[0](hidden_states)
155
+ # dropout
156
+ hidden_states = attn.to_out[1](hidden_states)
157
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
158
+
159
+ if condition_latents is not None:
160
+ condition_latents = attn.to_out[0](condition_latents)
161
+ condition_latents = attn.to_out[1](condition_latents)
162
+
163
+ return (
164
+ (hidden_states, encoder_hidden_states, condition_latents)
165
+ if condition_latents is not None
166
+ else (hidden_states, encoder_hidden_states)
167
+ )
168
+ elif condition_latents is not None:
169
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
170
+ hidden_states, condition_latents = (
171
+ hidden_states[:, : -condition_latents.shape[1]],
172
+ hidden_states[:, -condition_latents.shape[1] :],
173
+ )
174
+ return hidden_states, condition_latents
175
+ else:
176
+ return hidden_states
177
+
178
+
179
+ def block_forward(
180
+ self,
181
+ hidden_states: torch.FloatTensor,
182
+ encoder_hidden_states: torch.FloatTensor,
183
+ condition_latents: torch.FloatTensor,
184
+ temb: torch.FloatTensor,
185
+ cond_temb: torch.FloatTensor,
186
+ cond_rotary_emb=None,
187
+ image_rotary_emb=None,
188
+ model_config: Optional[Dict[str, Any]] = {},
189
+ ):
190
+ use_cond = condition_latents is not None
191
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
192
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
193
+ hidden_states, emb=temb
194
+ )
195
+
196
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
197
+ self.norm1_context(encoder_hidden_states, emb=temb)
198
+ )
199
+
200
+ if use_cond:
201
+ (
202
+ norm_condition_latents,
203
+ cond_gate_msa,
204
+ cond_shift_mlp,
205
+ cond_scale_mlp,
206
+ cond_gate_mlp,
207
+ ) = self.norm1(condition_latents, emb=cond_temb)
208
+
209
+ # Attention.
210
+ result = attn_forward(
211
+ self.attn,
212
+ model_config=model_config,
213
+ hidden_states=norm_hidden_states,
214
+ encoder_hidden_states=norm_encoder_hidden_states,
215
+ condition_latents=norm_condition_latents if use_cond else None,
216
+ image_rotary_emb=image_rotary_emb,
217
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
218
+ )
219
+ attn_output, context_attn_output = result[:2]
220
+ cond_attn_output = result[2] if use_cond else None
221
+
222
+ # Process attention outputs for the `hidden_states`.
223
+ # 1. hidden_states
224
+ attn_output = gate_msa.unsqueeze(1) * attn_output
225
+ hidden_states = hidden_states + attn_output
226
+ # 2. encoder_hidden_states
227
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
228
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
229
+ # 3. condition_latents
230
+ if use_cond:
231
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
232
+ condition_latents = condition_latents + cond_attn_output
233
+ if model_config.get("add_cond_attn", False):
234
+ hidden_states += cond_attn_output
235
+
236
+ # LayerNorm + MLP.
237
+ # 1. hidden_states
238
+ norm_hidden_states = self.norm2(hidden_states)
239
+ norm_hidden_states = (
240
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
241
+ )
242
+ # 2. encoder_hidden_states
243
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
244
+ norm_encoder_hidden_states = (
245
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
246
+ )
247
+ # 3. condition_latents
248
+ if use_cond:
249
+ norm_condition_latents = self.norm2(condition_latents)
250
+ norm_condition_latents = (
251
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
252
+ + cond_shift_mlp[:, None]
253
+ )
254
+
255
+ # Feed-forward.
256
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
257
+ # 1. hidden_states
258
+ ff_output = self.ff(norm_hidden_states)
259
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
260
+ # 2. encoder_hidden_states
261
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
262
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
263
+ # 3. condition_latents
264
+ if use_cond:
265
+ cond_ff_output = self.ff(norm_condition_latents)
266
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
267
+
268
+ # Process feed-forward outputs.
269
+ hidden_states = hidden_states + ff_output
270
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
271
+ if use_cond:
272
+ condition_latents = condition_latents + cond_ff_output
273
+
274
+ # Clip to avoid overflow.
275
+ if encoder_hidden_states.dtype == torch.float16:
276
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
277
+
278
+ return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
279
+
280
+
281
+ def single_block_forward(
282
+ self,
283
+ hidden_states: torch.FloatTensor,
284
+ temb: torch.FloatTensor,
285
+ image_rotary_emb=None,
286
+ condition_latents: torch.FloatTensor = None,
287
+ cond_temb: torch.FloatTensor = None,
288
+ cond_rotary_emb=None,
289
+ model_config: Optional[Dict[str, Any]] = {},
290
+ ):
291
+
292
+ using_cond = condition_latents is not None
293
+ residual = hidden_states
294
+ with enable_lora(
295
+ (
296
+ self.norm.linear,
297
+ self.proj_mlp,
298
+ ),
299
+ model_config.get("latent_lora", False),
300
+ ):
301
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
302
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
303
+ if using_cond:
304
+ residual_cond = condition_latents
305
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
306
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
307
+
308
+ attn_output = attn_forward(
309
+ self.attn,
310
+ model_config=model_config,
311
+ hidden_states=norm_hidden_states,
312
+ image_rotary_emb=image_rotary_emb,
313
+ **(
314
+ {
315
+ "condition_latents": norm_condition_latents,
316
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
317
+ }
318
+ if using_cond
319
+ else {}
320
+ ),
321
+ )
322
+ if using_cond:
323
+ attn_output, cond_attn_output = attn_output
324
+
325
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
326
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
327
+ gate = gate.unsqueeze(1)
328
+ hidden_states = gate * self.proj_out(hidden_states)
329
+ hidden_states = residual + hidden_states
330
+ if using_cond:
331
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
332
+ cond_gate = cond_gate.unsqueeze(1)
333
+ condition_latents = cond_gate * self.proj_out(condition_latents)
334
+ condition_latents = residual_cond + condition_latents
335
+
336
+ if hidden_states.dtype == torch.float16:
337
+ hidden_states = hidden_states.clip(-65504, 65504)
338
+
339
+ return hidden_states if not using_cond else (hidden_states, condition_latents)
src/flux/condition.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Union, List, Tuple
3
+ from diffusers.pipelines import FluxPipeline
4
+ from PIL import Image, ImageFilter
5
+ import numpy as np
6
+ import cv2
7
+
8
+ from .pipeline_tools import encode_images
9
+
10
+ condition_dict = {
11
+ "depth": 0,
12
+ "canny": 1,
13
+ "subject": 4,
14
+ "coloring": 6,
15
+ "deblurring": 7,
16
+ "depth_pred": 8,
17
+ "fill": 9,
18
+ "sr": 10,
19
+ "cartoon": 11,
20
+ "scene": 12
21
+ }
22
+
23
+
24
+ class Condition(object):
25
+ def __init__(
26
+ self,
27
+ condition_type: str,
28
+ raw_img: Union[Image.Image, torch.Tensor] = None,
29
+ condition: Union[Image.Image, torch.Tensor] = None,
30
+ mask=None,
31
+ position_delta=None,
32
+ position_scale=1.0,
33
+ ) -> None:
34
+ self.condition_type = condition_type
35
+ assert raw_img is not None or condition is not None
36
+ if raw_img is not None:
37
+ self.condition = self.get_condition(condition_type, raw_img)
38
+ else:
39
+ self.condition = condition
40
+ self.position_delta = position_delta
41
+ self.position_scale = position_scale
42
+ # TODO: Add mask support
43
+ assert mask is None, "Mask not supported yet"
44
+
45
+ def get_condition(
46
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
47
+ ) -> Union[Image.Image, torch.Tensor]:
48
+ """
49
+ Returns the condition image.
50
+ """
51
+ if condition_type == "depth":
52
+ from transformers import pipeline
53
+
54
+ depth_pipe = pipeline(
55
+ task="depth-estimation",
56
+ model="LiheYoung/depth-anything-small-hf",
57
+ device="cuda",
58
+ )
59
+ source_image = raw_img.convert("RGB")
60
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
61
+ return condition_img
62
+ elif condition_type == "canny":
63
+ img = np.array(raw_img)
64
+ edges = cv2.Canny(img, 100, 200)
65
+ edges = Image.fromarray(edges).convert("RGB")
66
+ return edges
67
+ elif condition_type == "subject":
68
+ return raw_img
69
+ elif condition_type == "coloring":
70
+ return raw_img.convert("L").convert("RGB")
71
+ elif condition_type == "deblurring":
72
+ condition_image = (
73
+ raw_img.convert("RGB")
74
+ .filter(ImageFilter.GaussianBlur(10))
75
+ .convert("RGB")
76
+ )
77
+ return condition_image
78
+ elif condition_type == "fill":
79
+ return raw_img.convert("RGB")
80
+ elif condition_type == "cartoon":
81
+ return raw_img.convert("RGB")
82
+ elif condition_type == "scene":
83
+ return raw_img.convert("RGB")
84
+ return self.condition
85
+
86
+ @property
87
+ def type_id(self) -> int:
88
+ """
89
+ Returns the type id of the condition.
90
+ """
91
+ return condition_dict[self.condition_type]
92
+
93
+ @classmethod
94
+ def get_type_id(cls, condition_type: str) -> int:
95
+ """
96
+ Returns the type id of the condition.
97
+ """
98
+ return condition_dict[condition_type]
99
+
100
+ def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
101
+ """
102
+ Encodes the condition into tokens, ids and type_id.
103
+ """
104
+ if self.condition_type in [
105
+ "depth",
106
+ "canny",
107
+ "subject",
108
+ "coloring",
109
+ "deblurring",
110
+ "depth_pred",
111
+ "fill",
112
+ "sr",
113
+ "cartoon",
114
+ "scene"
115
+ ]:
116
+ tokens, ids = encode_images(pipe, self.condition)
117
+ else:
118
+ raise NotImplementedError(
119
+ f"Condition type {self.condition_type} not implemented"
120
+ )
121
+ if self.position_delta is None and self.condition_type == "subject":
122
+ self.position_delta = [0, -self.condition.size[0] // 16]
123
+ if self.position_delta is not None:
124
+ ids[:, 1] += self.position_delta[0]
125
+ ids[:, 2] += self.position_delta[1]
126
+ if self.position_scale != 1.0:
127
+ scale_bias = (self.position_scale - 1.0) / 2
128
+ ids[:, 1] *= self.position_scale
129
+ ids[:, 2] *= self.position_scale
130
+ ids[:, 1] += scale_bias
131
+ ids[:, 2] += scale_bias
132
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
133
+ return tokens, ids, type_id
src/flux/generate.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml, os
3
+ from diffusers.pipelines import FluxPipeline
4
+ from typing import List, Union, Optional, Dict, Any, Callable
5
+ from .transformer import tranformer_forward
6
+ from .condition import Condition
7
+
8
+ from diffusers.pipelines.flux.pipeline_flux import (
9
+ FluxPipelineOutput,
10
+ calculate_shift,
11
+ retrieve_timesteps,
12
+ np,
13
+ )
14
+
15
+
16
+ def get_config(config_path: str = None):
17
+ config_path = config_path or os.environ.get("XFL_CONFIG")
18
+ if not config_path:
19
+ return {}
20
+ with open(config_path, "r") as f:
21
+ config = yaml.safe_load(f)
22
+ return config
23
+
24
+
25
+ def prepare_params(
26
+ prompt: Union[str, List[str]] = None,
27
+ prompt_2: Optional[Union[str, List[str]]] = None,
28
+ height: Optional[int] = 512,
29
+ width: Optional[int] = 512,
30
+ num_inference_steps: int = 28,
31
+ timesteps: List[int] = None,
32
+ guidance_scale: float = 3.5,
33
+ num_images_per_prompt: Optional[int] = 1,
34
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
35
+ latents: Optional[torch.FloatTensor] = None,
36
+ prompt_embeds: Optional[torch.FloatTensor] = None,
37
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
38
+ output_type: Optional[str] = "pil",
39
+ return_dict: bool = True,
40
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
41
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
42
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
43
+ max_sequence_length: int = 512,
44
+ **kwargs: dict,
45
+ ):
46
+ return (
47
+ prompt,
48
+ prompt_2,
49
+ height,
50
+ width,
51
+ num_inference_steps,
52
+ timesteps,
53
+ guidance_scale,
54
+ num_images_per_prompt,
55
+ generator,
56
+ latents,
57
+ prompt_embeds,
58
+ pooled_prompt_embeds,
59
+ output_type,
60
+ return_dict,
61
+ joint_attention_kwargs,
62
+ callback_on_step_end,
63
+ callback_on_step_end_tensor_inputs,
64
+ max_sequence_length,
65
+ )
66
+
67
+
68
+ def seed_everything(seed: int = 42):
69
+ torch.backends.cudnn.deterministic = True
70
+ torch.manual_seed(seed)
71
+ np.random.seed(seed)
72
+
73
+
74
+ @torch.no_grad()
75
+ def generate(
76
+ pipeline: FluxPipeline,
77
+ conditions: List[Condition] = None,
78
+ config_path: str = None,
79
+ model_config: Optional[Dict[str, Any]] = {},
80
+ condition_scale: float = 1.0,
81
+ default_lora: bool = False,
82
+ image_guidance_scale: float = 1.0,
83
+ **params: dict,
84
+ ):
85
+ model_config = model_config or get_config(config_path).get("model", {})
86
+ # print(model_config)
87
+ if condition_scale != 1:
88
+ for name, module in pipeline.transformer.named_modules():
89
+ if not name.endswith(".attn"):
90
+ continue
91
+ module.c_factor = torch.ones(1, 1) * condition_scale
92
+
93
+ self = pipeline
94
+ (
95
+ prompt,
96
+ prompt_2,
97
+ height,
98
+ width,
99
+ num_inference_steps,
100
+ timesteps,
101
+ guidance_scale,
102
+ num_images_per_prompt,
103
+ generator,
104
+ latents,
105
+ prompt_embeds,
106
+ pooled_prompt_embeds,
107
+ output_type,
108
+ return_dict,
109
+ joint_attention_kwargs,
110
+ callback_on_step_end,
111
+ callback_on_step_end_tensor_inputs,
112
+ max_sequence_length,
113
+ ) = prepare_params(**params)
114
+
115
+ height = height or self.default_sample_size * self.vae_scale_factor
116
+ width = width or self.default_sample_size * self.vae_scale_factor
117
+
118
+ # 1. Check inputs. Raise error if not correct
119
+ self.check_inputs(
120
+ prompt,
121
+ prompt_2,
122
+ height,
123
+ width,
124
+ prompt_embeds=prompt_embeds,
125
+ pooled_prompt_embeds=pooled_prompt_embeds,
126
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
127
+ max_sequence_length=max_sequence_length,
128
+ )
129
+
130
+ self._guidance_scale = guidance_scale
131
+ self._joint_attention_kwargs = joint_attention_kwargs
132
+ self._interrupt = False
133
+
134
+ # 2. Define call parameters
135
+ if prompt is not None and isinstance(prompt, str):
136
+ batch_size = 1
137
+ elif prompt is not None and isinstance(prompt, list):
138
+ batch_size = len(prompt)
139
+ else:
140
+ batch_size = prompt_embeds.shape[0]
141
+
142
+ device = self._execution_device
143
+
144
+ lora_scale = (
145
+ self.joint_attention_kwargs.get("scale", None)
146
+ if self.joint_attention_kwargs is not None
147
+ else None
148
+ )
149
+ (
150
+ prompt_embeds,
151
+ pooled_prompt_embeds,
152
+ text_ids,
153
+ ) = self.encode_prompt(
154
+ prompt=prompt,
155
+ prompt_2=prompt_2,
156
+ prompt_embeds=prompt_embeds,
157
+ pooled_prompt_embeds=pooled_prompt_embeds,
158
+ device=device,
159
+ num_images_per_prompt=num_images_per_prompt,
160
+ max_sequence_length=max_sequence_length,
161
+ lora_scale=lora_scale,
162
+ )
163
+
164
+ # 4. Prepare latent variables
165
+ num_channels_latents = self.transformer.config.in_channels // 4
166
+ latents, latent_image_ids = self.prepare_latents(
167
+ batch_size * num_images_per_prompt,
168
+ num_channels_latents,
169
+ height,
170
+ width,
171
+ prompt_embeds.dtype,
172
+ device,
173
+ generator,
174
+ latents,
175
+ )
176
+
177
+ # 4.1. Prepare conditions
178
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
179
+ use_condition = conditions is not None or []
180
+ if use_condition:
181
+ assert len(conditions) <= 1, "Only one condition is supported for now."
182
+ if not default_lora:
183
+ pipeline.set_adapters(conditions[0].condition_type)
184
+ for condition in conditions:
185
+ tokens, ids, type_id = condition.encode(self)
186
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
187
+ condition_ids.append(ids) # [token_n, id_dim(3)]
188
+ condition_type_ids.append(type_id) # [token_n, 1]
189
+ condition_latents = torch.cat(condition_latents, dim=1)
190
+ condition_ids = torch.cat(condition_ids, dim=0)
191
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
192
+
193
+ # 5. Prepare timesteps
194
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
195
+ image_seq_len = latents.shape[1]
196
+ mu = calculate_shift(
197
+ image_seq_len,
198
+ self.scheduler.config.base_image_seq_len,
199
+ self.scheduler.config.max_image_seq_len,
200
+ self.scheduler.config.base_shift,
201
+ self.scheduler.config.max_shift,
202
+ )
203
+ timesteps, num_inference_steps = retrieve_timesteps(
204
+ self.scheduler,
205
+ num_inference_steps,
206
+ device,
207
+ timesteps,
208
+ sigmas,
209
+ mu=mu,
210
+ )
211
+ num_warmup_steps = max(
212
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
213
+ )
214
+ self._num_timesteps = len(timesteps)
215
+
216
+ # 6. Denoising loop
217
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
218
+ for i, t in enumerate(timesteps):
219
+ if self.interrupt:
220
+ continue
221
+
222
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
223
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
224
+
225
+ # handle guidance
226
+ if self.transformer.config.guidance_embeds:
227
+ guidance = torch.tensor([guidance_scale], device=device)
228
+ guidance = guidance.expand(latents.shape[0])
229
+ else:
230
+ guidance = None
231
+ noise_pred = tranformer_forward(
232
+ self.transformer,
233
+ model_config=model_config,
234
+ # Inputs of the condition (new feature)
235
+ condition_latents=condition_latents if use_condition else None,
236
+ condition_ids=condition_ids if use_condition else None,
237
+ condition_type_ids=condition_type_ids if use_condition else None,
238
+ # Inputs to the original transformer
239
+ hidden_states=latents,
240
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
241
+ timestep=timestep / 1000,
242
+ guidance=guidance,
243
+ pooled_projections=pooled_prompt_embeds,
244
+ encoder_hidden_states=prompt_embeds,
245
+ txt_ids=text_ids,
246
+ img_ids=latent_image_ids,
247
+ joint_attention_kwargs=self.joint_attention_kwargs,
248
+ return_dict=False,
249
+ )[0]
250
+
251
+ if image_guidance_scale != 1.0:
252
+ uncondition_latents = condition.encode(self, empty=True)[0]
253
+ unc_pred = tranformer_forward(
254
+ self.transformer,
255
+ model_config=model_config,
256
+ # Inputs of the condition (new feature)
257
+ condition_latents=uncondition_latents if use_condition else None,
258
+ condition_ids=condition_ids if use_condition else None,
259
+ condition_type_ids=condition_type_ids if use_condition else None,
260
+ # Inputs to the original transformer
261
+ hidden_states=latents,
262
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
263
+ timestep=timestep / 1000,
264
+ guidance=torch.ones_like(guidance),
265
+ pooled_projections=pooled_prompt_embeds,
266
+ encoder_hidden_states=prompt_embeds,
267
+ txt_ids=text_ids,
268
+ img_ids=latent_image_ids,
269
+ joint_attention_kwargs=self.joint_attention_kwargs,
270
+ return_dict=False,
271
+ )[0]
272
+
273
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
274
+
275
+ # compute the previous noisy sample x_t -> x_t-1
276
+ latents_dtype = latents.dtype
277
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
278
+
279
+ if latents.dtype != latents_dtype:
280
+ if torch.backends.mps.is_available():
281
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
282
+ latents = latents.to(latents_dtype)
283
+
284
+ if callback_on_step_end is not None:
285
+ callback_kwargs = {}
286
+ for k in callback_on_step_end_tensor_inputs:
287
+ callback_kwargs[k] = locals()[k]
288
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
289
+
290
+ latents = callback_outputs.pop("latents", latents)
291
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
292
+
293
+ # call the callback, if provided
294
+ if i == len(timesteps) - 1 or (
295
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
296
+ ):
297
+ progress_bar.update()
298
+
299
+ if output_type == "latent":
300
+ image = latents
301
+
302
+ else:
303
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
304
+ latents = (
305
+ latents / self.vae.config.scaling_factor
306
+ ) + self.vae.config.shift_factor
307
+ image = self.vae.decode(latents, return_dict=False)[0]
308
+ image = self.image_processor.postprocess(image, output_type=output_type)
309
+
310
+ # Offload all models
311
+ self.maybe_free_model_hooks()
312
+
313
+ if condition_scale != 1:
314
+ for name, module in pipeline.transformer.named_modules():
315
+ if not name.endswith(".attn"):
316
+ continue
317
+ del module.c_factor
318
+
319
+ if not return_dict:
320
+ return (image,)
321
+
322
+ return FluxPipelineOutput(images=image)
src/flux/lora_controller.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft.tuners.tuners_utils import BaseTunerLayer
2
+ from typing import List, Any, Optional, Type
3
+ from .condition import condition_dict
4
+
5
+ class enable_lora:
6
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7
+ self.activated: bool = activated
8
+ if activated:
9
+ return
10
+ self.lora_modules: List[BaseTunerLayer] = [
11
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
12
+ ]
13
+ self.scales = [
14
+ {
15
+ active_adapter: lora_module.scaling[active_adapter]
16
+ for active_adapter in lora_module.active_adapters
17
+ }
18
+ for lora_module in self.lora_modules
19
+ ]
20
+
21
+ def __enter__(self) -> None:
22
+ if self.activated:
23
+ return
24
+
25
+ for lora_module in self.lora_modules:
26
+ if not isinstance(lora_module, BaseTunerLayer):
27
+ continue
28
+ for active_adapter in lora_module.active_adapters:
29
+ if active_adapter in condition_dict.keys():
30
+ lora_module.scaling[active_adapter] = 0.0
31
+
32
+ def __exit__(
33
+ self,
34
+ exc_type: Optional[Type[BaseException]],
35
+ exc_val: Optional[BaseException],
36
+ exc_tb: Optional[Any],
37
+ ) -> None:
38
+ if self.activated:
39
+ return
40
+ for i, lora_module in enumerate(self.lora_modules):
41
+ if not isinstance(lora_module, BaseTunerLayer):
42
+ continue
43
+ for active_adapter in lora_module.active_adapters:
44
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
45
+
46
+
47
+ class set_lora_scale:
48
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
49
+ self.lora_modules: List[BaseTunerLayer] = [
50
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
51
+ ]
52
+ self.scales = [
53
+ {
54
+ active_adapter: lora_module.scaling[active_adapter]
55
+ for active_adapter in lora_module.active_adapters
56
+ }
57
+ for lora_module in self.lora_modules
58
+ ]
59
+ self.scale = scale
60
+
61
+ def __enter__(self) -> None:
62
+ for lora_module in self.lora_modules:
63
+ if not isinstance(lora_module, BaseTunerLayer):
64
+ continue
65
+ lora_module.scale_layer(self.scale)
66
+
67
+ def __exit__(
68
+ self,
69
+ exc_type: Optional[Type[BaseException]],
70
+ exc_val: Optional[BaseException],
71
+ exc_tb: Optional[Any],
72
+ ) -> None:
73
+ for i, lora_module in enumerate(self.lora_modules):
74
+ if not isinstance(lora_module, BaseTunerLayer):
75
+ continue
76
+ for active_adapter in lora_module.active_adapters:
77
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
src/flux/pipeline_tools.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.pipelines import FluxPipeline
2
+ from diffusers.utils import logging
3
+ from diffusers.pipelines.flux.pipeline_flux import logger
4
+ from torch import Tensor
5
+
6
+
7
+ def encode_images(pipeline: FluxPipeline, images: Tensor):
8
+ images = pipeline.image_processor.preprocess(images)
9
+ images = images.to(pipeline.device).to(pipeline.dtype)
10
+ images = pipeline.vae.encode(images).latent_dist.sample()
11
+ images = (
12
+ images - pipeline.vae.config.shift_factor
13
+ ) * pipeline.vae.config.scaling_factor
14
+ images_tokens = pipeline._pack_latents(images, *images.shape)
15
+ images_ids = pipeline._prepare_latent_image_ids(
16
+ images.shape[0],
17
+ images.shape[2],
18
+ images.shape[3],
19
+ pipeline.device,
20
+ pipeline.dtype,
21
+ )
22
+ if images_tokens.shape[1] != images_ids.shape[0]:
23
+ images_ids = pipeline._prepare_latent_image_ids(
24
+ images.shape[0],
25
+ images.shape[2] // 2,
26
+ images.shape[3] // 2,
27
+ pipeline.device,
28
+ pipeline.dtype,
29
+ )
30
+ return images_tokens, images_ids
31
+
32
+
33
+ def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
34
+ # Turn off warnings (CLIP overflow)
35
+ logger.setLevel(logging.ERROR)
36
+ (
37
+ prompt_embeds,
38
+ pooled_prompt_embeds,
39
+ text_ids,
40
+ ) = pipeline.encode_prompt(
41
+ prompt=prompts,
42
+ prompt_2=None,
43
+ prompt_embeds=None,
44
+ pooled_prompt_embeds=None,
45
+ device=pipeline.device,
46
+ num_images_per_prompt=1,
47
+ max_sequence_length=max_sequence_length,
48
+ lora_scale=None,
49
+ )
50
+ # Turn on warnings
51
+ logger.setLevel(logging.WARNING)
52
+ return prompt_embeds, pooled_prompt_embeds, text_ids
src/flux/transformer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines import FluxPipeline
3
+ from typing import List, Union, Optional, Dict, Any, Callable
4
+ from .block import block_forward, single_block_forward
5
+ from .lora_controller import enable_lora
6
+ from accelerate.utils import is_torch_version
7
+ from diffusers.models.transformers.transformer_flux import (
8
+ FluxTransformer2DModel,
9
+ Transformer2DModelOutput,
10
+ USE_PEFT_BACKEND,
11
+ scale_lora_layers,
12
+ unscale_lora_layers,
13
+ logger,
14
+ )
15
+ import numpy as np
16
+
17
+
18
+ def prepare_params(
19
+ hidden_states: torch.Tensor,
20
+ encoder_hidden_states: torch.Tensor = None,
21
+ pooled_projections: torch.Tensor = None,
22
+ timestep: torch.LongTensor = None,
23
+ img_ids: torch.Tensor = None,
24
+ txt_ids: torch.Tensor = None,
25
+ guidance: torch.Tensor = None,
26
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27
+ controlnet_block_samples=None,
28
+ controlnet_single_block_samples=None,
29
+ return_dict: bool = True,
30
+ **kwargs: dict,
31
+ ):
32
+ return (
33
+ hidden_states,
34
+ encoder_hidden_states,
35
+ pooled_projections,
36
+ timestep,
37
+ img_ids,
38
+ txt_ids,
39
+ guidance,
40
+ joint_attention_kwargs,
41
+ controlnet_block_samples,
42
+ controlnet_single_block_samples,
43
+ return_dict,
44
+ )
45
+
46
+
47
+ def tranformer_forward(
48
+ transformer: FluxTransformer2DModel,
49
+ condition_latents: torch.Tensor,
50
+ condition_ids: torch.Tensor,
51
+ condition_type_ids: torch.Tensor,
52
+ model_config: Optional[Dict[str, Any]] = {},
53
+ c_t=0,
54
+ **params: dict,
55
+ ):
56
+ self = transformer
57
+ use_condition = condition_latents is not None
58
+
59
+ (
60
+ hidden_states,
61
+ encoder_hidden_states,
62
+ pooled_projections,
63
+ timestep,
64
+ img_ids,
65
+ txt_ids,
66
+ guidance,
67
+ joint_attention_kwargs,
68
+ controlnet_block_samples,
69
+ controlnet_single_block_samples,
70
+ return_dict,
71
+ ) = prepare_params(**params)
72
+
73
+ if joint_attention_kwargs is not None:
74
+ joint_attention_kwargs = joint_attention_kwargs.copy()
75
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
76
+ else:
77
+ lora_scale = 1.0
78
+
79
+ if USE_PEFT_BACKEND:
80
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
81
+ scale_lora_layers(self, lora_scale)
82
+ else:
83
+ if (
84
+ joint_attention_kwargs is not None
85
+ and joint_attention_kwargs.get("scale", None) is not None
86
+ ):
87
+ logger.warning(
88
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
89
+ )
90
+
91
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
92
+ hidden_states = self.x_embedder(hidden_states)
93
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
94
+
95
+ timestep = timestep.to(hidden_states.dtype) * 1000
96
+
97
+ if guidance is not None:
98
+ guidance = guidance.to(hidden_states.dtype) * 1000
99
+ else:
100
+ guidance = None
101
+
102
+ temb = (
103
+ self.time_text_embed(timestep, pooled_projections)
104
+ if guidance is None
105
+ else self.time_text_embed(timestep, guidance, pooled_projections)
106
+ )
107
+
108
+ cond_temb = (
109
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
110
+ if guidance is None
111
+ else self.time_text_embed(
112
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
113
+ )
114
+ )
115
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
116
+
117
+ if txt_ids.ndim == 3:
118
+ logger.warning(
119
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
120
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
121
+ )
122
+ txt_ids = txt_ids[0]
123
+ if img_ids.ndim == 3:
124
+ logger.warning(
125
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
126
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
127
+ )
128
+ img_ids = img_ids[0]
129
+
130
+ ids = torch.cat((txt_ids, img_ids), dim=0)
131
+ image_rotary_emb = self.pos_embed(ids)
132
+ if use_condition:
133
+ # condition_ids[:, :1] = condition_type_ids
134
+ cond_rotary_emb = self.pos_embed(condition_ids)
135
+
136
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
137
+
138
+ for index_block, block in enumerate(self.transformer_blocks):
139
+ if self.training and self.gradient_checkpointing:
140
+ ckpt_kwargs: Dict[str, Any] = (
141
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
142
+ )
143
+ encoder_hidden_states, hidden_states, condition_latents = (
144
+ torch.utils.checkpoint.checkpoint(
145
+ block_forward,
146
+ self=block,
147
+ model_config=model_config,
148
+ hidden_states=hidden_states,
149
+ encoder_hidden_states=encoder_hidden_states,
150
+ condition_latents=condition_latents if use_condition else None,
151
+ temb=temb,
152
+ cond_temb=cond_temb if use_condition else None,
153
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
154
+ image_rotary_emb=image_rotary_emb,
155
+ **ckpt_kwargs,
156
+ )
157
+ )
158
+
159
+ else:
160
+ encoder_hidden_states, hidden_states, condition_latents = block_forward(
161
+ block,
162
+ model_config=model_config,
163
+ hidden_states=hidden_states,
164
+ encoder_hidden_states=encoder_hidden_states,
165
+ condition_latents=condition_latents if use_condition else None,
166
+ temb=temb,
167
+ cond_temb=cond_temb if use_condition else None,
168
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
169
+ image_rotary_emb=image_rotary_emb,
170
+ )
171
+
172
+ # controlnet residual
173
+ if controlnet_block_samples is not None:
174
+ interval_control = len(self.transformer_blocks) / len(
175
+ controlnet_block_samples
176
+ )
177
+ interval_control = int(np.ceil(interval_control))
178
+ hidden_states = (
179
+ hidden_states
180
+ + controlnet_block_samples[index_block // interval_control]
181
+ )
182
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
183
+
184
+ for index_block, block in enumerate(self.single_transformer_blocks):
185
+ if self.training and self.gradient_checkpointing:
186
+ ckpt_kwargs: Dict[str, Any] = (
187
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
188
+ )
189
+ result = torch.utils.checkpoint.checkpoint(
190
+ single_block_forward,
191
+ self=block,
192
+ model_config=model_config,
193
+ hidden_states=hidden_states,
194
+ temb=temb,
195
+ image_rotary_emb=image_rotary_emb,
196
+ **(
197
+ {
198
+ "condition_latents": condition_latents,
199
+ "cond_temb": cond_temb,
200
+ "cond_rotary_emb": cond_rotary_emb,
201
+ }
202
+ if use_condition
203
+ else {}
204
+ ),
205
+ **ckpt_kwargs,
206
+ )
207
+
208
+ else:
209
+ result = single_block_forward(
210
+ block,
211
+ model_config=model_config,
212
+ hidden_states=hidden_states,
213
+ temb=temb,
214
+ image_rotary_emb=image_rotary_emb,
215
+ **(
216
+ {
217
+ "condition_latents": condition_latents,
218
+ "cond_temb": cond_temb,
219
+ "cond_rotary_emb": cond_rotary_emb,
220
+ }
221
+ if use_condition
222
+ else {}
223
+ ),
224
+ )
225
+ if use_condition:
226
+ hidden_states, condition_latents = result
227
+ else:
228
+ hidden_states = result
229
+
230
+ # controlnet residual
231
+ if controlnet_single_block_samples is not None:
232
+ interval_control = len(self.single_transformer_blocks) / len(
233
+ controlnet_single_block_samples
234
+ )
235
+ interval_control = int(np.ceil(interval_control))
236
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
237
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
238
+ + controlnet_single_block_samples[index_block // interval_control]
239
+ )
240
+
241
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
242
+
243
+ hidden_states = self.norm_out(hidden_states, temb)
244
+ output = self.proj_out(hidden_states)
245
+
246
+ if USE_PEFT_BACKEND:
247
+ # remove `lora_scale` from each PEFT layer
248
+ unscale_lora_layers(self, lora_scale)
249
+
250
+ if not return_dict:
251
+ return (output,)
252
+ return Transformer2DModelOutput(sample=output)
src/gradio/gradio_app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ from diffusers.pipelines import FluxPipeline
5
+ from diffusers import FluxTransformer2DModel
6
+ import numpy as np
7
+
8
+ from ..flux.condition import Condition
9
+ from ..flux.generate import seed_everything, generate
10
+
11
+ pipe = None
12
+ use_int8 = False
13
+
14
+
15
+ def get_gpu_memory():
16
+ return torch.cuda.get_device_properties(0).total_memory / 1024**3
17
+
18
+
19
+ def init_pipeline():
20
+ global pipe
21
+ if use_int8 or get_gpu_memory() < 33:
22
+ transformer_model = FluxTransformer2DModel.from_pretrained(
23
+ "sayakpaul/flux.1-schell-int8wo-improved",
24
+ torch_dtype=torch.bfloat16,
25
+ use_safetensors=False,
26
+ )
27
+ pipe = FluxPipeline.from_pretrained(
28
+ "black-forest-labs/FLUX.1-schnell",
29
+ transformer=transformer_model,
30
+ torch_dtype=torch.bfloat16,
31
+ )
32
+ else:
33
+ pipe = FluxPipeline.from_pretrained(
34
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
35
+ )
36
+ pipe = pipe.to("cuda")
37
+ pipe.load_lora_weights(
38
+ "Yuanshi/OminiControl",
39
+ weight_name="omini/subject_512.safetensors",
40
+ adapter_name="subject",
41
+ )
42
+
43
+ # Optional: Load additional LoRA weights
44
+ #pipe.load_lora_weights("XLabs-AI/flux-RealismLora", adapter_name="realism")
45
+
46
+
47
+ def process_image_and_text(image, text):
48
+ # center crop image
49
+ w, h, min_size = image.size[0], image.size[1], min(image.size)
50
+ image = image.crop(
51
+ (
52
+ (w - min_size) // 2,
53
+ (h - min_size) // 2,
54
+ (w + min_size) // 2,
55
+ (h + min_size) // 2,
56
+ )
57
+ )
58
+ image = image.resize((512, 512))
59
+
60
+ condition = Condition("subject", image, position_delta=(0, 32))
61
+
62
+ if pipe is None:
63
+ init_pipeline()
64
+
65
+ result_img = generate(
66
+ pipe,
67
+ prompt=text.strip(),
68
+ conditions=[condition],
69
+ num_inference_steps=8,
70
+ height=512,
71
+ width=512,
72
+ ).images[0]
73
+
74
+ return result_img
75
+
76
+
77
+ def get_samples():
78
+ sample_list = [
79
+ {
80
+ "image": "assets/oranges.jpg",
81
+ "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
82
+ },
83
+ {
84
+ "image": "assets/penguin.jpg",
85
+ "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
86
+ },
87
+ {
88
+ "image": "assets/rc_car.jpg",
89
+ "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
90
+ },
91
+ {
92
+ "image": "assets/clock.jpg",
93
+ "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
94
+ },
95
+ {
96
+ "image": "assets/tshirt.jpg",
97
+ "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.",
98
+ },
99
+ ]
100
+ return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
101
+
102
+
103
+ demo = gr.Interface(
104
+ fn=process_image_and_text,
105
+ inputs=[
106
+ gr.Image(type="pil"),
107
+ gr.Textbox(lines=2),
108
+ ],
109
+ outputs=gr.Image(type="pil"),
110
+ title="OminiControl / Subject driven generation",
111
+ examples=get_samples(),
112
+ )
113
+
114
+ if __name__ == "__main__":
115
+ init_pipeline()
116
+ demo.launch(
117
+ debug=True,
118
+ )
src/train/callbacks.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ from PIL import Image, ImageFilter, ImageDraw
3
+ import numpy as np
4
+ from transformers import pipeline
5
+ import cv2
6
+ import torch
7
+ import os
8
+
9
+ try:
10
+ import wandb
11
+ except ImportError:
12
+ wandb = None
13
+
14
+ from ..flux.condition import Condition
15
+ from ..flux.generate import generate
16
+
17
+
18
+ class TrainingCallback(L.Callback):
19
+ def __init__(self, run_name, training_config: dict = {}):
20
+ self.run_name, self.training_config = run_name, training_config
21
+
22
+ self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
23
+ self.save_interval = training_config.get("save_interval", 1000)
24
+ self.sample_interval = training_config.get("sample_interval", 1000)
25
+ self.save_path = training_config.get("save_path", "./output")
26
+
27
+ self.wandb_config = training_config.get("wandb", None)
28
+ self.use_wandb = (
29
+ wandb is not None and os.environ.get("WANDB_API_KEY") is not None
30
+ )
31
+
32
+ self.total_steps = 0
33
+
34
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
35
+ gradient_size = 0
36
+ max_gradient_size = 0
37
+ count = 0
38
+ for _, param in pl_module.named_parameters():
39
+ if param.grad is not None:
40
+ gradient_size += param.grad.norm(2).item()
41
+ max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
42
+ count += 1
43
+ if count > 0:
44
+ gradient_size /= count
45
+
46
+ self.total_steps += 1
47
+
48
+ # Print training progress every n steps
49
+ if self.use_wandb:
50
+ report_dict = {
51
+ "steps": batch_idx,
52
+ "steps": self.total_steps,
53
+ "epoch": trainer.current_epoch,
54
+ "gradient_size": gradient_size,
55
+ }
56
+ loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
57
+ report_dict["loss"] = loss_value
58
+ report_dict["t"] = pl_module.last_t
59
+ wandb.log(report_dict)
60
+
61
+ if self.total_steps % self.print_every_n_steps == 0:
62
+ print(
63
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
64
+ )
65
+
66
+ # Save LoRA weights at specified intervals
67
+ if self.total_steps % self.save_interval == 0:
68
+ print(
69
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
70
+ )
71
+ pl_module.save_lora(
72
+ f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
73
+ )
74
+
75
+ # Generate and save a sample image at specified intervals
76
+ if self.total_steps % self.sample_interval == 0:
77
+ print(
78
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
79
+ )
80
+ self.generate_a_sample(
81
+ trainer,
82
+ pl_module,
83
+ f"{self.save_path}/{self.run_name}/output",
84
+ f"lora_{self.total_steps}",
85
+ batch["condition_type"][
86
+ 0
87
+ ], # Use the condition type from the current batch
88
+ )
89
+
90
+ @torch.no_grad()
91
+ def generate_a_sample(
92
+ self,
93
+ trainer,
94
+ pl_module,
95
+ save_path,
96
+ file_name,
97
+ condition_type="super_resolution",
98
+ ):
99
+ # TODO: change this two variables to parameters
100
+ condition_size = trainer.training_config["dataset"]["condition_size"]
101
+ target_size = trainer.training_config["dataset"]["target_size"]
102
+ position_scale = trainer.training_config["dataset"].get("position_scale", 1.0)
103
+
104
+ generator = torch.Generator(device=pl_module.device)
105
+ generator.manual_seed(42)
106
+
107
+ test_list = []
108
+
109
+ if condition_type == "subject":
110
+ test_list.extend(
111
+ [
112
+ (
113
+ Image.open("assets/test_in.jpg"),
114
+ [0, -32],
115
+ "Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.",
116
+ ),
117
+ (
118
+ Image.open("assets/test_out.jpg"),
119
+ [0, -32],
120
+ "In a bright room. It is placed on a table.",
121
+ ),
122
+ ]
123
+ )
124
+ elif condition_type == "scene":
125
+ test_list.extend(
126
+ [
127
+ (
128
+ Image.open("assets/a2759.jpg"),
129
+ [0, -32],
130
+ "change the color of the plane to red",
131
+ ),
132
+ (
133
+ Image.open("assets/clock.jpg"),
134
+ [0, -32],
135
+ "turn the color of the clock to blue",
136
+ ),
137
+ ]
138
+ )
139
+ elif condition_type == "canny":
140
+ condition_img = Image.open("assets/vase_hq.jpg").resize(
141
+ (condition_size, condition_size)
142
+ )
143
+ condition_img = np.array(condition_img)
144
+ condition_img = cv2.Canny(condition_img, 100, 200)
145
+ condition_img = Image.fromarray(condition_img).convert("RGB")
146
+ test_list.append(
147
+ (
148
+ condition_img,
149
+ [0, 0],
150
+ "A beautiful vase on a table.",
151
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
152
+ )
153
+ )
154
+ elif condition_type == "coloring":
155
+ condition_img = (
156
+ Image.open("assets/vase_hq.jpg")
157
+ .resize((condition_size, condition_size))
158
+ .convert("L")
159
+ .convert("RGB")
160
+ )
161
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
162
+ elif condition_type == "depth":
163
+ if not hasattr(self, "deepth_pipe"):
164
+ self.deepth_pipe = pipeline(
165
+ task="depth-estimation",
166
+ model="LiheYoung/depth-anything-small-hf",
167
+ device="cpu",
168
+ )
169
+ condition_img = (
170
+ Image.open("assets/vase_hq.jpg")
171
+ .resize((condition_size, condition_size))
172
+ .convert("RGB")
173
+ )
174
+ condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB")
175
+ test_list.append(
176
+ (
177
+ condition_img,
178
+ [0, 0],
179
+ "A beautiful vase on a table.",
180
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
181
+ )
182
+ )
183
+ elif condition_type == "depth_pred":
184
+ condition_img = (
185
+ Image.open("assets/vase_hq.jpg")
186
+ .resize((condition_size, condition_size))
187
+ .convert("RGB")
188
+ )
189
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
190
+ elif condition_type == "deblurring":
191
+ blur_radius = 5
192
+ image = Image.open("./assets/vase_hq.jpg")
193
+ condition_img = (
194
+ image.convert("RGB")
195
+ .resize((condition_size, condition_size))
196
+ .filter(ImageFilter.GaussianBlur(blur_radius))
197
+ .convert("RGB")
198
+ )
199
+ test_list.append(
200
+ (
201
+ condition_img,
202
+ [0, 0],
203
+ "A beautiful vase on a table.",
204
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
205
+ )
206
+ )
207
+ elif condition_type == "fill":
208
+ condition_img = (
209
+ Image.open("./assets/vase_hq.jpg")
210
+ .resize((condition_size, condition_size))
211
+ .convert("RGB")
212
+ )
213
+ mask = Image.new("L", condition_img.size, 0)
214
+ draw = ImageDraw.Draw(mask)
215
+ a = condition_img.size[0] // 4
216
+ b = a * 3
217
+ draw.rectangle([a, a, b, b], fill=255)
218
+ condition_img = Image.composite(
219
+ condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask
220
+ )
221
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
222
+ elif condition_type == "sr":
223
+ condition_img = (
224
+ Image.open("assets/vase_hq.jpg")
225
+ .resize((condition_size, condition_size))
226
+ .convert("RGB")
227
+ )
228
+ test_list.append((condition_img, [0, -16], "A beautiful vase on a table."))
229
+ elif condition_type == "cartoon":
230
+ condition_img = (
231
+ Image.open("assets/cartoon_boy.png")
232
+ .resize((condition_size, condition_size))
233
+ .convert("RGB")
234
+ )
235
+ test_list.append(
236
+ (
237
+ condition_img,
238
+ [0, -16],
239
+ "A cartoon character in a white background. He is looking right, and running.",
240
+ )
241
+ )
242
+ else:
243
+ raise NotImplementedError
244
+
245
+ if not os.path.exists(save_path):
246
+ os.makedirs(save_path)
247
+ for i, (condition_img, position_delta, prompt, *others) in enumerate(test_list):
248
+ condition = Condition(
249
+ condition_type=condition_type,
250
+ condition=condition_img.resize(
251
+ (condition_size, condition_size)
252
+ ).convert("RGB"),
253
+ position_delta=position_delta,
254
+ **(others[0] if others else {}),
255
+ )
256
+ res = generate(
257
+ pl_module.flux_pipe,
258
+ prompt=prompt,
259
+ conditions=[condition],
260
+ height=target_size,
261
+ width=target_size,
262
+ generator=generator,
263
+ model_config=pl_module.model_config,
264
+ default_lora=True,
265
+ )
266
+ res.images[0].save(
267
+ os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
268
+ )
src/train/data.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter, ImageDraw
2
+ import cv2
3
+ import numpy as np
4
+ from torch.utils.data import Dataset
5
+ import torchvision.transforms as T
6
+ import random
7
+
8
+
9
+ class Subject200KDataset(Dataset):
10
+ def __init__(
11
+ self,
12
+ base_dataset,
13
+ condition_size: int = 512,
14
+ target_size: int = 512,
15
+ image_size: int = 512,
16
+ padding: int = 0,
17
+ condition_type: str = "subject",
18
+ drop_text_prob: float = 0.1,
19
+ drop_image_prob: float = 0.1,
20
+ return_pil_image: bool = False,
21
+ ):
22
+ self.base_dataset = base_dataset
23
+ self.condition_size = condition_size
24
+ self.target_size = target_size
25
+ self.image_size = image_size
26
+ self.padding = padding
27
+ self.condition_type = condition_type
28
+ self.drop_text_prob = drop_text_prob
29
+ self.drop_image_prob = drop_image_prob
30
+ self.return_pil_image = return_pil_image
31
+
32
+ self.to_tensor = T.ToTensor()
33
+
34
+ def __len__(self):
35
+ return len(self.base_dataset) * 2
36
+
37
+ def __getitem__(self, idx):
38
+ # If target is 0, left image is target, right image is condition
39
+ target = idx % 2
40
+ item = self.base_dataset[idx // 2]
41
+
42
+ # Crop the image to target and condition
43
+ image = item["image"]
44
+ left_img = image.crop(
45
+ (
46
+ self.padding,
47
+ self.padding,
48
+ self.image_size + self.padding,
49
+ self.image_size + self.padding,
50
+ )
51
+ )
52
+ right_img = image.crop(
53
+ (
54
+ self.image_size + self.padding * 2,
55
+ self.padding,
56
+ self.image_size * 2 + self.padding * 2,
57
+ self.image_size + self.padding,
58
+ )
59
+ )
60
+
61
+ # Get the target and condition image
62
+ target_image, condition_img = (
63
+ (left_img, right_img) if target == 0 else (right_img, left_img)
64
+ )
65
+
66
+ # Resize the image
67
+ condition_img = condition_img.resize(
68
+ (self.condition_size, self.condition_size)
69
+ ).convert("RGB")
70
+ target_image = target_image.resize(
71
+ (self.target_size, self.target_size)
72
+ ).convert("RGB")
73
+
74
+ # Get the description
75
+ description = item["description"][
76
+ "description_0" if target == 0 else "description_1"
77
+ ]
78
+
79
+ # Randomly drop text or image
80
+ drop_text = random.random() < self.drop_text_prob
81
+ drop_image = random.random() < self.drop_image_prob
82
+ if drop_text:
83
+ description = ""
84
+ if drop_image:
85
+ condition_img = Image.new(
86
+ "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
87
+ )
88
+
89
+ return {
90
+ "image": self.to_tensor(target_image),
91
+ "condition": self.to_tensor(condition_img),
92
+ "condition_type": self.condition_type,
93
+ "description": description,
94
+ # 16 is the downscale factor of the image
95
+ "position_delta": np.array([0, -self.condition_size // 16]),
96
+ **({"pil_image": image} if self.return_pil_image else {}),
97
+ }
98
+
99
+ class SceneDataset(Dataset):
100
+ def __init__(
101
+ self,
102
+ base_dataset,
103
+ condition_size: int = 512,
104
+ target_size: int = 512,
105
+ image_size: int = 512,
106
+ padding: int = 0,
107
+ condition_type: str = "scene",
108
+ drop_text_prob: float = 0.1,
109
+ drop_image_prob: float = 0.1,
110
+ return_pil_image: bool = False,
111
+ ):
112
+ self.base_dataset = base_dataset
113
+ self.condition_size = condition_size
114
+ self.target_size = target_size
115
+ self.image_size = image_size
116
+ self.padding = padding
117
+ self.condition_type = condition_type
118
+ self.drop_text_prob = drop_text_prob
119
+ self.drop_image_prob = drop_image_prob
120
+ self.return_pil_image = return_pil_image
121
+
122
+ self.to_tensor = T.ToTensor()
123
+
124
+ def __len__(self):
125
+ return len(self.base_dataset)
126
+
127
+ def __getitem__(self, idx):
128
+ # If target is 0, left image is target, right image is condition
129
+ # target = idx % 2
130
+ target = 1
131
+ item = self.base_dataset[idx // 2]
132
+
133
+ # Crop the image to target and condition
134
+ imageA = item["imageA"]
135
+ imageB = item["imageB"]
136
+
137
+ left_img = imageA
138
+ right_img = imageB
139
+
140
+ # Get the target and condition image
141
+ target_image, condition_img = (
142
+ (left_img, right_img) if target == 0 else (right_img, left_img)
143
+ )
144
+
145
+ # Resize the image
146
+ condition_img = condition_img.resize(
147
+ (self.condition_size, self.condition_size)
148
+ ).convert("RGB")
149
+ target_image = target_image.resize(
150
+ (self.target_size, self.target_size)
151
+ ).convert("RGB")
152
+
153
+ # Get the description
154
+ description = item["prompt"]
155
+
156
+ # Randomly drop text or image
157
+ drop_text = random.random() < self.drop_text_prob
158
+ drop_image = random.random() < self.drop_image_prob
159
+ if drop_text:
160
+ description = ""
161
+ if drop_image:
162
+ condition_img = Image.new(
163
+ "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
164
+ )
165
+
166
+ return {
167
+ "image": self.to_tensor(target_image),
168
+ "condition": self.to_tensor(condition_img),
169
+ "condition_type": self.condition_type,
170
+ "description": description,
171
+ "position_delta": np.array([0, -self.condition_size // 16]),
172
+ **({"pil_image": [target_image, condition_img]} if self.return_pil_image else {}),
173
+ }
174
+
175
+
176
+
177
+
178
+ class ImageConditionDataset(Dataset):
179
+ def __init__(
180
+ self,
181
+ base_dataset,
182
+ condition_size: int = 512,
183
+ target_size: int = 512,
184
+ condition_type: str = "canny",
185
+ drop_text_prob: float = 0.1,
186
+ drop_image_prob: float = 0.1,
187
+ return_pil_image: bool = False,
188
+ position_scale=1.0,
189
+ ):
190
+ self.base_dataset = base_dataset
191
+ self.condition_size = condition_size
192
+ self.target_size = target_size
193
+ self.condition_type = condition_type
194
+ self.drop_text_prob = drop_text_prob
195
+ self.drop_image_prob = drop_image_prob
196
+ self.return_pil_image = return_pil_image
197
+ self.position_scale = position_scale
198
+
199
+ self.to_tensor = T.ToTensor()
200
+
201
+ def __len__(self):
202
+ return len(self.base_dataset)
203
+
204
+ @property
205
+ def depth_pipe(self):
206
+ if not hasattr(self, "_depth_pipe"):
207
+ from transformers import pipeline
208
+
209
+ self._depth_pipe = pipeline(
210
+ task="depth-estimation",
211
+ model="LiheYoung/depth-anything-small-hf",
212
+ device="cpu",
213
+ )
214
+ return self._depth_pipe
215
+
216
+ def _get_canny_edge(self, img):
217
+ resize_ratio = self.condition_size / max(img.size)
218
+ img = img.resize(
219
+ (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
220
+ )
221
+ img_np = np.array(img)
222
+ img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
223
+ edges = cv2.Canny(img_gray, 100, 200)
224
+ return Image.fromarray(edges).convert("RGB")
225
+
226
+ def __getitem__(self, idx):
227
+ image = self.base_dataset[idx]["jpg"]
228
+ image = image.resize((self.target_size, self.target_size)).convert("RGB")
229
+ description = self.base_dataset[idx]["json"]["prompt"]
230
+
231
+ enable_scale = random.random() < 1
232
+ if not enable_scale:
233
+ condition_size = int(self.condition_size * self.position_scale)
234
+ position_scale = 1.0
235
+ else:
236
+ condition_size = self.condition_size
237
+ position_scale = self.position_scale
238
+
239
+ # Get the condition image
240
+ position_delta = np.array([0, 0])
241
+ if self.condition_type == "canny":
242
+ condition_img = self._get_canny_edge(image)
243
+ elif self.condition_type == "coloring":
244
+ condition_img = (
245
+ image.resize((condition_size, condition_size))
246
+ .convert("L")
247
+ .convert("RGB")
248
+ )
249
+ elif self.condition_type == "deblurring":
250
+ blur_radius = random.randint(1, 10)
251
+ condition_img = (
252
+ image.convert("RGB")
253
+ .filter(ImageFilter.GaussianBlur(blur_radius))
254
+ .resize((condition_size, condition_size))
255
+ .convert("RGB")
256
+ )
257
+ elif self.condition_type == "depth":
258
+ condition_img = self.depth_pipe(image)["depth"].convert("RGB")
259
+ condition_img = condition_img.resize((condition_size, condition_size))
260
+ elif self.condition_type == "depth_pred":
261
+ condition_img = image
262
+ image = self.depth_pipe(condition_img)["depth"].convert("RGB")
263
+ description = f"[depth] {description}"
264
+ elif self.condition_type == "fill":
265
+ condition_img = image.resize((condition_size, condition_size)).convert(
266
+ "RGB"
267
+ )
268
+ w, h = image.size
269
+ x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
270
+ y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
271
+ mask = Image.new("L", image.size, 0)
272
+ draw = ImageDraw.Draw(mask)
273
+ draw.rectangle([x1, y1, x2, y2], fill=255)
274
+ if random.random() > 0.5:
275
+ mask = Image.eval(mask, lambda a: 255 - a)
276
+ condition_img = Image.composite(
277
+ image, Image.new("RGB", image.size, (0, 0, 0)), mask
278
+ )
279
+ elif self.condition_type == "sr":
280
+ condition_img = image.resize((condition_size, condition_size)).convert(
281
+ "RGB"
282
+ )
283
+ position_delta = np.array([0, -condition_size // 16])
284
+
285
+ else:
286
+ raise ValueError(f"Condition type {self.condition_type} not implemented")
287
+
288
+ # Randomly drop text or image
289
+ drop_text = random.random() < self.drop_text_prob
290
+ drop_image = random.random() < self.drop_image_prob
291
+ if drop_text:
292
+ description = ""
293
+ if drop_image:
294
+ condition_img = Image.new(
295
+ "RGB", (condition_size, condition_size), (0, 0, 0)
296
+ )
297
+
298
+ return {
299
+ "image": self.to_tensor(image),
300
+ "condition": self.to_tensor(condition_img),
301
+ "condition_type": self.condition_type,
302
+ "description": description,
303
+ "position_delta": position_delta,
304
+ **({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
305
+ **({"position_scale": position_scale} if position_scale != 1.0 else {}),
306
+ }
307
+
308
+
309
+ class CartoonDataset(Dataset):
310
+ def __init__(
311
+ self,
312
+ base_dataset,
313
+ condition_size: int = 1024,
314
+ target_size: int = 1024,
315
+ image_size: int = 1024,
316
+ padding: int = 0,
317
+ condition_type: str = "cartoon",
318
+ drop_text_prob: float = 0.1,
319
+ drop_image_prob: float = 0.1,
320
+ return_pil_image: bool = False,
321
+ ):
322
+ self.base_dataset = base_dataset
323
+ self.condition_size = condition_size
324
+ self.target_size = target_size
325
+ self.image_size = image_size
326
+ self.padding = padding
327
+ self.condition_type = condition_type
328
+ self.drop_text_prob = drop_text_prob
329
+ self.drop_image_prob = drop_image_prob
330
+ self.return_pil_image = return_pil_image
331
+
332
+ self.to_tensor = T.ToTensor()
333
+
334
+ def __len__(self):
335
+ return len(self.base_dataset)
336
+
337
+ def __getitem__(self, idx):
338
+ data = self.base_dataset[idx]
339
+ condition_img = data["condition"]
340
+ target_image = data["target"]
341
+
342
+ # Tag
343
+ tag = data["tags"][0]
344
+
345
+ target_description = data["target_description"]
346
+
347
+ description = {
348
+ "lion": "lion like animal",
349
+ "bear": "bear like animal",
350
+ "gorilla": "gorilla like animal",
351
+ "dog": "dog like animal",
352
+ "elephant": "elephant like animal",
353
+ "eagle": "eagle like bird",
354
+ "tiger": "tiger like animal",
355
+ "owl": "owl like bird",
356
+ "woman": "woman",
357
+ "parrot": "parrot like bird",
358
+ "mouse": "mouse like animal",
359
+ "man": "man",
360
+ "pigeon": "pigeon like bird",
361
+ "girl": "girl",
362
+ "panda": "panda like animal",
363
+ "crocodile": "crocodile like animal",
364
+ "rabbit": "rabbit like animal",
365
+ "boy": "boy",
366
+ "monkey": "monkey like animal",
367
+ "cat": "cat like animal",
368
+ }
369
+
370
+ # Resize the image
371
+ condition_img = condition_img.resize(
372
+ (self.condition_size, self.condition_size)
373
+ ).convert("RGB")
374
+ target_image = target_image.resize(
375
+ (self.target_size, self.target_size)
376
+ ).convert("RGB")
377
+
378
+ # Process datum to create description
379
+ description = data.get(
380
+ "description",
381
+ f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.",
382
+ )
383
+
384
+ # Randomly drop text or image
385
+ drop_text = random.random() < self.drop_text_prob
386
+ drop_image = random.random() < self.drop_image_prob
387
+ if drop_text:
388
+ description = ""
389
+ if drop_image:
390
+ condition_img = Image.new(
391
+ "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
392
+ )
393
+
394
+ return {
395
+ "image": self.to_tensor(target_image),
396
+ "condition": self.to_tensor(condition_img),
397
+ "condition_type": self.condition_type,
398
+ "description": description,
399
+ # 16 is the downscale factor of the image
400
+ "position_delta": np.array([0, -16]),
401
+ }
src/train/model.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ from diffusers.pipelines import FluxPipeline
3
+ import torch
4
+ from peft import LoraConfig, get_peft_model_state_dict
5
+
6
+ import prodigyopt
7
+
8
+ from ..flux.transformer import tranformer_forward
9
+ from ..flux.condition import Condition
10
+ from ..flux.pipeline_tools import encode_images, prepare_text_input
11
+
12
+
13
+ class OminiModel(L.LightningModule):
14
+ def __init__(
15
+ self,
16
+ flux_pipe_id: str,
17
+ lora_path: str = None,
18
+ lora_config: dict = None,
19
+ device: str = "cuda",
20
+ dtype: torch.dtype = torch.bfloat16,
21
+ model_config: dict = {},
22
+ optimizer_config: dict = None,
23
+ gradient_checkpointing: bool = False,
24
+ ):
25
+ # Initialize the LightningModule
26
+ super().__init__()
27
+ self.model_config = model_config
28
+ self.optimizer_config = optimizer_config
29
+
30
+ # Load the Flux pipeline
31
+ self.flux_pipe: FluxPipeline = (
32
+ FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
33
+ )
34
+ self.transformer = self.flux_pipe.transformer
35
+ self.transformer.gradient_checkpointing = gradient_checkpointing
36
+ self.transformer.train()
37
+
38
+ # Freeze the Flux pipeline
39
+ self.flux_pipe.text_encoder.requires_grad_(False).eval()
40
+ self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
41
+ self.flux_pipe.vae.requires_grad_(False).eval()
42
+
43
+ # Initialize LoRA layers
44
+ self.lora_layers = self.init_lora(lora_path, lora_config)
45
+
46
+ self.to(device).to(dtype)
47
+
48
+ def init_lora(self, lora_path: str, lora_config: dict):
49
+ assert lora_path or lora_config
50
+ if lora_path:
51
+ # TODO: Implement this
52
+ raise NotImplementedError
53
+ else:
54
+ self.transformer.add_adapter(LoraConfig(**lora_config))
55
+ # TODO: Check if this is correct (p.requires_grad)
56
+ lora_layers = filter(
57
+ lambda p: p.requires_grad, self.transformer.parameters()
58
+ )
59
+ return list(lora_layers)
60
+
61
+ def save_lora(self, path: str):
62
+ FluxPipeline.save_lora_weights(
63
+ save_directory=path,
64
+ transformer_lora_layers=get_peft_model_state_dict(self.transformer),
65
+ safe_serialization=True,
66
+ )
67
+
68
+ def configure_optimizers(self):
69
+ # Freeze the transformer
70
+ self.transformer.requires_grad_(False)
71
+ opt_config = self.optimizer_config
72
+
73
+ # Set the trainable parameters
74
+ self.trainable_params = self.lora_layers
75
+
76
+ # Unfreeze trainable parameters
77
+ for p in self.trainable_params:
78
+ p.requires_grad_(True)
79
+
80
+ # Initialize the optimizer
81
+ if opt_config["type"] == "AdamW":
82
+ optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
83
+ elif opt_config["type"] == "Prodigy":
84
+ optimizer = prodigyopt.Prodigy(
85
+ self.trainable_params,
86
+ **opt_config["params"],
87
+ )
88
+ elif opt_config["type"] == "SGD":
89
+ optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
90
+ else:
91
+ raise NotImplementedError
92
+
93
+ return optimizer
94
+
95
+ def training_step(self, batch, batch_idx):
96
+ step_loss = self.step(batch)
97
+ self.log_loss = (
98
+ step_loss.item()
99
+ if not hasattr(self, "log_loss")
100
+ else self.log_loss * 0.95 + step_loss.item() * 0.05
101
+ )
102
+ return step_loss
103
+
104
+ def step(self, batch):
105
+ imgs = batch["image"]
106
+ conditions = batch["condition"]
107
+ condition_types = batch["condition_type"]
108
+ prompts = batch["description"]
109
+ position_delta = batch["position_delta"][0]
110
+ position_scale = float(batch.get("position_scale", [1.0])[0])
111
+
112
+ # Prepare inputs
113
+ with torch.no_grad():
114
+ # Prepare image input
115
+ x_0, img_ids = encode_images(self.flux_pipe, imgs)
116
+
117
+ # Prepare text input
118
+ prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
119
+ self.flux_pipe, prompts
120
+ )
121
+
122
+ # Prepare t and x_t
123
+ t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
124
+ x_1 = torch.randn_like(x_0).to(self.device)
125
+ t_ = t.unsqueeze(1).unsqueeze(1)
126
+ x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
127
+
128
+ # Prepare conditions
129
+ condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
130
+
131
+ # Add position delta
132
+ condition_ids[:, 1] += position_delta[0]
133
+ condition_ids[:, 2] += position_delta[1]
134
+
135
+ if position_scale != 1.0:
136
+ scale_bias = (position_scale - 1.0) / 2
137
+ condition_ids[:, 1] *= position_scale
138
+ condition_ids[:, 2] *= position_scale
139
+ condition_ids[:, 1] += scale_bias
140
+ condition_ids[:, 2] += scale_bias
141
+
142
+ # Prepare condition type
143
+ condition_type_ids = torch.tensor(
144
+ [
145
+ Condition.get_type_id(condition_type)
146
+ for condition_type in condition_types
147
+ ]
148
+ ).to(self.device)
149
+ condition_type_ids = (
150
+ torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
151
+ ).unsqueeze(1)
152
+
153
+ # Prepare guidance
154
+ guidance = (
155
+ torch.ones_like(t).to(self.device)
156
+ if self.transformer.config.guidance_embeds
157
+ else None
158
+ )
159
+
160
+ # Forward pass
161
+ transformer_out = tranformer_forward(
162
+ self.transformer,
163
+ # Model config
164
+ model_config=self.model_config,
165
+ # Inputs of the condition (new feature)
166
+ condition_latents=condition_latents,
167
+ condition_ids=condition_ids,
168
+ condition_type_ids=condition_type_ids,
169
+ # Inputs to the original transformer
170
+ hidden_states=x_t,
171
+ timestep=t,
172
+ guidance=guidance,
173
+ pooled_projections=pooled_prompt_embeds,
174
+ encoder_hidden_states=prompt_embeds,
175
+ txt_ids=text_ids,
176
+ img_ids=img_ids,
177
+ joint_attention_kwargs=None,
178
+ return_dict=False,
179
+ )
180
+ pred = transformer_out[0]
181
+
182
+ # Compute loss
183
+ loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
184
+ self.last_t = t.mean().item()
185
+ return loss
src/train/train.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ import torch
3
+ import lightning as L
4
+ import yaml
5
+ import os
6
+ import time
7
+ import re
8
+
9
+ from datasets import load_dataset
10
+
11
+ from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset, SceneDataset
12
+ from .model import OminiModel
13
+ from .callbacks import TrainingCallback
14
+ import safetensors.torch
15
+ from peft import PeftModel
16
+
17
+ import os
18
+ from PIL import Image
19
+ import pandas as pd
20
+ from torch.utils.data import Dataset
21
+
22
+ from torchvision import transforms
23
+ from torch.utils.data import DataLoader
24
+
25
+ class LocalSubjectsDataset(Dataset):
26
+ def __init__(self, csv_file, image_dir, transform=None):
27
+ self.data = pd.read_csv(csv_file)
28
+ self.image_dir = image_dir
29
+ self.transform = transform
30
+ self.features = {
31
+ 'imageA': 'PIL.Image',
32
+ 'prompt': 'str',
33
+ 'imageB': 'PIL.Image'
34
+ }
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, idx):
40
+ # 获取图片A、描述和图片B的文件名
41
+ imgA_value = self.data.iloc[idx]['imageA']
42
+ if isinstance(imgA_value, pd.Series):
43
+ imgA_value = imgA_value.values[0]
44
+ imgA_name = os.path.join(self.image_dir, str(imgA_value))
45
+
46
+ prompt = self.data.iloc[idx]['prompt']
47
+ imgB_value = self.data.iloc[idx]['imageB']
48
+ if isinstance(imgB_value, pd.Series):
49
+ imgB_value = imgB_value.values[0]
50
+ imgB_name = os.path.join(self.image_dir, str(imgB_value))
51
+
52
+ imageA = Image.open(imgA_name).convert("RGB")
53
+ imageB = Image.open(imgB_name).convert("RGB")
54
+
55
+ if self.transform:
56
+ imageA = self.transform(imageA)
57
+ imageB = self.transform(imageB)
58
+
59
+ sample = {'imageA': imageA, 'prompt': prompt, 'imageB': imageB}
60
+ return sample
61
+
62
+ transform = transforms.Compose([
63
+ transforms.Resize((600, 600)),
64
+ # transforms.ToTensor(),
65
+ ])
66
+
67
+
68
+ def get_rank():
69
+ try:
70
+ rank = int(os.environ.get("LOCAL_RANK"))
71
+ except:
72
+ rank = 0
73
+ return rank
74
+
75
+
76
+ def get_config():
77
+ config_path = os.environ.get("XFL_CONFIG")
78
+ assert config_path is not None, "Please set the XFL_CONFIG environment variable"
79
+ with open(config_path, "r") as f:
80
+ config = yaml.safe_load(f)
81
+ return config
82
+
83
+
84
+ def init_wandb(wandb_config, run_name):
85
+ import wandb
86
+ wandb.init(
87
+ project=wandb_config["project"],
88
+ name=run_name,
89
+ config={},
90
+ )
91
+
92
+
93
+ def main():
94
+ # Initialize
95
+ is_main_process, rank = get_rank() == 0, get_rank()
96
+ torch.cuda.set_device(rank)
97
+ config = get_config()
98
+ training_config = config["train"]
99
+ run_name = time.strftime("%Y%m%d-%H%M%S")
100
+
101
+ # Initialize WanDB
102
+ wandb_config = training_config.get("wandb", None)
103
+ if wandb_config is not None and is_main_process:
104
+ init_wandb(wandb_config, run_name)
105
+
106
+ print("Rank:", rank)
107
+ if is_main_process:
108
+ print("Config:", config)
109
+
110
+ # Initialize dataset and dataloader
111
+ if training_config["dataset"]["type"] == "scene":
112
+ dataset = LocalSubjectsDataset(csv_file='csv_path', image_dir='images_path', transform=transform)
113
+ data_valid = dataset
114
+ print(data_valid.features)
115
+ print(len(data_valid))
116
+ print(training_config["dataset"])
117
+ dataset = SceneDataset(
118
+ data_valid,
119
+ condition_size=training_config["dataset"]["condition_size"],
120
+ target_size=training_config["dataset"]["target_size"],
121
+ image_size=training_config["dataset"]["image_size"],
122
+ padding=training_config["dataset"]["padding"],
123
+ condition_type=training_config["condition_type"],
124
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
125
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
126
+ )
127
+ elif training_config["dataset"]["type"] == "img":
128
+ # Load dataset text-to-image-2M
129
+ dataset = load_dataset(
130
+ "webdataset",
131
+ data_files={"train": training_config["dataset"]["urls"]},
132
+ split="train",
133
+ cache_dir="cache/t2i2m",
134
+ num_proc=32,
135
+ )
136
+ dataset = ImageConditionDataset(
137
+ dataset,
138
+ condition_size=training_config["dataset"]["condition_size"],
139
+ target_size=training_config["dataset"]["target_size"],
140
+ condition_type=training_config["condition_type"],
141
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
142
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
143
+ position_scale=training_config["dataset"].get("position_scale", 1.0),
144
+ )
145
+ elif training_config["dataset"]["type"] == "cartoon":
146
+ dataset = load_dataset("saquiboye/oye-cartoon", split="train")
147
+ dataset = CartoonDataset(
148
+ dataset,
149
+ condition_size=training_config["dataset"]["condition_size"],
150
+ target_size=training_config["dataset"]["target_size"],
151
+ image_size=training_config["dataset"]["image_size"],
152
+ padding=training_config["dataset"]["padding"],
153
+ condition_type=training_config["condition_type"],
154
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
155
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
156
+ )
157
+ elif training_config["dataset"]["type"] == "scene":
158
+ dataset = dataset
159
+ else:
160
+ raise NotImplementedError
161
+
162
+ print("Dataset length:", len(dataset))
163
+ train_loader = DataLoader(
164
+ dataset,
165
+ batch_size=training_config["batch_size"],
166
+ shuffle=True,
167
+ num_workers=training_config["dataloader_workers"],
168
+ )
169
+ print("Trainloader generated.")
170
+
171
+ # Initialize model
172
+ trainable_model = OminiModel(
173
+ flux_pipe_id=config["flux_path"],
174
+ lora_config=training_config["lora_config"],
175
+ device=f"cuda",
176
+ dtype=getattr(torch, config["dtype"]),
177
+ optimizer_config=training_config["optimizer"],
178
+ model_config=config.get("model", {}),
179
+ gradient_checkpointing=training_config.get("gradient_checkpointing", False),
180
+ )
181
+
182
+ training_callbacks = (
183
+ [TrainingCallback(run_name, training_config=training_config)]
184
+ if is_main_process
185
+ else []
186
+ )
187
+
188
+ # Initialize trainer
189
+ trainer = L.Trainer(
190
+ accumulate_grad_batches=training_config["accumulate_grad_batches"],
191
+ callbacks=training_callbacks,
192
+ enable_checkpointing=False,
193
+ enable_progress_bar=False,
194
+ logger=False,
195
+ max_steps=training_config.get("max_steps", -1),
196
+ max_epochs=training_config.get("max_epochs", -1),
197
+ gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
198
+ )
199
+
200
+ setattr(trainer, "training_config", training_config)
201
+
202
+ # Save config
203
+ save_path = training_config.get("save_path", "./output")
204
+ if is_main_process:
205
+ os.makedirs(f"{save_path}/{run_name}")
206
+ with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
207
+ yaml.dump(config, f)
208
+
209
+ # Start training
210
+ trainer.fit(trainable_model, train_loader)
211
+
212
+
213
+ if __name__ == "__main__":
214
+ main()
train/README.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OminiControl Training 🛠️
2
+
3
+ ## Preparation
4
+
5
+ ### Setup
6
+ 1. **Environment**
7
+ ```bash
8
+ conda create -n omini python=3.10
9
+ conda activate omini
10
+ ```
11
+ 2. **Requirements**
12
+ ```bash
13
+ pip install -r train/requirements.txt
14
+ ```
15
+
16
+ ### Dataset
17
+ 1. Download dataset [Subject200K](https://huggingface.co/datasets/Yuanshi/Subjects200K). (**subject-driven generation**)
18
+ ```
19
+ bash train/script/data_download/data_download1.sh
20
+ ```
21
+ 2. Download dataset [text-to-image-2M](https://huggingface.co/datasets/jackyhate/text-to-image-2M). (**spatial control task**)
22
+ ```
23
+ bash train/script/data_download/data_download2.sh
24
+ ```
25
+ **Note:** By default, only a few files are downloaded. You can modify `data_download2.sh` to download additional datasets. Remember to update the config file to specify the training data accordingly.
26
+
27
+ ## Training
28
+
29
+ ### Start training training
30
+ **Config file path**: `./train/config`
31
+
32
+ **Scripts path**: `./train/script`
33
+
34
+ 1. Subject-driven generation
35
+ ```bash
36
+ bash train/script/train_subject.sh
37
+ ```
38
+ 2. Spatial control task
39
+ ```bash
40
+ bash train/script/train_canny.sh
41
+ ```
42
+
43
+ **Note**: Detailed WanDB settings and GPU settings can be found in the script files and the config files.
44
+
45
+ ### Other spatial control tasks
46
+ This repository supports 5 spatial control tasks:
47
+ 1. Canny edge to image (`canny`)
48
+ 2. Image colorization (`coloring`)
49
+ 3. Image deblurring (`deblurring`)
50
+ 4. Depth map to image (`depth`)
51
+ 5. Image to depth map (`depth_pred`)
52
+ 6. Image inpainting (`fill`)
53
+ 7. Super resolution (`sr`)
54
+
55
+ You can modify the `condition_type` parameter in config file `config/canny_512.yaml` to switch between different tasks.
56
+
57
+ ### Customize your own task
58
+ You can customize your own task by constructing a new dataset and modifying the training code.
59
+
60
+ <details>
61
+ <summary>Instructions</summary>
62
+
63
+ 1. **Dataset** :
64
+
65
+ Construct a new dataset with the following format: (`src/train/data.py`)
66
+ ```python
67
+ class MyDataset(Dataset):
68
+ def __init__(self, ...):
69
+ ...
70
+ def __len__(self):
71
+ ...
72
+ def __getitem__(self, idx):
73
+ ...
74
+ return {
75
+ "image": image,
76
+ "condition": condition_img,
77
+ "condition_type": "your_condition_type",
78
+ "description": description,
79
+ "position_delta": position_delta
80
+ }
81
+ ```
82
+ **Note:** For spatial control tasks, set the `position_delta` to be `[0, 0]`. For non-spatial control tasks, set `position_delta` to be `[0, -condition_width // 16]`.
83
+ 2. **Condition**:
84
+
85
+ Add a new condition type in the `Condition` class. (`src/flux/condition.py`)
86
+ ```python
87
+ condition_dict = {
88
+ ...
89
+ "your_condition_type": your_condition_id_number, # Add your condition type here
90
+ }
91
+ ...
92
+ if condition_type in [
93
+ ...
94
+ "your_condition_type", # Add your condition type here
95
+ ]:
96
+ ...
97
+ ```
98
+ 3. **Test**:
99
+
100
+ Add a new test function for your task. (`src/train/callbacks.py`)
101
+ ```python
102
+ if self.condition_type == "your_condition_type":
103
+ condition_img = (
104
+ Image.open("images/vase.jpg")
105
+ .resize((condition_size, condition_size))
106
+ .convert("RGB")
107
+ )
108
+ ...
109
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
110
+ ```
111
+
112
+ 4. **Import relevant dataset in the training script**
113
+ Update the file in the following section. (`src/train/train.py`)
114
+ ```python
115
+ from .data import (
116
+ ImageConditionDataset,
117
+ Subject200KDateset,
118
+ MyDataset
119
+ )
120
+ ...
121
+
122
+ # Initialize dataset and dataloader
123
+ if training_config["dataset"]["type"] == "your_condition_type":
124
+ ...
125
+ ```
126
+
127
+ </details>
128
+
129
+ ## Hardware requirement
130
+ **Note**: Memory optimization (like dynamic T5 model loading) is pending implementation.
131
+
132
+ **Recommanded**
133
+ - Hardware: 2x NVIDIA H100 GPUs
134
+ - Memory: ~80GB GPU memory
135
+
136
+ **Minimal**
137
+ - Hardware: 1x NVIDIA L20 GPU
138
+ - Memory: ~48GB GPU memory
train/config/canny_512.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flux_path: "black-forest-labs/FLUX.1-dev"
2
+ dtype: "bfloat16"
3
+
4
+ model:
5
+ union_cond_attn: true
6
+ add_cond_attn: false
7
+ latent_lora: false
8
+
9
+ train:
10
+ batch_size: 1
11
+ accumulate_grad_batches: 1
12
+ dataloader_workers: 5
13
+ save_interval: 1000
14
+ sample_interval: 100
15
+ max_steps: -1
16
+ gradient_checkpointing: true
17
+ save_path: "runs"
18
+
19
+ # Specify the type of condition to use.
20
+ # Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill"]
21
+ condition_type: "canny"
22
+ dataset:
23
+ type: "img"
24
+ urls:
25
+ - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000046.tar"
26
+ - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000045.tar"
27
+ cache_name: "data_512_2M"
28
+ condition_size: 512
29
+ target_size: 512
30
+ drop_text_prob: 0.1
31
+ drop_image_prob: 0.1
32
+
33
+ wandb:
34
+ project: "OminiControl"
35
+
36
+ lora_config:
37
+ r: 4
38
+ lora_alpha: 4
39
+ init_lora_weights: "gaussian"
40
+ target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
41
+
42
+ optimizer:
43
+ type: "Prodigy"
44
+ params:
45
+ lr: 1
46
+ use_bias_correction: true
47
+ safeguard_warmup: true
48
+ weight_decay: 0.01
train/config/cartoon_512.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flux_path: "black-forest-labs/FLUX.1-dev"
2
+ dtype: "bfloat16"
3
+
4
+ model:
5
+ union_cond_attn: true
6
+ add_cond_attn: false
7
+ latent_lora: false
8
+
9
+ train:
10
+ batch_size: 1
11
+ accumulate_grad_batches: 1
12
+ dataloader_workers: 8
13
+ save_interval: 1000
14
+ sample_interval: 100
15
+ max_steps: 15000
16
+ gradient_checkpointing: false
17
+ save_path: "runs"
18
+
19
+ condition_type: "cartoon"
20
+ dataset:
21
+ type: "cartoon"
22
+ condition_size: 512
23
+ target_size: 512
24
+ image_size: 512
25
+ padding: 0
26
+ drop_text_prob: 0.1
27
+ drop_image_prob: 0.0
28
+
29
+ wandb:
30
+ project: "OminiControl"
31
+
32
+ lora_config:
33
+ r: 4
34
+ lora_alpha: 4
35
+ init_lora_weights: "gaussian"
36
+ target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
37
+
38
+ optimizer:
39
+ type: "Prodigy"
40
+ params:
41
+ lr: 2
42
+ use_bias_correction: true
43
+ safeguard_warmup: true
44
+ weight_decay: 0.01
train/config/fill_1024.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flux_path: "black-forest-labs/FLUX.1-dev"
2
+ dtype: "bfloat16"
3
+
4
+ model:
5
+ union_cond_attn: true
6
+ add_cond_attn: false
7
+ latent_lora: false
8
+
9
+ train:
10
+ batch_size: 1
11
+ accumulate_grad_batches: 1
12
+ dataloader_workers: 5
13
+ save_interval: 1000
14
+ sample_interval: 100
15
+ max_steps: -1
16
+ gradient_checkpointing: true
17
+ save_path: "runs"
18
+
19
+ # Specify the type of condition to use.
20
+ # Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill"]
21
+ condition_type: "fill"
22
+ dataset:
23
+ type: "img"
24
+ urls:
25
+ - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_1024_10K/data_000000.tar"
26
+ cache_name: "data_1024_10K"
27
+ condition_size: 1024
28
+ target_size: 1024
29
+ drop_text_prob: 0.1
30
+ drop_image_prob: 0.1
31
+
32
+ wandb:
33
+ project: "OminiControl"
34
+
35
+ lora_config:
36
+ r: 4
37
+ lora_alpha: 4
38
+ init_lora_weights: "gaussian"
39
+ target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
40
+
41
+ optimizer:
42
+ type: "Prodigy"
43
+ params:
44
+ lr: 1
45
+ use_bias_correction: true
46
+ safeguard_warmup: true
47
+ weight_decay: 0.01
train/config/scene_512.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flux_path: "black-forest-labs/FLUX.1-dev"
2
+ dtype: "bfloat16"
3
+
4
+ model:
5
+ union_cond_attn: true
6
+ add_cond_attn: false
7
+ latent_lora: true
8
+
9
+ train:
10
+ batch_size: 1
11
+ accumulate_grad_batches: 1
12
+ dataloader_workers: 5
13
+ save_interval: 2000
14
+ sample_interval: 100
15
+ max_steps: -1
16
+ gradient_checkpointing: false
17
+ save_path: "save_path"
18
+
19
+ condition_type: "scene"
20
+ dataset:
21
+ type: "scene"
22
+ condition_size: 512
23
+ target_size: 512
24
+ image_size: 512
25
+ padding: 8
26
+ drop_text_prob: 0.1
27
+ drop_image_prob: 0.1
28
+
29
+ wandb:
30
+ project: "OminiControl"
31
+
32
+ lora_config:
33
+ r: 128
34
+ lora_alpha: 128
35
+ init_lora_weights: "gaussian"
36
+ target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
37
+
38
+
39
+ optimizer:
40
+ type: "Prodigy"
41
+ params:
42
+ lr: 1
43
+ use_bias_correction: true
44
+ safeguard_warmup: true
45
+ weight_decay: 0.01
train/config/sr_512.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flux_path: "black-forest-labs/FLUX.1-dev"
2
+ dtype: "bfloat16"
3
+
4
+ model:
5
+ union_cond_attn: true
6
+ add_cond_attn: false
7
+ latent_lora: false
8
+
9
+ train:
10
+ batch_size: 1
11
+ accumulate_grad_batches: 1
12
+ dataloader_workers: 5
13
+ save_interval: 1000
14
+ sample_interval: 100
15
+ max_steps: -1
16
+ gradient_checkpointing: true
17
+ save_path: "runs"
18
+
19
+ # Specify the type of condition to use.
20
+ # Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill", "sr"]
21
+ condition_type: "sr"
22
+ dataset:
23
+ type: "img"
24
+ urls:
25
+ - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000046.tar"
26
+ - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000045.tar"
27
+ cache_name: "data_512_2M"
28
+ condition_size: 256
29
+ target_size: 512
30
+ drop_text_prob: 0.1
31
+ drop_image_prob: 0.1
32
+
33
+ wandb:
34
+ project: "OminiControl"
35
+
36
+ lora_config:
37
+ r: 4
38
+ lora_alpha: 4
39
+ init_lora_weights: "gaussian"
40
+ target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
41
+
42
+ optimizer:
43
+ type: "Prodigy"
44
+ params:
45
+ lr: 1
46
+ use_bias_correction: true
47
+ safeguard_warmup: true
48
+ weight_decay: 0.01
train/config/subject_512.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flux_path: "black-forest-labs/FLUX.1-dev"
2
+ dtype: "bfloat16"
3
+
4
+ model:
5
+ union_cond_attn: true
6
+ add_cond_attn: false
7
+ latent_lora: true
8
+
9
+ train:
10
+ batch_size: 1
11
+ accumulate_grad_batches: 1
12
+ dataloader_workers: 5
13
+ save_interval: 1000
14
+ sample_interval: 100
15
+ max_steps: -1
16
+ gradient_checkpointing: true
17
+ save_path: "runs"
18
+
19
+ condition_type: "subject"
20
+ dataset:
21
+ type: "subject"
22
+ condition_size: 512
23
+ target_size: 512
24
+ image_size: 512
25
+ padding: 8
26
+ drop_text_prob: 0.1
27
+ drop_image_prob: 0.1
28
+
29
+ wandb:
30
+ project: "OminiControl"
31
+
32
+ lora_config:
33
+ r: 4
34
+ lora_alpha: 4
35
+ init_lora_weights: "gaussian"
36
+ target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
37
+
38
+ optimizer:
39
+ type: "Prodigy"
40
+ params:
41
+ lr: 1
42
+ use_bias_correction: true
43
+ safeguard_warmup: true
44
+ weight_decay: 0.01
train/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.31.0
2
+ transformers
3
+ peft
4
+ opencv-python
5
+ protobuf
6
+ sentencepiece
7
+ gradio
8
+ jupyter
9
+ torchao
10
+
11
+ lightning
12
+ datasets
13
+ torchvision
14
+ prodigyopt
15
+ wandb
train/script/data_download/data_download1.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ huggingface-cli download --repo-type dataset Yuanshi/Subjects200K
train/script/data_download/data_download2.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ huggingface-cli download --repo-type dataset jackyhate/text-to-image-2M data_512_2M/data_000045.tar
2
+ huggingface-cli download --repo-type dataset jackyhate/text-to-image-2M data_512_2M/data_000046.tar
3
+ huggingface-cli download --repo-type dataset jackyhate/text-to-image-2M data_1024_10K/data_000000.tar
train/script/train_canny.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Specify the config file path and the GPU devices to use
2
+ # export CUDA_VISIBLE_DEVICES=0,1
3
+
4
+ # Specify the config file path
5
+ export XFL_CONFIG=./train/config/canny_512.yaml
6
+
7
+ # Specify the WANDB API key
8
+ # export WANDB_API_KEY='YOUR_WANDB_API_KEY'
9
+
10
+ echo $XFL_CONFIG
11
+ export TOKENIZERS_PARALLELISM=true
12
+
13
+ accelerate launch --main_process_port 41353 -m src.train.train
train/script/train_cartoon.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Specify the config file path and the GPU devices to use
2
+ # export CUDA_VISIBLE_DEVICES=0,1
3
+
4
+ # Specify the config file path
5
+ export XFL_CONFIG=./train/config/cartoon_512.yaml
6
+
7
+ export HF_HUB_CACHE=./cache
8
+
9
+ # Specify the WANDB API key
10
+ # export WANDB_API_KEY='YOUR_WANDB_API_KEY'
11
+
12
+ echo $XFL_CONFIG
13
+ export TOKENIZERS_PARALLELISM=true
14
+
15
+ accelerate launch --main_process_port 41353 -m src.train.train
train/script/train_scene.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Specify the config file path and the GPU devices to use
2
+ # export CUDA_VISIBLE_DEVICES=0,1
3
+
4
+ # Specify the config file path
5
+ export XFL_CONFIG=./train/config/scene_512.yaml
6
+
7
+ # Specify the WANDB API key
8
+ # export WANDB_API_KEY='YOUR_WANDB_API_KEY'
9
+
10
+ echo $XFL_CONFIG
11
+ export TOKENIZERS_PARALLELISM=true
12
+
13
+ accelerate launch --main_process_port 41353 -m src.train.train
train/script/train_subject.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Specify the config file path and the GPU devices to use
2
+ # export CUDA_VISIBLE_DEVICES=0,1
3
+
4
+ # Specify the config file path
5
+ export XFL_CONFIG=./train/config/subject_512.yaml
6
+
7
+ # Specify the WANDB API key
8
+ # export WANDB_API_KEY='YOUR_WANDB_API_KEY'
9
+
10
+ echo $XFL_CONFIG
11
+ export TOKENIZERS_PARALLELISM=true
12
+
13
+ accelerate launch --main_process_port 41353 -m src.train.train
utils.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from diffusers.pipelines import FluxPipeline
4
+ from src.flux.condition import Condition
5
+ from PIL import Image
6
+ import argparse
7
+ import os
8
+ import json
9
+ import base64
10
+ import io
11
+ import re
12
+ from PIL import Image, ImageFilter
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from scipy.ndimage import binary_dilation
15
+ import cv2
16
+ import openai
17
+ from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
18
+
19
+
20
+ from src.flux.generate import generate, seed_everything
21
+
22
+ try:
23
+ from mmengine.visualization import Visualizer
24
+ except ImportError:
25
+ Visualizer = None
26
+ print("Warning: mmengine is not installed, visualization is disabled.")
27
+
28
+ import re
29
+
30
+ def encode_image_to_datauri(path, size=(512, 512)):
31
+ with Image.open(path).convert('RGB') as img:
32
+ img = img.resize(size, Image.LANCZOS)
33
+ buffer = io.BytesIO()
34
+ img.save(buffer, format='PNG')
35
+ b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
36
+ return b64
37
+ # return f"data:image/png;base64,{b64}"
38
+
39
+
40
+ @retry(
41
+ reraise=True,
42
+ wait=wait_exponential(min=1, max=60),
43
+ stop=stop_after_attempt(6),
44
+ retry=retry_if_exception_type((openai.error.RateLimitError, openai.error.APIError))
45
+ )
46
+ def cot_with_gpt(image_uri, instruction):
47
+ response = openai.ChatCompletion.create(
48
+ model="gpt-4o",
49
+ messages=[
50
+ {
51
+ "role": "user",
52
+ "content": [
53
+ {"type": "text", "text": f'''
54
+ Now you are an expert in image editing. Based on the given single image, what atomic image editing instructions should be if the user wants to {instruction}? Let's think step by step.
55
+ Atomic instructions include 13 categories as follows:
56
+ - Add: e.g.: add a car on the road
57
+ - Remove: e.g.: remove the sofa in the image
58
+ - Color Change: e.g.: change the color of the shoes to blue
59
+ - Material Change: e.g.: change the material of the sign like stone
60
+ - Action Change: e.g.: change the action of the boy to raising hands
61
+ - Expression Change: e.g.: change the expression to smile
62
+ - Replace: e.g.: replace the coffee with an apple
63
+ - Background Change: e.g.: change the background into forest
64
+ - Appearance Change: e.g.: make the cup have a floral pattern
65
+ - Move: e.g.: move the plane to the left
66
+ - Resize: e.g.: enlarge the clock
67
+ - Tone Transfer: e.g.: change the weather to foggy
68
+ - Style Change: e.g.: make the style of the image to cartoon
69
+ Respond *only* with a numbered list.
70
+ Each line must begin with the category in square brackets, then the instruction. Please strictly follow the atomic categories.
71
+ The operation (what) and the target (to what) are crystal clear.
72
+ Do not split replace to add and remove.
73
+ For example:
74
+ “1. [Add] add a car on the road\n
75
+ 2. [Color Change] change the color of the shoes to blue\n
76
+ 3. [Move] move the lamp to the left\n"
77
+ Do not include any extra text, explanations, JSON or markdown—just the list.
78
+ '''},
79
+ {
80
+ "type": "image_url",
81
+ "image_url": {
82
+ "url": f"data:image/jpeg;base64,{image_uri}"
83
+ }
84
+ },
85
+ ],
86
+ }
87
+ ],
88
+ max_tokens=300,
89
+ )
90
+ text = response.choices[0].message.content.strip()
91
+ print(text)
92
+
93
+ categories, instructions = extract_instructions(text)
94
+ return categories, instructions
95
+
96
+
97
+ def extract_instructions(text):
98
+ categories = []
99
+ instructions = []
100
+
101
+ pattern = r'^\s*\d+\.\s*\[(.*?)\]\s*(.*?)$'
102
+
103
+ for line in text.split('\n'):
104
+ line = line.strip()
105
+ if not line:
106
+ continue
107
+
108
+ match = re.match(pattern, line)
109
+ if match:
110
+ category = match.group(1).strip()
111
+ instruction = match.group(2).strip()
112
+
113
+ if category and instruction:
114
+ categories.append(category)
115
+ instructions.append(instruction)
116
+
117
+ return categories, instructions
118
+
119
+ def extract_last_bbox(result):
120
+ pattern = r'\[?<span data-type="inline-math" data-value="XCcoW15cJ10rKVwnLFxzKlxbXHMqKFxkKylccyosXHMqKFxkKylccyosXHMqKFxkKylccyosXHMqKFxkKylccypcXQ=="></span>\]?'
121
+ matches = re.findall(pattern, result)
122
+
123
+ if not matches:
124
+ simple_pattern = r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]'
125
+ simple_matches = re.findall(simple_pattern, result)
126
+ if simple_matches:
127
+ x0, y0, x1, y1 = map(int, simple_matches[-1])
128
+ return [x0, y0, x1, y1]
129
+ else:
130
+ print(f"No bounding boxes found, please try again: {result}")
131
+ return None
132
+
133
+ last_match = matches[-1]
134
+ x0, y0, x1, y1 = map(int, last_match[1:])
135
+ return x0, y0, x1, y1
136
+
137
+
138
+ def infer_with_DiT(task, image, instruction, category):
139
+ # seed_everything(3407)
140
+
141
+ if task == 'RoI Inpainting':
142
+ if category == 'Add' or category == 'Replace':
143
+ lora_path = "weights/add.safetensors"
144
+ added = extract_object_with_gpt(instruction)
145
+ instruction_dit = f"add {added} on the black region"
146
+ elif category == 'Remove' or category == 'Action Change':
147
+ lora_path = "weights/remove.safetensors"
148
+ instruction_dit = f"Fill the hole of the image"
149
+
150
+ condition = Condition("scene", image, position_delta=(0, 0))
151
+ elif task == 'RoI Editing':
152
+ image = Image.open(image).convert('RGB').resize((512, 512))
153
+ condition = Condition("scene", image, position_delta=(0, -32))
154
+ instruction_dit = instruction
155
+ if category == 'Action Change':
156
+ lora_path = "weights/action.safetensors"
157
+ elif category == 'Expression Change':
158
+ lora_path = "weights/expression.safetensors"
159
+ elif category == 'Add':
160
+ lora_path = "weights/addition.safetensors"
161
+ elif category == 'Material Change':
162
+ lora_path = "weights/material.safetensors"
163
+ elif category == 'Color Change':
164
+ lora_path = "weights/color.safetensors"
165
+ elif category == 'Background Change':
166
+ lora_path = "weights/bg.safetensors"
167
+ elif category == 'Appearance Change':
168
+ lora_path = "weights/appearance.safetensors"
169
+
170
+ elif task == 'RoI Compositioning':
171
+ lora_path = "weights/fusion.safetensors"
172
+ condition = Condition("scene", image, position_delta=(0, 0))
173
+ instruction_dit = "inpaint the black-bordered region so that the object's edges blend smoothly with the background"
174
+
175
+ elif task == 'Global Transformation':
176
+ image = Image.open(image).convert('RGB').resize((512, 512))
177
+ instruction_dit = instruction
178
+ lora_path = "weights/overall.safetensors"
179
+
180
+ condition = Condition("scene", image, position_delta=(0, -32))
181
+ else:
182
+ raise ValueError(f"Invalid task: '{task}'")
183
+ pipe = FluxPipeline.from_pretrained(
184
+ "black-forest-labs/FLUX.1-dev",
185
+ torch_dtype=torch.bfloat16
186
+ )
187
+
188
+ pipe = pipe.to("cuda")
189
+
190
+ pipe.load_lora_weights(
191
+ "Cicici1109/IEAP",
192
+ weight_name=lora_path,
193
+ adapter_name="scene",
194
+ )
195
+ result_img = generate(
196
+ pipe,
197
+ prompt=instruction_dit,
198
+ conditions=[condition],
199
+ config_path = "train/config/scene_512.yaml",
200
+ num_inference_steps=28,
201
+ height=512,
202
+ width=512,
203
+ ).images[0]
204
+ # result_img
205
+ if task == 'RoI Editing' and category == 'Action Change':
206
+ text_roi = extract_object_with_gpt(instruction)
207
+ instruction_loc = f"<image>Please segment {text_roi}."
208
+ # (model, tokenizer, image_path, instruction, work_dir, dilate):
209
+ img = result_img
210
+ # print(f"Instruction: {instruction_loc}")
211
+
212
+ model, tokenizer = load_model("ByteDance/Sa2VA-8B")
213
+
214
+ result = model.predict_forward(
215
+ image=img,
216
+ text=instruction_loc,
217
+ tokenizer=tokenizer,
218
+ )
219
+
220
+ prediction = result['prediction']
221
+ # print(f"Model Output: {prediction}")
222
+
223
+ if '[SEG]' in prediction and 'prediction_masks' in result:
224
+ pred_mask = result['prediction_masks'][0]
225
+ pred_mask_np = np.squeeze(np.array(pred_mask))
226
+
227
+ ## obtain region bbox
228
+ rows = np.any(pred_mask_np, axis=1)
229
+ cols = np.any(pred_mask_np, axis=0)
230
+ if not np.any(rows) or not np.any(cols):
231
+ print("Warning: Mask is empty, cannot compute bounding box")
232
+ return img
233
+
234
+ y0, y1 = np.where(rows)[0][[0, -1]]
235
+ x0, x1 = np.where(cols)[0][[0, -1]]
236
+
237
+ changed_instance = crop_masked_region(result_img, pred_mask_np)
238
+
239
+ return changed_instance, x0, y1, 1
240
+
241
+
242
+ return result_img
243
+
244
+ def load_model(model_path):
245
+ model = AutoModelForCausalLM.from_pretrained(
246
+ model_path,
247
+ torch_dtype="auto",
248
+ device_map="auto",
249
+ trust_remote_code=True
250
+ ).eval()
251
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
252
+ return model, tokenizer
253
+
254
+ def extract_object_with_gpt(instruction):
255
+ system_prompt = (
256
+ "You are a helpful assistant that extracts the object or target being edited in an image editing instruction. "
257
+ "Only return a concise noun phrase describing the object. "
258
+ "Examples:\n"
259
+ "- Input: 'Remove the dog' → Output: 'the dog'\n"
260
+ "- Input: 'Add a hat on the dog' → Output: 'a hat'\n"
261
+ "- Input: 'Replace the biggest bear with a tiger' → Output: 'the biggest bear'\n"
262
+ "- Input: 'Change the action of the girl to riding' → Output: 'the girl'\n"
263
+ "- Input: 'Move the red car on the lake' → Output: 'the red car'\n"
264
+ "- Input: 'Minify the carrot on the rabbit's hand' → Output: 'the carrot on the rabbit's hand'\n"
265
+ "- Input: 'Swap the location of the dog and the cat' → Output: 'the dog and the cat'\n"
266
+ "Now extract the object for this instruction:"
267
+ )
268
+
269
+ try:
270
+ response = openai.ChatCompletion.create(
271
+ model="gpt-3.5-turbo",
272
+ messages=[
273
+ {"role": "system", "content": system_prompt},
274
+ {"role": "user", "content": instruction}
275
+ ],
276
+ temperature=0.2,
277
+ max_tokens=20,
278
+ )
279
+ object_phrase = response.choices[0].message['content'].strip().strip('"')
280
+ print(f"Identified object: {object_phrase}")
281
+ return object_phrase
282
+ except Exception as e:
283
+ print(f"GPT extraction failed: {e}")
284
+ return instruction
285
+
286
+ def extract_region_with_gpt(instruction):
287
+ system_prompt = (
288
+ "You are a helpful assistant that extracts target region being edited in an image editing instruction. "
289
+ "Only return a concise noun phrase describing the target region. "
290
+ "Examples:\n"
291
+ "- Input: 'Add a red hat to the man on the left' → Output: 'the man on the left'\n"
292
+ "- Input: 'Add a cat beside the dog' → Output: 'the dog'\n"
293
+ "Now extract the target region for this instruction:"
294
+ )
295
+
296
+ try:
297
+ response = openai.ChatCompletion.create(
298
+ model="gpt-3.5-turbo",
299
+ messages=[
300
+ {"role": "system", "content": system_prompt},
301
+ {"role": "user", "content": instruction}
302
+ ],
303
+ temperature=0.2,
304
+ max_tokens=20,
305
+ )
306
+ object_phrase = response.choices[0].message['content'].strip().strip('"')
307
+ # print(f"Identified object: {object_phrase}")
308
+ return object_phrase
309
+ except Exception as e:
310
+ print(f"GPT extraction failed: {e}")
311
+ return instruction
312
+
313
+ def get_masked(mask, image):
314
+ if mask.shape[:2] != image.size[::-1]:
315
+ raise ValueError(f"Mask size {mask.shape[:2]} does not match image size {image.size}")
316
+
317
+ image_array = np.array(image)
318
+ image_array[mask] = [0, 0, 0]
319
+
320
+ return Image.fromarray(image_array)
321
+
322
+ def bbox_to_mask(x0, y0, x1, y1, image_shape=(512, 512), fill_value=True):
323
+ height, width = image_shape
324
+
325
+ mask = np.zeros((height, width), dtype=bool)
326
+
327
+ x0 = max(0, int(x0))
328
+ y0 = max(0, int(y0))
329
+ x1 = min(width, int(x1))
330
+ y1 = min(height, int(y1))
331
+
332
+ if x0 >= x1 or y0 >= y1:
333
+ print("Warning: Invalid bounding box coordinates")
334
+ return mask
335
+
336
+ mask[y0:y1, x0:x1] = fill_value
337
+
338
+ return mask
339
+
340
+ def combine_bbox(text, x0, y0, x1, y1):
341
+ bbox = [x0, y0, x1, y1]
342
+ return [(text, bbox)]
343
+
344
+ def crop_masked_region(image, pred_mask_np):
345
+ if not isinstance(image, Image.Image):
346
+ raise ValueError("The input image is not a PIL Image object")
347
+ if not isinstance(pred_mask_np, np.ndarray) or pred_mask_np.dtype != bool:
348
+ raise ValueError("pred_mask_np must be a NumPy array of boolean type")
349
+ if pred_mask_np.shape[:2] != image.size[::-1]:
350
+ raise ValueError(f"Mask size {pred_mask_np.shape[:2]} does not match image size {image.size}")
351
+
352
+ image_rgba = image.convert("RGBA")
353
+ image_array = np.array(image_rgba)
354
+
355
+ rows = np.any(pred_mask_np, axis=1)
356
+ cols = np.any(pred_mask_np, axis=0)
357
+
358
+ if not np.any(rows) or not np.any(cols):
359
+ print("Warning: Mask is empty, cannot compute bounding box")
360
+ return image_rgba
361
+
362
+ y0, y1 = np.where(rows)[0][[0, -1]]
363
+ x0, x1 = np.where(cols)[0][[0, -1]]
364
+
365
+ cropped_image = image_array[y0:y1+1, x0:x1+1].copy()
366
+ cropped_mask = pred_mask_np[y0:y1+1, x0:x1+1]
367
+
368
+ alpha_channel = np.ones(cropped_mask.shape, dtype=np.uint8) * 255
369
+ alpha_channel[~cropped_mask] = 0
370
+
371
+ cropped_image[:, :, 3] = alpha_channel
372
+
373
+ return Image.fromarray(cropped_image, mode='RGBA')
374
+
375
+ def roi_localization(image, instruction, category): # add, remove, replace, action change, move, resize
376
+ model, tokenizer = load_model("ByteDance/Sa2VA-8B")
377
+ if category == 'Add':
378
+ text_roi = extract_region_with_gpt(instruction)
379
+ else:
380
+ text_roi = extract_object_with_gpt(instruction)
381
+ instruction_loc = f"<image>Please segment {text_roi}."
382
+ img = Image.open(image).convert('RGB').resize((512, 512))
383
+ print(f"Processing image: {os.path.basename(image)}, Instruction: {instruction_loc}")
384
+
385
+ result = model.predict_forward(
386
+ image=img,
387
+ text=instruction_loc,
388
+ tokenizer=tokenizer,
389
+ )
390
+
391
+ prediction = result['prediction']
392
+ # print(f"Model Output: {prediction}")
393
+
394
+ if '[SEG]' in prediction and 'prediction_masks' in result:
395
+ pred_mask = result['prediction_masks'][0]
396
+ pred_mask_np = np.squeeze(np.array(pred_mask))
397
+ if category == 'Add':
398
+ ## obtain region bbox
399
+ rows = np.any(pred_mask_np, axis=1)
400
+ cols = np.any(pred_mask_np, axis=0)
401
+ if not np.any(rows) or not np.any(cols):
402
+ print("Warning: Mask is empty, cannot compute bounding box")
403
+ return img
404
+
405
+ y0, y1 = np.where(rows)[0][[0, -1]]
406
+ x0, x1 = np.where(cols)[0][[0, -1]]
407
+
408
+ ## obtain inpainting bbox
409
+ bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
410
+ # print(bbox)
411
+ x0, y0, x1, y1 = layout_add(bbox, instruction)
412
+ mask = bbox_to_mask(x0, y0, x1, y1)
413
+ ## make it black
414
+ masked_img = get_masked(mask, img)
415
+ elif category == 'Move' or category == 'Resize':
416
+ dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
417
+ masked_img = get_masked(dilated_original_mask, img)
418
+ ## obtain region bbox
419
+ rows = np.any(pred_mask_np, axis=1)
420
+ cols = np.any(pred_mask_np, axis=0)
421
+ if not np.any(rows) or not np.any(cols):
422
+ print("Warning: Mask is empty, cannot compute bounding box")
423
+ return img
424
+
425
+ y0, y1 = np.where(rows)[0][[0, -1]]
426
+ x0, x1 = np.where(cols)[0][[0, -1]]
427
+
428
+ ## obtain inpainting bbox
429
+ bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
430
+ # print(bbox)
431
+ x0_new, y0_new, x1_new, y1_new, = layout_change(bbox, instruction)
432
+ scale = (y1_new - y0_new) / (y1 - y0)
433
+ # print(scale)
434
+ changed_instance = crop_masked_region(img, pred_mask_np)
435
+
436
+ return masked_img, changed_instance, x0_new, y1_new, scale
437
+ else:
438
+ dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
439
+ masked_img = get_masked(dilated_original_mask, img)
440
+
441
+ return masked_img
442
+
443
+ else:
444
+ print("No valid mask found in the prediction.")
445
+ return None
446
+
447
+ def fusion(background, foreground, x, y, scale):
448
+ background = background.convert("RGBA")
449
+ bg_width, bg_height = background.size
450
+
451
+ fg_width, fg_height = foreground.size
452
+ new_size = (int(fg_width * scale), int(fg_height * scale))
453
+ foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
454
+
455
+ left = x
456
+ top = y - new_size[1]
457
+
458
+ canvas = Image.new('RGBA', (bg_width, bg_height), (0, 0, 0, 0))
459
+ canvas.paste(foreground_resized, (left, top), foreground_resized)
460
+ masked_foreground = process_edge(canvas, left, top, new_size)
461
+ result = Image.alpha_composite(background, masked_foreground)
462
+
463
+ return result
464
+
465
+ def process_edge(canvas, left, top, size):
466
+ width, height = size
467
+
468
+ region = canvas.crop((left, top, left + width, top + height))
469
+ alpha = region.getchannel('A')
470
+
471
+ dilated_alpha = alpha.filter(ImageFilter.MaxFilter(5))
472
+ eroded_alpha = alpha.filter(ImageFilter.MinFilter(3))
473
+
474
+ edge_mask = Image.new('L', (width, height), 0)
475
+ edge_pixels = edge_mask.load()
476
+ dilated_pixels = dilated_alpha.load()
477
+ eroded_pixels = eroded_alpha.load()
478
+
479
+ for y in range(height):
480
+ for x in range(width):
481
+ if dilated_pixels[x, y] > 0 and eroded_pixels[x, y] == 0:
482
+ edge_pixels[x, y] = 255
483
+
484
+ black_edge = Image.new('RGBA', (width, height), (0, 0, 0, 0))
485
+ black_edge.putalpha(edge_mask)
486
+
487
+ canvas.paste(black_edge, (left, top), black_edge)
488
+
489
+ return canvas
490
+
491
+ def combine_text_and_bbox(text_roi, x0, y0, x1, y1):
492
+ return [(text_roi, [x0, y0, x1, y1])]
493
+
494
+ @retry(
495
+ reraise=True,
496
+ wait=wait_exponential(min=1, max=60),
497
+ stop=stop_after_attempt(6),
498
+ retry=retry_if_exception_type((openai.error.RateLimitError, openai.error.APIError))
499
+ )
500
+ def layout_add(bbox, instruction):
501
+ response = openai.ChatCompletion.create(
502
+ model="gpt-4o",
503
+ messages=[
504
+ {
505
+ "role": "user",
506
+ "content": [
507
+ {"type": "text", "text": f'''
508
+ You are an intelligent bounding box editor. I will provide you with the current bounding boxes and an add editing instruction.
509
+ Your task is to determine the new bounding box of the added object. Let's think step by step.
510
+ The images are of size 512x512. The top-left corner has coordinate [0, 0]. The bottom-right corner has coordinnate [512, 512].
511
+ The bounding boxes should not go beyond the image boundaries. The new box must be large enough to reasonably encompass the added object in a visually appropriate way, allowing for partial overlap with existing objects when it comes to accessories like hat, necklace. etc.
512
+ Each bounding box should be in the format of (object name,[top-left x coordinate, top-left y coordinate, bottom-right x coordinate, bottom-right y coordinate]).
513
+ Only return the bounding box of the newly added object. Do not include the existing bounding boxes.
514
+ Please consider the semantic information of the layout, preserve semantic relations.
515
+ If needed, you can make reasonable guesses. Please refer to the examples below:
516
+ Input bounding boxes: [('a green car', [21, 281, 232, 440])]
517
+ Editing instruction: Add a bird on the green car.
518
+ Output bounding boxes: [('a bird', [80, 150, 180, 281])]
519
+ Input bounding boxes: [('stool', [300, 350, 380, 450])]
520
+ Editing instruction: Add a cat to the left of the stool.
521
+ Output bounding boxes: [('a cat', [180, 250, 300, 450])]
522
+
523
+ Here are some examples to illustrate appropriate overlapping for better visual effects:
524
+ Input bounding boxes: [('the white cat', [200, 300, 320, 420])]
525
+ Editing instruction: Add a hat on the white cat.
526
+ Output bounding boxes: [('a hat', [200, 150, 320, 330])]
527
+ Now, the current bounding boxes is {bbox}, the instruction is {instruction}.
528
+ '''},
529
+ ],
530
+ }
531
+ ],
532
+ max_tokens=1000,
533
+ )
534
+
535
+ result = response.choices[0].message.content.strip()
536
+
537
+ bbox = extract_last_bbox(result)
538
+ return bbox
539
+
540
+ @retry(
541
+ reraise=True,
542
+ wait=wait_exponential(min=1, max=60),
543
+ stop=stop_after_attempt(6),
544
+ retry=retry_if_exception_type((openai.error.RateLimitError, openai.error.APIError))
545
+ )
546
+ def layout_change(bbox, instruction):
547
+ response = openai.ChatCompletion.create(
548
+ model="gpt-4o",
549
+ messages=[
550
+ {
551
+ "role": "user",
552
+ "content": [
553
+ {"type": "text", "text": f'''
554
+ You are an intelligent bounding box editor. I will provide you with the current bounding boxes and the editing instruction.
555
+ Your task is to generate the new bounding boxes after editing.
556
+ The images are of size 512x512. The top-left corner has coordinate [0, 0]. The bottom-right corner has coordinnate [512, 512].
557
+ The bounding boxes should not overlap or go beyond the image boundaries.
558
+ Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, bottom-right x coordinate, bottom-right y coordinate]).
559
+ Do not add new objects or delete any object provided in the bounding boxes. Do not change the size or the shape of any object unless the instruction requires so.
560
+ Please consider the semantic information of the layout.
561
+ When resizing, keep the bottom-left corner fixed by default. When swaping locations, change according to the center point.
562
+ If needed, you can make reasonable guesses. Please refer to the examples below:
563
+
564
+ Input bounding boxes: [('a car', [21, 281, 232, 440])]
565
+ Editing instruction: Move the car to the right.
566
+ Output bounding boxes: [('a car', [121, 281, 332, 440])]
567
+
568
+ Input bounding boxes: [("bed", [50, 300, 450, 450]), ("pillow", [200, 200, 300, 230])]
569
+ Editing instruction: Move the pillow to the left side of the bed.
570
+ Output bounding boxes: [("bed", [50, 300, 450, 450]), ("pillow", [70, 270, 170, 300])]
571
+
572
+ Input bounding boxes: [("dog", [150, 250, 250, 300])]
573
+ Editing instruction: Enlarge the dog.
574
+ Output bounding boxes: [("dog", [150, 225, 300, 300])]
575
+
576
+ Input bounding boxes: [("chair", [100, 350, 200, 450]), ("lamp", [300, 200, 360, 300])]
577
+ Editing instruction: Swap the location of the chair and the lamp.
578
+ Output bounding boxes: [("chair", [280, 200, 380, 300]), ("lamp", [120, 350, 180, 450])]
579
+
580
+
581
+ Now, the current bounding boxes is {bbox}, the instruction is {instruction}. Let's think step by step, and output the edited layout.
582
+ '''},
583
+ ],
584
+ }
585
+ ],
586
+ max_tokens=1000,
587
+ )
588
+ result = response.choices[0].message.content.strip()
589
+
590
+ bbox = extract_last_bbox(result)
591
+ return bbox