Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import mediapipe as mp | |
import cv2 | |
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair") | |
model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair") | |
# Hàm lấy mặt (bao mặt và trán) mà không bao gồm tóc | |
def get_facemesh_mask(image): | |
image_np = np.array(image) | |
height, width, _ = image_np.shape | |
face_mask = np.zeros((height, width), dtype=np.uint8) | |
mp_face_mesh = mp.solutions.face_mesh | |
with mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5) as face_mesh: | |
results = face_mesh.process(image_np) | |
if results.multi_face_landmarks: | |
for face_landmarks in results.multi_face_landmarks: | |
points = [] | |
for lm in face_landmarks.landmark: | |
x, y = int(lm.x * width), int(lm.y * height) | |
points.append([x, y]) | |
points = np.array(points, np.int32) | |
if len(points) > 0: | |
hull = cv2.convexHull(points) | |
cv2.fillConvexPoly(face_mask, hull, 1) | |
return face_mask | |
# Hàm mở rộng vùng trán từ mask mặt đầu vào | |
def expand_forehead_mask(face_mask, expand_percent=0.2): | |
ys, xs = np.where(face_mask > 0) | |
if len(ys) == 0: | |
return face_mask # không tìm thấy mặt | |
min_y, max_y = ys.min(), ys.max() | |
height = max_y - min_y | |
expand = int(height * expand_percent) | |
expanded_min_y = max(min_y - expand, 0) | |
expanded_mask = np.zeros_like(face_mask) | |
src_start = min_y | |
src_end = max_y | |
dst_start = expanded_min_y | |
dst_end = expanded_min_y + (src_end - src_start) | |
if dst_end > face_mask.shape[0]: | |
overlap = dst_end - face_mask.shape[0] | |
dst_end = face_mask.shape[0] | |
src_end -= overlap | |
expanded_mask[dst_start:dst_end, :] = face_mask[src_start:src_end, :] | |
return expanded_mask | |
# Hàm chính: kết hợp mặt và trán mở rộng, không bao gồm tóc, lưu mask vào biến face_forehead_mask | |
face_forehead_mask = None | |
def extract_face_and_forehead_no_hair(image): | |
image = image.convert("RGB") | |
# SegFormer hair mask | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits.cpu() | |
upsampled_logits = F.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy() | |
hair_mask = (pred_seg == 2).astype(np.uint8) # tóc | |
# Face mesh mask (bao trọn mặt, trán, không cổ) | |
face_mesh_mask = get_facemesh_mask(image) | |
# Expand lên trên 20% chiều cao mặt (ăn gian trán) | |
expanded_face_mask = expand_forehead_mask(face_mesh_mask, expand_percent=0.2) | |
# Vùng trán mở rộng chỉ lấy phần không trùng vùng mặt gốc và không trùng tóc | |
expanded_only_forehead = cv2.bitwise_and(expanded_face_mask, 1 - face_mesh_mask) | |
expanded_only_forehead = cv2.bitwise_and(expanded_only_forehead, 1 - hair_mask) | |
# Kết hợp: tóc + mặt mediapipe (gốc) + vùng trán mở rộng (phía trên mặt gốc, không trùng tóc, không trùng mặt gốc) | |
combined_mask = ((face_mesh_mask + expanded_only_forehead) > 0).astype(np.uint8) | |
# Làm mượt mask | |
combined_mask = cv2.GaussianBlur(combined_mask.astype(np.float32), (3, 3), 0) | |
combined_mask = (combined_mask > 0.5).astype(np.uint8) | |
np_image = np.array(image) | |
alpha = (combined_mask * 255).astype(np.uint8) | |
rgba_image = np.dstack([np_image, alpha]) | |
return Image.fromarray(rgba_image) | |
iface = gr.Interface( | |
fn=extract_face_and_forehead_no_hair, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Image(type="numpy", label="Hair, Face & Full Forehead PNG"), | |
live=False, | |
title="Tách tóc, khuôn mặt và toàn bộ trán (ăn gian, không lấy cổ, không chồng lên tóc)", | |
description="Upload ảnh chân dung để nhận file PNG gồm tóc, mặt và trán mở rộng (không lấy cổ, không bị thiếu trán, không lấy vùng tóc vào vùng trán)." | |
) | |
iface.launch() |