Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
737c1a0
1
Parent(s):
8776445
Update code
Browse files- adaface/adaface_wrapper.py +34 -21
- adaface/diffusers_attn_lora_capture.py +67 -62
- adaface/face_id_to_ada_prompt.py +14 -16
- adaface/unet_teachers.py +37 -36
- adaface/util.py +6 -6
- app.py +62 -41
adaface/adaface_wrapper.py
CHANGED
@@ -30,7 +30,7 @@ class AdaFaceWrapper(nn.Module):
|
|
30 |
use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
-
attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
34 |
device='cuda', is_training=False):
|
35 |
'''
|
36 |
pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
|
@@ -52,7 +52,7 @@ class AdaFaceWrapper(nn.Module):
|
|
52 |
self.q_lora_updates_query = q_lora_updates_query
|
53 |
self.use_lcm = use_lcm
|
54 |
self.subject_string = subject_string
|
55 |
-
self.
|
56 |
|
57 |
self.default_scheduler_name = default_scheduler_name
|
58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
@@ -189,10 +189,10 @@ class AdaFaceWrapper(nn.Module):
|
|
189 |
pipeline.unet = unet_ensemble
|
190 |
|
191 |
print(f"Loaded pipeline from {self.base_model_path}.")
|
192 |
-
if not remove_unet and (self.unet_uses_attn_lora or self.
|
193 |
unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
|
194 |
attn_lora_layer_names=self.attn_lora_layer_names,
|
195 |
-
|
196 |
q_lora_updates_query=self.q_lora_updates_query)
|
197 |
|
198 |
pipeline.unet = unet2
|
@@ -294,12 +294,11 @@ class AdaFaceWrapper(nn.Module):
|
|
294 |
def load_unet_loras(self, unet, unet_lora_modules_state_dict,
|
295 |
use_attn_lora=True, use_ffn_lora=False,
|
296 |
attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
297 |
-
|
298 |
q_lora_updates_query=False):
|
299 |
attn_capture_procs, attn_opt_modules = \
|
300 |
set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
|
301 |
lora_rank=192, lora_scale_down=8,
|
302 |
-
cross_attn_shrink_factor=cross_attn_shrink_factor,
|
303 |
q_lora_updates_query=q_lora_updates_query)
|
304 |
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
|
305 |
if use_ffn_lora:
|
@@ -343,16 +342,17 @@ class AdaFaceWrapper(nn.Module):
|
|
343 |
print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
|
344 |
self.outfeat_capture_blocks.append(unet.up_blocks[3])
|
345 |
|
346 |
-
# If
|
347 |
# but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
|
348 |
set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
|
349 |
use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
|
350 |
-
|
|
|
351 |
|
352 |
return unet
|
353 |
|
354 |
def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
355 |
-
|
356 |
unet_lora_weight_found = False
|
357 |
if isinstance(self.adaface_ckpt_paths, str):
|
358 |
adaface_ckpt_paths = [self.adaface_ckpt_paths]
|
@@ -360,7 +360,7 @@ class AdaFaceWrapper(nn.Module):
|
|
360 |
adaface_ckpt_paths = self.adaface_ckpt_paths
|
361 |
|
362 |
for adaface_ckpt_path in adaface_ckpt_paths:
|
363 |
-
ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu')
|
364 |
if 'unet_lora_modules' in ckpt_dict:
|
365 |
unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
|
366 |
print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
|
@@ -379,7 +379,7 @@ class AdaFaceWrapper(nn.Module):
|
|
379 |
unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
|
380 |
use_attn_lora=use_attn_lora,
|
381 |
attn_lora_layer_names=attn_lora_layer_names,
|
382 |
-
|
383 |
q_lora_updates_query=q_lora_updates_query)
|
384 |
unet.unets[i] = unet_
|
385 |
print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
|
@@ -387,7 +387,7 @@ class AdaFaceWrapper(nn.Module):
|
|
387 |
unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
|
388 |
use_attn_lora=use_attn_lora,
|
389 |
attn_lora_layer_names=attn_lora_layer_names,
|
390 |
-
|
391 |
q_lora_updates_query=q_lora_updates_query)
|
392 |
|
393 |
return unet
|
@@ -612,8 +612,9 @@ class AdaFaceWrapper(nn.Module):
|
|
612 |
# Scan prompt and replace tokens in self.placeholder_token_ids
|
613 |
# with the corresponding image embeddings.
|
614 |
prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
|
|
|
615 |
prompt_embeds2 = prompt_embeds.clone()
|
616 |
-
if alt_prompt_embed_type
|
617 |
if self.img_prompt_embs is None:
|
618 |
print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
|
619 |
return prompt_embeds
|
@@ -628,17 +629,18 @@ class AdaFaceWrapper(nn.Module):
|
|
628 |
breakpoint()
|
629 |
|
630 |
repl_tokens = {}
|
|
|
631 |
for i in range(len(prompt_tokens)):
|
632 |
if prompt_tokens[i] in self.all_placeholder_tokens:
|
633 |
encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
|
634 |
if prompt_tokens[i] in sublist), 0)
|
635 |
-
alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx]
|
636 |
-
prompt_embeds2[:, i] = prompt_embeds2[:, i] *
|
637 |
+ repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
|
638 |
repl_tokens[prompt_tokens[i]] = 1
|
639 |
|
640 |
repl_token_count = len(repl_tokens)
|
641 |
-
if
|
642 |
print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
|
643 |
else:
|
644 |
print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
|
@@ -650,7 +652,7 @@ class AdaFaceWrapper(nn.Module):
|
|
650 |
placeholder_tokens_pos='append',
|
651 |
ablate_prompt_only_placeholders=False,
|
652 |
ablate_prompt_no_placeholders=False,
|
653 |
-
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
|
654 |
nonmix_prompt_emb_weight=0,
|
655 |
repeat_prompt_for_each_encoder=True,
|
656 |
device=None, verbose=False):
|
@@ -678,14 +680,25 @@ class AdaFaceWrapper(nn.Module):
|
|
678 |
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
|
679 |
self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
|
680 |
|
681 |
-
if ablate_prompt_embed_type
|
682 |
alt_prompt_embed_type = ablate_prompt_embed_type
|
683 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
elif nonmix_prompt_emb_weight > 0:
|
685 |
alt_prompt_embed_type = 'ada-nonmix'
|
686 |
-
|
|
|
|
|
687 |
else:
|
688 |
-
|
|
|
689 |
|
690 |
if sum(alt_prompt_emb_weights) > 0:
|
691 |
prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
|
|
|
30 |
use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
+
attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
|
34 |
device='cuda', is_training=False):
|
35 |
'''
|
36 |
pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
|
|
|
52 |
self.q_lora_updates_query = q_lora_updates_query
|
53 |
self.use_lcm = use_lcm
|
54 |
self.subject_string = subject_string
|
55 |
+
self.normalize_cross_attn = normalize_cross_attn
|
56 |
|
57 |
self.default_scheduler_name = default_scheduler_name
|
58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
|
|
189 |
pipeline.unet = unet_ensemble
|
190 |
|
191 |
print(f"Loaded pipeline from {self.base_model_path}.")
|
192 |
+
if not remove_unet and (self.unet_uses_attn_lora or self.normalize_cross_attn):
|
193 |
unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
|
194 |
attn_lora_layer_names=self.attn_lora_layer_names,
|
195 |
+
normalize_cross_attn=self.normalize_cross_attn,
|
196 |
q_lora_updates_query=self.q_lora_updates_query)
|
197 |
|
198 |
pipeline.unet = unet2
|
|
|
294 |
def load_unet_loras(self, unet, unet_lora_modules_state_dict,
|
295 |
use_attn_lora=True, use_ffn_lora=False,
|
296 |
attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
297 |
+
normalize_cross_attn=False,
|
298 |
q_lora_updates_query=False):
|
299 |
attn_capture_procs, attn_opt_modules = \
|
300 |
set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
|
301 |
lora_rank=192, lora_scale_down=8,
|
|
|
302 |
q_lora_updates_query=q_lora_updates_query)
|
303 |
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
|
304 |
if use_ffn_lora:
|
|
|
342 |
print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
|
343 |
self.outfeat_capture_blocks.append(unet.up_blocks[3])
|
344 |
|
345 |
+
# If normalize_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
|
346 |
# but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
|
347 |
set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
|
348 |
use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
|
349 |
+
normalize_cross_attn=normalize_cross_attn, mix_attn_mats_in_batch=False,
|
350 |
+
res_hidden_states_gradscale=0)
|
351 |
|
352 |
return unet
|
353 |
|
354 |
def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
355 |
+
normalize_cross_attn=False, q_lora_updates_query=False):
|
356 |
unet_lora_weight_found = False
|
357 |
if isinstance(self.adaface_ckpt_paths, str):
|
358 |
adaface_ckpt_paths = [self.adaface_ckpt_paths]
|
|
|
360 |
adaface_ckpt_paths = self.adaface_ckpt_paths
|
361 |
|
362 |
for adaface_ckpt_path in adaface_ckpt_paths:
|
363 |
+
ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu', weights_only=False)
|
364 |
if 'unet_lora_modules' in ckpt_dict:
|
365 |
unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
|
366 |
print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
|
|
|
379 |
unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
|
380 |
use_attn_lora=use_attn_lora,
|
381 |
attn_lora_layer_names=attn_lora_layer_names,
|
382 |
+
normalize_cross_attn=normalize_cross_attn,
|
383 |
q_lora_updates_query=q_lora_updates_query)
|
384 |
unet.unets[i] = unet_
|
385 |
print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
|
|
|
387 |
unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
|
388 |
use_attn_lora=use_attn_lora,
|
389 |
attn_lora_layer_names=attn_lora_layer_names,
|
390 |
+
normalize_cross_attn=normalize_cross_attn,
|
391 |
q_lora_updates_query=q_lora_updates_query)
|
392 |
|
393 |
return unet
|
|
|
612 |
# Scan prompt and replace tokens in self.placeholder_token_ids
|
613 |
# with the corresponding image embeddings.
|
614 |
prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
|
615 |
+
# prompt_embeds are the ada embeddings.
|
616 |
prompt_embeds2 = prompt_embeds.clone()
|
617 |
+
if alt_prompt_embed_type.startswith('img'):
|
618 |
if self.img_prompt_embs is None:
|
619 |
print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
|
620 |
return prompt_embeds
|
|
|
629 |
breakpoint()
|
630 |
|
631 |
repl_tokens = {}
|
632 |
+
ada_emb_weight = alt_prompt_emb_weights[0]
|
633 |
for i in range(len(prompt_tokens)):
|
634 |
if prompt_tokens[i] in self.all_placeholder_tokens:
|
635 |
encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
|
636 |
if prompt_tokens[i] in sublist), 0)
|
637 |
+
alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx + 1]
|
638 |
+
prompt_embeds2[:, i] = prompt_embeds2[:, i] * ada_emb_weight \
|
639 |
+ repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
|
640 |
repl_tokens[prompt_tokens[i]] = 1
|
641 |
|
642 |
repl_token_count = len(repl_tokens)
|
643 |
+
if ada_emb_weight == 0:
|
644 |
print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
|
645 |
else:
|
646 |
print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
|
|
|
652 |
placeholder_tokens_pos='append',
|
653 |
ablate_prompt_only_placeholders=False,
|
654 |
ablate_prompt_no_placeholders=False,
|
655 |
+
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img', 'img1', 'img2'.
|
656 |
nonmix_prompt_emb_weight=0,
|
657 |
repeat_prompt_for_each_encoder=True,
|
658 |
device=None, verbose=False):
|
|
|
680 |
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
|
681 |
self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
|
682 |
|
683 |
+
if ablate_prompt_embed_type.startswith('img'):
|
684 |
alt_prompt_embed_type = ablate_prompt_embed_type
|
685 |
+
if alt_prompt_embed_type == 'img1':
|
686 |
+
# The mixing weights of ada, img1, and img2 are 0, 1, and 0.
|
687 |
+
alt_prompt_emb_weights = (0, 1, 0)
|
688 |
+
elif alt_prompt_embed_type == 'img2':
|
689 |
+
# The mixing weights of ada, img1, and img2 are 0, 0, and 1.
|
690 |
+
alt_prompt_emb_weights = (0, 0, 1)
|
691 |
+
else:
|
692 |
+
# The mixing weights of ada, img1, and img2 are 0, 1, and 1.
|
693 |
+
alt_prompt_emb_weights = (0, 1, 1)
|
694 |
elif nonmix_prompt_emb_weight > 0:
|
695 |
alt_prompt_embed_type = 'ada-nonmix'
|
696 |
+
# The mixing weight of ada is 1 - nonmix_prompt_emb_weight, instead of 1 - nonmix_prompt_emb_weight * 2.
|
697 |
+
# It means ada is mixed by this weight with both img1 and img2.
|
698 |
+
alt_prompt_emb_weights = (1 - nonmix_prompt_emb_weight, nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
|
699 |
else:
|
700 |
+
# Don't change the prompt embeddings. So we set all the mixing weights to 0.
|
701 |
+
alt_prompt_emb_weights = (0, 0, 0)
|
702 |
|
703 |
if sum(alt_prompt_emb_weights) > 0:
|
704 |
prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
|
adaface/diffusers_attn_lora_capture.py
CHANGED
@@ -4,7 +4,6 @@ import torch.nn.functional as F
|
|
4 |
from typing import Optional, Tuple, Dict, Any
|
5 |
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
6 |
from diffusers.utils import logging, is_torch_version, deprecate
|
7 |
-
from diffusers.utils.torch_utils import fourier_filter
|
8 |
# UNet is a diffusers PeftAdapterMixin instance.
|
9 |
from diffusers.loaders.peft import PeftAdapterMixin
|
10 |
from peft import LoraConfig, get_peft_model
|
@@ -12,7 +11,6 @@ import peft.tuners.lora as peft_lora
|
|
12 |
from peft.tuners.lora.dora import DoraLinearLayer
|
13 |
from einops import rearrange
|
14 |
import math, re
|
15 |
-
import numpy as np
|
16 |
from peft.tuners.tuners_utils import BaseTunerLayer
|
17 |
|
18 |
|
@@ -28,7 +26,7 @@ class ScaleGrad(torch.autograd.Function):
|
|
28 |
ctx.save_for_backward(alpha_, debug)
|
29 |
output = input_
|
30 |
if debug:
|
31 |
-
print(f"input: {input_.abs().mean().item()}")
|
32 |
return output
|
33 |
|
34 |
@staticmethod
|
@@ -38,7 +36,7 @@ class ScaleGrad(torch.autograd.Function):
|
|
38 |
if ctx.needs_input_grad[0]:
|
39 |
grad_output2 = grad_output * alpha_
|
40 |
if debug:
|
41 |
-
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
42 |
else:
|
43 |
grad_output2 = None
|
44 |
return grad_output2, None, None
|
@@ -77,36 +75,11 @@ def split_indices_by_instance(indices, as_dict=False):
|
|
77 |
indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
|
78 |
return indices_by_instance
|
79 |
|
80 |
-
# If do_sum, returned emb_attns is 3D. Otherwise 4D.
|
81 |
-
# indices are applied on the first 2 dims of attn_mat.
|
82 |
-
def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False):
|
83 |
-
indices_by_instance = split_indices_by_instance(indices)
|
84 |
-
|
85 |
-
# emb_attns[0]: [1, 9, 8, 64]
|
86 |
-
# 8: 8 attention heads. Last dim 64: number of image tokens.
|
87 |
-
emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ]
|
88 |
-
if all_token_weights is not None:
|
89 |
-
# all_token_weights: [4, 77].
|
90 |
-
# token_weights_by_instance[0]: [1, 9, 1, 1].
|
91 |
-
token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ]
|
92 |
-
else:
|
93 |
-
token_weights = [ 1 ] * len(indices_by_instance)
|
94 |
-
|
95 |
-
# Apply token weights.
|
96 |
-
emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ]
|
97 |
-
|
98 |
-
# sum among K_subj_i subj embeddings -> [1, 8, 64]
|
99 |
-
if do_sum:
|
100 |
-
emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ]
|
101 |
-
elif do_mean:
|
102 |
-
emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ]
|
103 |
-
|
104 |
-
emb_attns = torch.cat(emb_attns, dim=0)
|
105 |
-
return emb_attns
|
106 |
-
|
107 |
# Slow implementation equivalent to F.scaled_dot_product_attention.
|
108 |
-
def scaled_dot_product_attention(query, key, value,
|
109 |
-
|
|
|
|
|
110 |
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
|
111 |
B, L, S = query.size(0), query.size(-2), key.size(-2)
|
112 |
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
@@ -128,21 +101,39 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|
128 |
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
129 |
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
130 |
|
131 |
-
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
147 |
output = attn_weight @ value
|
148 |
return output, attn_score, attn_weight
|
@@ -156,23 +147,25 @@ class AttnProcessor_LoRA_Capture(nn.Module):
|
|
156 |
def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
|
157 |
lora_uses_dora=True, lora_proj_layers=None,
|
158 |
lora_rank: int = 192, lora_alpha: float = 16,
|
159 |
-
cross_attn_shrink_factor: float = 0.5,
|
160 |
q_lora_updates_query=False, attn_proc_idx=-1):
|
161 |
super().__init__()
|
162 |
|
163 |
self.global_enable_lora = enable_lora
|
164 |
self.attn_proc_idx = attn_proc_idx
|
165 |
# reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
|
166 |
-
# By default,
|
167 |
-
self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora)
|
168 |
self.lora_rank = lora_rank
|
169 |
self.lora_alpha = lora_alpha
|
170 |
self.lora_scale = self.lora_alpha / self.lora_rank
|
171 |
-
self.cross_attn_shrink_factor = cross_attn_shrink_factor
|
172 |
self.q_lora_updates_query = q_lora_updates_query
|
173 |
|
174 |
self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
|
175 |
if self.global_enable_lora:
|
|
|
|
|
|
|
|
|
176 |
for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
|
177 |
if lora_layer_name == 'q':
|
178 |
self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
@@ -188,9 +181,10 @@ class AttnProcessor_LoRA_Capture(nn.Module):
|
|
188 |
use_dora=lora_uses_dora, lora_dropout=0.1)
|
189 |
|
190 |
# LoRA layers can be enabled/disabled dynamically.
|
191 |
-
def reset_attn_cache_and_flags(self, capture_ca_activations,
|
192 |
self.capture_ca_activations = capture_ca_activations
|
193 |
-
self.
|
|
|
194 |
self.cached_activations = {}
|
195 |
# Only enable LoRA for the next call(s) if global_enable_lora is set to True.
|
196 |
self.enable_lora = enable_lora and self.global_enable_lora
|
@@ -312,11 +306,14 @@ class AttnProcessor_LoRA_Capture(nn.Module):
|
|
312 |
breakpoint()
|
313 |
|
314 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
315 |
-
if is_cross_attn and (self.capture_ca_activations or self.
|
316 |
hidden_states, attn_score, attn_prob = \
|
317 |
scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
|
318 |
-
dropout_p=0.0,
|
319 |
-
|
|
|
|
|
|
|
320 |
else:
|
321 |
# Use the faster implementation of scaled_dot_product_attention
|
322 |
# when not capturing the activations or suppressing the subject attention.
|
@@ -452,7 +449,7 @@ def CrossAttnUpBlock2D_forward_capture(
|
|
452 |
# Adapted from ConsistentIDPipeline:set_ip_adapter().
|
453 |
# attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
|
454 |
def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
455 |
-
lora_rank=192, lora_scale_down=8,
|
456 |
q_lora_updates_query=False):
|
457 |
attn_procs = {}
|
458 |
attn_capture_procs = {}
|
@@ -502,7 +499,6 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
|
|
502 |
lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
|
503 |
# LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
|
504 |
lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
|
505 |
-
cross_attn_shrink_factor=cross_attn_shrink_factor,
|
506 |
q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
|
507 |
|
508 |
attn_proc_idx += 1
|
@@ -513,6 +509,11 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
|
|
513 |
attn_capture_procs[name] = attn_capture_proc
|
514 |
|
515 |
if use_attn_lora:
|
|
|
|
|
|
|
|
|
|
|
516 |
for subname, module in attn_capture_proc.named_modules():
|
517 |
if isinstance(module, peft_lora.LoraLayer):
|
518 |
# ModuleDict doesn't allow "." in the key.
|
@@ -537,7 +538,7 @@ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k',
|
|
537 |
return attn_capture_procs, attn_opt_modules
|
538 |
|
539 |
# NOTE: cross-attn layers are included in the returned lora_modules.
|
540 |
-
def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=
|
541 |
# target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
|
542 |
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
|
543 |
# Cannot set to conv.+ as it will match added adapter module names, including
|
@@ -592,15 +593,18 @@ def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=1
|
|
592 |
def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
|
593 |
outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
|
594 |
use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
|
595 |
-
|
596 |
# For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
|
597 |
-
for attn_capture_proc in attn_capture_procs:
|
598 |
-
attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations,
|
|
|
599 |
# outfeat_capture_blocks only contains the last up block, up_blocks[3].
|
600 |
# It contains 3 FFN layers. We want to capture their output features.
|
601 |
for block in outfeat_capture_blocks:
|
602 |
block.capture_outfeats = capture_ca_activations
|
603 |
|
|
|
|
|
604 |
for block in res_hidden_states_gradscale_blocks:
|
605 |
block.res_hidden_states_gradscale = res_hidden_states_gradscale
|
606 |
|
@@ -639,6 +643,7 @@ def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat
|
|
639 |
block.cached_outfeats = {}
|
640 |
block.capture_outfeats = False
|
641 |
|
|
|
642 |
for layer_idx in captured_layer_indices:
|
643 |
# Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
|
644 |
# 23, 24 -> 1, 2 (!! not 0, 1 !!)
|
|
|
4 |
from typing import Optional, Tuple, Dict, Any
|
5 |
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
6 |
from diffusers.utils import logging, is_torch_version, deprecate
|
|
|
7 |
# UNet is a diffusers PeftAdapterMixin instance.
|
8 |
from diffusers.loaders.peft import PeftAdapterMixin
|
9 |
from peft import LoraConfig, get_peft_model
|
|
|
11 |
from peft.tuners.lora.dora import DoraLinearLayer
|
12 |
from einops import rearrange
|
13 |
import math, re
|
|
|
14 |
from peft.tuners.tuners_utils import BaseTunerLayer
|
15 |
|
16 |
|
|
|
26 |
ctx.save_for_backward(alpha_, debug)
|
27 |
output = input_
|
28 |
if debug:
|
29 |
+
print(f"input: {input_.abs().mean().detach().item()}")
|
30 |
return output
|
31 |
|
32 |
@staticmethod
|
|
|
36 |
if ctx.needs_input_grad[0]:
|
37 |
grad_output2 = grad_output * alpha_
|
38 |
if debug:
|
39 |
+
print(f"grad_output2: {grad_output2.abs().mean().detach().item()}")
|
40 |
else:
|
41 |
grad_output2 = None
|
42 |
return grad_output2, None, None
|
|
|
75 |
indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
|
76 |
return indices_by_instance
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
# Slow implementation equivalent to F.scaled_dot_product_attention.
|
79 |
+
def scaled_dot_product_attention(query, key, value, cross_attn_scale_factor,
|
80 |
+
attn_mask=None, dropout_p=0.0,
|
81 |
+
subj_indices=None, normalize_cross_attn=False,
|
82 |
+
mix_attn_mats_in_batch=False,
|
83 |
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
|
84 |
B, L, S = query.size(0), query.size(-2), key.size(-2)
|
85 |
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
|
|
101 |
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
102 |
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
103 |
|
104 |
+
attn_score = query @ key.transpose(-2, -1) * scale_factor
|
105 |
|
106 |
+
# attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_score.
|
107 |
+
attn_score += attn_bias
|
108 |
+
if mix_attn_mats_in_batch:
|
109 |
+
# The instances in the batch are [sc, mc]. We average their attn scores,
|
110 |
+
# and apply to both instances.
|
111 |
+
# attn_score: [2, 8, 4096, 77] -> [1, 8, 4096, 77] -> [2, 8, 4096, 77].
|
112 |
+
# If BLOCK_SIZE > 1, attn_score.shape[0] = 2 * BLOCK_SIZE.
|
113 |
+
if attn_score.shape[0] %2 != 0:
|
114 |
+
breakpoint()
|
115 |
+
attn_score_sc, attn_score_mc = attn_score.chunk(2, dim=0)
|
116 |
+
# Cut off the grad flow from the SC instance to the MC instance.
|
117 |
+
attn_score = (attn_score_sc + attn_score_mc.detach()) / 2
|
118 |
+
attn_score = attn_score.repeat(2, 1, 1, 1)
|
119 |
+
elif normalize_cross_attn:
|
120 |
+
if subj_indices is None:
|
121 |
+
breakpoint()
|
122 |
+
subj_indices_B, subj_indices_N = subj_indices
|
123 |
+
subj_attn_score = attn_score[subj_indices_B, :, :, subj_indices_N]
|
124 |
+
# Normalize the attention score of the subject tokens to have mean 0 across tokens,
|
125 |
+
# so that positive and negative scores are balanced.
|
126 |
+
subj_attn_score = subj_attn_score - subj_attn_score.mean(dim=2, keepdim=True).detach()
|
127 |
+
# cross_attn_scale is a learnable parameter, so the score will be scaled appropriately.
|
128 |
+
# Scale up the BP'ed gradient to cross_attn_scale_factor by 10x.
|
129 |
+
ca_scale_grad_scaler = gen_gradient_scaler(10)
|
130 |
+
subj_attn_score = subj_attn_score * ca_scale_grad_scaler(cross_attn_scale_factor)
|
131 |
+
attn_score2 = attn_score.clone()
|
132 |
+
attn_score2[subj_indices_B, :, :, subj_indices_N] = subj_attn_score
|
133 |
+
attn_score = attn_score2
|
134 |
+
# Otherwise, do nothing to attn_score.
|
135 |
+
|
136 |
+
attn_weight = torch.softmax(attn_score, dim=-1)
|
137 |
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
138 |
output = attn_weight @ value
|
139 |
return output, attn_score, attn_weight
|
|
|
147 |
def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
|
148 |
lora_uses_dora=True, lora_proj_layers=None,
|
149 |
lora_rank: int = 192, lora_alpha: float = 16,
|
|
|
150 |
q_lora_updates_query=False, attn_proc_idx=-1):
|
151 |
super().__init__()
|
152 |
|
153 |
self.global_enable_lora = enable_lora
|
154 |
self.attn_proc_idx = attn_proc_idx
|
155 |
# reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
|
156 |
+
# By default, normalize_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
|
157 |
+
self.reset_attn_cache_and_flags(capture_ca_activations, False, False, enable_lora)
|
158 |
self.lora_rank = lora_rank
|
159 |
self.lora_alpha = lora_alpha
|
160 |
self.lora_scale = self.lora_alpha / self.lora_rank
|
|
|
161 |
self.q_lora_updates_query = q_lora_updates_query
|
162 |
|
163 |
self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
|
164 |
if self.global_enable_lora:
|
165 |
+
# enable_lora = True iff this is a cross-attn layer in the last 3 up blocks.
|
166 |
+
# Since we only use cross_attn_scale_factor on cross-attn layers,
|
167 |
+
# we only use cross_attn_scale_factor when enable_lora is True.
|
168 |
+
self.cross_attn_scale_factor = nn.Parameter(torch.tensor(0.8), requires_grad=True)
|
169 |
for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
|
170 |
if lora_layer_name == 'q':
|
171 |
self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
|
|
181 |
use_dora=lora_uses_dora, lora_dropout=0.1)
|
182 |
|
183 |
# LoRA layers can be enabled/disabled dynamically.
|
184 |
+
def reset_attn_cache_and_flags(self, capture_ca_activations, normalize_cross_attn, mix_attn_mats_in_batch, enable_lora):
|
185 |
self.capture_ca_activations = capture_ca_activations
|
186 |
+
self.normalize_cross_attn = normalize_cross_attn
|
187 |
+
self.mix_attn_mats_in_batch = mix_attn_mats_in_batch
|
188 |
self.cached_activations = {}
|
189 |
# Only enable LoRA for the next call(s) if global_enable_lora is set to True.
|
190 |
self.enable_lora = enable_lora and self.global_enable_lora
|
|
|
306 |
breakpoint()
|
307 |
|
308 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
309 |
+
if is_cross_attn and (self.capture_ca_activations or self.normalize_cross_attn):
|
310 |
hidden_states, attn_score, attn_prob = \
|
311 |
scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
|
312 |
+
dropout_p=0.0, subj_indices=subj_indices,
|
313 |
+
normalize_cross_attn=self.normalize_cross_attn,
|
314 |
+
cross_attn_scale_factor=self.cross_attn_scale_factor,
|
315 |
+
mix_attn_mats_in_batch=self.mix_attn_mats_in_batch)
|
316 |
+
|
317 |
else:
|
318 |
# Use the faster implementation of scaled_dot_product_attention
|
319 |
# when not capturing the activations or suppressing the subject attention.
|
|
|
449 |
# Adapted from ConsistentIDPipeline:set_ip_adapter().
|
450 |
# attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
|
451 |
def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
452 |
+
lora_rank=192, lora_scale_down=8,
|
453 |
q_lora_updates_query=False):
|
454 |
attn_procs = {}
|
455 |
attn_capture_procs = {}
|
|
|
499 |
lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
|
500 |
# LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
|
501 |
lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
|
|
|
502 |
q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
|
503 |
|
504 |
attn_proc_idx += 1
|
|
|
509 |
attn_capture_procs[name] = attn_capture_proc
|
510 |
|
511 |
if use_attn_lora:
|
512 |
+
cross_attn_scale_factor_name = name + "_cross_attn_scale_factor"
|
513 |
+
# Put cross_attn_scale_factor in attn_opt_modules, so that we can optimize and save/load it.
|
514 |
+
attn_opt_modules[cross_attn_scale_factor_name] = attn_capture_proc.cross_attn_scale_factor
|
515 |
+
|
516 |
+
# Put LoRA layers in attn_opt_modules, so that we can optimize and save/load them.
|
517 |
for subname, module in attn_capture_proc.named_modules():
|
518 |
if isinstance(module, peft_lora.LoraLayer):
|
519 |
# ModuleDict doesn't allow "." in the key.
|
|
|
538 |
return attn_capture_procs, attn_opt_modules
|
539 |
|
540 |
# NOTE: cross-attn layers are included in the returned lora_modules.
|
541 |
+
def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=True, lora_rank=192, lora_alpha=16):
|
542 |
# target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
|
543 |
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
|
544 |
# Cannot set to conv.+ as it will match added adapter module names, including
|
|
|
593 |
def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
|
594 |
outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
|
595 |
use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
|
596 |
+
normalize_cross_attn, mix_attn_mats_in_batch, res_hidden_states_gradscale):
|
597 |
# For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
|
598 |
+
for i, attn_capture_proc in enumerate(attn_capture_procs):
|
599 |
+
attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, normalize_cross_attn, mix_attn_mats_in_batch,
|
600 |
+
enable_lora=use_attn_lora)
|
601 |
# outfeat_capture_blocks only contains the last up block, up_blocks[3].
|
602 |
# It contains 3 FFN layers. We want to capture their output features.
|
603 |
for block in outfeat_capture_blocks:
|
604 |
block.capture_outfeats = capture_ca_activations
|
605 |
|
606 |
+
# res_hidden_states_gradscale_blocks contain the second to the last up blocks, up_blocks[1:].
|
607 |
+
# It's only used to set res_hidden_states_gradscale, and doesn't capture anything.
|
608 |
for block in res_hidden_states_gradscale_blocks:
|
609 |
block.res_hidden_states_gradscale = res_hidden_states_gradscale
|
610 |
|
|
|
643 |
block.cached_outfeats = {}
|
644 |
block.capture_outfeats = False
|
645 |
|
646 |
+
|
647 |
for layer_idx in captured_layer_indices:
|
648 |
# Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
|
649 |
# 23, 24 -> 1, 2 (!! not 0, 1 !!)
|
adaface/face_id_to_ada_prompt.py
CHANGED
@@ -603,9 +603,13 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
603 |
'''
|
604 |
# Use the same model as ID2AdaPrompt does.
|
605 |
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
|
606 |
-
# Note there
|
|
|
|
|
|
|
|
|
607 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
608 |
-
|
609 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
610 |
print(f'Arc2Face Face encoder loaded on CPU.')
|
611 |
|
@@ -642,7 +646,6 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
642 |
|
643 |
def _apply(self, fn):
|
644 |
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
645 |
-
return
|
646 |
# A dirty hack to get the device of the model, passed from
|
647 |
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
648 |
test_tensor = torch.zeros(1) # Create a test tensor
|
@@ -654,16 +657,14 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
654 |
|
655 |
if str(device) == 'cpu':
|
656 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
657 |
-
|
658 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
659 |
else:
|
660 |
device_id = device.index
|
661 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
662 |
providers=['CUDAExecutionProvider'],
|
663 |
-
provider_options=[{
|
664 |
-
|
665 |
-
"gpu_mem_limit": 2 * 1024**3
|
666 |
-
}])
|
667 |
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
668 |
|
669 |
self.device = device
|
@@ -739,8 +740,8 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
739 |
# but diffusers will call .to(dtype) in .from_single_file(),
|
740 |
# and at that moment, the consistentID specific modules are not loaded yet.
|
741 |
pipe = ConsistentIDPipeline.from_single_file(base_model_path)
|
742 |
-
pipe.load_ConsistentID_model(consistentID_weight_path="
|
743 |
-
bise_net_weight_path="
|
744 |
pipe.to(dtype=self.dtype)
|
745 |
# Since the passed-in pipe is None, this should be called during inference,
|
746 |
# when the teacher ConsistentIDPipeline is not initialized.
|
@@ -791,7 +792,6 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
791 |
|
792 |
def _apply(self, fn):
|
793 |
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
794 |
-
return
|
795 |
# A dirty hack to get the device of the model, passed from
|
796 |
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
797 |
test_tensor = torch.zeros(1) # Create a test tensor
|
@@ -809,10 +809,8 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
809 |
device_id = device.index
|
810 |
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
811 |
providers=['CUDAExecutionProvider'],
|
812 |
-
provider_options=[{
|
813 |
-
|
814 |
-
"gpu_mem_limit": 2 * 1024**3
|
815 |
-
}])
|
816 |
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
817 |
|
818 |
self.device = device
|
@@ -1277,7 +1275,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1277 |
# No faces are found in the images, so return None embeddings.
|
1278 |
# We don't want to return an all-zero embedding, which is useless.
|
1279 |
if num_available_id_vecs == 0:
|
1280 |
-
return None, [0]
|
1281 |
|
1282 |
# If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
|
1283 |
# during inference, we average across the batch dim.
|
|
|
603 |
'''
|
604 |
# Use the same model as ID2AdaPrompt does.
|
605 |
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
|
606 |
+
# Note there's a second "model" in the path.
|
607 |
+
# Note DO use CUDAExecutionProvider during training and CPUExecutionProvider during inference.
|
608 |
+
# Otherwise, CPUExecutionProvider will hang DDP training,
|
609 |
+
# and CUDAExecutionProvider will cause OOM on huggingface spaces.
|
610 |
+
self.onnx_providers = ['CUDAExecutionProvider'] if self.is_training else ['CPUExecutionProvider']
|
611 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
612 |
+
providers=self.onnx_providers)
|
613 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
614 |
print(f'Arc2Face Face encoder loaded on CPU.')
|
615 |
|
|
|
646 |
|
647 |
def _apply(self, fn):
|
648 |
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
|
|
649 |
# A dirty hack to get the device of the model, passed from
|
650 |
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
651 |
test_tensor = torch.zeros(1) # Create a test tensor
|
|
|
657 |
|
658 |
if str(device) == 'cpu':
|
659 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
660 |
+
providers=['CPUExecutionProvider'])
|
661 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
662 |
else:
|
663 |
device_id = device.index
|
664 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
665 |
providers=['CUDAExecutionProvider'],
|
666 |
+
provider_options=[{'device_id': device_id,
|
667 |
+
'cudnn_conv_algo_search': 'HEURISTIC'}])
|
|
|
|
|
668 |
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
669 |
|
670 |
self.device = device
|
|
|
740 |
# but diffusers will call .to(dtype) in .from_single_file(),
|
741 |
# and at that moment, the consistentID specific modules are not loaded yet.
|
742 |
pipe = ConsistentIDPipeline.from_single_file(base_model_path)
|
743 |
+
pipe.load_ConsistentID_model(consistentID_weight_path="models/ConsistentID/ConsistentID-v1.bin",
|
744 |
+
bise_net_weight_path="models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth")
|
745 |
pipe.to(dtype=self.dtype)
|
746 |
# Since the passed-in pipe is None, this should be called during inference,
|
747 |
# when the teacher ConsistentIDPipeline is not initialized.
|
|
|
792 |
|
793 |
def _apply(self, fn):
|
794 |
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
|
|
795 |
# A dirty hack to get the device of the model, passed from
|
796 |
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
797 |
test_tensor = torch.zeros(1) # Create a test tensor
|
|
|
809 |
device_id = device.index
|
810 |
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
811 |
providers=['CUDAExecutionProvider'],
|
812 |
+
provider_options=[{'device_id': device_id,
|
813 |
+
'cudnn_conv_algo_search': 'HEURISTIC'}])
|
|
|
|
|
814 |
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
815 |
|
816 |
self.device = device
|
|
|
1275 |
# No faces are found in the images, so return None embeddings.
|
1276 |
# We don't want to return an all-zero embedding, which is useless.
|
1277 |
if num_available_id_vecs == 0:
|
1278 |
+
return None, None, [0]
|
1279 |
|
1280 |
# If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
|
1281 |
# during inference, we average across the batch dim.
|
adaface/unet_teachers.py
CHANGED
@@ -62,46 +62,41 @@ class UNetTeacher(nn.Module):
|
|
62 |
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
|
63 |
# same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
|
64 |
def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
|
65 |
-
num_denoising_steps=1, same_t_noise_across_instances=False,
|
66 |
global_t_lb=0, global_t_ub=1000):
|
67 |
assert num_denoising_steps <= 10
|
68 |
|
69 |
-
|
|
|
|
|
|
|
70 |
self.uses_cfg = np.random.rand() < self.p_uses_cfg
|
71 |
-
if self.uses_cfg:
|
72 |
-
# Randomly sample a cfg_scale from cfg_scale_range.
|
73 |
-
self.cfg_scale = np.random.uniform(*self.cfg_scale_range)
|
74 |
-
if self.cfg_scale == 1:
|
75 |
-
self.uses_cfg = False
|
76 |
-
|
77 |
-
if self.uses_cfg:
|
78 |
-
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
|
79 |
-
if negative_context is not None:
|
80 |
-
negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
|
81 |
-
|
82 |
-
# if negative_context is None, then teacher_context is a combination of
|
83 |
-
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
84 |
-
# If negative_context is not None, then teacher_context is only pos_context.
|
85 |
-
else:
|
86 |
-
self.cfg_scale = 1
|
87 |
-
print("Teacher does not use CFG.")
|
88 |
-
|
89 |
-
# If negative_context is None, then teacher_context is a combination of
|
90 |
-
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
91 |
-
# Since not uses_cfg, we only need pos_context.
|
92 |
-
# If negative_context is not None, then teacher_context is only pos_context.
|
93 |
-
if negative_context is None:
|
94 |
-
teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
|
95 |
else:
|
96 |
# p_uses_cfg = 0. Never use CFG.
|
97 |
self.uses_cfg = False
|
98 |
-
# In this case, the student only passes pos_context to the teacher,
|
99 |
-
# so no need to split teacher_context into pos_context and neg_context.
|
100 |
-
# self.cfg_scale will be accessed by the student,
|
101 |
-
# so we need to make sure it is always set correctly,
|
102 |
-
# in case someday we want to switch from CFG to non-CFG during runtime.
|
103 |
self.cfg_scale = 1
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
|
106 |
if self.name == 'unet_ensemble':
|
107 |
# teacher_context is a list of teacher contexts.
|
@@ -199,14 +194,20 @@ class UNetTeacher(nn.Module):
|
|
199 |
teacher_pos_contexts = []
|
200 |
# teacher_context is a list of teacher contexts.
|
201 |
for teacher_context_i in teacher_context:
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
205 |
teacher_pos_contexts.append(pos_context)
|
206 |
teacher_context = teacher_pos_contexts
|
207 |
else:
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
210 |
breakpoint()
|
211 |
teacher_context = pos_context
|
212 |
|
|
|
62 |
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
|
63 |
# same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
|
64 |
def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
|
65 |
+
num_denoising_steps=1, force_uses_cfg=False, same_t_noise_across_instances=False,
|
66 |
global_t_lb=0, global_t_ub=1000):
|
67 |
assert num_denoising_steps <= 10
|
68 |
|
69 |
+
# force_uses_cfg overrides p_uses_cfg.
|
70 |
+
if force_uses_cfg > 0:
|
71 |
+
self.uses_cfg = True
|
72 |
+
elif self.p_uses_cfg > 0:
|
73 |
self.uses_cfg = np.random.rand() < self.p_uses_cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
else:
|
75 |
# p_uses_cfg = 0. Never use CFG.
|
76 |
self.uses_cfg = False
|
|
|
|
|
|
|
|
|
|
|
77 |
self.cfg_scale = 1
|
78 |
|
79 |
+
if self.uses_cfg:
|
80 |
+
# Randomly sample a cfg_scale from cfg_scale_range.
|
81 |
+
self.cfg_scale = np.random.uniform(*self.cfg_scale_range)
|
82 |
+
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
|
83 |
+
if negative_context is not None:
|
84 |
+
negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
|
85 |
+
|
86 |
+
# if negative_context is None, then teacher_context is a combination of
|
87 |
+
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
88 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
89 |
+
else:
|
90 |
+
self.cfg_scale = 1
|
91 |
+
print("Teacher does not use CFG.")
|
92 |
+
|
93 |
+
# If negative_context is None, then teacher_context is either a combination of
|
94 |
+
# (one or multiple if unet_ensemble) pos_context and neg_context, or only pos_context.
|
95 |
+
# Since not uses_cfg, we only need pos_context.
|
96 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
97 |
+
if negative_context is None:
|
98 |
+
teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
|
99 |
+
|
100 |
is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
|
101 |
if self.name == 'unet_ensemble':
|
102 |
# teacher_context is a list of teacher contexts.
|
|
|
194 |
teacher_pos_contexts = []
|
195 |
# teacher_context is a list of teacher contexts.
|
196 |
for teacher_context_i in teacher_context:
|
197 |
+
if teacher_context_i.shape[0] == BS * 2:
|
198 |
+
pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
|
199 |
+
elif teacher_context_i.shape[0] == BS:
|
200 |
+
pos_context = teacher_context_i
|
201 |
+
else:
|
202 |
+
breakpoint()
|
203 |
teacher_pos_contexts.append(pos_context)
|
204 |
teacher_context = teacher_pos_contexts
|
205 |
else:
|
206 |
+
if teacher_context.shape[0] == BS * 2:
|
207 |
+
pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
|
208 |
+
elif teacher_context.shape[0] == BS:
|
209 |
+
pos_context = teacher_context
|
210 |
+
else:
|
211 |
breakpoint()
|
212 |
teacher_context = pos_context
|
213 |
|
adaface/util.py
CHANGED
@@ -48,7 +48,7 @@ def perturb_tensor(ts, perturb_std, perturb_std_is_relative=True, keep_norm=Fals
|
|
48 |
ts = ts + noise
|
49 |
|
50 |
if verbose:
|
51 |
-
print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).item():.03f}")
|
52 |
|
53 |
return ts
|
54 |
|
@@ -69,7 +69,7 @@ def calc_stats(emb_name, embeddings, mean_dim=-1):
|
|
69 |
# Compute it manually.
|
70 |
l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt()
|
71 |
norms = torch.norm(embeddings, dim=1).detach().cpu().numpy()
|
72 |
-
print("L1: %.4f, L2: %.4f" %(l1_loss.item(), l2_loss.item()))
|
73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
74 |
|
75 |
|
@@ -80,7 +80,7 @@ class ScaleGrad(torch.autograd.Function):
|
|
80 |
ctx.save_for_backward(alpha_, debug)
|
81 |
output = input_
|
82 |
if debug:
|
83 |
-
print(f"input: {input_.abs().mean().item()}")
|
84 |
return output
|
85 |
|
86 |
@staticmethod
|
@@ -90,7 +90,7 @@ class ScaleGrad(torch.autograd.Function):
|
|
90 |
if ctx.needs_input_grad[0]:
|
91 |
grad_output2 = grad_output * alpha_
|
92 |
if debug:
|
93 |
-
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
94 |
else:
|
95 |
grad_output2 = None
|
96 |
return grad_output2, None, None
|
@@ -232,8 +232,8 @@ def create_consistentid_pipeline(base_model_path="models/sd15-dste8-vae.safetens
|
|
232 |
# consistentID specific modules are still in fp32. Will be converted to fp16
|
233 |
# later with .to(device, torch_dtype) by the caller.
|
234 |
pipe.load_ConsistentID_model(
|
235 |
-
consistentID_weight_path="
|
236 |
-
bise_net_weight_path="
|
237 |
)
|
238 |
# Avoid passing dtype to ConsistentIDPipeline.from_single_file(),
|
239 |
# because we've overloaded .to() to convert consistentID specific modules as well,
|
|
|
48 |
ts = ts + noise
|
49 |
|
50 |
if verbose:
|
51 |
+
print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).detach().item():.03f}")
|
52 |
|
53 |
return ts
|
54 |
|
|
|
69 |
# Compute it manually.
|
70 |
l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt()
|
71 |
norms = torch.norm(embeddings, dim=1).detach().cpu().numpy()
|
72 |
+
print("L1: %.4f, L2: %.4f" %(l1_loss.detach().item(), l2_loss.detach().item()))
|
73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
74 |
|
75 |
|
|
|
80 |
ctx.save_for_backward(alpha_, debug)
|
81 |
output = input_
|
82 |
if debug:
|
83 |
+
print(f"input: {input_.abs().mean().detach().item()}")
|
84 |
return output
|
85 |
|
86 |
@staticmethod
|
|
|
90 |
if ctx.needs_input_grad[0]:
|
91 |
grad_output2 = grad_output * alpha_
|
92 |
if debug:
|
93 |
+
print(f"grad_output2: {grad_output2.abs().mean().detach().item()}")
|
94 |
else:
|
95 |
grad_output2 = None
|
96 |
return grad_output2, None, None
|
|
|
232 |
# consistentID specific modules are still in fp32. Will be converted to fp16
|
233 |
# later with .to(device, torch_dtype) by the caller.
|
234 |
pipe.load_ConsistentID_model(
|
235 |
+
consistentID_weight_path="models/ConsistentID/ConsistentID-v1.bin",
|
236 |
+
bise_net_weight_path="models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
|
237 |
)
|
238 |
# Avoid passing dtype to ConsistentIDPipeline.from_single_file(),
|
239 |
# because we've overloaded .to() to convert consistentID specific modules as well,
|
app.py
CHANGED
@@ -24,11 +24,16 @@ parser = argparse.ArgumentParser()
|
|
24 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
25 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
26 |
parser.add_argument('--adaface_ckpt_path', type=str,
|
27 |
-
default='models/adaface/VGGface2_HQ_masks2025-
|
28 |
parser.add_argument('--model_style_type', type=str, default='photorealistic',
|
29 |
choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
|
30 |
-
parser.add_argument("--guidance_scale", type=float, default=
|
31 |
-
help="The guidance scale for the diffusion model. Default:
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
parser.add_argument('--gpu', type=int, default=None)
|
34 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
@@ -70,7 +75,8 @@ adaface_base_model_path = model_style_type2base_model_path["photorealistic"]
|
|
70 |
id_animator = load_model(model_style_type=args.model_style_type, device='cpu')
|
71 |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=adaface_base_model_path,
|
72 |
adaface_encoder_types=args.adaface_encoder_types,
|
73 |
-
adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu'
|
|
|
74 |
|
75 |
basedir = os.getcwd()
|
76 |
savedir = os.path.join(basedir,'samples')
|
@@ -80,22 +86,22 @@ os.makedirs(savedir, exist_ok=True)
|
|
80 |
#os.system(f"rm -rf gradio_cached_examples/")
|
81 |
|
82 |
def swap_to_gallery(images):
|
83 |
-
# Update
|
84 |
# Or:
|
85 |
-
# Update
|
86 |
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False)
|
87 |
|
88 |
def remove_back_to_files():
|
89 |
-
# Hide
|
90 |
# Or:
|
91 |
-
# Hide
|
92 |
-
return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True), gr.update(value="0")
|
93 |
|
94 |
def get_clicked_image(data: gr.SelectData):
|
95 |
return data.index
|
96 |
|
97 |
@spaces.GPU
|
98 |
-
def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale, out_image_count=4):
|
99 |
if uploaded_image_paths is None:
|
100 |
print("No image uploaded")
|
101 |
return None, None, None
|
@@ -112,7 +118,7 @@ def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale
|
|
112 |
with torch.no_grad():
|
113 |
adaface_subj_embs = \
|
114 |
adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
|
115 |
-
|
116 |
|
117 |
if adaface_subj_embs is None:
|
118 |
raise gr.Error(f"Failed to detect any faces! Please try with other images")
|
@@ -127,6 +133,7 @@ def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale
|
|
127 |
else:
|
128 |
prompt = "face portrait, " + prompt
|
129 |
|
|
|
130 |
guidance_scale = min(guidance_scale, 5)
|
131 |
|
132 |
# samples: A list of PIL Image instances.
|
@@ -134,7 +141,7 @@ def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale
|
|
134 |
samples = adaface(noise, prompt, placeholder_tokens_pos='append',
|
135 |
guidance_scale=guidance_scale,
|
136 |
out_image_count=out_image_count,
|
137 |
-
repeat_prompt_for_each_encoder=
|
138 |
verbose=True)
|
139 |
|
140 |
face_paths = []
|
@@ -145,7 +152,7 @@ def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale
|
|
145 |
sample.save(face_path)
|
146 |
print(f"Generated init image: {face_path}")
|
147 |
|
148 |
-
# Update
|
149 |
return gr.update(value=face_paths, visible=True), gr.update(value=face_paths, visible=False), gr.update(visible=True)
|
150 |
|
151 |
@spaces.GPU(duration=90)
|
@@ -153,7 +160,7 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
|
|
153 |
init_image_strength, init_image_final_weight,
|
154 |
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
155 |
seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
|
156 |
-
highlight_face, is_adaface_enabled, adaface_power_scale,
|
157 |
id_animator_anneal_steps, progress=gr.Progress(track_tqdm=True)):
|
158 |
|
159 |
global adaface, id_animator
|
@@ -195,10 +202,19 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
|
|
195 |
adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
|
196 |
update_text_encoder=True)
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
# adaface_prompt_embeds: [1, 77, 768].
|
199 |
adaface_prompt_embeds, negative_prompt_embeds, _, _ = \
|
200 |
adaface.encode_prompt(prompt, placeholder_tokens_pos='append',
|
201 |
-
|
|
|
202 |
verbose=True)
|
203 |
|
204 |
# ID-Animator Image Embedding Initial and End Scales
|
@@ -267,13 +283,14 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
267 |
<b>Official demo</b> for our working paper <b>AdaFace: A Versatile Text-space Face Encoder for Face Synthesis and Processing</b>.<br>
|
268 |
|
269 |
❗️**NOTE**❗️
|
270 |
-
- Support switching between three model styles: **
|
271 |
- If you change the model style, please wait for 20~30 seconds for loading new model weight before the model begins to generate images/videos.
|
272 |
|
273 |
❗️**Tips**❗️
|
274 |
- You can upload one or more subject images for generating ID-specific video.
|
275 |
-
-
|
276 |
-
-
|
|
|
277 |
- Usage explanations and demos: [Readme](https://huggingface.co/spaces/adaface-neurips/adaface-animate/blob/main/README2.md).
|
278 |
- AdaFace Text-to-Image: <a href="https://huggingface.co/spaces/adaface-neurips/adaface" style="display: inline-flex; align-items: center;">
|
279 |
AdaFace
|
@@ -285,16 +302,16 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
285 |
|
286 |
with gr.Row():
|
287 |
with gr.Column():
|
288 |
-
|
289 |
label="Drag / Select 1 or more photos of a person's face",
|
290 |
file_types=["image"],
|
291 |
file_count="multiple"
|
292 |
)
|
293 |
-
|
294 |
image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
|
295 |
-
|
296 |
with gr.Column(visible=False) as clear_button_column:
|
297 |
-
remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=
|
298 |
|
299 |
init_img_files = gr.File(
|
300 |
label="[Optional] Generate 4 images and select 1 image",
|
@@ -305,7 +322,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
305 |
init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
|
306 |
# Although there's only one image, we still use columns=3, to scale down the image size.
|
307 |
# Otherwise it will occupy the full width, and the gallery won't show the whole image.
|
308 |
-
|
309 |
# placeholder is just hint, not the real value. So we use "value='0'" instead of "placeholder='0'".
|
310 |
init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
|
311 |
|
@@ -320,7 +337,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
320 |
allow_custom_value=True,
|
321 |
choices=[
|
322 |
"portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
|
323 |
-
"portrait, walking on the beach, sunset",
|
324 |
"portrait, in a white apron and chef hat, garnishing a gourmet dish",
|
325 |
"portrait, dancing pose among folks in a park, waving hands",
|
326 |
"portrait, in iron man costume, the sky ablaze with hues of orange and purple",
|
@@ -328,18 +345,21 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
328 |
"portrait, night view of tokyo street, neon light",
|
329 |
"portrait, playing guitar on a boat, ocean waves",
|
330 |
"portrait, with a passion for reading, curled up with a book in a cozy nook near a window",
|
331 |
-
"portrait, celebrating new year, fireworks",
|
332 |
-
"portrait, running pose in a park",
|
333 |
"portrait, in space suit, space helmet, walking on mars",
|
334 |
"portrait, in superman costume, the sky ablaze with hues of orange and purple"
|
335 |
])
|
336 |
|
337 |
-
highlight_face = gr.Checkbox(label="Highlight face", value=
|
338 |
info="Enhance the facial features by prepending 'face portrait' to the prompt",
|
339 |
visible=True)
|
340 |
-
|
|
|
|
|
|
|
341 |
init_image_strength = gr.Slider(
|
342 |
-
label="Init Image
|
343 |
info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).",
|
344 |
minimum=0,
|
345 |
maximum=3,
|
@@ -352,7 +372,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
352 |
minimum=0,
|
353 |
maximum=2,
|
354 |
step=0.025,
|
355 |
-
value=0.
|
356 |
)
|
357 |
|
358 |
model_style_type = gr.Dropdown(
|
@@ -415,7 +435,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
415 |
minimum=0.8,
|
416 |
maximum=1.2,
|
417 |
step=0.05,
|
418 |
-
value=1.
|
419 |
visible=True,
|
420 |
)
|
421 |
|
@@ -443,7 +463,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
443 |
minimum=0,
|
444 |
maximum=1,
|
445 |
step=0.1,
|
446 |
-
value=0.
|
447 |
)
|
448 |
|
449 |
id_animator_anneal_steps = gr.Slider(
|
@@ -464,13 +484,13 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
464 |
with gr.Column():
|
465 |
result_video = gr.Video(label="Generated Animation", interactive=False)
|
466 |
|
467 |
-
|
468 |
-
remove_and_reupload.click(fn=remove_back_to_files, outputs=[
|
469 |
|
470 |
init_img_files.upload(fn=swap_to_gallery, inputs=init_img_files,
|
471 |
-
outputs=[
|
472 |
remove_init_and_reupload.click(fn=remove_back_to_files,
|
473 |
-
outputs=[
|
474 |
init_img_files, init_img_selected_idx])
|
475 |
gen_init.click(fn=check_prompt_and_model_type,
|
476 |
inputs=[prompt, model_style_type],outputs=None).success(
|
@@ -479,10 +499,11 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
479 |
outputs=seed,
|
480 |
queue=False,
|
481 |
api_name=False,
|
482 |
-
).then(fn=gen_init_images, inputs=[
|
|
|
483 |
guidance_scale],
|
484 |
-
outputs=[
|
485 |
-
|
486 |
|
487 |
submit.click(fn=check_prompt_and_model_type,
|
488 |
inputs=[prompt, model_style_type],outputs=None).success(
|
@@ -493,11 +514,11 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
493 |
api_name=False,
|
494 |
).then(
|
495 |
fn=generate_video,
|
496 |
-
inputs=[image_container,
|
497 |
init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
|
498 |
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
499 |
seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
|
500 |
-
highlight_face, is_adaface_enabled,
|
501 |
adaface_power_scale, id_animator_anneal_steps],
|
502 |
outputs=[result_video]
|
503 |
)
|
|
|
24 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
25 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
26 |
parser.add_argument('--adaface_ckpt_path', type=str,
|
27 |
+
default='models/adaface/VGGface2_HQ_masks2025-05-22T17-51-19_zero3-ada-1000.pt')
|
28 |
parser.add_argument('--model_style_type', type=str, default='photorealistic',
|
29 |
choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
|
30 |
+
parser.add_argument("--guidance_scale", type=float, default=6.0,
|
31 |
+
help="The guidance scale for the diffusion model. Default: 6.0")
|
32 |
+
parser.add_argument('--num_inference_steps', type=int, default=50,
|
33 |
+
help="The number of denoising steps for image generation (NOT FOR VIDEOS). Default: 50")
|
34 |
+
parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
|
35 |
+
choices=["ada", "arc2face", "consistentID"],
|
36 |
+
help="Ablate to use the image ID embs instead of Ada embs")
|
37 |
|
38 |
parser.add_argument('--gpu', type=int, default=None)
|
39 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
|
|
75 |
id_animator = load_model(model_style_type=args.model_style_type, device='cpu')
|
76 |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=adaface_base_model_path,
|
77 |
adaface_encoder_types=args.adaface_encoder_types,
|
78 |
+
adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu',
|
79 |
+
num_inference_steps=args.num_inference_steps)
|
80 |
|
81 |
basedir = os.getcwd()
|
82 |
savedir = os.path.join(basedir,'samples')
|
|
|
86 |
#os.system(f"rm -rf gradio_cached_examples/")
|
87 |
|
88 |
def swap_to_gallery(images):
|
89 |
+
# Update uploaded_ref_files_gallery, show ref_files, hide clear_button_column
|
90 |
# Or:
|
91 |
+
# Update generated_init_img_gallery, show init_img_files, hide init_clear_button_column
|
92 |
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False)
|
93 |
|
94 |
def remove_back_to_files():
|
95 |
+
# Hide uploaded_ref_files_gallery, show clear_button_column, hide ref_files, reset init_img_selected_idx
|
96 |
# Or:
|
97 |
+
# Hide generated_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
|
98 |
+
return gr.update(value=None, visible=False), gr.update(visible=False), gr.update(value=None, visible=True), gr.update(value="0")
|
99 |
|
100 |
def get_clicked_image(data: gr.SelectData):
|
101 |
return data.index
|
102 |
|
103 |
@spaces.GPU
|
104 |
+
def gen_init_images(uploaded_image_paths, prompt, highlight_face, enhance_composition, guidance_scale, out_image_count=4):
|
105 |
if uploaded_image_paths is None:
|
106 |
print("No image uploaded")
|
107 |
return None, None, None
|
|
|
118 |
with torch.no_grad():
|
119 |
adaface_subj_embs = \
|
120 |
adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
|
121 |
+
update_text_encoder=True)
|
122 |
|
123 |
if adaface_subj_embs is None:
|
124 |
raise gr.Error(f"Failed to detect any faces! Please try with other images")
|
|
|
133 |
else:
|
134 |
prompt = "face portrait, " + prompt
|
135 |
|
136 |
+
# guidance_scale is at most 5.0 for init image generation.
|
137 |
guidance_scale = min(guidance_scale, 5)
|
138 |
|
139 |
# samples: A list of PIL Image instances.
|
|
|
141 |
samples = adaface(noise, prompt, placeholder_tokens_pos='append',
|
142 |
guidance_scale=guidance_scale,
|
143 |
out_image_count=out_image_count,
|
144 |
+
repeat_prompt_for_each_encoder=enhance_composition,
|
145 |
verbose=True)
|
146 |
|
147 |
face_paths = []
|
|
|
152 |
sample.save(face_path)
|
153 |
print(f"Generated init image: {face_path}")
|
154 |
|
155 |
+
# Update generated_init_img_gallery, update and hide init_img_files, hide init_clear_button_column
|
156 |
return gr.update(value=face_paths, visible=True), gr.update(value=face_paths, visible=False), gr.update(visible=True)
|
157 |
|
158 |
@spaces.GPU(duration=90)
|
|
|
160 |
init_image_strength, init_image_final_weight,
|
161 |
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
162 |
seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
|
163 |
+
highlight_face, enhance_composition, is_adaface_enabled, adaface_power_scale,
|
164 |
id_animator_anneal_steps, progress=gr.Progress(track_tqdm=True)):
|
165 |
|
166 |
global adaface, id_animator
|
|
|
202 |
adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
|
203 |
update_text_encoder=True)
|
204 |
|
205 |
+
if args.ablate_prompt_embed_type != "ada":
|
206 |
+
# Find the prompt_emb_type index in adaface_encoder_types
|
207 |
+
# adaface_encoder_types: ["consistentID", "arc2face"]
|
208 |
+
ablate_prompt_embed_index = args.adaface_encoder_types.index(args.ablate_prompt_embed_type)
|
209 |
+
ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
|
210 |
+
else:
|
211 |
+
ablate_prompt_embed_type = "ada"
|
212 |
+
|
213 |
# adaface_prompt_embeds: [1, 77, 768].
|
214 |
adaface_prompt_embeds, negative_prompt_embeds, _, _ = \
|
215 |
adaface.encode_prompt(prompt, placeholder_tokens_pos='append',
|
216 |
+
ablate_prompt_embed_type=ablate_prompt_embed_type,
|
217 |
+
repeat_prompt_for_each_encoder=enhance_composition,
|
218 |
verbose=True)
|
219 |
|
220 |
# ID-Animator Image Embedding Initial and End Scales
|
|
|
283 |
<b>Official demo</b> for our working paper <b>AdaFace: A Versatile Text-space Face Encoder for Face Synthesis and Processing</b>.<br>
|
284 |
|
285 |
❗️**NOTE**❗️
|
286 |
+
- Support switching between three model styles: **Photorealistic**, **Realistic** and **Anime**.
|
287 |
- If you change the model style, please wait for 20~30 seconds for loading new model weight before the model begins to generate images/videos.
|
288 |
|
289 |
❗️**Tips**❗️
|
290 |
- You can upload one or more subject images for generating ID-specific video.
|
291 |
+
- "Highlight face" will make the face more prominent in the generated video.
|
292 |
+
- "Enhance Composition" will enhance the overall composition of the generated video.
|
293 |
+
- "Highlight face" and "Enhance Composition" can be used together.
|
294 |
- Usage explanations and demos: [Readme](https://huggingface.co/spaces/adaface-neurips/adaface-animate/blob/main/README2.md).
|
295 |
- AdaFace Text-to-Image: <a href="https://huggingface.co/spaces/adaface-neurips/adaface" style="display: inline-flex; align-items: center;">
|
296 |
AdaFace
|
|
|
302 |
|
303 |
with gr.Row():
|
304 |
with gr.Column():
|
305 |
+
ref_files = gr.File(
|
306 |
label="Drag / Select 1 or more photos of a person's face",
|
307 |
file_types=["image"],
|
308 |
file_count="multiple"
|
309 |
)
|
310 |
+
ref_files.GRADIO_CACHE = "/tmp/gradio"
|
311 |
image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
|
312 |
+
uploaded_ref_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=2, height=300)
|
313 |
with gr.Column(visible=False) as clear_button_column:
|
314 |
+
remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=ref_files, size="sm")
|
315 |
|
316 |
init_img_files = gr.File(
|
317 |
label="[Optional] Generate 4 images and select 1 image",
|
|
|
322 |
init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
|
323 |
# Although there's only one image, we still use columns=3, to scale down the image size.
|
324 |
# Otherwise it will occupy the full width, and the gallery won't show the whole image.
|
325 |
+
generated_init_img_gallery = gr.Gallery(label="Init image", visible=False, columns=3, rows=1, height=200)
|
326 |
# placeholder is just hint, not the real value. So we use "value='0'" instead of "placeholder='0'".
|
327 |
init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
|
328 |
|
|
|
337 |
allow_custom_value=True,
|
338 |
choices=[
|
339 |
"portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
|
340 |
+
"portrait, walking on the beach, front of face, sunset",
|
341 |
"portrait, in a white apron and chef hat, garnishing a gourmet dish",
|
342 |
"portrait, dancing pose among folks in a park, waving hands",
|
343 |
"portrait, in iron man costume, the sky ablaze with hues of orange and purple",
|
|
|
345 |
"portrait, night view of tokyo street, neon light",
|
346 |
"portrait, playing guitar on a boat, ocean waves",
|
347 |
"portrait, with a passion for reading, curled up with a book in a cozy nook near a window",
|
348 |
+
"portrait, celebrating new year alone, fireworks",
|
349 |
+
"portrait, running pose in a park, front view",
|
350 |
"portrait, in space suit, space helmet, walking on mars",
|
351 |
"portrait, in superman costume, the sky ablaze with hues of orange and purple"
|
352 |
])
|
353 |
|
354 |
+
highlight_face = gr.Checkbox(label="Highlight face", value=True,
|
355 |
info="Enhance the facial features by prepending 'face portrait' to the prompt",
|
356 |
visible=True)
|
357 |
+
enhance_composition = gr.Checkbox(label="Enhance Composition", value=False,
|
358 |
+
info="Enhance the overall composition of the generated video",
|
359 |
+
visible=True)
|
360 |
+
|
361 |
init_image_strength = gr.Slider(
|
362 |
+
label="Beginning Strength of Init Image",
|
363 |
info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).",
|
364 |
minimum=0,
|
365 |
maximum=3,
|
|
|
372 |
minimum=0,
|
373 |
maximum=2,
|
374 |
step=0.025,
|
375 |
+
value=0.5,
|
376 |
)
|
377 |
|
378 |
model_style_type = gr.Dropdown(
|
|
|
435 |
minimum=0.8,
|
436 |
maximum=1.2,
|
437 |
step=0.05,
|
438 |
+
value=1.05,
|
439 |
visible=True,
|
440 |
)
|
441 |
|
|
|
463 |
minimum=0,
|
464 |
maximum=1,
|
465 |
step=0.1,
|
466 |
+
value=0.5,
|
467 |
)
|
468 |
|
469 |
id_animator_anneal_steps = gr.Slider(
|
|
|
484 |
with gr.Column():
|
485 |
result_video = gr.Video(label="Generated Animation", interactive=False)
|
486 |
|
487 |
+
ref_files.upload(fn=swap_to_gallery, inputs=ref_files, outputs=[uploaded_ref_files_gallery, clear_button_column, ref_files])
|
488 |
+
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_ref_files_gallery, clear_button_column, ref_files, init_img_selected_idx])
|
489 |
|
490 |
init_img_files.upload(fn=swap_to_gallery, inputs=init_img_files,
|
491 |
+
outputs=[generated_init_img_gallery, init_clear_button_column, init_img_files])
|
492 |
remove_init_and_reupload.click(fn=remove_back_to_files,
|
493 |
+
outputs=[generated_init_img_gallery, init_clear_button_column,
|
494 |
init_img_files, init_img_selected_idx])
|
495 |
gen_init.click(fn=check_prompt_and_model_type,
|
496 |
inputs=[prompt, model_style_type],outputs=None).success(
|
|
|
499 |
outputs=seed,
|
500 |
queue=False,
|
501 |
api_name=False,
|
502 |
+
).then(fn=gen_init_images, inputs=[uploaded_ref_files_gallery, prompt,
|
503 |
+
highlight_face, enhance_composition,
|
504 |
guidance_scale],
|
505 |
+
outputs=[generated_init_img_gallery, init_img_files, init_clear_button_column])
|
506 |
+
generated_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx)
|
507 |
|
508 |
submit.click(fn=check_prompt_and_model_type,
|
509 |
inputs=[prompt, model_style_type],outputs=None).success(
|
|
|
514 |
api_name=False,
|
515 |
).then(
|
516 |
fn=generate_video,
|
517 |
+
inputs=[image_container, ref_files,
|
518 |
init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
|
519 |
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
520 |
seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
|
521 |
+
highlight_face, enhance_composition, is_adaface_enabled,
|
522 |
adaface_power_scale, id_animator_anneal_steps],
|
523 |
outputs=[result_video]
|
524 |
)
|