yusir4200 commited on
Commit
5df3c06
·
verified ·
1 Parent(s): 06d7ddb

Upload 21 files

Browse files
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore virtual environment
2
+ venv/
3
+ __pycache__/
4
+
5
+ # Ignore model weights (GitHub has a 100MB limit)
6
+ *.pth
7
+
8
+ # Ignore system files
9
+ .DS_Store
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # 默认忽略的文件
2
+ /shelf/
3
+ /workspace.xml
.idea/image.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.9 (look)" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyCompatibilityInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ourVersions">
6
+ <value>
7
+ <list size="3">
8
+ <item index="0" class="java.lang.String" itemvalue="2.7" />
9
+ <item index="1" class="java.lang.String" itemvalue="3.13" />
10
+ <item index="2" class="java.lang.String" itemvalue="3.9" />
11
+ </list>
12
+ </value>
13
+ </option>
14
+ </inspection_tool>
15
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
16
+ <option name="ignoredPackages">
17
+ <value>
18
+ <list size="53">
19
+ <item index="0" class="java.lang.String" itemvalue="tensorboard" />
20
+ <item index="1" class="java.lang.String" itemvalue="timm" />
21
+ <item index="2" class="java.lang.String" itemvalue="absl-py" />
22
+ <item index="3" class="java.lang.String" itemvalue="shapely" />
23
+ <item index="4" class="java.lang.String" itemvalue="joblib" />
24
+ <item index="5" class="java.lang.String" itemvalue="threadpoolctl" />
25
+ <item index="6" class="java.lang.String" itemvalue="huggingface-hub" />
26
+ <item index="7" class="java.lang.String" itemvalue="PyYAML" />
27
+ <item index="8" class="java.lang.String" itemvalue="setuptools" />
28
+ <item index="9" class="java.lang.String" itemvalue="fsspec" />
29
+ <item index="10" class="java.lang.String" itemvalue="filelock" />
30
+ <item index="11" class="java.lang.String" itemvalue="qudida" />
31
+ <item index="12" class="java.lang.String" itemvalue="pip" />
32
+ <item index="13" class="java.lang.String" itemvalue="certifi" />
33
+ <item index="14" class="java.lang.String" itemvalue="warmup-scheduler" />
34
+ <item index="15" class="java.lang.String" itemvalue="portalocker" />
35
+ <item index="16" class="java.lang.String" itemvalue="nibabel" />
36
+ <item index="17" class="java.lang.String" itemvalue="h5py" />
37
+ <item index="18" class="java.lang.String" itemvalue="loguru" />
38
+ <item index="19" class="java.lang.String" itemvalue="contourpy" />
39
+ <item index="20" class="java.lang.String" itemvalue="fonttools" />
40
+ <item index="21" class="java.lang.String" itemvalue="imageio" />
41
+ <item index="22" class="java.lang.String" itemvalue="fairscale" />
42
+ <item index="23" class="java.lang.String" itemvalue="matplotlib" />
43
+ <item index="24" class="java.lang.String" itemvalue="charset-normalizer" />
44
+ <item index="25" class="java.lang.String" itemvalue="MedPy" />
45
+ <item index="26" class="java.lang.String" itemvalue="SimpleITK" />
46
+ <item index="27" class="java.lang.String" itemvalue="idna" />
47
+ <item index="28" class="java.lang.String" itemvalue="scikit-image" />
48
+ <item index="29" class="java.lang.String" itemvalue="yapf" />
49
+ <item index="30" class="java.lang.String" itemvalue="numpy" />
50
+ <item index="31" class="java.lang.String" itemvalue="requests" />
51
+ <item index="32" class="java.lang.String" itemvalue="opencv-python-headless" />
52
+ <item index="33" class="java.lang.String" itemvalue="seaborn" />
53
+ <item index="34" class="java.lang.String" itemvalue="pthflops" />
54
+ <item index="35" class="java.lang.String" itemvalue="PyWavelets" />
55
+ <item index="36" class="java.lang.String" itemvalue="zipp" />
56
+ <item index="37" class="java.lang.String" itemvalue="lazy_loader" />
57
+ <item index="38" class="java.lang.String" itemvalue="scipy" />
58
+ <item index="39" class="java.lang.String" itemvalue="opencv-python" />
59
+ <item index="40" class="java.lang.String" itemvalue="wheel" />
60
+ <item index="41" class="java.lang.String" itemvalue="packaging" />
61
+ <item index="42" class="java.lang.String" itemvalue="addict" />
62
+ <item index="43" class="java.lang.String" itemvalue="albumentations" />
63
+ <item index="44" class="java.lang.String" itemvalue="termcolor" />
64
+ <item index="45" class="java.lang.String" itemvalue="importlib-resources" />
65
+ <item index="46" class="java.lang.String" itemvalue="typing_extensions" />
66
+ <item index="47" class="java.lang.String" itemvalue="pytz" />
67
+ <item index="48" class="java.lang.String" itemvalue="einops" />
68
+ <item index="49" class="java.lang.String" itemvalue="Pillow" />
69
+ <item index="50" class="java.lang.String" itemvalue="torch" />
70
+ <item index="51" class="java.lang.String" itemvalue="torchvision" />
71
+ <item index="52" class="java.lang.String" itemvalue="torchsummary" />
72
+ </list>
73
+ </value>
74
+ </option>
75
+ </inspection_tool>
76
+ </profile>
77
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.9 (look)" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (look)" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/image.iml" filepath="$PROJECT_DIR$/.idea/image.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
+ <component name="ChangeListManager">
7
+ <list default="true" id="2992db4b-be05-4abd-b057-113b2d800a0e" name="更改" comment="" />
8
+ <option name="SHOW_DIALOG" value="false" />
9
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
11
+ <option name="LAST_RESOLUTION" value="IGNORE" />
12
+ </component>
13
+ <component name="MarkdownSettingsMigration">
14
+ <option name="stateVersion" value="1" />
15
+ </component>
16
+ <component name="ProjectColorInfo"><![CDATA[{
17
+ "associatedIndex": 1
18
+ }]]></component>
19
+ <component name="ProjectId" id="2y2VrJeUMs6MgMG5dRwdH89cBQE" />
20
+ <component name="ProjectViewState">
21
+ <option name="hideEmptyMiddlePackages" value="true" />
22
+ <option name="showLibraryContents" value="true" />
23
+ </component>
24
+ <component name="PropertiesComponent"><![CDATA[{
25
+ "keyToString": {
26
+ "RunOnceActivity.OpenProjectViewOnStart": "true",
27
+ "RunOnceActivity.ShowReadmeOnStart": "true",
28
+ "last_opened_file_path": "E:/挑战杯/image"
29
+ }
30
+ }]]></component>
31
+ <component name="SharedIndexes">
32
+ <attachedChunks>
33
+ <set>
34
+ <option value="bundled-python-sdk-09665e90c3a7-d3b881c8e49f-com.jetbrains.pycharm.community.sharedIndexes.bundled-PC-233.15026.15" />
35
+ </set>
36
+ </attachedChunks>
37
+ </component>
38
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
39
+ <component name="TaskManager">
40
+ <task active="true" id="Default" summary="默认任务">
41
+ <changelist id="2992db4b-be05-4abd-b057-113b2d800a0e" name="更改" comment="" />
42
+ <created>1749032852317</created>
43
+ <option name="number" value="Default" />
44
+ <option name="presentableId" value="Default" />
45
+ <updated>1749032852317</updated>
46
+ </task>
47
+ <servers />
48
+ </component>
49
+ </project>
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python image as a base
2
+ FROM python:3.9-slim
3
+
4
+ # Install system dependencies for OpenCV and wget
5
+ RUN apt-get update && apt-get install -y libgl1-mesa-glx libglib2.0-0 wget && rm -rf /var/lib/apt/lists/*
6
+
7
+ # Set working directory inside the container
8
+ WORKDIR /app
9
+
10
+ # Copy and install dependencies
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Download model weights from Hugging Face
15
+ RUN wget -O best_swin_upernet_main.pth "https://huggingface.co/samyakshrestha/swin-medical-segmentation/resolve/main/best_swin_upernet_main.pth?download=true"
16
+
17
+ # Copy the rest of the app code
18
+ COPY . .
19
+
20
+ # Expose the port
21
+ EXPOSE 8000
22
+
23
+ # Run FastAPI with uvicorn
24
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,12 +1,89 @@
1
- ---
2
- title: Fenge1
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.32.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Skin Lesion Segmentation API with Uncertainty Estimation
2
+
3
+ This repository contains a production-ready FastAPI service for performing semantic segmentation of skin lesions, trained on the ISIC 2018 dataset. The segmentation model uses a Swin-Tiny transformer encoder with a custom UPerNet decoder, and incorporates Monte Carlo (MC) Dropout at inference time to estimate predictive uncertainty.
4
+
5
+ ## Features
6
+
7
+ - RESTful FastAPI endpoint for real-time image segmentation
8
+ - Swin-Tiny + UPerNet architecture for accurate lesion boundary detection
9
+ - Monte Carlo Dropout-based uncertainty quantification
10
+ - Heatmap visualization of model uncertainty using OpenCV colormaps
11
+ - Side-by-side output of the predicted segmentation mask and uncertainty map
12
+ - Clean, modular codebase suitable for further extension and deployment
13
+
14
+ ## Inference Workflow
15
+
16
+ 1. The API accepts a dermoscopic image uploaded via a POST request to the `/predict/` endpoint.
17
+ 2. The image is resized, normalized, and passed through the segmentation model.
18
+ 3. The model performs multiple stochastic forward passes using MC Dropout.
19
+ 4. The mean prediction is binarized to produce a segmentation mask.
20
+ 5. The standard deviation across predictions is visualized as a heatmap to express uncertainty.
21
+ 6. The API returns both the binary mask and uncertainty heatmap as a combined PNG image.
22
+
23
+ ## File Structure
24
+
25
+ | File | Description |
26
+ |-------------------|--------------------------------------------------------------|
27
+ | `main.py` | FastAPI application and endpoint logic |
28
+ | `model.py` | Model architecture and MC Dropout inference utilities |
29
+ | `requirements.txt`| Required Python packages |
30
+ | `.gitignore` | Specifies ignored files (e.g., `venv/`, model weights) |
31
+
32
+ ## Requirements
33
+
34
+ - Python 3.9+
35
+ - PyTorch
36
+ - FastAPI
37
+ - OpenCV
38
+ - Pillow
39
+ - timm
40
+ - torchvision
41
+ - uvicorn
42
+
43
+ Install dependencies:
44
+
45
+ ```bash
46
+ pip install -r requirements.txt
47
+ ```
48
+
49
+
50
+ ## Model Weights
51
+
52
+ To run the API, you will need to download the pretrained model weights and place the file in the root directory of this repository.
53
+
54
+ **Download link:**
55
+
56
+ [best_swin_upernet_main.pth](https://huggingface.co/samyakshrestha/swin-medical-segmentation/resolve/main/best_swin_upernet_main.pth?download=true)
57
+
58
+ After downloading, ensure the file is named exactly "best_swin_upernet_main.pth" and saved in the same directory as `main.py` and `model.py`.
59
+
60
+ This weight file contains the Swin-Tiny UPerNet model trained on the ISIC 2018 skin lesion segmentation dataset.
61
+
62
+
63
+ ## Dataset
64
+
65
+ This project uses the ISIC 2018 Challenge Dataset for training, which contains dermoscopic images of skin lesions annotated for binary segmentation.
66
+
67
+ Dataset access: https://challenge.isic-archive.com/data/
68
+
69
+ ## Docker Support
70
+
71
+ This project includes a `Dockerfile` for containerized deployment. Run the API with Docker:
72
+
73
+ ### Build the Docker image
74
+ ```bash
75
+ docker build -t medical-segmentation-api .
76
+ ```
77
+
78
+ ### Run the container
79
+ ```bash
80
+ docker run -p 8000:8000 medical-segmentation-api
81
+ ```
82
+
83
+
84
+ The API will then be available at:
85
+ http://localhost:8000/predict/
86
+
87
+ ## Author
88
+
89
+ Samyak Shrestha
__pycache__/model.cpython-310.pyc ADDED
Binary file (4.8 kB). View file
 
__pycache__/model.cpython-39.pyc ADDED
Binary file (4.82 kB). View file
 
aa.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ url = "http://127.0.0.1:4011/predict/"
4
+ files = {'file': open('D:\zkwg\image\data\iamge1.jpg', 'rb')} # 替换为您的图像路径
5
+ response = requests.post(url, files=files)
6
+
7
+ # 保存返回的预测结果
8
+ with open('D:\zkwg\image\data\prediction_result.png', 'wb') as f:
9
+ f.write(response.content)
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2
6
+ from model import load_model, predict_with_uncertainty
7
+ import torchvision.transforms as transforms
8
+ from transformers import AutoProcessor, AutoModelForImageTextToText
9
+ from io import BytesIO
10
+
11
+ # 设置页面配置
12
+ st.set_page_config(
13
+ page_title="医学图像分析系统",
14
+ page_icon="🏥",
15
+ layout="wide"
16
+ )
17
+
18
+ # 初始化模型
19
+ @st.cache_resource
20
+ def load_models():
21
+ # 加载分割模型
22
+ seg_model = load_model()
23
+ seg_model.eval()
24
+
25
+ # 加载分析模型
26
+ model_id = "google/medgemma-4b-it"
27
+ analysis_model = AutoModelForImageTextToText.from_pretrained(
28
+ model_id,
29
+ torch_dtype=torch.bfloat16,
30
+ device_map="auto",
31
+ token="HUGGINGFACE_TOKEN"
32
+ )
33
+ processor = AutoProcessor.from_pretrained(
34
+ model_id,
35
+ token="HUGGINGFACE_TOKEN"
36
+ )
37
+
38
+ return seg_model, analysis_model, processor
39
+
40
+ # 页面标题
41
+ st.title("🏥 医学图像分析系统")
42
+ st.markdown("---")
43
+
44
+ # 加载模型
45
+ with st.spinner("正在加载模型..."):
46
+ seg_model, analysis_model, processor = load_models()
47
+
48
+ # 创建两列布局
49
+ col1, col2 = st.columns(2)
50
+
51
+ with col1:
52
+ st.subheader("📤 上传图片")
53
+ uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
54
+
55
+ if uploaded_file is not None:
56
+ # 显示原始图片
57
+ image = Image.open(uploaded_file).convert("RGB")
58
+ st.image(image, caption="原始图片", use_column_width=True)
59
+
60
+ # 处理图片
61
+ if st.button("开始分析"):
62
+ with st.spinner("正在处理..."):
63
+ # 图像分割
64
+ image_resized = image.resize((224, 224))
65
+ transform = transforms.ToTensor()
66
+ image_tensor = transform(image_resized).unsqueeze(0)
67
+
68
+ # 执行分割
69
+ preds_mean, preds_uncertainty = predict_with_uncertainty(image_tensor)
70
+
71
+ # 生成分割掩码
72
+ pred_binary = (preds_mean > 0.5).astype(np.uint8) * 255
73
+ mask_image = Image.fromarray(pred_binary).convert("L")
74
+
75
+ # 处理不确定性图
76
+ uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8)
77
+ uncertainty_colormap = cv2.applyColorMap((uncertainty * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
78
+ uncertainty_image = Image.fromarray(uncertainty_colormap).convert("RGB")
79
+
80
+ # 合并图片
81
+ combined = Image.new("RGB", (mask_image.width + uncertainty_image.width, mask_image.height))
82
+ combined.paste(mask_image.convert("RGB"), (0, 0))
83
+ combined.paste(uncertainty_image, (mask_image.width, 0))
84
+
85
+ # 图像分析
86
+ messages = [
87
+ {
88
+ "role": "system",
89
+ "content": [{"type": "text", "text": "你是一名皮肤病专家,请使用中文分析图片."}]
90
+ },
91
+ {
92
+ "role": "user",
93
+ "content": [
94
+ {"type": "text", "text": "这是一张皮肤病的图片,帮我分析一下"},
95
+ {"type": "image", "image": image}
96
+ ]
97
+ }
98
+ ]
99
+
100
+ inputs = processor.apply_chat_template(
101
+ messages, add_generation_prompt=True, tokenize=True,
102
+ return_dict=True, return_tensors="pt"
103
+ ).to(analysis_model.device, dtype=torch.bfloat16)
104
+
105
+ input_len = inputs["input_ids"].shape[-1]
106
+
107
+ with torch.inference_mode():
108
+ generation = analysis_model.generate(**inputs, max_new_tokens=200, do_sample=False)
109
+ generation = generation[0][input_len:]
110
+
111
+ analysis_text = processor.decode(generation, skip_special_tokens=True)
112
+
113
+ # 显示结果
114
+ with col2:
115
+ st.subheader("📊 分析结果")
116
+ st.image(combined, caption="分割结果", use_column_width=True)
117
+ st.markdown("### 📝 图像分析")
118
+ st.write(analysis_text)
119
+
120
+ # 添加页脚
121
+ st.markdown("---")
122
+ st.markdown("### 使用说明")
123
+ st.markdown("""
124
+ 1. 在左侧上传一张医学图像
125
+ 2. 点击"开始分析"按钮
126
+ 3. 系统将自动进行图像分割和分析
127
+ 4. 右侧将显示分割结果和分析报告
128
+ """)
data/iamge1.jpg ADDED
data/iamge2.jpg ADDED
data/iamge3.png ADDED
data/prediction_result.png ADDED
main.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from fastapi import FastAPI, UploadFile, File
3
+ from fastapi.responses import Response
4
+ import uvicorn
5
+ from model import load_model, predict_with_uncertainty
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import numpy as np
10
+ import cv2
11
+ import torch
12
+
13
+ app = FastAPI()
14
+
15
+ # Load the model when the API starts
16
+ model = load_model()
17
+ model.eval()
18
+
19
+ def convert_to_image(array, colormap=None):
20
+ array = (array * 255).astype(np.uint8) # Normalize to 0–255 range
21
+ if colormap is not None:
22
+ array = cv2.applyColorMap(array, colormap)
23
+ return Image.fromarray(array)
24
+
25
+ @app.post("/predict/")
26
+ async def predict_mask(file: UploadFile = File(...)):
27
+ # Read and preprocess the image
28
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
29
+ image = image.resize((224, 224))
30
+ transform = transforms.ToTensor()
31
+ image_tensor = transform(image).unsqueeze(0)
32
+
33
+ # Perform MC Dropout Inference
34
+ preds_mean, preds_uncertainty = predict_with_uncertainty(image_tensor)
35
+
36
+ # Binary mask (0 or 255)
37
+ pred_binary = (preds_mean > 0.5).astype(np.uint8) * 255
38
+ mask_image = Image.fromarray(pred_binary).convert("L")
39
+
40
+ # Normalize and apply colormap to uncertainty
41
+ uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8)
42
+ uncertainty_colormap = cv2.applyColorMap((uncertainty * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
43
+ uncertainty_image = Image.fromarray(uncertainty_colormap).convert("RGB")
44
+
45
+ # Combine side by side
46
+ combined = Image.new("RGB", (mask_image.width + uncertainty_image.width, mask_image.height))
47
+ combined.paste(mask_image.convert("RGB"), (0, 0))
48
+ combined.paste(uncertainty_image, (mask_image.width, 0))
49
+
50
+ # Save to buffer and return
51
+ img_io = BytesIO()
52
+ combined.save(img_io, format="PNG")
53
+ img_io.seek(0)
54
+
55
+ return Response(content=img_io.getvalue(), media_type="image/png")
56
+
57
+ if __name__ == "__main__":
58
+ # 直接启动模式,适合各种操作系统环境
59
+ logging.info("启动MCP客户端API服务...")
60
+ uvicorn.run(app, host="0.0.0.0", port=4011)
model.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import timm
5
+ import numpy as np
6
+ import cv2
7
+
8
+ # -------------------------------
9
+ # Define Pyramid Pooling Module (with GroupNorm)
10
+ # -------------------------------
11
+ class PyramidPoolingModule(nn.Module):
12
+ def __init__(self, in_channels, pool_sizes=[1, 2, 3, 6]):
13
+ super().__init__()
14
+ self.pool_layers = nn.ModuleList([
15
+ nn.Sequential(
16
+ nn.AdaptiveAvgPool2d(pool_size),
17
+ nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False),
18
+ nn.GroupNorm(num_groups=8, num_channels=in_channels // 4),
19
+ nn.ReLU(inplace=True)
20
+ ) for pool_size in pool_sizes
21
+ ])
22
+ total_channels = in_channels + len(pool_sizes) * (in_channels // 4)
23
+ self.conv = nn.Conv2d(total_channels, in_channels, kernel_size=1, bias=False)
24
+
25
+ def forward(self, x):
26
+ pooled_features = [x]
27
+ for layer in self.pool_layers:
28
+ pooled = layer(x)
29
+ pooled = F.interpolate(pooled, size=x.shape[2:], mode='bilinear', align_corners=False)
30
+ pooled_features.append(pooled)
31
+ x = torch.cat(pooled_features, dim=1)
32
+ x = self.conv(x)
33
+ return x
34
+
35
+ # -------------------------------
36
+ # Define UPerNet Decoder (With Dropout)
37
+ # -------------------------------
38
+ class UPerNetDecoder(nn.Module):
39
+ def __init__(self, encoder_channels, num_classes=1, dropout_rate=0.1):
40
+ super().__init__()
41
+ self.ppm = PyramidPoolingModule(encoder_channels[-1])
42
+ self.lateral_conv2 = nn.Conv2d(encoder_channels[2], encoder_channels[-1], kernel_size=1)
43
+ self.conv3 = nn.Sequential(
44
+ nn.Conv2d(encoder_channels[-1], encoder_channels[2], kernel_size=1),
45
+ nn.Dropout2d(p=dropout_rate)
46
+ )
47
+ self.lateral_conv1 = nn.Conv2d(encoder_channels[1], encoder_channels[2], kernel_size=1)
48
+ self.conv2 = nn.Sequential(
49
+ nn.Conv2d(encoder_channels[2], encoder_channels[1], kernel_size=1),
50
+ nn.Dropout2d(p=dropout_rate)
51
+ )
52
+ self.lateral_conv0 = nn.Conv2d(encoder_channels[0], encoder_channels[1], kernel_size=1)
53
+ self.conv1 = nn.Sequential(
54
+ nn.Conv2d(encoder_channels[1], encoder_channels[0], kernel_size=1),
55
+ nn.Dropout2d(p=dropout_rate)
56
+ )
57
+ self.segmentation_head = nn.Conv2d(encoder_channels[0], num_classes, kernel_size=1)
58
+
59
+ def forward(self, features):
60
+ f0, f1, f2, f3 = features
61
+ x3 = self.ppm(f3)
62
+ x3_up = F.interpolate(x3, size=f2.shape[2:], mode="bilinear", align_corners=False)
63
+ fuse2 = x3_up + self.lateral_conv2(f2)
64
+ fuse2 = self.conv3(fuse2)
65
+ fuse2_up = F.interpolate(fuse2, size=f1.shape[2:], mode="bilinear", align_corners=False)
66
+ fuse1 = fuse2_up + self.lateral_conv1(f1)
67
+ fuse1 = self.conv2(fuse1)
68
+ fuse1_up = F.interpolate(fuse1, size=f0.shape[2:], mode="bilinear", align_corners=False)
69
+ fuse0 = fuse1_up + self.lateral_conv0(f0)
70
+ fuse0 = self.conv1(fuse0)
71
+ x_out = F.interpolate(fuse0, size=(224, 224), mode="bilinear", align_corners=False)
72
+ output = self.segmentation_head(x_out)
73
+ return output
74
+
75
+ # -------------------------------
76
+ # Define Swin-Tiny UPerNet Model
77
+ # -------------------------------
78
+ class SwinTinyUPerNet(nn.Module):
79
+ def __init__(self, num_classes=1, dropout_rate=0.1):
80
+ super().__init__()
81
+ self.encoder = timm.create_model(
82
+ "swin_tiny_patch4_window7_224.ms_in22k_ft_in1k",
83
+ pretrained=True,
84
+ features_only=True
85
+ )
86
+ encoder_channels = self.encoder.feature_info.channels()
87
+ self.decoder = UPerNetDecoder(encoder_channels, num_classes, dropout_rate=dropout_rate)
88
+
89
+ def forward(self, x):
90
+ features = self.encoder(x)
91
+ features = [f.permute(0, 3, 1, 2) if f.dim() == 4 else f for f in features]
92
+ output = self.decoder(features)
93
+ return F.interpolate(output, size=(224, 224), mode="bilinear", align_corners=False)
94
+
95
+ # -------------------------------
96
+ # Load the Model
97
+ # -------------------------------
98
+ def load_model():
99
+ model = SwinTinyUPerNet(num_classes=1)
100
+ model.load_state_dict(torch.load("best_swin_upernet_main.pth", map_location=torch.device("cpu")), strict=False)
101
+ model.eval()
102
+ return model
103
+
104
+ # -------------------------------
105
+ # Enable Dropout at Inference Time
106
+ # -------------------------------
107
+ def enable_dropout(m):
108
+ if isinstance(m, nn.Dropout) or isinstance(m, nn.Dropout2d):
109
+ m.train()
110
+
111
+ # -------------------------------
112
+ # Perform Inference with MC Dropout
113
+ # -------------------------------
114
+ def predict_with_uncertainty(image_tensor, num_samples=10):
115
+ model = load_model()
116
+ model.apply(enable_dropout)
117
+ preds_list = []
118
+
119
+ with torch.no_grad():
120
+ for _ in range(num_samples):
121
+ preds = torch.sigmoid(model(image_tensor))
122
+ preds_list.append(preds)
123
+
124
+ preds_array = torch.stack(preds_list, dim=0)
125
+ preds_mean = preds_array.mean(dim=0).squeeze().cpu().numpy()
126
+ preds_uncertainty = preds_array.std(dim=0).squeeze().cpu().numpy()
127
+
128
+ # Normalize uncertainty map
129
+ preds_uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8)
130
+ return preds_mean, preds_uncertainty
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.32.0
2
+ torch==2.2.0
3
+ torchvision==0.17.0
4
+ Pillow==10.2.0
5
+ numpy==1.26.4
6
+ opencv-python==4.9.0.80
7
+ fastapi==0.110.0
8
+ uvicorn==0.27.1
9
+ python-multipart==0.0.9
10
+ huggingface-hub>=0.30.0
11
+ accelerate>=0.27.0
12
+ sentencepiece==0.2.0
13
+ protobuf==4.25.3
14
+ einops==0.7.0
15
+ safetensors==0.4.2