OralGPT / app.py
Eric3200's picture
Upload app.py
9cd7739 verified
# filename: app.py
import os
import torch
import gradio as gr
from PIL import Image
import spaces
from transformers import (
AutoProcessor,
Qwen2_5_VLForConditionalGeneration, # ✅ 正确的多模态类
)
# ========================
# 基本设置
# ========================
torch.manual_seed(100)
MODEL_NAME = "Eric3200/OralGPT-7B-Preview"
# ========================
# 模型与处理器
# ========================
processor = AutoProcessor.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype="auto", # 或 torch.bfloat16 / torch.float16
device_map="auto",
# attn_implementation="flash_attention_2", # 环境支持再打开(速度/显存更优)
).eval()
# pad_token 兜底
try:
PAD_ID = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
except Exception:
PAD_ID = None
# ========================
# 输入转 Qwen/OralGPT 风格 content
# ========================
def to_qwen_content(user_msg):
"""
支持:
- 纯文本
- (img_path, text)
- list 混合(PIL.Image / 文本 / 本地路径字符串)
- dict {files:[...], text:"..."}
"""
contents = []
if isinstance(user_msg, tuple):
img_path, text = user_msg
if img_path:
contents.append({"type": "image", "image": img_path})
if text:
contents.append({"type": "text", "text": str(text)})
elif isinstance(user_msg, list):
for item in user_msg:
if isinstance(item, Image.Image):
contents.append({"type": "image", "image": item})
elif isinstance(item, str) and os.path.isfile(item):
contents.append({"type": "image", "image": item})
else:
contents.append({"type": "text", "text": str(item)})
elif isinstance(user_msg, dict):
if user_msg.get("files"):
for fp in user_msg["files"]:
contents.append({"type": "image", "image": fp})
if user_msg.get("text"):
contents.append({"type": "text", "text": str(user_msg["text"])})
else:
contents.append({"type": "text", "text": str(user_msg)})
return contents
def build_messages(history, latest_user_msg):
messages = []
for u, a in history:
messages.append({"role": "user", "content": to_qwen_content(u)})
if a:
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(a)}]})
messages.append({"role": "user", "content": to_qwen_content(latest_user_msg)})
return messages
# ========================
# ZeroGPU 推理(关键:@spaces.GPU)
# ========================
@spaces.GPU(duration=120)
def qwen_infer(messages, max_new_tokens=512, do_sample=False, temperature=0.7, top_p=0.9, pad_token_id=None):
"""
流程:
1) apply_chat_template → prompt
2) 收集 messages 中的图片
3) processor(...) 打包
4) model.generate(...)
5) 取新增 token 解码
"""
# 1) 文本模板
text_prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# 2) 收集图片(demo 不处理视频)
images = []
for m in messages:
for c in m.get("content", []):
if c.get("type") == "image":
img_obj = c.get("image")
if isinstance(img_obj, Image.Image):
images.append(img_obj)
elif isinstance(img_obj, str) and os.path.isfile(img_obj):
images.append(Image.open(img_obj).convert("RGB"))
# 3) 打包输入
inputs = processor(
text=[text_prompt],
images=images if images else None,
padding=True,
return_tensors="pt",
)
# 放到 GPU/对应设备
if torch.cuda.is_available():
inputs = inputs.to("cuda")
else:
inputs = inputs.to(model.device)
gen_kwargs = dict(
max_new_tokens=int(max_new_tokens),
do_sample=bool(do_sample),
)
if do_sample:
gen_kwargs.update(dict(temperature=float(temperature), top_p=float(top_p)))
if pad_token_id is not None:
gen_kwargs.update(dict(pad_token_id=int(pad_token_id)))
# 4) 生成
generated_ids = model.generate(**inputs, **gen_kwargs)
# 5) 只取新增 token 并解码
trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
outputs = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return outputs[0].strip() if outputs else "(空响应)"
# ========================
# Gradio 界面
# ========================
with gr.Blocks(title="OralGPT-7B-Preview — Gradio Demo (ZeroGPU)") as demo:
gr.Markdown(
"""
# 🦷 OralGPT-7B-Preview — 多模态对话 Demo (ZeroGPU)
上传牙科相关图片并提问,或进行文字对话。右侧可开启 “Thinking(采样)模式” 拓展输出。
"""
)
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
height=500,
show_label=False,
container=True,
type="tuples"
)
with gr.Row():
msg = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="输入消息或上传图片(可多张)…",
show_label=False,
container=False
)
with gr.Row():
clear = gr.Button("🗑️ 清空", size="sm")
submit = gr.Button("📤 发送", variant="primary", size="sm")
with gr.Column(scale=1):
enable_thinking = gr.Checkbox(label="启用 Thinking(采样)模式", value=False)
max_tokens = gr.Slider(64, 2048, value=512, step=32, label="max_new_tokens")
gr.Markdown(
"""
### 示例
- “请问这张图片中是否存在龋病现象?”
- “是否存在牙周病现象?”
- “描述图像中的可疑病灶区域。”
"""
)
gr.Markdown("**注意:此模型仅供研究参考,不用于临床诊断或治疗。**")
# 事件逻辑
def user_submit(message, history, enable_thinking, max_new_tokens):
# 组织本轮用户消息
if isinstance(message, dict) and message.get("files"):
user_msg = []
for fp in message["files"]:
user_msg.append(fp)
if message.get("text", ""):
user_msg.append(message["text"])
else:
user_msg = message.get("text", "") if isinstance(message, dict) else message
history = history + [(user_msg, None)]
messages = build_messages(history[:-1], user_msg)
do_sample = bool(enable_thinking)
# generate 需要数值,这里给默认
temperature = 0.7 if do_sample else 0.7
top_p = 0.9 if do_sample else 0.9
try:
answer = qwen_infer(
messages=messages,
max_new_tokens=int(max_new_tokens),
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
pad_token_id=PAD_ID,
)
except Exception as e:
answer = f"Error: {str(e)}"
history[-1] = (history[-1][0], answer if answer else "(空响应)")
return "", history
msg.submit(user_submit, inputs=[msg, chatbot, enable_thinking, max_tokens], outputs=[msg, chatbot])
submit.click(user_submit, inputs=[msg, chatbot, enable_thinking, max_tokens], outputs=[msg, chatbot])
clear.click(lambda: (None, []), outputs=[msg, chatbot])
if __name__ == "__main__":
demo.launch(share=True)