savvy7007's picture
Update app.py
34234e4 verified
# =========================
# app.py (production-ready, safer)
# =========================
import os
# Streamlit server tweaks (safe on HF Spaces / containers)
os.environ["STREAMLIT_SERVER_ENABLECORS"] = "false"
os.environ["STREAMLIT_SERVER_ENABLEWEBSOCKETCOMPRESSION"] = "false"
import streamlit as st
import numpy as np
import cv2
import tempfile
import traceback
from PIL import Image
import io
# -------------------------
# VERY EARLY: initialize session state
# -------------------------
# This prevents the "SessionInfo before it was initialized" glitch on some boots
for key, default in {
"uploaded_image": None,
"uploaded_video": None,
"uploaded_target_image": None,
"output_video": None,
"output_image": None,
"mode": "video", # 'video' or 'image'
}.items():
if key not in st.session_state:
st.session_state[key] = default
# -------------------------
# GPU check (optional torch import)
# -------------------------
def _has_cuda():
try:
import torch
return torch.cuda.is_available()
except Exception:
# If torch isn't installed, just say no CUDA
return False
# -----------------------------------
# Page & Sidebar (controls for speed)
# -----------------------------------
st.set_page_config(page_title="Face Swapper", layout="centered")
st.title("🎭 Savvy Face Swapper")
# Mode selection
mode = st.radio("Select Mode:", ["Video", "Image"], horizontal=True)
st.session_state.mode = mode.lower()
st.sidebar.title("⚙️ Settings")
# Downscale to speed up detection & swapping
proc_res = st.sidebar.selectbox(
"Processing Resolution",
["Original", "720p", "480p"],
index=1,
help="Frames are resized before detection/swap. Lower = faster."
)
# For video mode only
if st.session_state.mode == "video":
# Skip frames to hit a lower effective FPS
fps_cap = st.sidebar.selectbox(
"Target FPS",
["Original", "24", "15"],
index=0,
help="Lower target FPS drops frames during processing for speed."
)
# Keep the original output resolution even if we process smaller
keep_original_res = st.sidebar.checkbox(
"Keep original output resolution",
value=False,
help="If enabled, processed frames are upscaled back to the input size."
)
# Limit faces per frame (helps speed on crowded scenes)
max_faces = st.sidebar.slider(
"Max faces per frame", min_value=1, max_value=8, value=4,
help="At most this many faces will be swapped per frame."
)
# -------------------------
# Model loading (cached)
# -------------------------
@st.cache_resource(show_spinner=True)
def load_models():
"""
Load InsightFace detectors and the inswapper model once.
Auto-select GPU if available, else CPU.
Be tolerant of insightface versions (providers kwarg may not exist).
"""
import insightface
from insightface.app import FaceAnalysis
# Desired providers for ORT
wants_cuda = _has_cuda()
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if wants_cuda else ["CPUExecutionProvider"]
# Face detector/landmarks (retinaface + arcface in buffalo_l)
ctx_id = 0 if wants_cuda else -1
app = FaceAnalysis(name="buffalo_l")
app.prepare(ctx_id=ctx_id, det_size=(640, 640))
# Face swapper (inswapper_128)
# Some insightface versions accept providers=..., some don't.
swapper = None
try:
swapper = insightface.model_zoo.get_model(
"inswapper_128.onnx",
download=True,
download_zip=False,
providers=providers
)
except TypeError:
# Fallback path: older insightface without providers kwarg
swapper = insightface.model_zoo.get_model(
"inswapper_128.onnx",
download=True,
download_zip=False
)
except Exception as e:
# Last resort: surface a helpful error
raise RuntimeError(f"Failed to load inswapper_128.onnx: {e}")
return app, swapper, providers, ctx_id
# Initialize models
with st.spinner("Loading models…"):
try:
app, swapper, providers, ctx_id = load_models()
except Exception as e:
st.error("❌ Model loading failed. See logs for details.")
st.error(str(e))
st.stop()
st.caption(
f"Device: {'GPU (CUDA)' if ctx_id == 0 else 'CPU'} • ORT Providers: {', '.join(providers)}"
)
# -------------------------
# Helpers
# -------------------------
def _target_size_for_height(width, height, target_h):
if target_h <= 0 or height == 0:
return width, height
scale = target_h / float(height)
new_w = max(1, int(round(width * scale)))
new_h = max(1, int(round(height * scale)))
return new_w, new_h
def _get_proc_size_choice(orig_w, orig_h, choice):
if choice == "720p":
return _target_size_for_height(orig_w, orig_h, 720)
if choice == "480p":
return _target_size_for_height(orig_w, orig_h, 480)
return orig_w, orig_h
def _parse_fps_cap(original_fps, cap_choice):
# Handle bad/zero FPS from container decoders
if not original_fps or original_fps <= 0:
original_fps = 25.0
if cap_choice == "Original":
return max(1.0, original_fps), 1 # write_fps, frame_step
try:
tgt = float(cap_choice)
tgt = max(1.0, tgt)
step = max(1, int(round(original_fps / tgt)))
write_fps = max(1.0, original_fps / step)
return write_fps, step
except Exception:
return max(1.0, original_fps), 1
def _safe_imdecode(file_bytes):
arr = np.frombuffer(file_bytes, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
return img
def _cv2_to_pil(image):
"""Convert OpenCV BGR image to PIL RGB image"""
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return Image.fromarray(image_rgb)
def _pil_to_cv2(image):
"""Convert PIL RGB image to OpenCV BGR image"""
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# -------------------------------------
# Core: face swap functions
# -------------------------------------
def swap_faces_in_image(
source_image_bgr: np.ndarray,
target_image_bgr: np.ndarray,
proc_res: str,
max_faces: int
):
# Validate source image
try:
source_faces = app.get(source_image_bgr)
except Exception as e:
st.error(f"❌ FaceAnalysis failed on source image: {e}")
return None
if not source_faces:
st.error("❌ No face detected in the source image.")
return None
# Use the largest detected face
source_face = max(
source_faces,
key=lambda f: max(1, int((f.bbox[2]-f.bbox[0]) * (f.bbox[3]-f.bbox[1])))
)
# Get processing size
orig_h, orig_w = target_image_bgr.shape[:2]
proc_w, proc_h = _get_proc_size_choice(orig_w, orig_h, proc_res)
# Resize target image for processing
if (proc_w, proc_h) != (orig_w, orig_h):
target_image_proc = cv2.resize(target_image_bgr, (proc_w, proc_h), interpolation=cv2.INTER_AREA)
else:
target_image_proc = target_image_bgr.copy()
try:
# Detect faces on target image
try:
target_faces = app.get(target_image_proc)
except Exception as det_e:
st.error(f"[ERROR] Detection failed on target image: {det_e}")
target_faces = []
if not target_faces:
st.warning("⚠️ No faces detected in the target image.")
return _cv2_to_pil(target_image_bgr)
# Optionally limit faces to largest N
target_faces = sorted(
target_faces,
key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]),
reverse=True
)[:max_faces]
# Swap faces
result_image = target_image_proc.copy()
for tface in target_faces:
try:
result_image = swapper.get(result_image, tface, source_face, paste_back=True)
except Exception as swap_e:
st.error(f"Face swap error: {swap_e}")
continue
# Resize back to original if needed
if (proc_w, proc_h) != (orig_w, orig_h):
result_image = cv2.resize(result_image, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
return _cv2_to_pil(result_image)
except Exception as e:
st.error(f"❌ Error processing image: {e}")
traceback.print_exc()
return _cv2_to_pil(target_image_bgr)
def swap_faces_in_video(
image_bgr: np.ndarray,
video_path: str,
proc_res: str,
fps_cap: str,
keep_original_res: bool,
max_faces: int,
progress
):
# Validate source image
try:
source_faces = app.get(image_bgr)
except Exception as e:
st.error(f"❌ FaceAnalysis failed on source image: {e}")
return None
if not source_faces:
st.error("❌ No face detected in the source image.")
return None
# Use the largest detected face
source_face = max(
source_faces,
key=lambda f: max(1, int((f.bbox[2]-f.bbox[0]) * (f.bbox[3]-f.bbox[1])))
)
# Open video
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
st.error("❌ Could not open the uploaded video. Try re-encoding to MP4/H.264.")
return None
# Read properties
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
orig_fps = float(cap.get(cv2.CAP_PROP_FPS))
if orig_fps <= 0 or np.isnan(orig_fps):
orig_fps = 25.0
# Decide processing size & FPS behavior
proc_w, proc_h = _get_proc_size_choice(orig_w, orig_h, proc_res)
write_fps, frame_step = _parse_fps_cap(orig_fps, fps_cap)
out_w, out_h = (orig_w, orig_h) if keep_original_res else (proc_w, proc_h)
# Prepare output writer
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_out:
output_path = tmp_out.name
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, write_fps, (out_w, out_h))
if not out.isOpened():
cap.release()
st.error(
"❌ Failed to open VideoWriter. "
"Try setting Processing Resolution to 480p or Target FPS to 24."
)
return None
st.info(
f"Input: {orig_w}×{orig_h} @ {orig_fps:.2f} fps | "
f"Processing: {proc_w}×{proc_h} | Writing: {out_w}×{out_h} @ {write_fps:.2f} fps | "
f"Frame step: {frame_step} (1 = process every frame) | "
f"Max faces/frame: {max_faces}"
)
# Process loop
read_idx = 0
processed_frames = 0
try:
while True:
ret, frame = cap.read()
if not ret:
break
# FPS cap by skipping frames
if frame_step > 1 and (read_idx % frame_step != 0):
read_idx += 1
if frame_count > 0:
progress.progress(min(1.0, read_idx / frame_count))
continue
# Resize for processing
if (proc_w, proc_h) != (orig_w, orig_h):
proc_frame = cv2.resize(frame, (proc_w, proc_h), interpolation=cv2.INTER_AREA)
else:
proc_frame = frame
try:
# Detect faces on processed frame
try:
target_faces = app.get(proc_frame)
except Exception as det_e:
print(f"[WARN] Detection failed on frame {read_idx}: {det_e}")
target_faces = []
if target_faces:
# Optionally limit faces to largest N for speed
target_faces = sorted(
target_faces,
key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]),
reverse=True
)[:max_faces]
# Swap into a working buffer
result_frame = proc_frame.copy()
for tface in target_faces:
try:
result_frame = swapper.get(result_frame, tface, source_face, paste_back=True)
except Exception as swap_e:
print(f"[WARN] Face swap failed on frame {read_idx}: {swap_e}")
continue
# Upscale back to original if requested
if keep_original_res and (proc_w, proc_h) != (orig_w, orig_h):
result_frame = cv2.resize(result_frame, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
out.write(result_frame)
except Exception as e:
# Log & write fallback frame (processed size or original size)
print(f"[WARN] Frame {read_idx} failed: {e}")
traceback.print_exc()
fallback = proc_frame
if keep_original_res and (proc_w, proc_h) != (orig_w, orig_h):
fallback = cv2.resize(proc_frame, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
out.write(fallback)
read_idx += 1
processed_frames += 1
# Update progress
if frame_count > 0:
progress.progress(min(1.0, read_idx / frame_count))
elif processed_frames % 30 == 0:
# Fallback progress for unknown frame counts
progress.progress(min(1.0, (processed_frames % 300) / 300.0))
except Exception as e:
st.error(f"❌ Error during video processing: {e}")
traceback.print_exc()
finally:
cap.release()
out.release()
return output_path
# -------------------------
# UI: Uploads & Preview
# -------------------------
st.write("Upload a **source face image** and a **target**, preview them, tweak options, then start swapping.")
image_file = st.file_uploader("Upload Source Image", type=["jpg", "jpeg", "png"])
if st.session_state.mode == "video":
target_file = st.file_uploader("Upload Target Video", type=["mp4", "mov", "mkv", "avi"])
else:
target_file = st.file_uploader("Upload Target Image", type=["jpg", "jpeg", "png"])
# Previews
if image_file:
st.subheader("📷 Source Image Preview")
st.image(image_file, caption="Source Image", use_column_width=True)
if target_file:
if st.session_state.mode == "video":
st.subheader("🎬 Target Video Preview")
st.video(target_file)
else:
st.subheader("🖼️ Target Image Preview")
st.image(target_file, caption="Target Image", use_column_width=True)
# -------------------------
# Run button
# -------------------------
if st.button("🚀 Start Face Swap"):
if not image_file or not target_file:
st.error("⚠️ Please upload both a source image and a target.")
else:
# Read source image
try:
image_bytes = image_file.getvalue()
source_image = _safe_imdecode(image_bytes)
if source_image is None:
st.error("❌ Failed to decode source image. Please use a valid JPG/PNG.")
st.stop()
except Exception as e:
st.error(f"❌ Failed to read the source image bytes: {e}")
st.stop()
if st.session_state.mode == "video":
# Process video
try:
# Persist temp video for OpenCV
video_bytes = target_file.getvalue()
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video:
tmp_video.write(video_bytes)
tmp_video_path = tmp_video.name
except Exception as e:
st.error(f"❌ Failed to save the uploaded video to a temp file: {e}")
st.stop()
with st.spinner("Processing video… This can take a while ⏳"):
progress_bar = st.progress(0)
output_path = swap_faces_in_video(
source_image,
tmp_video_path,
proc_res=proc_res,
fps_cap=fps_cap,
keep_original_res=keep_original_res,
max_faces=max_faces,
progress=progress_bar
)
if output_path:
st.success("✅ Face swapping completed!")
st.subheader("📺 Output Video Preview")
st.video(output_path)
# Download button
try:
with open(output_path, "rb") as f:
st.download_button(
label="⬇️ Download Processed Video",
data=f,
file_name="output_swapped_video.mp4",
mime="video/mp4"
)
except Exception as e:
st.warning(f"⚠️ Could not open the output file for download: {e}")
# Cleanup temp input video
try:
os.remove(tmp_video_path)
except Exception:
pass
else:
# Process image
try:
target_bytes = target_file.getvalue()
target_image = _safe_imdecode(target_bytes)
if target_image is None:
st.error("❌ Failed to decode target image. Please use a valid JPG/PNG.")
st.stop()
except Exception as e:
st.error(f"❌ Failed to read the target image bytes: {e}")
st.stop()
with st.spinner("Processing image…"):
result_image = swap_faces_in_image(
source_image,
target_image,
proc_res=proc_res,
max_faces=max_faces
)
if result_image:
st.success("✅ Face swapping completed!")
st.subheader("🖼️ Output Image Preview")
st.image(result_image, caption="Result Image", use_column_width=True)
# Download button
buf = io.BytesIO()
result_image.save(buf, format="JPEG")
byte_im = buf.getvalue()
st.download_button(
label="⬇️ Download Processed Image",
data=byte_im,
file_name="output_swapped_image.jpg",
mime="image/jpeg"
)
# -------------
# Diagnostics
# -------------
with st.expander("🩺 Diagnostics"):
st.write(
"- If you see **SessionInfo** errors: this app initializes `st.session_state` early and defers heavy loads via "
"`@st.cache_resource`. If errors persist, restart the Space/Runtime.\n"
"- If output is jumpy/stutters: lower **Target FPS** or choose **480p** processing.\n"
"- If video fails to open: re-encode your input to **MP4 (H.264, AAC)**.\n"
"- If VideoWriter fails: try **480p** and **Target FPS 24**."
)