|
import shutil |
|
import subprocess |
|
|
|
import torch |
|
import gradio as gr |
|
from fastapi import FastAPI |
|
import os |
|
from PIL import Image |
|
import tempfile |
|
from decord import VideoReader, cpu |
|
from transformers import TextStreamer |
|
|
|
from llava.constants import DEFAULT_X_TOKEN, X_TOKEN_INDEX |
|
from llava.conversation import conv_templates, SeparatorStyle, Conversation |
|
from llava.serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css |
|
|
|
import os, re, math, time, tempfile, shutil |
|
import requests |
|
import numpy as np |
|
from PIL import Image |
|
from decord import VideoReader |
|
import ffmpeg |
|
|
|
|
|
|
|
GEN_KW = dict( |
|
do_sample=False, |
|
temperature=0.0, |
|
top_p=1.0, |
|
repetition_penalty=1.15, |
|
no_repeat_ngram_size=3, |
|
use_cache=False, |
|
) |
|
|
|
|
|
def _big_gpu(): |
|
try: |
|
return (torch.cuda.is_available() and |
|
torch.cuda.get_device_properties(0).total_memory/1024**3 >= 40) |
|
except Exception: |
|
return False |
|
|
|
MAX_NEW_TOKENS_SMALL = 128 |
|
MAX_NEW_TOKENS_BIG = 256 |
|
|
|
|
|
def _uniform_indices(n_total, n_want): |
|
if n_total <= 0 or n_want <= 0: |
|
return [] |
|
return np.linspace(0, n_total-1, n_want).round().astype(int).tolist() |
|
|
|
def sample_frames(video_path, n_frames=8): |
|
"""Return (frames_numpy[N,H,W,3], timestamps_sec[N]) sampled uniformly.""" |
|
vr = VideoReader(video_path) |
|
idx = _uniform_indices(len(vr), n_frames) |
|
frames = vr.get_batch(idx).asnumpy() |
|
fps = float(vr.get_avg_fps()) |
|
ts = [i / fps for i in idx] |
|
return frames, ts |
|
|
|
def mmss(s): |
|
m = int(s // 60); ss = int(round(s - 60*m)) |
|
return f"{m:02d}:{ss:02d}" |
|
|
|
def fetch_video_from_url(url, out_dir=None, max_seconds=None): |
|
"""Download URL to a local mp4; optionally trim with ffmpeg to first max_seconds.""" |
|
if out_dir is None: |
|
out_dir = tempfile.mkdtemp() |
|
local = os.path.join(out_dir, "input.mp4") |
|
with requests.get(url, stream=True, timeout=30) as r: |
|
r.raise_for_status() |
|
with open(local, "wb") as f: |
|
for chunk in r.iter_content(chunk_size=1<<20): |
|
if chunk: |
|
f.write(chunk) |
|
if (max_seconds is not None) and max_seconds > 0: |
|
trimmed = os.path.join(out_dir, "input_trimmed.mp4") |
|
( |
|
ffmpeg |
|
.input(local) |
|
.output(trimmed, t=max_seconds, c='copy', loglevel="error") |
|
.overwrite_output() |
|
.run() |
|
) |
|
return trimmed |
|
return local |
|
|
|
|
|
def keep_frame_lines(text, T): |
|
"""Enforce 'Frame i: ...' lines; fill missing frames with placeholders.""" |
|
lines = [] |
|
for ln in text.splitlines(): |
|
m = re.match(r"^Frame\s+(\d+)\s*:\s*(.+)$", ln.strip()) |
|
if not m: |
|
continue |
|
i = int(m.group(1)) |
|
body = " ".join(m.group(2).split()[:10]) |
|
if 1 <= i <= T: |
|
lines.append((i, f"Frame {i}: {body}")) |
|
have = {i for i,_ in lines} |
|
for i in range(1, T+1): |
|
if i not in have: |
|
lines.append((i, f"Frame {i}: (no description)")) |
|
return "\n".join(t for _, t in sorted(lines)) |
|
|
|
|
|
def build_framewise_prompt(T): |
|
return ( |
|
f"You will output exactly {T} plain lines, one per frame.\n" |
|
"Format strictly:\n" |
|
"Frame 1: <<=10 words>\n" |
|
"Frame 2: <<=10 words>\n" |
|
"...\n" |
|
"No brackets [], no JSON, no code blocks, no numbered list other than 'Frame i:'." |
|
) |
|
|
|
|
|
|
|
|
|
def save_image_to_local(image): |
|
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') |
|
image = Image.open(image) |
|
image.save(filename) |
|
|
|
return filename |
|
|
|
|
|
def save_video_to_local(video_path): |
|
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') |
|
shutil.copyfile(video_path, filename) |
|
return filename |
|
|
|
|
|
def generate(image1, video, textbox_in, first_run, state, state_, images_tensor): |
|
flag = 1 |
|
if not textbox_in: |
|
if len(state_.messages) > 0: |
|
textbox_in = state_.messages[-1][1] |
|
state_.messages.pop(-1) |
|
flag = 0 |
|
else: |
|
return "Please enter instruction" |
|
|
|
image1 = image1 if image1 else "none" |
|
video = video if video else "none" |
|
|
|
|
|
if type(state) is not Conversation: |
|
state = conv_templates[conv_mode].copy() |
|
state_ = conv_templates[conv_mode].copy() |
|
images_tensor = [[], []] |
|
|
|
first_run = False if len(state.messages) > 0 else True |
|
|
|
text_en_in = textbox_in.replace("picture", "image") |
|
|
|
|
|
image_processor = handler.image_processor |
|
if os.path.exists(image1) and not os.path.exists(video): |
|
tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] |
|
|
|
tensor = tensor.to(handler.model.device, dtype=dtype) |
|
images_tensor[0] = images_tensor[0] + [tensor] |
|
images_tensor[1] = images_tensor[1] + ['image'] |
|
print(torch.cuda.memory_allocated()) |
|
print(torch.cuda.max_memory_allocated()) |
|
video_processor = handler.video_processor |
|
if not os.path.exists(image1) and os.path.exists(video): |
|
tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] |
|
|
|
tensor = tensor.to(handler.model.device, dtype=dtype) |
|
images_tensor[0] = images_tensor[0] + [tensor] |
|
images_tensor[1] = images_tensor[1] + ['video'] |
|
print(torch.cuda.memory_allocated()) |
|
print(torch.cuda.max_memory_allocated()) |
|
if os.path.exists(image1) and os.path.exists(video): |
|
tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] |
|
|
|
tensor = tensor.to(handler.model.device, dtype=dtype) |
|
images_tensor[0] = images_tensor[0] + [tensor] |
|
images_tensor[1] = images_tensor[1] + ['video'] |
|
|
|
|
|
tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] |
|
|
|
tensor = tensor.to(handler.model.device, dtype=dtype) |
|
images_tensor[0] = images_tensor[0] + [tensor] |
|
images_tensor[1] = images_tensor[1] + ['image'] |
|
print(torch.cuda.memory_allocated()) |
|
print(torch.cuda.max_memory_allocated()) |
|
|
|
|
|
|
|
if os.path.exists(image1) and not os.path.exists(video): |
|
text_en_in = DEFAULT_X_TOKEN['IMAGE'] + '\n' + text_en_in |
|
if not os.path.exists(image1) and os.path.exists(video): |
|
text_en_in = DEFAULT_X_TOKEN['VIDEO'] + '\n' + text_en_in |
|
if os.path.exists(image1) and os.path.exists(video): |
|
text_en_in = DEFAULT_X_TOKEN['VIDEO'] + '\n' + text_en_in + '\n' + DEFAULT_X_TOKEN['IMAGE'] |
|
|
|
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) |
|
state_.messages[-1] = (state_.roles[1], text_en_out) |
|
|
|
text_en_out = text_en_out.split('#')[0] |
|
textbox_out = text_en_out |
|
|
|
show_images = "" |
|
if os.path.exists(image1): |
|
filename = save_image_to_local(image1) |
|
show_images += f'<img src="./file={filename}" style="display: inline-block;width: 250px;max-height: 400px;">' |
|
if os.path.exists(video): |
|
filename = save_video_to_local(video) |
|
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>' |
|
|
|
if flag: |
|
state.append_message(state.roles[0], textbox_in + "\n" + show_images) |
|
state.append_message(state.roles[1], textbox_out) |
|
torch.cuda.empty_cache() |
|
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) |
|
|
|
def regenerate(state, state_): |
|
state.messages.pop(-1) |
|
state_.messages.pop(-1) |
|
if len(state.messages) > 0: |
|
return state, state_, state.to_gradio_chatbot(), False |
|
return (state, state_, state.to_gradio_chatbot(), True) |
|
|
|
|
|
def clear_history(state, state_): |
|
state = conv_templates[conv_mode].copy() |
|
state_ = conv_templates[conv_mode].copy() |
|
return (gr.update(value=None, interactive=True), |
|
gr.update(value=None, interactive=True),\ |
|
gr.update(value=None, interactive=True),\ |
|
True, state, state_, state.to_gradio_chatbot(), [[], []]) |
|
|
|
|
|
|
|
conv_mode = "llava_v1" |
|
model_path = 'LanguageBind/Video-LLaVA-7B' |
|
device = 'cuda' |
|
load_8bit = False |
|
load_4bit = True |
|
dtype = torch.float16 |
|
handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device) |
|
|
|
if not os.path.exists("temp"): |
|
os.makedirs("temp") |
|
|
|
print(torch.cuda.memory_allocated()) |
|
print(torch.cuda.max_memory_allocated()) |
|
|
|
app = FastAPI() |
|
|
|
textbox = gr.Textbox( |
|
show_label=False, placeholder="Enter text and press ENTER", container=False |
|
) |
|
with gr.Blocks(title='Video-LLaVA🚀', theme=gr.themes.Default(), css=block_css) as demo: |
|
gr.Markdown(title_markdown) |
|
state = gr.State() |
|
state_ = gr.State() |
|
first_run = gr.State() |
|
images_tensor = gr.State() |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
image1 = gr.Image(label="Input Image", type="filepath") |
|
video = gr.Video(label="Input Video") |
|
|
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
f"{cur_dir}/examples/extreme_ironing.jpg", |
|
"What is unusual about this image?", |
|
], |
|
[ |
|
f"{cur_dir}/examples/waterview.jpg", |
|
"What are the things I should be cautious about when I visit here?", |
|
], |
|
[ |
|
f"{cur_dir}/examples/desert.jpg", |
|
"If there are factual errors in the questions, point it out; if not, proceed answering the question. What’s happening in the desert?", |
|
], |
|
], |
|
inputs=[image1, textbox], |
|
) |
|
|
|
with gr.Column(scale=7): |
|
chatbot = gr.Chatbot(label="Video-LLaVA", bubble_full_width=True).style(height=750) |
|
with gr.Row(): |
|
with gr.Column(scale=8): |
|
textbox.render() |
|
with gr.Column(scale=1, min_width=50): |
|
submit_btn = gr.Button( |
|
value="Send", variant="primary", interactive=True |
|
) |
|
with gr.Row(elem_id="buttons") as button_row: |
|
upvote_btn = gr.Button(value="👍 Upvote", interactive=True) |
|
downvote_btn = gr.Button(value="👎 Downvote", interactive=True) |
|
flag_btn = gr.Button(value="⚠️ Flag", interactive=True) |
|
|
|
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) |
|
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
f"{cur_dir}/examples/sample_img_8.png", |
|
f"{cur_dir}/examples/sample_demo_8.mp4", |
|
"Are the image and the video depicting the same place?", |
|
], |
|
[ |
|
f"{cur_dir}/examples/sample_img_22.png", |
|
f"{cur_dir}/examples/sample_demo_22.mp4", |
|
"Are the instruments in the pictures used in the video?", |
|
], |
|
[ |
|
f"{cur_dir}/examples/sample_img_13.png", |
|
f"{cur_dir}/examples/sample_demo_13.mp4", |
|
"Does the flag in the image appear in the video?", |
|
], |
|
], |
|
inputs=[image1, video, textbox], |
|
) |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
f"{cur_dir}/examples/sample_demo_1.mp4", |
|
"Why is this video funny?", |
|
], |
|
[ |
|
f"{cur_dir}/examples/sample_demo_7.mp4", |
|
"Create a short fairy tale with a moral lesson inspired by the video.", |
|
], |
|
[ |
|
f"{cur_dir}/examples/sample_demo_8.mp4", |
|
"Where is this video taken from? What place/landmark is shown in the video?", |
|
], |
|
[ |
|
f"{cur_dir}/examples/sample_demo_12.mp4", |
|
"What does the woman use to split the logs and how does she do it?", |
|
], |
|
[ |
|
f"{cur_dir}/examples/sample_demo_18.mp4", |
|
"Describe the video in detail.", |
|
], |
|
[ |
|
f"{cur_dir}/examples/sample_demo_22.mp4", |
|
"Describe the activity in the video.", |
|
], |
|
], |
|
inputs=[video, textbox], |
|
) |
|
gr.Markdown(tos_markdown) |
|
gr.Markdown(learn_more_markdown) |
|
|
|
submit_btn.click(generate, [image1, video, textbox, first_run, state, state_, images_tensor], |
|
[state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) |
|
|
|
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( |
|
generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) |
|
|
|
clear_btn.click(clear_history, [state, state_], |
|
[image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|