Captain / app.py
mrbeliever's picture
Update app.py
f07f9e1 verified
raw
history blame
2.71 kB
from typing import Any
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
DEFAULT_PARAMS = {
"do_sample": False,
"max_new_tokens": 256,
}
DEFAULT_QUERY = (
"Provide a factual description of this image in up to two paragraphs. "
"Include details on objects, background, scenery, interactions, gestures, poses, and any visible text content. "
"Specify the number of repeated objects. "
"Describe the dominant colors, color contrasts, textures, and materials. "
"Mention the composition, including the arrangement of elements and focus points. "
"Note the camera angle or perspective, and provide any identifiable contextual information. "
"Include details on the style, lighting, and shadows. "
"Avoid subjective interpretations or speculation."
)
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = LlamaTokenizer.from_pretrained(
pretrained_model_name_or_path="lmsys/vicuna-7b-v1.5",
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path="THUDM/cogvlm-chat-hf",
torch_dtype=DTYPE,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
model = model.to(device=DEVICE)
@spaces.GPU
@torch.no_grad()
def generate_caption(
image: Image.Image,
params: dict[str, Any] = DEFAULT_PARAMS,
) -> str:
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=DEFAULT_QUERY, # Use the default query directly
history=[],
images=[image],
)
inputs = {
"input_ids": inputs["input_ids"].unsqueeze(0).to(device=DEVICE),
"token_type_ids": inputs["token_type_ids"].unsqueeze(0).to(device=DEVICE),
"attention_mask": inputs["attention_mask"].unsqueeze(0).to(device=DEVICE),
"images": [[inputs["images"][0].to(device=DEVICE, dtype=DTYPE)]],
}
outputs = model.generate(**inputs, **params)
outputs = outputs[:, inputs["input_ids"].shape[1] :]
result = tokenizer.decode(outputs[0])
result = result.replace("This image showcases", "").strip().removesuffix("</s>").strip().capitalize()
return result
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil") # Image input remains
run_button = gr.Button(value="Generate Caption")
with gr.Column():
output_caption = gr.Textbox(label="Generated Caption", show_copy_button=True)
run_button.click(
fn=generate_caption,
inputs=[input_image], # Only the image input is passed
outputs=output_caption,
)
demo.launch(share=False)