Spaces:
Running
Running
import os | |
import time | |
import torch | |
from models import UNet | |
from test_functions import process_image | |
from PIL import Image | |
import gradio as gr | |
from gradio_client import Client, handle_file | |
from huggingface_hub import hf_hub_download | |
import tempfile | |
from dotenv import load_dotenv | |
load_dotenv() | |
from s3 import imagine | |
# Model download & caching directory (created in Dockerfile) | |
MODEL_DIR = "/tmp/model" | |
os.makedirs(MODEL_DIR, exist_ok=True) | |
MODEL_PATH = os.path.join(MODEL_DIR, "best_unet_model.pth") | |
def download_model(): | |
print("Starting model download at", time.strftime("%Y-%m-%d %H:%M:%S")) | |
path = hf_hub_download( | |
repo_id="Robys01/face-aging", | |
filename="best_unet_model.pth", | |
local_dir=MODEL_DIR, | |
cache_dir=os.environ.get("HUGGINGFACE_HUB_CACHE"), | |
) | |
print(f"Model downloaded to {path}") | |
if not os.path.exists(MODEL_PATH): | |
download_model() | |
# Load model | |
model = UNet() | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"), weights_only=False)) | |
model.eval() | |
def age_image(image_path: str, source_age: int, target_age: int) -> Image.Image: | |
try: | |
image = Image.open(image_path) | |
if image.mode not in ["RGB", "L"]: | |
print(f"Converting image from {image.mode} to RGB") | |
image = image.convert("RGB") | |
processed_image = process_image(model, image, source_age, target_age) | |
imagine(image_path, source_age) | |
return processed_image | |
except ValueError as e: | |
if "No faces detected" in str(e): | |
raise gr.Error("No faces detected in the image. Please upload an image with a clear, visible face.") | |
else: | |
raise gr.Error(f"Error processing image: {str(e)}") | |
except Exception as e: | |
raise gr.Error(f"Unexpected error: {str(e)}") | |
def age_video(image_path: str, source_age: int, target_age: int, duration: int, fps: int) -> str: | |
try: | |
image = Image.open(image_path) | |
orig_tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) | |
orig_path = orig_tmp.name | |
image.save(orig_path) | |
orig_tmp.close() | |
aged_img = age_image(image_path, source_age, target_age) | |
aged_tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) | |
aged_path = aged_tmp.name | |
aged_img.save(aged_path) | |
aged_tmp.close() | |
imagine(image_path, source_age) | |
client = Client("Robys01/Face-Morphing") | |
try: | |
result = client.predict( | |
image_files=[handle_file(orig_path), handle_file(aged_path)], | |
duration=duration, | |
fps=fps, | |
method="Dlib", | |
align_resize=False, | |
order_images=False, | |
guideline=False, | |
api_name="/predict" | |
) | |
except Exception as e: | |
raise gr.Error(f"Error during video generation: {e}") | |
# Unpack response for video path | |
video_path = None | |
# handle (data, msg) tuple | |
if isinstance(result, tuple): | |
data, msg = result | |
video_path = data.get('video') if isinstance(data, dict) else None | |
print(f"Response message: {msg}") | |
if not video_path or not os.path.exists(video_path): | |
raise gr.Error(f"Video file not found: {video_path}") | |
return video_path | |
except gr.Error: | |
# Re-raise Gradio errors as-is | |
raise | |
except Exception as e: | |
raise gr.Error(f"Unexpected error in video generation: {str(e)}") | |
def age_timelapse(image_path: str, source_age: int) -> str: | |
try: | |
image = Image.open(image_path) | |
target_ages = [10, 20, 30, 50, 70] | |
# Filter out ages too close to source | |
filtered = [age for age in target_ages if abs(age - source_age) >= 4] | |
# Combine with source and sort | |
ages = sorted(set(filtered + [source_age])) | |
temp_handles = [] | |
for age in ages: | |
if age == source_age: | |
img = image | |
else: | |
img = age_image(image_path, source_age, age) | |
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) | |
path = tmp.name | |
img.save(path) | |
tmp.close() | |
temp_handles.append(handle_file(path)) | |
imagine(image_path, source_age) | |
client = Client("Robys01/Face-Morphing") | |
try: | |
result = client.predict( | |
image_files=temp_handles, | |
duration=3, | |
fps=20, | |
method="Dlib", | |
align_resize=False, | |
order_images=False, | |
guideline=False, | |
api_name="/predict" | |
) | |
except Exception as e: | |
raise gr.Error(f"Error generating timelapse video: {e}") | |
video_path = None | |
if isinstance(result, tuple): | |
data, msg = result | |
video_path = data.get('video') if isinstance(data, dict) else None | |
print(f"Response message: {msg}") | |
if not video_path or not os.path.exists(video_path): | |
raise gr.Error(f"Timelapse video not found: {video_path}") | |
return video_path | |
except gr.Error: | |
# Re-raise Gradio errors as-is | |
raise | |
except Exception as e: | |
raise gr.Error(f"Unexpected error in timelapse generation: {str(e)}") | |
demo_age_image = gr.Interface( | |
fn=age_image, | |
inputs=[ | |
gr.Image(type="filepath", label="Input Image"), | |
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"), | |
gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age") | |
], | |
outputs=gr.Image(type="pil", label="Aged Image"), | |
examples=[ | |
["examples/girl.jpg", 14, 50], | |
["examples/man.jpg", 45, 70], | |
["examples/man.jpg", 45, 20], | |
["examples/trump.jpg", 74, 30], | |
], | |
cache_examples=True, | |
description="Upload an image along with a source age approximation and a target age to generate an aged version of the face." | |
) | |
demo_age_video = gr.Interface( | |
fn=age_video, | |
inputs=[ | |
gr.Image(type="filepath", label="Input Image"), | |
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"), | |
gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age"), | |
gr.Slider(label="Duration (seconds)", minimum=1, maximum=10, step=1, value=3), | |
gr.Slider(label="Frames per second (fps)", minimum=2, maximum=60, step=1, value=30), | |
], | |
outputs=gr.Video(label="Aged Video", format="mp4"), | |
examples=[ | |
["examples/girl.jpg", 14, 50, 3, 30], | |
["examples/man.jpg", 45, 70, 3, 30], | |
["examples/man.jpg", 45, 20, 3, 30], | |
], | |
cache_examples=True, | |
description="Generate a video of the aging process." | |
) | |
demo_age_timelapse = gr.Interface( | |
fn=age_timelapse, | |
inputs=[gr.Image(type="filepath", label="Input Image"), gr.Slider(10, 90, value=20, step=1, label="Current age")], | |
outputs=[gr.Video(label="Aging Timelapse", format="mp4")], | |
examples=[ | |
["examples/girl.jpg", 14], | |
["examples/man.jpg", 45], | |
], | |
cache_examples=True, | |
description="Generate a timelapse video showing the aging process at different ages." | |
) | |
if __name__ == "__main__": | |
iface = gr.TabbedInterface( | |
[demo_age_image, demo_age_video, demo_age_timelapse], | |
tab_names=["Face Aging", "Aging Video", "Aging Timelapse"], | |
title="Face Aging Demo", | |
).queue() | |
iface.launch(server_name="0.0.0.0", server_port=7000) |