import torch from torch import nn import torchvision.transforms as T import torch.nn.functional as F import os import math import folder_paths import comfy.utils from insightface.app import FaceAnalysis from facexlib.parsing import init_parsing_model from facexlib.utils.face_restoration_helper import FaceRestoreHelper from comfy.ldm.modules.attention import optimized_attention from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .encoders import IDEncoder INSIGHTFACE_DIR = os.path.join(folder_paths.models_dir, "insightface") MODELS_DIR = os.path.join(folder_paths.models_dir, "pulid") if "pulid" not in folder_paths.folder_names_and_paths: current_paths = [MODELS_DIR] else: current_paths, _ = folder_paths.folder_names_and_paths["pulid"] folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions) class PulidModel(nn.Module): def __init__(self, model): super().__init__() self.model = model self.image_proj_model = self.init_id_adapter() self.image_proj_model.load_state_dict(model["image_proj"]) self.ip_layers = To_KV(model["ip_adapter"]) def init_id_adapter(self): image_proj_model = IDEncoder() return image_proj_model def get_image_embeds(self, face_embed, clip_embeds): embeds = self.image_proj_model(face_embed, clip_embeds) return embeds class To_KV(nn.Module): def __init__(self, state_dict): super().__init__() self.to_kvs = nn.ModuleDict() for key, value in state_dict.items(): self.to_kvs[key.replace(".weight", "").replace(".", "_")] = nn.Linear(value.shape[1], value.shape[0], bias=False) self.to_kvs[key.replace(".weight", "").replace(".", "_")].weight.data = value def tensor_to_image(tensor): image = tensor.mul(255).clamp(0, 255).byte().cpu() image = image[..., [2, 1, 0]].numpy() return image def image_to_tensor(image): tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1) tensor = tensor[..., [2, 1, 0]] return tensor def tensor_to_size(source, dest_size): if isinstance(dest_size, torch.Tensor): dest_size = dest_size.shape[0] source_size = source.shape[0] if source_size < dest_size: shape = [dest_size - source_size] + [1]*(source.dim()-1) source = torch.cat((source, source[-1:].repeat(shape)), dim=0) elif source_size > dest_size: source = source[:dest_size] return source def set_model_patch_replace(model, patch_kwargs, key): to = model.model_options["transformer_options"].copy() if "patches_replace" not in to: to["patches_replace"] = {} else: to["patches_replace"] = to["patches_replace"].copy() if "attn2" not in to["patches_replace"]: to["patches_replace"]["attn2"] = {} else: to["patches_replace"]["attn2"] = to["patches_replace"]["attn2"].copy() if key not in to["patches_replace"]["attn2"]: to["patches_replace"]["attn2"][key] = Attn2Replace(pulid_attention, **patch_kwargs) model.model_options["transformer_options"] = to else: to["patches_replace"]["attn2"][key].add(pulid_attention, **patch_kwargs) class Attn2Replace: def __init__(self, callback=None, **kwargs): self.callback = [callback] self.kwargs = [kwargs] def add(self, callback, **kwargs): self.callback.append(callback) self.kwargs.append(kwargs) for key, value in kwargs.items(): setattr(self, key, value) def __call__(self, q, k, v, extra_options): dtype = q.dtype out = optimized_attention(q, k, v, extra_options["n_heads"]) sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9 for i, callback in enumerate(self.callback): if sigma <= self.kwargs[i]["sigma_start"] and sigma >= self.kwargs[i]["sigma_end"]: out = out + callback(out, q, k, v, extra_options, **self.kwargs[i]) return out.to(dtype=dtype) def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond=None, uncond=None, weight=1.0, ortho=False, ortho_v2=False, mask=None, **kwargs): k_key = module_key + "_to_k_ip" v_key = module_key + "_to_v_ip" dtype = q.dtype seq_len = q.shape[1] cond_or_uncond = extra_options["cond_or_uncond"] b = q.shape[0] batch_prompt = b // len(cond_or_uncond) _, _, oh, ow = extra_options["original_shape"] #conds = torch.cat([uncond.repeat(batch_prompt, 1, 1), cond.repeat(batch_prompt, 1, 1)], dim=0) #zero_tensor = torch.zeros((conds.size(0), num_zero, conds.size(-1)), dtype=conds.dtype, device=conds.device) #conds = torch.cat([conds, zero_tensor], dim=1) #ip_k = pulid.ip_layers.to_kvs[k_key](conds) #ip_v = pulid.ip_layers.to_kvs[v_key](conds) k_cond = pulid.ip_layers.to_kvs[k_key](cond).repeat(batch_prompt, 1, 1) k_uncond = pulid.ip_layers.to_kvs[k_key](uncond).repeat(batch_prompt, 1, 1) v_cond = pulid.ip_layers.to_kvs[v_key](cond).repeat(batch_prompt, 1, 1) v_uncond = pulid.ip_layers.to_kvs[v_key](uncond).repeat(batch_prompt, 1, 1) ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) if ortho: out = out.to(dtype=torch.float32) out_ip = out_ip.to(dtype=torch.float32) projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) orthogonal = out_ip - projection out_ip = weight * orthogonal elif ortho_v2: out = out.to(dtype=torch.float32) out_ip = out_ip.to(dtype=torch.float32) attn_map = q @ ip_k.transpose(-2, -1) attn_mean = attn_map.softmax(dim=-1).mean(dim=1, keepdim=True) attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) orthogonal = out_ip + (attn_mean - 1) * projection out_ip = weight * orthogonal else: out_ip = out_ip * weight if mask is not None: mask_h = oh / math.sqrt(oh * ow / seq_len) mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) mask_w = seq_len // mask_h mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) mask = tensor_to_size(mask, batch_prompt) mask = mask.repeat(len(cond_or_uncond), 1, 1) mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2]) # covers cases where extreme aspect ratios can cause the mask to have a wrong size mask_len = mask_h * mask_w if mask_len < seq_len: pad_len = seq_len - mask_len pad1 = pad_len // 2 pad2 = pad_len - pad1 mask = F.pad(mask, (0, 0, pad1, pad2), value=0.0) elif mask_len > seq_len: crop_start = (mask_len - seq_len) // 2 mask = mask[:, crop_start:crop_start+seq_len, :] out_ip = out_ip * mask return out_ip.to(dtype=dtype) def to_gray(img): x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] x = x.repeat(1, 3, 1, 1) return x """ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Nodes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ class PulidModelLoader: @classmethod def INPUT_TYPES(s): return {"required": { "pulid_file": (folder_paths.get_filename_list("pulid"), )}} RETURN_TYPES = ("PULID",) FUNCTION = "load_model" CATEGORY = "pulid" def load_model(self, pulid_file): ckpt_path = folder_paths.get_full_path("pulid", pulid_file) model = comfy.utils.load_torch_file(ckpt_path, safe_load=True) if ckpt_path.lower().endswith(".safetensors"): st_model = {"image_proj": {}, "ip_adapter": {}} for key in model.keys(): if key.startswith("image_proj."): st_model["image_proj"][key.replace("image_proj.", "")] = model[key] elif key.startswith("ip_adapter."): st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key] model = st_model # Also initialize the model, takes longer to load but then it doesn't have to be done every time you change parameters in the apply node model = PulidModel(model) return (model,) class PulidInsightFaceLoader: @classmethod def INPUT_TYPES(s): return { "required": { "provider": (["CPU", "CUDA", "ROCM"], ), }, } RETURN_TYPES = ("FACEANALYSIS",) FUNCTION = "load_insightface" CATEGORY = "pulid" def load_insightface(self, provider): model = FaceAnalysis(name="antelopev2", root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider',]) # alternative to buffalo_l model.prepare(ctx_id=0, det_size=(640, 640)) return (model,) class PulidEvaClipLoader: @classmethod def INPUT_TYPES(s): return { "required": {}, } RETURN_TYPES = ("EVA_CLIP",) FUNCTION = "load_eva_clip" CATEGORY = "pulid" def load_eva_clip(self): from .eva_clip.factory import create_model_and_transforms model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True) model = model.visual eva_transform_mean = getattr(model, 'image_mean', OPENAI_DATASET_MEAN) eva_transform_std = getattr(model, 'image_std', OPENAI_DATASET_STD) if not isinstance(eva_transform_mean, (list, tuple)): model["image_mean"] = (eva_transform_mean,) * 3 if not isinstance(eva_transform_std, (list, tuple)): model["image_std"] = (eva_transform_std,) * 3 return (model,) class ApplyPulid: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL", ), "pulid": ("PULID", ), "eva_clip": ("EVA_CLIP", ), "face_analysis": ("FACEANALYSIS", ), "image": ("IMAGE", ), "method": (["fidelity", "style", "neutral"],), "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), }, "optional": { "attn_mask": ("MASK", ), }, } RETURN_TYPES = ("MODEL",) FUNCTION = "apply_pulid" CATEGORY = "pulid" def apply_pulid(self, model, pulid, eva_clip, face_analysis, image, weight, start_at, end_at, method=None, noise=0.0, fidelity=None, projection=None, attn_mask=None): work_model = model.clone() device = comfy.model_management.get_torch_device() dtype = comfy.model_management.unet_dtype() if dtype not in [torch.float32, torch.float16, torch.bfloat16]: dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32 eva_clip.to(device, dtype=dtype) pulid_model = pulid.to(device, dtype=dtype) if attn_mask is not None: if attn_mask.dim() > 3: attn_mask = attn_mask.squeeze(-1) elif attn_mask.dim() < 3: attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.to(device, dtype=dtype) if method == "fidelity" or projection == "ortho_v2": num_zero = 8 ortho = False ortho_v2 = True elif method == "style" or projection == "ortho": num_zero = 16 ortho = True ortho_v2 = False else: num_zero = 0 ortho = False ortho_v2 = False if fidelity is not None: num_zero = fidelity #face_analysis.det_model.input_size = (640,640) image = tensor_to_image(image) face_helper = FaceRestoreHelper( upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device=device, ) face_helper.face_parse = None face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device) bg_label = [0, 16, 18, 7, 8, 9, 14, 15] cond = [] uncond = [] for i in range(image.shape[0]): # get insightface embeddings iface_embeds = None for size in [(size, size) for size in range(640, 256, -64)]: face_analysis.det_model.input_size = size face = face_analysis.get(image[i]) if face: face = sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True)[-1] iface_embeds = torch.from_numpy(face.embedding).unsqueeze(0).to(device, dtype=dtype) break else: raise Exception('insightface: No face detected.') # get eva_clip embeddings face_helper.clean_all() face_helper.read_image(image[i]) face_helper.get_face_landmarks_5(only_center_face=True) face_helper.align_warp_face() if len(face_helper.cropped_faces) == 0: raise Exception('facexlib: No face detected.') face = face_helper.cropped_faces[0] face = image_to_tensor(face).unsqueeze(0).permute(0,3,1,2).to(device) parsing_out = face_helper.face_parse(T.functional.normalize(face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] parsing_out = parsing_out.argmax(dim=1, keepdim=True) bg = sum(parsing_out == i for i in bg_label).bool() white_image = torch.ones_like(face) face_features_image = torch.where(bg, white_image, to_gray(face)) # apparently MPS only supports NEAREST interpolation? face_features_image = T.functional.resize(face_features_image, eva_clip.image_size, T.InterpolationMode.BICUBIC if 'cuda' in device.type else T.InterpolationMode.NEAREST).to(device, dtype=dtype) face_features_image = T.functional.normalize(face_features_image, eva_clip.image_mean, eva_clip.image_std) id_cond_vit, id_vit_hidden = eva_clip(face_features_image, return_all_features=False, return_hidden=True, shuffle=False) id_cond_vit = id_cond_vit.to(device, dtype=dtype) for idx in range(len(id_vit_hidden)): id_vit_hidden[idx] = id_vit_hidden[idx].to(device, dtype=dtype) id_cond_vit = torch.div(id_cond_vit, torch.norm(id_cond_vit, 2, 1, True)) # combine embeddings id_cond = torch.cat([iface_embeds, id_cond_vit], dim=-1) if noise == 0: id_uncond = torch.zeros_like(id_cond) else: id_uncond = torch.rand_like(id_cond) * noise id_vit_hidden_uncond = [] for idx in range(len(id_vit_hidden)): if noise == 0: id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden[idx])) else: id_vit_hidden_uncond.append(torch.rand_like(id_vit_hidden[idx]) * noise) cond.append(pulid_model.get_image_embeds(id_cond, id_vit_hidden)) uncond.append(pulid_model.get_image_embeds(id_uncond, id_vit_hidden_uncond)) # average embeddings cond = torch.cat(cond).to(device, dtype=dtype) uncond = torch.cat(uncond).to(device, dtype=dtype) if cond.shape[0] > 1: cond = torch.mean(cond, dim=0, keepdim=True) uncond = torch.mean(uncond, dim=0, keepdim=True) if num_zero > 0: if noise == 0: zero_tensor = torch.zeros((cond.size(0), num_zero, cond.size(-1)), dtype=dtype, device=device) else: zero_tensor = torch.rand((cond.size(0), num_zero, cond.size(-1)), dtype=dtype, device=device) * noise cond = torch.cat([cond, zero_tensor], dim=1) uncond = torch.cat([uncond, zero_tensor], dim=1) sigma_start = work_model.get_model_object("model_sampling").percent_to_sigma(start_at) sigma_end = work_model.get_model_object("model_sampling").percent_to_sigma(end_at) patch_kwargs = { "pulid": pulid_model, "weight": weight, "cond": cond, "uncond": uncond, "sigma_start": sigma_start, "sigma_end": sigma_end, "ortho": ortho, "ortho_v2": ortho_v2, "mask": attn_mask, } number = 0 for id in [4,5,7,8]: # id of input_blocks that have cross attention block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth for index in block_indices: patch_kwargs["module_key"] = str(number*2+1) set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) number += 1 for id in range(6): # id of output_blocks that have cross attention block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth for index in block_indices: patch_kwargs["module_key"] = str(number*2+1) set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) number += 1 for index in range(10): patch_kwargs["module_key"] = str(number*2+1) set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index)) number += 1 return (work_model,) class ApplyPulidAdvanced(ApplyPulid): @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL", ), "pulid": ("PULID", ), "eva_clip": ("EVA_CLIP", ), "face_analysis": ("FACEANALYSIS", ), "image": ("IMAGE", ), "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), "projection": (["ortho_v2", "ortho", "none"],), "fidelity": ("INT", {"default": 8, "min": 0, "max": 32, "step": 1 }), "noise": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.1 }), "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), }, "optional": { "attn_mask": ("MASK", ), }, } NODE_CLASS_MAPPINGS = { "PulidModelLoader": PulidModelLoader, "PulidInsightFaceLoader": PulidInsightFaceLoader, "PulidEvaClipLoader": PulidEvaClipLoader, "ApplyPulid": ApplyPulid, "ApplyPulidAdvanced": ApplyPulidAdvanced, } NODE_DISPLAY_NAME_MAPPINGS = { "PulidModelLoader": "Load PuLID Model", "PulidInsightFaceLoader": "Load InsightFace (PuLID)", "PulidEvaClipLoader": "Load Eva Clip (PuLID)", "ApplyPulid": "Apply PuLID", "ApplyPulidAdvanced": "Apply PuLID Advanced", }