ghep_image / app.py
VanNguyen1214's picture
Update app.py
ac4af0c verified
import gradio as gr
from overlay import overlay_source
from detect_face import predict, NUM_CLASSES
import os
from pathlib import Path
BASE_DIR = Path(__file__).parent # thư mục chứa app.py
FOLDER = BASE_DIR / "example_wigs"
# --- Hàm load ảnh từ folder ---
def load_images_from_folder(folder_path: str) -> list[str]:
"""
Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path.
"""
supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
if not os.path.isdir(folder_path):
print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.")
return []
files = [
os.path.join(folder_path, fn)
for fn in os.listdir(folder_path)
if os.path.splitext(fn)[1].lower() in supported
]
if not files:
print(f"Không tìm thấy hình trong: {folder_path}")
return files
# --- Handler khi click thumbnail của Gallery ---
def on_gallery_select(evt: gr.SelectData):
"""
Xử lý khi click vào ảnh trong gallery - tối ưu và robust.
"""
val = evt.value
if isinstance(val, dict):
img = val.get("image")
if isinstance(img, str): return img
if isinstance(img, dict):
path = img.get("path") or img.get("url")
if isinstance(path, str): return path
for v in img.values():
if isinstance(v, str) and os.path.isfile(v):
return v
for v in val.values():
if isinstance(v, str) and os.path.isfile(v):
return v
raise ValueError(f"Không trích được filepath từ dict: {val}")
if isinstance(val, str):
return val
raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")
# --- Hàm xác định folder dựa trên phân lớp ---
def infer_folder(image) -> str:
cls = predict(image)["predicted_class"]
folder = str(FOLDER / cls)
return folder
# --- Hàm gộp: phân loại + load ảnh ---
def handle_bg_change(image):
"""
Khi thay đổi background:
1. Phân loại khuôn mặt
2. Load ảnh từ folder tương ứng
"""
if image is None:
return "", []
try:
folder = infer_folder(image)
images = load_images_from_folder(folder)
return folder, images
except Exception as e:
print(f"Lỗi xử lý ảnh: {e}")
return "", []
# --- Xây dựng giao diện Gradio ---
def build_demo():
with gr.Blocks(title="Xử lý hai hình ảnh", theme=gr.themes.Soft()) as demo:
gr.Markdown("Upload Background & Source, click **Run** để ghép tóc.")
with gr.Row():
bg = gr.Image(type="pil", label="Background", height=500)
src = gr.Image(type="pil", label="Source", height=500, interactive=False)
out = gr.Image(label="Result", height=500, interactive=False)
folder_path_box = gr.Textbox(label="Folder path", visible=False)
with gr.Row():
gallery = gr.Gallery(
label="Recommend For You",
height=300,
value=[],
type="filepath",
interactive=False,
scale = 6,
columns=5,
object_fit="cover",
allow_preview=True
)
btn = gr.Button("🔄 Run", variant="primary",scale = 1)
# Chạy ghép tóc
btn.click(fn=overlay_source, inputs=[bg, src], outputs=[out])
# Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý
bg.change(
fn=handle_bg_change,
inputs=[bg],
outputs=[folder_path_box, gallery],
show_progress=True
)
# Nút tải lại ảnh thủ công (backup)
# Khi chọn ảnh trong gallery, cập nhật vào khung Source
gallery.select(
fn=on_gallery_select,
outputs=[src]
)
return demo
if __name__ == "__main__":
build_demo().launch()