Spaces:
Running
on
Zero
Running
on
Zero
Initial commit
Browse filesCo-authored-by: Mark Boss <hello@markboss.me>
- .gitattributes +2 -0
- .gitignore +179 -0
- LICENSE.md +51 -0
- README.md +21 -7
- gradio_demo.py +272 -0
- input_images/context_image/beetle.png +3 -0
- input_images/context_image/genart_teapot.jpg +3 -0
- input_images/context_image/toy_car.png +3 -0
- input_images/context_image/white_car_night.jpg +3 -0
- input_images/depth/beetle.png +3 -0
- input_images/depth/toy_car.png +3 -0
- input_images/texture/high_roughness.png +3 -0
- input_images/texture/low_roughness.png +3 -0
- input_images/texture/metal_bowl.png +3 -0
- ip_adapter_instantstyle/__init__.py +9 -0
- ip_adapter_instantstyle/attention_processor.py +562 -0
- ip_adapter_instantstyle/ip_adapter.py +858 -0
- ip_adapter_instantstyle/resampler.py +158 -0
- ip_adapter_instantstyle/utils.py +93 -0
- marble.py +287 -0
- model_weights/glow.pt +3 -0
- model_weights/metallic.pt +3 -0
- model_weights/roughness.pt +3 -0
- model_weights/transparency.pt +3 -0
- parametric_control_mlp.py +32 -0
- requirements.txt +10 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
IP-Adapter/
|
2 |
+
models/
|
3 |
+
sdxl_models/
|
4 |
+
.gradio/
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# UV
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
#uv.lock
|
107 |
+
|
108 |
+
# poetry
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
110 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
111 |
+
# commonly ignored for libraries.
|
112 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
113 |
+
#poetry.lock
|
114 |
+
|
115 |
+
# pdm
|
116 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
117 |
+
#pdm.lock
|
118 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
119 |
+
# in version control.
|
120 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
121 |
+
.pdm.toml
|
122 |
+
.pdm-python
|
123 |
+
.pdm-build/
|
124 |
+
|
125 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
126 |
+
__pypackages__/
|
127 |
+
|
128 |
+
# Celery stuff
|
129 |
+
celerybeat-schedule
|
130 |
+
celerybeat.pid
|
131 |
+
|
132 |
+
# SageMath parsed files
|
133 |
+
*.sage.py
|
134 |
+
|
135 |
+
# Environments
|
136 |
+
.env
|
137 |
+
.venv
|
138 |
+
env/
|
139 |
+
venv/
|
140 |
+
ENV/
|
141 |
+
env.bak/
|
142 |
+
venv.bak/
|
143 |
+
|
144 |
+
# Spyder project settings
|
145 |
+
.spyderproject
|
146 |
+
.spyproject
|
147 |
+
|
148 |
+
# Rope project settings
|
149 |
+
.ropeproject
|
150 |
+
|
151 |
+
# mkdocs documentation
|
152 |
+
/site
|
153 |
+
|
154 |
+
# mypy
|
155 |
+
.mypy_cache/
|
156 |
+
.dmypy.json
|
157 |
+
dmypy.json
|
158 |
+
|
159 |
+
# Pyre type checker
|
160 |
+
.pyre/
|
161 |
+
|
162 |
+
# pytype static type analyzer
|
163 |
+
.pytype/
|
164 |
+
|
165 |
+
# Cython debug symbols
|
166 |
+
cython_debug/
|
167 |
+
|
168 |
+
# PyCharm
|
169 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
170 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
171 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
172 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
173 |
+
#.idea/
|
174 |
+
|
175 |
+
# Ruff stuff:
|
176 |
+
.ruff_cache/
|
177 |
+
|
178 |
+
# PyPI configuration file
|
179 |
+
.pypirc
|
LICENSE.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
STABILITY AI COMMUNITY LICENSE AGREEMENT
|
2 |
+
Last Updated: July 5, 2024
|
3 |
+
|
4 |
+
|
5 |
+
I. INTRODUCTION
|
6 |
+
|
7 |
+
This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
|
8 |
+
|
9 |
+
|
10 |
+
This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
|
11 |
+
|
12 |
+
|
13 |
+
By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
|
14 |
+
|
15 |
+
II. RESEARCH & NON-COMMERCIAL USE LICENSE
|
16 |
+
|
17 |
+
Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
|
18 |
+
|
19 |
+
III. COMMERCIAL USE LICENSE
|
20 |
+
|
21 |
+
Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations.
|
22 |
+
If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
|
23 |
+
|
24 |
+
IV. GENERAL TERMS
|
25 |
+
|
26 |
+
Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
|
27 |
+
a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified.
|
28 |
+
b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
|
29 |
+
c. Intellectual Property.
|
30 |
+
(i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
|
31 |
+
(ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
|
32 |
+
(iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
|
33 |
+
(iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
|
34 |
+
(v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
|
35 |
+
d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
36 |
+
e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
37 |
+
f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
|
38 |
+
g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
|
39 |
+
|
40 |
+
V. DEFINITIONS
|
41 |
+
|
42 |
+
"Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
|
43 |
+
"Agreement" means this Stability AI Community License Agreement.
|
44 |
+
"AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
45 |
+
"Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
|
46 |
+
"Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
|
47 |
+
"Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time.
|
48 |
+
"Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
|
49 |
+
"Software" means Stability AI's proprietary software made available under this Agreement now or in the future.
|
50 |
+
"Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
|
51 |
+
"Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
|
README.md
CHANGED
@@ -1,12 +1,26 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: MARBLE
|
3 |
+
emoji: ✨
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.32.1
|
8 |
+
app_file: gradio_demo.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
* **Repository**: [https://github.com/Stability-AI/marble](https://github.com/Stability-AI/marble)
|
13 |
+
* **Tech report**: [https://marblecontrol.github.io/static/MARBLE.pdf](https://marblecontrol.github.io/static/MARBLE.pdf)
|
14 |
+
* **Project page**: [https://marblecontrol.github.io](https://marblecontrol.github.io)
|
15 |
+
|
16 |
+
## Citation
|
17 |
+
If you find MARBLE helpful in your research/applications, please cite using this BibTeX:
|
18 |
+
|
19 |
+
```bibtex
|
20 |
+
@article{cheng2024marble,
|
21 |
+
title={MARBLE: Material Recomposition and Blending in CLIP-Space},
|
22 |
+
author={Cheng, Ta-Ying and Sharma, Prafull and Boss, Mark and Jampani, Varun},
|
23 |
+
journal={CVPR},
|
24 |
+
year={2025}
|
25 |
+
}
|
26 |
+
```
|
gradio_demo.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
from marble import (
|
5 |
+
get_session,
|
6 |
+
run_blend,
|
7 |
+
run_parametric_control,
|
8 |
+
setup_control_mlps,
|
9 |
+
setup_pipeline,
|
10 |
+
)
|
11 |
+
|
12 |
+
# Setup the pipeline and control MLPs
|
13 |
+
control_mlps = setup_control_mlps()
|
14 |
+
ip_adapter = setup_pipeline()
|
15 |
+
get_session()
|
16 |
+
|
17 |
+
# Load example images
|
18 |
+
EXAMPLE_IMAGES = {
|
19 |
+
"blend": {
|
20 |
+
"target": "input_images/context_image/beetle.png",
|
21 |
+
"texture1": "input_images/texture/low_roughness.png",
|
22 |
+
"texture2": "input_images/texture/high_roughness.png",
|
23 |
+
},
|
24 |
+
"parametric": {
|
25 |
+
"target": "input_images/context_image/toy_car.png",
|
26 |
+
"texture": "input_images/texture/metal_bowl.png",
|
27 |
+
},
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def blend_images(target_image, texture1, texture2, edit_strength):
|
32 |
+
"""Blend between two texture images"""
|
33 |
+
result = run_blend(
|
34 |
+
ip_adapter, target_image, texture1, texture2, edit_strength=edit_strength
|
35 |
+
)
|
36 |
+
return result
|
37 |
+
|
38 |
+
|
39 |
+
def parametric_control(
|
40 |
+
target_image,
|
41 |
+
texture_image,
|
42 |
+
control_type,
|
43 |
+
metallic_strength,
|
44 |
+
roughness_strength,
|
45 |
+
transparency_strength,
|
46 |
+
glow_strength,
|
47 |
+
):
|
48 |
+
"""Apply parametric control based on selected control type"""
|
49 |
+
edit_mlps = {}
|
50 |
+
|
51 |
+
if control_type == "Roughness + Metallic":
|
52 |
+
edit_mlps = {
|
53 |
+
control_mlps["metallic"]: metallic_strength,
|
54 |
+
control_mlps["roughness"]: roughness_strength,
|
55 |
+
}
|
56 |
+
elif control_type == "Transparency":
|
57 |
+
edit_mlps = {
|
58 |
+
control_mlps["transparency"]: transparency_strength,
|
59 |
+
}
|
60 |
+
elif control_type == "Glow":
|
61 |
+
edit_mlps = {
|
62 |
+
control_mlps["glow"]: glow_strength,
|
63 |
+
}
|
64 |
+
|
65 |
+
# Use target image as texture if no texture is provided
|
66 |
+
texture_to_use = texture_image if texture_image is not None else target_image
|
67 |
+
|
68 |
+
result = run_parametric_control(
|
69 |
+
ip_adapter,
|
70 |
+
target_image,
|
71 |
+
edit_mlps,
|
72 |
+
texture_to_use,
|
73 |
+
)
|
74 |
+
return result
|
75 |
+
|
76 |
+
|
77 |
+
# Create the Gradio interface
|
78 |
+
with gr.Blocks(
|
79 |
+
title="MARBLE: Material Recomposition and Blending in CLIP-Space"
|
80 |
+
) as demo:
|
81 |
+
gr.Markdown(
|
82 |
+
"""
|
83 |
+
# MARBLE: Material Recomposition and Blending in CLIP-Space
|
84 |
+
|
85 |
+
<div style="display: flex; justify-content: flex-start; gap: 10px;>
|
86 |
+
<a href="https://arxiv.org/abs/"><img src="https://img.shields.io/badge/Arxiv-2501.04689-B31B1B.svg"></a>
|
87 |
+
<a href="https://github.com/Stability-AI/marble"><img src="https://img.shields.io/badge/Github-Marble-B31B1B.svg"></a>
|
88 |
+
</div>
|
89 |
+
|
90 |
+
MARBLE is a tool for material recomposition and blending in CLIP-Space.
|
91 |
+
We provide two modes of operation:
|
92 |
+
- **Texture Blending**: Blend the material properties of two texture images and apply it to a target image.
|
93 |
+
- **Parametric Control**: Apply parametric material control to a target image. You can either provide a texture image, transferring the material properties of the texture to the original image, or you can just provide a target image, and edit the material properties of the original image.
|
94 |
+
"""
|
95 |
+
)
|
96 |
+
|
97 |
+
with gr.Row(variant="panel"):
|
98 |
+
with gr.Tabs():
|
99 |
+
with gr.TabItem("Texture Blending"):
|
100 |
+
with gr.Row(equal_height=False):
|
101 |
+
with gr.Column():
|
102 |
+
with gr.Row():
|
103 |
+
texture1 = gr.Image(label="Texture 1", type="pil")
|
104 |
+
texture2 = gr.Image(label="Texture 2", type="pil")
|
105 |
+
edit_strength = gr.Slider(
|
106 |
+
minimum=0.0,
|
107 |
+
maximum=1.0,
|
108 |
+
value=0.5,
|
109 |
+
step=0.1,
|
110 |
+
label="Blend Strength",
|
111 |
+
)
|
112 |
+
with gr.Column():
|
113 |
+
with gr.Row():
|
114 |
+
target_image = gr.Image(label="Target Image", type="pil")
|
115 |
+
blend_output = gr.Image(label="Blended Result")
|
116 |
+
blend_btn = gr.Button("Blend Textures")
|
117 |
+
|
118 |
+
# Add examples for blending
|
119 |
+
gr.Examples(
|
120 |
+
examples=[
|
121 |
+
[
|
122 |
+
Image.open(EXAMPLE_IMAGES["blend"]["target"]),
|
123 |
+
Image.open(EXAMPLE_IMAGES["blend"]["texture1"]),
|
124 |
+
Image.open(EXAMPLE_IMAGES["blend"]["texture2"]),
|
125 |
+
0.5,
|
126 |
+
]
|
127 |
+
],
|
128 |
+
inputs=[target_image, texture1, texture2, edit_strength],
|
129 |
+
outputs=blend_output,
|
130 |
+
fn=blend_images,
|
131 |
+
cache_examples=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
blend_btn.click(
|
135 |
+
fn=blend_images,
|
136 |
+
inputs=[target_image, texture1, texture2, edit_strength],
|
137 |
+
outputs=blend_output,
|
138 |
+
)
|
139 |
+
|
140 |
+
with gr.TabItem("Parametric Control"):
|
141 |
+
with gr.Row(equal_height=False):
|
142 |
+
with gr.Column():
|
143 |
+
with gr.Row():
|
144 |
+
target_image_pc = gr.Image(label="Target Image", type="pil")
|
145 |
+
texture_image_pc = gr.Image(
|
146 |
+
label="Texture Image (Optional - uses target image if not provided)",
|
147 |
+
type="pil",
|
148 |
+
)
|
149 |
+
control_type = gr.Dropdown(
|
150 |
+
choices=["Roughness + Metallic", "Transparency", "Glow"],
|
151 |
+
value="Roughness + Metallic",
|
152 |
+
label="Control Type",
|
153 |
+
)
|
154 |
+
|
155 |
+
metallic_strength = gr.Slider(
|
156 |
+
minimum=-20,
|
157 |
+
maximum=20,
|
158 |
+
value=0,
|
159 |
+
step=0.1,
|
160 |
+
label="Metallic Strength",
|
161 |
+
visible=True,
|
162 |
+
)
|
163 |
+
roughness_strength = gr.Slider(
|
164 |
+
minimum=-1,
|
165 |
+
maximum=1,
|
166 |
+
value=0,
|
167 |
+
step=0.1,
|
168 |
+
label="Roughness Strength",
|
169 |
+
visible=True,
|
170 |
+
)
|
171 |
+
transparency_strength = gr.Slider(
|
172 |
+
minimum=0,
|
173 |
+
maximum=4,
|
174 |
+
value=0,
|
175 |
+
step=0.1,
|
176 |
+
label="Transparency Strength",
|
177 |
+
visible=False,
|
178 |
+
)
|
179 |
+
glow_strength = gr.Slider(
|
180 |
+
minimum=0,
|
181 |
+
maximum=3,
|
182 |
+
value=0,
|
183 |
+
step=0.1,
|
184 |
+
label="Glow Strength",
|
185 |
+
visible=False,
|
186 |
+
)
|
187 |
+
control_btn = gr.Button("Apply Control")
|
188 |
+
|
189 |
+
with gr.Column():
|
190 |
+
control_output = gr.Image(label="Result")
|
191 |
+
|
192 |
+
def update_slider_visibility(control_type):
|
193 |
+
return [
|
194 |
+
gr.update(visible=control_type == "Roughness + Metallic"),
|
195 |
+
gr.update(visible=control_type == "Roughness + Metallic"),
|
196 |
+
gr.update(visible=control_type == "Transparency"),
|
197 |
+
gr.update(visible=control_type == "Glow"),
|
198 |
+
]
|
199 |
+
|
200 |
+
control_type.change(
|
201 |
+
fn=update_slider_visibility,
|
202 |
+
inputs=[control_type],
|
203 |
+
outputs=[
|
204 |
+
metallic_strength,
|
205 |
+
roughness_strength,
|
206 |
+
transparency_strength,
|
207 |
+
glow_strength,
|
208 |
+
],
|
209 |
+
show_progress=False,
|
210 |
+
)
|
211 |
+
|
212 |
+
# Add examples for parametric control
|
213 |
+
gr.Examples(
|
214 |
+
examples=[
|
215 |
+
[
|
216 |
+
Image.open(EXAMPLE_IMAGES["parametric"]["target"]),
|
217 |
+
Image.open(EXAMPLE_IMAGES["parametric"]["texture"]),
|
218 |
+
"Roughness + Metallic",
|
219 |
+
0, # metallic_strength
|
220 |
+
0, # roughness_strength
|
221 |
+
0, # transparency_strength
|
222 |
+
0, # glow_strength
|
223 |
+
],
|
224 |
+
[
|
225 |
+
Image.open(EXAMPLE_IMAGES["parametric"]["target"]),
|
226 |
+
Image.open(EXAMPLE_IMAGES["parametric"]["texture"]),
|
227 |
+
"Roughness + Metallic",
|
228 |
+
20, # metallic_strength
|
229 |
+
0, # roughness_strength
|
230 |
+
0, # transparency_strength
|
231 |
+
0, # glow_strength
|
232 |
+
],
|
233 |
+
[
|
234 |
+
Image.open(EXAMPLE_IMAGES["parametric"]["target"]),
|
235 |
+
Image.open(EXAMPLE_IMAGES["parametric"]["texture"]),
|
236 |
+
"Roughness + Metallic",
|
237 |
+
0, # metallic_strength
|
238 |
+
1, # roughness_strength
|
239 |
+
0, # transparency_strength
|
240 |
+
0, # glow_strength
|
241 |
+
],
|
242 |
+
],
|
243 |
+
inputs=[
|
244 |
+
target_image_pc,
|
245 |
+
texture_image_pc,
|
246 |
+
control_type,
|
247 |
+
metallic_strength,
|
248 |
+
roughness_strength,
|
249 |
+
transparency_strength,
|
250 |
+
glow_strength,
|
251 |
+
],
|
252 |
+
outputs=control_output,
|
253 |
+
fn=parametric_control,
|
254 |
+
cache_examples=True,
|
255 |
+
)
|
256 |
+
|
257 |
+
control_btn.click(
|
258 |
+
fn=parametric_control,
|
259 |
+
inputs=[
|
260 |
+
target_image_pc,
|
261 |
+
texture_image_pc,
|
262 |
+
control_type,
|
263 |
+
metallic_strength,
|
264 |
+
roughness_strength,
|
265 |
+
transparency_strength,
|
266 |
+
glow_strength,
|
267 |
+
],
|
268 |
+
outputs=control_output,
|
269 |
+
)
|
270 |
+
|
271 |
+
if __name__ == "__main__":
|
272 |
+
demo.launch()
|
input_images/context_image/beetle.png
ADDED
![]() |
Git LFS Details
|
input_images/context_image/genart_teapot.jpg
ADDED
![]() |
Git LFS Details
|
input_images/context_image/toy_car.png
ADDED
![]() |
Git LFS Details
|
input_images/context_image/white_car_night.jpg
ADDED
![]() |
Git LFS Details
|
input_images/depth/beetle.png
ADDED
![]() |
Git LFS Details
|
input_images/depth/toy_car.png
ADDED
![]() |
Git LFS Details
|
input_images/texture/high_roughness.png
ADDED
![]() |
Git LFS Details
|
input_images/texture/low_roughness.png
ADDED
![]() |
Git LFS Details
|
input_images/texture/metal_bowl.png
ADDED
![]() |
Git LFS Details
|
ip_adapter_instantstyle/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
"IPAdapter",
|
5 |
+
"IPAdapterPlus",
|
6 |
+
"IPAdapterPlusXL",
|
7 |
+
"IPAdapterXL",
|
8 |
+
"IPAdapterFull",
|
9 |
+
]
|
ip_adapter_instantstyle/attention_processor.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class AttnProcessor(nn.Module):
|
8 |
+
r"""
|
9 |
+
Default processor for performing attention-related computations.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
hidden_size=None,
|
15 |
+
cross_attention_dim=None,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
def __call__(
|
20 |
+
self,
|
21 |
+
attn,
|
22 |
+
hidden_states,
|
23 |
+
encoder_hidden_states=None,
|
24 |
+
attention_mask=None,
|
25 |
+
temb=None,
|
26 |
+
):
|
27 |
+
residual = hidden_states
|
28 |
+
|
29 |
+
if attn.spatial_norm is not None:
|
30 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
31 |
+
|
32 |
+
input_ndim = hidden_states.ndim
|
33 |
+
|
34 |
+
if input_ndim == 4:
|
35 |
+
batch_size, channel, height, width = hidden_states.shape
|
36 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
37 |
+
|
38 |
+
batch_size, sequence_length, _ = (
|
39 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
40 |
+
)
|
41 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
42 |
+
|
43 |
+
if attn.group_norm is not None:
|
44 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
45 |
+
|
46 |
+
query = attn.to_q(hidden_states)
|
47 |
+
|
48 |
+
if encoder_hidden_states is None:
|
49 |
+
encoder_hidden_states = hidden_states
|
50 |
+
elif attn.norm_cross:
|
51 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
52 |
+
|
53 |
+
key = attn.to_k(encoder_hidden_states)
|
54 |
+
value = attn.to_v(encoder_hidden_states)
|
55 |
+
|
56 |
+
query = attn.head_to_batch_dim(query)
|
57 |
+
key = attn.head_to_batch_dim(key)
|
58 |
+
value = attn.head_to_batch_dim(value)
|
59 |
+
|
60 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
61 |
+
hidden_states = torch.bmm(attention_probs, value)
|
62 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
63 |
+
|
64 |
+
# linear proj
|
65 |
+
hidden_states = attn.to_out[0](hidden_states)
|
66 |
+
# dropout
|
67 |
+
hidden_states = attn.to_out[1](hidden_states)
|
68 |
+
|
69 |
+
if input_ndim == 4:
|
70 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
71 |
+
|
72 |
+
if attn.residual_connection:
|
73 |
+
hidden_states = hidden_states + residual
|
74 |
+
|
75 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
76 |
+
|
77 |
+
return hidden_states
|
78 |
+
|
79 |
+
|
80 |
+
class IPAttnProcessor(nn.Module):
|
81 |
+
r"""
|
82 |
+
Attention processor for IP-Adapater.
|
83 |
+
Args:
|
84 |
+
hidden_size (`int`):
|
85 |
+
The hidden size of the attention layer.
|
86 |
+
cross_attention_dim (`int`):
|
87 |
+
The number of channels in the `encoder_hidden_states`.
|
88 |
+
scale (`float`, defaults to 1.0):
|
89 |
+
the weight scale of image prompt.
|
90 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
91 |
+
The context length of the image features.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.hidden_size = hidden_size
|
98 |
+
self.cross_attention_dim = cross_attention_dim
|
99 |
+
self.scale = scale
|
100 |
+
self.num_tokens = num_tokens
|
101 |
+
self.skip = skip
|
102 |
+
|
103 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
104 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
105 |
+
|
106 |
+
def __call__(
|
107 |
+
self,
|
108 |
+
attn,
|
109 |
+
hidden_states,
|
110 |
+
encoder_hidden_states=None,
|
111 |
+
attention_mask=None,
|
112 |
+
temb=None,
|
113 |
+
):
|
114 |
+
residual = hidden_states
|
115 |
+
|
116 |
+
if attn.spatial_norm is not None:
|
117 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
118 |
+
|
119 |
+
input_ndim = hidden_states.ndim
|
120 |
+
|
121 |
+
if input_ndim == 4:
|
122 |
+
batch_size, channel, height, width = hidden_states.shape
|
123 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
124 |
+
|
125 |
+
batch_size, sequence_length, _ = (
|
126 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
127 |
+
)
|
128 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
129 |
+
|
130 |
+
if attn.group_norm is not None:
|
131 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
132 |
+
|
133 |
+
query = attn.to_q(hidden_states)
|
134 |
+
|
135 |
+
if encoder_hidden_states is None:
|
136 |
+
encoder_hidden_states = hidden_states
|
137 |
+
else:
|
138 |
+
# get encoder_hidden_states, ip_hidden_states
|
139 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
140 |
+
encoder_hidden_states, ip_hidden_states = (
|
141 |
+
encoder_hidden_states[:, :end_pos, :],
|
142 |
+
encoder_hidden_states[:, end_pos:, :],
|
143 |
+
)
|
144 |
+
if attn.norm_cross:
|
145 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
146 |
+
|
147 |
+
key = attn.to_k(encoder_hidden_states)
|
148 |
+
value = attn.to_v(encoder_hidden_states)
|
149 |
+
|
150 |
+
query = attn.head_to_batch_dim(query)
|
151 |
+
key = attn.head_to_batch_dim(key)
|
152 |
+
value = attn.head_to_batch_dim(value)
|
153 |
+
|
154 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
155 |
+
hidden_states = torch.bmm(attention_probs, value)
|
156 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
157 |
+
|
158 |
+
if not self.skip:
|
159 |
+
# for ip-adapter
|
160 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
161 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
162 |
+
|
163 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
164 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
165 |
+
|
166 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
167 |
+
self.attn_map = ip_attention_probs
|
168 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
169 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
170 |
+
|
171 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
172 |
+
|
173 |
+
# linear proj
|
174 |
+
hidden_states = attn.to_out[0](hidden_states)
|
175 |
+
# dropout
|
176 |
+
hidden_states = attn.to_out[1](hidden_states)
|
177 |
+
|
178 |
+
if input_ndim == 4:
|
179 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
180 |
+
|
181 |
+
if attn.residual_connection:
|
182 |
+
hidden_states = hidden_states + residual
|
183 |
+
|
184 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
185 |
+
|
186 |
+
return hidden_states
|
187 |
+
|
188 |
+
|
189 |
+
class AttnProcessor2_0(torch.nn.Module):
|
190 |
+
r"""
|
191 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
192 |
+
"""
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
hidden_size=None,
|
197 |
+
cross_attention_dim=None,
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
201 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
202 |
+
|
203 |
+
def __call__(
|
204 |
+
self,
|
205 |
+
attn,
|
206 |
+
hidden_states,
|
207 |
+
encoder_hidden_states=None,
|
208 |
+
attention_mask=None,
|
209 |
+
temb=None,
|
210 |
+
):
|
211 |
+
residual = hidden_states
|
212 |
+
|
213 |
+
if attn.spatial_norm is not None:
|
214 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
215 |
+
|
216 |
+
input_ndim = hidden_states.ndim
|
217 |
+
|
218 |
+
if input_ndim == 4:
|
219 |
+
batch_size, channel, height, width = hidden_states.shape
|
220 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
221 |
+
|
222 |
+
batch_size, sequence_length, _ = (
|
223 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
224 |
+
)
|
225 |
+
|
226 |
+
if attention_mask is not None:
|
227 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
228 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
229 |
+
# (batch, heads, source_length, target_length)
|
230 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
231 |
+
|
232 |
+
if attn.group_norm is not None:
|
233 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
234 |
+
|
235 |
+
query = attn.to_q(hidden_states)
|
236 |
+
|
237 |
+
if encoder_hidden_states is None:
|
238 |
+
encoder_hidden_states = hidden_states
|
239 |
+
elif attn.norm_cross:
|
240 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
241 |
+
|
242 |
+
key = attn.to_k(encoder_hidden_states)
|
243 |
+
value = attn.to_v(encoder_hidden_states)
|
244 |
+
|
245 |
+
inner_dim = key.shape[-1]
|
246 |
+
head_dim = inner_dim // attn.heads
|
247 |
+
|
248 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
249 |
+
|
250 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
251 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
252 |
+
|
253 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
254 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
255 |
+
hidden_states = F.scaled_dot_product_attention(
|
256 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
257 |
+
)
|
258 |
+
|
259 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
260 |
+
hidden_states = hidden_states.to(query.dtype)
|
261 |
+
|
262 |
+
# linear proj
|
263 |
+
hidden_states = attn.to_out[0](hidden_states)
|
264 |
+
# dropout
|
265 |
+
hidden_states = attn.to_out[1](hidden_states)
|
266 |
+
|
267 |
+
if input_ndim == 4:
|
268 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
269 |
+
|
270 |
+
if attn.residual_connection:
|
271 |
+
hidden_states = hidden_states + residual
|
272 |
+
|
273 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
274 |
+
|
275 |
+
return hidden_states
|
276 |
+
|
277 |
+
|
278 |
+
class IPAttnProcessor2_0(torch.nn.Module):
|
279 |
+
r"""
|
280 |
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
281 |
+
Args:
|
282 |
+
hidden_size (`int`):
|
283 |
+
The hidden size of the attention layer.
|
284 |
+
cross_attention_dim (`int`):
|
285 |
+
The number of channels in the `encoder_hidden_states`.
|
286 |
+
scale (`float`, defaults to 1.0):
|
287 |
+
the weight scale of image prompt.
|
288 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
289 |
+
The context length of the image features.
|
290 |
+
"""
|
291 |
+
|
292 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
|
293 |
+
super().__init__()
|
294 |
+
|
295 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
296 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
297 |
+
|
298 |
+
self.hidden_size = hidden_size
|
299 |
+
self.cross_attention_dim = cross_attention_dim
|
300 |
+
self.scale = scale
|
301 |
+
self.num_tokens = num_tokens
|
302 |
+
self.skip = skip
|
303 |
+
|
304 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
305 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
306 |
+
|
307 |
+
def __call__(
|
308 |
+
self,
|
309 |
+
attn,
|
310 |
+
hidden_states,
|
311 |
+
encoder_hidden_states=None,
|
312 |
+
attention_mask=None,
|
313 |
+
temb=None,
|
314 |
+
):
|
315 |
+
residual = hidden_states
|
316 |
+
|
317 |
+
if attn.spatial_norm is not None:
|
318 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
319 |
+
|
320 |
+
input_ndim = hidden_states.ndim
|
321 |
+
|
322 |
+
if input_ndim == 4:
|
323 |
+
batch_size, channel, height, width = hidden_states.shape
|
324 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
325 |
+
|
326 |
+
batch_size, sequence_length, _ = (
|
327 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
328 |
+
)
|
329 |
+
|
330 |
+
if attention_mask is not None:
|
331 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
332 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
333 |
+
# (batch, heads, source_length, target_length)
|
334 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
335 |
+
|
336 |
+
if attn.group_norm is not None:
|
337 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
338 |
+
|
339 |
+
query = attn.to_q(hidden_states)
|
340 |
+
|
341 |
+
if encoder_hidden_states is None:
|
342 |
+
encoder_hidden_states = hidden_states
|
343 |
+
else:
|
344 |
+
# get encoder_hidden_states, ip_hidden_states
|
345 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
346 |
+
encoder_hidden_states, ip_hidden_states = (
|
347 |
+
encoder_hidden_states[:, :end_pos, :],
|
348 |
+
encoder_hidden_states[:, end_pos:, :],
|
349 |
+
)
|
350 |
+
if attn.norm_cross:
|
351 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
352 |
+
|
353 |
+
key = attn.to_k(encoder_hidden_states)
|
354 |
+
value = attn.to_v(encoder_hidden_states)
|
355 |
+
|
356 |
+
inner_dim = key.shape[-1]
|
357 |
+
head_dim = inner_dim // attn.heads
|
358 |
+
|
359 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
360 |
+
|
361 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
362 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
363 |
+
|
364 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
365 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
366 |
+
hidden_states = F.scaled_dot_product_attention(
|
367 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
368 |
+
)
|
369 |
+
|
370 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
371 |
+
hidden_states = hidden_states.to(query.dtype)
|
372 |
+
|
373 |
+
if not self.skip:
|
374 |
+
# for ip-adapter
|
375 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
376 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
377 |
+
|
378 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
379 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
380 |
+
|
381 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
382 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
383 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
384 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
385 |
+
)
|
386 |
+
with torch.no_grad():
|
387 |
+
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
388 |
+
#print(self.attn_map.shape)
|
389 |
+
|
390 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
391 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
392 |
+
|
393 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
394 |
+
|
395 |
+
# linear proj
|
396 |
+
hidden_states = attn.to_out[0](hidden_states)
|
397 |
+
# dropout
|
398 |
+
hidden_states = attn.to_out[1](hidden_states)
|
399 |
+
|
400 |
+
if input_ndim == 4:
|
401 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
402 |
+
|
403 |
+
if attn.residual_connection:
|
404 |
+
hidden_states = hidden_states + residual
|
405 |
+
|
406 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
407 |
+
|
408 |
+
return hidden_states
|
409 |
+
|
410 |
+
|
411 |
+
## for controlnet
|
412 |
+
class CNAttnProcessor:
|
413 |
+
r"""
|
414 |
+
Default processor for performing attention-related computations.
|
415 |
+
"""
|
416 |
+
|
417 |
+
def __init__(self, num_tokens=4):
|
418 |
+
self.num_tokens = num_tokens
|
419 |
+
|
420 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
421 |
+
residual = hidden_states
|
422 |
+
|
423 |
+
if attn.spatial_norm is not None:
|
424 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
425 |
+
|
426 |
+
input_ndim = hidden_states.ndim
|
427 |
+
|
428 |
+
if input_ndim == 4:
|
429 |
+
batch_size, channel, height, width = hidden_states.shape
|
430 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
431 |
+
|
432 |
+
batch_size, sequence_length, _ = (
|
433 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
434 |
+
)
|
435 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
436 |
+
|
437 |
+
if attn.group_norm is not None:
|
438 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
439 |
+
|
440 |
+
query = attn.to_q(hidden_states)
|
441 |
+
|
442 |
+
if encoder_hidden_states is None:
|
443 |
+
encoder_hidden_states = hidden_states
|
444 |
+
else:
|
445 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
446 |
+
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
447 |
+
if attn.norm_cross:
|
448 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
449 |
+
|
450 |
+
key = attn.to_k(encoder_hidden_states)
|
451 |
+
value = attn.to_v(encoder_hidden_states)
|
452 |
+
|
453 |
+
query = attn.head_to_batch_dim(query)
|
454 |
+
key = attn.head_to_batch_dim(key)
|
455 |
+
value = attn.head_to_batch_dim(value)
|
456 |
+
|
457 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
458 |
+
hidden_states = torch.bmm(attention_probs, value)
|
459 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
460 |
+
|
461 |
+
# linear proj
|
462 |
+
hidden_states = attn.to_out[0](hidden_states)
|
463 |
+
# dropout
|
464 |
+
hidden_states = attn.to_out[1](hidden_states)
|
465 |
+
|
466 |
+
if input_ndim == 4:
|
467 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
468 |
+
|
469 |
+
if attn.residual_connection:
|
470 |
+
hidden_states = hidden_states + residual
|
471 |
+
|
472 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
473 |
+
|
474 |
+
return hidden_states
|
475 |
+
|
476 |
+
|
477 |
+
class CNAttnProcessor2_0:
|
478 |
+
r"""
|
479 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
480 |
+
"""
|
481 |
+
|
482 |
+
def __init__(self, num_tokens=4):
|
483 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
484 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
485 |
+
self.num_tokens = num_tokens
|
486 |
+
|
487 |
+
def __call__(
|
488 |
+
self,
|
489 |
+
attn,
|
490 |
+
hidden_states,
|
491 |
+
encoder_hidden_states=None,
|
492 |
+
attention_mask=None,
|
493 |
+
temb=None,
|
494 |
+
):
|
495 |
+
residual = hidden_states
|
496 |
+
|
497 |
+
if attn.spatial_norm is not None:
|
498 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
499 |
+
|
500 |
+
input_ndim = hidden_states.ndim
|
501 |
+
|
502 |
+
if input_ndim == 4:
|
503 |
+
batch_size, channel, height, width = hidden_states.shape
|
504 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
505 |
+
|
506 |
+
batch_size, sequence_length, _ = (
|
507 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
508 |
+
)
|
509 |
+
|
510 |
+
if attention_mask is not None:
|
511 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
512 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
513 |
+
# (batch, heads, source_length, target_length)
|
514 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
515 |
+
|
516 |
+
if attn.group_norm is not None:
|
517 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
518 |
+
|
519 |
+
query = attn.to_q(hidden_states)
|
520 |
+
|
521 |
+
if encoder_hidden_states is None:
|
522 |
+
encoder_hidden_states = hidden_states
|
523 |
+
else:
|
524 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
525 |
+
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
526 |
+
if attn.norm_cross:
|
527 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
528 |
+
|
529 |
+
key = attn.to_k(encoder_hidden_states)
|
530 |
+
value = attn.to_v(encoder_hidden_states)
|
531 |
+
|
532 |
+
inner_dim = key.shape[-1]
|
533 |
+
head_dim = inner_dim // attn.heads
|
534 |
+
|
535 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
536 |
+
|
537 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
538 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
539 |
+
|
540 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
541 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
542 |
+
hidden_states = F.scaled_dot_product_attention(
|
543 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
544 |
+
)
|
545 |
+
|
546 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
547 |
+
hidden_states = hidden_states.to(query.dtype)
|
548 |
+
|
549 |
+
# linear proj
|
550 |
+
hidden_states = attn.to_out[0](hidden_states)
|
551 |
+
# dropout
|
552 |
+
hidden_states = attn.to_out[1](hidden_states)
|
553 |
+
|
554 |
+
if input_ndim == 4:
|
555 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
556 |
+
|
557 |
+
if attn.residual_connection:
|
558 |
+
hidden_states = hidden_states + residual
|
559 |
+
|
560 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
561 |
+
|
562 |
+
return hidden_states
|
ip_adapter_instantstyle/ip_adapter.py
ADDED
@@ -0,0 +1,858 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from diffusers import StableDiffusionPipeline
|
8 |
+
from diffusers.pipelines.controlnet import MultiControlNetModel
|
9 |
+
from PIL import Image
|
10 |
+
from safetensors import safe_open
|
11 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
12 |
+
|
13 |
+
from .utils import is_torch2_available, get_generator
|
14 |
+
|
15 |
+
L = 4
|
16 |
+
|
17 |
+
|
18 |
+
def pos_encode(x, L):
|
19 |
+
pos_encode = []
|
20 |
+
|
21 |
+
for freq in range(L):
|
22 |
+
pos_encode.append(torch.cos(2**freq * torch.pi * x))
|
23 |
+
pos_encode.append(torch.sin(2**freq * torch.pi * x))
|
24 |
+
pos_encode = torch.cat(pos_encode, dim=1)
|
25 |
+
return pos_encode
|
26 |
+
|
27 |
+
|
28 |
+
if is_torch2_available():
|
29 |
+
from .attention_processor import (
|
30 |
+
AttnProcessor2_0 as AttnProcessor,
|
31 |
+
)
|
32 |
+
from .attention_processor import (
|
33 |
+
CNAttnProcessor2_0 as CNAttnProcessor,
|
34 |
+
)
|
35 |
+
from .attention_processor import (
|
36 |
+
IPAttnProcessor2_0 as IPAttnProcessor,
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
|
40 |
+
from .resampler import Resampler
|
41 |
+
|
42 |
+
|
43 |
+
class ImageProjModel(torch.nn.Module):
|
44 |
+
"""Projection Model"""
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
cross_attention_dim=1024,
|
49 |
+
clip_embeddings_dim=1024,
|
50 |
+
clip_extra_context_tokens=4,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.generator = None
|
55 |
+
self.cross_attention_dim = cross_attention_dim
|
56 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
57 |
+
self.proj = torch.nn.Linear(
|
58 |
+
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
|
59 |
+
)
|
60 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
61 |
+
|
62 |
+
def forward(self, image_embeds):
|
63 |
+
embeds = image_embeds
|
64 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(
|
65 |
+
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
66 |
+
)
|
67 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
68 |
+
return clip_extra_context_tokens
|
69 |
+
|
70 |
+
|
71 |
+
class MLPProjModel(torch.nn.Module):
|
72 |
+
"""SD model with image prompt"""
|
73 |
+
|
74 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.proj = torch.nn.Sequential(
|
78 |
+
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
|
79 |
+
torch.nn.GELU(),
|
80 |
+
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
|
81 |
+
torch.nn.LayerNorm(cross_attention_dim),
|
82 |
+
)
|
83 |
+
|
84 |
+
def forward(self, image_embeds):
|
85 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
86 |
+
return clip_extra_context_tokens
|
87 |
+
|
88 |
+
|
89 |
+
class IPAdapter:
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
sd_pipe,
|
93 |
+
image_encoder_path,
|
94 |
+
ip_ckpt,
|
95 |
+
device,
|
96 |
+
num_tokens=4,
|
97 |
+
target_blocks=["block"],
|
98 |
+
):
|
99 |
+
self.device = device
|
100 |
+
self.image_encoder_path = image_encoder_path
|
101 |
+
self.ip_ckpt = ip_ckpt
|
102 |
+
self.num_tokens = num_tokens
|
103 |
+
self.target_blocks = target_blocks
|
104 |
+
|
105 |
+
self.pipe = sd_pipe.to(self.device)
|
106 |
+
self.set_ip_adapter()
|
107 |
+
|
108 |
+
# load image encoder
|
109 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
110 |
+
self.image_encoder_path
|
111 |
+
).to(self.device, dtype=torch.float16)
|
112 |
+
self.clip_image_processor = CLIPImageProcessor()
|
113 |
+
# image proj model
|
114 |
+
self.image_proj_model = self.init_proj()
|
115 |
+
|
116 |
+
self.load_ip_adapter()
|
117 |
+
|
118 |
+
def init_proj(self):
|
119 |
+
image_proj_model = ImageProjModel(
|
120 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
121 |
+
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
122 |
+
clip_extra_context_tokens=self.num_tokens,
|
123 |
+
).to(self.device, dtype=torch.float16)
|
124 |
+
return image_proj_model
|
125 |
+
|
126 |
+
def set_ip_adapter(self):
|
127 |
+
unet = self.pipe.unet
|
128 |
+
attn_procs = {}
|
129 |
+
for name in unet.attn_processors.keys():
|
130 |
+
cross_attention_dim = (
|
131 |
+
None
|
132 |
+
if name.endswith("attn1.processor")
|
133 |
+
else unet.config.cross_attention_dim
|
134 |
+
)
|
135 |
+
if name.startswith("mid_block"):
|
136 |
+
hidden_size = unet.config.block_out_channels[-1]
|
137 |
+
elif name.startswith("up_blocks"):
|
138 |
+
block_id = int(name[len("up_blocks.")])
|
139 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
140 |
+
elif name.startswith("down_blocks"):
|
141 |
+
block_id = int(name[len("down_blocks.")])
|
142 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
143 |
+
if cross_attention_dim is None:
|
144 |
+
attn_procs[name] = AttnProcessor()
|
145 |
+
else:
|
146 |
+
selected = False
|
147 |
+
for block_name in self.target_blocks:
|
148 |
+
if block_name in name:
|
149 |
+
selected = True
|
150 |
+
break
|
151 |
+
if selected:
|
152 |
+
attn_procs[name] = IPAttnProcessor(
|
153 |
+
hidden_size=hidden_size,
|
154 |
+
cross_attention_dim=cross_attention_dim,
|
155 |
+
scale=1.0,
|
156 |
+
num_tokens=self.num_tokens,
|
157 |
+
).to(self.device, dtype=torch.float16)
|
158 |
+
else:
|
159 |
+
attn_procs[name] = IPAttnProcessor(
|
160 |
+
hidden_size=hidden_size,
|
161 |
+
cross_attention_dim=cross_attention_dim,
|
162 |
+
scale=1.0,
|
163 |
+
num_tokens=self.num_tokens,
|
164 |
+
skip=True,
|
165 |
+
).to(self.device, dtype=torch.float16)
|
166 |
+
unet.set_attn_processor(attn_procs)
|
167 |
+
if hasattr(self.pipe, "controlnet"):
|
168 |
+
if isinstance(self.pipe.controlnet, MultiControlNetModel):
|
169 |
+
for controlnet in self.pipe.controlnet.nets:
|
170 |
+
controlnet.set_attn_processor(
|
171 |
+
CNAttnProcessor(num_tokens=self.num_tokens)
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
self.pipe.controlnet.set_attn_processor(
|
175 |
+
CNAttnProcessor(num_tokens=self.num_tokens)
|
176 |
+
)
|
177 |
+
|
178 |
+
def load_ip_adapter(self):
|
179 |
+
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
|
180 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
181 |
+
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
|
182 |
+
for key in f.keys():
|
183 |
+
if key.startswith("image_proj."):
|
184 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = (
|
185 |
+
f.get_tensor(key)
|
186 |
+
)
|
187 |
+
elif key.startswith("ip_adapter."):
|
188 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = (
|
189 |
+
f.get_tensor(key)
|
190 |
+
)
|
191 |
+
else:
|
192 |
+
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
193 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
194 |
+
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
195 |
+
ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
|
196 |
+
|
197 |
+
@torch.inference_mode()
|
198 |
+
def get_image_embeds(
|
199 |
+
self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None
|
200 |
+
):
|
201 |
+
if pil_image is not None:
|
202 |
+
if isinstance(pil_image, Image.Image):
|
203 |
+
pil_image = [pil_image]
|
204 |
+
clip_image = self.clip_image_processor(
|
205 |
+
images=pil_image, return_tensors="pt"
|
206 |
+
).pixel_values
|
207 |
+
clip_image_embeds = self.image_encoder(
|
208 |
+
clip_image.to(self.device, dtype=torch.float16)
|
209 |
+
).image_embeds
|
210 |
+
else:
|
211 |
+
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
|
212 |
+
|
213 |
+
if content_prompt_embeds is not None:
|
214 |
+
print(clip_image_embeds.shape)
|
215 |
+
print(content_prompt_embeds.shape)
|
216 |
+
clip_image_embeds = clip_image_embeds - content_prompt_embeds
|
217 |
+
|
218 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
219 |
+
uncond_image_prompt_embeds = self.image_proj_model(
|
220 |
+
torch.zeros_like(clip_image_embeds)
|
221 |
+
)
|
222 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
223 |
+
|
224 |
+
@torch.inference_mode()
|
225 |
+
def generate_image_edit_dir(
|
226 |
+
self,
|
227 |
+
pil_image=None,
|
228 |
+
content_prompt_embeds=None,
|
229 |
+
edit_mlps: dict[torch.nn.Module, float] = None,
|
230 |
+
):
|
231 |
+
print("Combining multiple MLPs!")
|
232 |
+
if pil_image is not None:
|
233 |
+
if isinstance(pil_image, Image.Image):
|
234 |
+
pil_image = [pil_image]
|
235 |
+
clip_image = self.clip_image_processor(
|
236 |
+
images=pil_image, return_tensors="pt"
|
237 |
+
).pixel_values
|
238 |
+
clip_image_embeds = self.image_encoder(
|
239 |
+
clip_image.to(self.device, dtype=torch.float16)
|
240 |
+
).image_embeds
|
241 |
+
pred_editing_dirs = [
|
242 |
+
net(
|
243 |
+
clip_image_embeds,
|
244 |
+
torch.Tensor([strength]).to(self.device, dtype=torch.float16),
|
245 |
+
)
|
246 |
+
for net, strength in edit_mlps.items()
|
247 |
+
]
|
248 |
+
|
249 |
+
clip_image_embeds = clip_image_embeds + sum(pred_editing_dirs)
|
250 |
+
|
251 |
+
if content_prompt_embeds is not None:
|
252 |
+
clip_image_embeds = clip_image_embeds - content_prompt_embeds
|
253 |
+
|
254 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
255 |
+
uncond_image_prompt_embeds = self.image_proj_model(
|
256 |
+
torch.zeros_like(clip_image_embeds)
|
257 |
+
)
|
258 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
259 |
+
|
260 |
+
@torch.inference_mode()
|
261 |
+
def get_image_edit_dir(
|
262 |
+
self,
|
263 |
+
start_image=None,
|
264 |
+
pil_image=None,
|
265 |
+
pil_image2=None,
|
266 |
+
content_prompt_embeds=None,
|
267 |
+
edit_strength=1.0,
|
268 |
+
):
|
269 |
+
print("Blending Two Materials!")
|
270 |
+
if pil_image is not None:
|
271 |
+
if isinstance(pil_image, Image.Image):
|
272 |
+
pil_image = [pil_image]
|
273 |
+
clip_image = self.clip_image_processor(
|
274 |
+
images=pil_image, return_tensors="pt"
|
275 |
+
).pixel_values
|
276 |
+
clip_image_embeds = self.image_encoder(
|
277 |
+
clip_image.to(self.device, dtype=torch.float16)
|
278 |
+
).image_embeds
|
279 |
+
|
280 |
+
if pil_image2 is not None:
|
281 |
+
if isinstance(pil_image2, Image.Image):
|
282 |
+
pil_image2 = [pil_image2]
|
283 |
+
clip_image2 = self.clip_image_processor(
|
284 |
+
images=pil_image2, return_tensors="pt"
|
285 |
+
).pixel_values
|
286 |
+
clip_image_embeds2 = self.image_encoder(
|
287 |
+
clip_image2.to(self.device, dtype=torch.float16)
|
288 |
+
).image_embeds
|
289 |
+
|
290 |
+
if start_image is not None:
|
291 |
+
if isinstance(start_image, Image.Image):
|
292 |
+
start_image = [start_image]
|
293 |
+
clip_image_start = self.clip_image_processor(
|
294 |
+
images=start_image, return_tensors="pt"
|
295 |
+
).pixel_values
|
296 |
+
clip_image_embeds_start = self.image_encoder(
|
297 |
+
clip_image_start.to(self.device, dtype=torch.float16)
|
298 |
+
).image_embeds
|
299 |
+
|
300 |
+
if content_prompt_embeds is not None:
|
301 |
+
clip_image_embeds = clip_image_embeds - content_prompt_embeds
|
302 |
+
clip_image_embeds2 = clip_image_embeds2 - content_prompt_embeds
|
303 |
+
|
304 |
+
# clip_image_embeds += edit_strength * (clip_image_embeds2 - clip_image_embeds)
|
305 |
+
clip_image_embeds = clip_image_embeds_start + edit_strength * (
|
306 |
+
clip_image_embeds2 - clip_image_embeds
|
307 |
+
)
|
308 |
+
|
309 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
310 |
+
uncond_image_prompt_embeds = self.image_proj_model(
|
311 |
+
torch.zeros_like(clip_image_embeds)
|
312 |
+
)
|
313 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
314 |
+
|
315 |
+
def set_scale(self, scale):
|
316 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
317 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
318 |
+
attn_processor.scale = scale
|
319 |
+
|
320 |
+
def set_scale(self, scale):
|
321 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
322 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
323 |
+
attn_processor.scale = scale
|
324 |
+
|
325 |
+
def generate(
|
326 |
+
self,
|
327 |
+
pil_image=None,
|
328 |
+
clip_image_embeds=None,
|
329 |
+
prompt=None,
|
330 |
+
negative_prompt=None,
|
331 |
+
scale=1.0,
|
332 |
+
num_samples=4,
|
333 |
+
seed=None,
|
334 |
+
guidance_scale=7.5,
|
335 |
+
num_inference_steps=30,
|
336 |
+
neg_content_emb=None,
|
337 |
+
**kwargs,
|
338 |
+
):
|
339 |
+
self.set_scale(scale)
|
340 |
+
|
341 |
+
if pil_image is not None:
|
342 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
343 |
+
else:
|
344 |
+
num_prompts = clip_image_embeds.size(0)
|
345 |
+
|
346 |
+
if prompt is None:
|
347 |
+
prompt = "best quality, high quality"
|
348 |
+
if negative_prompt is None:
|
349 |
+
negative_prompt = (
|
350 |
+
"monochrome, lowres, bad anatomy, worst quality, low quality"
|
351 |
+
)
|
352 |
+
|
353 |
+
if not isinstance(prompt, List):
|
354 |
+
prompt = [prompt] * num_prompts
|
355 |
+
if not isinstance(negative_prompt, List):
|
356 |
+
negative_prompt = [negative_prompt] * num_prompts
|
357 |
+
|
358 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
|
359 |
+
pil_image=pil_image,
|
360 |
+
clip_image_embeds=clip_image_embeds,
|
361 |
+
content_prompt_embeds=neg_content_emb,
|
362 |
+
)
|
363 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
364 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
365 |
+
image_prompt_embeds = image_prompt_embeds.view(
|
366 |
+
bs_embed * num_samples, seq_len, -1
|
367 |
+
)
|
368 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
|
369 |
+
1, num_samples, 1
|
370 |
+
)
|
371 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
|
372 |
+
bs_embed * num_samples, seq_len, -1
|
373 |
+
)
|
374 |
+
|
375 |
+
with torch.inference_mode():
|
376 |
+
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
|
377 |
+
prompt,
|
378 |
+
device=self.device,
|
379 |
+
num_images_per_prompt=num_samples,
|
380 |
+
do_classifier_free_guidance=True,
|
381 |
+
negative_prompt=negative_prompt,
|
382 |
+
)
|
383 |
+
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
384 |
+
negative_prompt_embeds = torch.cat(
|
385 |
+
[negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1
|
386 |
+
)
|
387 |
+
|
388 |
+
generator = get_generator(seed, self.device)
|
389 |
+
|
390 |
+
images = self.pipe(
|
391 |
+
prompt_embeds=prompt_embeds,
|
392 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
393 |
+
guidance_scale=guidance_scale,
|
394 |
+
num_inference_steps=num_inference_steps,
|
395 |
+
generator=generator,
|
396 |
+
**kwargs,
|
397 |
+
).images
|
398 |
+
|
399 |
+
return images
|
400 |
+
|
401 |
+
|
402 |
+
class IPAdapterXL(IPAdapter):
|
403 |
+
"""SDXL"""
|
404 |
+
|
405 |
+
def generate(
|
406 |
+
self,
|
407 |
+
pil_image,
|
408 |
+
prompt=None,
|
409 |
+
negative_prompt=None,
|
410 |
+
scale=1.0,
|
411 |
+
num_samples=4,
|
412 |
+
seed=None,
|
413 |
+
num_inference_steps=30,
|
414 |
+
neg_content_emb=None,
|
415 |
+
neg_content_prompt=None,
|
416 |
+
neg_content_scale=1.0,
|
417 |
+
clip_strength=1.0,
|
418 |
+
**kwargs,
|
419 |
+
):
|
420 |
+
self.set_scale(scale)
|
421 |
+
|
422 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
423 |
+
|
424 |
+
if prompt is None:
|
425 |
+
prompt = "best quality, high quality"
|
426 |
+
if negative_prompt is None:
|
427 |
+
negative_prompt = (
|
428 |
+
"monochrome, lowres, bad anatomy, worst quality, low quality"
|
429 |
+
)
|
430 |
+
|
431 |
+
if not isinstance(prompt, List):
|
432 |
+
prompt = [prompt] * num_prompts
|
433 |
+
if not isinstance(negative_prompt, List):
|
434 |
+
negative_prompt = [negative_prompt] * num_prompts
|
435 |
+
|
436 |
+
if neg_content_emb is None:
|
437 |
+
if neg_content_prompt is not None:
|
438 |
+
with torch.inference_mode():
|
439 |
+
(
|
440 |
+
prompt_embeds_, # torch.Size([1, 77, 2048])
|
441 |
+
negative_prompt_embeds_,
|
442 |
+
pooled_prompt_embeds_, # torch.Size([1, 1280])
|
443 |
+
negative_pooled_prompt_embeds_,
|
444 |
+
) = self.pipe.encode_prompt(
|
445 |
+
neg_content_prompt,
|
446 |
+
num_images_per_prompt=num_samples,
|
447 |
+
do_classifier_free_guidance=True,
|
448 |
+
negative_prompt=negative_prompt,
|
449 |
+
)
|
450 |
+
pooled_prompt_embeds_ *= neg_content_scale
|
451 |
+
else:
|
452 |
+
pooled_prompt_embeds_ = neg_content_emb
|
453 |
+
else:
|
454 |
+
pooled_prompt_embeds_ = None
|
455 |
+
|
456 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
|
457 |
+
pil_image, content_prompt_embeds=pooled_prompt_embeds_
|
458 |
+
)
|
459 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
460 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
461 |
+
image_prompt_embeds = image_prompt_embeds.view(
|
462 |
+
bs_embed * num_samples, seq_len, -1
|
463 |
+
)
|
464 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
|
465 |
+
1, num_samples, 1
|
466 |
+
)
|
467 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
|
468 |
+
bs_embed * num_samples, seq_len, -1
|
469 |
+
)
|
470 |
+
print("CLIP Strength is {}".format(clip_strength))
|
471 |
+
image_prompt_embeds *= clip_strength
|
472 |
+
uncond_image_prompt_embeds *= clip_strength
|
473 |
+
|
474 |
+
with torch.inference_mode():
|
475 |
+
(
|
476 |
+
prompt_embeds,
|
477 |
+
negative_prompt_embeds,
|
478 |
+
pooled_prompt_embeds,
|
479 |
+
negative_pooled_prompt_embeds,
|
480 |
+
) = self.pipe.encode_prompt(
|
481 |
+
prompt,
|
482 |
+
num_images_per_prompt=num_samples,
|
483 |
+
do_classifier_free_guidance=True,
|
484 |
+
negative_prompt=negative_prompt,
|
485 |
+
)
|
486 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
487 |
+
negative_prompt_embeds = torch.cat(
|
488 |
+
[negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
|
489 |
+
)
|
490 |
+
|
491 |
+
self.generator = get_generator(seed, self.device)
|
492 |
+
|
493 |
+
images = self.pipe(
|
494 |
+
prompt_embeds=prompt_embeds,
|
495 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
496 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
497 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
498 |
+
num_inference_steps=num_inference_steps,
|
499 |
+
generator=self.generator,
|
500 |
+
**kwargs,
|
501 |
+
).images
|
502 |
+
|
503 |
+
return images
|
504 |
+
|
505 |
+
def generate_parametric_edits(
|
506 |
+
self,
|
507 |
+
pil_image,
|
508 |
+
edit_mlps: dict[torch.nn.Module, float],
|
509 |
+
prompt=None,
|
510 |
+
negative_prompt=None,
|
511 |
+
scale=1.0,
|
512 |
+
num_samples=4,
|
513 |
+
seed=None,
|
514 |
+
num_inference_steps=30,
|
515 |
+
neg_content_emb=None,
|
516 |
+
neg_content_prompt=None,
|
517 |
+
neg_content_scale=1.0,
|
518 |
+
**kwargs,
|
519 |
+
):
|
520 |
+
self.set_scale(scale)
|
521 |
+
|
522 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
523 |
+
|
524 |
+
if prompt is None:
|
525 |
+
prompt = "best quality, high quality"
|
526 |
+
if negative_prompt is None:
|
527 |
+
negative_prompt = (
|
528 |
+
"monochrome, lowres, bad anatomy, worst quality, low quality"
|
529 |
+
)
|
530 |
+
|
531 |
+
if not isinstance(prompt, List):
|
532 |
+
prompt = [prompt] * num_prompts
|
533 |
+
if not isinstance(negative_prompt, List):
|
534 |
+
negative_prompt = [negative_prompt] * num_prompts
|
535 |
+
|
536 |
+
if neg_content_emb is None:
|
537 |
+
if neg_content_prompt is not None:
|
538 |
+
with torch.inference_mode():
|
539 |
+
(
|
540 |
+
prompt_embeds_, # torch.Size([1, 77, 2048])
|
541 |
+
negative_prompt_embeds_,
|
542 |
+
pooled_prompt_embeds_, # torch.Size([1, 1280])
|
543 |
+
negative_pooled_prompt_embeds_,
|
544 |
+
) = self.pipe.encode_prompt(
|
545 |
+
neg_content_prompt,
|
546 |
+
num_images_per_prompt=num_samples,
|
547 |
+
do_classifier_free_guidance=True,
|
548 |
+
negative_prompt=negative_prompt,
|
549 |
+
)
|
550 |
+
pooled_prompt_embeds_ *= neg_content_scale
|
551 |
+
else:
|
552 |
+
pooled_prompt_embeds_ = neg_content_emb
|
553 |
+
else:
|
554 |
+
pooled_prompt_embeds_ = None
|
555 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.generate_image_edit_dir(
|
556 |
+
pil_image, content_prompt_embeds=pooled_prompt_embeds_, edit_mlps=edit_mlps
|
557 |
+
)
|
558 |
+
|
559 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
560 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
561 |
+
image_prompt_embeds = image_prompt_embeds.view(
|
562 |
+
bs_embed * num_samples, seq_len, -1
|
563 |
+
)
|
564 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
|
565 |
+
1, num_samples, 1
|
566 |
+
)
|
567 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
|
568 |
+
bs_embed * num_samples, seq_len, -1
|
569 |
+
)
|
570 |
+
|
571 |
+
with torch.inference_mode():
|
572 |
+
(
|
573 |
+
prompt_embeds,
|
574 |
+
negative_prompt_embeds,
|
575 |
+
pooled_prompt_embeds,
|
576 |
+
negative_pooled_prompt_embeds,
|
577 |
+
) = self.pipe.encode_prompt(
|
578 |
+
prompt,
|
579 |
+
num_images_per_prompt=num_samples,
|
580 |
+
do_classifier_free_guidance=True,
|
581 |
+
negative_prompt=negative_prompt,
|
582 |
+
)
|
583 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
584 |
+
negative_prompt_embeds = torch.cat(
|
585 |
+
[negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
|
586 |
+
)
|
587 |
+
|
588 |
+
self.generator = get_generator(seed, self.device)
|
589 |
+
|
590 |
+
images = self.pipe(
|
591 |
+
prompt_embeds=prompt_embeds,
|
592 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
593 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
594 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
595 |
+
num_inference_steps=num_inference_steps,
|
596 |
+
generator=self.generator,
|
597 |
+
**kwargs,
|
598 |
+
).images
|
599 |
+
|
600 |
+
return images
|
601 |
+
|
602 |
+
def generate_edit(
|
603 |
+
self,
|
604 |
+
start_image,
|
605 |
+
pil_image,
|
606 |
+
pil_image2,
|
607 |
+
prompt=None,
|
608 |
+
negative_prompt=None,
|
609 |
+
scale=1.0,
|
610 |
+
num_samples=4,
|
611 |
+
seed=None,
|
612 |
+
num_inference_steps=30,
|
613 |
+
neg_content_emb=None,
|
614 |
+
neg_content_prompt=None,
|
615 |
+
neg_content_scale=1.0,
|
616 |
+
edit_strength=1.0,
|
617 |
+
**kwargs,
|
618 |
+
):
|
619 |
+
self.set_scale(scale)
|
620 |
+
|
621 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
622 |
+
|
623 |
+
if prompt is None:
|
624 |
+
prompt = "best quality, high quality"
|
625 |
+
if negative_prompt is None:
|
626 |
+
negative_prompt = (
|
627 |
+
"monochrome, lowres, bad anatomy, worst quality, low quality"
|
628 |
+
)
|
629 |
+
|
630 |
+
if not isinstance(prompt, List):
|
631 |
+
prompt = [prompt] * num_prompts
|
632 |
+
if not isinstance(negative_prompt, List):
|
633 |
+
negative_prompt = [negative_prompt] * num_prompts
|
634 |
+
|
635 |
+
if neg_content_emb is None:
|
636 |
+
if neg_content_prompt is not None:
|
637 |
+
with torch.inference_mode():
|
638 |
+
(
|
639 |
+
prompt_embeds_, # torch.Size([1, 77, 2048])
|
640 |
+
negative_prompt_embeds_,
|
641 |
+
pooled_prompt_embeds_, # torch.Size([1, 1280])
|
642 |
+
negative_pooled_prompt_embeds_,
|
643 |
+
) = self.pipe.encode_prompt(
|
644 |
+
neg_content_prompt,
|
645 |
+
num_images_per_prompt=num_samples,
|
646 |
+
do_classifier_free_guidance=True,
|
647 |
+
negative_prompt=negative_prompt,
|
648 |
+
)
|
649 |
+
pooled_prompt_embeds_ *= neg_content_scale
|
650 |
+
else:
|
651 |
+
pooled_prompt_embeds_ = neg_content_emb
|
652 |
+
else:
|
653 |
+
pooled_prompt_embeds_ = None
|
654 |
+
|
655 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_edit_dir(
|
656 |
+
start_image,
|
657 |
+
pil_image,
|
658 |
+
pil_image2,
|
659 |
+
content_prompt_embeds=pooled_prompt_embeds_,
|
660 |
+
edit_strength=edit_strength,
|
661 |
+
)
|
662 |
+
|
663 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
664 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
665 |
+
image_prompt_embeds = image_prompt_embeds.view(
|
666 |
+
bs_embed * num_samples, seq_len, -1
|
667 |
+
)
|
668 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
|
669 |
+
1, num_samples, 1
|
670 |
+
)
|
671 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
|
672 |
+
bs_embed * num_samples, seq_len, -1
|
673 |
+
)
|
674 |
+
|
675 |
+
with torch.inference_mode():
|
676 |
+
(
|
677 |
+
prompt_embeds,
|
678 |
+
negative_prompt_embeds,
|
679 |
+
pooled_prompt_embeds,
|
680 |
+
negative_pooled_prompt_embeds,
|
681 |
+
) = self.pipe.encode_prompt(
|
682 |
+
prompt,
|
683 |
+
num_images_per_prompt=num_samples,
|
684 |
+
do_classifier_free_guidance=True,
|
685 |
+
negative_prompt=negative_prompt,
|
686 |
+
)
|
687 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
688 |
+
negative_prompt_embeds = torch.cat(
|
689 |
+
[negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
|
690 |
+
)
|
691 |
+
|
692 |
+
self.generator = get_generator(seed, self.device)
|
693 |
+
|
694 |
+
images = self.pipe(
|
695 |
+
prompt_embeds=prompt_embeds,
|
696 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
697 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
698 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
699 |
+
num_inference_steps=num_inference_steps,
|
700 |
+
generator=self.generator,
|
701 |
+
**kwargs,
|
702 |
+
).images
|
703 |
+
|
704 |
+
return images
|
705 |
+
|
706 |
+
|
707 |
+
class IPAdapterPlus(IPAdapter):
|
708 |
+
"""IP-Adapter with fine-grained features"""
|
709 |
+
|
710 |
+
def init_proj(self):
|
711 |
+
image_proj_model = Resampler(
|
712 |
+
dim=self.pipe.unet.config.cross_attention_dim,
|
713 |
+
depth=4,
|
714 |
+
dim_head=64,
|
715 |
+
heads=12,
|
716 |
+
num_queries=self.num_tokens,
|
717 |
+
embedding_dim=self.image_encoder.config.hidden_size,
|
718 |
+
output_dim=self.pipe.unet.config.cross_attention_dim,
|
719 |
+
ff_mult=4,
|
720 |
+
).to(self.device, dtype=torch.float16)
|
721 |
+
return image_proj_model
|
722 |
+
|
723 |
+
@torch.inference_mode()
|
724 |
+
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
|
725 |
+
if isinstance(pil_image, Image.Image):
|
726 |
+
pil_image = [pil_image]
|
727 |
+
clip_image = self.clip_image_processor(
|
728 |
+
images=pil_image, return_tensors="pt"
|
729 |
+
).pixel_values
|
730 |
+
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
731 |
+
clip_image_embeds = self.image_encoder(
|
732 |
+
clip_image, output_hidden_states=True
|
733 |
+
).hidden_states[-2]
|
734 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
735 |
+
uncond_clip_image_embeds = self.image_encoder(
|
736 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
737 |
+
).hidden_states[-2]
|
738 |
+
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
739 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
740 |
+
|
741 |
+
|
742 |
+
class IPAdapterFull(IPAdapterPlus):
|
743 |
+
"""IP-Adapter with full features"""
|
744 |
+
|
745 |
+
def init_proj(self):
|
746 |
+
image_proj_model = MLPProjModel(
|
747 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
748 |
+
clip_embeddings_dim=self.image_encoder.config.hidden_size,
|
749 |
+
).to(self.device, dtype=torch.float16)
|
750 |
+
return image_proj_model
|
751 |
+
|
752 |
+
|
753 |
+
class IPAdapterPlusXL(IPAdapter):
|
754 |
+
"""SDXL"""
|
755 |
+
|
756 |
+
def init_proj(self):
|
757 |
+
image_proj_model = Resampler(
|
758 |
+
dim=1280,
|
759 |
+
depth=4,
|
760 |
+
dim_head=64,
|
761 |
+
heads=20,
|
762 |
+
num_queries=self.num_tokens,
|
763 |
+
embedding_dim=self.image_encoder.config.hidden_size,
|
764 |
+
output_dim=self.pipe.unet.config.cross_attention_dim,
|
765 |
+
ff_mult=4,
|
766 |
+
).to(self.device, dtype=torch.float16)
|
767 |
+
return image_proj_model
|
768 |
+
|
769 |
+
@torch.inference_mode()
|
770 |
+
def get_image_embeds(self, pil_image):
|
771 |
+
if isinstance(pil_image, Image.Image):
|
772 |
+
pil_image = [pil_image]
|
773 |
+
clip_image = self.clip_image_processor(
|
774 |
+
images=pil_image, return_tensors="pt"
|
775 |
+
).pixel_values
|
776 |
+
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
777 |
+
clip_image_embeds = self.image_encoder(
|
778 |
+
clip_image, output_hidden_states=True
|
779 |
+
).hidden_states[-2]
|
780 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
781 |
+
uncond_clip_image_embeds = self.image_encoder(
|
782 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
783 |
+
).hidden_states[-2]
|
784 |
+
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
785 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
786 |
+
|
787 |
+
def generate(
|
788 |
+
self,
|
789 |
+
pil_image,
|
790 |
+
prompt=None,
|
791 |
+
negative_prompt=None,
|
792 |
+
scale=1.0,
|
793 |
+
num_samples=4,
|
794 |
+
seed=None,
|
795 |
+
num_inference_steps=30,
|
796 |
+
**kwargs,
|
797 |
+
):
|
798 |
+
self.set_scale(scale)
|
799 |
+
|
800 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
801 |
+
|
802 |
+
if prompt is None:
|
803 |
+
prompt = "best quality, high quality"
|
804 |
+
if negative_prompt is None:
|
805 |
+
negative_prompt = (
|
806 |
+
"monochrome, lowres, bad anatomy, worst quality, low quality"
|
807 |
+
)
|
808 |
+
|
809 |
+
if not isinstance(prompt, List):
|
810 |
+
prompt = [prompt] * num_prompts
|
811 |
+
if not isinstance(negative_prompt, List):
|
812 |
+
negative_prompt = [negative_prompt] * num_prompts
|
813 |
+
|
814 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
|
815 |
+
pil_image
|
816 |
+
)
|
817 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
818 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
819 |
+
image_prompt_embeds = image_prompt_embeds.view(
|
820 |
+
bs_embed * num_samples, seq_len, -1
|
821 |
+
)
|
822 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
|
823 |
+
1, num_samples, 1
|
824 |
+
)
|
825 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
|
826 |
+
bs_embed * num_samples, seq_len, -1
|
827 |
+
)
|
828 |
+
|
829 |
+
with torch.inference_mode():
|
830 |
+
(
|
831 |
+
prompt_embeds,
|
832 |
+
negative_prompt_embeds,
|
833 |
+
pooled_prompt_embeds,
|
834 |
+
negative_pooled_prompt_embeds,
|
835 |
+
) = self.pipe.encode_prompt(
|
836 |
+
prompt,
|
837 |
+
num_images_per_prompt=num_samples,
|
838 |
+
do_classifier_free_guidance=True,
|
839 |
+
negative_prompt=negative_prompt,
|
840 |
+
)
|
841 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
842 |
+
negative_prompt_embeds = torch.cat(
|
843 |
+
[negative_prompt_embeds, uncond_image_prompt_embeds], dim=1
|
844 |
+
)
|
845 |
+
|
846 |
+
generator = get_generator(seed, self.device)
|
847 |
+
|
848 |
+
images = self.pipe(
|
849 |
+
prompt_embeds=prompt_embeds,
|
850 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
851 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
852 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
853 |
+
num_inference_steps=num_inference_steps,
|
854 |
+
generator=generator,
|
855 |
+
**kwargs,
|
856 |
+
).images
|
857 |
+
|
858 |
+
return images
|
ip_adapter_instantstyle/resampler.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from einops.layers.torch import Rearrange
|
10 |
+
|
11 |
+
|
12 |
+
# FFN
|
13 |
+
def FeedForward(dim, mult=4):
|
14 |
+
inner_dim = int(dim * mult)
|
15 |
+
return nn.Sequential(
|
16 |
+
nn.LayerNorm(dim),
|
17 |
+
nn.Linear(dim, inner_dim, bias=False),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(inner_dim, dim, bias=False),
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def reshape_tensor(x, heads):
|
24 |
+
bs, length, width = x.shape
|
25 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
26 |
+
x = x.view(bs, length, heads, -1)
|
27 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
28 |
+
x = x.transpose(1, 2)
|
29 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
30 |
+
x = x.reshape(bs, heads, length, -1)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class PerceiverAttention(nn.Module):
|
35 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
36 |
+
super().__init__()
|
37 |
+
self.scale = dim_head**-0.5
|
38 |
+
self.dim_head = dim_head
|
39 |
+
self.heads = heads
|
40 |
+
inner_dim = dim_head * heads
|
41 |
+
|
42 |
+
self.norm1 = nn.LayerNorm(dim)
|
43 |
+
self.norm2 = nn.LayerNorm(dim)
|
44 |
+
|
45 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
46 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
47 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
48 |
+
|
49 |
+
def forward(self, x, latents):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): image features
|
53 |
+
shape (b, n1, D)
|
54 |
+
latent (torch.Tensor): latent features
|
55 |
+
shape (b, n2, D)
|
56 |
+
"""
|
57 |
+
x = self.norm1(x)
|
58 |
+
latents = self.norm2(latents)
|
59 |
+
|
60 |
+
b, l, _ = latents.shape
|
61 |
+
|
62 |
+
q = self.to_q(latents)
|
63 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
64 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
65 |
+
|
66 |
+
q = reshape_tensor(q, self.heads)
|
67 |
+
k = reshape_tensor(k, self.heads)
|
68 |
+
v = reshape_tensor(v, self.heads)
|
69 |
+
|
70 |
+
# attention
|
71 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
72 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
73 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
74 |
+
out = weight @ v
|
75 |
+
|
76 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
77 |
+
|
78 |
+
return self.to_out(out)
|
79 |
+
|
80 |
+
|
81 |
+
class Resampler(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim=1024,
|
85 |
+
depth=8,
|
86 |
+
dim_head=64,
|
87 |
+
heads=16,
|
88 |
+
num_queries=8,
|
89 |
+
embedding_dim=768,
|
90 |
+
output_dim=1024,
|
91 |
+
ff_mult=4,
|
92 |
+
max_seq_len: int = 257, # CLIP tokens + CLS token
|
93 |
+
apply_pos_emb: bool = False,
|
94 |
+
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
98 |
+
|
99 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
100 |
+
|
101 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
102 |
+
|
103 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
104 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
105 |
+
|
106 |
+
self.to_latents_from_mean_pooled_seq = (
|
107 |
+
nn.Sequential(
|
108 |
+
nn.LayerNorm(dim),
|
109 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
110 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
111 |
+
)
|
112 |
+
if num_latents_mean_pooled > 0
|
113 |
+
else None
|
114 |
+
)
|
115 |
+
|
116 |
+
self.layers = nn.ModuleList([])
|
117 |
+
for _ in range(depth):
|
118 |
+
self.layers.append(
|
119 |
+
nn.ModuleList(
|
120 |
+
[
|
121 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
122 |
+
FeedForward(dim=dim, mult=ff_mult),
|
123 |
+
]
|
124 |
+
)
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
if self.pos_emb is not None:
|
129 |
+
n, device = x.shape[1], x.device
|
130 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
131 |
+
x = x + pos_emb
|
132 |
+
|
133 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
134 |
+
|
135 |
+
x = self.proj_in(x)
|
136 |
+
|
137 |
+
if self.to_latents_from_mean_pooled_seq:
|
138 |
+
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
139 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
140 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
141 |
+
|
142 |
+
for attn, ff in self.layers:
|
143 |
+
latents = attn(x, latents) + latents
|
144 |
+
latents = ff(latents) + latents
|
145 |
+
|
146 |
+
latents = self.proj_out(latents)
|
147 |
+
return self.norm_out(latents)
|
148 |
+
|
149 |
+
|
150 |
+
def masked_mean(t, *, dim, mask=None):
|
151 |
+
if mask is None:
|
152 |
+
return t.mean(dim=dim)
|
153 |
+
|
154 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
155 |
+
mask = rearrange(mask, "b n -> b n 1")
|
156 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
157 |
+
|
158 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
ip_adapter_instantstyle/utils.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
attn_maps = {}
|
7 |
+
def hook_fn(name):
|
8 |
+
def forward_hook(module, input, output):
|
9 |
+
if hasattr(module.processor, "attn_map"):
|
10 |
+
attn_maps[name] = module.processor.attn_map
|
11 |
+
del module.processor.attn_map
|
12 |
+
|
13 |
+
return forward_hook
|
14 |
+
|
15 |
+
def register_cross_attention_hook(unet):
|
16 |
+
for name, module in unet.named_modules():
|
17 |
+
if name.split('.')[-1].startswith('attn2'):
|
18 |
+
module.register_forward_hook(hook_fn(name))
|
19 |
+
|
20 |
+
return unet
|
21 |
+
|
22 |
+
def upscale(attn_map, target_size):
|
23 |
+
attn_map = torch.mean(attn_map, dim=0)
|
24 |
+
attn_map = attn_map.permute(1,0)
|
25 |
+
temp_size = None
|
26 |
+
|
27 |
+
for i in range(0,5):
|
28 |
+
scale = 2 ** i
|
29 |
+
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
|
30 |
+
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
|
31 |
+
break
|
32 |
+
|
33 |
+
assert temp_size is not None, "temp_size cannot is None"
|
34 |
+
|
35 |
+
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
|
36 |
+
|
37 |
+
attn_map = F.interpolate(
|
38 |
+
attn_map.unsqueeze(0).to(dtype=torch.float32),
|
39 |
+
size=target_size,
|
40 |
+
mode='bilinear',
|
41 |
+
align_corners=False
|
42 |
+
)[0]
|
43 |
+
|
44 |
+
attn_map = torch.softmax(attn_map, dim=0)
|
45 |
+
return attn_map
|
46 |
+
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
|
47 |
+
|
48 |
+
idx = 0 if instance_or_negative else 1
|
49 |
+
net_attn_maps = []
|
50 |
+
|
51 |
+
for name, attn_map in attn_maps.items():
|
52 |
+
attn_map = attn_map.cpu() if detach else attn_map
|
53 |
+
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
|
54 |
+
attn_map = upscale(attn_map, image_size)
|
55 |
+
net_attn_maps.append(attn_map)
|
56 |
+
|
57 |
+
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
|
58 |
+
|
59 |
+
return net_attn_maps
|
60 |
+
|
61 |
+
def attnmaps2images(net_attn_maps):
|
62 |
+
|
63 |
+
#total_attn_scores = 0
|
64 |
+
images = []
|
65 |
+
|
66 |
+
for attn_map in net_attn_maps:
|
67 |
+
attn_map = attn_map.cpu().numpy()
|
68 |
+
#total_attn_scores += attn_map.mean().item()
|
69 |
+
|
70 |
+
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
|
71 |
+
normalized_attn_map = normalized_attn_map.astype(np.uint8)
|
72 |
+
#print("norm: ", normalized_attn_map.shape)
|
73 |
+
image = Image.fromarray(normalized_attn_map)
|
74 |
+
|
75 |
+
#image = fix_save_attn_map(attn_map)
|
76 |
+
images.append(image)
|
77 |
+
|
78 |
+
#print(total_attn_scores)
|
79 |
+
return images
|
80 |
+
def is_torch2_available():
|
81 |
+
return hasattr(F, "scaled_dot_product_attention")
|
82 |
+
|
83 |
+
def get_generator(seed, device):
|
84 |
+
|
85 |
+
if seed is not None:
|
86 |
+
if isinstance(seed, list):
|
87 |
+
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
|
88 |
+
else:
|
89 |
+
generator = torch.Generator(device).manual_seed(seed)
|
90 |
+
else:
|
91 |
+
generator = None
|
92 |
+
|
93 |
+
return generator
|
marble.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline
|
7 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
8 |
+
from PIL import Image, ImageChops, ImageEnhance
|
9 |
+
from rembg import new_session, remove
|
10 |
+
from transformers import DPTForDepthEstimation, DPTImageProcessor
|
11 |
+
|
12 |
+
from ip_adapter_instantstyle import IPAdapterXL
|
13 |
+
from ip_adapter_instantstyle.utils import register_cross_attention_hook
|
14 |
+
from parametric_control_mlp import control_mlp
|
15 |
+
|
16 |
+
file_dir = os.path.dirname(os.path.abspath(__file__))
|
17 |
+
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
18 |
+
image_encoder_path = "models/image_encoder"
|
19 |
+
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
|
20 |
+
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
|
21 |
+
|
22 |
+
# Cache for rembg sessions
|
23 |
+
_session_cache = None
|
24 |
+
CONTROL_MLPS = ["metallic", "roughness", "transparency", "glow"]
|
25 |
+
|
26 |
+
|
27 |
+
def get_session():
|
28 |
+
global _session_cache
|
29 |
+
if _session_cache is None:
|
30 |
+
_session_cache = new_session()
|
31 |
+
return _session_cache
|
32 |
+
|
33 |
+
|
34 |
+
def setup_control_mlps(
|
35 |
+
features: int = 1024, device: str = "cuda", dtype: torch.dtype = torch.float16
|
36 |
+
) -> Dict[str, torch.nn.Module]:
|
37 |
+
ret = {}
|
38 |
+
for mlp in CONTROL_MLPS:
|
39 |
+
ret[mlp] = setup_control_mlp(mlp, features, device, dtype)
|
40 |
+
return ret
|
41 |
+
|
42 |
+
|
43 |
+
def setup_control_mlp(
|
44 |
+
material_parameter: str,
|
45 |
+
features: int = 1024,
|
46 |
+
device: str = "cuda",
|
47 |
+
dtype: torch.dtype = torch.float16,
|
48 |
+
):
|
49 |
+
net = control_mlp(features)
|
50 |
+
net.load_state_dict(
|
51 |
+
torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt"))
|
52 |
+
)
|
53 |
+
net.to(device, dtype=dtype)
|
54 |
+
net.eval()
|
55 |
+
return net
|
56 |
+
|
57 |
+
|
58 |
+
def download_ip_adapter():
|
59 |
+
repo_id = "h94/IP-Adapter"
|
60 |
+
target_folders = ["models/", "sdxl_models/"]
|
61 |
+
local_dir = file_dir
|
62 |
+
|
63 |
+
# Check if folders exist and contain files
|
64 |
+
folders_exist = all(
|
65 |
+
os.path.exists(os.path.join(local_dir, folder)) for folder in target_folders
|
66 |
+
)
|
67 |
+
|
68 |
+
if folders_exist:
|
69 |
+
# Check if any of the target folders are empty
|
70 |
+
folders_empty = any(
|
71 |
+
len(os.listdir(os.path.join(local_dir, folder))) == 0
|
72 |
+
for folder in target_folders
|
73 |
+
)
|
74 |
+
if not folders_empty:
|
75 |
+
print("IP-Adapter files already downloaded. Skipping download.")
|
76 |
+
return
|
77 |
+
|
78 |
+
# List all files in the repo
|
79 |
+
all_files = list_repo_files(repo_id)
|
80 |
+
|
81 |
+
# Filter for files in the desired folders
|
82 |
+
filtered_files = [
|
83 |
+
f for f in all_files if any(f.startswith(folder) for folder in target_folders)
|
84 |
+
]
|
85 |
+
|
86 |
+
# Download each file
|
87 |
+
for file_path in filtered_files:
|
88 |
+
local_path = hf_hub_download(
|
89 |
+
repo_id=repo_id,
|
90 |
+
filename=file_path,
|
91 |
+
local_dir=local_dir,
|
92 |
+
local_dir_use_symlinks=False,
|
93 |
+
)
|
94 |
+
print(f"Downloaded: {file_path} to {local_path}")
|
95 |
+
|
96 |
+
|
97 |
+
def setup_pipeline(
|
98 |
+
device: str = "cuda",
|
99 |
+
dtype: torch.dtype = torch.float16,
|
100 |
+
):
|
101 |
+
download_ip_adapter()
|
102 |
+
|
103 |
+
cur_block = ("up", 0, 1)
|
104 |
+
|
105 |
+
controlnet = ControlNetModel.from_pretrained(
|
106 |
+
controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=dtype
|
107 |
+
).to(device)
|
108 |
+
|
109 |
+
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
|
110 |
+
base_model_path,
|
111 |
+
controlnet=controlnet,
|
112 |
+
use_safetensors=True,
|
113 |
+
torch_dtype=dtype,
|
114 |
+
add_watermarker=False,
|
115 |
+
).to(device)
|
116 |
+
|
117 |
+
pipe.unet = register_cross_attention_hook(pipe.unet)
|
118 |
+
|
119 |
+
block_name = (
|
120 |
+
cur_block[0]
|
121 |
+
+ "_blocks."
|
122 |
+
+ str(cur_block[1])
|
123 |
+
+ ".attentions."
|
124 |
+
+ str(cur_block[2])
|
125 |
+
)
|
126 |
+
|
127 |
+
print("Testing block {}".format(block_name))
|
128 |
+
|
129 |
+
return IPAdapterXL(
|
130 |
+
pipe,
|
131 |
+
os.path.join(file_dir, image_encoder_path),
|
132 |
+
os.path.join(file_dir, ip_ckpt),
|
133 |
+
device,
|
134 |
+
target_blocks=[block_name],
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16):
|
139 |
+
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
140 |
+
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
|
141 |
+
model.to(device, dtype=dtype)
|
142 |
+
model.eval()
|
143 |
+
return model, image_processor
|
144 |
+
|
145 |
+
|
146 |
+
def run_dpt_depth(
|
147 |
+
image: Image.Image, model, processor, device: str = "cuda"
|
148 |
+
) -> Image.Image:
|
149 |
+
"""Run DPT depth estimation on an image."""
|
150 |
+
# Prepare image
|
151 |
+
inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype)
|
152 |
+
|
153 |
+
# Get depth prediction
|
154 |
+
with torch.no_grad():
|
155 |
+
depth_map = model(**inputs).predicted_depth
|
156 |
+
|
157 |
+
# Now normalize to 0-1 range
|
158 |
+
depth_map = (depth_map - depth_map.min()) / (
|
159 |
+
depth_map.max() - depth_map.min() + 1e-7
|
160 |
+
)
|
161 |
+
depth_map = depth_map.clip(0, 1) * 255
|
162 |
+
|
163 |
+
# Convert to PIL Image
|
164 |
+
depth_map = depth_map.squeeze().cpu().numpy().astype(np.uint8)
|
165 |
+
return Image.fromarray(depth_map).resize((1024, 1024))
|
166 |
+
|
167 |
+
|
168 |
+
def prepare_mask(image: Image.Image) -> Image.Image:
|
169 |
+
"""Prepare mask from image using rembg."""
|
170 |
+
rm_bg = remove(image, session=get_session())
|
171 |
+
target_mask = (
|
172 |
+
rm_bg.convert("RGB")
|
173 |
+
.point(lambda x: 0 if x < 1 else 255)
|
174 |
+
.convert("L")
|
175 |
+
.convert("RGB")
|
176 |
+
)
|
177 |
+
return target_mask.resize((1024, 1024))
|
178 |
+
|
179 |
+
|
180 |
+
def prepare_init_image(image: Image.Image, mask: Image.Image) -> Image.Image:
|
181 |
+
"""Prepare initial image for inpainting."""
|
182 |
+
|
183 |
+
# Create grayscale version
|
184 |
+
gray_image = image.convert("L").convert("RGB")
|
185 |
+
gray_image = ImageEnhance.Brightness(gray_image).enhance(1.0)
|
186 |
+
|
187 |
+
# Create mask inversions
|
188 |
+
invert_mask = ImageChops.invert(mask)
|
189 |
+
|
190 |
+
# Combine images
|
191 |
+
grayscale_img = ImageChops.darker(gray_image, mask)
|
192 |
+
img_black_mask = ImageChops.darker(image, invert_mask)
|
193 |
+
init_img = ImageChops.lighter(img_black_mask, grayscale_img)
|
194 |
+
|
195 |
+
return init_img.resize((1024, 1024))
|
196 |
+
|
197 |
+
|
198 |
+
def run_parametric_control(
|
199 |
+
ip_model,
|
200 |
+
target_image: Image.Image,
|
201 |
+
edit_mlps: dict[torch.nn.Module, float],
|
202 |
+
texture_image: Image.Image = None,
|
203 |
+
num_inference_steps: int = 30,
|
204 |
+
seed: int = 42,
|
205 |
+
depth_map: Image.Image = None,
|
206 |
+
mask: Image.Image = None,
|
207 |
+
) -> Image.Image:
|
208 |
+
"""Run parametric control with metallic and roughness adjustments."""
|
209 |
+
# Get depth map
|
210 |
+
if depth_map is None:
|
211 |
+
model, processor = get_dpt_model()
|
212 |
+
depth_map = run_dpt_depth(target_image, model, processor)
|
213 |
+
else:
|
214 |
+
depth_map = depth_map.resize((1024, 1024))
|
215 |
+
|
216 |
+
# Prepare mask and init image
|
217 |
+
if mask is None:
|
218 |
+
mask = prepare_mask(target_image)
|
219 |
+
else:
|
220 |
+
mask = mask.resize((1024, 1024))
|
221 |
+
|
222 |
+
if texture_image is None:
|
223 |
+
texture_image = target_image
|
224 |
+
|
225 |
+
init_img = prepare_init_image(target_image, mask)
|
226 |
+
|
227 |
+
# Generate edit
|
228 |
+
images = ip_model.generate_parametric_edits(
|
229 |
+
texture_image,
|
230 |
+
image=init_img,
|
231 |
+
control_image=depth_map,
|
232 |
+
mask_image=mask,
|
233 |
+
controlnet_conditioning_scale=1.0,
|
234 |
+
num_samples=1,
|
235 |
+
num_inference_steps=num_inference_steps,
|
236 |
+
seed=seed,
|
237 |
+
edit_mlps=edit_mlps,
|
238 |
+
strength=1.0,
|
239 |
+
)
|
240 |
+
|
241 |
+
return images[0]
|
242 |
+
|
243 |
+
|
244 |
+
def run_blend(
|
245 |
+
ip_model,
|
246 |
+
target_image: Image.Image,
|
247 |
+
texture_image1: Image.Image,
|
248 |
+
texture_image2: Image.Image,
|
249 |
+
edit_strength: float = 0.0,
|
250 |
+
num_inference_steps: int = 20,
|
251 |
+
seed: int = 1,
|
252 |
+
depth_map: Image.Image = None,
|
253 |
+
mask: Image.Image = None,
|
254 |
+
) -> Image.Image:
|
255 |
+
"""Run blending between two texture images."""
|
256 |
+
# Get depth map
|
257 |
+
if depth_map is None:
|
258 |
+
model, processor = get_dpt_model()
|
259 |
+
depth_map = run_dpt_depth(target_image, model, processor)
|
260 |
+
else:
|
261 |
+
depth_map = depth_map.resize((1024, 1024))
|
262 |
+
|
263 |
+
# Prepare mask and init image
|
264 |
+
if mask is None:
|
265 |
+
mask = prepare_mask(target_image)
|
266 |
+
else:
|
267 |
+
mask = mask.resize((1024, 1024))
|
268 |
+
init_img = prepare_init_image(target_image, mask)
|
269 |
+
|
270 |
+
# Generate edit
|
271 |
+
images = ip_model.generate_edit(
|
272 |
+
start_image=texture_image1,
|
273 |
+
pil_image=texture_image1,
|
274 |
+
pil_image2=texture_image2,
|
275 |
+
image=init_img,
|
276 |
+
control_image=depth_map,
|
277 |
+
mask_image=mask,
|
278 |
+
controlnet_conditioning_scale=1.0,
|
279 |
+
num_samples=1,
|
280 |
+
num_inference_steps=num_inference_steps,
|
281 |
+
seed=seed,
|
282 |
+
edit_strength=edit_strength,
|
283 |
+
clip_strength=1.0,
|
284 |
+
strength=1.0,
|
285 |
+
)
|
286 |
+
|
287 |
+
return images[0]
|
model_weights/glow.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4a0d131d8878dd96f10dcd3476ece28664d49836562c22495781c13b6d8eea6
|
3 |
+
size 12600283
|
model_weights/metallic.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b6b1f879902533d49db1ee4a5630d13cc7ff2ec4eba57c0a5d5955a13fbb12c
|
3 |
+
size 12600447
|
model_weights/roughness.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:635a69cfe7c63f4166e3170c82e6e72c3e721ef395eb700807a03b5dc56b01a8
|
3 |
+
size 12600467
|
model_weights/transparency.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b38d72f57017878030c401428272906fa2bc6a87049165923943b7621329b798
|
3 |
+
size 12600397
|
parametric_control_mlp.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class control_mlp(nn.Module):
|
6 |
+
def __init__(self, embedding_size):
|
7 |
+
super(control_mlp, self).__init__()
|
8 |
+
|
9 |
+
self.fc1 = nn.Linear(embedding_size, 1024)
|
10 |
+
self.fc2 = nn.Linear(1024, 2048)
|
11 |
+
self.relu = nn.ReLU()
|
12 |
+
|
13 |
+
self.edit_strength_fc1 = nn.Linear(1, 128)
|
14 |
+
self.edit_strength_fc2 = nn.Linear(128, 2)
|
15 |
+
|
16 |
+
def forward(self, x, edit_strength):
|
17 |
+
x = self.relu(self.fc1(x))
|
18 |
+
x = self.fc2(x)
|
19 |
+
|
20 |
+
edit_strength = self.relu(self.edit_strength_fc1(edit_strength.unsqueeze(1)))
|
21 |
+
edit_strength = self.edit_strength_fc2(edit_strength)
|
22 |
+
|
23 |
+
edit_strength1, edit_strength2 = edit_strength[:, 0], edit_strength[:, 1]
|
24 |
+
# print(edit_strength1.shape)
|
25 |
+
# exit()
|
26 |
+
|
27 |
+
output = (
|
28 |
+
edit_strength1.unsqueeze(1) * x[:, :1024]
|
29 |
+
+ edit_strength2.unsqueeze(1) * x[:, 1024:]
|
30 |
+
)
|
31 |
+
|
32 |
+
return output
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.30
|
2 |
+
rembg[gpu]
|
3 |
+
einops==0.7.0
|
4 |
+
transformers==4.27.4
|
5 |
+
opencv-python==4.7.0.68
|
6 |
+
accelerate==0.26.1
|
7 |
+
timm==0.6.12
|
8 |
+
torch==2.3.0
|
9 |
+
torchvision==0.18.0
|
10 |
+
huggingface_hub==0.30.2
|