File size: 6,113 Bytes
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
b3c036b
 
70d4250
b3c036b
 
70d4250
 
 
 
dc4a84a
 
 
70d4250
 
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d4250
dc4a84a
 
70d4250
 
 
 
dc4a84a
 
70d4250
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d4250
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d4250
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d4250
 
dc4a84a
70d4250
 
 
dc4a84a
70d4250
dc4a84a
 
 
70d4250
 
 
 
 
 
 
 
 
 
 
 
dc4a84a
70d4250
 
 
 
 
 
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
70d4250
dc4a84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
from typing import Tuple

import gradio as gr
from PIL import Image

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    VisionEncoderDecoderModel,
    TrOCRProcessor,
)
from huggingface_hub import login

# Optional: login via repo secret HF_TOKEN in Spaces
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    try:
        login(token=hf_token)
    except Exception:
        pass

TITLE = "Picture to Problem Solver"
DESCRIPTION = (
    "Upload an image. I’ll read the text and a math/code/science-trained AI will help answer your question.\n\n"
    "⚠️ Note: facebook/MobileLLM-R1-950M is released for non-commercial research use."
)

# ---------------------------
# Load OCR (TrOCR)
# ---------------------------
OCR_MODEL_ID = os.getenv("OCR_MODEL_ID", "microsoft/trocr-base-printed")
ocr_processor = TrOCRProcessor.from_pretrained(OCR_MODEL_ID)
ocr_model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_ID)
ocr_model.eval()

# ---------------------------
# Load MobileLLM
# ---------------------------
LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "facebook/MobileLLM-R1-950M")

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32

llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True)
# Ensure pad token exists to prevent warnings during generation
if llm_tokenizer.pad_token_id is None and llm_tokenizer.eos_token_id is not None:
    llm_tokenizer.pad_token = llm_tokenizer.eos_token

llm_model = AutoModelForCausalLM.from_pretrained(
    LLM_MODEL_ID,
    dtype=dtype,
    low_cpu_mem_usage=True,
    device_map="auto" if device == "cuda" else None,
)
llm_model.eval()
if device == "cpu":
    llm_model.to(device)

eos_token_id = llm_tokenizer.eos_token_id
if eos_token_id is None:
    llm_tokenizer.add_special_tokens({"eos_token": "</s>"})
    llm_model.resize_token_embeddings(len(llm_tokenizer))
    eos_token_id = llm_tokenizer.eos_token_id

SYSTEM_INSTRUCTION = (
    "You are a precise, step-by-step technical assistant. "
    "You excel at math, programming (Python, C++), and scientific reasoning. "
    "Be concise, show steps when helpful, and avoid hallucinations."
)

USER_PROMPT_TEMPLATE = (
    "Extracted text from the image:\n"
    "-----------------------------\n"
    "{ocr_text}\n"
    "-----------------------------\n"
    "{question_hint}"
)

def build_prompt(ocr_text: str, user_question: str) -> str:
    if user_question and user_question.strip():
        q = f"User question: {user_question.strip()}"
    else:
        q = "Please summarize the key information and explain any math/code/science content."
    return f"{SYSTEM_INSTRUCTION}\n\n" + USER_PROMPT_TEMPLATE.format(
        ocr_text=(ocr_text or "").strip() or "(no text detected)",
        question_hint=q,
    )

@torch.inference_mode()
def run_pipeline(
    image: Image.Image,
    question: str,
    max_new_tokens: int = 256,
    temperature: float = 0.2,
    top_p: float = 0.9,
) -> Tuple[str, str]:
    if image is None:
        return "", "Please upload an image."

    # --- OCR ---
    try:
        pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
        ocr_ids = ocr_model.generate(pixel_values, max_new_tokens=256)
        extracted_text = ocr_processor.batch_decode(ocr_ids, skip_special_tokens=True)[0].strip()
    except Exception as e:
        return "", f"OCR failed: {e}"

    # --- Build prompt ---
    prompt = build_prompt(extracted_text, question)

    # --- LLM Inference ---
    try:
        inputs = llm_tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(llm_model.device if device == "cuda" else device) for k, v in inputs.items()}

        generation_kwargs = dict(
            max_new_tokens=max_new_tokens,
            do_sample=temperature > 0,
            temperature=max(0.0, min(temperature, 1.5)),
            top_p=max(0.1, min(top_p, 1.0)),
            eos_token_id=eos_token_id,
            pad_token_id=llm_tokenizer.pad_token_id if llm_tokenizer.pad_token_id is not None else eos_token_id,
        )

        output_ids = llm_model.generate(**inputs, **generation_kwargs)
        gen_text = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
        if gen_text.startswith(prompt):
            gen_text = gen_text[len(prompt):].lstrip()
    except Exception as e:
        gen_text = f"LLM inference failed: {e}"

    return extracted_text, gen_text

def demo_ui():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown(f"# {TITLE}")
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(type="pil", label="Upload an image")
                question = gr.Textbox(
                    label="Ask a question about the image (optional)",
                    placeholder="e.g., Summarize, extract key numbers, explain this formula, convert code to Python...",
                )
                with gr.Accordion("Generation settings (advanced)", open=False):
                    max_new_tokens = gr.Slider(32, 1024, value=256, step=16, label="max_new_tokens")
                    temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="temperature")
                    top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")

                run_btn = gr.Button("Run")

            with gr.Column(scale=1):
                ocr_out = gr.Textbox(label="Extracted Text (OCR)", lines=8)
                llm_out = gr.Markdown(label="AI Answer", elem_id="ai-answer")

        run_btn.click(
            run_pipeline,
            inputs=[image_input, question, max_new_tokens, temperature, top_p],
            outputs=[ocr_out, llm_out],
        )

        gr.Markdown(
            "—\n**Licensing reminder:** facebook/MobileLLM-R1-950M is typically released for non-commercial research use. "
            "Review the model card before production use."
        )

    return demo

if __name__ == "__main__":
    demo = demo_ui()
    demo.launch()