Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
faed889
1
Parent(s):
1fe897a
extend CLIP text encoder to 97 tokens
Browse files- adaface/adaface_wrapper.py +21 -3
- adaface/util.py +20 -0
- app.py +29 -23
adaface/adaface_wrapper.py
CHANGED
@@ -14,7 +14,7 @@ from diffusers import (
|
|
14 |
LCMScheduler,
|
15 |
)
|
16 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
17 |
-
from adaface.util import UNetEnsemble
|
18 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
19 |
from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
|
20 |
from safetensors.torch import load_file as safetensors_load_file
|
@@ -27,7 +27,7 @@ class AdaFaceWrapper(nn.Module):
|
|
27 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
28 |
enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
|
29 |
num_inference_steps=50, subject_string='z', negative_prompt=None,
|
30 |
-
use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
|
@@ -56,6 +56,9 @@ class AdaFaceWrapper(nn.Module):
|
|
56 |
|
57 |
self.default_scheduler_name = default_scheduler_name
|
58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
|
|
|
|
|
|
59 |
self.use_840k_vae = use_840k_vae
|
60 |
self.use_ds_text_encoder = use_ds_text_encoder
|
61 |
self.main_unet_filepath = main_unet_filepath
|
@@ -199,6 +202,21 @@ class AdaFaceWrapper(nn.Module):
|
|
199 |
|
200 |
pipeline.unet = unet2
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
if self.use_840k_vae:
|
203 |
pipeline.vae = vae
|
204 |
print("Replaced the VAE with the 840k-step VAE.")
|
@@ -715,7 +733,7 @@ class AdaFaceWrapper(nn.Module):
|
|
715 |
ref_img_strength=0.8, generator=None,
|
716 |
ablate_prompt_only_placeholders=False,
|
717 |
ablate_prompt_no_placeholders=False,
|
718 |
-
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', '
|
719 |
nonmix_prompt_emb_weight=0,
|
720 |
repeat_prompt_for_each_encoder=True,
|
721 |
verbose=False):
|
|
|
14 |
LCMScheduler,
|
15 |
)
|
16 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
17 |
+
from adaface.util import UNetEnsemble, extend_nn_embedding
|
18 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
19 |
from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
|
20 |
from safetensors.torch import load_file as safetensors_load_file
|
|
|
27 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
28 |
enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
|
29 |
num_inference_steps=50, subject_string='z', negative_prompt=None,
|
30 |
+
max_prompt_length=77, use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
|
|
|
56 |
|
57 |
self.default_scheduler_name = default_scheduler_name
|
58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
59 |
+
|
60 |
+
self.max_prompt_length = max_prompt_length
|
61 |
+
|
62 |
self.use_840k_vae = use_840k_vae
|
63 |
self.use_ds_text_encoder = use_ds_text_encoder
|
64 |
self.main_unet_filepath = main_unet_filepath
|
|
|
202 |
|
203 |
pipeline.unet = unet2
|
204 |
|
205 |
+
# Extending prompt length is for SD 1.5 only.
|
206 |
+
if (self.pipeline_name == "text2img") and (self.max_prompt_length > 77):
|
207 |
+
# pipeline.text_encoder.text_model.embeddings.position_embedding.weight: [77, 768] -> [max_length, 768]
|
208 |
+
# We reuse the last EL position embeddings for the new position embeddings.
|
209 |
+
# If we use the "neat" way, i.e., initialize CLIPTextModel with a CLIPTextConfig with
|
210 |
+
# a larger max_position_embeddings, and set ignore_mismatched_sizes=True,
|
211 |
+
# then the old position embeddings won't be loaded from the pretrained ckpt,
|
212 |
+
# leading to degenerated performance.
|
213 |
+
EL = self.max_prompt_length - 77
|
214 |
+
# position_embedding.weight: [77, 768] -> [max_length, 768]
|
215 |
+
new_position_embedding = extend_nn_embedding(pipeline.text_encoder.text_model.embeddings.position_embedding,
|
216 |
+
pipeline.text_encoder.text_model.embeddings.position_embedding.weight[-EL:])
|
217 |
+
pipeline.text_encoder.text_model.embeddings.position_embedding = new_position_embedding
|
218 |
+
pipeline.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_prompt_length).unsqueeze(0)
|
219 |
+
|
220 |
if self.use_840k_vae:
|
221 |
pipeline.vae = vae
|
222 |
print("Replaced the VAE with the 840k-step VAE.")
|
|
|
733 |
ref_img_strength=0.8, generator=None,
|
734 |
ablate_prompt_only_placeholders=False,
|
735 |
ablate_prompt_no_placeholders=False,
|
736 |
+
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img1', 'img2'.
|
737 |
nonmix_prompt_emb_weight=0,
|
738 |
repeat_prompt_for_each_encoder=True,
|
739 |
verbose=False):
|
adaface/util.py
CHANGED
@@ -73,6 +73,26 @@ def calc_stats(emb_name, embeddings, mean_dim=-1):
|
|
73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
# Revised from RevGrad, by removing the grad negation.
|
77 |
class ScaleGrad(torch.autograd.Function):
|
78 |
@staticmethod
|
|
|
73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
74 |
|
75 |
|
76 |
+
# new_token_embeddings: [new_num_tokens, 768].
|
77 |
+
def extend_nn_embedding(old_nn_embedding, new_token_embeddings):
|
78 |
+
emb_dim = old_nn_embedding.embedding_dim
|
79 |
+
num_old_tokens = old_nn_embedding.num_embeddings
|
80 |
+
num_new_tokens = new_token_embeddings.shape[0]
|
81 |
+
num_tokens2 = num_old_tokens + num_new_tokens
|
82 |
+
|
83 |
+
new_nn_embedding = nn.Embedding(num_tokens2, emb_dim,
|
84 |
+
device=old_nn_embedding.weight.device,
|
85 |
+
dtype=old_nn_embedding.weight.dtype)
|
86 |
+
|
87 |
+
old_num_tokens = old_nn_embedding.weight.shape[0]
|
88 |
+
# Copy the first old_num_tokens embeddings from old_nn_embedding to new_nn_embedding.
|
89 |
+
new_nn_embedding.weight.data[:old_num_tokens] = old_nn_embedding.weight.data
|
90 |
+
# Copy the new embeddings to new_nn_embedding.
|
91 |
+
new_nn_embedding.weight.data[old_num_tokens:] = new_token_embeddings
|
92 |
+
|
93 |
+
print(f"Extended nn.Embedding from {num_old_tokens} to {num_tokens2} tokens.")
|
94 |
+
return new_nn_embedding
|
95 |
+
|
96 |
# Revised from RevGrad, by removing the grad negation.
|
97 |
class ScaleGrad(torch.autograd.Function):
|
98 |
@staticmethod
|
app.py
CHANGED
@@ -50,17 +50,33 @@ parser.add_argument("--q_lora_updates_query", type=str2bool, nargs="?", const=Tr
|
|
50 |
"If False, the q lora only updates query2.")
|
51 |
parser.add_argument("--show_disable_adaface_checkbox", type=str2bool, nargs="?", const=True, default=False,
|
52 |
help="Whether to show the checkbox for disabling AdaFace")
|
|
|
|
|
53 |
parser.add_argument('--extra_save_dir', type=str, default=None, help="Directory to save the generated images")
|
54 |
parser.add_argument('--test_ui_only', type=str2bool, nargs="?", const=True, default=False,
|
55 |
help="Only test the UI layout, and skip loadding the adaface model")
|
|
|
|
|
56 |
parser.add_argument('--gpu', type=int, default=None)
|
57 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
58 |
args = parser.parse_args()
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
os.makedirs("/tmp/gradio", exist_ok=True)
|
|
|
|
|
|
|
|
|
64 |
|
65 |
model_style_type2base_model_path = {
|
66 |
"realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
|
@@ -75,22 +91,13 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
75 |
global adaface
|
76 |
adaface = None
|
77 |
|
78 |
-
if is_running_on_hf_space():
|
79 |
-
args.device = 'cuda:0'
|
80 |
-
is_on_hf_space = True
|
81 |
-
else:
|
82 |
-
if args.gpu is None:
|
83 |
-
args.device = "cuda"
|
84 |
-
else:
|
85 |
-
args.device = f"cuda:{args.gpu}"
|
86 |
-
is_on_hf_space = False
|
87 |
-
|
88 |
if not args.test_ui_only:
|
89 |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
|
90 |
adaface_encoder_types=args.adaface_encoder_types,
|
91 |
adaface_ckpt_paths=args.adaface_ckpt_path,
|
92 |
adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
|
93 |
enabled_encoders=args.enabled_encoders,
|
|
|
94 |
unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
95 |
unet_uses_attn_lora=args.unet_uses_attn_lora,
|
96 |
attn_lora_layer_names=args.attn_lora_layer_names,
|
@@ -120,7 +127,7 @@ def remove_back_to_files():
|
|
120 |
@spaces.GPU
|
121 |
def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
|
122 |
num_images, prompt, negative_prompt, gender, highlight_face,
|
123 |
-
ablate_prompt_embed_type,
|
124 |
composition_level, seed, disable_adaface, subj_name_sig, progress=gr.Progress(track_tqdm=True)):
|
125 |
|
126 |
global adaface, args
|
@@ -168,6 +175,12 @@ def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
|
|
168 |
else:
|
169 |
prompt = gender + ", " + prompt
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed)
|
172 |
samples = adaface(noise, prompt, negative_prompt=negative_prompt,
|
173 |
guidance_scale=guidance_scale,
|
@@ -175,7 +188,6 @@ def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
|
|
175 |
repeat_prompt_for_each_encoder=(composition_level >= 1),
|
176 |
ablate_prompt_no_placeholders=disable_adaface,
|
177 |
ablate_prompt_embed_type=ablate_prompt_embed_type,
|
178 |
-
nonmix_prompt_emb_weight=nonmix_prompt_emb_weight,
|
179 |
verbose=True)
|
180 |
|
181 |
session_signature = ",".join(image_paths + [prompt, str(seed)])
|
@@ -387,14 +399,8 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
387 |
minimum=0, maximum=2, step=1, value=0)
|
388 |
|
389 |
ablate_prompt_embed_type = gr.Dropdown(label="Ablate prompt embeddings type",
|
390 |
-
choices=["ada", "
|
391 |
info="Use this type of prompt embeddings for ablation study")
|
392 |
-
|
393 |
-
nonmix_prompt_emb_weight = gr.Slider(label="Weight of ada-nonmix ID embeddings",
|
394 |
-
minimum=0.0, maximum=0.5, step=0.1, value=0,
|
395 |
-
info="Weight of ada-nonmix ID embeddings in the prompt embeddings",
|
396 |
-
visible=False)
|
397 |
-
|
398 |
|
399 |
subj_name_sig = gr.Textbox(
|
400 |
label="Nickname of Subject (optional; used to name saved images)",
|
@@ -497,7 +503,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
497 |
'fn': generate_image,
|
498 |
'inputs': [img_files, img_files2, guidance_scale, perturb_std, num_images, prompt,
|
499 |
negative_prompt, gender, highlight_face, ablate_prompt_embed_type,
|
500 |
-
|
501 |
'outputs': [out_gallery]
|
502 |
}
|
503 |
submit.click(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict)
|
|
|
50 |
"If False, the q lora only updates query2.")
|
51 |
parser.add_argument("--show_disable_adaface_checkbox", type=str2bool, nargs="?", const=True, default=False,
|
52 |
help="Whether to show the checkbox for disabling AdaFace")
|
53 |
+
parser.add_argument('--show_ablate_prompt_embed_type', type=str2bool, nargs="?", const=True, default=False,
|
54 |
+
help="Whether to show the dropdown for ablate prompt embeddings type")
|
55 |
parser.add_argument('--extra_save_dir', type=str, default=None, help="Directory to save the generated images")
|
56 |
parser.add_argument('--test_ui_only', type=str2bool, nargs="?", const=True, default=False,
|
57 |
help="Only test the UI layout, and skip loadding the adaface model")
|
58 |
+
parser.add_argument('--max_prompt_length', type=int, default=97,
|
59 |
+
help="Maximum length of the prompt. If > 77, the CLIP text encoder will be extended.")
|
60 |
parser.add_argument('--gpu', type=int, default=None)
|
61 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
62 |
args = parser.parse_args()
|
63 |
|
64 |
+
|
65 |
+
if is_running_on_hf_space():
|
66 |
+
args.device = 'cuda:0'
|
67 |
+
is_on_hf_space = True
|
68 |
+
else:
|
69 |
+
if args.gpu is None:
|
70 |
+
args.device = "cuda"
|
71 |
+
else:
|
72 |
+
args.device = f"cuda:{args.gpu}"
|
73 |
+
is_on_hf_space = False
|
74 |
+
|
75 |
os.makedirs("/tmp/gradio", exist_ok=True)
|
76 |
+
from huggingface_hub import snapshot_download
|
77 |
+
if is_on_hf_space:
|
78 |
+
large_files = ["models/*", "models/**/*"]
|
79 |
+
snapshot_download(repo_id="adaface-neurips/adaface-models", repo_type="model", allow_patterns=large_files, local_dir=".")
|
80 |
|
81 |
model_style_type2base_model_path = {
|
82 |
"realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
|
|
|
91 |
global adaface
|
92 |
adaface = None
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
if not args.test_ui_only:
|
95 |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
|
96 |
adaface_encoder_types=args.adaface_encoder_types,
|
97 |
adaface_ckpt_paths=args.adaface_ckpt_path,
|
98 |
adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
|
99 |
enabled_encoders=args.enabled_encoders,
|
100 |
+
max_prompt_length=args.max_prompt_length,
|
101 |
unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
102 |
unet_uses_attn_lora=args.unet_uses_attn_lora,
|
103 |
attn_lora_layer_names=args.attn_lora_layer_names,
|
|
|
127 |
@spaces.GPU
|
128 |
def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
|
129 |
num_images, prompt, negative_prompt, gender, highlight_face,
|
130 |
+
ablate_prompt_embed_type,
|
131 |
composition_level, seed, disable_adaface, subj_name_sig, progress=gr.Progress(track_tqdm=True)):
|
132 |
|
133 |
global adaface, args
|
|
|
175 |
else:
|
176 |
prompt = gender + ", " + prompt
|
177 |
|
178 |
+
if ablate_prompt_embed_type != "ada":
|
179 |
+
# Find the prompt_emb_type index in adaface_encoder_types
|
180 |
+
# adaface_encoder_types: ["consistentID", "arc2face"]
|
181 |
+
ablate_prompt_embed_index = args.adaface_encoder_types.index(ablate_prompt_embed_type) + 1
|
182 |
+
ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
|
183 |
+
|
184 |
generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed)
|
185 |
samples = adaface(noise, prompt, negative_prompt=negative_prompt,
|
186 |
guidance_scale=guidance_scale,
|
|
|
188 |
repeat_prompt_for_each_encoder=(composition_level >= 1),
|
189 |
ablate_prompt_no_placeholders=disable_adaface,
|
190 |
ablate_prompt_embed_type=ablate_prompt_embed_type,
|
|
|
191 |
verbose=True)
|
192 |
|
193 |
session_signature = ",".join(image_paths + [prompt, str(seed)])
|
|
|
399 |
minimum=0, maximum=2, step=1, value=0)
|
400 |
|
401 |
ablate_prompt_embed_type = gr.Dropdown(label="Ablate prompt embeddings type",
|
402 |
+
choices=["ada", "arc2face", "consistentID"], value="ada", visible=args.show_ablate_prompt_embed_type,
|
403 |
info="Use this type of prompt embeddings for ablation study")
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
subj_name_sig = gr.Textbox(
|
406 |
label="Nickname of Subject (optional; used to name saved images)",
|
|
|
503 |
'fn': generate_image,
|
504 |
'inputs': [img_files, img_files2, guidance_scale, perturb_std, num_images, prompt,
|
505 |
negative_prompt, gender, highlight_face, ablate_prompt_embed_type,
|
506 |
+
composition_level, seed, disable_adaface, subj_name_sig],
|
507 |
'outputs': [out_gallery]
|
508 |
}
|
509 |
submit.click(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict)
|