LLaVADinov2 / app.py
Rausda6's picture
Update app.py
ff3e8f4 verified
import os
# ---- Hub download settings (apply before any HF imports) ----
os.environ["HF_HUB_ENABLE_XET"] = "0"
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["HF_HUB_ENABLE_RESUME"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
import gradio as gr
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
# Use the compact HF-format LLaVA model
MODEL_ID = "xtuner/llava-phi-3-mini-hf"
# Device + dtype
if torch.cuda.is_available():
TORCH_DTYPE = torch.float16
else:
TORCH_DTYPE = torch.float32
def load_model():
"""
Load the LLaVA model and its processor.
"""
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=TORCH_DTYPE,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
# ---- Robustness: ensure processor carries vision attrs expected by LLaVA ----
vcfg = getattr(model.config, "vision_config", None)
if not hasattr(processor, "patch_size") or processor.patch_size is None:
# CLIP-L/336 typically uses patch_size=14; default to 14 if missing
processor.patch_size = getattr(vcfg, "patch_size", 14)
if (
not hasattr(processor, "vision_feature_select_strategy")
or processor.vision_feature_select_strategy is None
):
processor.vision_feature_select_strategy = getattr(
model.config, "vision_feature_select_strategy", "default"
)
if (
not hasattr(processor, "num_additional_image_tokens")
or processor.num_additional_image_tokens is None
):
# CLIP ViT uses a single CLS token
processor.num_additional_image_tokens = 1
return model, processor
# Load once at import
MODEL, PROCESSOR = load_model()
def answer_question(image: Image.Image, question: str) -> str:
"""
Generate an answer about the uploaded image.
"""
if image is None:
return "Please upload an image."
if not question or not question.strip():
return "Please enter a question about the image."
try:
# ---- Preferred: chat-template path (handles image + text cleanly) ----
conversation = [{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": question.strip()},
],
}]
inputs = PROCESSOR.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
images=[image],
)
except Exception:
# ---- Fallback: legacy prompt with <image> placeholder ----
prompt = f"USER: <image>\n{question.strip()} ASSISTANT:"
inputs = PROCESSOR(
images=image,
text=prompt,
return_tensors="pt",
)
# Move all tensors to the model's device
inputs = {k: (v.to(MODEL.device) if hasattr(v, "to") else v) for k, v in inputs.items()}
with torch.inference_mode():
generated_ids = MODEL.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
)
text = PROCESSOR.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0]
return text.strip()
def build_interface() -> gr.Interface:
description = (
"Upload an image and ask a question about it.\n\n"
"This demo uses **xtuner/llava-phi-3-mini-hf** (LLaVA in HF format) "
"to perform visual question answering. Note: a GPU is recommended; "
"CPU inference will be slow."
)
return gr.Interface(
fn=answer_question,
inputs=[
gr.Image(type="pil", label="Image"),
gr.Textbox(
label="Question",
placeholder="Describe or ask something about the image",
lines=1,
),
],
outputs=gr.Textbox(label="Answer"),
title="Visual Question Answering (LLaVA Phi-3 Mini)",
description=description,
flagging_mode="never",
)
def main() -> None:
iface = build_interface()
iface.launch()
if __name__ == "__main__":
main()