mboss Mark Boss commited on
Commit
5a1f586
·
1 Parent(s): 270e8e8

Initial commit

Browse files

Co-authored-by: Mark Boss <hello@markboss.me>

.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IP-Adapter/
2
+ models/
3
+ sdxl_models/
4
+ .gradio/
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # UV
103
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ #uv.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ #poetry.lock
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ #pdm.lock
118
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119
+ # in version control.
120
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
121
+ .pdm.toml
122
+ .pdm-python
123
+ .pdm-build/
124
+
125
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126
+ __pypackages__/
127
+
128
+ # Celery stuff
129
+ celerybeat-schedule
130
+ celerybeat.pid
131
+
132
+ # SageMath parsed files
133
+ *.sage.py
134
+
135
+ # Environments
136
+ .env
137
+ .venv
138
+ env/
139
+ venv/
140
+ ENV/
141
+ env.bak/
142
+ venv.bak/
143
+
144
+ # Spyder project settings
145
+ .spyderproject
146
+ .spyproject
147
+
148
+ # Rope project settings
149
+ .ropeproject
150
+
151
+ # mkdocs documentation
152
+ /site
153
+
154
+ # mypy
155
+ .mypy_cache/
156
+ .dmypy.json
157
+ dmypy.json
158
+
159
+ # Pyre type checker
160
+ .pyre/
161
+
162
+ # pytype static type analyzer
163
+ .pytype/
164
+
165
+ # Cython debug symbols
166
+ cython_debug/
167
+
168
+ # PyCharm
169
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
172
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173
+ #.idea/
174
+
175
+ # Ruff stuff:
176
+ .ruff_cache/
177
+
178
+ # PyPI configuration file
179
+ .pypirc
LICENSE.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI COMMUNITY LICENSE AGREEMENT
2
+ Last Updated: July 5, 2024
3
+
4
+
5
+ I. INTRODUCTION
6
+
7
+ This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
8
+
9
+
10
+ This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
11
+
12
+
13
+ By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
14
+
15
+ II. RESEARCH & NON-COMMERCIAL USE LICENSE
16
+
17
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
18
+
19
+ III. COMMERCIAL USE LICENSE
20
+
21
+ Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations.
22
+ If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
23
+
24
+ IV. GENERAL TERMS
25
+
26
+ Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
27
+ a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified.
28
+ b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
29
+ c. Intellectual Property.
30
+ (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
31
+ (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
32
+ (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
33
+ (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
34
+ (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
35
+ d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
36
+ e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
37
+ f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
38
+ g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
39
+
40
+ V. DEFINITIONS
41
+
42
+ "Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
43
+ "Agreement" means this Stability AI Community License Agreement.
44
+ "AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
45
+ "Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
46
+ "Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
47
+ "Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time.
48
+ "Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
49
+ "Software" means Stability AI's proprietary software made available under this Agreement now or in the future.
50
+ "Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
51
+ "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
README.md CHANGED
@@ -1,12 +1,26 @@
1
  ---
2
- title: Marble
3
- emoji: 🏢
4
- colorFrom: gray
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.30.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MARBLE
3
+ emoji:
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.32.1
8
+ app_file: gradio_demo.py
9
  pinned: false
10
  ---
11
 
12
+ * **Repository**: [https://github.com/Stability-AI/marble](https://github.com/Stability-AI/marble)
13
+ * **Tech report**: [https://marblecontrol.github.io/static/MARBLE.pdf](https://marblecontrol.github.io/static/MARBLE.pdf)
14
+ * **Project page**: [https://marblecontrol.github.io](https://marblecontrol.github.io)
15
+
16
+ ## Citation
17
+ If you find MARBLE helpful in your research/applications, please cite using this BibTeX:
18
+
19
+ ```bibtex
20
+ @article{cheng2024marble,
21
+ title={MARBLE: Material Recomposition and Blending in CLIP-Space},
22
+ author={Cheng, Ta-Ying and Sharma, Prafull and Boss, Mark and Jampani, Varun},
23
+ journal={CVPR},
24
+ year={2025}
25
+ }
26
+ ```
gradio_demo.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ from marble import (
5
+ get_session,
6
+ run_blend,
7
+ run_parametric_control,
8
+ setup_control_mlps,
9
+ setup_pipeline,
10
+ )
11
+
12
+ # Setup the pipeline and control MLPs
13
+ control_mlps = setup_control_mlps()
14
+ ip_adapter = setup_pipeline()
15
+ get_session()
16
+
17
+ # Load example images
18
+ EXAMPLE_IMAGES = {
19
+ "blend": {
20
+ "target": "input_images/context_image/beetle.png",
21
+ "texture1": "input_images/texture/low_roughness.png",
22
+ "texture2": "input_images/texture/high_roughness.png",
23
+ },
24
+ "parametric": {
25
+ "target": "input_images/context_image/toy_car.png",
26
+ "texture": "input_images/texture/metal_bowl.png",
27
+ },
28
+ }
29
+
30
+
31
+ def blend_images(target_image, texture1, texture2, edit_strength):
32
+ """Blend between two texture images"""
33
+ result = run_blend(
34
+ ip_adapter, target_image, texture1, texture2, edit_strength=edit_strength
35
+ )
36
+ return result
37
+
38
+
39
+ def parametric_control(
40
+ target_image,
41
+ texture_image,
42
+ control_type,
43
+ metallic_strength,
44
+ roughness_strength,
45
+ transparency_strength,
46
+ glow_strength,
47
+ ):
48
+ """Apply parametric control based on selected control type"""
49
+ edit_mlps = {}
50
+
51
+ if control_type == "Roughness + Metallic":
52
+ edit_mlps = {
53
+ control_mlps["metallic"]: metallic_strength,
54
+ control_mlps["roughness"]: roughness_strength,
55
+ }
56
+ elif control_type == "Transparency":
57
+ edit_mlps = {
58
+ control_mlps["transparency"]: transparency_strength,
59
+ }
60
+ elif control_type == "Glow":
61
+ edit_mlps = {
62
+ control_mlps["glow"]: glow_strength,
63
+ }
64
+
65
+ # Use target image as texture if no texture is provided
66
+ texture_to_use = texture_image if texture_image is not None else target_image
67
+
68
+ result = run_parametric_control(
69
+ ip_adapter,
70
+ target_image,
71
+ edit_mlps,
72
+ texture_to_use,
73
+ )
74
+ return result
75
+
76
+
77
+ # Create the Gradio interface
78
+ with gr.Blocks(
79
+ title="MARBLE: Material Recomposition and Blending in CLIP-Space"
80
+ ) as demo:
81
+ gr.Markdown(
82
+ """
83
+ # MARBLE: Material Recomposition and Blending in CLIP-Space
84
+
85
+ <div style="display: flex; justify-content: flex-start; gap: 10px;>
86
+ <a href="https://arxiv.org/abs/"><img src="https://img.shields.io/badge/Arxiv-2501.04689-B31B1B.svg"></a>
87
+ <a href="https://github.com/Stability-AI/marble"><img src="https://img.shields.io/badge/Github-Marble-B31B1B.svg"></a>
88
+ </div>
89
+
90
+ MARBLE is a tool for material recomposition and blending in CLIP-Space.
91
+ We provide two modes of operation:
92
+ - **Texture Blending**: Blend the material properties of two texture images and apply it to a target image.
93
+ - **Parametric Control**: Apply parametric material control to a target image. You can either provide a texture image, transferring the material properties of the texture to the original image, or you can just provide a target image, and edit the material properties of the original image.
94
+ """
95
+ )
96
+
97
+ with gr.Row(variant="panel"):
98
+ with gr.Tabs():
99
+ with gr.TabItem("Texture Blending"):
100
+ with gr.Row(equal_height=False):
101
+ with gr.Column():
102
+ with gr.Row():
103
+ texture1 = gr.Image(label="Texture 1", type="pil")
104
+ texture2 = gr.Image(label="Texture 2", type="pil")
105
+ edit_strength = gr.Slider(
106
+ minimum=0.0,
107
+ maximum=1.0,
108
+ value=0.5,
109
+ step=0.1,
110
+ label="Blend Strength",
111
+ )
112
+ with gr.Column():
113
+ with gr.Row():
114
+ target_image = gr.Image(label="Target Image", type="pil")
115
+ blend_output = gr.Image(label="Blended Result")
116
+ blend_btn = gr.Button("Blend Textures")
117
+
118
+ # Add examples for blending
119
+ gr.Examples(
120
+ examples=[
121
+ [
122
+ Image.open(EXAMPLE_IMAGES["blend"]["target"]),
123
+ Image.open(EXAMPLE_IMAGES["blend"]["texture1"]),
124
+ Image.open(EXAMPLE_IMAGES["blend"]["texture2"]),
125
+ 0.5,
126
+ ]
127
+ ],
128
+ inputs=[target_image, texture1, texture2, edit_strength],
129
+ outputs=blend_output,
130
+ fn=blend_images,
131
+ cache_examples=True,
132
+ )
133
+
134
+ blend_btn.click(
135
+ fn=blend_images,
136
+ inputs=[target_image, texture1, texture2, edit_strength],
137
+ outputs=blend_output,
138
+ )
139
+
140
+ with gr.TabItem("Parametric Control"):
141
+ with gr.Row(equal_height=False):
142
+ with gr.Column():
143
+ with gr.Row():
144
+ target_image_pc = gr.Image(label="Target Image", type="pil")
145
+ texture_image_pc = gr.Image(
146
+ label="Texture Image (Optional - uses target image if not provided)",
147
+ type="pil",
148
+ )
149
+ control_type = gr.Dropdown(
150
+ choices=["Roughness + Metallic", "Transparency", "Glow"],
151
+ value="Roughness + Metallic",
152
+ label="Control Type",
153
+ )
154
+
155
+ metallic_strength = gr.Slider(
156
+ minimum=-20,
157
+ maximum=20,
158
+ value=0,
159
+ step=0.1,
160
+ label="Metallic Strength",
161
+ visible=True,
162
+ )
163
+ roughness_strength = gr.Slider(
164
+ minimum=-1,
165
+ maximum=1,
166
+ value=0,
167
+ step=0.1,
168
+ label="Roughness Strength",
169
+ visible=True,
170
+ )
171
+ transparency_strength = gr.Slider(
172
+ minimum=0,
173
+ maximum=4,
174
+ value=0,
175
+ step=0.1,
176
+ label="Transparency Strength",
177
+ visible=False,
178
+ )
179
+ glow_strength = gr.Slider(
180
+ minimum=0,
181
+ maximum=3,
182
+ value=0,
183
+ step=0.1,
184
+ label="Glow Strength",
185
+ visible=False,
186
+ )
187
+ control_btn = gr.Button("Apply Control")
188
+
189
+ with gr.Column():
190
+ control_output = gr.Image(label="Result")
191
+
192
+ def update_slider_visibility(control_type):
193
+ return [
194
+ gr.update(visible=control_type == "Roughness + Metallic"),
195
+ gr.update(visible=control_type == "Roughness + Metallic"),
196
+ gr.update(visible=control_type == "Transparency"),
197
+ gr.update(visible=control_type == "Glow"),
198
+ ]
199
+
200
+ control_type.change(
201
+ fn=update_slider_visibility,
202
+ inputs=[control_type],
203
+ outputs=[
204
+ metallic_strength,
205
+ roughness_strength,
206
+ transparency_strength,
207
+ glow_strength,
208
+ ],
209
+ show_progress=False,
210
+ )
211
+
212
+ # Add examples for parametric control
213
+ gr.Examples(
214
+ examples=[
215
+ [
216
+ Image.open(EXAMPLE_IMAGES["parametric"]["target"]),
217
+ Image.open(EXAMPLE_IMAGES["parametric"]["texture"]),
218
+ "Roughness + Metallic",
219
+ 0, # metallic_strength
220
+ 0, # roughness_strength
221
+ 0, # transparency_strength
222
+ 0, # glow_strength
223
+ ],
224
+ [
225
+ Image.open(EXAMPLE_IMAGES["parametric"]["target"]),
226
+ Image.open(EXAMPLE_IMAGES["parametric"]["texture"]),
227
+ "Roughness + Metallic",
228
+ 20, # metallic_strength
229
+ 0, # roughness_strength
230
+ 0, # transparency_strength
231
+ 0, # glow_strength
232
+ ],
233
+ [
234
+ Image.open(EXAMPLE_IMAGES["parametric"]["target"]),
235
+ Image.open(EXAMPLE_IMAGES["parametric"]["texture"]),
236
+ "Roughness + Metallic",
237
+ 0, # metallic_strength
238
+ 1, # roughness_strength
239
+ 0, # transparency_strength
240
+ 0, # glow_strength
241
+ ],
242
+ ],
243
+ inputs=[
244
+ target_image_pc,
245
+ texture_image_pc,
246
+ control_type,
247
+ metallic_strength,
248
+ roughness_strength,
249
+ transparency_strength,
250
+ glow_strength,
251
+ ],
252
+ outputs=control_output,
253
+ fn=parametric_control,
254
+ cache_examples=True,
255
+ )
256
+
257
+ control_btn.click(
258
+ fn=parametric_control,
259
+ inputs=[
260
+ target_image_pc,
261
+ texture_image_pc,
262
+ control_type,
263
+ metallic_strength,
264
+ roughness_strength,
265
+ transparency_strength,
266
+ glow_strength,
267
+ ],
268
+ outputs=control_output,
269
+ )
270
+
271
+ if __name__ == "__main__":
272
+ demo.launch()
input_images/context_image/beetle.png ADDED

Git LFS Details

  • SHA256: f2ccc9844ec757f7654885a630a9f12f35c34bd6dfb7e9b6d55e8f356a5730b0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
input_images/context_image/genart_teapot.jpg ADDED

Git LFS Details

  • SHA256: af9b997d9ef41591a7957a457ab1bf3cec9b76411781f9b4908ec23280a7a738
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
input_images/context_image/toy_car.png ADDED

Git LFS Details

  • SHA256: 277cdad8fc902e5a902b676ec35ee128c09ef044cc5044d63ea0f188fd60ba86
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
input_images/context_image/white_car_night.jpg ADDED

Git LFS Details

  • SHA256: ada69d89974837a31fa5121a96991e895bb7e010561631c4e03cbaf1496928a1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.29 MB
input_images/depth/beetle.png ADDED

Git LFS Details

  • SHA256: 8677786cdc418c2c8f5c90f98f23db9bb4fe9e5af18b3a4257841e21ea415a52
  • Pointer size: 132 Bytes
  • Size of remote file: 2.1 MB
input_images/depth/toy_car.png ADDED

Git LFS Details

  • SHA256: bb745360933739d8a44086f5220458044b827a7eb2aa25803ad7710a3cd30916
  • Pointer size: 132 Bytes
  • Size of remote file: 2.1 MB
input_images/texture/high_roughness.png ADDED

Git LFS Details

  • SHA256: 50b9a3ef6c8a487a361da627a282d94a1d2570ddca8b4c05cfb072b937d046c8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
input_images/texture/low_roughness.png ADDED

Git LFS Details

  • SHA256: f61fc7a84c18e61ab9a5ed6b14c2f5f1ced1ade5dee8e779678c65bae8a5ffbb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
input_images/texture/metal_bowl.png ADDED

Git LFS Details

  • SHA256: fec45d29a16022aab1a8a0a86943c4b3ed5ffe57af1b3da01b8e139eec464c24
  • Pointer size: 131 Bytes
  • Size of remote file: 334 kB
ip_adapter_instantstyle/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
2
+
3
+ __all__ = [
4
+ "IPAdapter",
5
+ "IPAdapterPlus",
6
+ "IPAdapterPlusXL",
7
+ "IPAdapterXL",
8
+ "IPAdapterFull",
9
+ ]
ip_adapter_instantstyle/attention_processor.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=None,
15
+ cross_attention_dim=None,
16
+ ):
17
+ super().__init__()
18
+
19
+ def __call__(
20
+ self,
21
+ attn,
22
+ hidden_states,
23
+ encoder_hidden_states=None,
24
+ attention_mask=None,
25
+ temb=None,
26
+ ):
27
+ residual = hidden_states
28
+
29
+ if attn.spatial_norm is not None:
30
+ hidden_states = attn.spatial_norm(hidden_states, temb)
31
+
32
+ input_ndim = hidden_states.ndim
33
+
34
+ if input_ndim == 4:
35
+ batch_size, channel, height, width = hidden_states.shape
36
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
+
38
+ batch_size, sequence_length, _ = (
39
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+ )
41
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
42
+
43
+ if attn.group_norm is not None:
44
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45
+
46
+ query = attn.to_q(hidden_states)
47
+
48
+ if encoder_hidden_states is None:
49
+ encoder_hidden_states = hidden_states
50
+ elif attn.norm_cross:
51
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
+
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+
56
+ query = attn.head_to_batch_dim(query)
57
+ key = attn.head_to_batch_dim(key)
58
+ value = attn.head_to_batch_dim(value)
59
+
60
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
61
+ hidden_states = torch.bmm(attention_probs, value)
62
+ hidden_states = attn.batch_to_head_dim(hidden_states)
63
+
64
+ # linear proj
65
+ hidden_states = attn.to_out[0](hidden_states)
66
+ # dropout
67
+ hidden_states = attn.to_out[1](hidden_states)
68
+
69
+ if input_ndim == 4:
70
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
71
+
72
+ if attn.residual_connection:
73
+ hidden_states = hidden_states + residual
74
+
75
+ hidden_states = hidden_states / attn.rescale_output_factor
76
+
77
+ return hidden_states
78
+
79
+
80
+ class IPAttnProcessor(nn.Module):
81
+ r"""
82
+ Attention processor for IP-Adapater.
83
+ Args:
84
+ hidden_size (`int`):
85
+ The hidden size of the attention layer.
86
+ cross_attention_dim (`int`):
87
+ The number of channels in the `encoder_hidden_states`.
88
+ scale (`float`, defaults to 1.0):
89
+ the weight scale of image prompt.
90
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
91
+ The context length of the image features.
92
+ """
93
+
94
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
95
+ super().__init__()
96
+
97
+ self.hidden_size = hidden_size
98
+ self.cross_attention_dim = cross_attention_dim
99
+ self.scale = scale
100
+ self.num_tokens = num_tokens
101
+ self.skip = skip
102
+
103
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
105
+
106
+ def __call__(
107
+ self,
108
+ attn,
109
+ hidden_states,
110
+ encoder_hidden_states=None,
111
+ attention_mask=None,
112
+ temb=None,
113
+ ):
114
+ residual = hidden_states
115
+
116
+ if attn.spatial_norm is not None:
117
+ hidden_states = attn.spatial_norm(hidden_states, temb)
118
+
119
+ input_ndim = hidden_states.ndim
120
+
121
+ if input_ndim == 4:
122
+ batch_size, channel, height, width = hidden_states.shape
123
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
124
+
125
+ batch_size, sequence_length, _ = (
126
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
127
+ )
128
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
129
+
130
+ if attn.group_norm is not None:
131
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = attn.to_q(hidden_states)
134
+
135
+ if encoder_hidden_states is None:
136
+ encoder_hidden_states = hidden_states
137
+ else:
138
+ # get encoder_hidden_states, ip_hidden_states
139
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
140
+ encoder_hidden_states, ip_hidden_states = (
141
+ encoder_hidden_states[:, :end_pos, :],
142
+ encoder_hidden_states[:, end_pos:, :],
143
+ )
144
+ if attn.norm_cross:
145
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
146
+
147
+ key = attn.to_k(encoder_hidden_states)
148
+ value = attn.to_v(encoder_hidden_states)
149
+
150
+ query = attn.head_to_batch_dim(query)
151
+ key = attn.head_to_batch_dim(key)
152
+ value = attn.head_to_batch_dim(value)
153
+
154
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
155
+ hidden_states = torch.bmm(attention_probs, value)
156
+ hidden_states = attn.batch_to_head_dim(hidden_states)
157
+
158
+ if not self.skip:
159
+ # for ip-adapter
160
+ ip_key = self.to_k_ip(ip_hidden_states)
161
+ ip_value = self.to_v_ip(ip_hidden_states)
162
+
163
+ ip_key = attn.head_to_batch_dim(ip_key)
164
+ ip_value = attn.head_to_batch_dim(ip_value)
165
+
166
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
167
+ self.attn_map = ip_attention_probs
168
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
169
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
170
+
171
+ hidden_states = hidden_states + self.scale * ip_hidden_states
172
+
173
+ # linear proj
174
+ hidden_states = attn.to_out[0](hidden_states)
175
+ # dropout
176
+ hidden_states = attn.to_out[1](hidden_states)
177
+
178
+ if input_ndim == 4:
179
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
180
+
181
+ if attn.residual_connection:
182
+ hidden_states = hidden_states + residual
183
+
184
+ hidden_states = hidden_states / attn.rescale_output_factor
185
+
186
+ return hidden_states
187
+
188
+
189
+ class AttnProcessor2_0(torch.nn.Module):
190
+ r"""
191
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ hidden_size=None,
197
+ cross_attention_dim=None,
198
+ ):
199
+ super().__init__()
200
+ if not hasattr(F, "scaled_dot_product_attention"):
201
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
202
+
203
+ def __call__(
204
+ self,
205
+ attn,
206
+ hidden_states,
207
+ encoder_hidden_states=None,
208
+ attention_mask=None,
209
+ temb=None,
210
+ ):
211
+ residual = hidden_states
212
+
213
+ if attn.spatial_norm is not None:
214
+ hidden_states = attn.spatial_norm(hidden_states, temb)
215
+
216
+ input_ndim = hidden_states.ndim
217
+
218
+ if input_ndim == 4:
219
+ batch_size, channel, height, width = hidden_states.shape
220
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
221
+
222
+ batch_size, sequence_length, _ = (
223
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
224
+ )
225
+
226
+ if attention_mask is not None:
227
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
228
+ # scaled_dot_product_attention expects attention_mask shape to be
229
+ # (batch, heads, source_length, target_length)
230
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
231
+
232
+ if attn.group_norm is not None:
233
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
234
+
235
+ query = attn.to_q(hidden_states)
236
+
237
+ if encoder_hidden_states is None:
238
+ encoder_hidden_states = hidden_states
239
+ elif attn.norm_cross:
240
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
241
+
242
+ key = attn.to_k(encoder_hidden_states)
243
+ value = attn.to_v(encoder_hidden_states)
244
+
245
+ inner_dim = key.shape[-1]
246
+ head_dim = inner_dim // attn.heads
247
+
248
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+
250
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
252
+
253
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
254
+ # TODO: add support for attn.scale when we move to Torch 2.1
255
+ hidden_states = F.scaled_dot_product_attention(
256
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
257
+ )
258
+
259
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
260
+ hidden_states = hidden_states.to(query.dtype)
261
+
262
+ # linear proj
263
+ hidden_states = attn.to_out[0](hidden_states)
264
+ # dropout
265
+ hidden_states = attn.to_out[1](hidden_states)
266
+
267
+ if input_ndim == 4:
268
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
269
+
270
+ if attn.residual_connection:
271
+ hidden_states = hidden_states + residual
272
+
273
+ hidden_states = hidden_states / attn.rescale_output_factor
274
+
275
+ return hidden_states
276
+
277
+
278
+ class IPAttnProcessor2_0(torch.nn.Module):
279
+ r"""
280
+ Attention processor for IP-Adapater for PyTorch 2.0.
281
+ Args:
282
+ hidden_size (`int`):
283
+ The hidden size of the attention layer.
284
+ cross_attention_dim (`int`):
285
+ The number of channels in the `encoder_hidden_states`.
286
+ scale (`float`, defaults to 1.0):
287
+ the weight scale of image prompt.
288
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
289
+ The context length of the image features.
290
+ """
291
+
292
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
293
+ super().__init__()
294
+
295
+ if not hasattr(F, "scaled_dot_product_attention"):
296
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
297
+
298
+ self.hidden_size = hidden_size
299
+ self.cross_attention_dim = cross_attention_dim
300
+ self.scale = scale
301
+ self.num_tokens = num_tokens
302
+ self.skip = skip
303
+
304
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
305
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
306
+
307
+ def __call__(
308
+ self,
309
+ attn,
310
+ hidden_states,
311
+ encoder_hidden_states=None,
312
+ attention_mask=None,
313
+ temb=None,
314
+ ):
315
+ residual = hidden_states
316
+
317
+ if attn.spatial_norm is not None:
318
+ hidden_states = attn.spatial_norm(hidden_states, temb)
319
+
320
+ input_ndim = hidden_states.ndim
321
+
322
+ if input_ndim == 4:
323
+ batch_size, channel, height, width = hidden_states.shape
324
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
325
+
326
+ batch_size, sequence_length, _ = (
327
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
328
+ )
329
+
330
+ if attention_mask is not None:
331
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
332
+ # scaled_dot_product_attention expects attention_mask shape to be
333
+ # (batch, heads, source_length, target_length)
334
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
335
+
336
+ if attn.group_norm is not None:
337
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
338
+
339
+ query = attn.to_q(hidden_states)
340
+
341
+ if encoder_hidden_states is None:
342
+ encoder_hidden_states = hidden_states
343
+ else:
344
+ # get encoder_hidden_states, ip_hidden_states
345
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
346
+ encoder_hidden_states, ip_hidden_states = (
347
+ encoder_hidden_states[:, :end_pos, :],
348
+ encoder_hidden_states[:, end_pos:, :],
349
+ )
350
+ if attn.norm_cross:
351
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
352
+
353
+ key = attn.to_k(encoder_hidden_states)
354
+ value = attn.to_v(encoder_hidden_states)
355
+
356
+ inner_dim = key.shape[-1]
357
+ head_dim = inner_dim // attn.heads
358
+
359
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
360
+
361
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
362
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
363
+
364
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
365
+ # TODO: add support for attn.scale when we move to Torch 2.1
366
+ hidden_states = F.scaled_dot_product_attention(
367
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
368
+ )
369
+
370
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
371
+ hidden_states = hidden_states.to(query.dtype)
372
+
373
+ if not self.skip:
374
+ # for ip-adapter
375
+ ip_key = self.to_k_ip(ip_hidden_states)
376
+ ip_value = self.to_v_ip(ip_hidden_states)
377
+
378
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
379
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
380
+
381
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
382
+ # TODO: add support for attn.scale when we move to Torch 2.1
383
+ ip_hidden_states = F.scaled_dot_product_attention(
384
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
385
+ )
386
+ with torch.no_grad():
387
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
388
+ #print(self.attn_map.shape)
389
+
390
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
391
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
392
+
393
+ hidden_states = hidden_states + self.scale * ip_hidden_states
394
+
395
+ # linear proj
396
+ hidden_states = attn.to_out[0](hidden_states)
397
+ # dropout
398
+ hidden_states = attn.to_out[1](hidden_states)
399
+
400
+ if input_ndim == 4:
401
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
402
+
403
+ if attn.residual_connection:
404
+ hidden_states = hidden_states + residual
405
+
406
+ hidden_states = hidden_states / attn.rescale_output_factor
407
+
408
+ return hidden_states
409
+
410
+
411
+ ## for controlnet
412
+ class CNAttnProcessor:
413
+ r"""
414
+ Default processor for performing attention-related computations.
415
+ """
416
+
417
+ def __init__(self, num_tokens=4):
418
+ self.num_tokens = num_tokens
419
+
420
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
421
+ residual = hidden_states
422
+
423
+ if attn.spatial_norm is not None:
424
+ hidden_states = attn.spatial_norm(hidden_states, temb)
425
+
426
+ input_ndim = hidden_states.ndim
427
+
428
+ if input_ndim == 4:
429
+ batch_size, channel, height, width = hidden_states.shape
430
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
431
+
432
+ batch_size, sequence_length, _ = (
433
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
434
+ )
435
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
436
+
437
+ if attn.group_norm is not None:
438
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
439
+
440
+ query = attn.to_q(hidden_states)
441
+
442
+ if encoder_hidden_states is None:
443
+ encoder_hidden_states = hidden_states
444
+ else:
445
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
446
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
447
+ if attn.norm_cross:
448
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
449
+
450
+ key = attn.to_k(encoder_hidden_states)
451
+ value = attn.to_v(encoder_hidden_states)
452
+
453
+ query = attn.head_to_batch_dim(query)
454
+ key = attn.head_to_batch_dim(key)
455
+ value = attn.head_to_batch_dim(value)
456
+
457
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
458
+ hidden_states = torch.bmm(attention_probs, value)
459
+ hidden_states = attn.batch_to_head_dim(hidden_states)
460
+
461
+ # linear proj
462
+ hidden_states = attn.to_out[0](hidden_states)
463
+ # dropout
464
+ hidden_states = attn.to_out[1](hidden_states)
465
+
466
+ if input_ndim == 4:
467
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
468
+
469
+ if attn.residual_connection:
470
+ hidden_states = hidden_states + residual
471
+
472
+ hidden_states = hidden_states / attn.rescale_output_factor
473
+
474
+ return hidden_states
475
+
476
+
477
+ class CNAttnProcessor2_0:
478
+ r"""
479
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
480
+ """
481
+
482
+ def __init__(self, num_tokens=4):
483
+ if not hasattr(F, "scaled_dot_product_attention"):
484
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
485
+ self.num_tokens = num_tokens
486
+
487
+ def __call__(
488
+ self,
489
+ attn,
490
+ hidden_states,
491
+ encoder_hidden_states=None,
492
+ attention_mask=None,
493
+ temb=None,
494
+ ):
495
+ residual = hidden_states
496
+
497
+ if attn.spatial_norm is not None:
498
+ hidden_states = attn.spatial_norm(hidden_states, temb)
499
+
500
+ input_ndim = hidden_states.ndim
501
+
502
+ if input_ndim == 4:
503
+ batch_size, channel, height, width = hidden_states.shape
504
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
505
+
506
+ batch_size, sequence_length, _ = (
507
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
508
+ )
509
+
510
+ if attention_mask is not None:
511
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
512
+ # scaled_dot_product_attention expects attention_mask shape to be
513
+ # (batch, heads, source_length, target_length)
514
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
515
+
516
+ if attn.group_norm is not None:
517
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
518
+
519
+ query = attn.to_q(hidden_states)
520
+
521
+ if encoder_hidden_states is None:
522
+ encoder_hidden_states = hidden_states
523
+ else:
524
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
525
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
526
+ if attn.norm_cross:
527
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
528
+
529
+ key = attn.to_k(encoder_hidden_states)
530
+ value = attn.to_v(encoder_hidden_states)
531
+
532
+ inner_dim = key.shape[-1]
533
+ head_dim = inner_dim // attn.heads
534
+
535
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
536
+
537
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
538
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
539
+
540
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
541
+ # TODO: add support for attn.scale when we move to Torch 2.1
542
+ hidden_states = F.scaled_dot_product_attention(
543
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
544
+ )
545
+
546
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
547
+ hidden_states = hidden_states.to(query.dtype)
548
+
549
+ # linear proj
550
+ hidden_states = attn.to_out[0](hidden_states)
551
+ # dropout
552
+ hidden_states = attn.to_out[1](hidden_states)
553
+
554
+ if input_ndim == 4:
555
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
556
+
557
+ if attn.residual_connection:
558
+ hidden_states = hidden_states + residual
559
+
560
+ hidden_states = hidden_states / attn.rescale_output_factor
561
+
562
+ return hidden_states
ip_adapter_instantstyle/ip_adapter.py ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from typing import List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from diffusers import StableDiffusionPipeline
8
+ from diffusers.pipelines.controlnet import MultiControlNetModel
9
+ from PIL import Image
10
+ from safetensors import safe_open
11
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
12
+
13
+ from .utils import is_torch2_available, get_generator
14
+
15
+ L = 4
16
+
17
+
18
+ def pos_encode(x, L):
19
+ pos_encode = []
20
+
21
+ for freq in range(L):
22
+ pos_encode.append(torch.cos(2**freq * torch.pi * x))
23
+ pos_encode.append(torch.sin(2**freq * torch.pi * x))
24
+ pos_encode = torch.cat(pos_encode, dim=1)
25
+ return pos_encode
26
+
27
+
28
+ if is_torch2_available():
29
+ from .attention_processor import (
30
+ AttnProcessor2_0 as AttnProcessor,
31
+ )
32
+ from .attention_processor import (
33
+ CNAttnProcessor2_0 as CNAttnProcessor,
34
+ )
35
+ from .attention_processor import (
36
+ IPAttnProcessor2_0 as IPAttnProcessor,
37
+ )
38
+ else:
39
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
40
+ from .resampler import Resampler
41
+
42
+
43
+ class ImageProjModel(torch.nn.Module):
44
+ """Projection Model"""
45
+
46
+ def __init__(
47
+ self,
48
+ cross_attention_dim=1024,
49
+ clip_embeddings_dim=1024,
50
+ clip_extra_context_tokens=4,
51
+ ):
52
+ super().__init__()
53
+
54
+ self.generator = None
55
+ self.cross_attention_dim = cross_attention_dim
56
+ self.clip_extra_context_tokens = clip_extra_context_tokens
57
+ self.proj = torch.nn.Linear(
58
+ clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
59
+ )
60
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
61
+
62
+ def forward(self, image_embeds):
63
+ embeds = image_embeds
64
+ clip_extra_context_tokens = self.proj(embeds).reshape(
65
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
66
+ )
67
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
68
+ return clip_extra_context_tokens
69
+
70
+
71
+ class MLPProjModel(torch.nn.Module):
72
+ """SD model with image prompt"""
73
+
74
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
75
+ super().__init__()
76
+
77
+ self.proj = torch.nn.Sequential(
78
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
79
+ torch.nn.GELU(),
80
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
81
+ torch.nn.LayerNorm(cross_attention_dim),
82
+ )
83
+
84
+ def forward(self, image_embeds):
85
+ clip_extra_context_tokens = self.proj(image_embeds)
86
+ return clip_extra_context_tokens
87
+
88
+
89
+ class IPAdapter:
90
+ def __init__(
91
+ self,
92
+ sd_pipe,
93
+ image_encoder_path,
94
+ ip_ckpt,
95
+ device,
96
+ num_tokens=4,
97
+ target_blocks=["block"],
98
+ ):
99
+ self.device = device
100
+ self.image_encoder_path = image_encoder_path
101
+ self.ip_ckpt = ip_ckpt
102
+ self.num_tokens = num_tokens
103
+ self.target_blocks = target_blocks
104
+
105
+ self.pipe = sd_pipe.to(self.device)
106
+ self.set_ip_adapter()
107
+
108
+ # load image encoder
109
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
110
+ self.image_encoder_path
111
+ ).to(self.device, dtype=torch.float16)
112
+ self.clip_image_processor = CLIPImageProcessor()
113
+ # image proj model
114
+ self.image_proj_model = self.init_proj()
115
+
116
+ self.load_ip_adapter()
117
+
118
+ def init_proj(self):
119
+ image_proj_model = ImageProjModel(
120
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
121
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
122
+ clip_extra_context_tokens=self.num_tokens,
123
+ ).to(self.device, dtype=torch.float16)
124
+ return image_proj_model
125
+
126
+ def set_ip_adapter(self):
127
+ unet = self.pipe.unet
128
+ attn_procs = {}
129
+ for name in unet.attn_processors.keys():
130
+ cross_attention_dim = (
131
+ None
132
+ if name.endswith("attn1.processor")
133
+ else unet.config.cross_attention_dim
134
+ )
135
+ if name.startswith("mid_block"):
136
+ hidden_size = unet.config.block_out_channels[-1]
137
+ elif name.startswith("up_blocks"):
138
+ block_id = int(name[len("up_blocks.")])
139
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
140
+ elif name.startswith("down_blocks"):
141
+ block_id = int(name[len("down_blocks.")])
142
+ hidden_size = unet.config.block_out_channels[block_id]
143
+ if cross_attention_dim is None:
144
+ attn_procs[name] = AttnProcessor()
145
+ else:
146
+ selected = False
147
+ for block_name in self.target_blocks:
148
+ if block_name in name:
149
+ selected = True
150
+ break
151
+ if selected:
152
+ attn_procs[name] = IPAttnProcessor(
153
+ hidden_size=hidden_size,
154
+ cross_attention_dim=cross_attention_dim,
155
+ scale=1.0,
156
+ num_tokens=self.num_tokens,
157
+ ).to(self.device, dtype=torch.float16)
158
+ else:
159
+ attn_procs[name] = IPAttnProcessor(
160
+ hidden_size=hidden_size,
161
+ cross_attention_dim=cross_attention_dim,
162
+ scale=1.0,
163
+ num_tokens=self.num_tokens,
164
+ skip=True,
165
+ ).to(self.device, dtype=torch.float16)
166
+ unet.set_attn_processor(attn_procs)
167
+ if hasattr(self.pipe, "controlnet"):
168
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
169
+ for controlnet in self.pipe.controlnet.nets:
170
+ controlnet.set_attn_processor(
171
+ CNAttnProcessor(num_tokens=self.num_tokens)
172
+ )
173
+ else:
174
+ self.pipe.controlnet.set_attn_processor(
175
+ CNAttnProcessor(num_tokens=self.num_tokens)
176
+ )
177
+
178
+ def load_ip_adapter(self):
179
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
180
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
181
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
182
+ for key in f.keys():
183
+ if key.startswith("image_proj."):
184
+ state_dict["image_proj"][key.replace("image_proj.", "")] = (
185
+ f.get_tensor(key)
186
+ )
187
+ elif key.startswith("ip_adapter."):
188
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = (
189
+ f.get_tensor(key)
190
+ )
191
+ else:
192
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
193
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
194
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
195
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
196
+
197
+ @torch.inference_mode()
198
+ def get_image_embeds(
199
+ self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None
200
+ ):
201
+ if pil_image is not None:
202
+ if isinstance(pil_image, Image.Image):
203
+ pil_image = [pil_image]
204
+ clip_image = self.clip_image_processor(
205
+ images=pil_image, return_tensors="pt"
206
+ ).pixel_values
207
+ clip_image_embeds = self.image_encoder(
208
+ clip_image.to(self.device, dtype=torch.float16)
209
+ ).image_embeds
210
+ else:
211
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
212
+
213
+ if content_prompt_embeds is not None:
214
+ print(clip_image_embeds.shape)
215
+ print(content_prompt_embeds.shape)
216
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
217
+
218
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
219
+ uncond_image_prompt_embeds = self.image_proj_model(
220
+ torch.zeros_like(clip_image_embeds)
221
+ )
222
+ return image_prompt_embeds, uncond_image_prompt_embeds
223
+
224
+ @torch.inference_mode()
225
+ def generate_image_edit_dir(
226
+ self,
227
+ pil_image=None,
228
+ content_prompt_embeds=None,
229
+ edit_mlps: dict[torch.nn.Module, float] = None,
230
+ ):
231
+ print("Combining multiple MLPs!")
232
+ if pil_image is not None:
233
+ if isinstance(pil_image, Image.Image):
234
+ pil_image = [pil_image]
235
+ clip_image = self.clip_image_processor(
236
+ images=pil_image, return_tensors="pt"
237
+ ).pixel_values
238
+ clip_image_embeds = self.image_encoder(
239
+ clip_image.to(self.device, dtype=torch.float16)
240
+ ).image_embeds
241
+ pred_editing_dirs = [
242
+ net(
243
+ clip_image_embeds,
244
+ torch.Tensor([strength]).to(self.device, dtype=torch.float16),
245
+ )
246
+ for net, strength in edit_mlps.items()
247
+ ]
248
+
249
+ clip_image_embeds = clip_image_embeds + sum(pred_editing_dirs)
250
+
251
+ if content_prompt_embeds is not None:
252
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
253
+
254
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
255
+ uncond_image_prompt_embeds = self.image_proj_model(
256
+ torch.zeros_like(clip_image_embeds)
257
+ )
258
+ return image_prompt_embeds, uncond_image_prompt_embeds
259
+
260
+ @torch.inference_mode()
261
+ def get_image_edit_dir(
262
+ self,
263
+ start_image=None,
264
+ pil_image=None,
265
+ pil_image2=None,
266
+ content_prompt_embeds=None,
267
+ edit_strength=1.0,
268
+ ):
269
+ print("Blending Two Materials!")
270
+ if pil_image is not None:
271
+ if isinstance(pil_image, Image.Image):
272
+ pil_image = [pil_image]
273
+ clip_image = self.clip_image_processor(
274
+ images=pil_image, return_tensors="pt"
275
+ ).pixel_values
276
+ clip_image_embeds = self.image_encoder(
277
+ clip_image.to(self.device, dtype=torch.float16)
278
+ ).image_embeds
279
+
280
+ if pil_image2 is not None:
281
+ if isinstance(pil_image2, Image.Image):
282
+ pil_image2 = [pil_image2]
283
+ clip_image2 = self.clip_image_processor(
284
+ images=pil_image2, return_tensors="pt"
285
+ ).pixel_values
286
+ clip_image_embeds2 = self.image_encoder(
287
+ clip_image2.to(self.device, dtype=torch.float16)
288
+ ).image_embeds
289
+
290
+ if start_image is not None:
291
+ if isinstance(start_image, Image.Image):
292
+ start_image = [start_image]
293
+ clip_image_start = self.clip_image_processor(
294
+ images=start_image, return_tensors="pt"
295
+ ).pixel_values
296
+ clip_image_embeds_start = self.image_encoder(
297
+ clip_image_start.to(self.device, dtype=torch.float16)
298
+ ).image_embeds
299
+
300
+ if content_prompt_embeds is not None:
301
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
302
+ clip_image_embeds2 = clip_image_embeds2 - content_prompt_embeds
303
+
304
+ # clip_image_embeds += edit_strength * (clip_image_embeds2 - clip_image_embeds)
305
+ clip_image_embeds = clip_image_embeds_start + edit_strength * (
306
+ clip_image_embeds2 - clip_image_embeds
307
+ )
308
+
309
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
310
+ uncond_image_prompt_embeds = self.image_proj_model(
311
+ torch.zeros_like(clip_image_embeds)
312
+ )
313
+ return image_prompt_embeds, uncond_image_prompt_embeds
314
+
315
+ def set_scale(self, scale):
316
+ for attn_processor in self.pipe.unet.attn_processors.values():
317
+ if isinstance(attn_processor, IPAttnProcessor):
318
+ attn_processor.scale = scale
319
+
320
+ def set_scale(self, scale):
321
+ for attn_processor in self.pipe.unet.attn_processors.values():
322
+ if isinstance(attn_processor, IPAttnProcessor):
323
+ attn_processor.scale = scale
324
+
325
+ def generate(
326
+ self,
327
+ pil_image=None,
328
+ clip_image_embeds=None,
329
+ prompt=None,
330
+ negative_prompt=None,
331
+ scale=1.0,
332
+ num_samples=4,
333
+ seed=None,
334
+ guidance_scale=7.5,
335
+ num_inference_steps=30,
336
+ neg_content_emb=None,
337
+ **kwargs,
338
+ ):
339
+ self.set_scale(scale)
340
+
341
+ if pil_image is not None:
342
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
343
+ else:
344
+ num_prompts = clip_image_embeds.size(0)
345
+
346
+ if prompt is None:
347
+ prompt = "best quality, high quality"
348
+ if negative_prompt is None:
349
+ negative_prompt = (
350
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
351
+ )
352
+
353
+ if not isinstance(prompt, List):
354
+ prompt = [prompt] * num_prompts
355
+ if not isinstance(negative_prompt, List):
356
+ negative_prompt = [negative_prompt] * num_prompts
357
+
358
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
359
+ pil_image=pil_image,
360
+ clip_image_embeds=clip_image_embeds,
361
+ content_prompt_embeds=neg_content_emb,
362
+ )
363
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
364
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
365
+ image_prompt_embeds = image_prompt_embeds.view(
366
+ bs_embed * num_samples, seq_len, -1
367
+ )
368
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
369
+ 1, num_samples, 1
370
+ )
371
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
372
+ bs_embed * num_samples, seq_len, -1
373
+ )
374
+
375
+ with torch.inference_mode():
376
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
377
+ prompt,
378
+ device=self.device,
379
+ num_images_per_prompt=num_samples,
380
+ do_classifier_free_guidance=True,
381
+ negative_prompt=negative_prompt,
382
+ )
383
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
384
+ negative_prompt_embeds = torch.cat(
385
+ [negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1
386
+ )
387
+
388
+ generator = get_generator(seed, self.device)
389
+
390
+ images = self.pipe(
391
+ prompt_embeds=prompt_embeds,
392
+ negative_prompt_embeds=negative_prompt_embeds,
393
+ guidance_scale=guidance_scale,
394
+ num_inference_steps=num_inference_steps,
395
+ generator=generator,
396
+ **kwargs,
397
+ ).images
398
+
399
+ return images
400
+
401
+
402
+ class IPAdapterXL(IPAdapter):
403
+ """SDXL"""
404
+
405
+ def generate(
406
+ self,
407
+ pil_image,
408
+ prompt=None,
409
+ negative_prompt=None,
410
+ scale=1.0,
411
+ num_samples=4,
412
+ seed=None,
413
+ num_inference_steps=30,
414
+ neg_content_emb=None,
415
+ neg_content_prompt=None,
416
+ neg_content_scale=1.0,
417
+ clip_strength=1.0,
418
+ **kwargs,
419
+ ):
420
+ self.set_scale(scale)
421
+
422
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
423
+
424
+ if prompt is None:
425
+ prompt = "best quality, high quality"
426
+ if negative_prompt is None:
427
+ negative_prompt = (
428
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
429
+ )
430
+
431
+ if not isinstance(prompt, List):
432
+ prompt = [prompt] * num_prompts
433
+ if not isinstance(negative_prompt, List):
434
+ negative_prompt = [negative_prompt] * num_prompts
435
+
436
+ if neg_content_emb is None:
437
+ if neg_content_prompt is not None:
438
+ with torch.inference_mode():
439
+ (
440
+ prompt_embeds_, # torch.Size([1, 77, 2048])
441
+ negative_prompt_embeds_,
442
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
443
+ negative_pooled_prompt_embeds_,
444
+ ) = self.pipe.encode_prompt(
445
+ neg_content_prompt,
446
+ num_images_per_prompt=num_samples,
447
+ do_classifier_free_guidance=True,
448
+ negative_prompt=negative_prompt,
449
+ )
450
+ pooled_prompt_embeds_ *= neg_content_scale
451
+ else:
452
+ pooled_prompt_embeds_ = neg_content_emb
453
+ else:
454
+ pooled_prompt_embeds_ = None
455
+
456
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
457
+ pil_image, content_prompt_embeds=pooled_prompt_embeds_
458
+ )
459
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
460
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
461
+ image_prompt_embeds = image_prompt_embeds.view(
462
+ bs_embed * num_samples, seq_len, -1
463
+ )
464
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
465
+ 1, num_samples, 1
466
+ )
467
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
468
+ bs_embed * num_samples, seq_len, -1
469
+ )
470
+ print("CLIP Strength is {}".format(clip_strength))
471
+ image_prompt_embeds *= clip_strength
472
+ uncond_image_prompt_embeds *= clip_strength
473
+
474
+ with torch.inference_mode():
475
+ (
476
+ prompt_embeds,
477
+ negative_prompt_embeds,
478
+ pooled_prompt_embeds,
479
+ negative_pooled_prompt_embeds,
480
+ ) = self.pipe.encode_prompt(
481
+ prompt,
482
+ num_images_per_prompt=num_samples,
483
+ do_classifier_free_guidance=True,
484
+ negative_prompt=negative_prompt,
485
+ )
486
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
487
+ negative_prompt_embeds = torch.cat(
488
+ [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
489
+ )
490
+
491
+ self.generator = get_generator(seed, self.device)
492
+
493
+ images = self.pipe(
494
+ prompt_embeds=prompt_embeds,
495
+ negative_prompt_embeds=negative_prompt_embeds,
496
+ pooled_prompt_embeds=pooled_prompt_embeds,
497
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
498
+ num_inference_steps=num_inference_steps,
499
+ generator=self.generator,
500
+ **kwargs,
501
+ ).images
502
+
503
+ return images
504
+
505
+ def generate_parametric_edits(
506
+ self,
507
+ pil_image,
508
+ edit_mlps: dict[torch.nn.Module, float],
509
+ prompt=None,
510
+ negative_prompt=None,
511
+ scale=1.0,
512
+ num_samples=4,
513
+ seed=None,
514
+ num_inference_steps=30,
515
+ neg_content_emb=None,
516
+ neg_content_prompt=None,
517
+ neg_content_scale=1.0,
518
+ **kwargs,
519
+ ):
520
+ self.set_scale(scale)
521
+
522
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
523
+
524
+ if prompt is None:
525
+ prompt = "best quality, high quality"
526
+ if negative_prompt is None:
527
+ negative_prompt = (
528
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
529
+ )
530
+
531
+ if not isinstance(prompt, List):
532
+ prompt = [prompt] * num_prompts
533
+ if not isinstance(negative_prompt, List):
534
+ negative_prompt = [negative_prompt] * num_prompts
535
+
536
+ if neg_content_emb is None:
537
+ if neg_content_prompt is not None:
538
+ with torch.inference_mode():
539
+ (
540
+ prompt_embeds_, # torch.Size([1, 77, 2048])
541
+ negative_prompt_embeds_,
542
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
543
+ negative_pooled_prompt_embeds_,
544
+ ) = self.pipe.encode_prompt(
545
+ neg_content_prompt,
546
+ num_images_per_prompt=num_samples,
547
+ do_classifier_free_guidance=True,
548
+ negative_prompt=negative_prompt,
549
+ )
550
+ pooled_prompt_embeds_ *= neg_content_scale
551
+ else:
552
+ pooled_prompt_embeds_ = neg_content_emb
553
+ else:
554
+ pooled_prompt_embeds_ = None
555
+ image_prompt_embeds, uncond_image_prompt_embeds = self.generate_image_edit_dir(
556
+ pil_image, content_prompt_embeds=pooled_prompt_embeds_, edit_mlps=edit_mlps
557
+ )
558
+
559
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
560
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
561
+ image_prompt_embeds = image_prompt_embeds.view(
562
+ bs_embed * num_samples, seq_len, -1
563
+ )
564
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
565
+ 1, num_samples, 1
566
+ )
567
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
568
+ bs_embed * num_samples, seq_len, -1
569
+ )
570
+
571
+ with torch.inference_mode():
572
+ (
573
+ prompt_embeds,
574
+ negative_prompt_embeds,
575
+ pooled_prompt_embeds,
576
+ negative_pooled_prompt_embeds,
577
+ ) = self.pipe.encode_prompt(
578
+ prompt,
579
+ num_images_per_prompt=num_samples,
580
+ do_classifier_free_guidance=True,
581
+ negative_prompt=negative_prompt,
582
+ )
583
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
584
+ negative_prompt_embeds = torch.cat(
585
+ [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
586
+ )
587
+
588
+ self.generator = get_generator(seed, self.device)
589
+
590
+ images = self.pipe(
591
+ prompt_embeds=prompt_embeds,
592
+ negative_prompt_embeds=negative_prompt_embeds,
593
+ pooled_prompt_embeds=pooled_prompt_embeds,
594
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
595
+ num_inference_steps=num_inference_steps,
596
+ generator=self.generator,
597
+ **kwargs,
598
+ ).images
599
+
600
+ return images
601
+
602
+ def generate_edit(
603
+ self,
604
+ start_image,
605
+ pil_image,
606
+ pil_image2,
607
+ prompt=None,
608
+ negative_prompt=None,
609
+ scale=1.0,
610
+ num_samples=4,
611
+ seed=None,
612
+ num_inference_steps=30,
613
+ neg_content_emb=None,
614
+ neg_content_prompt=None,
615
+ neg_content_scale=1.0,
616
+ edit_strength=1.0,
617
+ **kwargs,
618
+ ):
619
+ self.set_scale(scale)
620
+
621
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
622
+
623
+ if prompt is None:
624
+ prompt = "best quality, high quality"
625
+ if negative_prompt is None:
626
+ negative_prompt = (
627
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
628
+ )
629
+
630
+ if not isinstance(prompt, List):
631
+ prompt = [prompt] * num_prompts
632
+ if not isinstance(negative_prompt, List):
633
+ negative_prompt = [negative_prompt] * num_prompts
634
+
635
+ if neg_content_emb is None:
636
+ if neg_content_prompt is not None:
637
+ with torch.inference_mode():
638
+ (
639
+ prompt_embeds_, # torch.Size([1, 77, 2048])
640
+ negative_prompt_embeds_,
641
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
642
+ negative_pooled_prompt_embeds_,
643
+ ) = self.pipe.encode_prompt(
644
+ neg_content_prompt,
645
+ num_images_per_prompt=num_samples,
646
+ do_classifier_free_guidance=True,
647
+ negative_prompt=negative_prompt,
648
+ )
649
+ pooled_prompt_embeds_ *= neg_content_scale
650
+ else:
651
+ pooled_prompt_embeds_ = neg_content_emb
652
+ else:
653
+ pooled_prompt_embeds_ = None
654
+
655
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_edit_dir(
656
+ start_image,
657
+ pil_image,
658
+ pil_image2,
659
+ content_prompt_embeds=pooled_prompt_embeds_,
660
+ edit_strength=edit_strength,
661
+ )
662
+
663
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
664
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
665
+ image_prompt_embeds = image_prompt_embeds.view(
666
+ bs_embed * num_samples, seq_len, -1
667
+ )
668
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
669
+ 1, num_samples, 1
670
+ )
671
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
672
+ bs_embed * num_samples, seq_len, -1
673
+ )
674
+
675
+ with torch.inference_mode():
676
+ (
677
+ prompt_embeds,
678
+ negative_prompt_embeds,
679
+ pooled_prompt_embeds,
680
+ negative_pooled_prompt_embeds,
681
+ ) = self.pipe.encode_prompt(
682
+ prompt,
683
+ num_images_per_prompt=num_samples,
684
+ do_classifier_free_guidance=True,
685
+ negative_prompt=negative_prompt,
686
+ )
687
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
688
+ negative_prompt_embeds = torch.cat(
689
+ [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
690
+ )
691
+
692
+ self.generator = get_generator(seed, self.device)
693
+
694
+ images = self.pipe(
695
+ prompt_embeds=prompt_embeds,
696
+ negative_prompt_embeds=negative_prompt_embeds,
697
+ pooled_prompt_embeds=pooled_prompt_embeds,
698
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
699
+ num_inference_steps=num_inference_steps,
700
+ generator=self.generator,
701
+ **kwargs,
702
+ ).images
703
+
704
+ return images
705
+
706
+
707
+ class IPAdapterPlus(IPAdapter):
708
+ """IP-Adapter with fine-grained features"""
709
+
710
+ def init_proj(self):
711
+ image_proj_model = Resampler(
712
+ dim=self.pipe.unet.config.cross_attention_dim,
713
+ depth=4,
714
+ dim_head=64,
715
+ heads=12,
716
+ num_queries=self.num_tokens,
717
+ embedding_dim=self.image_encoder.config.hidden_size,
718
+ output_dim=self.pipe.unet.config.cross_attention_dim,
719
+ ff_mult=4,
720
+ ).to(self.device, dtype=torch.float16)
721
+ return image_proj_model
722
+
723
+ @torch.inference_mode()
724
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
725
+ if isinstance(pil_image, Image.Image):
726
+ pil_image = [pil_image]
727
+ clip_image = self.clip_image_processor(
728
+ images=pil_image, return_tensors="pt"
729
+ ).pixel_values
730
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
731
+ clip_image_embeds = self.image_encoder(
732
+ clip_image, output_hidden_states=True
733
+ ).hidden_states[-2]
734
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
735
+ uncond_clip_image_embeds = self.image_encoder(
736
+ torch.zeros_like(clip_image), output_hidden_states=True
737
+ ).hidden_states[-2]
738
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
739
+ return image_prompt_embeds, uncond_image_prompt_embeds
740
+
741
+
742
+ class IPAdapterFull(IPAdapterPlus):
743
+ """IP-Adapter with full features"""
744
+
745
+ def init_proj(self):
746
+ image_proj_model = MLPProjModel(
747
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
748
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
749
+ ).to(self.device, dtype=torch.float16)
750
+ return image_proj_model
751
+
752
+
753
+ class IPAdapterPlusXL(IPAdapter):
754
+ """SDXL"""
755
+
756
+ def init_proj(self):
757
+ image_proj_model = Resampler(
758
+ dim=1280,
759
+ depth=4,
760
+ dim_head=64,
761
+ heads=20,
762
+ num_queries=self.num_tokens,
763
+ embedding_dim=self.image_encoder.config.hidden_size,
764
+ output_dim=self.pipe.unet.config.cross_attention_dim,
765
+ ff_mult=4,
766
+ ).to(self.device, dtype=torch.float16)
767
+ return image_proj_model
768
+
769
+ @torch.inference_mode()
770
+ def get_image_embeds(self, pil_image):
771
+ if isinstance(pil_image, Image.Image):
772
+ pil_image = [pil_image]
773
+ clip_image = self.clip_image_processor(
774
+ images=pil_image, return_tensors="pt"
775
+ ).pixel_values
776
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
777
+ clip_image_embeds = self.image_encoder(
778
+ clip_image, output_hidden_states=True
779
+ ).hidden_states[-2]
780
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
781
+ uncond_clip_image_embeds = self.image_encoder(
782
+ torch.zeros_like(clip_image), output_hidden_states=True
783
+ ).hidden_states[-2]
784
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
785
+ return image_prompt_embeds, uncond_image_prompt_embeds
786
+
787
+ def generate(
788
+ self,
789
+ pil_image,
790
+ prompt=None,
791
+ negative_prompt=None,
792
+ scale=1.0,
793
+ num_samples=4,
794
+ seed=None,
795
+ num_inference_steps=30,
796
+ **kwargs,
797
+ ):
798
+ self.set_scale(scale)
799
+
800
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
801
+
802
+ if prompt is None:
803
+ prompt = "best quality, high quality"
804
+ if negative_prompt is None:
805
+ negative_prompt = (
806
+ "monochrome, lowres, bad anatomy, worst quality, low quality"
807
+ )
808
+
809
+ if not isinstance(prompt, List):
810
+ prompt = [prompt] * num_prompts
811
+ if not isinstance(negative_prompt, List):
812
+ negative_prompt = [negative_prompt] * num_prompts
813
+
814
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
815
+ pil_image
816
+ )
817
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
818
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
819
+ image_prompt_embeds = image_prompt_embeds.view(
820
+ bs_embed * num_samples, seq_len, -1
821
+ )
822
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
823
+ 1, num_samples, 1
824
+ )
825
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
826
+ bs_embed * num_samples, seq_len, -1
827
+ )
828
+
829
+ with torch.inference_mode():
830
+ (
831
+ prompt_embeds,
832
+ negative_prompt_embeds,
833
+ pooled_prompt_embeds,
834
+ negative_pooled_prompt_embeds,
835
+ ) = self.pipe.encode_prompt(
836
+ prompt,
837
+ num_images_per_prompt=num_samples,
838
+ do_classifier_free_guidance=True,
839
+ negative_prompt=negative_prompt,
840
+ )
841
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
842
+ negative_prompt_embeds = torch.cat(
843
+ [negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
844
+ )
845
+
846
+ generator = get_generator(seed, self.device)
847
+
848
+ images = self.pipe(
849
+ prompt_embeds=prompt_embeds,
850
+ negative_prompt_embeds=negative_prompt_embeds,
851
+ pooled_prompt_embeds=pooled_prompt_embeds,
852
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
853
+ num_inference_steps=num_inference_steps,
854
+ generator=generator,
855
+ **kwargs,
856
+ ).images
857
+
858
+ return images
ip_adapter_instantstyle/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
ip_adapter_instantstyle/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ attn_maps = {}
7
+ def hook_fn(name):
8
+ def forward_hook(module, input, output):
9
+ if hasattr(module.processor, "attn_map"):
10
+ attn_maps[name] = module.processor.attn_map
11
+ del module.processor.attn_map
12
+
13
+ return forward_hook
14
+
15
+ def register_cross_attention_hook(unet):
16
+ for name, module in unet.named_modules():
17
+ if name.split('.')[-1].startswith('attn2'):
18
+ module.register_forward_hook(hook_fn(name))
19
+
20
+ return unet
21
+
22
+ def upscale(attn_map, target_size):
23
+ attn_map = torch.mean(attn_map, dim=0)
24
+ attn_map = attn_map.permute(1,0)
25
+ temp_size = None
26
+
27
+ for i in range(0,5):
28
+ scale = 2 ** i
29
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31
+ break
32
+
33
+ assert temp_size is not None, "temp_size cannot is None"
34
+
35
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36
+
37
+ attn_map = F.interpolate(
38
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
39
+ size=target_size,
40
+ mode='bilinear',
41
+ align_corners=False
42
+ )[0]
43
+
44
+ attn_map = torch.softmax(attn_map, dim=0)
45
+ return attn_map
46
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47
+
48
+ idx = 0 if instance_or_negative else 1
49
+ net_attn_maps = []
50
+
51
+ for name, attn_map in attn_maps.items():
52
+ attn_map = attn_map.cpu() if detach else attn_map
53
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54
+ attn_map = upscale(attn_map, image_size)
55
+ net_attn_maps.append(attn_map)
56
+
57
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58
+
59
+ return net_attn_maps
60
+
61
+ def attnmaps2images(net_attn_maps):
62
+
63
+ #total_attn_scores = 0
64
+ images = []
65
+
66
+ for attn_map in net_attn_maps:
67
+ attn_map = attn_map.cpu().numpy()
68
+ #total_attn_scores += attn_map.mean().item()
69
+
70
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
72
+ #print("norm: ", normalized_attn_map.shape)
73
+ image = Image.fromarray(normalized_attn_map)
74
+
75
+ #image = fix_save_attn_map(attn_map)
76
+ images.append(image)
77
+
78
+ #print(total_attn_scores)
79
+ return images
80
+ def is_torch2_available():
81
+ return hasattr(F, "scaled_dot_product_attention")
82
+
83
+ def get_generator(seed, device):
84
+
85
+ if seed is not None:
86
+ if isinstance(seed, list):
87
+ generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
88
+ else:
89
+ generator = torch.Generator(device).manual_seed(seed)
90
+ else:
91
+ generator = None
92
+
93
+ return generator
marble.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline
7
+ from huggingface_hub import hf_hub_download, list_repo_files
8
+ from PIL import Image, ImageChops, ImageEnhance
9
+ from rembg import new_session, remove
10
+ from transformers import DPTForDepthEstimation, DPTImageProcessor
11
+
12
+ from ip_adapter_instantstyle import IPAdapterXL
13
+ from ip_adapter_instantstyle.utils import register_cross_attention_hook
14
+ from parametric_control_mlp import control_mlp
15
+
16
+ file_dir = os.path.dirname(os.path.abspath(__file__))
17
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
18
+ image_encoder_path = "models/image_encoder"
19
+ ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
20
+ controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
21
+
22
+ # Cache for rembg sessions
23
+ _session_cache = None
24
+ CONTROL_MLPS = ["metallic", "roughness", "transparency", "glow"]
25
+
26
+
27
+ def get_session():
28
+ global _session_cache
29
+ if _session_cache is None:
30
+ _session_cache = new_session()
31
+ return _session_cache
32
+
33
+
34
+ def setup_control_mlps(
35
+ features: int = 1024, device: str = "cuda", dtype: torch.dtype = torch.float16
36
+ ) -> Dict[str, torch.nn.Module]:
37
+ ret = {}
38
+ for mlp in CONTROL_MLPS:
39
+ ret[mlp] = setup_control_mlp(mlp, features, device, dtype)
40
+ return ret
41
+
42
+
43
+ def setup_control_mlp(
44
+ material_parameter: str,
45
+ features: int = 1024,
46
+ device: str = "cuda",
47
+ dtype: torch.dtype = torch.float16,
48
+ ):
49
+ net = control_mlp(features)
50
+ net.load_state_dict(
51
+ torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt"))
52
+ )
53
+ net.to(device, dtype=dtype)
54
+ net.eval()
55
+ return net
56
+
57
+
58
+ def download_ip_adapter():
59
+ repo_id = "h94/IP-Adapter"
60
+ target_folders = ["models/", "sdxl_models/"]
61
+ local_dir = file_dir
62
+
63
+ # Check if folders exist and contain files
64
+ folders_exist = all(
65
+ os.path.exists(os.path.join(local_dir, folder)) for folder in target_folders
66
+ )
67
+
68
+ if folders_exist:
69
+ # Check if any of the target folders are empty
70
+ folders_empty = any(
71
+ len(os.listdir(os.path.join(local_dir, folder))) == 0
72
+ for folder in target_folders
73
+ )
74
+ if not folders_empty:
75
+ print("IP-Adapter files already downloaded. Skipping download.")
76
+ return
77
+
78
+ # List all files in the repo
79
+ all_files = list_repo_files(repo_id)
80
+
81
+ # Filter for files in the desired folders
82
+ filtered_files = [
83
+ f for f in all_files if any(f.startswith(folder) for folder in target_folders)
84
+ ]
85
+
86
+ # Download each file
87
+ for file_path in filtered_files:
88
+ local_path = hf_hub_download(
89
+ repo_id=repo_id,
90
+ filename=file_path,
91
+ local_dir=local_dir,
92
+ local_dir_use_symlinks=False,
93
+ )
94
+ print(f"Downloaded: {file_path} to {local_path}")
95
+
96
+
97
+ def setup_pipeline(
98
+ device: str = "cuda",
99
+ dtype: torch.dtype = torch.float16,
100
+ ):
101
+ download_ip_adapter()
102
+
103
+ cur_block = ("up", 0, 1)
104
+
105
+ controlnet = ControlNetModel.from_pretrained(
106
+ controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=dtype
107
+ ).to(device)
108
+
109
+ pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
110
+ base_model_path,
111
+ controlnet=controlnet,
112
+ use_safetensors=True,
113
+ torch_dtype=dtype,
114
+ add_watermarker=False,
115
+ ).to(device)
116
+
117
+ pipe.unet = register_cross_attention_hook(pipe.unet)
118
+
119
+ block_name = (
120
+ cur_block[0]
121
+ + "_blocks."
122
+ + str(cur_block[1])
123
+ + ".attentions."
124
+ + str(cur_block[2])
125
+ )
126
+
127
+ print("Testing block {}".format(block_name))
128
+
129
+ return IPAdapterXL(
130
+ pipe,
131
+ os.path.join(file_dir, image_encoder_path),
132
+ os.path.join(file_dir, ip_ckpt),
133
+ device,
134
+ target_blocks=[block_name],
135
+ )
136
+
137
+
138
+ def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16):
139
+ image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
140
+ model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
141
+ model.to(device, dtype=dtype)
142
+ model.eval()
143
+ return model, image_processor
144
+
145
+
146
+ def run_dpt_depth(
147
+ image: Image.Image, model, processor, device: str = "cuda"
148
+ ) -> Image.Image:
149
+ """Run DPT depth estimation on an image."""
150
+ # Prepare image
151
+ inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype)
152
+
153
+ # Get depth prediction
154
+ with torch.no_grad():
155
+ depth_map = model(**inputs).predicted_depth
156
+
157
+ # Now normalize to 0-1 range
158
+ depth_map = (depth_map - depth_map.min()) / (
159
+ depth_map.max() - depth_map.min() + 1e-7
160
+ )
161
+ depth_map = depth_map.clip(0, 1) * 255
162
+
163
+ # Convert to PIL Image
164
+ depth_map = depth_map.squeeze().cpu().numpy().astype(np.uint8)
165
+ return Image.fromarray(depth_map).resize((1024, 1024))
166
+
167
+
168
+ def prepare_mask(image: Image.Image) -> Image.Image:
169
+ """Prepare mask from image using rembg."""
170
+ rm_bg = remove(image, session=get_session())
171
+ target_mask = (
172
+ rm_bg.convert("RGB")
173
+ .point(lambda x: 0 if x < 1 else 255)
174
+ .convert("L")
175
+ .convert("RGB")
176
+ )
177
+ return target_mask.resize((1024, 1024))
178
+
179
+
180
+ def prepare_init_image(image: Image.Image, mask: Image.Image) -> Image.Image:
181
+ """Prepare initial image for inpainting."""
182
+
183
+ # Create grayscale version
184
+ gray_image = image.convert("L").convert("RGB")
185
+ gray_image = ImageEnhance.Brightness(gray_image).enhance(1.0)
186
+
187
+ # Create mask inversions
188
+ invert_mask = ImageChops.invert(mask)
189
+
190
+ # Combine images
191
+ grayscale_img = ImageChops.darker(gray_image, mask)
192
+ img_black_mask = ImageChops.darker(image, invert_mask)
193
+ init_img = ImageChops.lighter(img_black_mask, grayscale_img)
194
+
195
+ return init_img.resize((1024, 1024))
196
+
197
+
198
+ def run_parametric_control(
199
+ ip_model,
200
+ target_image: Image.Image,
201
+ edit_mlps: dict[torch.nn.Module, float],
202
+ texture_image: Image.Image = None,
203
+ num_inference_steps: int = 30,
204
+ seed: int = 42,
205
+ depth_map: Image.Image = None,
206
+ mask: Image.Image = None,
207
+ ) -> Image.Image:
208
+ """Run parametric control with metallic and roughness adjustments."""
209
+ # Get depth map
210
+ if depth_map is None:
211
+ model, processor = get_dpt_model()
212
+ depth_map = run_dpt_depth(target_image, model, processor)
213
+ else:
214
+ depth_map = depth_map.resize((1024, 1024))
215
+
216
+ # Prepare mask and init image
217
+ if mask is None:
218
+ mask = prepare_mask(target_image)
219
+ else:
220
+ mask = mask.resize((1024, 1024))
221
+
222
+ if texture_image is None:
223
+ texture_image = target_image
224
+
225
+ init_img = prepare_init_image(target_image, mask)
226
+
227
+ # Generate edit
228
+ images = ip_model.generate_parametric_edits(
229
+ texture_image,
230
+ image=init_img,
231
+ control_image=depth_map,
232
+ mask_image=mask,
233
+ controlnet_conditioning_scale=1.0,
234
+ num_samples=1,
235
+ num_inference_steps=num_inference_steps,
236
+ seed=seed,
237
+ edit_mlps=edit_mlps,
238
+ strength=1.0,
239
+ )
240
+
241
+ return images[0]
242
+
243
+
244
+ def run_blend(
245
+ ip_model,
246
+ target_image: Image.Image,
247
+ texture_image1: Image.Image,
248
+ texture_image2: Image.Image,
249
+ edit_strength: float = 0.0,
250
+ num_inference_steps: int = 20,
251
+ seed: int = 1,
252
+ depth_map: Image.Image = None,
253
+ mask: Image.Image = None,
254
+ ) -> Image.Image:
255
+ """Run blending between two texture images."""
256
+ # Get depth map
257
+ if depth_map is None:
258
+ model, processor = get_dpt_model()
259
+ depth_map = run_dpt_depth(target_image, model, processor)
260
+ else:
261
+ depth_map = depth_map.resize((1024, 1024))
262
+
263
+ # Prepare mask and init image
264
+ if mask is None:
265
+ mask = prepare_mask(target_image)
266
+ else:
267
+ mask = mask.resize((1024, 1024))
268
+ init_img = prepare_init_image(target_image, mask)
269
+
270
+ # Generate edit
271
+ images = ip_model.generate_edit(
272
+ start_image=texture_image1,
273
+ pil_image=texture_image1,
274
+ pil_image2=texture_image2,
275
+ image=init_img,
276
+ control_image=depth_map,
277
+ mask_image=mask,
278
+ controlnet_conditioning_scale=1.0,
279
+ num_samples=1,
280
+ num_inference_steps=num_inference_steps,
281
+ seed=seed,
282
+ edit_strength=edit_strength,
283
+ clip_strength=1.0,
284
+ strength=1.0,
285
+ )
286
+
287
+ return images[0]
model_weights/glow.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4a0d131d8878dd96f10dcd3476ece28664d49836562c22495781c13b6d8eea6
3
+ size 12600283
model_weights/metallic.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b6b1f879902533d49db1ee4a5630d13cc7ff2ec4eba57c0a5d5955a13fbb12c
3
+ size 12600447
model_weights/roughness.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:635a69cfe7c63f4166e3170c82e6e72c3e721ef395eb700807a03b5dc56b01a8
3
+ size 12600467
model_weights/transparency.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b38d72f57017878030c401428272906fa2bc6a87049165923943b7621329b798
3
+ size 12600397
parametric_control_mlp.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class control_mlp(nn.Module):
6
+ def __init__(self, embedding_size):
7
+ super(control_mlp, self).__init__()
8
+
9
+ self.fc1 = nn.Linear(embedding_size, 1024)
10
+ self.fc2 = nn.Linear(1024, 2048)
11
+ self.relu = nn.ReLU()
12
+
13
+ self.edit_strength_fc1 = nn.Linear(1, 128)
14
+ self.edit_strength_fc2 = nn.Linear(128, 2)
15
+
16
+ def forward(self, x, edit_strength):
17
+ x = self.relu(self.fc1(x))
18
+ x = self.fc2(x)
19
+
20
+ edit_strength = self.relu(self.edit_strength_fc1(edit_strength.unsqueeze(1)))
21
+ edit_strength = self.edit_strength_fc2(edit_strength)
22
+
23
+ edit_strength1, edit_strength2 = edit_strength[:, 0], edit_strength[:, 1]
24
+ # print(edit_strength1.shape)
25
+ # exit()
26
+
27
+ output = (
28
+ edit_strength1.unsqueeze(1) * x[:, :1024]
29
+ + edit_strength2.unsqueeze(1) * x[:, 1024:]
30
+ )
31
+
32
+ return output
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.30
2
+ rembg[gpu]
3
+ einops==0.7.0
4
+ transformers==4.27.4
5
+ opencv-python==4.7.0.68
6
+ accelerate==0.26.1
7
+ timm==0.6.12
8
+ torch==2.3.0
9
+ torchvision==0.18.0
10
+ huggingface_hub==0.30.2