Face-Aging / app.py
Robys01's picture
Fixed gr.Error bug not showing error message.
b754309
raw
history blame
7.66 kB
import os
import time
import torch
from models import UNet
from test_functions import process_image
from PIL import Image
import gradio as gr
from gradio_client import Client, handle_file
from huggingface_hub import hf_hub_download
import tempfile
from dotenv import load_dotenv
load_dotenv()
from s3 import imagine
# Model download & caching directory (created in Dockerfile)
MODEL_DIR = "/tmp/model"
os.makedirs(MODEL_DIR, exist_ok=True)
MODEL_PATH = os.path.join(MODEL_DIR, "best_unet_model.pth")
def download_model():
print("Starting model download at", time.strftime("%Y-%m-%d %H:%M:%S"))
path = hf_hub_download(
repo_id="Robys01/face-aging",
filename="best_unet_model.pth",
local_dir=MODEL_DIR,
cache_dir=os.environ.get("HUGGINGFACE_HUB_CACHE"),
)
print(f"Model downloaded to {path}")
if not os.path.exists(MODEL_PATH):
download_model()
# Load model
model = UNet()
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"), weights_only=False))
model.eval()
def age_image(image_path: str, source_age: int, target_age: int) -> Image.Image:
try:
image = Image.open(image_path)
if image.mode not in ["RGB", "L"]:
print(f"Converting image from {image.mode} to RGB")
image = image.convert("RGB")
processed_image = process_image(model, image, source_age, target_age)
imagine(image_path, source_age)
return processed_image
except ValueError as e:
if "No faces detected" in str(e):
raise gr.Error("No faces detected in the image. Please upload an image with a clear, visible face.")
else:
raise gr.Error(f"Error processing image: {str(e)}")
except Exception as e:
raise gr.Error(f"Unexpected error: {str(e)}")
def age_video(image_path: str, source_age: int, target_age: int, duration: int, fps: int) -> str:
try:
image = Image.open(image_path)
orig_tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
orig_path = orig_tmp.name
image.save(orig_path)
orig_tmp.close()
aged_img = age_image(image_path, source_age, target_age)
aged_tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
aged_path = aged_tmp.name
aged_img.save(aged_path)
aged_tmp.close()
imagine(image_path, source_age)
client = Client("Robys01/Face-Morphing")
try:
result = client.predict(
image_files=[handle_file(orig_path), handle_file(aged_path)],
duration=duration,
fps=fps,
method="Dlib",
align_resize=False,
order_images=False,
guideline=False,
api_name="/predict"
)
except Exception as e:
raise gr.Error(f"Error during video generation: {e}")
# Unpack response for video path
video_path = None
# handle (data, msg) tuple
if isinstance(result, tuple):
data, msg = result
video_path = data.get('video') if isinstance(data, dict) else None
print(f"Response message: {msg}")
if not video_path or not os.path.exists(video_path):
raise gr.Error(f"Video file not found: {video_path}")
return video_path
except gr.Error:
# Re-raise Gradio errors as-is
raise
except Exception as e:
raise gr.Error(f"Unexpected error in video generation: {str(e)}")
def age_timelapse(image_path: str, source_age: int) -> str:
try:
image = Image.open(image_path)
target_ages = [10, 20, 30, 50, 70]
# Filter out ages too close to source
filtered = [age for age in target_ages if abs(age - source_age) >= 4]
# Combine with source and sort
ages = sorted(set(filtered + [source_age]))
temp_handles = []
for age in ages:
if age == source_age:
img = image
else:
img = age_image(image_path, source_age, age)
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
path = tmp.name
img.save(path)
tmp.close()
temp_handles.append(handle_file(path))
imagine(image_path, source_age)
client = Client("Robys01/Face-Morphing")
try:
result = client.predict(
image_files=temp_handles,
duration=3,
fps=20,
method="Dlib",
align_resize=False,
order_images=False,
guideline=False,
api_name="/predict"
)
except Exception as e:
raise gr.Error(f"Error generating timelapse video: {e}")
video_path = None
if isinstance(result, tuple):
data, msg = result
video_path = data.get('video') if isinstance(data, dict) else None
print(f"Response message: {msg}")
if not video_path or not os.path.exists(video_path):
raise gr.Error(f"Timelapse video not found: {video_path}")
return video_path
except gr.Error:
# Re-raise Gradio errors as-is
raise
except Exception as e:
raise gr.Error(f"Unexpected error in timelapse generation: {str(e)}")
demo_age_image = gr.Interface(
fn=age_image,
inputs=[
gr.Image(type="filepath", label="Input Image"),
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"),
gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age")
],
outputs=gr.Image(type="pil", label="Aged Image"),
examples=[
["examples/girl.jpg", 14, 50],
["examples/man.jpg", 45, 70],
["examples/man.jpg", 45, 20],
["examples/trump.jpg", 74, 30],
],
cache_examples=True,
description="Upload an image along with a source age approximation and a target age to generate an aged version of the face."
)
demo_age_video = gr.Interface(
fn=age_video,
inputs=[
gr.Image(type="filepath", label="Input Image"),
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"),
gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age"),
gr.Slider(label="Duration (seconds)", minimum=1, maximum=10, step=1, value=3),
gr.Slider(label="Frames per second (fps)", minimum=2, maximum=60, step=1, value=30),
],
outputs=gr.Video(label="Aged Video", format="mp4"),
examples=[
["examples/girl.jpg", 14, 50, 3, 30],
["examples/man.jpg", 45, 70, 3, 30],
["examples/man.jpg", 45, 20, 3, 30],
],
cache_examples=True,
description="Generate a video of the aging process."
)
demo_age_timelapse = gr.Interface(
fn=age_timelapse,
inputs=[gr.Image(type="filepath", label="Input Image"), gr.Slider(10, 90, value=20, step=1, label="Current age")],
outputs=[gr.Video(label="Aging Timelapse", format="mp4")],
examples=[
["examples/girl.jpg", 14],
["examples/man.jpg", 45],
],
cache_examples=True,
description="Generate a timelapse video showing the aging process at different ages."
)
if __name__ == "__main__":
iface = gr.TabbedInterface(
[demo_age_image, demo_age_video, demo_age_timelapse],
tab_names=["Face Aging", "Aging Video", "Aging Timelapse"],
title="Face Aging Demo",
).queue()
iface.launch(server_name="0.0.0.0", server_port=7000)