import gradio as gr from loadimg import load_img import spaces from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms from pydub import AudioSegment from PIL import Image import numpy as np import os import tempfile import uuid import time from concurrent.futures import ThreadPoolExecutor from moviepy import VideoFileClip, vfx, concatenate_videoclips, ImageSequenceClip torch.set_float32_matmul_precision("medium") device = "cuda" if torch.cuda.is_available() else "cpu" # Load both BiRefNet models birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True) birefnet.to(device) birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True) birefnet_lite.to(device) transform_image = transforms.Compose([ transforms.Resize((768, 768)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # Function to process a single frame def process_frame(frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color): try: pil_image = Image.fromarray(frame) if bg_type == "Color": processed_image = process(pil_image, color, fast_mode) elif bg_type == "Image": processed_image = process(pil_image, bg, fast_mode) elif bg_type == "Video": background_frame = background_frames[bg_frame_index] # Access the correct background frame bg_frame_index += 1 background_image = Image.fromarray(background_frame) processed_image = process(pil_image, background_image, fast_mode) else: processed_image = pil_image # Default to original image if no background is selected return np.array(processed_image), bg_frame_index except Exception as e: print(f"Error processing frame: {e}") return frame, bg_frame_index @spaces.GPU def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=True, max_workers=10): try: start_time = time.time() # Start the timer video = VideoFileClip(vid) if fps == 0: fps = video.fps audio = video.audio frames = list(video.iter_frames(fps=fps)) processed_frames = [] yield gr.update(visible=True), gr.update(visible=False) if bg_type == "Video": background_video = VideoFileClip(bg_video) if background_video.duration < video.duration: if video_handling == "slow_down": background_video = background_video.fx(vfx.speedx, factor=video.duration / background_video.duration) else: # video_handling == "loop" background_video = concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1)) background_frames = list(background_video.iter_frames(fps=fps)) else: background_frames = None bg_frame_index = 0 # Initialize background frame index with ThreadPoolExecutor(max_workers=max_workers) as executor: # Pass bg_frame_index as part of the function arguments futures = [executor.submit(process_frame, frames[i], bg_type, bg_image, fast_mode, bg_frame_index + i, background_frames, color) for i in range(len(frames))] for i, future in enumerate(futures): result, _ = future.result() # No need to update bg_frame_index here processed_frames.append(result) elapsed_time = time.time() - start_time yield result, None processed_video = ImageSequenceClip(processed_frames, fps=fps) processed_video = processed_video.with_audio(audio) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: temp_filepath = temp_file.name processed_video.write_videofile(temp_filepath, codec="libx264") elapsed_time = time.time() - start_time yield gr.update(visible=False), gr.update(visible=True) yield processed_frames[-1], temp_filepath except Exception as e: print(f"Error: {e}") elapsed_time = time.time() - start_time yield gr.update(visible=False), gr.update(visible=True) yield None, f"Error processing video: {e}" def process(image, bg, fast_mode=False): image_size = image.size input_images = transform_image(image).unsqueeze(0).to(device) model = birefnet_lite if fast_mode else birefnet with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) if isinstance(bg, str) and bg.startswith("#"): color_rgb = tuple(int(bg[i:i+2], 16) for i in (1, 3, 5)) background = Image.new("RGBA", image_size, color_rgb + (255,)) elif isinstance(bg, Image.Image): background = bg.convert("RGBA").resize(image_size) else: background = Image.open(bg).convert("RGBA").resize(image_size) image = Image.composite(image, background, mask) return image # Create custom dark theme dark_theme = gr.themes.Soft( primary_hue="purple", secondary_hue="blue", neutral_hue="slate", ).set( body_background_fill_dark="#0f0f23", block_background_fill_dark="#1a1b2e", block_border_color_dark="#16213e", input_background_fill_dark="#16213e", button_primary_background_fill_dark="#6366f1", button_primary_background_fill_hover_dark="#4f46e5", button_secondary_background_fill_dark="#374151", button_secondary_background_fill_hover_dark="#4b5563" ) with gr.Blocks(theme=dark_theme) as demo: with gr.Row(): in_video = gr.Video(label="Input Video", interactive=True) stream_image = gr.Image(label="Streaming Output", visible=False) out_video = gr.Video(label="Final Output Video") # Settings panels aligned below input video with gr.Column(): # Row 1: Background Type and FPS (two smaller panels) with gr.Row(): bg_type = gr.Radio(["Color", "Image", "Video"], label="Background Type", value="Color", interactive=True) fps_slider = gr.Slider( minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 = original fps)", interactive=True ) # Row 2: Dynamic background options (full width) color_picker = gr.ColorPicker(label="Background Color", value="#00FF00", visible=True, interactive=True) bg_image = gr.Image(label="Background Image", type="filepath", visible=False, interactive=True) bg_video = gr.Video(label="Background Video", visible=False, interactive=True) # Row 3: Video handling options (only visible when Video is selected) with gr.Row(visible=False) as video_handling_options: video_handling_radio = gr.Radio(["slow_down", "loop"], label="Video Handling", value="slow_down", interactive=True) # Row 4: Processing options (two smaller panels) with gr.Row(): fast_mode_checkbox = gr.Checkbox(label="Fast Mode (Use BiRefNet_lite)", value=True, interactive=True) max_workers_slider = gr.Slider( minimum=1, maximum=32, step=1, value=10, label="Max Workers", info="Parallel processing threads", interactive=True ) # Styled button with rocket emoji and rounded corners submit_button = gr.Button( "🚀 Change Background", interactive=True, variant="primary", size="lg" ) def update_visibility(bg_type): if bg_type == "Color": return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) elif bg_type == "Image": return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) elif bg_type == "Video": return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) bg_type.change(update_visibility, inputs=bg_type, outputs=[color_picker, bg_image, bg_video, video_handling_options]) examples = gr.Examples( [ ["rickroll-2sec.mp4", "Video", None, "background.mp4"], ["rickroll-2sec.mp4", "Image", "images.webp", None], ["rickroll-2sec.mp4", "Color", None, None], ], inputs=[in_video, bg_type, bg_image, bg_video], outputs=[stream_image, out_video], fn=fn, cache_examples=True, cache_mode="eager", ) submit_button.click( fn, inputs=[in_video, bg_type, bg_image, bg_video, color_picker, fps_slider, video_handling_radio, fast_mode_checkbox, max_workers_slider], outputs=[stream_image, out_video], ) if __name__ == "__main__": demo.launch(show_error=True)