import streamlit as st import torch from PIL import Image import numpy as np import cv2 from model import load_model, predict_with_uncertainty import torchvision.transforms as transforms from transformers import AutoProcessor, AutoModelForImageTextToText from io import BytesIO # 设置页面配置 st.set_page_config( page_title="医学图像分析系统", page_icon="🏥", layout="wide" ) # 初始化模型 @st.cache_resource def load_models(): # 加载分割模型 seg_model = load_model() seg_model.eval() # 加载分析模型 model_id = "google/medgemma-4b-it" analysis_model = AutoModelForImageTextToText.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", token="HUGGINGFACE_TOKEN" ) processor = AutoProcessor.from_pretrained( model_id, token="HUGGINGFACE_TOKEN" ) return seg_model, analysis_model, processor # 页面标题 st.title("🏥 医学图像分析系统") st.markdown("---") # 加载模型 with st.spinner("正在加载模型..."): seg_model, analysis_model, processor = load_models() # 创建两列布局 col1, col2 = st.columns(2) with col1: st.subheader("📤 上传图片") uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: # 显示原始图片 image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="原始图片", use_column_width=True) # 处理图片 if st.button("开始分析"): with st.spinner("正在处理..."): # 图像分割 image_resized = image.resize((224, 224)) transform = transforms.ToTensor() image_tensor = transform(image_resized).unsqueeze(0) # 执行分割 preds_mean, preds_uncertainty = predict_with_uncertainty(image_tensor) # 生成分割掩码 pred_binary = (preds_mean > 0.5).astype(np.uint8) * 255 mask_image = Image.fromarray(pred_binary).convert("L") # 处理不确定性图 uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8) uncertainty_colormap = cv2.applyColorMap((uncertainty * 255).astype(np.uint8), cv2.COLORMAP_INFERNO) uncertainty_image = Image.fromarray(uncertainty_colormap).convert("RGB") # 合并图片 combined = Image.new("RGB", (mask_image.width + uncertainty_image.width, mask_image.height)) combined.paste(mask_image.convert("RGB"), (0, 0)) combined.paste(uncertainty_image, (mask_image.width, 0)) # 图像分析 messages = [ { "role": "system", "content": [{"type": "text", "text": "你是一名皮肤病专家,请使用中文分析图片."}] }, { "role": "user", "content": [ {"type": "text", "text": "这是一张皮肤病的图片,帮我分析一下"}, {"type": "image", "image": image} ] } ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(analysis_model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = analysis_model.generate(**inputs, max_new_tokens=200, do_sample=False) generation = generation[0][input_len:] analysis_text = processor.decode(generation, skip_special_tokens=True) # 显示结果 with col2: st.subheader("📊 分析结果") st.image(combined, caption="分割结果", use_column_width=True) st.markdown("### 📝 图像分析") st.write(analysis_text) # 添加页脚 st.markdown("---") st.markdown("### 使用说明") st.markdown(""" 1. 在左侧上传一张医学图像 2. 点击"开始分析"按钮 3. 系统将自动进行图像分割和分析 4. 右侧将显示分割结果和分析报告 """)