adaface-neurips commited on
Commit
faed889
·
1 Parent(s): 1fe897a

extend CLIP text encoder to 97 tokens

Browse files
Files changed (3) hide show
  1. adaface/adaface_wrapper.py +21 -3
  2. adaface/util.py +20 -0
  3. app.py +29 -23
adaface/adaface_wrapper.py CHANGED
@@ -14,7 +14,7 @@ from diffusers import (
14
  LCMScheduler,
15
  )
16
  from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
17
- from adaface.util import UNetEnsemble
18
  from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
19
  from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
20
  from safetensors.torch import load_file as safetensors_load_file
@@ -27,7 +27,7 @@ class AdaFaceWrapper(nn.Module):
27
  adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
28
  enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
29
  num_inference_steps=50, subject_string='z', negative_prompt=None,
30
- use_840k_vae=False, use_ds_text_encoder=False,
31
  main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
  enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
  attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
@@ -56,6 +56,9 @@ class AdaFaceWrapper(nn.Module):
56
 
57
  self.default_scheduler_name = default_scheduler_name
58
  self.num_inference_steps = num_inference_steps if not use_lcm else 4
 
 
 
59
  self.use_840k_vae = use_840k_vae
60
  self.use_ds_text_encoder = use_ds_text_encoder
61
  self.main_unet_filepath = main_unet_filepath
@@ -199,6 +202,21 @@ class AdaFaceWrapper(nn.Module):
199
 
200
  pipeline.unet = unet2
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  if self.use_840k_vae:
203
  pipeline.vae = vae
204
  print("Replaced the VAE with the 840k-step VAE.")
@@ -715,7 +733,7 @@ class AdaFaceWrapper(nn.Module):
715
  ref_img_strength=0.8, generator=None,
716
  ablate_prompt_only_placeholders=False,
717
  ablate_prompt_no_placeholders=False,
718
- ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
719
  nonmix_prompt_emb_weight=0,
720
  repeat_prompt_for_each_encoder=True,
721
  verbose=False):
 
14
  LCMScheduler,
15
  )
16
  from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
17
+ from adaface.util import UNetEnsemble, extend_nn_embedding
18
  from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
19
  from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
20
  from safetensors.torch import load_file as safetensors_load_file
 
27
  adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
28
  enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
29
  num_inference_steps=50, subject_string='z', negative_prompt=None,
30
+ max_prompt_length=77, use_840k_vae=False, use_ds_text_encoder=False,
31
  main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
  enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
  attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
 
56
 
57
  self.default_scheduler_name = default_scheduler_name
58
  self.num_inference_steps = num_inference_steps if not use_lcm else 4
59
+
60
+ self.max_prompt_length = max_prompt_length
61
+
62
  self.use_840k_vae = use_840k_vae
63
  self.use_ds_text_encoder = use_ds_text_encoder
64
  self.main_unet_filepath = main_unet_filepath
 
202
 
203
  pipeline.unet = unet2
204
 
205
+ # Extending prompt length is for SD 1.5 only.
206
+ if (self.pipeline_name == "text2img") and (self.max_prompt_length > 77):
207
+ # pipeline.text_encoder.text_model.embeddings.position_embedding.weight: [77, 768] -> [max_length, 768]
208
+ # We reuse the last EL position embeddings for the new position embeddings.
209
+ # If we use the "neat" way, i.e., initialize CLIPTextModel with a CLIPTextConfig with
210
+ # a larger max_position_embeddings, and set ignore_mismatched_sizes=True,
211
+ # then the old position embeddings won't be loaded from the pretrained ckpt,
212
+ # leading to degenerated performance.
213
+ EL = self.max_prompt_length - 77
214
+ # position_embedding.weight: [77, 768] -> [max_length, 768]
215
+ new_position_embedding = extend_nn_embedding(pipeline.text_encoder.text_model.embeddings.position_embedding,
216
+ pipeline.text_encoder.text_model.embeddings.position_embedding.weight[-EL:])
217
+ pipeline.text_encoder.text_model.embeddings.position_embedding = new_position_embedding
218
+ pipeline.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_prompt_length).unsqueeze(0)
219
+
220
  if self.use_840k_vae:
221
  pipeline.vae = vae
222
  print("Replaced the VAE with the 840k-step VAE.")
 
733
  ref_img_strength=0.8, generator=None,
734
  ablate_prompt_only_placeholders=False,
735
  ablate_prompt_no_placeholders=False,
736
+ ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img1', 'img2'.
737
  nonmix_prompt_emb_weight=0,
738
  repeat_prompt_for_each_encoder=True,
739
  verbose=False):
adaface/util.py CHANGED
@@ -73,6 +73,26 @@ def calc_stats(emb_name, embeddings, mean_dim=-1):
73
  print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # Revised from RevGrad, by removing the grad negation.
77
  class ScaleGrad(torch.autograd.Function):
78
  @staticmethod
 
73
  print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
74
 
75
 
76
+ # new_token_embeddings: [new_num_tokens, 768].
77
+ def extend_nn_embedding(old_nn_embedding, new_token_embeddings):
78
+ emb_dim = old_nn_embedding.embedding_dim
79
+ num_old_tokens = old_nn_embedding.num_embeddings
80
+ num_new_tokens = new_token_embeddings.shape[0]
81
+ num_tokens2 = num_old_tokens + num_new_tokens
82
+
83
+ new_nn_embedding = nn.Embedding(num_tokens2, emb_dim,
84
+ device=old_nn_embedding.weight.device,
85
+ dtype=old_nn_embedding.weight.dtype)
86
+
87
+ old_num_tokens = old_nn_embedding.weight.shape[0]
88
+ # Copy the first old_num_tokens embeddings from old_nn_embedding to new_nn_embedding.
89
+ new_nn_embedding.weight.data[:old_num_tokens] = old_nn_embedding.weight.data
90
+ # Copy the new embeddings to new_nn_embedding.
91
+ new_nn_embedding.weight.data[old_num_tokens:] = new_token_embeddings
92
+
93
+ print(f"Extended nn.Embedding from {num_old_tokens} to {num_tokens2} tokens.")
94
+ return new_nn_embedding
95
+
96
  # Revised from RevGrad, by removing the grad negation.
97
  class ScaleGrad(torch.autograd.Function):
98
  @staticmethod
app.py CHANGED
@@ -50,17 +50,33 @@ parser.add_argument("--q_lora_updates_query", type=str2bool, nargs="?", const=Tr
50
  "If False, the q lora only updates query2.")
51
  parser.add_argument("--show_disable_adaface_checkbox", type=str2bool, nargs="?", const=True, default=False,
52
  help="Whether to show the checkbox for disabling AdaFace")
 
 
53
  parser.add_argument('--extra_save_dir', type=str, default=None, help="Directory to save the generated images")
54
  parser.add_argument('--test_ui_only', type=str2bool, nargs="?", const=True, default=False,
55
  help="Only test the UI layout, and skip loadding the adaface model")
 
 
56
  parser.add_argument('--gpu', type=int, default=None)
57
  parser.add_argument('--ip', type=str, default="0.0.0.0")
58
  args = parser.parse_args()
59
 
60
- from huggingface_hub import snapshot_download
61
- large_files = ["models/*", "models/**/*"]
62
- snapshot_download(repo_id="adaface-neurips/adaface-models", repo_type="model", allow_patterns=large_files, local_dir=".")
 
 
 
 
 
 
 
 
63
  os.makedirs("/tmp/gradio", exist_ok=True)
 
 
 
 
64
 
65
  model_style_type2base_model_path = {
66
  "realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
@@ -75,22 +91,13 @@ MAX_SEED = np.iinfo(np.int32).max
75
  global adaface
76
  adaface = None
77
 
78
- if is_running_on_hf_space():
79
- args.device = 'cuda:0'
80
- is_on_hf_space = True
81
- else:
82
- if args.gpu is None:
83
- args.device = "cuda"
84
- else:
85
- args.device = f"cuda:{args.gpu}"
86
- is_on_hf_space = False
87
-
88
  if not args.test_ui_only:
89
  adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
90
  adaface_encoder_types=args.adaface_encoder_types,
91
  adaface_ckpt_paths=args.adaface_ckpt_path,
92
  adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
93
  enabled_encoders=args.enabled_encoders,
 
94
  unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
95
  unet_uses_attn_lora=args.unet_uses_attn_lora,
96
  attn_lora_layer_names=args.attn_lora_layer_names,
@@ -120,7 +127,7 @@ def remove_back_to_files():
120
  @spaces.GPU
121
  def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
122
  num_images, prompt, negative_prompt, gender, highlight_face,
123
- ablate_prompt_embed_type, nonmix_prompt_emb_weight,
124
  composition_level, seed, disable_adaface, subj_name_sig, progress=gr.Progress(track_tqdm=True)):
125
 
126
  global adaface, args
@@ -168,6 +175,12 @@ def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
168
  else:
169
  prompt = gender + ", " + prompt
170
 
 
 
 
 
 
 
171
  generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed)
172
  samples = adaface(noise, prompt, negative_prompt=negative_prompt,
173
  guidance_scale=guidance_scale,
@@ -175,7 +188,6 @@ def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
175
  repeat_prompt_for_each_encoder=(composition_level >= 1),
176
  ablate_prompt_no_placeholders=disable_adaface,
177
  ablate_prompt_embed_type=ablate_prompt_embed_type,
178
- nonmix_prompt_emb_weight=nonmix_prompt_emb_weight,
179
  verbose=True)
180
 
181
  session_signature = ",".join(image_paths + [prompt, str(seed)])
@@ -387,14 +399,8 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
387
  minimum=0, maximum=2, step=1, value=0)
388
 
389
  ablate_prompt_embed_type = gr.Dropdown(label="Ablate prompt embeddings type",
390
- choices=["ada", "ada-nonmix", "img"], value="ada", visible=False,
391
  info="Use this type of prompt embeddings for ablation study")
392
-
393
- nonmix_prompt_emb_weight = gr.Slider(label="Weight of ada-nonmix ID embeddings",
394
- minimum=0.0, maximum=0.5, step=0.1, value=0,
395
- info="Weight of ada-nonmix ID embeddings in the prompt embeddings",
396
- visible=False)
397
-
398
 
399
  subj_name_sig = gr.Textbox(
400
  label="Nickname of Subject (optional; used to name saved images)",
@@ -497,7 +503,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
497
  'fn': generate_image,
498
  'inputs': [img_files, img_files2, guidance_scale, perturb_std, num_images, prompt,
499
  negative_prompt, gender, highlight_face, ablate_prompt_embed_type,
500
- nonmix_prompt_emb_weight, composition_level, seed, disable_adaface, subj_name_sig],
501
  'outputs': [out_gallery]
502
  }
503
  submit.click(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict)
 
50
  "If False, the q lora only updates query2.")
51
  parser.add_argument("--show_disable_adaface_checkbox", type=str2bool, nargs="?", const=True, default=False,
52
  help="Whether to show the checkbox for disabling AdaFace")
53
+ parser.add_argument('--show_ablate_prompt_embed_type', type=str2bool, nargs="?", const=True, default=False,
54
+ help="Whether to show the dropdown for ablate prompt embeddings type")
55
  parser.add_argument('--extra_save_dir', type=str, default=None, help="Directory to save the generated images")
56
  parser.add_argument('--test_ui_only', type=str2bool, nargs="?", const=True, default=False,
57
  help="Only test the UI layout, and skip loadding the adaface model")
58
+ parser.add_argument('--max_prompt_length', type=int, default=97,
59
+ help="Maximum length of the prompt. If > 77, the CLIP text encoder will be extended.")
60
  parser.add_argument('--gpu', type=int, default=None)
61
  parser.add_argument('--ip', type=str, default="0.0.0.0")
62
  args = parser.parse_args()
63
 
64
+
65
+ if is_running_on_hf_space():
66
+ args.device = 'cuda:0'
67
+ is_on_hf_space = True
68
+ else:
69
+ if args.gpu is None:
70
+ args.device = "cuda"
71
+ else:
72
+ args.device = f"cuda:{args.gpu}"
73
+ is_on_hf_space = False
74
+
75
  os.makedirs("/tmp/gradio", exist_ok=True)
76
+ from huggingface_hub import snapshot_download
77
+ if is_on_hf_space:
78
+ large_files = ["models/*", "models/**/*"]
79
+ snapshot_download(repo_id="adaface-neurips/adaface-models", repo_type="model", allow_patterns=large_files, local_dir=".")
80
 
81
  model_style_type2base_model_path = {
82
  "realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
 
91
  global adaface
92
  adaface = None
93
 
 
 
 
 
 
 
 
 
 
 
94
  if not args.test_ui_only:
95
  adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
96
  adaface_encoder_types=args.adaface_encoder_types,
97
  adaface_ckpt_paths=args.adaface_ckpt_path,
98
  adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
99
  enabled_encoders=args.enabled_encoders,
100
+ max_prompt_length=args.max_prompt_length,
101
  unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
102
  unet_uses_attn_lora=args.unet_uses_attn_lora,
103
  attn_lora_layer_names=args.attn_lora_layer_names,
 
127
  @spaces.GPU
128
  def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
129
  num_images, prompt, negative_prompt, gender, highlight_face,
130
+ ablate_prompt_embed_type,
131
  composition_level, seed, disable_adaface, subj_name_sig, progress=gr.Progress(track_tqdm=True)):
132
 
133
  global adaface, args
 
175
  else:
176
  prompt = gender + ", " + prompt
177
 
178
+ if ablate_prompt_embed_type != "ada":
179
+ # Find the prompt_emb_type index in adaface_encoder_types
180
+ # adaface_encoder_types: ["consistentID", "arc2face"]
181
+ ablate_prompt_embed_index = args.adaface_encoder_types.index(ablate_prompt_embed_type) + 1
182
+ ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
183
+
184
  generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed)
185
  samples = adaface(noise, prompt, negative_prompt=negative_prompt,
186
  guidance_scale=guidance_scale,
 
188
  repeat_prompt_for_each_encoder=(composition_level >= 1),
189
  ablate_prompt_no_placeholders=disable_adaface,
190
  ablate_prompt_embed_type=ablate_prompt_embed_type,
 
191
  verbose=True)
192
 
193
  session_signature = ",".join(image_paths + [prompt, str(seed)])
 
399
  minimum=0, maximum=2, step=1, value=0)
400
 
401
  ablate_prompt_embed_type = gr.Dropdown(label="Ablate prompt embeddings type",
402
+ choices=["ada", "arc2face", "consistentID"], value="ada", visible=args.show_ablate_prompt_embed_type,
403
  info="Use this type of prompt embeddings for ablation study")
 
 
 
 
 
 
404
 
405
  subj_name_sig = gr.Textbox(
406
  label="Nickname of Subject (optional; used to name saved images)",
 
503
  'fn': generate_image,
504
  'inputs': [img_files, img_files2, guidance_scale, perturb_std, num_images, prompt,
505
  negative_prompt, gender, highlight_face, ablate_prompt_embed_type,
506
+ composition_level, seed, disable_adaface, subj_name_sig],
507
  'outputs': [out_gallery]
508
  }
509
  submit.click(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict)