inoculatemedia's picture
Update app.py
308d965 verified
try:
import spaces
except ImportError:
# Create a dummy decorator if spaces is not available
def spaces_gpu(func):
return func
spaces = type('spaces', (), {'GPU': spaces_gpu})()
import gradio as gr
import torch
from torchvision.transforms import functional as F
from PIL import Image
import os
import cv2
import numpy as np
from super_image import EdsrModel, ImageLoader
from safetensors.torch import load_file
@spaces.GPU
def upscale_video(video_path, scale_factor, progress=gr.Progress()):
"""
Upscales a video using EDSR model.
This function is decorated with @spaces.GPU to run on ZeroGPU.
"""
# Load models inside the function for ZeroGPU compatibility
if scale_factor == 2:
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
elif scale_factor == 4:
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4)
else:
raise gr.Error("Invalid scale factor. Choose 2 or 4.")
if not os.path.exists(video_path):
raise gr.Error(f"Input file not found at {video_path}")
video_capture = cv2.VideoCapture(video_path)
if not video_capture.isOpened():
raise gr.Error(f"Could not open video file {video_path}")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = video_capture.get(cv2.CAP_PROP_FPS)
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
output_width = width * scale_factor
output_height = height * scale_factor
output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}"
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height))
for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"):
ret, frame = video_capture.read()
if not ret:
break
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
inputs = ImageLoader.load_image(pil_frame)
preds = model(inputs)
output_frame = ImageLoader.to_pil(preds)
video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR))
video_capture.release()
video_writer.release()
return output_path
@spaces.GPU
def rife_interpolate_video(video_path, progress=gr.Progress()):
"""
Interpolates a video using the RIFE model.
This function is decorated with @spaces.GPU to run on ZeroGPU.
"""
if not os.path.exists(video_path):
raise gr.Error(f"Input file not found at {video_path}")
# Load the RIFE model
model = RIFEModel()
model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors"))
model.eval()
model.cuda()
video_capture = cv2.VideoCapture(video_path)
if not video_capture.isOpened():
raise gr.Error(f"Could not open video file {video_path}")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = video_capture.get(cv2.CAP_PROP_FPS)
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
output_path = f"interpolated_{os.path.basename(video_path)}"
video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height))
prev_frame = None
for i in progress.tqdm(range(frame_count), desc="Interpolating"):
ret, frame = video_capture.read()
if not ret:
break
if prev_frame is not None:
# Preprocess frames
img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
# Run inference
with torch.no_grad():
interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255
video_writer.write(interpolated_frame.astype(np.uint8))
video_writer.write(frame)
prev_frame = frame
video_capture.release()
video_writer.release()
return output_path
with gr.Blocks() as demo:
gr.Markdown("# Video Upscaler and Frame Interpolator")
with gr.Tab("Upscale"):
with gr.Row():
with gr.Column():
video_input_upscale = gr.Video(label="Input Video")
scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2)
upscale_button = gr.Button("Upscale Video")
with gr.Column():
video_output_upscale = gr.Video(label="Upscaled Video")
with gr.Tab("Interpolate"):
with gr.Row():
with gr.Column():
video_input_rife = gr.Video(label="Input Video")
rife_button = gr.Button("Interpolate Frames")
with gr.Column():
video_output_rife = gr.Video(label="Interpolated Video")
upscale_button.click(
fn=upscale_video,
inputs=[video_input_upscale, scale_factor],
outputs=video_output_upscale
)
rife_button.click(
fn=rife_interpolate_video,
inputs=[video_input_rife],
outputs=video_output_rife
)
if __name__ == "__main__":
demo.launch(share=True)