Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import os | |
import math | |
import uuid | |
import json | |
from dataclasses import dataclass | |
from typing import List, Optional, Tuple | |
import gradio as gr | |
from PIL import Image, ImageOps | |
import pandas as pd | |
import numpy as np | |
import mediapipe as mp | |
# ---------------------------- | |
# Globals & configuration | |
# ---------------------------- | |
DATASET_PATH = os.getenv("HAIRSTYLE_DATASET", "data/enhanced_full_hairstyle_dataset.csv") | |
HAIRSTYLE_FOLDER = os.getenv("HAIRSTYLE_FOLDER", "hairstyles") | |
RESULTS_DIR = os.getenv("RESULTS_DIR", "generated_results") | |
os.makedirs(RESULTS_DIR, exist_ok=True) | |
# Tune these if your images tend to sit too high/low by default | |
DEFAULT_VERT_OFFSET_PCT = -0.25 # relative to style_forehead_height | |
DEFAULT_HORIZ_OFFSET_PX = 0 | |
# MediaPipe indices used | |
LM_LEFT_EYE_OUTER = 33 | |
LM_RIGHT_EYE_OUTER = 263 | |
LM_FOREHEAD_TOP = 10 | |
LM_FOREHEAD_LEFT = 103 | |
LM_FOREHEAD_RIGHT = 332 | |
# Initialize MediaPipe FaceMesh once (safer with concurrency=1 in Gradio queue) | |
mp_face_mesh = mp.solutions.face_mesh | |
FACE_MESH = mp_face_mesh.FaceMesh( | |
static_image_mode=True, | |
max_num_faces=1, | |
refine_landmarks=True, | |
min_detection_confidence=0.5 | |
) | |
class Style: | |
name: str | |
gender: str | |
img_path: str | |
img_rgba: Optional[Image.Image] | |
style_forehead_w: int | |
style_forehead_h: int | |
def _safe_read_dataset(path: str) -> pd.DataFrame: | |
if not os.path.exists(path): | |
# Create an empty frame with expected columns to avoid crashes | |
cols = ["name", "gender", "forehead_width_px", "forehead_height_px", "image_file"] | |
return pd.DataFrame(columns=cols) | |
df = pd.read_csv(path) | |
# Normalize columns and fill NaNs | |
for col in ["name", "gender", "image_file"]: | |
if col not in df.columns: | |
df[col] = "" | |
df[col] = df[col].fillna("") | |
for col in ["forehead_width_px", "forehead_height_px"]: | |
if col not in df.columns: | |
df[col] = 0 | |
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0).astype(int) | |
return df | |
def _load_styles(df: pd.DataFrame) -> List[Style]: | |
styles: List[Style] = [] | |
if not os.path.exists(HAIRSTYLE_FOLDER): | |
return styles | |
for _, row in df.iterrows(): | |
img_file = row.get("image_file", "").strip() | |
if not img_file: | |
continue | |
path = os.path.join(HAIRSTYLE_FOLDER, img_file) | |
if not os.path.exists(path): | |
continue | |
try: | |
img = Image.open(path).convert("RGBA") | |
except Exception: | |
img = None | |
styles.append( | |
Style( | |
name=str(row.get("name", "Style")).strip() or "Style", | |
gender=str(row.get("gender", "All")).strip() or "All", | |
img_path=path, | |
img_rgba=img, | |
style_forehead_w=int(row.get("forehead_width_px", 0) or 0), | |
style_forehead_h=int(row.get("forehead_height_px", 0) or 0), | |
) | |
) | |
return styles | |
def _to_rgb(image: Image.Image) -> Image.Image: | |
return image.convert("RGB") if image.mode != "RGB" else image | |
def get_face_landmarks(img_rgb: Image.Image): | |
"""Return MediaPipe face landmarks for a PIL RGB image or None.""" | |
np_img = np.array(img_rgb) | |
results = FACE_MESH.process(np_img) | |
if results.multi_face_landmarks: | |
return results.multi_face_landmarks[0] | |
return None | |
def _rotation_angle_rad(landmarks, w: int, h: int) -> float: | |
"""Estimate roll angle using outer eye corners.""" | |
left = landmarks.landmark[LM_LEFT_EYE_OUTER] | |
right = landmarks.landmark[LM_RIGHT_EYE_OUTER] | |
x1, y1 = left.x * w, left.y * h | |
x2, y2 = right.x * w, right.y * h | |
# angle of the line from left to right; positive means head tilted CCW | |
angle = math.atan2(y2 - y1, x2 - x1) | |
return angle | |
def _compute_forehead_metrics(landmarks, w: int, h: int) -> Tuple[int, Tuple[int, int]]: | |
left = landmarks.landmark[LM_FOREHEAD_LEFT] | |
right = landmarks.landmark[LM_FOREHEAD_RIGHT] | |
top = landmarks.landmark[LM_FOREHEAD_TOP] | |
forehead_width_px = int(abs((right.x - left.x) * w)) | |
top_x = int(top.x * w) | |
top_y = int(top.y * h) | |
return forehead_width_px, (top_x, top_y) | |
def _paste_rgba(base: Image.Image, overlay: Image.Image, pos: Tuple[int, int]) -> Image.Image: | |
canvas = base.copy().convert("RGBA") | |
tmp = Image.new("RGBA", canvas.size, (0, 0, 0, 0)) | |
x, y = pos | |
tmp.paste(overlay, (x, y), overlay) | |
return Image.alpha_composite(canvas, tmp) | |
def apply_hairstyle_impl( | |
upload_img: Optional[Image.Image], | |
webcam_img: Optional[Image.Image], | |
input_source: str, | |
style_index: Optional[int], | |
scale_tweak: float, | |
vert_offset: int, | |
horiz_offset: int, | |
opacity: float, | |
) -> Tuple[Optional[Image.Image], str]: | |
user_img = upload_img if input_source == "Upload" else webcam_img | |
if user_img is None: | |
return None, "❌ No image from selected source." | |
if style_index is None or style_index < 0 or style_index >= len(STYLES): | |
return _to_rgb(user_img), "ℹ️ Select a hairstyle from the gallery." | |
style = STYLES[style_index] | |
if style.img_rgba is None: | |
return _to_rgb(user_img), f"⚠️ Could not load image for: {style.name}" | |
try: | |
img_rgb = _to_rgb(user_img) | |
w, h = img_rgb.size | |
lms = get_face_landmarks(img_rgb) | |
if not lms: | |
return img_rgb, "⚠️ No face detected. Showing original image. Try a clearer, front‑facing photo." | |
# Compute rotation and size | |
angle_rad = _rotation_angle_rad(lms, w, h) | |
forehead_w_px, (top_x, top_y) = _compute_forehead_metrics(lms, w, h) | |
style_fw = max(style.style_forehead_w, 1) | |
style_fh = max(style.style_forehead_h, 1) | |
scale_ratio = (forehead_w_px / style_fw) * float(scale_tweak) | |
new_w = max(int(style.img_rgba.width * scale_ratio), 1) | |
new_h = max(int(style.img_rgba.height * scale_ratio), 1) | |
# Rotate hair to match head roll | |
hair = style.img_rgba.resize((new_w, new_h), resample=Image.LANCZOS) | |
angle_deg = math.degrees(angle_rad) | |
hair = hair.rotate(angle=-angle_deg, expand=True, resample=Image.BICUBIC) | |
# Compute placement | |
attach_y = top_y - int(style_fh * scale_ratio) | |
attach_y += int(DEFAULT_VERT_OFFSET_PCT * style_fh * scale_ratio) | |
attach_y += int(vert_offset) | |
attach_x = top_x - hair.width // 2 + int(horiz_offset) + int(DEFAULT_HORIZ_OFFSET_PX) | |
# Clamp within canvas (x can be <0 to allow partial paste, but we clamp y >= 0) | |
attach_y = max(0, attach_y) | |
# Optional opacity tweak | |
if 0 <= opacity < 1: | |
a = hair.split()[-1] | |
a = ImageOps.autocontrast(a) | |
a = a.point(lambda px: int(px * opacity)) | |
hair = Image.merge("RGBA", (*hair.split()[:3], a)) | |
composed = _paste_rgba(img_rgb, hair, (attach_x, attach_y)).convert("RGB") | |
return composed, "✅ Success! Tip: fine‑tune scale/offsets if needed." | |
except Exception as e: | |
return _to_rgb(user_img), f"❌ Error: {str(e)}" | |
# ---------------------------- | |
# Load data once | |
# ---------------------------- | |
DATASET_DF = _safe_read_dataset(DATASET_PATH) | |
STYLES: List[Style] = _load_styles(DATASET_DF) | |
# Precompute gallery data (image + caption) | |
GALLERY_ITEMS: List[Tuple[Image.Image, str]] = [] | |
for s in STYLES: | |
if s.img_rgba is not None: | |
thumb = s.img_rgba.copy() | |
GALLERY_ITEMS.append((thumb, s.name)) | |
# ---------------------------- | |
# Gradio helpers | |
# ---------------------------- | |
def update_gallery(gender: str): | |
if gender == "All": | |
indices = list(range(len(STYLES))) | |
else: | |
indices = [i for i, s in enumerate(STYLES) if s.gender.lower() == gender.lower()] | |
filtered = [] | |
for i in indices: | |
s = STYLES[i] | |
if s.img_rgba is not None: | |
filtered.append((s.img_rgba, s.name)) | |
return filtered, indices | |
def select_hairstyle(evt: gr.SelectData, filtered_inds: List[int]): | |
if filtered_inds and 0 <= evt.index < len(filtered_inds): | |
return int(filtered_inds[evt.index]) | |
return None | |
def update_source(source: str): | |
return gr.update(visible=source == "Upload"), gr.update(visible=source == "Webcam") | |
def on_apply(upload_img, webcam_img, input_source, selected_index, scale_tweak, vert_offset, horiz_offset, opacity): | |
img, msg = apply_hairstyle_impl( | |
upload_img, webcam_img, input_source, selected_index, scale_tweak, vert_offset, horiz_offset, opacity | |
) | |
return img, msg | |
def on_random(filtered_indices: List[int]): | |
if not filtered_indices: | |
return None, "ℹ️ No styles available for current filter." | |
import random | |
return int(random.choice(filtered_indices)), "🎲 Random style selected!" | |
def on_save(result_img: Optional[Image.Image]): | |
if result_img is None: | |
return None, "⚠️ Generate a preview first." | |
file_path = os.path.join(RESULTS_DIR, f"hairstyle_{uuid.uuid4().hex}.png") | |
result_img.save(file_path, format="PNG") | |
return file_path, "💾 Saved! Use the button below to download." | |
# ---------------------------- | |
# UI | |
# ---------------------------- | |
with gr.Blocks(theme=gr.themes.Soft(), css=".small-hint{font-size:12px;opacity:.8}") as demo: | |
gr.Markdown("## 💇 Virtual Hairstyle Try‑On") | |
gr.Markdown( | |
"Upload a front‑facing photo or use your webcam. Click a hairstyle to select it, then fine‑tune using the controls." | |
) | |
status = gr.Textbox(label="Status", interactive=False) | |
filtered_indices = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_source = gr.Radio(["Upload", "Webcam"], value="Upload", label="Input Source") | |
upload_col = gr.Column(visible=True) | |
with upload_col: | |
upload_img = gr.Image(sources=["upload"], type="pil", label="📷 Upload Your Photo (front‑facing)") | |
webcam_col = gr.Column(visible=False) | |
with webcam_col: | |
webcam_img = gr.Image(sources=["webcam"], type="pil", label="📹 Live Webcam", streaming=True) | |
gender_filter = gr.Dropdown(choices=["All", "Male", "Female"], value="All", label="🎭 Filter by Gender") | |
hairstyle_gallery = gr.Gallery( | |
label="🎨 Available Hairstyles (click to select)", columns=4, height=380, object_fit="contain" | |
) | |
selected_index = gr.Number(value=None, visible=False) | |
selected_label = gr.Markdown("*No style selected*", elem_classes=["small-hint"]) | |
random_btn = gr.Button("🎲 Random Style") | |
with gr.Column(scale=2): | |
result_output = gr.Image(label="🔍 Preview Result", height=520) | |
with gr.Row(): | |
scale_tweak = gr.Slider(0.7, 1.4, value=1.0, step=0.01, label="Scale tweak") | |
opacity = gr.Slider(0.6, 1.0, value=1.0, step=0.01, label="Opacity") | |
with gr.Row(): | |
vert_offset = gr.Slider(-150, 150, value=0, step=1, label="Vertical offset (px)") | |
horiz_offset = gr.Slider(-150, 150, value=0, step=1, label="Horizontal offset (px)") | |
with gr.Row(): | |
apply_btn = gr.Button("✨ Apply Hairstyle", variant="primary") | |
save_btn = gr.Button("💾 Save Preview") | |
dl = gr.DownloadButton("⬇️ Download PNG", file_name="hairstyle_result.png") | |
# Visibility switching | |
input_source.change(update_source, inputs=input_source, outputs=[upload_col, webcam_col]) | |
# Gallery filtering / selection | |
def _update_label(i): | |
if i is None or not isinstance(i, (int, float)): | |
return "*No style selected*" | |
idx = int(i) | |
if 0 <= idx < len(STYLES): | |
return f"**Selected:** {STYLES[idx].name}" | |
return "*No style selected*" | |
gender_filter.change(update_gallery, inputs=gender_filter, outputs=[hairstyle_gallery, filtered_indices]) | |
hairstyle_gallery.select(select_hairstyle, inputs=filtered_indices, outputs=selected_index) | |
selected_index.change(_update_label, inputs=selected_index, outputs=selected_label) | |
random_btn.click(on_random, inputs=filtered_indices, outputs=[selected_index, status]) | |
# Apply + live preview | |
apply_inputs = [upload_img, webcam_img, input_source, selected_index, scale_tweak, vert_offset, horiz_offset, opacity] | |
apply_btn.click(on_apply, inputs=apply_inputs, outputs=[result_output, status]) | |
# Live webcam auto-apply (gives a smooth preview). Keep concurrency=1 for FaceMesh safety. | |
webcam_img.change(on_apply, inputs=apply_inputs, outputs=[result_output, status], every=0.6) | |
# Save & download | |
def _save_and_link(img): | |
path, msg = on_save(img) | |
# Update download component with the new file | |
return msg, gr.update(value=path) | |
save_btn.click(_save_and_link, inputs=[result_output], outputs=[status, dl]) | |
# Initial gallery | |
demo.load(update_gallery, inputs=gender_filter, outputs=[hairstyle_gallery, filtered_indices]) | |
# Limit concurrency to avoid MediaPipe thread issues, enable queue for responsiveness | |
if __name__ == "__main__": | |
demo.queue(concurrency_count=1) | |
demo.launch() | |