FastVLM-7B / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
35d853d verified
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
import cv2
import numpy as np
from typing import Optional
import tempfile
import os
import spaces
MID = "apple/FastVLM-7B"
IMAGE_TOKEN_INDEX = -200
# Initialize model variables
tok = None
model = None
def load_model():
global tok, model
if tok is None or model is None:
print("Loading FastVLM model...")
tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MID,
torch_dtype=torch.float16,
device_map="cuda",
trust_remote_code=True,
)
print("Model loaded successfully!")
return tok, model
def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"):
"""Extract frames from video"""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
cap.release()
return []
frames = []
if sampling_method == "uniform":
# Uniform sampling
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
elif sampling_method == "first":
# Take first N frames
indices = list(range(min(num_frames, total_frames)))
elif sampling_method == "last":
# Take last N frames
start = max(0, total_frames - num_frames)
indices = list(range(start, total_frames))
else: # middle
# Take frames from the middle
start = max(0, (total_frames - num_frames) // 2)
indices = list(range(start, min(start + num_frames, total_frames)))
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
cap.release()
return frames
@spaces.GPU(duration=60)
def caption_frame(image: Image.Image, prompt: str) -> str:
"""Generate caption for a single frame"""
# Load model on GPU
tok, model = load_model()
# Build chat with custom prompt
messages = [
{"role": "user", "content": f"<image>\n{prompt}"}
]
rendered = tok.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
pre, post = rendered.split("<image>", 1)
# Tokenize the text around the image token
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
# Splice in the IMAGE token id
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)
# Preprocess image
px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]
px = px.to(model.device, dtype=model.dtype)
# Generate
with torch.no_grad():
out = model.generate(
inputs=input_ids,
attention_mask=attention_mask,
images=px,
max_new_tokens=15,
temperature=0.7,
do_sample=True,
)
caption = tok.decode(out[0], skip_special_tokens=True)
# Extract only the generated part
if prompt in caption:
caption = caption.split(prompt)[-1].strip()
return caption
def process_video(
video_path: str,
num_frames: int,
sampling_method: str,
caption_mode: str,
custom_prompt: str,
progress=gr.Progress()
) -> tuple:
"""Process video and generate captions"""
if not video_path:
return "Please upload a video first.", None
progress(0, desc="Extracting frames...")
frames = extract_frames(video_path, num_frames, sampling_method)
if not frames:
return "Failed to extract frames from video.", None
# Use brief one-sentence prompt for faster processing
prompt = "Provide a brief one-sentence description of what's happening in this image."
captions = []
frame_previews = []
for i, frame in enumerate(frames):
progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...")
caption = caption_frame(frame, prompt)
captions.append(f"Frame {i + 1}: {caption}")
frame_previews.append(frame)
progress(1.0, desc="Generating summary...")
# Combine captions into a simple narrative
full_caption = "\n".join(captions)
# Generate overall summary if multiple frames
if len(frames) > 1:
video_summary = f"Analyzed {len(frames)} frames:\n\n{full_caption}"
else:
video_summary = f"Video Analysis:\n\n{full_caption}"
return video_summary, frame_previews
# Create the Gradio interface
# Create custom Apple-inspired theme
class AppleTheme(gr.themes.Base):
def __init__(self):
super().__init__(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.gray,
neutral_hue=gr.themes.colors.gray,
spacing_size=gr.themes.sizes.spacing_md,
radius_size=gr.themes.sizes.radius_md,
text_size=gr.themes.sizes.text_md,
font=[
gr.themes.GoogleFont("Inter"),
"-apple-system",
"BlinkMacSystemFont",
"SF Pro Display",
"SF Pro Text",
"Helvetica Neue",
"Helvetica",
"Arial",
"sans-serif"
],
font_mono=[
gr.themes.GoogleFont("SF Mono"),
"ui-monospace",
"Consolas",
"monospace"
]
)
super().set(
# Core colors
body_background_fill="*neutral_50",
body_background_fill_dark="*neutral_950",
button_primary_background_fill="*primary_500",
button_primary_background_fill_hover="*primary_600",
button_primary_text_color="white",
button_primary_border_color="*primary_500",
# Shadows
block_shadow="0 4px 12px rgba(0, 0, 0, 0.08)",
# Borders
block_border_width="1px",
block_border_color="*neutral_200",
input_border_width="1px",
input_border_color="*neutral_300",
input_border_color_focus="*primary_500",
# Text
block_title_text_weight="600",
block_label_text_weight="500",
block_label_text_size="13px",
block_label_text_color="*neutral_600",
body_text_color="*neutral_900",
# Spacing
layout_gap="16px",
block_padding="20px",
# Specific components
slider_color="*primary_500",
)
# Create the Gradio interface with the custom theme
with gr.Blocks(theme=AppleTheme()) as demo:
gr.Markdown("# 🎬 FastVLM Video Captioning")
with gr.Row():
# Main video display
with gr.Column(scale=7):
video_display = gr.Video(
label="Video Input",
autoplay=True,
loop=True
)
# Sidebar with chat interface
with gr.Sidebar(width=400):
gr.Markdown("## πŸ’¬ Video Analysis Chat")
chatbot = gr.Chatbot(
value=[["Assistant", "Upload a video and I'll analyze it for you!"]],
height=400,
elem_classes=["chatbot"]
)
process_btn = gr.Button("🎯 Analyze Video", variant="primary", size="lg")
with gr.Accordion("πŸ–ΌοΈ Analyzed Frames", open=False):
frame_gallery = gr.Gallery(
label="Extracted Frames",
show_label=False,
columns=2,
rows=4,
object_fit="contain",
height="auto"
)
# Hidden parameters with default values
num_frames = gr.State(value=8)
sampling_method = gr.State(value="uniform")
caption_mode = gr.State(value="Brief Summary")
custom_prompt = gr.State(value="")
# Upload handler
def handle_upload(video, chat_history):
if video:
chat_history.append(["User", "Video uploaded"])
chat_history.append(["Assistant", "Video loaded! Click 'Analyze Video' to generate captions."])
return video, chat_history
return None, chat_history
video_display.upload(
handle_upload,
inputs=[video_display, chatbot],
outputs=[video_display, chatbot]
)
# Modified process function to update chatbot with streaming
def process_video_with_chat(video_path, num_frames, sampling_method, caption_mode, custom_prompt, chat_history, progress=gr.Progress()):
if not video_path:
chat_history.append(["Assistant", "Please upload a video first."])
yield chat_history, None
return
chat_history.append(["User", "Analyzing video..."])
yield chat_history, None
# Extract frames
progress(0, desc="Extracting frames...")
frames = extract_frames(video_path, num_frames, sampling_method)
if not frames:
chat_history.append(["Assistant", "Failed to extract frames from video."])
yield chat_history, None
return
# Start streaming response
chat_history.append(["Assistant", ""])
prompt = "Provide a brief one-sentence description of what's happening in this image."
captions = []
for i, frame in enumerate(frames):
progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...")
caption = caption_frame(frame, prompt)
frame_caption = f"Frame {i + 1}: {caption}\n"
captions.append(frame_caption)
# Update the last message with accumulated captions
current_text = "".join(captions)
chat_history[-1] = ["Assistant", f"Analyzing {len(frames)} frames:\n\n{current_text}"]
yield chat_history, frames[:i+1] # Also update frame gallery progressively
progress(1.0, desc="Analysis complete!")
# Final update with complete message
full_caption = "".join(captions)
final_message = f"Analyzed {len(frames)} frames:\n\n{full_caption}"
chat_history[-1] = ["Assistant", final_message]
yield chat_history, frames
# Process button with streaming
process_btn.click(
process_video_with_chat,
inputs=[video_display, num_frames, sampling_method, caption_mode, custom_prompt, chatbot],
outputs=[chatbot, frame_gallery],
show_progress=True
)
demo.launch()