import spaces import torch import os import time import datetime from moviepy.editor import VideoFileClip import gradio as gr # Download Weights from huggingface_hub import snapshot_download # List of subdirectories to create inside "weights" subfolders = [ "diffuEraser", "stable-diffusion-v1-5", "PCM_Weights", "propainter", "sd-vae-ft-mse" ] # Create directories for subfolder in subfolders: os.makedirs(os.path.join("weights", subfolder), exist_ok=True) snapshot_download(repo_id="lixiaowen/diffuEraser", local_dir="./weights/diffuEraser") snapshot_download(repo_id="stable-diffusion-v1-5/stable-diffusion-v1-5", local_dir="./weights/stable-diffusion-v1-5") snapshot_download(repo_id="wangfuyun/PCM_Weights", local_dir="./weights/PCM_Weights") snapshot_download(repo_id="camenduru/ProPainter", local_dir="./weights/propainter") snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir="./weights/sd-vae-ft-mse") # Import model classes from diffueraser.diffueraser import DiffuEraser from propainter.inference import Propainter, get_device base_model_path = "weights/stable-diffusion-v1-5" vae_path = "weights/sd-vae-ft-mse" diffueraser_path = "weights/diffuEraser" propainter_model_dir = "weights/propainter" # Model setup device = get_device() ckpt = "2-Step" video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt) propainter = Propainter(propainter_model_dir, device=device) # Helper function to trim videos def trim_video(input_path, output_path, max_duration=5): clip = VideoFileClip(input_path) trimmed_clip = clip.subclip(0, min(max_duration, clip.duration)) trimmed_clip.write_videofile(output_path, codec="libx264", audio_codec="aac") clip.close() trimmed_clip.close() @spaces.GPU(duration=100) def infer(input_video, input_mask): # Setup paths and parameters save_path = "results" mask_dilation_iter = 8 max_img_size = 960 ref_stride = 10 neighbor_length = 10 subvideo_length = 50 if not os.path.exists(save_path): os.makedirs(save_path) # Timestamp for unique filenames timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") trimmed_video_path = os.path.join(save_path, f"trimmed_video_{timestamp}.mp4") trimmed_mask_path = os.path.join(save_path, f"trimmed_mask_{timestamp}.mp4") priori_path = os.path.join(save_path, f"priori_{timestamp}.mp4") output_path = os.path.join(save_path, f"diffueraser_result_{timestamp}.mp4") # Trim input videos trim_video(input_video, trimmed_video_path) trim_video(input_mask, trimmed_mask_path) # Dynamically compute video_length (in frames) assuming 30 fps clip = VideoFileClip(trimmed_video_path) video_duration = clip.duration clip.close() video_length = int(video_duration * 30) # Run models start_time = time.time() # ProPainter (priori) propainter.forward(trimmed_video_path, trimmed_mask_path, priori_path, video_length=video_length, ref_stride=ref_stride, neighbor_length=neighbor_length, subvideo_length=subvideo_length, mask_dilation=mask_dilation_iter) # DiffuEraser guidance_scale = None video_inpainting_sd.forward(trimmed_video_path, trimmed_mask_path, priori_path, output_path, max_img_size=max_img_size, video_length=video_length, mask_dilation_iter=mask_dilation_iter, guidance_scale=guidance_scale) end_time = time.time() print(f"DiffuEraser inference time: {end_time - start_time:.2f} seconds") torch.cuda.empty_cache() return output_path # Gradio interface with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# DiffuEraser: A Diffusion Model for Video Inpainting") gr.Markdown("DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model ProPainter in both content completeness and temporal consistency while maintaining acceptable efficiency.") gr.HTML("""
Duplicate this Space
""") with gr.Row(): with gr.Column(): input_video = gr.Video(label="Input Video (MP4 ONLY)") input_mask = gr.Video(label="Input Mask Video (MP4 ONLY)") submit_btn = gr.Button("Submit") with gr.Column(): video_result = gr.Video(label="Result") gr.Examples( examples=[ ["./examples/example1/video.mp4", "./examples/example1/mask.mp4"], ["./examples/example2/video.mp4", "./examples/example2/mask.mp4"], ["./examples/example3/video.mp4", "./examples/example3/mask.mp4"], ], inputs=[input_video, input_mask] ) submit_btn.click(fn=infer, inputs=[input_video, input_mask], outputs=[video_result]) demo.queue().launch(show_api=True, show_error=True, ssr_mode=False)