try: import spaces except ImportError: # Create a dummy decorator if spaces is not available def spaces_gpu(func): return func spaces = type('spaces', (), {'GPU': spaces_gpu})() import gradio as gr import torch from torchvision.transforms import functional as F from PIL import Image import os import cv2 import numpy as np from super_image import EdsrModel, ImageLoader from safetensors.torch import load_file @spaces.GPU def upscale_video(video_path, scale_factor, progress=gr.Progress()): """ Upscales a video using EDSR model. This function is decorated with @spaces.GPU to run on ZeroGPU. """ # Load models inside the function for ZeroGPU compatibility if scale_factor == 2: model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2) elif scale_factor == 4: model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4) else: raise gr.Error("Invalid scale factor. Choose 2 or 4.") if not os.path.exists(video_path): raise gr.Error(f"Input file not found at {video_path}") video_capture = cv2.VideoCapture(video_path) if not video_capture.isOpened(): raise gr.Error(f"Could not open video file {video_path}") fourcc = cv2.VideoWriter_fourcc(*'mp4v') fps = video_capture.get(cv2.CAP_PROP_FPS) width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) output_width = width * scale_factor output_height = height * scale_factor output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}" video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height)) for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"): ret, frame = video_capture.read() if not ret: break pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) inputs = ImageLoader.load_image(pil_frame) preds = model(inputs) output_frame = ImageLoader.to_pil(preds) video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR)) video_capture.release() video_writer.release() return output_path @spaces.GPU def rife_interpolate_video(video_path, progress=gr.Progress()): """ Interpolates a video using the RIFE model. This function is decorated with @spaces.GPU to run on ZeroGPU. """ if not os.path.exists(video_path): raise gr.Error(f"Input file not found at {video_path}") # Load the RIFE model model = RIFEModel() model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors")) model.eval() model.cuda() video_capture = cv2.VideoCapture(video_path) if not video_capture.isOpened(): raise gr.Error(f"Could not open video file {video_path}") fourcc = cv2.VideoWriter_fourcc(*'mp4v') fps = video_capture.get(cv2.CAP_PROP_FPS) width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) output_path = f"interpolated_{os.path.basename(video_path)}" video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height)) prev_frame = None for i in progress.tqdm(range(frame_count), desc="Interpolating"): ret, frame = video_capture.read() if not ret: break if prev_frame is not None: # Preprocess frames img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255. img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255. # Run inference with torch.no_grad(): interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255 video_writer.write(interpolated_frame.astype(np.uint8)) video_writer.write(frame) prev_frame = frame video_capture.release() video_writer.release() return output_path with gr.Blocks() as demo: gr.Markdown("# Video Upscaler and Frame Interpolator") with gr.Tab("Upscale"): with gr.Row(): with gr.Column(): video_input_upscale = gr.Video(label="Input Video") scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2) upscale_button = gr.Button("Upscale Video") with gr.Column(): video_output_upscale = gr.Video(label="Upscaled Video") with gr.Tab("Interpolate"): with gr.Row(): with gr.Column(): video_input_rife = gr.Video(label="Input Video") rife_button = gr.Button("Interpolate Frames") with gr.Column(): video_output_rife = gr.Video(label="Interpolated Video") upscale_button.click( fn=upscale_video, inputs=[video_input_upscale, scale_factor], outputs=video_output_upscale ) rife_button.click( fn=rife_interpolate_video, inputs=[video_input_rife], outputs=video_output_rife ) if __name__ == "__main__": demo.launch(share=True)