Spaces:
Running
on
Zero
Running
on
Zero
try: | |
import spaces | |
except ImportError: | |
# Create a dummy decorator if spaces is not available | |
def spaces_gpu(func): | |
return func | |
spaces = type('spaces', (), {'GPU': spaces_gpu})() | |
import gradio as gr | |
import torch | |
from torchvision.transforms import functional as F | |
from PIL import Image | |
import os | |
import cv2 | |
import numpy as np | |
from super_image import EdsrModel, ImageLoader | |
from safetensors.torch import load_file | |
def upscale_video(video_path, scale_factor, progress=gr.Progress()): | |
""" | |
Upscales a video using EDSR model. | |
This function is decorated with @spaces.GPU to run on ZeroGPU. | |
""" | |
# Load models inside the function for ZeroGPU compatibility | |
if scale_factor == 2: | |
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2) | |
elif scale_factor == 4: | |
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4) | |
else: | |
raise gr.Error("Invalid scale factor. Choose 2 or 4.") | |
if not os.path.exists(video_path): | |
raise gr.Error(f"Input file not found at {video_path}") | |
video_capture = cv2.VideoCapture(video_path) | |
if not video_capture.isOpened(): | |
raise gr.Error(f"Could not open video file {video_path}") | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
fps = video_capture.get(cv2.CAP_PROP_FPS) | |
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
output_width = width * scale_factor | |
output_height = height * scale_factor | |
output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}" | |
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height)) | |
for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"): | |
ret, frame = video_capture.read() | |
if not ret: | |
break | |
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
inputs = ImageLoader.load_image(pil_frame) | |
preds = model(inputs) | |
output_frame = ImageLoader.to_pil(preds) | |
video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR)) | |
video_capture.release() | |
video_writer.release() | |
return output_path | |
def rife_interpolate_video(video_path, progress=gr.Progress()): | |
""" | |
Interpolates a video using the RIFE model. | |
This function is decorated with @spaces.GPU to run on ZeroGPU. | |
""" | |
if not os.path.exists(video_path): | |
raise gr.Error(f"Input file not found at {video_path}") | |
# Load the RIFE model | |
model = RIFEModel() | |
model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors")) | |
model.eval() | |
model.cuda() | |
video_capture = cv2.VideoCapture(video_path) | |
if not video_capture.isOpened(): | |
raise gr.Error(f"Could not open video file {video_path}") | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
fps = video_capture.get(cv2.CAP_PROP_FPS) | |
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
output_path = f"interpolated_{os.path.basename(video_path)}" | |
video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height)) | |
prev_frame = None | |
for i in progress.tqdm(range(frame_count), desc="Interpolating"): | |
ret, frame = video_capture.read() | |
if not ret: | |
break | |
if prev_frame is not None: | |
# Preprocess frames | |
img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255. | |
img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255. | |
# Run inference | |
with torch.no_grad(): | |
interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255 | |
video_writer.write(interpolated_frame.astype(np.uint8)) | |
video_writer.write(frame) | |
prev_frame = frame | |
video_capture.release() | |
video_writer.release() | |
return output_path | |
with gr.Blocks() as demo: | |
gr.Markdown("# Video Upscaler and Frame Interpolator") | |
with gr.Tab("Upscale"): | |
with gr.Row(): | |
with gr.Column(): | |
video_input_upscale = gr.Video(label="Input Video") | |
scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2) | |
upscale_button = gr.Button("Upscale Video") | |
with gr.Column(): | |
video_output_upscale = gr.Video(label="Upscaled Video") | |
with gr.Tab("Interpolate"): | |
with gr.Row(): | |
with gr.Column(): | |
video_input_rife = gr.Video(label="Input Video") | |
rife_button = gr.Button("Interpolate Frames") | |
with gr.Column(): | |
video_output_rife = gr.Video(label="Interpolated Video") | |
upscale_button.click( | |
fn=upscale_video, | |
inputs=[video_input_upscale, scale_factor], | |
outputs=video_output_upscale | |
) | |
rife_button.click( | |
fn=rife_interpolate_video, | |
inputs=[video_input_rife], | |
outputs=video_output_rife | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |