File size: 4,428 Bytes
62be6ed
 
ff3e8f4
62be6ed
ff3e8f4
62be6ed
 
 
 
621ed57
 
 
16dfb59
621ed57
ff3e8f4
be08ea0
621ed57
ff3e8f4
621ed57
 
 
 
 
 
 
ff3e8f4
 
621ed57
 
 
 
 
 
 
 
 
ff3e8f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621ed57
 
 
ff3e8f4
621ed57
 
 
16dfb59
621ed57
ff3e8f4
 
621ed57
 
 
 
 
ff3e8f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621ed57
ff3e8f4
 
 
 
621ed57
 
 
 
 
 
ff3e8f4
621ed57
 
 
 
 
ff3e8f4
621ed57
 
 
 
 
ff3e8f4
 
 
621ed57
ff3e8f4
621ed57
 
 
 
 
 
 
 
 
 
ff3e8f4
621ed57
16dfb59
621ed57
 
 
 
 
 
 
 
 
ff3e8f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os

# ---- Hub download settings (apply before any HF imports) ----
os.environ["HF_HUB_ENABLE_XET"] = "0"
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["HF_HUB_ENABLE_RESUME"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

import gradio as gr
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image

# Use the compact HF-format LLaVA model
MODEL_ID = "xtuner/llava-phi-3-mini-hf"

# Device + dtype
if torch.cuda.is_available():
    TORCH_DTYPE = torch.float16
else:
    TORCH_DTYPE = torch.float32


def load_model():
    """
    Load the LLaVA model and its processor.
    """
    model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=TORCH_DTYPE,
        device_map="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

    # ---- Robustness: ensure processor carries vision attrs expected by LLaVA ----
    vcfg = getattr(model.config, "vision_config", None)

    if not hasattr(processor, "patch_size") or processor.patch_size is None:
        # CLIP-L/336 typically uses patch_size=14; default to 14 if missing
        processor.patch_size = getattr(vcfg, "patch_size", 14)

    if (
        not hasattr(processor, "vision_feature_select_strategy")
        or processor.vision_feature_select_strategy is None
    ):
        processor.vision_feature_select_strategy = getattr(
            model.config, "vision_feature_select_strategy", "default"
        )

    if (
        not hasattr(processor, "num_additional_image_tokens")
        or processor.num_additional_image_tokens is None
    ):
        # CLIP ViT uses a single CLS token
        processor.num_additional_image_tokens = 1

    return model, processor


# Load once at import
MODEL, PROCESSOR = load_model()


def answer_question(image: Image.Image, question: str) -> str:
    """
    Generate an answer about the uploaded image.
    """
    if image is None:
        return "Please upload an image."
    if not question or not question.strip():
        return "Please enter a question about the image."

    try:
        # ---- Preferred: chat-template path (handles image + text cleanly) ----
        conversation = [{
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": question.strip()},
            ],
        }]

        inputs = PROCESSOR.apply_chat_template(
            conversation,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
            images=[image],
        )
    except Exception:
        # ---- Fallback: legacy prompt with <image> placeholder ----
        prompt = f"USER: <image>\n{question.strip()} ASSISTANT:"
        inputs = PROCESSOR(
            images=image,
            text=prompt,
            return_tensors="pt",
        )

    # Move all tensors to the model's device
    inputs = {k: (v.to(MODEL.device) if hasattr(v, "to") else v) for k, v in inputs.items()}

    with torch.inference_mode():
        generated_ids = MODEL.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=False,
        )

    text = PROCESSOR.batch_decode(
        generated_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )[0]

    return text.strip()


def build_interface() -> gr.Interface:
    description = (
        "Upload an image and ask a question about it.\n\n"
        "This demo uses **xtuner/llava-phi-3-mini-hf** (LLaVA in HF format) "
        "to perform visual question answering. Note: a GPU is recommended; "
        "CPU inference will be slow."
    )
    return gr.Interface(
        fn=answer_question,
        inputs=[
            gr.Image(type="pil", label="Image"),
            gr.Textbox(
                label="Question",
                placeholder="Describe or ask something about the image",
                lines=1,
            ),
        ],
        outputs=gr.Textbox(label="Answer"),
        title="Visual Question Answering (LLaVA Phi-3 Mini)",
        description=description,
        flagging_mode="never",
    )


def main() -> None:
    iface = build_interface()
    iface.launch()


if __name__ == "__main__":
    main()