Spaces:
Configuration error
Configuration error
import torch | |
from torch import nn | |
import torchvision.transforms as T | |
import torch.nn.functional as F | |
import os | |
import math | |
import folder_paths | |
import comfy.utils | |
from insightface.app import FaceAnalysis | |
from facexlib.parsing import init_parsing_model | |
from facexlib.utils.face_restoration_helper import FaceRestoreHelper | |
from comfy.ldm.modules.attention import optimized_attention | |
from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD | |
from .encoders import IDEncoder | |
INSIGHTFACE_DIR = os.path.join(folder_paths.models_dir, "insightface") | |
MODELS_DIR = os.path.join(folder_paths.models_dir, "pulid") | |
if "pulid" not in folder_paths.folder_names_and_paths: | |
current_paths = [MODELS_DIR] | |
else: | |
current_paths, _ = folder_paths.folder_names_and_paths["pulid"] | |
folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions) | |
class PulidModel(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.model = model | |
self.image_proj_model = self.init_id_adapter() | |
self.image_proj_model.load_state_dict(model["image_proj"]) | |
self.ip_layers = To_KV(model["ip_adapter"]) | |
def init_id_adapter(self): | |
image_proj_model = IDEncoder() | |
return image_proj_model | |
def get_image_embeds(self, face_embed, clip_embeds): | |
embeds = self.image_proj_model(face_embed, clip_embeds) | |
return embeds | |
class To_KV(nn.Module): | |
def __init__(self, state_dict): | |
super().__init__() | |
self.to_kvs = nn.ModuleDict() | |
for key, value in state_dict.items(): | |
self.to_kvs[key.replace(".weight", "").replace(".", "_")] = nn.Linear(value.shape[1], value.shape[0], bias=False) | |
self.to_kvs[key.replace(".weight", "").replace(".", "_")].weight.data = value | |
def tensor_to_image(tensor): | |
image = tensor.mul(255).clamp(0, 255).byte().cpu() | |
image = image[..., [2, 1, 0]].numpy() | |
return image | |
def image_to_tensor(image): | |
tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1) | |
tensor = tensor[..., [2, 1, 0]] | |
return tensor | |
def tensor_to_size(source, dest_size): | |
if isinstance(dest_size, torch.Tensor): | |
dest_size = dest_size.shape[0] | |
source_size = source.shape[0] | |
if source_size < dest_size: | |
shape = [dest_size - source_size] + [1]*(source.dim()-1) | |
source = torch.cat((source, source[-1:].repeat(shape)), dim=0) | |
elif source_size > dest_size: | |
source = source[:dest_size] | |
return source | |
def set_model_patch_replace(model, patch_kwargs, key): | |
to = model.model_options["transformer_options"].copy() | |
if "patches_replace" not in to: | |
to["patches_replace"] = {} | |
else: | |
to["patches_replace"] = to["patches_replace"].copy() | |
if "attn2" not in to["patches_replace"]: | |
to["patches_replace"]["attn2"] = {} | |
else: | |
to["patches_replace"]["attn2"] = to["patches_replace"]["attn2"].copy() | |
if key not in to["patches_replace"]["attn2"]: | |
to["patches_replace"]["attn2"][key] = Attn2Replace(pulid_attention, **patch_kwargs) | |
model.model_options["transformer_options"] = to | |
else: | |
to["patches_replace"]["attn2"][key].add(pulid_attention, **patch_kwargs) | |
class Attn2Replace: | |
def __init__(self, callback=None, **kwargs): | |
self.callback = [callback] | |
self.kwargs = [kwargs] | |
def add(self, callback, **kwargs): | |
self.callback.append(callback) | |
self.kwargs.append(kwargs) | |
for key, value in kwargs.items(): | |
setattr(self, key, value) | |
def __call__(self, q, k, v, extra_options): | |
dtype = q.dtype | |
out = optimized_attention(q, k, v, extra_options["n_heads"]) | |
sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9 | |
for i, callback in enumerate(self.callback): | |
if sigma <= self.kwargs[i]["sigma_start"] and sigma >= self.kwargs[i]["sigma_end"]: | |
out = out + callback(out, q, k, v, extra_options, **self.kwargs[i]) | |
return out.to(dtype=dtype) | |
def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond=None, uncond=None, weight=1.0, ortho=False, ortho_v2=False, mask=None, **kwargs): | |
k_key = module_key + "_to_k_ip" | |
v_key = module_key + "_to_v_ip" | |
dtype = q.dtype | |
seq_len = q.shape[1] | |
cond_or_uncond = extra_options["cond_or_uncond"] | |
b = q.shape[0] | |
batch_prompt = b // len(cond_or_uncond) | |
_, _, oh, ow = extra_options["original_shape"] | |
#conds = torch.cat([uncond.repeat(batch_prompt, 1, 1), cond.repeat(batch_prompt, 1, 1)], dim=0) | |
#zero_tensor = torch.zeros((conds.size(0), num_zero, conds.size(-1)), dtype=conds.dtype, device=conds.device) | |
#conds = torch.cat([conds, zero_tensor], dim=1) | |
#ip_k = pulid.ip_layers.to_kvs[k_key](conds) | |
#ip_v = pulid.ip_layers.to_kvs[v_key](conds) | |
k_cond = pulid.ip_layers.to_kvs[k_key](cond).repeat(batch_prompt, 1, 1) | |
k_uncond = pulid.ip_layers.to_kvs[k_key](uncond).repeat(batch_prompt, 1, 1) | |
v_cond = pulid.ip_layers.to_kvs[v_key](cond).repeat(batch_prompt, 1, 1) | |
v_uncond = pulid.ip_layers.to_kvs[v_key](uncond).repeat(batch_prompt, 1, 1) | |
ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) | |
ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) | |
out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) | |
if ortho: | |
out = out.to(dtype=torch.float32) | |
out_ip = out_ip.to(dtype=torch.float32) | |
projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) | |
orthogonal = out_ip - projection | |
out_ip = weight * orthogonal | |
elif ortho_v2: | |
out = out.to(dtype=torch.float32) | |
out_ip = out_ip.to(dtype=torch.float32) | |
attn_map = q @ ip_k.transpose(-2, -1) | |
attn_mean = attn_map.softmax(dim=-1).mean(dim=1, keepdim=True) | |
attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) | |
projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) | |
orthogonal = out_ip + (attn_mean - 1) * projection | |
out_ip = weight * orthogonal | |
else: | |
out_ip = out_ip * weight | |
if mask is not None: | |
mask_h = oh / math.sqrt(oh * ow / seq_len) | |
mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) | |
mask_w = seq_len // mask_h | |
mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) | |
mask = tensor_to_size(mask, batch_prompt) | |
mask = mask.repeat(len(cond_or_uncond), 1, 1) | |
mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2]) | |
# covers cases where extreme aspect ratios can cause the mask to have a wrong size | |
mask_len = mask_h * mask_w | |
if mask_len < seq_len: | |
pad_len = seq_len - mask_len | |
pad1 = pad_len // 2 | |
pad2 = pad_len - pad1 | |
mask = F.pad(mask, (0, 0, pad1, pad2), value=0.0) | |
elif mask_len > seq_len: | |
crop_start = (mask_len - seq_len) // 2 | |
mask = mask[:, crop_start:crop_start+seq_len, :] | |
out_ip = out_ip * mask | |
return out_ip.to(dtype=dtype) | |
def to_gray(img): | |
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] | |
x = x.repeat(1, 3, 1, 1) | |
return x | |
""" | |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
Nodes | |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
""" | |
class PulidModelLoader: | |
def INPUT_TYPES(s): | |
return {"required": { "pulid_file": (folder_paths.get_filename_list("pulid"), )}} | |
RETURN_TYPES = ("PULID",) | |
FUNCTION = "load_model" | |
CATEGORY = "pulid" | |
def load_model(self, pulid_file): | |
ckpt_path = folder_paths.get_full_path("pulid", pulid_file) | |
model = comfy.utils.load_torch_file(ckpt_path, safe_load=True) | |
if ckpt_path.lower().endswith(".safetensors"): | |
st_model = {"image_proj": {}, "ip_adapter": {}} | |
for key in model.keys(): | |
if key.startswith("image_proj."): | |
st_model["image_proj"][key.replace("image_proj.", "")] = model[key] | |
elif key.startswith("ip_adapter."): | |
st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key] | |
model = st_model | |
# Also initialize the model, takes longer to load but then it doesn't have to be done every time you change parameters in the apply node | |
model = PulidModel(model) | |
return (model,) | |
class PulidInsightFaceLoader: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"provider": (["CPU", "CUDA", "ROCM"], ), | |
}, | |
} | |
RETURN_TYPES = ("FACEANALYSIS",) | |
FUNCTION = "load_insightface" | |
CATEGORY = "pulid" | |
def load_insightface(self, provider): | |
model = FaceAnalysis(name="antelopev2", root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider',]) # alternative to buffalo_l | |
model.prepare(ctx_id=0, det_size=(640, 640)) | |
return (model,) | |
class PulidEvaClipLoader: | |
def INPUT_TYPES(s): | |
return { | |
"required": {}, | |
} | |
RETURN_TYPES = ("EVA_CLIP",) | |
FUNCTION = "load_eva_clip" | |
CATEGORY = "pulid" | |
def load_eva_clip(self): | |
from .eva_clip.factory import create_model_and_transforms | |
model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True) | |
model = model.visual | |
eva_transform_mean = getattr(model, 'image_mean', OPENAI_DATASET_MEAN) | |
eva_transform_std = getattr(model, 'image_std', OPENAI_DATASET_STD) | |
if not isinstance(eva_transform_mean, (list, tuple)): | |
model["image_mean"] = (eva_transform_mean,) * 3 | |
if not isinstance(eva_transform_std, (list, tuple)): | |
model["image_std"] = (eva_transform_std,) * 3 | |
return (model,) | |
class ApplyPulid: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL", ), | |
"pulid": ("PULID", ), | |
"eva_clip": ("EVA_CLIP", ), | |
"face_analysis": ("FACEANALYSIS", ), | |
"image": ("IMAGE", ), | |
"method": (["fidelity", "style", "neutral"],), | |
"weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), | |
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), | |
"end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), | |
}, | |
"optional": { | |
"attn_mask": ("MASK", ), | |
}, | |
} | |
RETURN_TYPES = ("MODEL",) | |
FUNCTION = "apply_pulid" | |
CATEGORY = "pulid" | |
def apply_pulid(self, model, pulid, eva_clip, face_analysis, image, weight, start_at, end_at, method=None, noise=0.0, fidelity=None, projection=None, attn_mask=None): | |
work_model = model.clone() | |
device = comfy.model_management.get_torch_device() | |
dtype = comfy.model_management.unet_dtype() | |
if dtype not in [torch.float32, torch.float16, torch.bfloat16]: | |
dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32 | |
eva_clip.to(device, dtype=dtype) | |
pulid_model = pulid.to(device, dtype=dtype) | |
if attn_mask is not None: | |
if attn_mask.dim() > 3: | |
attn_mask = attn_mask.squeeze(-1) | |
elif attn_mask.dim() < 3: | |
attn_mask = attn_mask.unsqueeze(0) | |
attn_mask = attn_mask.to(device, dtype=dtype) | |
if method == "fidelity" or projection == "ortho_v2": | |
num_zero = 8 | |
ortho = False | |
ortho_v2 = True | |
elif method == "style" or projection == "ortho": | |
num_zero = 16 | |
ortho = True | |
ortho_v2 = False | |
else: | |
num_zero = 0 | |
ortho = False | |
ortho_v2 = False | |
if fidelity is not None: | |
num_zero = fidelity | |
#face_analysis.det_model.input_size = (640,640) | |
image = tensor_to_image(image) | |
face_helper = FaceRestoreHelper( | |
upscale_factor=1, | |
face_size=512, | |
crop_ratio=(1, 1), | |
det_model='retinaface_resnet50', | |
save_ext='png', | |
device=device, | |
) | |
face_helper.face_parse = None | |
face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device) | |
bg_label = [0, 16, 18, 7, 8, 9, 14, 15] | |
cond = [] | |
uncond = [] | |
for i in range(image.shape[0]): | |
# get insightface embeddings | |
iface_embeds = None | |
for size in [(size, size) for size in range(640, 256, -64)]: | |
face_analysis.det_model.input_size = size | |
face = face_analysis.get(image[i]) | |
if face: | |
face = sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True)[-1] | |
iface_embeds = torch.from_numpy(face.embedding).unsqueeze(0).to(device, dtype=dtype) | |
break | |
else: | |
raise Exception('insightface: No face detected.') | |
# get eva_clip embeddings | |
face_helper.clean_all() | |
face_helper.read_image(image[i]) | |
face_helper.get_face_landmarks_5(only_center_face=True) | |
face_helper.align_warp_face() | |
if len(face_helper.cropped_faces) == 0: | |
raise Exception('facexlib: No face detected.') | |
face = face_helper.cropped_faces[0] | |
face = image_to_tensor(face).unsqueeze(0).permute(0,3,1,2).to(device) | |
parsing_out = face_helper.face_parse(T.functional.normalize(face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] | |
parsing_out = parsing_out.argmax(dim=1, keepdim=True) | |
bg = sum(parsing_out == i for i in bg_label).bool() | |
white_image = torch.ones_like(face) | |
face_features_image = torch.where(bg, white_image, to_gray(face)) | |
# apparently MPS only supports NEAREST interpolation? | |
face_features_image = T.functional.resize(face_features_image, eva_clip.image_size, T.InterpolationMode.BICUBIC if 'cuda' in device.type else T.InterpolationMode.NEAREST).to(device, dtype=dtype) | |
face_features_image = T.functional.normalize(face_features_image, eva_clip.image_mean, eva_clip.image_std) | |
id_cond_vit, id_vit_hidden = eva_clip(face_features_image, return_all_features=False, return_hidden=True, shuffle=False) | |
id_cond_vit = id_cond_vit.to(device, dtype=dtype) | |
for idx in range(len(id_vit_hidden)): | |
id_vit_hidden[idx] = id_vit_hidden[idx].to(device, dtype=dtype) | |
id_cond_vit = torch.div(id_cond_vit, torch.norm(id_cond_vit, 2, 1, True)) | |
# combine embeddings | |
id_cond = torch.cat([iface_embeds, id_cond_vit], dim=-1) | |
if noise == 0: | |
id_uncond = torch.zeros_like(id_cond) | |
else: | |
id_uncond = torch.rand_like(id_cond) * noise | |
id_vit_hidden_uncond = [] | |
for idx in range(len(id_vit_hidden)): | |
if noise == 0: | |
id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden[idx])) | |
else: | |
id_vit_hidden_uncond.append(torch.rand_like(id_vit_hidden[idx]) * noise) | |
cond.append(pulid_model.get_image_embeds(id_cond, id_vit_hidden)) | |
uncond.append(pulid_model.get_image_embeds(id_uncond, id_vit_hidden_uncond)) | |
# average embeddings | |
cond = torch.cat(cond).to(device, dtype=dtype) | |
uncond = torch.cat(uncond).to(device, dtype=dtype) | |
if cond.shape[0] > 1: | |
cond = torch.mean(cond, dim=0, keepdim=True) | |
uncond = torch.mean(uncond, dim=0, keepdim=True) | |
if num_zero > 0: | |
if noise == 0: | |
zero_tensor = torch.zeros((cond.size(0), num_zero, cond.size(-1)), dtype=dtype, device=device) | |
else: | |
zero_tensor = torch.rand((cond.size(0), num_zero, cond.size(-1)), dtype=dtype, device=device) * noise | |
cond = torch.cat([cond, zero_tensor], dim=1) | |
uncond = torch.cat([uncond, zero_tensor], dim=1) | |
sigma_start = work_model.get_model_object("model_sampling").percent_to_sigma(start_at) | |
sigma_end = work_model.get_model_object("model_sampling").percent_to_sigma(end_at) | |
patch_kwargs = { | |
"pulid": pulid_model, | |
"weight": weight, | |
"cond": cond, | |
"uncond": uncond, | |
"sigma_start": sigma_start, | |
"sigma_end": sigma_end, | |
"ortho": ortho, | |
"ortho_v2": ortho_v2, | |
"mask": attn_mask, | |
} | |
number = 0 | |
for id in [4,5,7,8]: # id of input_blocks that have cross attention | |
block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth | |
for index in block_indices: | |
patch_kwargs["module_key"] = str(number*2+1) | |
set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) | |
number += 1 | |
for id in range(6): # id of output_blocks that have cross attention | |
block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth | |
for index in block_indices: | |
patch_kwargs["module_key"] = str(number*2+1) | |
set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) | |
number += 1 | |
for index in range(10): | |
patch_kwargs["module_key"] = str(number*2+1) | |
set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index)) | |
number += 1 | |
return (work_model,) | |
class ApplyPulidAdvanced(ApplyPulid): | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL", ), | |
"pulid": ("PULID", ), | |
"eva_clip": ("EVA_CLIP", ), | |
"face_analysis": ("FACEANALYSIS", ), | |
"image": ("IMAGE", ), | |
"weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), | |
"projection": (["ortho_v2", "ortho", "none"],), | |
"fidelity": ("INT", {"default": 8, "min": 0, "max": 32, "step": 1 }), | |
"noise": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.1 }), | |
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), | |
"end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), | |
}, | |
"optional": { | |
"attn_mask": ("MASK", ), | |
}, | |
} | |
NODE_CLASS_MAPPINGS = { | |
"PulidModelLoader": PulidModelLoader, | |
"PulidInsightFaceLoader": PulidInsightFaceLoader, | |
"PulidEvaClipLoader": PulidEvaClipLoader, | |
"ApplyPulid": ApplyPulid, | |
"ApplyPulidAdvanced": ApplyPulidAdvanced, | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"PulidModelLoader": "Load PuLID Model", | |
"PulidInsightFaceLoader": "Load InsightFace (PuLID)", | |
"PulidEvaClipLoader": "Load Eva Clip (PuLID)", | |
"ApplyPulid": "Apply PuLID", | |
"ApplyPulidAdvanced": "Apply PuLID Advanced", | |
} |