adaface-neurips commited on
Commit
737c1a0
·
1 Parent(s): 8776445

Update code

Browse files
adaface/adaface_wrapper.py CHANGED
@@ -30,7 +30,7 @@ class AdaFaceWrapper(nn.Module):
30
  use_840k_vae=False, use_ds_text_encoder=False,
31
  main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
  enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
- attn_lora_layer_names=['q', 'k', 'v', 'out'], shrink_cross_attn=False, q_lora_updates_query=False,
34
  device='cuda', is_training=False):
35
  '''
36
  pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
@@ -52,7 +52,7 @@ class AdaFaceWrapper(nn.Module):
52
  self.q_lora_updates_query = q_lora_updates_query
53
  self.use_lcm = use_lcm
54
  self.subject_string = subject_string
55
- self.shrink_cross_attn = shrink_cross_attn
56
 
57
  self.default_scheduler_name = default_scheduler_name
58
  self.num_inference_steps = num_inference_steps if not use_lcm else 4
@@ -189,10 +189,10 @@ class AdaFaceWrapper(nn.Module):
189
  pipeline.unet = unet_ensemble
190
 
191
  print(f"Loaded pipeline from {self.base_model_path}.")
192
- if not remove_unet and (self.unet_uses_attn_lora or self.shrink_cross_attn):
193
  unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
194
  attn_lora_layer_names=self.attn_lora_layer_names,
195
- shrink_cross_attn=self.shrink_cross_attn,
196
  q_lora_updates_query=self.q_lora_updates_query)
197
 
198
  pipeline.unet = unet2
@@ -294,12 +294,11 @@ class AdaFaceWrapper(nn.Module):
294
  def load_unet_loras(self, unet, unet_lora_modules_state_dict,
295
  use_attn_lora=True, use_ffn_lora=False,
296
  attn_lora_layer_names=['q', 'k', 'v', 'out'],
297
- shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
298
  q_lora_updates_query=False):
299
  attn_capture_procs, attn_opt_modules = \
300
  set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
301
  lora_rank=192, lora_scale_down=8,
302
- cross_attn_shrink_factor=cross_attn_shrink_factor,
303
  q_lora_updates_query=q_lora_updates_query)
304
  # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
305
  if use_ffn_lora:
@@ -343,16 +342,17 @@ class AdaFaceWrapper(nn.Module):
343
  print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
344
  self.outfeat_capture_blocks.append(unet.up_blocks[3])
345
 
346
- # If shrink_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
347
  # but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
348
  set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
349
  use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
350
- shrink_cross_attn=shrink_cross_attn)
 
351
 
352
  return unet
353
 
354
  def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
355
- shrink_cross_attn=False, q_lora_updates_query=False):
356
  unet_lora_weight_found = False
357
  if isinstance(self.adaface_ckpt_paths, str):
358
  adaface_ckpt_paths = [self.adaface_ckpt_paths]
@@ -360,7 +360,7 @@ class AdaFaceWrapper(nn.Module):
360
  adaface_ckpt_paths = self.adaface_ckpt_paths
361
 
362
  for adaface_ckpt_path in adaface_ckpt_paths:
363
- ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu')
364
  if 'unet_lora_modules' in ckpt_dict:
365
  unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
366
  print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
@@ -379,7 +379,7 @@ class AdaFaceWrapper(nn.Module):
379
  unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
380
  use_attn_lora=use_attn_lora,
381
  attn_lora_layer_names=attn_lora_layer_names,
382
- shrink_cross_attn=shrink_cross_attn,
383
  q_lora_updates_query=q_lora_updates_query)
384
  unet.unets[i] = unet_
385
  print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
@@ -387,7 +387,7 @@ class AdaFaceWrapper(nn.Module):
387
  unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
388
  use_attn_lora=use_attn_lora,
389
  attn_lora_layer_names=attn_lora_layer_names,
390
- shrink_cross_attn=shrink_cross_attn,
391
  q_lora_updates_query=q_lora_updates_query)
392
 
393
  return unet
@@ -612,8 +612,9 @@ class AdaFaceWrapper(nn.Module):
612
  # Scan prompt and replace tokens in self.placeholder_token_ids
613
  # with the corresponding image embeddings.
614
  prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
 
615
  prompt_embeds2 = prompt_embeds.clone()
616
- if alt_prompt_embed_type == 'img':
617
  if self.img_prompt_embs is None:
618
  print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
619
  return prompt_embeds
@@ -628,17 +629,18 @@ class AdaFaceWrapper(nn.Module):
628
  breakpoint()
629
 
630
  repl_tokens = {}
 
631
  for i in range(len(prompt_tokens)):
632
  if prompt_tokens[i] in self.all_placeholder_tokens:
633
  encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
634
  if prompt_tokens[i] in sublist), 0)
635
- alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx]
636
- prompt_embeds2[:, i] = prompt_embeds2[:, i] * (1 - alt_prompt_emb_weight) \
637
  + repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
638
  repl_tokens[prompt_tokens[i]] = 1
639
 
640
  repl_token_count = len(repl_tokens)
641
- if np.all(np.array(alt_prompt_emb_weights) == 1):
642
  print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
643
  else:
644
  print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
@@ -650,7 +652,7 @@ class AdaFaceWrapper(nn.Module):
650
  placeholder_tokens_pos='append',
651
  ablate_prompt_only_placeholders=False,
652
  ablate_prompt_no_placeholders=False,
653
- ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
654
  nonmix_prompt_emb_weight=0,
655
  repeat_prompt_for_each_encoder=True,
656
  device=None, verbose=False):
@@ -678,14 +680,25 @@ class AdaFaceWrapper(nn.Module):
678
  prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
679
  self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
680
 
681
- if ablate_prompt_embed_type != 'ada':
682
  alt_prompt_embed_type = ablate_prompt_embed_type
683
- alt_prompt_emb_weights = (1, 1)
 
 
 
 
 
 
 
 
684
  elif nonmix_prompt_emb_weight > 0:
685
  alt_prompt_embed_type = 'ada-nonmix'
686
- alt_prompt_emb_weights = (nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
 
 
687
  else:
688
- alt_prompt_emb_weights = (0, 0)
 
689
 
690
  if sum(alt_prompt_emb_weights) > 0:
691
  prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
 
30
  use_840k_vae=False, use_ds_text_encoder=False,
31
  main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
  enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
+ attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
34
  device='cuda', is_training=False):
35
  '''
36
  pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
 
52
  self.q_lora_updates_query = q_lora_updates_query
53
  self.use_lcm = use_lcm
54
  self.subject_string = subject_string
55
+ self.normalize_cross_attn = normalize_cross_attn
56
 
57
  self.default_scheduler_name = default_scheduler_name
58
  self.num_inference_steps = num_inference_steps if not use_lcm else 4
 
189
  pipeline.unet = unet_ensemble
190
 
191
  print(f"Loaded pipeline from {self.base_model_path}.")
192
+ if not remove_unet and (self.unet_uses_attn_lora or self.normalize_cross_attn):
193
  unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
194
  attn_lora_layer_names=self.attn_lora_layer_names,
195
+ normalize_cross_attn=self.normalize_cross_attn,
196
  q_lora_updates_query=self.q_lora_updates_query)
197
 
198
  pipeline.unet = unet2
 
294
  def load_unet_loras(self, unet, unet_lora_modules_state_dict,
295
  use_attn_lora=True, use_ffn_lora=False,
296
  attn_lora_layer_names=['q', 'k', 'v', 'out'],
297
+ normalize_cross_attn=False,
298
  q_lora_updates_query=False):
299
  attn_capture_procs, attn_opt_modules = \
300
  set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
301
  lora_rank=192, lora_scale_down=8,
 
302
  q_lora_updates_query=q_lora_updates_query)
303
  # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
304
  if use_ffn_lora:
 
342
  print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
343
  self.outfeat_capture_blocks.append(unet.up_blocks[3])
344
 
345
+ # If normalize_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
346
  # but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
347
  set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
348
  use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
349
+ normalize_cross_attn=normalize_cross_attn, mix_attn_mats_in_batch=False,
350
+ res_hidden_states_gradscale=0)
351
 
352
  return unet
353
 
354
  def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
355
+ normalize_cross_attn=False, q_lora_updates_query=False):
356
  unet_lora_weight_found = False
357
  if isinstance(self.adaface_ckpt_paths, str):
358
  adaface_ckpt_paths = [self.adaface_ckpt_paths]
 
360
  adaface_ckpt_paths = self.adaface_ckpt_paths
361
 
362
  for adaface_ckpt_path in adaface_ckpt_paths:
363
+ ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu', weights_only=False)
364
  if 'unet_lora_modules' in ckpt_dict:
365
  unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
366
  print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
 
379
  unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
380
  use_attn_lora=use_attn_lora,
381
  attn_lora_layer_names=attn_lora_layer_names,
382
+ normalize_cross_attn=normalize_cross_attn,
383
  q_lora_updates_query=q_lora_updates_query)
384
  unet.unets[i] = unet_
385
  print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
 
387
  unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
388
  use_attn_lora=use_attn_lora,
389
  attn_lora_layer_names=attn_lora_layer_names,
390
+ normalize_cross_attn=normalize_cross_attn,
391
  q_lora_updates_query=q_lora_updates_query)
392
 
393
  return unet
 
612
  # Scan prompt and replace tokens in self.placeholder_token_ids
613
  # with the corresponding image embeddings.
614
  prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
615
+ # prompt_embeds are the ada embeddings.
616
  prompt_embeds2 = prompt_embeds.clone()
617
+ if alt_prompt_embed_type.startswith('img'):
618
  if self.img_prompt_embs is None:
619
  print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
620
  return prompt_embeds
 
629
  breakpoint()
630
 
631
  repl_tokens = {}
632
+ ada_emb_weight = alt_prompt_emb_weights[0]
633
  for i in range(len(prompt_tokens)):
634
  if prompt_tokens[i] in self.all_placeholder_tokens:
635
  encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
636
  if prompt_tokens[i] in sublist), 0)
637
+ alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx + 1]
638
+ prompt_embeds2[:, i] = prompt_embeds2[:, i] * ada_emb_weight \
639
  + repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
640
  repl_tokens[prompt_tokens[i]] = 1
641
 
642
  repl_token_count = len(repl_tokens)
643
+ if ada_emb_weight == 0:
644
  print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
645
  else:
646
  print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
 
652
  placeholder_tokens_pos='append',
653
  ablate_prompt_only_placeholders=False,
654
  ablate_prompt_no_placeholders=False,
655
+ ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img', 'img1', 'img2'.
656
  nonmix_prompt_emb_weight=0,
657
  repeat_prompt_for_each_encoder=True,
658
  device=None, verbose=False):
 
680
  prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
681
  self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
682
 
683
+ if ablate_prompt_embed_type.startswith('img'):
684
  alt_prompt_embed_type = ablate_prompt_embed_type
685
+ if alt_prompt_embed_type == 'img1':
686
+ # The mixing weights of ada, img1, and img2 are 0, 1, and 0.
687
+ alt_prompt_emb_weights = (0, 1, 0)
688
+ elif alt_prompt_embed_type == 'img2':
689
+ # The mixing weights of ada, img1, and img2 are 0, 0, and 1.
690
+ alt_prompt_emb_weights = (0, 0, 1)
691
+ else:
692
+ # The mixing weights of ada, img1, and img2 are 0, 1, and 1.
693
+ alt_prompt_emb_weights = (0, 1, 1)
694
  elif nonmix_prompt_emb_weight > 0:
695
  alt_prompt_embed_type = 'ada-nonmix'
696
+ # The mixing weight of ada is 1 - nonmix_prompt_emb_weight, instead of 1 - nonmix_prompt_emb_weight * 2.
697
+ # It means ada is mixed by this weight with both img1 and img2.
698
+ alt_prompt_emb_weights = (1 - nonmix_prompt_emb_weight, nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
699
  else:
700
+ # Don't change the prompt embeddings. So we set all the mixing weights to 0.
701
+ alt_prompt_emb_weights = (0, 0, 0)
702
 
703
  if sum(alt_prompt_emb_weights) > 0:
704
  prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
adaface/diffusers_attn_lora_capture.py CHANGED
@@ -4,7 +4,6 @@ import torch.nn.functional as F
4
  from typing import Optional, Tuple, Dict, Any
5
  from diffusers.models.attention_processor import Attention, AttnProcessor2_0
6
  from diffusers.utils import logging, is_torch_version, deprecate
7
- from diffusers.utils.torch_utils import fourier_filter
8
  # UNet is a diffusers PeftAdapterMixin instance.
9
  from diffusers.loaders.peft import PeftAdapterMixin
10
  from peft import LoraConfig, get_peft_model
@@ -12,7 +11,6 @@ import peft.tuners.lora as peft_lora
12
  from peft.tuners.lora.dora import DoraLinearLayer
13
  from einops import rearrange
14
  import math, re
15
- import numpy as np
16
  from peft.tuners.tuners_utils import BaseTunerLayer
17
 
18
 
@@ -28,7 +26,7 @@ class ScaleGrad(torch.autograd.Function):
28
  ctx.save_for_backward(alpha_, debug)
29
  output = input_
30
  if debug:
31
- print(f"input: {input_.abs().mean().item()}")
32
  return output
33
 
34
  @staticmethod
@@ -38,7 +36,7 @@ class ScaleGrad(torch.autograd.Function):
38
  if ctx.needs_input_grad[0]:
39
  grad_output2 = grad_output * alpha_
40
  if debug:
41
- print(f"grad_output2: {grad_output2.abs().mean().item()}")
42
  else:
43
  grad_output2 = None
44
  return grad_output2, None, None
@@ -77,36 +75,11 @@ def split_indices_by_instance(indices, as_dict=False):
77
  indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
78
  return indices_by_instance
79
 
80
- # If do_sum, returned emb_attns is 3D. Otherwise 4D.
81
- # indices are applied on the first 2 dims of attn_mat.
82
- def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False):
83
- indices_by_instance = split_indices_by_instance(indices)
84
-
85
- # emb_attns[0]: [1, 9, 8, 64]
86
- # 8: 8 attention heads. Last dim 64: number of image tokens.
87
- emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ]
88
- if all_token_weights is not None:
89
- # all_token_weights: [4, 77].
90
- # token_weights_by_instance[0]: [1, 9, 1, 1].
91
- token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ]
92
- else:
93
- token_weights = [ 1 ] * len(indices_by_instance)
94
-
95
- # Apply token weights.
96
- emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ]
97
-
98
- # sum among K_subj_i subj embeddings -> [1, 8, 64]
99
- if do_sum:
100
- emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ]
101
- elif do_mean:
102
- emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ]
103
-
104
- emb_attns = torch.cat(emb_attns, dim=0)
105
- return emb_attns
106
-
107
  # Slow implementation equivalent to F.scaled_dot_product_attention.
108
- def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
109
- shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
 
 
110
  is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
111
  B, L, S = query.size(0), query.size(-2), key.size(-2)
112
  scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
@@ -128,21 +101,39 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
128
  key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
129
  value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
130
 
131
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
132
 
133
- if shrink_cross_attn:
134
- cross_attn_scale = cross_attn_shrink_factor
135
- else:
136
- cross_attn_scale = 1
137
-
138
- # attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_weight.
139
- attn_weight += attn_bias
140
- attn_score = attn_weight
141
- attn_weight = torch.softmax(attn_weight, dim=-1)
142
- # NOTE: After scaling, the "probabilities" of the subject embeddings will sum to < 1.
143
- # But this is intended, as we want to scale down the impact of the subject embeddings
144
- # in the computed attention output tensors.
145
- attn_weight = attn_weight * cross_attn_scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
147
  output = attn_weight @ value
148
  return output, attn_score, attn_weight
@@ -156,23 +147,25 @@ class AttnProcessor_LoRA_Capture(nn.Module):
156
  def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
157
  lora_uses_dora=True, lora_proj_layers=None,
158
  lora_rank: int = 192, lora_alpha: float = 16,
159
- cross_attn_shrink_factor: float = 0.5,
160
  q_lora_updates_query=False, attn_proc_idx=-1):
161
  super().__init__()
162
 
163
  self.global_enable_lora = enable_lora
164
  self.attn_proc_idx = attn_proc_idx
165
  # reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
166
- # By default, shrink_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
167
- self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora)
168
  self.lora_rank = lora_rank
169
  self.lora_alpha = lora_alpha
170
  self.lora_scale = self.lora_alpha / self.lora_rank
171
- self.cross_attn_shrink_factor = cross_attn_shrink_factor
172
  self.q_lora_updates_query = q_lora_updates_query
173
 
174
  self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
175
  if self.global_enable_lora:
 
 
 
 
176
  for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
177
  if lora_layer_name == 'q':
178
  self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
@@ -188,9 +181,10 @@ class AttnProcessor_LoRA_Capture(nn.Module):
188
  use_dora=lora_uses_dora, lora_dropout=0.1)
189
 
190
  # LoRA layers can be enabled/disabled dynamically.
191
- def reset_attn_cache_and_flags(self, capture_ca_activations, shrink_cross_attn, enable_lora):
192
  self.capture_ca_activations = capture_ca_activations
193
- self.shrink_cross_attn = shrink_cross_attn
 
194
  self.cached_activations = {}
195
  # Only enable LoRA for the next call(s) if global_enable_lora is set to True.
196
  self.enable_lora = enable_lora and self.global_enable_lora
@@ -312,11 +306,14 @@ class AttnProcessor_LoRA_Capture(nn.Module):
312
  breakpoint()
313
 
314
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
315
- if is_cross_attn and (self.capture_ca_activations or self.shrink_cross_attn):
316
  hidden_states, attn_score, attn_prob = \
317
  scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
318
- dropout_p=0.0, shrink_cross_attn=self.shrink_cross_attn,
319
- cross_attn_shrink_factor=self.cross_attn_shrink_factor)
 
 
 
320
  else:
321
  # Use the faster implementation of scaled_dot_product_attention
322
  # when not capturing the activations or suppressing the subject attention.
@@ -452,7 +449,7 @@ def CrossAttnUpBlock2D_forward_capture(
452
  # Adapted from ConsistentIDPipeline:set_ip_adapter().
453
  # attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
454
  def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
455
- lora_rank=192, lora_scale_down=8, cross_attn_shrink_factor=0.5,
456
  q_lora_updates_query=False):
457
  attn_procs = {}
458
  attn_capture_procs = {}
@@ -502,7 +499,6 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
502
  lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
503
  # LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
504
  lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
505
- cross_attn_shrink_factor=cross_attn_shrink_factor,
506
  q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
507
 
508
  attn_proc_idx += 1
@@ -513,6 +509,11 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
513
  attn_capture_procs[name] = attn_capture_proc
514
 
515
  if use_attn_lora:
 
 
 
 
 
516
  for subname, module in attn_capture_proc.named_modules():
517
  if isinstance(module, peft_lora.LoraLayer):
518
  # ModuleDict doesn't allow "." in the key.
@@ -537,7 +538,7 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
537
  return attn_capture_procs, attn_opt_modules
538
 
539
  # NOTE: cross-attn layers are included in the returned lora_modules.
540
- def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=192, lora_alpha=16):
541
  # target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
542
  # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
543
  # Cannot set to conv.+ as it will match added adapter module names, including
@@ -592,15 +593,18 @@ def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=1
592
  def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
593
  outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
594
  use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
595
- shrink_cross_attn, res_hidden_states_gradscale):
596
  # For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
597
- for attn_capture_proc in attn_capture_procs:
598
- attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, shrink_cross_attn, enable_lora=use_attn_lora)
 
599
  # outfeat_capture_blocks only contains the last up block, up_blocks[3].
600
  # It contains 3 FFN layers. We want to capture their output features.
601
  for block in outfeat_capture_blocks:
602
  block.capture_outfeats = capture_ca_activations
603
 
 
 
604
  for block in res_hidden_states_gradscale_blocks:
605
  block.res_hidden_states_gradscale = res_hidden_states_gradscale
606
 
@@ -639,6 +643,7 @@ def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat
639
  block.cached_outfeats = {}
640
  block.capture_outfeats = False
641
 
 
642
  for layer_idx in captured_layer_indices:
643
  # Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
644
  # 23, 24 -> 1, 2 (!! not 0, 1 !!)
 
4
  from typing import Optional, Tuple, Dict, Any
5
  from diffusers.models.attention_processor import Attention, AttnProcessor2_0
6
  from diffusers.utils import logging, is_torch_version, deprecate
 
7
  # UNet is a diffusers PeftAdapterMixin instance.
8
  from diffusers.loaders.peft import PeftAdapterMixin
9
  from peft import LoraConfig, get_peft_model
 
11
  from peft.tuners.lora.dora import DoraLinearLayer
12
  from einops import rearrange
13
  import math, re
 
14
  from peft.tuners.tuners_utils import BaseTunerLayer
15
 
16
 
 
26
  ctx.save_for_backward(alpha_, debug)
27
  output = input_
28
  if debug:
29
+ print(f"input: {input_.abs().mean().detach().item()}")
30
  return output
31
 
32
  @staticmethod
 
36
  if ctx.needs_input_grad[0]:
37
  grad_output2 = grad_output * alpha_
38
  if debug:
39
+ print(f"grad_output2: {grad_output2.abs().mean().detach().item()}")
40
  else:
41
  grad_output2 = None
42
  return grad_output2, None, None
 
75
  indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
76
  return indices_by_instance
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Slow implementation equivalent to F.scaled_dot_product_attention.
79
+ def scaled_dot_product_attention(query, key, value, cross_attn_scale_factor,
80
+ attn_mask=None, dropout_p=0.0,
81
+ subj_indices=None, normalize_cross_attn=False,
82
+ mix_attn_mats_in_batch=False,
83
  is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
84
  B, L, S = query.size(0), query.size(-2), key.size(-2)
85
  scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
 
101
  key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
102
  value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
103
 
104
+ attn_score = query @ key.transpose(-2, -1) * scale_factor
105
 
106
+ # attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_score.
107
+ attn_score += attn_bias
108
+ if mix_attn_mats_in_batch:
109
+ # The instances in the batch are [sc, mc]. We average their attn scores,
110
+ # and apply to both instances.
111
+ # attn_score: [2, 8, 4096, 77] -> [1, 8, 4096, 77] -> [2, 8, 4096, 77].
112
+ # If BLOCK_SIZE > 1, attn_score.shape[0] = 2 * BLOCK_SIZE.
113
+ if attn_score.shape[0] %2 != 0:
114
+ breakpoint()
115
+ attn_score_sc, attn_score_mc = attn_score.chunk(2, dim=0)
116
+ # Cut off the grad flow from the SC instance to the MC instance.
117
+ attn_score = (attn_score_sc + attn_score_mc.detach()) / 2
118
+ attn_score = attn_score.repeat(2, 1, 1, 1)
119
+ elif normalize_cross_attn:
120
+ if subj_indices is None:
121
+ breakpoint()
122
+ subj_indices_B, subj_indices_N = subj_indices
123
+ subj_attn_score = attn_score[subj_indices_B, :, :, subj_indices_N]
124
+ # Normalize the attention score of the subject tokens to have mean 0 across tokens,
125
+ # so that positive and negative scores are balanced.
126
+ subj_attn_score = subj_attn_score - subj_attn_score.mean(dim=2, keepdim=True).detach()
127
+ # cross_attn_scale is a learnable parameter, so the score will be scaled appropriately.
128
+ # Scale up the BP'ed gradient to cross_attn_scale_factor by 10x.
129
+ ca_scale_grad_scaler = gen_gradient_scaler(10)
130
+ subj_attn_score = subj_attn_score * ca_scale_grad_scaler(cross_attn_scale_factor)
131
+ attn_score2 = attn_score.clone()
132
+ attn_score2[subj_indices_B, :, :, subj_indices_N] = subj_attn_score
133
+ attn_score = attn_score2
134
+ # Otherwise, do nothing to attn_score.
135
+
136
+ attn_weight = torch.softmax(attn_score, dim=-1)
137
  attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
138
  output = attn_weight @ value
139
  return output, attn_score, attn_weight
 
147
  def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
148
  lora_uses_dora=True, lora_proj_layers=None,
149
  lora_rank: int = 192, lora_alpha: float = 16,
 
150
  q_lora_updates_query=False, attn_proc_idx=-1):
151
  super().__init__()
152
 
153
  self.global_enable_lora = enable_lora
154
  self.attn_proc_idx = attn_proc_idx
155
  # reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
156
+ # By default, normalize_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
157
+ self.reset_attn_cache_and_flags(capture_ca_activations, False, False, enable_lora)
158
  self.lora_rank = lora_rank
159
  self.lora_alpha = lora_alpha
160
  self.lora_scale = self.lora_alpha / self.lora_rank
 
161
  self.q_lora_updates_query = q_lora_updates_query
162
 
163
  self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
164
  if self.global_enable_lora:
165
+ # enable_lora = True iff this is a cross-attn layer in the last 3 up blocks.
166
+ # Since we only use cross_attn_scale_factor on cross-attn layers,
167
+ # we only use cross_attn_scale_factor when enable_lora is True.
168
+ self.cross_attn_scale_factor = nn.Parameter(torch.tensor(0.8), requires_grad=True)
169
  for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
170
  if lora_layer_name == 'q':
171
  self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
 
181
  use_dora=lora_uses_dora, lora_dropout=0.1)
182
 
183
  # LoRA layers can be enabled/disabled dynamically.
184
+ def reset_attn_cache_and_flags(self, capture_ca_activations, normalize_cross_attn, mix_attn_mats_in_batch, enable_lora):
185
  self.capture_ca_activations = capture_ca_activations
186
+ self.normalize_cross_attn = normalize_cross_attn
187
+ self.mix_attn_mats_in_batch = mix_attn_mats_in_batch
188
  self.cached_activations = {}
189
  # Only enable LoRA for the next call(s) if global_enable_lora is set to True.
190
  self.enable_lora = enable_lora and self.global_enable_lora
 
306
  breakpoint()
307
 
308
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
309
+ if is_cross_attn and (self.capture_ca_activations or self.normalize_cross_attn):
310
  hidden_states, attn_score, attn_prob = \
311
  scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
312
+ dropout_p=0.0, subj_indices=subj_indices,
313
+ normalize_cross_attn=self.normalize_cross_attn,
314
+ cross_attn_scale_factor=self.cross_attn_scale_factor,
315
+ mix_attn_mats_in_batch=self.mix_attn_mats_in_batch)
316
+
317
  else:
318
  # Use the faster implementation of scaled_dot_product_attention
319
  # when not capturing the activations or suppressing the subject attention.
 
449
  # Adapted from ConsistentIDPipeline:set_ip_adapter().
450
  # attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
451
  def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
452
+ lora_rank=192, lora_scale_down=8,
453
  q_lora_updates_query=False):
454
  attn_procs = {}
455
  attn_capture_procs = {}
 
499
  lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
500
  # LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
501
  lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
 
502
  q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
503
 
504
  attn_proc_idx += 1
 
509
  attn_capture_procs[name] = attn_capture_proc
510
 
511
  if use_attn_lora:
512
+ cross_attn_scale_factor_name = name + "_cross_attn_scale_factor"
513
+ # Put cross_attn_scale_factor in attn_opt_modules, so that we can optimize and save/load it.
514
+ attn_opt_modules[cross_attn_scale_factor_name] = attn_capture_proc.cross_attn_scale_factor
515
+
516
+ # Put LoRA layers in attn_opt_modules, so that we can optimize and save/load them.
517
  for subname, module in attn_capture_proc.named_modules():
518
  if isinstance(module, peft_lora.LoraLayer):
519
  # ModuleDict doesn't allow "." in the key.
 
538
  return attn_capture_procs, attn_opt_modules
539
 
540
  # NOTE: cross-attn layers are included in the returned lora_modules.
541
+ def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=True, lora_rank=192, lora_alpha=16):
542
  # target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
543
  # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
544
  # Cannot set to conv.+ as it will match added adapter module names, including
 
593
  def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
594
  outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
595
  use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
596
+ normalize_cross_attn, mix_attn_mats_in_batch, res_hidden_states_gradscale):
597
  # For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
598
+ for i, attn_capture_proc in enumerate(attn_capture_procs):
599
+ attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, normalize_cross_attn, mix_attn_mats_in_batch,
600
+ enable_lora=use_attn_lora)
601
  # outfeat_capture_blocks only contains the last up block, up_blocks[3].
602
  # It contains 3 FFN layers. We want to capture their output features.
603
  for block in outfeat_capture_blocks:
604
  block.capture_outfeats = capture_ca_activations
605
 
606
+ # res_hidden_states_gradscale_blocks contain the second to the last up blocks, up_blocks[1:].
607
+ # It's only used to set res_hidden_states_gradscale, and doesn't capture anything.
608
  for block in res_hidden_states_gradscale_blocks:
609
  block.res_hidden_states_gradscale = res_hidden_states_gradscale
610
 
 
643
  block.cached_outfeats = {}
644
  block.capture_outfeats = False
645
 
646
+
647
  for layer_idx in captured_layer_indices:
648
  # Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
649
  # 23, 24 -> 1, 2 (!! not 0, 1 !!)
adaface/face_id_to_ada_prompt.py CHANGED
@@ -603,9 +603,13 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
603
  '''
604
  # Use the same model as ID2AdaPrompt does.
605
  # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
606
- # Note there are two "models" in the path.
 
 
 
 
607
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
608
- providers=['CPUExecutionProvider'])
609
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
610
  print(f'Arc2Face Face encoder loaded on CPU.')
611
 
@@ -642,7 +646,6 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
642
 
643
  def _apply(self, fn):
644
  super()._apply(fn) # Call the parent _apply to handle parameters and buffers
645
- return
646
  # A dirty hack to get the device of the model, passed from
647
  # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
648
  test_tensor = torch.zeros(1) # Create a test tensor
@@ -654,16 +657,14 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
654
 
655
  if str(device) == 'cpu':
656
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
657
- providers=['CPUExecutionProvider'])
658
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
659
  else:
660
  device_id = device.index
661
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
662
  providers=['CUDAExecutionProvider'],
663
- provider_options=[{"device_id": device_id,
664
- "cudnn_conv_algo_search": "HEURISTIC",
665
- "gpu_mem_limit": 2 * 1024**3
666
- }])
667
  self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
668
 
669
  self.device = device
@@ -739,8 +740,8 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
739
  # but diffusers will call .to(dtype) in .from_single_file(),
740
  # and at that moment, the consistentID specific modules are not loaded yet.
741
  pipe = ConsistentIDPipeline.from_single_file(base_model_path)
742
- pipe.load_ConsistentID_model(consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin",
743
- bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth")
744
  pipe.to(dtype=self.dtype)
745
  # Since the passed-in pipe is None, this should be called during inference,
746
  # when the teacher ConsistentIDPipeline is not initialized.
@@ -791,7 +792,6 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
791
 
792
  def _apply(self, fn):
793
  super()._apply(fn) # Call the parent _apply to handle parameters and buffers
794
- return
795
  # A dirty hack to get the device of the model, passed from
796
  # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
797
  test_tensor = torch.zeros(1) # Create a test tensor
@@ -809,10 +809,8 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
809
  device_id = device.index
810
  self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
811
  providers=['CUDAExecutionProvider'],
812
- provider_options=[{"device_id": device_id,
813
- "cudnn_conv_algo_search": "HEURISTIC",
814
- "gpu_mem_limit": 2 * 1024**3
815
- }])
816
  self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
817
 
818
  self.device = device
@@ -1277,7 +1275,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1277
  # No faces are found in the images, so return None embeddings.
1278
  # We don't want to return an all-zero embedding, which is useless.
1279
  if num_available_id_vecs == 0:
1280
- return None, [0]
1281
 
1282
  # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
1283
  # during inference, we average across the batch dim.
 
603
  '''
604
  # Use the same model as ID2AdaPrompt does.
605
  # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
606
+ # Note there's a second "model" in the path.
607
+ # Note DO use CUDAExecutionProvider during training and CPUExecutionProvider during inference.
608
+ # Otherwise, CPUExecutionProvider will hang DDP training,
609
+ # and CUDAExecutionProvider will cause OOM on huggingface spaces.
610
+ self.onnx_providers = ['CUDAExecutionProvider'] if self.is_training else ['CPUExecutionProvider']
611
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
612
+ providers=self.onnx_providers)
613
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
614
  print(f'Arc2Face Face encoder loaded on CPU.')
615
 
 
646
 
647
  def _apply(self, fn):
648
  super()._apply(fn) # Call the parent _apply to handle parameters and buffers
 
649
  # A dirty hack to get the device of the model, passed from
650
  # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
651
  test_tensor = torch.zeros(1) # Create a test tensor
 
657
 
658
  if str(device) == 'cpu':
659
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
660
+ providers=['CPUExecutionProvider'])
661
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
662
  else:
663
  device_id = device.index
664
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
665
  providers=['CUDAExecutionProvider'],
666
+ provider_options=[{'device_id': device_id,
667
+ 'cudnn_conv_algo_search': 'HEURISTIC'}])
 
 
668
  self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
669
 
670
  self.device = device
 
740
  # but diffusers will call .to(dtype) in .from_single_file(),
741
  # and at that moment, the consistentID specific modules are not loaded yet.
742
  pipe = ConsistentIDPipeline.from_single_file(base_model_path)
743
+ pipe.load_ConsistentID_model(consistentID_weight_path="models/ConsistentID/ConsistentID-v1.bin",
744
+ bise_net_weight_path="models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth")
745
  pipe.to(dtype=self.dtype)
746
  # Since the passed-in pipe is None, this should be called during inference,
747
  # when the teacher ConsistentIDPipeline is not initialized.
 
792
 
793
  def _apply(self, fn):
794
  super()._apply(fn) # Call the parent _apply to handle parameters and buffers
 
795
  # A dirty hack to get the device of the model, passed from
796
  # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
797
  test_tensor = torch.zeros(1) # Create a test tensor
 
809
  device_id = device.index
810
  self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
811
  providers=['CUDAExecutionProvider'],
812
+ provider_options=[{'device_id': device_id,
813
+ 'cudnn_conv_algo_search': 'HEURISTIC'}])
 
 
814
  self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
815
 
816
  self.device = device
 
1275
  # No faces are found in the images, so return None embeddings.
1276
  # We don't want to return an all-zero embedding, which is useless.
1277
  if num_available_id_vecs == 0:
1278
+ return None, None, [0]
1279
 
1280
  # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
1281
  # during inference, we average across the batch dim.
adaface/unet_teachers.py CHANGED
@@ -62,46 +62,41 @@ class UNetTeacher(nn.Module):
62
  # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
63
  # same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
64
  def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
65
- num_denoising_steps=1, same_t_noise_across_instances=False,
66
  global_t_lb=0, global_t_ub=1000):
67
  assert num_denoising_steps <= 10
68
 
69
- if self.p_uses_cfg > 0:
 
 
 
70
  self.uses_cfg = np.random.rand() < self.p_uses_cfg
71
- if self.uses_cfg:
72
- # Randomly sample a cfg_scale from cfg_scale_range.
73
- self.cfg_scale = np.random.uniform(*self.cfg_scale_range)
74
- if self.cfg_scale == 1:
75
- self.uses_cfg = False
76
-
77
- if self.uses_cfg:
78
- print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
79
- if negative_context is not None:
80
- negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
81
-
82
- # if negative_context is None, then teacher_context is a combination of
83
- # (one or multiple if unet_ensemble) pos_context and neg_context.
84
- # If negative_context is not None, then teacher_context is only pos_context.
85
- else:
86
- self.cfg_scale = 1
87
- print("Teacher does not use CFG.")
88
-
89
- # If negative_context is None, then teacher_context is a combination of
90
- # (one or multiple if unet_ensemble) pos_context and neg_context.
91
- # Since not uses_cfg, we only need pos_context.
92
- # If negative_context is not None, then teacher_context is only pos_context.
93
- if negative_context is None:
94
- teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
95
  else:
96
  # p_uses_cfg = 0. Never use CFG.
97
  self.uses_cfg = False
98
- # In this case, the student only passes pos_context to the teacher,
99
- # so no need to split teacher_context into pos_context and neg_context.
100
- # self.cfg_scale will be accessed by the student,
101
- # so we need to make sure it is always set correctly,
102
- # in case someday we want to switch from CFG to non-CFG during runtime.
103
  self.cfg_scale = 1
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
106
  if self.name == 'unet_ensemble':
107
  # teacher_context is a list of teacher contexts.
@@ -199,14 +194,20 @@ class UNetTeacher(nn.Module):
199
  teacher_pos_contexts = []
200
  # teacher_context is a list of teacher contexts.
201
  for teacher_context_i in teacher_context:
202
- pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
203
- if pos_context.shape[0] != BS:
204
- breakpoint()
 
 
 
205
  teacher_pos_contexts.append(pos_context)
206
  teacher_context = teacher_pos_contexts
207
  else:
208
- pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
209
- if pos_context.shape[0] != BS:
 
 
 
210
  breakpoint()
211
  teacher_context = pos_context
212
 
 
62
  # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
63
  # same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
64
  def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
65
+ num_denoising_steps=1, force_uses_cfg=False, same_t_noise_across_instances=False,
66
  global_t_lb=0, global_t_ub=1000):
67
  assert num_denoising_steps <= 10
68
 
69
+ # force_uses_cfg overrides p_uses_cfg.
70
+ if force_uses_cfg > 0:
71
+ self.uses_cfg = True
72
+ elif self.p_uses_cfg > 0:
73
  self.uses_cfg = np.random.rand() < self.p_uses_cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  else:
75
  # p_uses_cfg = 0. Never use CFG.
76
  self.uses_cfg = False
 
 
 
 
 
77
  self.cfg_scale = 1
78
 
79
+ if self.uses_cfg:
80
+ # Randomly sample a cfg_scale from cfg_scale_range.
81
+ self.cfg_scale = np.random.uniform(*self.cfg_scale_range)
82
+ print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
83
+ if negative_context is not None:
84
+ negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
85
+
86
+ # if negative_context is None, then teacher_context is a combination of
87
+ # (one or multiple if unet_ensemble) pos_context and neg_context.
88
+ # If negative_context is not None, then teacher_context is only pos_context.
89
+ else:
90
+ self.cfg_scale = 1
91
+ print("Teacher does not use CFG.")
92
+
93
+ # If negative_context is None, then teacher_context is either a combination of
94
+ # (one or multiple if unet_ensemble) pos_context and neg_context, or only pos_context.
95
+ # Since not uses_cfg, we only need pos_context.
96
+ # If negative_context is not None, then teacher_context is only pos_context.
97
+ if negative_context is None:
98
+ teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
99
+
100
  is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
101
  if self.name == 'unet_ensemble':
102
  # teacher_context is a list of teacher contexts.
 
194
  teacher_pos_contexts = []
195
  # teacher_context is a list of teacher contexts.
196
  for teacher_context_i in teacher_context:
197
+ if teacher_context_i.shape[0] == BS * 2:
198
+ pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
199
+ elif teacher_context_i.shape[0] == BS:
200
+ pos_context = teacher_context_i
201
+ else:
202
+ breakpoint()
203
  teacher_pos_contexts.append(pos_context)
204
  teacher_context = teacher_pos_contexts
205
  else:
206
+ if teacher_context.shape[0] == BS * 2:
207
+ pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
208
+ elif teacher_context.shape[0] == BS:
209
+ pos_context = teacher_context
210
+ else:
211
  breakpoint()
212
  teacher_context = pos_context
213
 
adaface/util.py CHANGED
@@ -48,7 +48,7 @@ def perturb_tensor(ts, perturb_std, perturb_std_is_relative=True, keep_norm=Fals
48
  ts = ts + noise
49
 
50
  if verbose:
51
- print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).item():.03f}")
52
 
53
  return ts
54
 
@@ -69,7 +69,7 @@ def calc_stats(emb_name, embeddings, mean_dim=-1):
69
  # Compute it manually.
70
  l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt()
71
  norms = torch.norm(embeddings, dim=1).detach().cpu().numpy()
72
- print("L1: %.4f, L2: %.4f" %(l1_loss.item(), l2_loss.item()))
73
  print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
74
 
75
 
@@ -80,7 +80,7 @@ class ScaleGrad(torch.autograd.Function):
80
  ctx.save_for_backward(alpha_, debug)
81
  output = input_
82
  if debug:
83
- print(f"input: {input_.abs().mean().item()}")
84
  return output
85
 
86
  @staticmethod
@@ -90,7 +90,7 @@ class ScaleGrad(torch.autograd.Function):
90
  if ctx.needs_input_grad[0]:
91
  grad_output2 = grad_output * alpha_
92
  if debug:
93
- print(f"grad_output2: {grad_output2.abs().mean().item()}")
94
  else:
95
  grad_output2 = None
96
  return grad_output2, None, None
@@ -232,8 +232,8 @@ def create_consistentid_pipeline(base_model_path="models/sd15-dste8-vae.safetens
232
  # consistentID specific modules are still in fp32. Will be converted to fp16
233
  # later with .to(device, torch_dtype) by the caller.
234
  pipe.load_ConsistentID_model(
235
- consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin",
236
- bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
237
  )
238
  # Avoid passing dtype to ConsistentIDPipeline.from_single_file(),
239
  # because we've overloaded .to() to convert consistentID specific modules as well,
 
48
  ts = ts + noise
49
 
50
  if verbose:
51
+ print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).detach().item():.03f}")
52
 
53
  return ts
54
 
 
69
  # Compute it manually.
70
  l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt()
71
  norms = torch.norm(embeddings, dim=1).detach().cpu().numpy()
72
+ print("L1: %.4f, L2: %.4f" %(l1_loss.detach().item(), l2_loss.detach().item()))
73
  print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
74
 
75
 
 
80
  ctx.save_for_backward(alpha_, debug)
81
  output = input_
82
  if debug:
83
+ print(f"input: {input_.abs().mean().detach().item()}")
84
  return output
85
 
86
  @staticmethod
 
90
  if ctx.needs_input_grad[0]:
91
  grad_output2 = grad_output * alpha_
92
  if debug:
93
+ print(f"grad_output2: {grad_output2.abs().mean().detach().item()}")
94
  else:
95
  grad_output2 = None
96
  return grad_output2, None, None
 
232
  # consistentID specific modules are still in fp32. Will be converted to fp16
233
  # later with .to(device, torch_dtype) by the caller.
234
  pipe.load_ConsistentID_model(
235
+ consistentID_weight_path="models/ConsistentID/ConsistentID-v1.bin",
236
+ bise_net_weight_path="models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
237
  )
238
  # Avoid passing dtype to ConsistentIDPipeline.from_single_file(),
239
  # because we've overloaded .to() to convert consistentID specific modules as well,
app.py CHANGED
@@ -24,11 +24,16 @@ parser = argparse.ArgumentParser()
24
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
25
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
26
  parser.add_argument('--adaface_ckpt_path', type=str,
27
- default='models/adaface/VGGface2_HQ_masks2025-03-06T03-31-21_zero3-ada-1000.pt')
28
  parser.add_argument('--model_style_type', type=str, default='photorealistic',
29
  choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
30
- parser.add_argument("--guidance_scale", type=float, default=8.0,
31
- help="The guidance scale for the diffusion model. Default: 8.0")
 
 
 
 
 
32
 
33
  parser.add_argument('--gpu', type=int, default=None)
34
  parser.add_argument('--ip', type=str, default="0.0.0.0")
@@ -70,7 +75,8 @@ adaface_base_model_path = model_style_type2base_model_path["photorealistic"]
70
  id_animator = load_model(model_style_type=args.model_style_type, device='cpu')
71
  adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=adaface_base_model_path,
72
  adaface_encoder_types=args.adaface_encoder_types,
73
- adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu')
 
74
 
75
  basedir = os.getcwd()
76
  savedir = os.path.join(basedir,'samples')
@@ -80,22 +86,22 @@ os.makedirs(savedir, exist_ok=True)
80
  #os.system(f"rm -rf gradio_cached_examples/")
81
 
82
  def swap_to_gallery(images):
83
- # Update uploaded_files_gallery, show files, hide clear_button_column
84
  # Or:
85
- # Update uploaded_init_img_gallery, show init_img_files, hide init_clear_button_column
86
  return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False)
87
 
88
  def remove_back_to_files():
89
- # Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx
90
  # Or:
91
- # Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
92
- return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True), gr.update(value="0")
93
 
94
  def get_clicked_image(data: gr.SelectData):
95
  return data.index
96
 
97
  @spaces.GPU
98
- def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale, out_image_count=4):
99
  if uploaded_image_paths is None:
100
  print("No image uploaded")
101
  return None, None, None
@@ -112,7 +118,7 @@ def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale
112
  with torch.no_grad():
113
  adaface_subj_embs = \
114
  adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
115
- update_text_encoder=True)
116
 
117
  if adaface_subj_embs is None:
118
  raise gr.Error(f"Failed to detect any faces! Please try with other images")
@@ -127,6 +133,7 @@ def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale
127
  else:
128
  prompt = "face portrait, " + prompt
129
 
 
130
  guidance_scale = min(guidance_scale, 5)
131
 
132
  # samples: A list of PIL Image instances.
@@ -134,7 +141,7 @@ def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale
134
  samples = adaface(noise, prompt, placeholder_tokens_pos='append',
135
  guidance_scale=guidance_scale,
136
  out_image_count=out_image_count,
137
- repeat_prompt_for_each_encoder=True,
138
  verbose=True)
139
 
140
  face_paths = []
@@ -145,7 +152,7 @@ def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale
145
  sample.save(face_path)
146
  print(f"Generated init image: {face_path}")
147
 
148
- # Update uploaded_init_img_gallery, update and hide init_img_files, hide init_clear_button_column
149
  return gr.update(value=face_paths, visible=True), gr.update(value=face_paths, visible=False), gr.update(visible=True)
150
 
151
  @spaces.GPU(duration=90)
@@ -153,7 +160,7 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
153
  init_image_strength, init_image_final_weight,
154
  prompt, negative_prompt, num_steps, video_length, guidance_scale,
155
  seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
156
- highlight_face, is_adaface_enabled, adaface_power_scale,
157
  id_animator_anneal_steps, progress=gr.Progress(track_tqdm=True)):
158
 
159
  global adaface, id_animator
@@ -195,10 +202,19 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
195
  adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
196
  update_text_encoder=True)
197
 
 
 
 
 
 
 
 
 
198
  # adaface_prompt_embeds: [1, 77, 768].
199
  adaface_prompt_embeds, negative_prompt_embeds, _, _ = \
200
  adaface.encode_prompt(prompt, placeholder_tokens_pos='append',
201
- repeat_prompt_for_each_encoder=True,
 
202
  verbose=True)
203
 
204
  # ID-Animator Image Embedding Initial and End Scales
@@ -267,13 +283,14 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
267
  <b>Official demo</b> for our working paper <b>AdaFace: A Versatile Text-space Face Encoder for Face Synthesis and Processing</b>.<br>
268
 
269
  ❗️**NOTE**❗️
270
- - Support switching between three model styles: **Realistic**, **Photorealistic** and **Anime**. **Realistic** is less realistic than **Photorealistic** but has better motions.
271
  - If you change the model style, please wait for 20~30 seconds for loading new model weight before the model begins to generate images/videos.
272
 
273
  ❗️**Tips**❗️
274
  - You can upload one or more subject images for generating ID-specific video.
275
- - If the face loses focus, try enabling "Highlight face".
276
- - If the motion is weird, e.g., the prompt is "... running", try increasing the number of sampling steps.
 
277
  - Usage explanations and demos: [Readme](https://huggingface.co/spaces/adaface-neurips/adaface-animate/blob/main/README2.md).
278
  - AdaFace Text-to-Image: <a href="https://huggingface.co/spaces/adaface-neurips/adaface" style="display: inline-flex; align-items: center;">
279
  AdaFace
@@ -285,16 +302,16 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
285
 
286
  with gr.Row():
287
  with gr.Column():
288
- files = gr.File(
289
  label="Drag / Select 1 or more photos of a person's face",
290
  file_types=["image"],
291
  file_count="multiple"
292
  )
293
- files.GRADIO_CACHE = "/tmp/gradio"
294
  image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
295
- uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=2, height=300)
296
  with gr.Column(visible=False) as clear_button_column:
297
- remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=files, size="sm")
298
 
299
  init_img_files = gr.File(
300
  label="[Optional] Generate 4 images and select 1 image",
@@ -305,7 +322,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
305
  init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
306
  # Although there's only one image, we still use columns=3, to scale down the image size.
307
  # Otherwise it will occupy the full width, and the gallery won't show the whole image.
308
- uploaded_init_img_gallery = gr.Gallery(label="Init image", visible=False, columns=3, rows=1, height=200)
309
  # placeholder is just hint, not the real value. So we use "value='0'" instead of "placeholder='0'".
310
  init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
311
 
@@ -320,7 +337,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
320
  allow_custom_value=True,
321
  choices=[
322
  "portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
323
- "portrait, walking on the beach, sunset",
324
  "portrait, in a white apron and chef hat, garnishing a gourmet dish",
325
  "portrait, dancing pose among folks in a park, waving hands",
326
  "portrait, in iron man costume, the sky ablaze with hues of orange and purple",
@@ -328,18 +345,21 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
328
  "portrait, night view of tokyo street, neon light",
329
  "portrait, playing guitar on a boat, ocean waves",
330
  "portrait, with a passion for reading, curled up with a book in a cozy nook near a window",
331
- "portrait, celebrating new year, fireworks",
332
- "portrait, running pose in a park",
333
  "portrait, in space suit, space helmet, walking on mars",
334
  "portrait, in superman costume, the sky ablaze with hues of orange and purple"
335
  ])
336
 
337
- highlight_face = gr.Checkbox(label="Highlight face", value=False,
338
  info="Enhance the facial features by prepending 'face portrait' to the prompt",
339
  visible=True)
340
-
 
 
 
341
  init_image_strength = gr.Slider(
342
- label="Init Image Strength",
343
  info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).",
344
  minimum=0,
345
  maximum=3,
@@ -352,7 +372,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
352
  minimum=0,
353
  maximum=2,
354
  step=0.025,
355
- value=0.1,
356
  )
357
 
358
  model_style_type = gr.Dropdown(
@@ -415,7 +435,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
415
  minimum=0.8,
416
  maximum=1.2,
417
  step=0.05,
418
- value=1.1,
419
  visible=True,
420
  )
421
 
@@ -443,7 +463,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
443
  minimum=0,
444
  maximum=1,
445
  step=0.1,
446
- value=0.1,
447
  )
448
 
449
  id_animator_anneal_steps = gr.Slider(
@@ -464,13 +484,13 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
464
  with gr.Column():
465
  result_video = gr.Video(label="Generated Animation", interactive=False)
466
 
467
- files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files_gallery, clear_button_column, files])
468
- remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column, files, init_img_selected_idx])
469
 
470
  init_img_files.upload(fn=swap_to_gallery, inputs=init_img_files,
471
- outputs=[uploaded_init_img_gallery, init_clear_button_column, init_img_files])
472
  remove_init_and_reupload.click(fn=remove_back_to_files,
473
- outputs=[uploaded_init_img_gallery, init_clear_button_column,
474
  init_img_files, init_img_selected_idx])
475
  gen_init.click(fn=check_prompt_and_model_type,
476
  inputs=[prompt, model_style_type],outputs=None).success(
@@ -479,10 +499,11 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
479
  outputs=seed,
480
  queue=False,
481
  api_name=False,
482
- ).then(fn=gen_init_images, inputs=[uploaded_files_gallery, prompt, highlight_face,
 
483
  guidance_scale],
484
- outputs=[uploaded_init_img_gallery, init_img_files, init_clear_button_column])
485
- uploaded_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx)
486
 
487
  submit.click(fn=check_prompt_and_model_type,
488
  inputs=[prompt, model_style_type],outputs=None).success(
@@ -493,11 +514,11 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
493
  api_name=False,
494
  ).then(
495
  fn=generate_video,
496
- inputs=[image_container, files,
497
  init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
498
  prompt, negative_prompt, num_steps, video_length, guidance_scale,
499
  seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
500
- highlight_face, is_adaface_enabled,
501
  adaface_power_scale, id_animator_anneal_steps],
502
  outputs=[result_video]
503
  )
 
24
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
25
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
26
  parser.add_argument('--adaface_ckpt_path', type=str,
27
+ default='models/adaface/VGGface2_HQ_masks2025-05-22T17-51-19_zero3-ada-1000.pt')
28
  parser.add_argument('--model_style_type', type=str, default='photorealistic',
29
  choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
30
+ parser.add_argument("--guidance_scale", type=float, default=6.0,
31
+ help="The guidance scale for the diffusion model. Default: 6.0")
32
+ parser.add_argument('--num_inference_steps', type=int, default=50,
33
+ help="The number of denoising steps for image generation (NOT FOR VIDEOS). Default: 50")
34
+ parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
35
+ choices=["ada", "arc2face", "consistentID"],
36
+ help="Ablate to use the image ID embs instead of Ada embs")
37
 
38
  parser.add_argument('--gpu', type=int, default=None)
39
  parser.add_argument('--ip', type=str, default="0.0.0.0")
 
75
  id_animator = load_model(model_style_type=args.model_style_type, device='cpu')
76
  adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=adaface_base_model_path,
77
  adaface_encoder_types=args.adaface_encoder_types,
78
+ adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu',
79
+ num_inference_steps=args.num_inference_steps)
80
 
81
  basedir = os.getcwd()
82
  savedir = os.path.join(basedir,'samples')
 
86
  #os.system(f"rm -rf gradio_cached_examples/")
87
 
88
  def swap_to_gallery(images):
89
+ # Update uploaded_ref_files_gallery, show ref_files, hide clear_button_column
90
  # Or:
91
+ # Update generated_init_img_gallery, show init_img_files, hide init_clear_button_column
92
  return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False)
93
 
94
  def remove_back_to_files():
95
+ # Hide uploaded_ref_files_gallery, show clear_button_column, hide ref_files, reset init_img_selected_idx
96
  # Or:
97
+ # Hide generated_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
98
+ return gr.update(value=None, visible=False), gr.update(visible=False), gr.update(value=None, visible=True), gr.update(value="0")
99
 
100
  def get_clicked_image(data: gr.SelectData):
101
  return data.index
102
 
103
  @spaces.GPU
104
+ def gen_init_images(uploaded_image_paths, prompt, highlight_face, enhance_composition, guidance_scale, out_image_count=4):
105
  if uploaded_image_paths is None:
106
  print("No image uploaded")
107
  return None, None, None
 
118
  with torch.no_grad():
119
  adaface_subj_embs = \
120
  adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
121
+ update_text_encoder=True)
122
 
123
  if adaface_subj_embs is None:
124
  raise gr.Error(f"Failed to detect any faces! Please try with other images")
 
133
  else:
134
  prompt = "face portrait, " + prompt
135
 
136
+ # guidance_scale is at most 5.0 for init image generation.
137
  guidance_scale = min(guidance_scale, 5)
138
 
139
  # samples: A list of PIL Image instances.
 
141
  samples = adaface(noise, prompt, placeholder_tokens_pos='append',
142
  guidance_scale=guidance_scale,
143
  out_image_count=out_image_count,
144
+ repeat_prompt_for_each_encoder=enhance_composition,
145
  verbose=True)
146
 
147
  face_paths = []
 
152
  sample.save(face_path)
153
  print(f"Generated init image: {face_path}")
154
 
155
+ # Update generated_init_img_gallery, update and hide init_img_files, hide init_clear_button_column
156
  return gr.update(value=face_paths, visible=True), gr.update(value=face_paths, visible=False), gr.update(visible=True)
157
 
158
  @spaces.GPU(duration=90)
 
160
  init_image_strength, init_image_final_weight,
161
  prompt, negative_prompt, num_steps, video_length, guidance_scale,
162
  seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
163
+ highlight_face, enhance_composition, is_adaface_enabled, adaface_power_scale,
164
  id_animator_anneal_steps, progress=gr.Progress(track_tqdm=True)):
165
 
166
  global adaface, id_animator
 
202
  adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
203
  update_text_encoder=True)
204
 
205
+ if args.ablate_prompt_embed_type != "ada":
206
+ # Find the prompt_emb_type index in adaface_encoder_types
207
+ # adaface_encoder_types: ["consistentID", "arc2face"]
208
+ ablate_prompt_embed_index = args.adaface_encoder_types.index(args.ablate_prompt_embed_type)
209
+ ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
210
+ else:
211
+ ablate_prompt_embed_type = "ada"
212
+
213
  # adaface_prompt_embeds: [1, 77, 768].
214
  adaface_prompt_embeds, negative_prompt_embeds, _, _ = \
215
  adaface.encode_prompt(prompt, placeholder_tokens_pos='append',
216
+ ablate_prompt_embed_type=ablate_prompt_embed_type,
217
+ repeat_prompt_for_each_encoder=enhance_composition,
218
  verbose=True)
219
 
220
  # ID-Animator Image Embedding Initial and End Scales
 
283
  <b>Official demo</b> for our working paper <b>AdaFace: A Versatile Text-space Face Encoder for Face Synthesis and Processing</b>.<br>
284
 
285
  ❗️**NOTE**❗️
286
+ - Support switching between three model styles: **Photorealistic**, **Realistic** and **Anime**.
287
  - If you change the model style, please wait for 20~30 seconds for loading new model weight before the model begins to generate images/videos.
288
 
289
  ❗️**Tips**❗️
290
  - You can upload one or more subject images for generating ID-specific video.
291
+ - "Highlight face" will make the face more prominent in the generated video.
292
+ - "Enhance Composition" will enhance the overall composition of the generated video.
293
+ - "Highlight face" and "Enhance Composition" can be used together.
294
  - Usage explanations and demos: [Readme](https://huggingface.co/spaces/adaface-neurips/adaface-animate/blob/main/README2.md).
295
  - AdaFace Text-to-Image: <a href="https://huggingface.co/spaces/adaface-neurips/adaface" style="display: inline-flex; align-items: center;">
296
  AdaFace
 
302
 
303
  with gr.Row():
304
  with gr.Column():
305
+ ref_files = gr.File(
306
  label="Drag / Select 1 or more photos of a person's face",
307
  file_types=["image"],
308
  file_count="multiple"
309
  )
310
+ ref_files.GRADIO_CACHE = "/tmp/gradio"
311
  image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
312
+ uploaded_ref_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=2, height=300)
313
  with gr.Column(visible=False) as clear_button_column:
314
+ remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=ref_files, size="sm")
315
 
316
  init_img_files = gr.File(
317
  label="[Optional] Generate 4 images and select 1 image",
 
322
  init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
323
  # Although there's only one image, we still use columns=3, to scale down the image size.
324
  # Otherwise it will occupy the full width, and the gallery won't show the whole image.
325
+ generated_init_img_gallery = gr.Gallery(label="Init image", visible=False, columns=3, rows=1, height=200)
326
  # placeholder is just hint, not the real value. So we use "value='0'" instead of "placeholder='0'".
327
  init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
328
 
 
337
  allow_custom_value=True,
338
  choices=[
339
  "portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
340
+ "portrait, walking on the beach, front of face, sunset",
341
  "portrait, in a white apron and chef hat, garnishing a gourmet dish",
342
  "portrait, dancing pose among folks in a park, waving hands",
343
  "portrait, in iron man costume, the sky ablaze with hues of orange and purple",
 
345
  "portrait, night view of tokyo street, neon light",
346
  "portrait, playing guitar on a boat, ocean waves",
347
  "portrait, with a passion for reading, curled up with a book in a cozy nook near a window",
348
+ "portrait, celebrating new year alone, fireworks",
349
+ "portrait, running pose in a park, front view",
350
  "portrait, in space suit, space helmet, walking on mars",
351
  "portrait, in superman costume, the sky ablaze with hues of orange and purple"
352
  ])
353
 
354
+ highlight_face = gr.Checkbox(label="Highlight face", value=True,
355
  info="Enhance the facial features by prepending 'face portrait' to the prompt",
356
  visible=True)
357
+ enhance_composition = gr.Checkbox(label="Enhance Composition", value=False,
358
+ info="Enhance the overall composition of the generated video",
359
+ visible=True)
360
+
361
  init_image_strength = gr.Slider(
362
+ label="Beginning Strength of Init Image",
363
  info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).",
364
  minimum=0,
365
  maximum=3,
 
372
  minimum=0,
373
  maximum=2,
374
  step=0.025,
375
+ value=0.5,
376
  )
377
 
378
  model_style_type = gr.Dropdown(
 
435
  minimum=0.8,
436
  maximum=1.2,
437
  step=0.05,
438
+ value=1.05,
439
  visible=True,
440
  )
441
 
 
463
  minimum=0,
464
  maximum=1,
465
  step=0.1,
466
+ value=0.5,
467
  )
468
 
469
  id_animator_anneal_steps = gr.Slider(
 
484
  with gr.Column():
485
  result_video = gr.Video(label="Generated Animation", interactive=False)
486
 
487
+ ref_files.upload(fn=swap_to_gallery, inputs=ref_files, outputs=[uploaded_ref_files_gallery, clear_button_column, ref_files])
488
+ remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_ref_files_gallery, clear_button_column, ref_files, init_img_selected_idx])
489
 
490
  init_img_files.upload(fn=swap_to_gallery, inputs=init_img_files,
491
+ outputs=[generated_init_img_gallery, init_clear_button_column, init_img_files])
492
  remove_init_and_reupload.click(fn=remove_back_to_files,
493
+ outputs=[generated_init_img_gallery, init_clear_button_column,
494
  init_img_files, init_img_selected_idx])
495
  gen_init.click(fn=check_prompt_and_model_type,
496
  inputs=[prompt, model_style_type],outputs=None).success(
 
499
  outputs=seed,
500
  queue=False,
501
  api_name=False,
502
+ ).then(fn=gen_init_images, inputs=[uploaded_ref_files_gallery, prompt,
503
+ highlight_face, enhance_composition,
504
  guidance_scale],
505
+ outputs=[generated_init_img_gallery, init_img_files, init_clear_button_column])
506
+ generated_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx)
507
 
508
  submit.click(fn=check_prompt_and_model_type,
509
  inputs=[prompt, model_style_type],outputs=None).success(
 
514
  api_name=False,
515
  ).then(
516
  fn=generate_video,
517
+ inputs=[image_container, ref_files,
518
  init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
519
  prompt, negative_prompt, num_steps, video_length, guidance_scale,
520
  seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
521
+ highlight_face, enhance_composition, is_adaface_enabled,
522
  adaface_power_scale, id_animator_anneal_steps],
523
  outputs=[result_video]
524
  )