|
|
|
|
|
|
|
import os |
|
|
|
|
|
os.environ["STREAMLIT_SERVER_ENABLECORS"] = "false" |
|
os.environ["STREAMLIT_SERVER_ENABLEWEBSOCKETCOMPRESSION"] = "false" |
|
|
|
import streamlit as st |
|
import numpy as np |
|
import cv2 |
|
import tempfile |
|
import traceback |
|
from PIL import Image |
|
import io |
|
|
|
|
|
|
|
|
|
|
|
for key, default in { |
|
"uploaded_image": None, |
|
"uploaded_video": None, |
|
"uploaded_target_image": None, |
|
"output_video": None, |
|
"output_image": None, |
|
"mode": "video", |
|
}.items(): |
|
if key not in st.session_state: |
|
st.session_state[key] = default |
|
|
|
|
|
|
|
|
|
def _has_cuda(): |
|
try: |
|
import torch |
|
return torch.cuda.is_available() |
|
except Exception: |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Face Swapper", layout="centered") |
|
st.title("🎭 Savvy Face Swapper") |
|
|
|
|
|
mode = st.radio("Select Mode:", ["Video", "Image"], horizontal=True) |
|
st.session_state.mode = mode.lower() |
|
|
|
st.sidebar.title("⚙️ Settings") |
|
|
|
|
|
proc_res = st.sidebar.selectbox( |
|
"Processing Resolution", |
|
["Original", "720p", "480p"], |
|
index=1, |
|
help="Frames are resized before detection/swap. Lower = faster." |
|
) |
|
|
|
|
|
if st.session_state.mode == "video": |
|
|
|
fps_cap = st.sidebar.selectbox( |
|
"Target FPS", |
|
["Original", "24", "15"], |
|
index=0, |
|
help="Lower target FPS drops frames during processing for speed." |
|
) |
|
|
|
|
|
keep_original_res = st.sidebar.checkbox( |
|
"Keep original output resolution", |
|
value=False, |
|
help="If enabled, processed frames are upscaled back to the input size." |
|
) |
|
|
|
|
|
max_faces = st.sidebar.slider( |
|
"Max faces per frame", min_value=1, max_value=8, value=4, |
|
help="At most this many faces will be swapped per frame." |
|
) |
|
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=True) |
|
def load_models(): |
|
""" |
|
Load InsightFace detectors and the inswapper model once. |
|
Auto-select GPU if available, else CPU. |
|
Be tolerant of insightface versions (providers kwarg may not exist). |
|
""" |
|
import insightface |
|
from insightface.app import FaceAnalysis |
|
|
|
|
|
wants_cuda = _has_cuda() |
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if wants_cuda else ["CPUExecutionProvider"] |
|
|
|
|
|
ctx_id = 0 if wants_cuda else -1 |
|
app = FaceAnalysis(name="buffalo_l") |
|
app.prepare(ctx_id=ctx_id, det_size=(640, 640)) |
|
|
|
|
|
|
|
swapper = None |
|
try: |
|
swapper = insightface.model_zoo.get_model( |
|
"inswapper_128.onnx", |
|
download=True, |
|
download_zip=False, |
|
providers=providers |
|
) |
|
except TypeError: |
|
|
|
swapper = insightface.model_zoo.get_model( |
|
"inswapper_128.onnx", |
|
download=True, |
|
download_zip=False |
|
) |
|
except Exception as e: |
|
|
|
raise RuntimeError(f"Failed to load inswapper_128.onnx: {e}") |
|
|
|
return app, swapper, providers, ctx_id |
|
|
|
|
|
with st.spinner("Loading models…"): |
|
try: |
|
app, swapper, providers, ctx_id = load_models() |
|
except Exception as e: |
|
st.error("❌ Model loading failed. See logs for details.") |
|
st.error(str(e)) |
|
st.stop() |
|
|
|
st.caption( |
|
f"Device: {'GPU (CUDA)' if ctx_id == 0 else 'CPU'} • ORT Providers: {', '.join(providers)}" |
|
) |
|
|
|
|
|
|
|
|
|
def _target_size_for_height(width, height, target_h): |
|
if target_h <= 0 or height == 0: |
|
return width, height |
|
scale = target_h / float(height) |
|
new_w = max(1, int(round(width * scale))) |
|
new_h = max(1, int(round(height * scale))) |
|
return new_w, new_h |
|
|
|
def _get_proc_size_choice(orig_w, orig_h, choice): |
|
if choice == "720p": |
|
return _target_size_for_height(orig_w, orig_h, 720) |
|
if choice == "480p": |
|
return _target_size_for_height(orig_w, orig_h, 480) |
|
return orig_w, orig_h |
|
|
|
def _parse_fps_cap(original_fps, cap_choice): |
|
|
|
if not original_fps or original_fps <= 0: |
|
original_fps = 25.0 |
|
if cap_choice == "Original": |
|
return max(1.0, original_fps), 1 |
|
try: |
|
tgt = float(cap_choice) |
|
tgt = max(1.0, tgt) |
|
step = max(1, int(round(original_fps / tgt))) |
|
write_fps = max(1.0, original_fps / step) |
|
return write_fps, step |
|
except Exception: |
|
return max(1.0, original_fps), 1 |
|
|
|
def _safe_imdecode(file_bytes): |
|
arr = np.frombuffer(file_bytes, np.uint8) |
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR) |
|
return img |
|
|
|
def _cv2_to_pil(image): |
|
"""Convert OpenCV BGR image to PIL RGB image""" |
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
return Image.fromarray(image_rgb) |
|
|
|
def _pil_to_cv2(image): |
|
"""Convert PIL RGB image to OpenCV BGR image""" |
|
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
|
def swap_faces_in_image( |
|
source_image_bgr: np.ndarray, |
|
target_image_bgr: np.ndarray, |
|
proc_res: str, |
|
max_faces: int |
|
): |
|
|
|
try: |
|
source_faces = app.get(source_image_bgr) |
|
except Exception as e: |
|
st.error(f"❌ FaceAnalysis failed on source image: {e}") |
|
return None |
|
|
|
if not source_faces: |
|
st.error("❌ No face detected in the source image.") |
|
return None |
|
|
|
|
|
source_face = max( |
|
source_faces, |
|
key=lambda f: max(1, int((f.bbox[2]-f.bbox[0]) * (f.bbox[3]-f.bbox[1]))) |
|
) |
|
|
|
|
|
orig_h, orig_w = target_image_bgr.shape[:2] |
|
proc_w, proc_h = _get_proc_size_choice(orig_w, orig_h, proc_res) |
|
|
|
|
|
if (proc_w, proc_h) != (orig_w, orig_h): |
|
target_image_proc = cv2.resize(target_image_bgr, (proc_w, proc_h), interpolation=cv2.INTER_AREA) |
|
else: |
|
target_image_proc = target_image_bgr.copy() |
|
|
|
try: |
|
|
|
try: |
|
target_faces = app.get(target_image_proc) |
|
except Exception as det_e: |
|
st.error(f"[ERROR] Detection failed on target image: {det_e}") |
|
target_faces = [] |
|
|
|
if not target_faces: |
|
st.warning("⚠️ No faces detected in the target image.") |
|
return _cv2_to_pil(target_image_bgr) |
|
|
|
|
|
target_faces = sorted( |
|
target_faces, |
|
key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]), |
|
reverse=True |
|
)[:max_faces] |
|
|
|
|
|
result_image = target_image_proc.copy() |
|
for tface in target_faces: |
|
try: |
|
result_image = swapper.get(result_image, tface, source_face, paste_back=True) |
|
except Exception as swap_e: |
|
st.error(f"Face swap error: {swap_e}") |
|
continue |
|
|
|
|
|
if (proc_w, proc_h) != (orig_w, orig_h): |
|
result_image = cv2.resize(result_image, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC) |
|
|
|
return _cv2_to_pil(result_image) |
|
|
|
except Exception as e: |
|
st.error(f"❌ Error processing image: {e}") |
|
traceback.print_exc() |
|
return _cv2_to_pil(target_image_bgr) |
|
|
|
def swap_faces_in_video( |
|
image_bgr: np.ndarray, |
|
video_path: str, |
|
proc_res: str, |
|
fps_cap: str, |
|
keep_original_res: bool, |
|
max_faces: int, |
|
progress |
|
): |
|
|
|
try: |
|
source_faces = app.get(image_bgr) |
|
except Exception as e: |
|
st.error(f"❌ FaceAnalysis failed on source image: {e}") |
|
return None |
|
|
|
if not source_faces: |
|
st.error("❌ No face detected in the source image.") |
|
return None |
|
|
|
|
|
source_face = max( |
|
source_faces, |
|
key=lambda f: max(1, int((f.bbox[2]-f.bbox[0]) * (f.bbox[3]-f.bbox[1]))) |
|
) |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
st.error("❌ Could not open the uploaded video. Try re-encoding to MP4/H.264.") |
|
return None |
|
|
|
|
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
orig_fps = float(cap.get(cv2.CAP_PROP_FPS)) |
|
if orig_fps <= 0 or np.isnan(orig_fps): |
|
orig_fps = 25.0 |
|
|
|
|
|
proc_w, proc_h = _get_proc_size_choice(orig_w, orig_h, proc_res) |
|
write_fps, frame_step = _parse_fps_cap(orig_fps, fps_cap) |
|
out_w, out_h = (orig_w, orig_h) if keep_original_res else (proc_w, proc_h) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_out: |
|
output_path = tmp_out.name |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
out = cv2.VideoWriter(output_path, fourcc, write_fps, (out_w, out_h)) |
|
if not out.isOpened(): |
|
cap.release() |
|
st.error( |
|
"❌ Failed to open VideoWriter. " |
|
"Try setting Processing Resolution to 480p or Target FPS to 24." |
|
) |
|
return None |
|
|
|
st.info( |
|
f"Input: {orig_w}×{orig_h} @ {orig_fps:.2f} fps | " |
|
f"Processing: {proc_w}×{proc_h} | Writing: {out_w}×{out_h} @ {write_fps:.2f} fps | " |
|
f"Frame step: {frame_step} (1 = process every frame) | " |
|
f"Max faces/frame: {max_faces}" |
|
) |
|
|
|
|
|
read_idx = 0 |
|
processed_frames = 0 |
|
|
|
try: |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
if frame_step > 1 and (read_idx % frame_step != 0): |
|
read_idx += 1 |
|
if frame_count > 0: |
|
progress.progress(min(1.0, read_idx / frame_count)) |
|
continue |
|
|
|
|
|
if (proc_w, proc_h) != (orig_w, orig_h): |
|
proc_frame = cv2.resize(frame, (proc_w, proc_h), interpolation=cv2.INTER_AREA) |
|
else: |
|
proc_frame = frame |
|
|
|
try: |
|
|
|
try: |
|
target_faces = app.get(proc_frame) |
|
except Exception as det_e: |
|
print(f"[WARN] Detection failed on frame {read_idx}: {det_e}") |
|
target_faces = [] |
|
|
|
if target_faces: |
|
|
|
target_faces = sorted( |
|
target_faces, |
|
key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]), |
|
reverse=True |
|
)[:max_faces] |
|
|
|
|
|
result_frame = proc_frame.copy() |
|
for tface in target_faces: |
|
try: |
|
result_frame = swapper.get(result_frame, tface, source_face, paste_back=True) |
|
except Exception as swap_e: |
|
print(f"[WARN] Face swap failed on frame {read_idx}: {swap_e}") |
|
continue |
|
|
|
|
|
if keep_original_res and (proc_w, proc_h) != (orig_w, orig_h): |
|
result_frame = cv2.resize(result_frame, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC) |
|
|
|
out.write(result_frame) |
|
|
|
except Exception as e: |
|
|
|
print(f"[WARN] Frame {read_idx} failed: {e}") |
|
traceback.print_exc() |
|
fallback = proc_frame |
|
if keep_original_res and (proc_w, proc_h) != (orig_w, orig_h): |
|
fallback = cv2.resize(proc_frame, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC) |
|
out.write(fallback) |
|
|
|
read_idx += 1 |
|
processed_frames += 1 |
|
|
|
|
|
if frame_count > 0: |
|
progress.progress(min(1.0, read_idx / frame_count)) |
|
elif processed_frames % 30 == 0: |
|
|
|
progress.progress(min(1.0, (processed_frames % 300) / 300.0)) |
|
|
|
except Exception as e: |
|
st.error(f"❌ Error during video processing: {e}") |
|
traceback.print_exc() |
|
finally: |
|
cap.release() |
|
out.release() |
|
|
|
return output_path |
|
|
|
|
|
|
|
|
|
st.write("Upload a **source face image** and a **target**, preview them, tweak options, then start swapping.") |
|
|
|
image_file = st.file_uploader("Upload Source Image", type=["jpg", "jpeg", "png"]) |
|
|
|
if st.session_state.mode == "video": |
|
target_file = st.file_uploader("Upload Target Video", type=["mp4", "mov", "mkv", "avi"]) |
|
else: |
|
target_file = st.file_uploader("Upload Target Image", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
if image_file: |
|
st.subheader("📷 Source Image Preview") |
|
st.image(image_file, caption="Source Image", use_column_width=True) |
|
|
|
if target_file: |
|
if st.session_state.mode == "video": |
|
st.subheader("🎬 Target Video Preview") |
|
st.video(target_file) |
|
else: |
|
st.subheader("🖼️ Target Image Preview") |
|
st.image(target_file, caption="Target Image", use_column_width=True) |
|
|
|
|
|
|
|
|
|
if st.button("🚀 Start Face Swap"): |
|
if not image_file or not target_file: |
|
st.error("⚠️ Please upload both a source image and a target.") |
|
else: |
|
|
|
try: |
|
image_bytes = image_file.getvalue() |
|
source_image = _safe_imdecode(image_bytes) |
|
if source_image is None: |
|
st.error("❌ Failed to decode source image. Please use a valid JPG/PNG.") |
|
st.stop() |
|
except Exception as e: |
|
st.error(f"❌ Failed to read the source image bytes: {e}") |
|
st.stop() |
|
|
|
if st.session_state.mode == "video": |
|
|
|
try: |
|
|
|
video_bytes = target_file.getvalue() |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video: |
|
tmp_video.write(video_bytes) |
|
tmp_video_path = tmp_video.name |
|
except Exception as e: |
|
st.error(f"❌ Failed to save the uploaded video to a temp file: {e}") |
|
st.stop() |
|
|
|
with st.spinner("Processing video… This can take a while ⏳"): |
|
progress_bar = st.progress(0) |
|
output_path = swap_faces_in_video( |
|
source_image, |
|
tmp_video_path, |
|
proc_res=proc_res, |
|
fps_cap=fps_cap, |
|
keep_original_res=keep_original_res, |
|
max_faces=max_faces, |
|
progress=progress_bar |
|
) |
|
|
|
if output_path: |
|
st.success("✅ Face swapping completed!") |
|
st.subheader("📺 Output Video Preview") |
|
st.video(output_path) |
|
|
|
|
|
try: |
|
with open(output_path, "rb") as f: |
|
st.download_button( |
|
label="⬇️ Download Processed Video", |
|
data=f, |
|
file_name="output_swapped_video.mp4", |
|
mime="video/mp4" |
|
) |
|
except Exception as e: |
|
st.warning(f"⚠️ Could not open the output file for download: {e}") |
|
|
|
|
|
try: |
|
os.remove(tmp_video_path) |
|
except Exception: |
|
pass |
|
|
|
else: |
|
|
|
try: |
|
target_bytes = target_file.getvalue() |
|
target_image = _safe_imdecode(target_bytes) |
|
if target_image is None: |
|
st.error("❌ Failed to decode target image. Please use a valid JPG/PNG.") |
|
st.stop() |
|
except Exception as e: |
|
st.error(f"❌ Failed to read the target image bytes: {e}") |
|
st.stop() |
|
|
|
with st.spinner("Processing image…"): |
|
result_image = swap_faces_in_image( |
|
source_image, |
|
target_image, |
|
proc_res=proc_res, |
|
max_faces=max_faces |
|
) |
|
|
|
if result_image: |
|
st.success("✅ Face swapping completed!") |
|
st.subheader("🖼️ Output Image Preview") |
|
st.image(result_image, caption="Result Image", use_column_width=True) |
|
|
|
|
|
buf = io.BytesIO() |
|
result_image.save(buf, format="JPEG") |
|
byte_im = buf.getvalue() |
|
|
|
st.download_button( |
|
label="⬇️ Download Processed Image", |
|
data=byte_im, |
|
file_name="output_swapped_image.jpg", |
|
mime="image/jpeg" |
|
) |
|
|
|
|
|
|
|
|
|
with st.expander("🩺 Diagnostics"): |
|
st.write( |
|
"- If you see **SessionInfo** errors: this app initializes `st.session_state` early and defers heavy loads via " |
|
"`@st.cache_resource`. If errors persist, restart the Space/Runtime.\n" |
|
"- If output is jumpy/stutters: lower **Target FPS** or choose **480p** processing.\n" |
|
"- If video fails to open: re-encode your input to **MP4 (H.264, AAC)**.\n" |
|
"- If VideoWriter fails: try **480p** and **Target FPS 24**." |
|
) |