tohid4n's picture
Update app.py
96b242d verified
import gradio as gr
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from pydub import AudioSegment
from PIL import Image
import numpy as np
import os
import tempfile
import uuid
import time
from concurrent.futures import ThreadPoolExecutor
from moviepy import VideoFileClip, vfx, concatenate_videoclips, ImageSequenceClip
torch.set_float32_matmul_precision("medium")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load both BiRefNet models
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
birefnet.to(device)
birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
birefnet_lite.to(device)
transform_image = transforms.Compose([
transforms.Resize((768, 768)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Function to process a single frame
def process_frame(frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color):
try:
pil_image = Image.fromarray(frame)
if bg_type == "Color":
processed_image = process(pil_image, color, fast_mode)
elif bg_type == "Image":
processed_image = process(pil_image, bg, fast_mode)
elif bg_type == "Video":
background_frame = background_frames[bg_frame_index] # Access the correct background frame
bg_frame_index += 1
background_image = Image.fromarray(background_frame)
processed_image = process(pil_image, background_image, fast_mode)
else:
processed_image = pil_image # Default to original image if no background is selected
return np.array(processed_image), bg_frame_index
except Exception as e:
print(f"Error processing frame: {e}")
return frame, bg_frame_index
@spaces.GPU
def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=True, max_workers=10):
try:
start_time = time.time() # Start the timer
video = VideoFileClip(vid)
if fps == 0:
fps = video.fps
audio = video.audio
frames = list(video.iter_frames(fps=fps))
processed_frames = []
yield gr.update(visible=True), gr.update(visible=False)
if bg_type == "Video":
background_video = VideoFileClip(bg_video)
if background_video.duration < video.duration:
if video_handling == "slow_down":
background_video = background_video.fx(vfx.speedx, factor=video.duration / background_video.duration)
else: # video_handling == "loop"
background_video = concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1))
background_frames = list(background_video.iter_frames(fps=fps))
else:
background_frames = None
bg_frame_index = 0 # Initialize background frame index
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Pass bg_frame_index as part of the function arguments
futures = [executor.submit(process_frame, frames[i], bg_type, bg_image, fast_mode, bg_frame_index + i, background_frames, color) for i in range(len(frames))]
for i, future in enumerate(futures):
result, _ = future.result() # No need to update bg_frame_index here
processed_frames.append(result)
elapsed_time = time.time() - start_time
yield result, None
processed_video = ImageSequenceClip(processed_frames, fps=fps)
processed_video = processed_video.with_audio(audio)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
temp_filepath = temp_file.name
processed_video.write_videofile(temp_filepath, codec="libx264")
elapsed_time = time.time() - start_time
yield gr.update(visible=False), gr.update(visible=True)
yield processed_frames[-1], temp_filepath
except Exception as e:
print(f"Error: {e}")
elapsed_time = time.time() - start_time
yield gr.update(visible=False), gr.update(visible=True)
yield None, f"Error processing video: {e}"
def process(image, bg, fast_mode=False):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to(device)
model = birefnet_lite if fast_mode else birefnet
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
if isinstance(bg, str) and bg.startswith("#"):
color_rgb = tuple(int(bg[i:i+2], 16) for i in (1, 3, 5))
background = Image.new("RGBA", image_size, color_rgb + (255,))
elif isinstance(bg, Image.Image):
background = bg.convert("RGBA").resize(image_size)
else:
background = Image.open(bg).convert("RGBA").resize(image_size)
image = Image.composite(image, background, mask)
return image
# Create custom dark theme
dark_theme = gr.themes.Soft(
primary_hue="purple",
secondary_hue="blue",
neutral_hue="slate",
).set(
body_background_fill_dark="#0f0f23",
block_background_fill_dark="#1a1b2e",
block_border_color_dark="#16213e",
input_background_fill_dark="#16213e",
button_primary_background_fill_dark="#6366f1",
button_primary_background_fill_hover_dark="#4f46e5",
button_secondary_background_fill_dark="#374151",
button_secondary_background_fill_hover_dark="#4b5563"
)
with gr.Blocks(theme=dark_theme) as demo:
with gr.Row():
in_video = gr.Video(label="Input Video", interactive=True)
stream_image = gr.Image(label="Streaming Output", visible=False)
out_video = gr.Video(label="Final Output Video")
# Settings panels aligned below input video
with gr.Column():
# Row 1: Background Type and FPS (two smaller panels)
with gr.Row():
bg_type = gr.Radio(["Color", "Image", "Video"], label="Background Type", value="Color", interactive=True)
fps_slider = gr.Slider(
minimum=0,
maximum=60,
step=1,
value=0,
label="Output FPS (0 = original fps)",
interactive=True
)
# Row 2: Dynamic background options (full width)
color_picker = gr.ColorPicker(label="Background Color", value="#00FF00", visible=True, interactive=True)
bg_image = gr.Image(label="Background Image", type="filepath", visible=False, interactive=True)
bg_video = gr.Video(label="Background Video", visible=False, interactive=True)
# Row 3: Video handling options (only visible when Video is selected)
with gr.Row(visible=False) as video_handling_options:
video_handling_radio = gr.Radio(["slow_down", "loop"], label="Video Handling", value="slow_down", interactive=True)
# Row 4: Processing options (two smaller panels)
with gr.Row():
fast_mode_checkbox = gr.Checkbox(label="Fast Mode (Use BiRefNet_lite)", value=True, interactive=True)
max_workers_slider = gr.Slider(
minimum=1,
maximum=32,
step=1,
value=10,
label="Max Workers",
info="Parallel processing threads",
interactive=True
)
# Styled button with rocket emoji and rounded corners
submit_button = gr.Button(
"🚀 Change Background",
interactive=True,
variant="primary",
size="lg"
)
def update_visibility(bg_type):
if bg_type == "Color":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
elif bg_type == "Image":
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
elif bg_type == "Video":
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
bg_type.change(update_visibility, inputs=bg_type, outputs=[color_picker, bg_image, bg_video, video_handling_options])
examples = gr.Examples(
[
["rickroll-2sec.mp4", "Video", None, "background.mp4"],
["rickroll-2sec.mp4", "Image", "images.webp", None],
["rickroll-2sec.mp4", "Color", None, None],
],
inputs=[in_video, bg_type, bg_image, bg_video],
outputs=[stream_image, out_video],
fn=fn,
cache_examples=True,
cache_mode="eager",
)
submit_button.click(
fn,
inputs=[in_video, bg_type, bg_image, bg_video, color_picker, fps_slider, video_handling_radio, fast_mode_checkbox, max_workers_slider],
outputs=[stream_image, out_video],
)
if __name__ == "__main__":
demo.launch(show_error=True)