File size: 4,867 Bytes
5df3c06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
import torch
from PIL import Image
import numpy as np
import cv2
from model import load_model, predict_with_uncertainty
import torchvision.transforms as transforms
from transformers import AutoProcessor, AutoModelForImageTextToText
from io import BytesIO

# 设置页面配置
st.set_page_config(
    page_title="医学图像分析系统",
    page_icon="🏥",
    layout="wide"
)

# 初始化模型
@st.cache_resource
def load_models():
    # 加载分割模型
    seg_model = load_model()
    seg_model.eval()
    
    # 加载分析模型
    model_id = "google/medgemma-4b-it"
    analysis_model = AutoModelForImageTextToText.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token="HUGGINGFACE_TOKEN"
    )
    processor = AutoProcessor.from_pretrained(
        model_id,
        token="HUGGINGFACE_TOKEN"
    )
    
    return seg_model, analysis_model, processor

# 页面标题
st.title("🏥 医学图像分析系统")
st.markdown("---")

# 加载模型
with st.spinner("正在加载模型..."):
    seg_model, analysis_model, processor = load_models()

# 创建两列布局
col1, col2 = st.columns(2)

with col1:
    st.subheader("📤 上传图片")
    uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
    
    if uploaded_file is not None:
        # 显示原始图片
        image = Image.open(uploaded_file).convert("RGB")
        st.image(image, caption="原始图片", use_column_width=True)
        
        # 处理图片
        if st.button("开始分析"):
            with st.spinner("正在处理..."):
                # 图像分割
                image_resized = image.resize((224, 224))
                transform = transforms.ToTensor()
                image_tensor = transform(image_resized).unsqueeze(0)
                
                # 执行分割
                preds_mean, preds_uncertainty = predict_with_uncertainty(image_tensor)
                
                # 生成分割掩码
                pred_binary = (preds_mean > 0.5).astype(np.uint8) * 255
                mask_image = Image.fromarray(pred_binary).convert("L")
                
                # 处理不确定性图
                uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8)
                uncertainty_colormap = cv2.applyColorMap((uncertainty * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
                uncertainty_image = Image.fromarray(uncertainty_colormap).convert("RGB")
                
                # 合并图片
                combined = Image.new("RGB", (mask_image.width + uncertainty_image.width, mask_image.height))
                combined.paste(mask_image.convert("RGB"), (0, 0))
                combined.paste(uncertainty_image, (mask_image.width, 0))
                
                # 图像分析
                messages = [
                    {
                        "role": "system",
                        "content": [{"type": "text", "text": "你是一名皮肤病专家,请使用中文分析图片."}]
                    },
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": "这是一张皮肤病的图片,帮我分析一下"},
                            {"type": "image", "image": image}
                        ]
                    }
                ]
                
                inputs = processor.apply_chat_template(
                    messages, add_generation_prompt=True, tokenize=True,
                    return_dict=True, return_tensors="pt"
                ).to(analysis_model.device, dtype=torch.bfloat16)
                
                input_len = inputs["input_ids"].shape[-1]
                
                with torch.inference_mode():
                    generation = analysis_model.generate(**inputs, max_new_tokens=200, do_sample=False)
                    generation = generation[0][input_len:]
                
                analysis_text = processor.decode(generation, skip_special_tokens=True)
                
                # 显示结果
                with col2:
                    st.subheader("📊 分析结果")
                    st.image(combined, caption="分割结果", use_column_width=True)
                    st.markdown("### 📝 图像分析")
                    st.write(analysis_text)

# 添加页脚
st.markdown("---")
st.markdown("### 使用说明")
st.markdown("""

1. 在左侧上传一张医学图像

2. 点击"开始分析"按钮

3. 系统将自动进行图像分割和分析

4. 右侧将显示分割结果和分析报告

""")