|
|
|
import os |
|
import torch |
|
import gradio as gr |
|
from PIL import Image |
|
import spaces |
|
|
|
from transformers import ( |
|
AutoProcessor, |
|
Qwen2_5_VLForConditionalGeneration, |
|
) |
|
|
|
|
|
|
|
|
|
torch.manual_seed(100) |
|
MODEL_NAME = "Eric3200/OralGPT-7B-Preview" |
|
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
) |
|
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
torch_dtype="auto", |
|
device_map="auto", |
|
|
|
).eval() |
|
|
|
|
|
try: |
|
PAD_ID = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id |
|
except Exception: |
|
PAD_ID = None |
|
|
|
|
|
|
|
|
|
def to_qwen_content(user_msg): |
|
""" |
|
支持: |
|
- 纯文本 |
|
- (img_path, text) |
|
- list 混合(PIL.Image / 文本 / 本地路径字符串) |
|
- dict {files:[...], text:"..."} |
|
""" |
|
contents = [] |
|
if isinstance(user_msg, tuple): |
|
img_path, text = user_msg |
|
if img_path: |
|
contents.append({"type": "image", "image": img_path}) |
|
if text: |
|
contents.append({"type": "text", "text": str(text)}) |
|
elif isinstance(user_msg, list): |
|
for item in user_msg: |
|
if isinstance(item, Image.Image): |
|
contents.append({"type": "image", "image": item}) |
|
elif isinstance(item, str) and os.path.isfile(item): |
|
contents.append({"type": "image", "image": item}) |
|
else: |
|
contents.append({"type": "text", "text": str(item)}) |
|
elif isinstance(user_msg, dict): |
|
if user_msg.get("files"): |
|
for fp in user_msg["files"]: |
|
contents.append({"type": "image", "image": fp}) |
|
if user_msg.get("text"): |
|
contents.append({"type": "text", "text": str(user_msg["text"])}) |
|
else: |
|
contents.append({"type": "text", "text": str(user_msg)}) |
|
return contents |
|
|
|
def build_messages(history, latest_user_msg): |
|
messages = [] |
|
for u, a in history: |
|
messages.append({"role": "user", "content": to_qwen_content(u)}) |
|
if a: |
|
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(a)}]}) |
|
messages.append({"role": "user", "content": to_qwen_content(latest_user_msg)}) |
|
return messages |
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def qwen_infer(messages, max_new_tokens=512, do_sample=False, temperature=0.7, top_p=0.9, pad_token_id=None): |
|
""" |
|
流程: |
|
1) apply_chat_template → prompt |
|
2) 收集 messages 中的图片 |
|
3) processor(...) 打包 |
|
4) model.generate(...) |
|
5) 取新增 token 解码 |
|
""" |
|
|
|
text_prompt = processor.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True |
|
) |
|
|
|
|
|
images = [] |
|
for m in messages: |
|
for c in m.get("content", []): |
|
if c.get("type") == "image": |
|
img_obj = c.get("image") |
|
if isinstance(img_obj, Image.Image): |
|
images.append(img_obj) |
|
elif isinstance(img_obj, str) and os.path.isfile(img_obj): |
|
images.append(Image.open(img_obj).convert("RGB")) |
|
|
|
|
|
inputs = processor( |
|
text=[text_prompt], |
|
images=images if images else None, |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
inputs = inputs.to("cuda") |
|
else: |
|
inputs = inputs.to(model.device) |
|
|
|
gen_kwargs = dict( |
|
max_new_tokens=int(max_new_tokens), |
|
do_sample=bool(do_sample), |
|
) |
|
if do_sample: |
|
gen_kwargs.update(dict(temperature=float(temperature), top_p=float(top_p))) |
|
if pad_token_id is not None: |
|
gen_kwargs.update(dict(pad_token_id=int(pad_token_id))) |
|
|
|
|
|
generated_ids = model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
trimmed = [ |
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
] |
|
outputs = processor.batch_decode( |
|
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
) |
|
return outputs[0].strip() if outputs else "(空响应)" |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="OralGPT-7B-Preview — Gradio Demo (ZeroGPU)") as demo: |
|
gr.Markdown( |
|
""" |
|
# 🦷 OralGPT-7B-Preview — 多模态对话 Demo (ZeroGPU) |
|
上传牙科相关图片并提问,或进行文字对话。右侧可开启 “Thinking(采样)模式” 拓展输出。 |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
chatbot = gr.Chatbot( |
|
height=500, |
|
show_label=False, |
|
container=True, |
|
type="tuples" |
|
) |
|
|
|
with gr.Row(): |
|
msg = gr.MultimodalTextbox( |
|
interactive=True, |
|
file_types=["image"], |
|
placeholder="输入消息或上传图片(可多张)…", |
|
show_label=False, |
|
container=False |
|
) |
|
|
|
with gr.Row(): |
|
clear = gr.Button("🗑️ 清空", size="sm") |
|
submit = gr.Button("📤 发送", variant="primary", size="sm") |
|
|
|
with gr.Column(scale=1): |
|
enable_thinking = gr.Checkbox(label="启用 Thinking(采样)模式", value=False) |
|
max_tokens = gr.Slider(64, 2048, value=512, step=32, label="max_new_tokens") |
|
|
|
gr.Markdown( |
|
""" |
|
### 示例 |
|
- “请问这张图片中是否存在龋病现象?” |
|
- “是否存在牙周病现象?” |
|
- “描述图像中的可疑病灶区域。” |
|
""" |
|
) |
|
gr.Markdown("**注意:此模型仅供研究参考,不用于临床诊断或治疗。**") |
|
|
|
|
|
def user_submit(message, history, enable_thinking, max_new_tokens): |
|
|
|
if isinstance(message, dict) and message.get("files"): |
|
user_msg = [] |
|
for fp in message["files"]: |
|
user_msg.append(fp) |
|
if message.get("text", ""): |
|
user_msg.append(message["text"]) |
|
else: |
|
user_msg = message.get("text", "") if isinstance(message, dict) else message |
|
|
|
history = history + [(user_msg, None)] |
|
messages = build_messages(history[:-1], user_msg) |
|
|
|
do_sample = bool(enable_thinking) |
|
|
|
temperature = 0.7 if do_sample else 0.7 |
|
top_p = 0.9 if do_sample else 0.9 |
|
|
|
try: |
|
answer = qwen_infer( |
|
messages=messages, |
|
max_new_tokens=int(max_new_tokens), |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
top_p=top_p, |
|
pad_token_id=PAD_ID, |
|
) |
|
except Exception as e: |
|
answer = f"Error: {str(e)}" |
|
|
|
history[-1] = (history[-1][0], answer if answer else "(空响应)") |
|
return "", history |
|
|
|
msg.submit(user_submit, inputs=[msg, chatbot, enable_thinking, max_tokens], outputs=[msg, chatbot]) |
|
submit.click(user_submit, inputs=[msg, chatbot, enable_thinking, max_tokens], outputs=[msg, chatbot]) |
|
clear.click(lambda: (None, []), outputs=[msg, chatbot]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|