Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,448 Bytes
4a8a7a3 e950119 4a8a7a3 f8a10df e950119 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 308d965 4b66926 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 e950119 4a8a7a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
try:
import spaces
except ImportError:
# Create a dummy decorator if spaces is not available
def spaces_gpu(func):
return func
spaces = type('spaces', (), {'GPU': spaces_gpu})()
import gradio as gr
import torch
from torchvision.transforms import functional as F
from PIL import Image
import os
import cv2
import numpy as np
from super_image import EdsrModel, ImageLoader
from safetensors.torch import load_file
@spaces.GPU
def upscale_video(video_path, scale_factor, progress=gr.Progress()):
"""
Upscales a video using EDSR model.
This function is decorated with @spaces.GPU to run on ZeroGPU.
"""
# Load models inside the function for ZeroGPU compatibility
if scale_factor == 2:
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
elif scale_factor == 4:
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4)
else:
raise gr.Error("Invalid scale factor. Choose 2 or 4.")
if not os.path.exists(video_path):
raise gr.Error(f"Input file not found at {video_path}")
video_capture = cv2.VideoCapture(video_path)
if not video_capture.isOpened():
raise gr.Error(f"Could not open video file {video_path}")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = video_capture.get(cv2.CAP_PROP_FPS)
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
output_width = width * scale_factor
output_height = height * scale_factor
output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}"
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height))
for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"):
ret, frame = video_capture.read()
if not ret:
break
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
inputs = ImageLoader.load_image(pil_frame)
preds = model(inputs)
output_frame = ImageLoader.to_pil(preds)
video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR))
video_capture.release()
video_writer.release()
return output_path
@spaces.GPU
def rife_interpolate_video(video_path, progress=gr.Progress()):
"""
Interpolates a video using the RIFE model.
This function is decorated with @spaces.GPU to run on ZeroGPU.
"""
if not os.path.exists(video_path):
raise gr.Error(f"Input file not found at {video_path}")
# Load the RIFE model
model = RIFEModel()
model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors"))
model.eval()
model.cuda()
video_capture = cv2.VideoCapture(video_path)
if not video_capture.isOpened():
raise gr.Error(f"Could not open video file {video_path}")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = video_capture.get(cv2.CAP_PROP_FPS)
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
output_path = f"interpolated_{os.path.basename(video_path)}"
video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height))
prev_frame = None
for i in progress.tqdm(range(frame_count), desc="Interpolating"):
ret, frame = video_capture.read()
if not ret:
break
if prev_frame is not None:
# Preprocess frames
img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
# Run inference
with torch.no_grad():
interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255
video_writer.write(interpolated_frame.astype(np.uint8))
video_writer.write(frame)
prev_frame = frame
video_capture.release()
video_writer.release()
return output_path
with gr.Blocks() as demo:
gr.Markdown("# Video Upscaler and Frame Interpolator")
with gr.Tab("Upscale"):
with gr.Row():
with gr.Column():
video_input_upscale = gr.Video(label="Input Video")
scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2)
upscale_button = gr.Button("Upscale Video")
with gr.Column():
video_output_upscale = gr.Video(label="Upscaled Video")
with gr.Tab("Interpolate"):
with gr.Row():
with gr.Column():
video_input_rife = gr.Video(label="Input Video")
rife_button = gr.Button("Interpolate Frames")
with gr.Column():
video_output_rife = gr.Video(label="Interpolated Video")
upscale_button.click(
fn=upscale_video,
inputs=[video_input_upscale, scale_factor],
outputs=video_output_upscale
)
rife_button.click(
fn=rife_interpolate_video,
inputs=[video_input_rife],
outputs=video_output_rife
)
if __name__ == "__main__":
demo.launch(share=True)
|