high / app.py
VanNguyen1214's picture
Update app.py
4ddf5bc verified
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
import insightface
from insightface.app import FaceAnalysis
import onnxruntime as ort
import os
import tempfile
import warnings
from gfpgan import GFPGANer
import argparse
warnings.filterwarnings("ignore")
class FaceHairSwapperGradio:
def __init__(self):
"""Khởi tạo models cho Gradio app"""
self.setup_models()
def setup_models(self):
"""Setup các models cần thiết"""
try:
print("🔄 Đang tải SegFormer model...")
self.processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
self.model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")
print("✅ SegFormer model đã tải thành công!")
print("🔄 Đang tải InsightFace model...")
# Thử GPU trước, fallback CPU
providers = ['CPUExecutionProvider']
if torch.cuda.is_available():
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
try:
self.face_app = FaceAnalysis(providers=providers)
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
print("✅ InsightFace đã tải thành công!")
except Exception as e:
print(f"⚠️ Lỗi tải InsightFace: {e}")
# Thử với providers đơn giản hơn
self.face_app = FaceAnalysis(providers=['CPUExecutionProvider'])
self.face_app.prepare(ctx_id=-1, det_size=(320, 320))
print("✅ InsightFace đã tải thành công (CPU mode)!")
# Load face swapper từ file local
swapper_path = './models/inswapper_128.onnx'
if os.path.exists(swapper_path):
try:
self.face_swapper = insightface.model_zoo.get_model(swapper_path, providers=providers)
print("✅ Face Swapper đã tải từ file local!")
except Exception as e:
print(f"⚠️ Lỗi tải face swapper từ file local: {e}")
# Thử tải từ model zoo
try:
self.face_swapper = insightface.model_zoo.get_model('inswapper_128.onnx', providers=providers)
print("✅ Face Swapper đã tải từ model zoo!")
except Exception as e2:
print(f"❌ Không thể tải face swapper: {e2}")
self.face_swapper = None
else:
print("⚠️ Không tìm thấy inswapper_128.onnx trong thư mục models, thử tải từ model zoo...")
try:
self.face_swapper = insightface.model_zoo.get_model('inswapper_128.onnx', providers=providers)
print("✅ Face Swapper đã tải từ model zoo!")
except Exception as e:
print(f"❌ Không thể tải face swapper: {e}")
self.face_swapper = None
# Load GFPGAN
print("🔄 Đang tải GFPGAN model...")
gfpgan_path = './models/GFPGANv1.4.pth'
if os.path.exists(gfpgan_path):
try:
self.gfpgan_enhancer = GFPGANer(
model_path=gfpgan_path,
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=None
)
print("✅ GFPGAN model đã tải thành công!")
except Exception as e:
print(f"⚠️ Lỗi tải GFPGAN: {e}")
self.gfpgan_enhancer = None
else:
print("⚠️ Không tìm thấy GFPGANv1.4.pth, bỏ qua GFPGAN enhancement")
self.gfpgan_enhancer = None
print("✅ Setup models hoàn tất!")
except Exception as e:
print(f"❌ Lỗi khi tải model: {e}")
import traceback
traceback.print_exc()
raise e
def enhance_face_gfpgan(self, image):
"""Tăng chất lượng khuôn mặt bằng GFPGAN"""
if self.gfpgan_enhancer is None:
return image
try:
# GFPGAN expects BGR format
if len(image.shape) == 3 and image.shape[2] == 3:
input_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
else:
input_img = image
# Enhance with GFPGAN
_, _, enhanced_img = self.gfpgan_enhancer.enhance(
input_img,
has_aligned=False,
only_center_face=False,
paste_back=True
)
# Convert back to RGB
if enhanced_img is not None:
enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB)
return enhanced_img
else:
return image
except Exception as e:
print(f"GFPGAN enhancement failed: {e}")
return image
def get_face_mask_insightface(self, image, expand_ratio=0.3):
"""Lấy mask khuôn mặt từ InsightFace với độ chính xác cao hơn"""
try:
image_np = np.array(image)
height, width = image_np.shape[:2]
face_mask = np.zeros((height, width), dtype=np.uint8)
faces = self.face_app.get(image_np)
if len(faces) > 0:
# Chọn khuôn mặt lớn nhất
face = max(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
# Sử dụng landmarks để tạo mask chính xác hơn
if hasattr(face, 'kps') and face.kps is not None:
try:
landmarks = face.kps.astype(int)
# Tạo convex hull từ landmarks
hull = cv2.convexHull(landmarks)
cv2.fillPoly(face_mask, [hull], 1)
# Mở rộng mask
if expand_ratio > 0:
kernel_size = max(3, int(min(width, height) * expand_ratio * 0.1))
if kernel_size % 2 == 0:
kernel_size += 1
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
face_mask = cv2.dilate(face_mask, kernel, iterations=1)
except Exception as e:
print(f"Landmark method failed: {e}, using bbox method")
# Fallback về bbox method
face_mask = self._create_face_mask_from_bbox(face.bbox, width, height, expand_ratio)
else:
# Fallback về bbox method
face_mask = self._create_face_mask_from_bbox(face.bbox, width, height, expand_ratio)
# Làm mượt mask
face_mask = cv2.GaussianBlur(face_mask.astype(np.float32), (15, 15), 0)
face_mask = (face_mask > 0.1).astype(np.uint8)
return face_mask
except Exception as e:
print(f"Error in get_face_mask_insightface: {e}")
# Return empty mask if all fails
return np.zeros((image.size[1], image.size[0]), dtype=np.uint8)
def _create_face_mask_from_bbox(self, bbox, width, height, expand_ratio):
"""Helper method to create face mask from bounding box"""
face_mask = np.zeros((height, width), dtype=np.uint8)
x1, y1, x2, y2 = bbox.astype(int)
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
w, h = x2 - x1, y2 - y1
new_w = int(w * (1 + expand_ratio))
new_h = int(h * (1 + expand_ratio))
new_x1 = max(0, center_x - new_w // 2)
new_y1 = max(0, center_y - new_h // 2)
new_x2 = min(width, center_x + new_w // 2)
new_y2 = min(height, center_y + new_h // 2)
if new_x2 > new_x1 and new_y2 > new_y1:
mask_region = np.zeros((new_y2 - new_y1, new_x2 - new_x1), dtype=np.uint8)
center_region_x = (new_x2 - new_x1) // 2
center_region_y = (new_y2 - new_y1) // 2
if center_region_x > 0 and center_region_y > 0:
cv2.ellipse(mask_region,
(center_region_x, center_region_y),
(center_region_x, center_region_y),
0, 0, 360, 1, -1)
face_mask[new_y1:new_y2, new_x1:new_x2] = mask_region
return face_mask
def get_hair_mask(self, image):
"""Lấy mask tóc từ SegFormer"""
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = F.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
hair_mask = (pred_seg == 2).astype(np.uint8)
return hair_mask
def swap_face_insightface(self, source_image, target_image, enhance_result=True):
"""Hoán đổi khuôn mặt với InsightFace và tùy chọn enhance"""
if self.face_swapper is None:
print("⚠️ Face swapper không khả dụng, sử dụng blend method")
return target_image
try:
source_faces = self.face_app.get(source_image)
target_faces = self.face_app.get(target_image)
if len(source_faces) == 0 or len(target_faces) == 0:
print("⚠️ Không tìm thấy khuôn mặt trong một hoặc cả hai ảnh")
return target_image
# Chọn khuôn mặt lớn nhất
source_face = max(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
target_face = max(target_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
# Thực hiện face swap
result = self.face_swapper.get(target_image, target_face, source_face, paste_back=True)
# Enhance kết quả với GFPGAN nếu có
if enhance_result and self.gfpgan_enhancer is not None:
result = self.enhance_face_gfpgan(result)
return result
except Exception as e:
print(f"❌ Lỗi trong face swap: {e}")
return target_image
def blend_images_with_mask(self, base_image, overlay_image, mask, blur_kernel=5):
"""
Blend hai ảnh với mask được cải thiện. Đảm bảo mọi shape đều khớp nhau.
Args:
base_image: numpy.ndarray (H, W, 3)
overlay_image: numpy.ndarray (H, W, 3)
mask: numpy.ndarray (H, W)
blur_kernel: int (làm mượt mask)
Returns:
numpy.ndarray (H, W, 3)
"""
# Đảm bảo overlay_image cùng shape với base_image
if overlay_image.shape != base_image.shape:
overlay_image = cv2.resize(overlay_image, (base_image.shape[1], base_image.shape[0]))
# Đảm bảo mask cùng shape (chỉ cần 2 chiều height, width)
if mask.shape[:2] != base_image.shape[:2]:
mask = cv2.resize(mask, (base_image.shape[1], base_image.shape[0]))
mask_float = mask.astype(np.float32)
if blur_kernel > 0:
mask_float = cv2.GaussianBlur(mask_float, (blur_kernel, blur_kernel), 0)
mask_float = mask_float / mask_float.max() if mask_float.max() > 0 else mask_float
mask_3d = np.stack([mask_float] * 3, axis=-1)
result = base_image.astype(np.float32) * (1 - mask_3d) + overlay_image.astype(np.float32) * mask_3d
return result.astype(np.uint8)
def post_process_result(self, image):
"""Post-processing để cải thiện kết quả cuối cùng"""
# Color correction
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
# Áp dụng CLAHE cho kênh L
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
l = clahe.apply(l)
# Merge lại
lab = cv2.merge([l, a, b])
result = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
return result
def process_swap(self, full_face_image, hair_head_image, enhance_faces=True, progress=gr.Progress()):
"""
Hàm chính để xử lý face/hair swap cho Gradio với GFPGAN enhancement
Args:
full_face_image: ảnh có khuôn mặt (từ Gradio)
hair_head_image: ảnh có tóc (từ Gradio)
enhance_faces: có sử dụng GFPGAN không
progress: Gradio progress bar
Returns:
PIL Image: kết quả
"""
if full_face_image is None or hair_head_image is None:
return None, "❌ Vui lòng tải lên cả hai ảnh!"
try:
progress(0.1, desc="🔄 Đang xử lý ảnh...")
# Convert to PIL nếu cần
if isinstance(full_face_image, np.ndarray):
full_face_image = Image.fromarray(full_face_image)
if isinstance(hair_head_image, np.ndarray):
hair_head_image = Image.fromarray(hair_head_image)
# Chuyển về RGB
if full_face_image.mode != 'RGB':
full_face_image = full_face_image.convert('RGB')
if hair_head_image.mode != 'RGB':
hair_head_image = hair_head_image.convert('RGB')
progress(0.2, desc="📏 Resize ảnh...")
# Resize về cùng kích thước
target_size = (512, 512)
full_face_resized = full_face_image.resize(target_size, Image.Resampling.LANCZOS)
hair_head_resized = hair_head_image.resize(target_size, Image.Resampling.LANCZOS)
# Convert to numpy
full_face_array = np.array(full_face_resized)
hair_head_array = np.array(hair_head_resized)
progress(0.3, desc="💇 Phân tích tóc...")
hair_mask = self.get_hair_mask(hair_head_resized)
progress(0.5, desc="👤 Phân tích khuôn mặt...")
face_mask = self.get_face_mask_insightface(full_face_resized, expand_ratio=0.2)
progress(0.7, desc="🔄 Thực hiện face swap...")
try:
face_swapped = self.swap_face_insightface(
full_face_array,
hair_head_array,
enhance_result=enhance_faces
)
except Exception as e:
print(f"Face swap thất bại: {e}, dùng blend method...")
face_swapped = self.blend_images_with_mask(
hair_head_array, full_face_array, face_mask, blur_kernel=7
)
progress(0.85, desc="🎨 Kết hợp tóc...")
final_result = self.blend_images_with_mask(
face_swapped, hair_head_array, hair_mask, blur_kernel=3
)
progress(0.95, desc="✨ Post-processing...")
final_result = self.post_process_result(final_result)
progress(1.0, desc="✅ Hoàn thành!")
result_image = Image.fromarray(final_result)
return result_image, "✅ Hoàn thành thành công!"
except Exception as e:
error_msg = f"❌ Lỗi xử lý: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return None, error_msg
# Khởi tạo model (chỉ một lần)
print("🚀 Đang khởi tạo Face Hair Swapper với GFPGAN...")
try:
swapper = FaceHairSwapperGradio()
print("✅ Khởi tạo thành công!")
except Exception as e:
print(f"❌ Lỗi khởi tạo: {e}")
swapper = None
def gradio_interface(full_face_img, hair_head_img, enhance_faces):
"""Interface function cho Gradio"""
if swapper is None:
return None, "❌ Model chưa được khởi tạo!"
return swapper.process_swap(full_face_img, hair_head_img, enhance_faces)
# Tạo Gradio Interface
def create_gradio_app():
"""Tạo Gradio app với enhanced features"""
with gr.Blocks(
title="Face & Hair Swap with GFPGAN",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
.image-container {
height: 400px !important;
}
"""
) as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>🎭 Face & Hair Swap với GFPGAN Enhancement</h1>
<p>Hoán đổi khuôn mặt và tóc giữa hai bức ảnh sử dụng AI với chất lượng cao</p>
</div>
""")
with gr.Row():
with gr.Column():
gr.HTML("<h3>📸 Ảnh có khuôn mặt</h3>")
full_face_input = gr.Image(
label="Tải ảnh có khuôn mặt",
type="pil",
height=400
)
with gr.Column():
gr.HTML("<h3>💇 Ảnh có tóc</h3>")
hair_head_input = gr.Image(
label="Tải ảnh có tóc",
type="pil",
height=400
)
with gr.Row():
enhance_checkbox = gr.Checkbox(
label="✨ Sử dụng GFPGAN để tăng chất lượng khuôn mặt",
value=True,
info="Tăng chất lượng và độ chi tiết của khuôn mặt trong kết quả"
)
with gr.Row():
process_btn = gr.Button(
"🚀 Bắt đầu hoán đổi",
variant="primary",
size="lg"
)
with gr.Row():
with gr.Column():
result_output = gr.Image(
label="Kết quả",
type="pil",
height=400
)
with gr.Column():
status_output = gr.Textbox(
label="Trạng thái",
lines=3,
interactive=False
)
# Event handlers
process_btn.click(
fn=gradio_interface,
inputs=[full_face_input, hair_head_input, enhance_checkbox],
outputs=[result_output, status_output],
api_name="swap_face_hair"
)
# Examples và hướng dẫn
gr.HTML("<h3>📋 Hướng dẫn sử dụng:</h3>")
gr.HTML("""
<ul>
<li><strong>Ảnh khuôn mặt:</strong> Tải ảnh có khuôn mặt rõ ràng, nhìn thẳng</li>
<li><strong>Ảnh tóc:</strong> Tải ảnh có kiểu tóc đẹp mong muốn</li>
<li><strong>GFPGAN Enhancement:</strong> Bật để tăng chất lượng khuôn mặt (khuyến nghị)</li>
<li><strong>Chất lượng tốt nhất:</strong> Ảnh 512x512px, ánh sáng đều, không bị mờ</li>
<li><strong>Models sử dụng:</strong> InsightFace + SegFormer + GFPGAN</li>
</ul>
""")
gr.HTML("<h3>🔧 Thông tin kỹ thuật:</h3>")
gr.HTML("""
<ul>
<li><strong>Face Swap:</strong> inswapper_128.onnx (InsightFace)</li>
<li><strong>Face Enhancement:</strong> GFPGANv1.4.pth</li>
<li><strong>Hair Segmentation:</strong> SegFormer</li>
<li><strong>Post-processing:</strong> CLAHE, Color Correction</li>
</ul>
""")
return demo
# Chạy app
if __name__ == "__main__":
demo = create_gradio_app()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=True
)
# Cho Hugging Face Spaces
demo = create_gradio_app()