import gradio as gr import torch from PIL import Image from transformers import AutoTokenizer, AutoModelForCausalLM import cv2 import numpy as np from typing import Optional import tempfile import os import spaces MID = "apple/FastVLM-7B" IMAGE_TOKEN_INDEX = -200 # Initialize model variables tok = None model = None def load_model(): global tok, model if tok is None or model is None: print("Loading FastVLM model...") tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MID, torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True, ) print("Model loaded successfully!") return tok, model def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"): """Extract frames from video""" cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames == 0: cap.release() return [] frames = [] if sampling_method == "uniform": # Uniform sampling indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) elif sampling_method == "first": # Take first N frames indices = list(range(min(num_frames, total_frames))) elif sampling_method == "last": # Take last N frames start = max(0, total_frames - num_frames) indices = list(range(start, total_frames)) else: # middle # Take frames from the middle start = max(0, (total_frames - num_frames) // 2) indices = list(range(start, min(start + num_frames, total_frames))) for idx in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) cap.release() return frames @spaces.GPU(duration=60) def caption_frame(image: Image.Image, prompt: str) -> str: """Generate caption for a single frame""" # Load model on GPU tok, model = load_model() # Build chat with custom prompt messages = [ {"role": "user", "content": f"\n{prompt}"} ] rendered = tok.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) pre, post = rendered.split("", 1) # Tokenize the text around the image token pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids # Splice in the IMAGE token id img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) attention_mask = torch.ones_like(input_ids, device=model.device) # Preprocess image px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"] px = px.to(model.device, dtype=model.dtype) # Generate with torch.no_grad(): out = model.generate( inputs=input_ids, attention_mask=attention_mask, images=px, max_new_tokens=15, temperature=0.7, do_sample=True, ) caption = tok.decode(out[0], skip_special_tokens=True) # Extract only the generated part if prompt in caption: caption = caption.split(prompt)[-1].strip() return caption def process_video( video_path: str, num_frames: int, sampling_method: str, caption_mode: str, custom_prompt: str, progress=gr.Progress() ) -> tuple: """Process video and generate captions""" if not video_path: return "Please upload a video first.", None progress(0, desc="Extracting frames...") frames = extract_frames(video_path, num_frames, sampling_method) if not frames: return "Failed to extract frames from video.", None # Use brief one-sentence prompt for faster processing prompt = "Provide a brief one-sentence description of what's happening in this image." captions = [] frame_previews = [] for i, frame in enumerate(frames): progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") caption = caption_frame(frame, prompt) captions.append(f"Frame {i + 1}: {caption}") frame_previews.append(frame) progress(1.0, desc="Generating summary...") # Combine captions into a simple narrative full_caption = "\n".join(captions) # Generate overall summary if multiple frames if len(frames) > 1: video_summary = f"Analyzed {len(frames)} frames:\n\n{full_caption}" else: video_summary = f"Video Analysis:\n\n{full_caption}" return video_summary, frame_previews # Create the Gradio interface # Create custom Apple-inspired theme class AppleTheme(gr.themes.Base): def __init__(self): super().__init__( primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.gray, neutral_hue=gr.themes.colors.gray, spacing_size=gr.themes.sizes.spacing_md, radius_size=gr.themes.sizes.radius_md, text_size=gr.themes.sizes.text_md, font=[ gr.themes.GoogleFont("Inter"), "-apple-system", "BlinkMacSystemFont", "SF Pro Display", "SF Pro Text", "Helvetica Neue", "Helvetica", "Arial", "sans-serif" ], font_mono=[ gr.themes.GoogleFont("SF Mono"), "ui-monospace", "Consolas", "monospace" ] ) super().set( # Core colors body_background_fill="*neutral_50", body_background_fill_dark="*neutral_950", button_primary_background_fill="*primary_500", button_primary_background_fill_hover="*primary_600", button_primary_text_color="white", button_primary_border_color="*primary_500", # Shadows block_shadow="0 4px 12px rgba(0, 0, 0, 0.08)", # Borders block_border_width="1px", block_border_color="*neutral_200", input_border_width="1px", input_border_color="*neutral_300", input_border_color_focus="*primary_500", # Text block_title_text_weight="600", block_label_text_weight="500", block_label_text_size="13px", block_label_text_color="*neutral_600", body_text_color="*neutral_900", # Spacing layout_gap="16px", block_padding="20px", # Specific components slider_color="*primary_500", ) # Create the Gradio interface with the custom theme with gr.Blocks(theme=AppleTheme()) as demo: gr.Markdown("# 🎬 FastVLM Video Captioning") with gr.Row(): # Main video display with gr.Column(scale=7): video_display = gr.Video( label="Video Input", autoplay=True, loop=True ) # Sidebar with chat interface with gr.Sidebar(width=400): gr.Markdown("## 💬 Video Analysis Chat") chatbot = gr.Chatbot( value=[["Assistant", "Upload a video and I'll analyze it for you!"]], height=400, elem_classes=["chatbot"] ) process_btn = gr.Button("🎯 Analyze Video", variant="primary", size="lg") with gr.Accordion("🖼️ Analyzed Frames", open=False): frame_gallery = gr.Gallery( label="Extracted Frames", show_label=False, columns=2, rows=4, object_fit="contain", height="auto" ) # Hidden parameters with default values num_frames = gr.State(value=8) sampling_method = gr.State(value="uniform") caption_mode = gr.State(value="Brief Summary") custom_prompt = gr.State(value="") # Upload handler def handle_upload(video, chat_history): if video: chat_history.append(["User", "Video uploaded"]) chat_history.append(["Assistant", "Video loaded! Click 'Analyze Video' to generate captions."]) return video, chat_history return None, chat_history video_display.upload( handle_upload, inputs=[video_display, chatbot], outputs=[video_display, chatbot] ) # Modified process function to update chatbot with streaming def process_video_with_chat(video_path, num_frames, sampling_method, caption_mode, custom_prompt, chat_history, progress=gr.Progress()): if not video_path: chat_history.append(["Assistant", "Please upload a video first."]) yield chat_history, None return chat_history.append(["User", "Analyzing video..."]) yield chat_history, None # Extract frames progress(0, desc="Extracting frames...") frames = extract_frames(video_path, num_frames, sampling_method) if not frames: chat_history.append(["Assistant", "Failed to extract frames from video."]) yield chat_history, None return # Start streaming response chat_history.append(["Assistant", ""]) prompt = "Provide a brief one-sentence description of what's happening in this image." captions = [] for i, frame in enumerate(frames): progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") caption = caption_frame(frame, prompt) frame_caption = f"Frame {i + 1}: {caption}\n" captions.append(frame_caption) # Update the last message with accumulated captions current_text = "".join(captions) chat_history[-1] = ["Assistant", f"Analyzing {len(frames)} frames:\n\n{current_text}"] yield chat_history, frames[:i+1] # Also update frame gallery progressively progress(1.0, desc="Analysis complete!") # Final update with complete message full_caption = "".join(captions) final_message = f"Analyzed {len(frames)} frames:\n\n{full_caption}" chat_history[-1] = ["Assistant", final_message] yield chat_history, frames # Process button with streaming process_btn.click( process_video_with_chat, inputs=[video_display, num_frames, sampling_method, caption_mode, custom_prompt, chatbot], outputs=[chatbot, frame_gallery], show_progress=True ) demo.launch()