Spaces:
Running
on
L40S
Running
on
L40S
import gradio as gr | |
import spaces | |
import torch | |
import os | |
import tempfile | |
import shutil | |
import imageio | |
import logging | |
from pathlib import Path | |
import numpy as np | |
import random | |
# Import from our modules | |
from model_loader import ModelLoader, MODELS_ROOT_DIR | |
from video_processor import VideoProcessor | |
from config import CAMERA_TRANSFORMATIONS, TEST_DATA_DIR | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global model loader instance | |
model_loader = ModelLoader() | |
video_processor = None | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
# Check if running in demo mode | |
IS_DEMO = os.environ.get("IS_DEMO", "").lower() in ["true", "1", "yes", "on"] | |
# Set limits based on demo mode | |
MAX_INFERENCE_STEPS = 25 if IS_DEMO else 50 | |
MAX_FRAMES = 49 if IS_DEMO else 81 | |
def init_video_processor(): | |
"""Initialize video processor""" | |
global video_processor | |
if model_loader.is_loaded and video_processor is None: | |
video_processor = VideoProcessor(model_loader.pipe) | |
return video_processor is not None | |
def extract_frames_from_video(video_path, output_dir, max_frames=81): | |
"""Extract frames from video and ensure we have at least max_frames frames""" | |
os.makedirs(output_dir, exist_ok=True) | |
reader = imageio.get_reader(video_path) | |
fps = reader.get_meta_data()['fps'] | |
total_frames = reader.count_frames() | |
frames = [] | |
for i, frame in enumerate(reader): | |
frames.append(frame) | |
reader.close() | |
# If we have fewer than required frames, repeat the last frame | |
if len(frames) < max_frames: | |
logger.info(f"Video has {len(frames)} frames, padding to {max_frames} frames") | |
last_frame = frames[-1] | |
while len(frames) < max_frames: | |
frames.append(last_frame) | |
# Save frames | |
for i, frame in enumerate(frames[:max_frames]): | |
frame_path = os.path.join(output_dir, f"frame_{i:04d}.png") | |
imageio.imwrite(frame_path, frame) | |
return len(frames[:max_frames]), fps | |
# compute takes at least 5 min so we cannot use ZeroGPU | |
# @spaces.GPU(duration=300) | |
def generate_recammaster_video( | |
video_file, | |
text_prompt, | |
camera_type, | |
num_frames, | |
resolution, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
cfg_scale, | |
progress=gr.Progress() | |
): | |
"""Main function to generate video with ReCamMaster""" | |
if not model_loader.is_loaded: | |
return None, "Error: Models not loaded! Please load models first.", seed | |
if not init_video_processor(): | |
return None, "Error: Failed to initialize video processor.", seed | |
if video_file is None: | |
return None, "Please upload a video file.", seed | |
try: | |
# Create temporary directory for processing | |
with tempfile.TemporaryDirectory() as temp_dir: | |
progress(0.1, desc="Processing input video...") | |
# Copy uploaded video to temp directory | |
input_video_path = os.path.join(temp_dir, "input.mp4") | |
shutil.copy(video_file, input_video_path) | |
# Parse resolution | |
width, height = map(int, resolution.split('x')) | |
# Handle seed | |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
logger.info(f"Using seed: {current_seed}") | |
# Extract frames | |
progress(0.2, desc="Extracting video frames...") | |
extracted_frames, fps = extract_frames_from_video( | |
input_video_path, | |
os.path.join(temp_dir, "frames"), | |
max_frames=num_frames | |
) | |
logger.info(f"Extracted {extracted_frames} frames at {fps} fps") | |
# Process with ReCamMaster | |
progress(0.3, desc="Processing with ReCamMaster...") | |
output_video = video_processor.process_video( | |
input_video_path, | |
text_prompt, | |
camera_type, | |
num_frames=num_frames, | |
height=height, | |
width=width, | |
seed=current_seed, | |
num_inference_steps=num_inference_steps, | |
cfg_scale=cfg_scale | |
) | |
# Save output video | |
progress(0.9, desc="Saving output video...") | |
output_path = os.path.join(temp_dir, "output.mp4") | |
from diffsynth import save_video | |
save_video(output_video, output_path, fps=30, quality=5) | |
# Copy to persistent location | |
final_output_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name | |
shutil.copy(output_path, final_output_path) | |
progress(1.0, desc="Done!") | |
transformation_name = CAMERA_TRANSFORMATIONS.get(str(camera_type), "Unknown") | |
status_msg = f"Successfully generated video with '{transformation_name}' camera movement! (Seed: {current_seed})" | |
return final_output_path, status_msg, current_seed | |
except Exception as e: | |
logger.error(f"Error generating video: {str(e)}") | |
return None, f"Error: {str(e)}", seed | |
# Create Gradio interface | |
with gr.Blocks(title="ReCamMaster") as demo: | |
demo_notice = "ℹ️ Due to the long generation times (~ 10 min for 50 steps of 81 frames) this space has be artificially limited to 25 steps, and [should be duplicated](https://huggingface.co/spaces/jbilcke-hf/ReCamMaster?duplicate=true) to your own account for the best experience (please select at least a Nvidia L40S)." if IS_DEMO else "" | |
gr.Markdown(f""" | |
# ReCamMaster 🎥 | |
This is a demo of [ReCamMaster](https://jianhongbai.github.io/ReCamMaster/), an amazing model that allows you to reshoot any video! | |
{demo_notice} | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# Video input section | |
with gr.Group(): | |
gr.Markdown("### 1. Upload a video (about ~3 sec long)") | |
video_input = gr.Video(label="Video file or webcam clip") | |
text_prompt = gr.Textbox( | |
label="Describe the scene", | |
placeholder="A person walking in the street", | |
value="A dynamic scene" | |
) | |
# Camera selection | |
with gr.Group(): | |
gr.Markdown("### 2. Decide how to reshoot the scene") | |
camera_type = gr.Radio( | |
choices=[(v, k) for k, v in CAMERA_TRANSFORMATIONS.items()], | |
label="New camera angle and movement", | |
value="1" | |
) | |
# Video settings | |
with gr.Group(): | |
gr.Markdown("### 3. (Optional) Tweak some settings") | |
num_frames = gr.Slider( | |
minimum=17, | |
maximum=81, # MAX_FRAMES, | |
value=81, # MAX_FRAMES, | |
step=16, | |
label="Number of Frames", | |
info=f"Must be 16n+1 (17, 33, 49{', 65, 81' if not IS_DEMO else ''})", | |
# let's disable it, because for now we have a bug if we don't use 81 frames | |
visible=False, | |
) | |
resolution = gr.Dropdown( | |
choices=["832x480", "480x480", "480x832", "576x320", "320x576"], | |
value="832x480", | |
label="Resolution", | |
info="Output video resolution", | |
# let's disable the resolution picker, | |
# right now the rest of the code doesn't support changing it, so.. | |
visible=False | |
) | |
with gr.Row(): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
interactive=True | |
) | |
randomize_seed = gr.Checkbox( | |
label="Randomize seed", | |
value=True, | |
interactive=True | |
) | |
num_inference_steps = gr.Slider( | |
minimum=10, | |
maximum=MAX_INFERENCE_STEPS, | |
value=min(30, MAX_INFERENCE_STEPS), | |
step=1, | |
label="Inference Steps", | |
info=f"50 steps are recommended but slower{' (demo is limited to 25, duplicate to remove the limit)' if IS_DEMO else ''}" | |
) | |
cfg_scale = gr.Slider( | |
minimum=0.0, | |
maximum=8.0, | |
value=5.0, | |
step=0.5, | |
label="CFG Scale", | |
info="Controls adherence to prompt" | |
) | |
# Generate button | |
generate_btn = gr.Button("Generate (will take 6~10 min)", variant="primary") | |
with gr.Column(): | |
# Output section | |
output_video = gr.Video(label="Modified video") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
# Event handlers | |
generate_btn.click( | |
fn=generate_recammaster_video, | |
inputs=[video_input, text_prompt, camera_type, num_frames, resolution, seed, randomize_seed, num_inference_steps, cfg_scale], | |
outputs=[output_video, status_output, seed] | |
) | |
if __name__ == "__main__": | |
model_loader.load_models() | |
demo.launch(share=True) |