|
import streamlit as st
|
|
import torch
|
|
from PIL import Image
|
|
import numpy as np
|
|
import cv2
|
|
from model import load_model, predict_with_uncertainty
|
|
import torchvision.transforms as transforms
|
|
from transformers import AutoProcessor, AutoModelForImageTextToText
|
|
from io import BytesIO
|
|
|
|
|
|
st.set_page_config(
|
|
page_title="医学图像分析系统",
|
|
page_icon="🏥",
|
|
layout="wide"
|
|
)
|
|
|
|
|
|
@st.cache_resource
|
|
def load_models():
|
|
|
|
seg_model = load_model()
|
|
seg_model.eval()
|
|
|
|
|
|
model_id = "google/medgemma-4b-it"
|
|
analysis_model = AutoModelForImageTextToText.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
token="HUGGINGFACE_TOKEN"
|
|
)
|
|
processor = AutoProcessor.from_pretrained(
|
|
model_id,
|
|
token="HUGGINGFACE_TOKEN"
|
|
)
|
|
|
|
return seg_model, analysis_model, processor
|
|
|
|
|
|
st.title("🏥 医学图像分析系统")
|
|
st.markdown("---")
|
|
|
|
|
|
with st.spinner("正在加载模型..."):
|
|
seg_model, analysis_model, processor = load_models()
|
|
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
st.subheader("📤 上传图片")
|
|
uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
|
|
|
|
if uploaded_file is not None:
|
|
|
|
image = Image.open(uploaded_file).convert("RGB")
|
|
st.image(image, caption="原始图片", use_column_width=True)
|
|
|
|
|
|
if st.button("开始分析"):
|
|
with st.spinner("正在处理..."):
|
|
|
|
image_resized = image.resize((224, 224))
|
|
transform = transforms.ToTensor()
|
|
image_tensor = transform(image_resized).unsqueeze(0)
|
|
|
|
|
|
preds_mean, preds_uncertainty = predict_with_uncertainty(image_tensor)
|
|
|
|
|
|
pred_binary = (preds_mean > 0.5).astype(np.uint8) * 255
|
|
mask_image = Image.fromarray(pred_binary).convert("L")
|
|
|
|
|
|
uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8)
|
|
uncertainty_colormap = cv2.applyColorMap((uncertainty * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
|
|
uncertainty_image = Image.fromarray(uncertainty_colormap).convert("RGB")
|
|
|
|
|
|
combined = Image.new("RGB", (mask_image.width + uncertainty_image.width, mask_image.height))
|
|
combined.paste(mask_image.convert("RGB"), (0, 0))
|
|
combined.paste(uncertainty_image, (mask_image.width, 0))
|
|
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": "你是一名皮肤病专家,请使用中文分析图片."}]
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "这是一张皮肤病的图片,帮我分析一下"},
|
|
{"type": "image", "image": image}
|
|
]
|
|
}
|
|
]
|
|
|
|
inputs = processor.apply_chat_template(
|
|
messages, add_generation_prompt=True, tokenize=True,
|
|
return_dict=True, return_tensors="pt"
|
|
).to(analysis_model.device, dtype=torch.bfloat16)
|
|
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
|
|
with torch.inference_mode():
|
|
generation = analysis_model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
|
generation = generation[0][input_len:]
|
|
|
|
analysis_text = processor.decode(generation, skip_special_tokens=True)
|
|
|
|
|
|
with col2:
|
|
st.subheader("📊 分析结果")
|
|
st.image(combined, caption="分割结果", use_column_width=True)
|
|
st.markdown("### 📝 图像分析")
|
|
st.write(analysis_text)
|
|
|
|
|
|
st.markdown("---")
|
|
st.markdown("### 使用说明")
|
|
st.markdown("""
|
|
1. 在左侧上传一张医学图像
|
|
2. 点击"开始分析"按钮
|
|
3. 系统将自动进行图像分割和分析
|
|
4. 右侧将显示分割结果和分析报告
|
|
""") |