# filename: app.py import os import torch import gradio as gr from PIL import Image import spaces from transformers import ( AutoProcessor, Qwen2_5_VLForConditionalGeneration, # ✅ 正确的多模态类 ) # ======================== # 基本设置 # ======================== torch.manual_seed(100) MODEL_NAME = "Eric3200/OralGPT-7B-Preview" # ======================== # 模型与处理器 # ======================== processor = AutoProcessor.from_pretrained( MODEL_NAME, trust_remote_code=True, ) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype="auto", # 或 torch.bfloat16 / torch.float16 device_map="auto", # attn_implementation="flash_attention_2", # 环境支持再打开(速度/显存更优) ).eval() # pad_token 兜底 try: PAD_ID = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id except Exception: PAD_ID = None # ======================== # 输入转 Qwen/OralGPT 风格 content # ======================== def to_qwen_content(user_msg): """ 支持: - 纯文本 - (img_path, text) - list 混合(PIL.Image / 文本 / 本地路径字符串) - dict {files:[...], text:"..."} """ contents = [] if isinstance(user_msg, tuple): img_path, text = user_msg if img_path: contents.append({"type": "image", "image": img_path}) if text: contents.append({"type": "text", "text": str(text)}) elif isinstance(user_msg, list): for item in user_msg: if isinstance(item, Image.Image): contents.append({"type": "image", "image": item}) elif isinstance(item, str) and os.path.isfile(item): contents.append({"type": "image", "image": item}) else: contents.append({"type": "text", "text": str(item)}) elif isinstance(user_msg, dict): if user_msg.get("files"): for fp in user_msg["files"]: contents.append({"type": "image", "image": fp}) if user_msg.get("text"): contents.append({"type": "text", "text": str(user_msg["text"])}) else: contents.append({"type": "text", "text": str(user_msg)}) return contents def build_messages(history, latest_user_msg): messages = [] for u, a in history: messages.append({"role": "user", "content": to_qwen_content(u)}) if a: messages.append({"role": "assistant", "content": [{"type": "text", "text": str(a)}]}) messages.append({"role": "user", "content": to_qwen_content(latest_user_msg)}) return messages # ======================== # ZeroGPU 推理(关键:@spaces.GPU) # ======================== @spaces.GPU(duration=120) def qwen_infer(messages, max_new_tokens=512, do_sample=False, temperature=0.7, top_p=0.9, pad_token_id=None): """ 流程: 1) apply_chat_template → prompt 2) 收集 messages 中的图片 3) processor(...) 打包 4) model.generate(...) 5) 取新增 token 解码 """ # 1) 文本模板 text_prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # 2) 收集图片(demo 不处理视频) images = [] for m in messages: for c in m.get("content", []): if c.get("type") == "image": img_obj = c.get("image") if isinstance(img_obj, Image.Image): images.append(img_obj) elif isinstance(img_obj, str) and os.path.isfile(img_obj): images.append(Image.open(img_obj).convert("RGB")) # 3) 打包输入 inputs = processor( text=[text_prompt], images=images if images else None, padding=True, return_tensors="pt", ) # 放到 GPU/对应设备 if torch.cuda.is_available(): inputs = inputs.to("cuda") else: inputs = inputs.to(model.device) gen_kwargs = dict( max_new_tokens=int(max_new_tokens), do_sample=bool(do_sample), ) if do_sample: gen_kwargs.update(dict(temperature=float(temperature), top_p=float(top_p))) if pad_token_id is not None: gen_kwargs.update(dict(pad_token_id=int(pad_token_id))) # 4) 生成 generated_ids = model.generate(**inputs, **gen_kwargs) # 5) 只取新增 token 并解码 trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] outputs = processor.batch_decode( trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return outputs[0].strip() if outputs else "(空响应)" # ======================== # Gradio 界面 # ======================== with gr.Blocks(title="OralGPT-7B-Preview — Gradio Demo (ZeroGPU)") as demo: gr.Markdown( """ # 🦷 OralGPT-7B-Preview — 多模态对话 Demo (ZeroGPU) 上传牙科相关图片并提问,或进行文字对话。右侧可开启 “Thinking(采样)模式” 拓展输出。 """ ) with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( height=500, show_label=False, container=True, type="tuples" ) with gr.Row(): msg = gr.MultimodalTextbox( interactive=True, file_types=["image"], placeholder="输入消息或上传图片(可多张)…", show_label=False, container=False ) with gr.Row(): clear = gr.Button("🗑️ 清空", size="sm") submit = gr.Button("📤 发送", variant="primary", size="sm") with gr.Column(scale=1): enable_thinking = gr.Checkbox(label="启用 Thinking(采样)模式", value=False) max_tokens = gr.Slider(64, 2048, value=512, step=32, label="max_new_tokens") gr.Markdown( """ ### 示例 - “请问这张图片中是否存在龋病现象?” - “是否存在牙周病现象?” - “描述图像中的可疑病灶区域。” """ ) gr.Markdown("**注意:此模型仅供研究参考,不用于临床诊断或治疗。**") # 事件逻辑 def user_submit(message, history, enable_thinking, max_new_tokens): # 组织本轮用户消息 if isinstance(message, dict) and message.get("files"): user_msg = [] for fp in message["files"]: user_msg.append(fp) if message.get("text", ""): user_msg.append(message["text"]) else: user_msg = message.get("text", "") if isinstance(message, dict) else message history = history + [(user_msg, None)] messages = build_messages(history[:-1], user_msg) do_sample = bool(enable_thinking) # generate 需要数值,这里给默认 temperature = 0.7 if do_sample else 0.7 top_p = 0.9 if do_sample else 0.9 try: answer = qwen_infer( messages=messages, max_new_tokens=int(max_new_tokens), do_sample=do_sample, temperature=temperature, top_p=top_p, pad_token_id=PAD_ID, ) except Exception as e: answer = f"Error: {str(e)}" history[-1] = (history[-1][0], answer if answer else "(空响应)") return "", history msg.submit(user_submit, inputs=[msg, chatbot, enable_thinking, max_tokens], outputs=[msg, chatbot]) submit.click(user_submit, inputs=[msg, chatbot, enable_thinking, max_tokens], outputs=[msg, chatbot]) clear.click(lambda: (None, []), outputs=[msg, chatbot]) if __name__ == "__main__": demo.launch(share=True)