Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import cv2 | |
import insightface | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from pulid.utils import img2tensor, tensor2img | |
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline | |
from facexlib.parsing import init_parsing_model | |
from facexlib.utils.face_restoration_helper import FaceRestoreHelper | |
from huggingface_hub import hf_hub_download, snapshot_download | |
from insightface.app import FaceAnalysis | |
from safetensors.torch import load_file | |
from torchvision.transforms import InterpolationMode | |
from torchvision.transforms.functional import normalize, resize | |
from eva_clip import create_model_and_transforms | |
from eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD | |
from pulid.encoders_transformer import IDFormer | |
from pulid.utils import is_torch2_available, sample_dpmpp_2m, sample_dpmpp_sde | |
if is_torch2_available(): | |
from pulid.attention_processor import AttnProcessor2_0 as AttnProcessor | |
from pulid.attention_processor import IDAttnProcessor2_0 as IDAttnProcessor | |
else: | |
from pulid.attention_processor import AttnProcessor, IDAttnProcessor | |
class PuLIDEncoder: | |
def __init__( | |
self, | |
device | |
): | |
super().__init__() | |
self.device = device | |
# scheduler | |
# self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
# self.pipe.scheduler.config | |
# ) | |
# ID adapters | |
# self.id_adapter = IDFormer().to(self.device) | |
# preprocessors | |
# face align and parsing | |
self.face_helper = FaceRestoreHelper( | |
upscale_factor=1, | |
face_size=512, | |
crop_ratio=(1, 1), | |
det_model="retinaface_resnet50", | |
save_ext="png", | |
device=self.device, | |
) | |
self.face_helper.face_parse = None | |
self.face_helper.face_parse = init_parsing_model( | |
model_name="bisenet", device=self.device | |
) | |
# clip-vit backbone | |
model, _, _ = create_model_and_transforms( | |
"EVA02-CLIP-L-14-336", "eva_clip", force_custom_clip=True | |
) | |
model = model.visual | |
self.clip_vision_model = model.to(self.device) | |
eva_transform_mean = getattr( | |
self.clip_vision_model, "image_mean", OPENAI_DATASET_MEAN | |
) | |
eva_transform_std = getattr( | |
self.clip_vision_model, "image_std", OPENAI_DATASET_STD | |
) | |
if not isinstance(eva_transform_mean, (list, tuple)): | |
eva_transform_mean = (eva_transform_mean,) * 3 | |
if not isinstance(eva_transform_std, (list, tuple)): | |
eva_transform_std = (eva_transform_std,) * 3 | |
self.eva_transform_mean = eva_transform_mean | |
self.eva_transform_std = eva_transform_std | |
# antelopev2 | |
snapshot_download("DIAMONIK7777/antelopev2", local_dir="models/antelopev2") | |
self.app = FaceAnalysis( | |
name="antelopev2", | |
root=".", | |
providers=["CPUExecutionProvider"], | |
) | |
self.app.prepare(ctx_id=0, det_size=(640, 640)) | |
self.handler_ante = insightface.model_zoo.get_model( | |
"models/antelopev2/glintr100.onnx" | |
) | |
self.handler_ante.prepare(ctx_id=0) | |
gc.collect() | |
torch.cuda.empty_cache() | |
# self.load_pretrain() | |
# other configs | |
self.debug_img_list = [] | |
def to_gray(self, img): | |
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] | |
x = x.repeat(1, 3, 1, 1) | |
return x | |
def get_id_embedding(self, image_list): | |
""" | |
Args: | |
image in image_list: numpy rgb image, range [0, 255] | |
""" | |
id_cond_list = [] | |
id_vit_hidden_list = [] | |
for ii, image in enumerate(image_list): | |
self.face_helper.clean_all() | |
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
# get antelopev2 embedding | |
face_info = self.app.get(image_bgr) | |
if len(face_info) > 0: | |
face_info = sorted( | |
face_info, | |
key=lambda x: (x["bbox"][2] - x["bbox"][0]) | |
* (x["bbox"][3] - x["bbox"][1]), | |
)[ | |
-1 | |
] # only use the maximum face | |
id_ante_embedding = face_info["embedding"] | |
self.debug_img_list.append( | |
image[ | |
int(face_info["bbox"][1]) : int(face_info["bbox"][3]), | |
int(face_info["bbox"][0]) : int(face_info["bbox"][2]), | |
] | |
) | |
else: | |
id_ante_embedding = None | |
# using facexlib to detect and align face | |
self.face_helper.read_image(image_bgr) | |
self.face_helper.get_face_landmarks_5(only_center_face=True) | |
self.face_helper.align_warp_face() | |
if len(self.face_helper.cropped_faces) == 0: | |
raise RuntimeError("facexlib align face fail") | |
align_face = self.face_helper.cropped_faces[0] | |
# incase insightface didn't detect face | |
if id_ante_embedding is None: | |
print( | |
"fail to detect face using insightface, extract embedding on align face" | |
) | |
id_ante_embedding = self.handler_ante.get_feat(align_face) | |
id_ante_embedding = torch.from_numpy(id_ante_embedding).to(self.device) | |
if id_ante_embedding.ndim == 1: | |
id_ante_embedding = id_ante_embedding.unsqueeze(0) | |
# parsing | |
input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 | |
input = input.to(self.device) | |
parsing_out = self.face_helper.face_parse( | |
normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
)[0] | |
parsing_out = parsing_out.argmax(dim=1, keepdim=True) | |
bg_label = [0, 16, 18, 7, 8, 9, 14, 15] | |
bg = sum(parsing_out == i for i in bg_label).bool() | |
white_image = torch.ones_like(input) | |
# only keep the face features | |
face_features_image = torch.where(bg, white_image, self.to_gray(input)) | |
self.debug_img_list.append(tensor2img(face_features_image, rgb2bgr=False)) | |
# transform img before sending to eva-clip-vit | |
face_features_image = resize( | |
face_features_image, | |
self.clip_vision_model.image_size, | |
InterpolationMode.BICUBIC, | |
) | |
face_features_image = normalize( | |
face_features_image, self.eva_transform_mean, self.eva_transform_std | |
) | |
id_cond_vit, id_vit_hidden = self.clip_vision_model( | |
face_features_image, | |
return_all_features=False, | |
return_hidden=True, | |
shuffle=False, | |
) | |
id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True) | |
id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm) | |
id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) | |
id_cond_list.append(id_cond) | |
id_vit_hidden_list.append(id_vit_hidden) | |
id_uncond = torch.zeros_like(id_cond_list[0]) | |
id_vit_hidden_uncond = [] | |
for layer_idx in range(0, len(id_vit_hidden_list[0])): | |
id_vit_hidden_uncond.append( | |
torch.zeros_like(id_vit_hidden_list[0][layer_idx]) | |
) | |
id_cond = torch.stack(id_cond_list, dim=1) | |
id_vit_hidden = id_vit_hidden_list[0] | |
for i in range(1, len(image_list)): | |
for j, x in enumerate(id_vit_hidden_list[i]): | |
id_vit_hidden[j] = torch.cat([id_vit_hidden[j], x], dim=1) | |
# id_embedding = self.id_adapter(id_cond, id_vit_hidden) | |
# uncond_id_embedding = self.id_adapter(id_uncond, id_vit_hidden_uncond) | |
# return id_embedding | |
return id_cond, id_vit_hidden, id_uncond, id_vit_hidden_uncond |