Spaces:
Runtime error
Runtime error
import os | |
# ---- Hub download settings (apply before any HF imports) ---- | |
os.environ["HF_HUB_ENABLE_XET"] = "0" | |
os.environ["HF_HUB_DISABLE_XET"] = "1" | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
os.environ["HF_HUB_ENABLE_RESUME"] = "1" | |
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, LlavaForConditionalGeneration | |
from PIL import Image | |
# Use the compact HF-format LLaVA model | |
MODEL_ID = "xtuner/llava-phi-3-mini-hf" | |
# Device + dtype | |
if torch.cuda.is_available(): | |
TORCH_DTYPE = torch.float16 | |
else: | |
TORCH_DTYPE = torch.float32 | |
def load_model(): | |
""" | |
Load the LLaVA model and its processor. | |
""" | |
model = LlavaForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
torch_dtype=TORCH_DTYPE, | |
device_map="auto", | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
) | |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
# ---- Robustness: ensure processor carries vision attrs expected by LLaVA ---- | |
vcfg = getattr(model.config, "vision_config", None) | |
if not hasattr(processor, "patch_size") or processor.patch_size is None: | |
# CLIP-L/336 typically uses patch_size=14; default to 14 if missing | |
processor.patch_size = getattr(vcfg, "patch_size", 14) | |
if ( | |
not hasattr(processor, "vision_feature_select_strategy") | |
or processor.vision_feature_select_strategy is None | |
): | |
processor.vision_feature_select_strategy = getattr( | |
model.config, "vision_feature_select_strategy", "default" | |
) | |
if ( | |
not hasattr(processor, "num_additional_image_tokens") | |
or processor.num_additional_image_tokens is None | |
): | |
# CLIP ViT uses a single CLS token | |
processor.num_additional_image_tokens = 1 | |
return model, processor | |
# Load once at import | |
MODEL, PROCESSOR = load_model() | |
def answer_question(image: Image.Image, question: str) -> str: | |
""" | |
Generate an answer about the uploaded image. | |
""" | |
if image is None: | |
return "Please upload an image." | |
if not question or not question.strip(): | |
return "Please enter a question about the image." | |
try: | |
# ---- Preferred: chat-template path (handles image + text cleanly) ---- | |
conversation = [{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": question.strip()}, | |
], | |
}] | |
inputs = PROCESSOR.apply_chat_template( | |
conversation, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
images=[image], | |
) | |
except Exception: | |
# ---- Fallback: legacy prompt with <image> placeholder ---- | |
prompt = f"USER: <image>\n{question.strip()} ASSISTANT:" | |
inputs = PROCESSOR( | |
images=image, | |
text=prompt, | |
return_tensors="pt", | |
) | |
# Move all tensors to the model's device | |
inputs = {k: (v.to(MODEL.device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
with torch.inference_mode(): | |
generated_ids = MODEL.generate( | |
**inputs, | |
max_new_tokens=256, | |
do_sample=False, | |
) | |
text = PROCESSOR.batch_decode( | |
generated_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
)[0] | |
return text.strip() | |
def build_interface() -> gr.Interface: | |
description = ( | |
"Upload an image and ask a question about it.\n\n" | |
"This demo uses **xtuner/llava-phi-3-mini-hf** (LLaVA in HF format) " | |
"to perform visual question answering. Note: a GPU is recommended; " | |
"CPU inference will be slow." | |
) | |
return gr.Interface( | |
fn=answer_question, | |
inputs=[ | |
gr.Image(type="pil", label="Image"), | |
gr.Textbox( | |
label="Question", | |
placeholder="Describe or ask something about the image", | |
lines=1, | |
), | |
], | |
outputs=gr.Textbox(label="Answer"), | |
title="Visual Question Answering (LLaVA Phi-3 Mini)", | |
description=description, | |
flagging_mode="never", | |
) | |
def main() -> None: | |
iface = build_interface() | |
iface.launch() | |
if __name__ == "__main__": | |
main() | |