fenge1 / app.py
yusir4200's picture
Upload 21 files
5df3c06 verified
import streamlit as st
import torch
from PIL import Image
import numpy as np
import cv2
from model import load_model, predict_with_uncertainty
import torchvision.transforms as transforms
from transformers import AutoProcessor, AutoModelForImageTextToText
from io import BytesIO
# 设置页面配置
st.set_page_config(
page_title="医学图像分析系统",
page_icon="🏥",
layout="wide"
)
# 初始化模型
@st.cache_resource
def load_models():
# 加载分割模型
seg_model = load_model()
seg_model.eval()
# 加载分析模型
model_id = "google/medgemma-4b-it"
analysis_model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
token="HUGGINGFACE_TOKEN"
)
processor = AutoProcessor.from_pretrained(
model_id,
token="HUGGINGFACE_TOKEN"
)
return seg_model, analysis_model, processor
# 页面标题
st.title("🏥 医学图像分析系统")
st.markdown("---")
# 加载模型
with st.spinner("正在加载模型..."):
seg_model, analysis_model, processor = load_models()
# 创建两列布局
col1, col2 = st.columns(2)
with col1:
st.subheader("📤 上传图片")
uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
# 显示原始图片
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="原始图片", use_column_width=True)
# 处理图片
if st.button("开始分析"):
with st.spinner("正在处理..."):
# 图像分割
image_resized = image.resize((224, 224))
transform = transforms.ToTensor()
image_tensor = transform(image_resized).unsqueeze(0)
# 执行分割
preds_mean, preds_uncertainty = predict_with_uncertainty(image_tensor)
# 生成分割掩码
pred_binary = (preds_mean > 0.5).astype(np.uint8) * 255
mask_image = Image.fromarray(pred_binary).convert("L")
# 处理不确定性图
uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8)
uncertainty_colormap = cv2.applyColorMap((uncertainty * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
uncertainty_image = Image.fromarray(uncertainty_colormap).convert("RGB")
# 合并图片
combined = Image.new("RGB", (mask_image.width + uncertainty_image.width, mask_image.height))
combined.paste(mask_image.convert("RGB"), (0, 0))
combined.paste(uncertainty_image, (mask_image.width, 0))
# 图像分析
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "你是一名皮肤病专家,请使用中文分析图片."}]
},
{
"role": "user",
"content": [
{"type": "text", "text": "这是一张皮肤病的图片,帮我分析一下"},
{"type": "image", "image": image}
]
}
]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(analysis_model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = analysis_model.generate(**inputs, max_new_tokens=200, do_sample=False)
generation = generation[0][input_len:]
analysis_text = processor.decode(generation, skip_special_tokens=True)
# 显示结果
with col2:
st.subheader("📊 分析结果")
st.image(combined, caption="分割结果", use_column_width=True)
st.markdown("### 📝 图像分析")
st.write(analysis_text)
# 添加页脚
st.markdown("---")
st.markdown("### 使用说明")
st.markdown("""
1. 在左侧上传一张医学图像
2. 点击"开始分析"按钮
3. 系统将自动进行图像分割和分析
4. 右侧将显示分割结果和分析报告
""")