import os # ---- Hub download settings (apply before any HF imports) ---- os.environ["HF_HUB_ENABLE_XET"] = "0" os.environ["HF_HUB_DISABLE_XET"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_RESUME"] = "1" os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" import gradio as gr import torch from transformers import AutoProcessor, LlavaForConditionalGeneration from PIL import Image # Use the compact HF-format LLaVA model MODEL_ID = "xtuner/llava-phi-3-mini-hf" # Device + dtype if torch.cuda.is_available(): TORCH_DTYPE = torch.float16 else: TORCH_DTYPE = torch.float32 def load_model(): """ Load the LLaVA model and its processor. """ model = LlavaForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=TORCH_DTYPE, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, ) processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) # ---- Robustness: ensure processor carries vision attrs expected by LLaVA ---- vcfg = getattr(model.config, "vision_config", None) if not hasattr(processor, "patch_size") or processor.patch_size is None: # CLIP-L/336 typically uses patch_size=14; default to 14 if missing processor.patch_size = getattr(vcfg, "patch_size", 14) if ( not hasattr(processor, "vision_feature_select_strategy") or processor.vision_feature_select_strategy is None ): processor.vision_feature_select_strategy = getattr( model.config, "vision_feature_select_strategy", "default" ) if ( not hasattr(processor, "num_additional_image_tokens") or processor.num_additional_image_tokens is None ): # CLIP ViT uses a single CLS token processor.num_additional_image_tokens = 1 return model, processor # Load once at import MODEL, PROCESSOR = load_model() def answer_question(image: Image.Image, question: str) -> str: """ Generate an answer about the uploaded image. """ if image is None: return "Please upload an image." if not question or not question.strip(): return "Please enter a question about the image." try: # ---- Preferred: chat-template path (handles image + text cleanly) ---- conversation = [{ "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": question.strip()}, ], }] inputs = PROCESSOR.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", images=[image], ) except Exception: # ---- Fallback: legacy prompt with placeholder ---- prompt = f"USER: \n{question.strip()} ASSISTANT:" inputs = PROCESSOR( images=image, text=prompt, return_tensors="pt", ) # Move all tensors to the model's device inputs = {k: (v.to(MODEL.device) if hasattr(v, "to") else v) for k, v in inputs.items()} with torch.inference_mode(): generated_ids = MODEL.generate( **inputs, max_new_tokens=256, do_sample=False, ) text = PROCESSOR.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True, )[0] return text.strip() def build_interface() -> gr.Interface: description = ( "Upload an image and ask a question about it.\n\n" "This demo uses **xtuner/llava-phi-3-mini-hf** (LLaVA in HF format) " "to perform visual question answering. Note: a GPU is recommended; " "CPU inference will be slow." ) return gr.Interface( fn=answer_question, inputs=[ gr.Image(type="pil", label="Image"), gr.Textbox( label="Question", placeholder="Describe or ask something about the image", lines=1, ), ], outputs=gr.Textbox(label="Answer"), title="Visual Question Answering (LLaVA Phi-3 Mini)", description=description, flagging_mode="never", ) def main() -> None: iface = build_interface() iface.launch() if __name__ == "__main__": main()