File size: 8,057 Bytes
a6b4aae
592af71
7c73997
a6b4aae
7c73997
2a3300f
7c73997
a6b4aae
 
9cd7739
7c73997
ad0b86c
a6b4aae
2a3300f
a6b4aae
2a3300f
 
a6b4aae
 
9cd7739
a6b4aae
 
 
 
 
 
9cd7739
2a3300f
 
9cd7739
2a3300f
9cd7739
2a3300f
a6b4aae
9cd7739
2a3300f
9cd7739
2a3300f
 
a6b4aae
2a3300f
9cd7739
2a3300f
a6b4aae
 
 
 
9cd7739
2a3300f
9cd7739
a6b4aae
 
 
 
 
2a3300f
a6b4aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad0b86c
 
 
 
a6b4aae
 
592af71
2a3300f
9cd7739
2a3300f
9cd7739
2a3300f
 
9cd7739
 
 
 
 
 
2a3300f
 
 
 
 
 
9cd7739
2a3300f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cd7739
2a3300f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c73997
 
2a3300f
 
7c73997
 
592af71
7c73997
 
 
 
 
 
 
 
592af71
7c73997
 
 
 
a6b4aae
7c73997
 
 
592af71
7c73997
592af71
 
 
7c73997
9cd7739
a6b4aae
2a3300f
592af71
 
 
2a3300f
 
 
 
 
9cd7739
592af71
9cd7739
a6b4aae
9cd7739
7c73997
a6b4aae
 
 
 
 
7c73997
 
592af71
7c73997
a6b4aae
 
 
9cd7739
 
2a3300f
a6b4aae
 
 
2a3300f
a6b4aae
 
2a3300f
 
 
a6b4aae
 
 
 
 
7c73997
592af71
9cd7739
 
a6b4aae
7c73997
 
592af71
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# filename: app.py
import os
import torch
import gradio as gr
from PIL import Image
import spaces

from transformers import (
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,  # ✅ 正确的多模态类
)

# ========================
# 基本设置
# ========================
torch.manual_seed(100)
MODEL_NAME = "Eric3200/OralGPT-7B-Preview"

# ========================
# 模型与处理器
# ========================
processor = AutoProcessor.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
)

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype="auto",         # 或 torch.bfloat16 / torch.float16
    device_map="auto",
    # attn_implementation="flash_attention_2",  # 环境支持再打开(速度/显存更优)
).eval()

# pad_token 兜底
try:
    PAD_ID = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
except Exception:
    PAD_ID = None

# ========================
# 输入转 Qwen/OralGPT 风格 content
# ========================
def to_qwen_content(user_msg):
    """
    支持:
      - 纯文本
      - (img_path, text)
      - list 混合(PIL.Image / 文本 / 本地路径字符串)
      - dict {files:[...], text:"..."}
    """
    contents = []
    if isinstance(user_msg, tuple):
        img_path, text = user_msg
        if img_path:
            contents.append({"type": "image", "image": img_path})
        if text:
            contents.append({"type": "text", "text": str(text)})
    elif isinstance(user_msg, list):
        for item in user_msg:
            if isinstance(item, Image.Image):
                contents.append({"type": "image", "image": item})
            elif isinstance(item, str) and os.path.isfile(item):
                contents.append({"type": "image", "image": item})
            else:
                contents.append({"type": "text", "text": str(item)})
    elif isinstance(user_msg, dict):
        if user_msg.get("files"):
            for fp in user_msg["files"]:
                contents.append({"type": "image", "image": fp})
        if user_msg.get("text"):
            contents.append({"type": "text", "text": str(user_msg["text"])})
    else:
        contents.append({"type": "text", "text": str(user_msg)})
    return contents

def build_messages(history, latest_user_msg):
    messages = []
    for u, a in history:
        messages.append({"role": "user", "content": to_qwen_content(u)})
        if a:
            messages.append({"role": "assistant", "content": [{"type": "text", "text": str(a)}]})
    messages.append({"role": "user", "content": to_qwen_content(latest_user_msg)})
    return messages

# ========================
# ZeroGPU 推理(关键:@spaces.GPU)
# ========================
@spaces.GPU(duration=120)
def qwen_infer(messages, max_new_tokens=512, do_sample=False, temperature=0.7, top_p=0.9, pad_token_id=None):
    """
    流程:
      1) apply_chat_template → prompt
      2) 收集 messages 中的图片
      3) processor(...) 打包
      4) model.generate(...)
      5) 取新增 token 解码
    """
    # 1) 文本模板
    text_prompt = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    # 2) 收集图片(demo 不处理视频)
    images = []
    for m in messages:
        for c in m.get("content", []):
            if c.get("type") == "image":
                img_obj = c.get("image")
                if isinstance(img_obj, Image.Image):
                    images.append(img_obj)
                elif isinstance(img_obj, str) and os.path.isfile(img_obj):
                    images.append(Image.open(img_obj).convert("RGB"))

    # 3) 打包输入
    inputs = processor(
        text=[text_prompt],
        images=images if images else None,
        padding=True,
        return_tensors="pt",
    )

    # 放到 GPU/对应设备
    if torch.cuda.is_available():
        inputs = inputs.to("cuda")
    else:
        inputs = inputs.to(model.device)

    gen_kwargs = dict(
        max_new_tokens=int(max_new_tokens),
        do_sample=bool(do_sample),
    )
    if do_sample:
        gen_kwargs.update(dict(temperature=float(temperature), top_p=float(top_p)))
    if pad_token_id is not None:
        gen_kwargs.update(dict(pad_token_id=int(pad_token_id)))

    # 4) 生成
    generated_ids = model.generate(**inputs, **gen_kwargs)

    # 5) 只取新增 token 并解码
    trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    outputs = processor.batch_decode(
        trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return outputs[0].strip() if outputs else "(空响应)"

# ========================
# Gradio 界面
# ========================
with gr.Blocks(title="OralGPT-7B-Preview — Gradio Demo (ZeroGPU)") as demo:
    gr.Markdown(
        """
        # 🦷 OralGPT-7B-Preview — 多模态对话 Demo (ZeroGPU)
        上传牙科相关图片并提问,或进行文字对话。右侧可开启 “Thinking(采样)模式” 拓展输出。
        """
    )

    with gr.Row():
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                height=500,
                show_label=False,
                container=True,
                type="tuples"
            )

            with gr.Row():
                msg = gr.MultimodalTextbox(
                    interactive=True,
                    file_types=["image"],
                    placeholder="输入消息或上传图片(可多张)…",
                    show_label=False,
                    container=False
                )

            with gr.Row():
                clear = gr.Button("🗑️ 清空", size="sm")
                submit = gr.Button("📤 发送", variant="primary", size="sm")

        with gr.Column(scale=1):
            enable_thinking = gr.Checkbox(label="启用 Thinking(采样)模式", value=False)
            max_tokens = gr.Slider(64, 2048, value=512, step=32, label="max_new_tokens")

            gr.Markdown(
                """
                ### 示例
                - “请问这张图片中是否存在龋病现象?”
                - “是否存在牙周病现象?”
                - “描述图像中的可疑病灶区域。”
                """
            )
            gr.Markdown("**注意:此模型仅供研究参考,不用于临床诊断或治疗。**")

    # 事件逻辑
    def user_submit(message, history, enable_thinking, max_new_tokens):
        # 组织本轮用户消息
        if isinstance(message, dict) and message.get("files"):
            user_msg = []
            for fp in message["files"]:
                user_msg.append(fp)
            if message.get("text", ""):
                user_msg.append(message["text"])
        else:
            user_msg = message.get("text", "") if isinstance(message, dict) else message

        history = history + [(user_msg, None)]
        messages = build_messages(history[:-1], user_msg)

        do_sample = bool(enable_thinking)
        # generate 需要数值,这里给默认
        temperature = 0.7 if do_sample else 0.7
        top_p = 0.9 if do_sample else 0.9

        try:
            answer = qwen_infer(
                messages=messages,
                max_new_tokens=int(max_new_tokens),
                do_sample=do_sample,
                temperature=temperature,
                top_p=top_p,
                pad_token_id=PAD_ID,
            )
        except Exception as e:
            answer = f"Error: {str(e)}"

        history[-1] = (history[-1][0], answer if answer else "(空响应)")
        return "", history

    msg.submit(user_submit, inputs=[msg, chatbot, enable_thinking, max_tokens], outputs=[msg, chatbot])
    submit.click(user_submit, inputs=[msg, chatbot, enable_thinking, max_tokens], outputs=[msg, chatbot])
    clear.click(lambda: (None, []), outputs=[msg, chatbot])

if __name__ == "__main__":
    demo.launch(share=True)