Upload 21 files
Browse files- .gitignore +9 -0
- .idea/.gitignore +3 -0
- .idea/image.iml +8 -0
- .idea/inspectionProfiles/Project_Default.xml +77 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/workspace.xml +49 -0
- Dockerfile +24 -0
- README.md +89 -12
- __pycache__/model.cpython-310.pyc +0 -0
- __pycache__/model.cpython-39.pyc +0 -0
- aa.py +9 -0
- app.py +128 -0
- data/iamge1.jpg +0 -0
- data/iamge2.jpg +0 -0
- data/iamge3.png +0 -0
- data/prediction_result.png +0 -0
- main.py +60 -0
- model.py +130 -0
- requirements.txt +15 -0
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore virtual environment
|
2 |
+
venv/
|
3 |
+
__pycache__/
|
4 |
+
|
5 |
+
# Ignore model weights (GitHub has a 100MB limit)
|
6 |
+
*.pth
|
7 |
+
|
8 |
+
# Ignore system files
|
9 |
+
.DS_Store
|
.idea/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# 默认忽略的文件
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
.idea/image.iml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="jdk" jdkName="Python 3.9 (look)" jdkType="Python SDK" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyCompatibilityInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
5 |
+
<option name="ourVersions">
|
6 |
+
<value>
|
7 |
+
<list size="3">
|
8 |
+
<item index="0" class="java.lang.String" itemvalue="2.7" />
|
9 |
+
<item index="1" class="java.lang.String" itemvalue="3.13" />
|
10 |
+
<item index="2" class="java.lang.String" itemvalue="3.9" />
|
11 |
+
</list>
|
12 |
+
</value>
|
13 |
+
</option>
|
14 |
+
</inspection_tool>
|
15 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
16 |
+
<option name="ignoredPackages">
|
17 |
+
<value>
|
18 |
+
<list size="53">
|
19 |
+
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
|
20 |
+
<item index="1" class="java.lang.String" itemvalue="timm" />
|
21 |
+
<item index="2" class="java.lang.String" itemvalue="absl-py" />
|
22 |
+
<item index="3" class="java.lang.String" itemvalue="shapely" />
|
23 |
+
<item index="4" class="java.lang.String" itemvalue="joblib" />
|
24 |
+
<item index="5" class="java.lang.String" itemvalue="threadpoolctl" />
|
25 |
+
<item index="6" class="java.lang.String" itemvalue="huggingface-hub" />
|
26 |
+
<item index="7" class="java.lang.String" itemvalue="PyYAML" />
|
27 |
+
<item index="8" class="java.lang.String" itemvalue="setuptools" />
|
28 |
+
<item index="9" class="java.lang.String" itemvalue="fsspec" />
|
29 |
+
<item index="10" class="java.lang.String" itemvalue="filelock" />
|
30 |
+
<item index="11" class="java.lang.String" itemvalue="qudida" />
|
31 |
+
<item index="12" class="java.lang.String" itemvalue="pip" />
|
32 |
+
<item index="13" class="java.lang.String" itemvalue="certifi" />
|
33 |
+
<item index="14" class="java.lang.String" itemvalue="warmup-scheduler" />
|
34 |
+
<item index="15" class="java.lang.String" itemvalue="portalocker" />
|
35 |
+
<item index="16" class="java.lang.String" itemvalue="nibabel" />
|
36 |
+
<item index="17" class="java.lang.String" itemvalue="h5py" />
|
37 |
+
<item index="18" class="java.lang.String" itemvalue="loguru" />
|
38 |
+
<item index="19" class="java.lang.String" itemvalue="contourpy" />
|
39 |
+
<item index="20" class="java.lang.String" itemvalue="fonttools" />
|
40 |
+
<item index="21" class="java.lang.String" itemvalue="imageio" />
|
41 |
+
<item index="22" class="java.lang.String" itemvalue="fairscale" />
|
42 |
+
<item index="23" class="java.lang.String" itemvalue="matplotlib" />
|
43 |
+
<item index="24" class="java.lang.String" itemvalue="charset-normalizer" />
|
44 |
+
<item index="25" class="java.lang.String" itemvalue="MedPy" />
|
45 |
+
<item index="26" class="java.lang.String" itemvalue="SimpleITK" />
|
46 |
+
<item index="27" class="java.lang.String" itemvalue="idna" />
|
47 |
+
<item index="28" class="java.lang.String" itemvalue="scikit-image" />
|
48 |
+
<item index="29" class="java.lang.String" itemvalue="yapf" />
|
49 |
+
<item index="30" class="java.lang.String" itemvalue="numpy" />
|
50 |
+
<item index="31" class="java.lang.String" itemvalue="requests" />
|
51 |
+
<item index="32" class="java.lang.String" itemvalue="opencv-python-headless" />
|
52 |
+
<item index="33" class="java.lang.String" itemvalue="seaborn" />
|
53 |
+
<item index="34" class="java.lang.String" itemvalue="pthflops" />
|
54 |
+
<item index="35" class="java.lang.String" itemvalue="PyWavelets" />
|
55 |
+
<item index="36" class="java.lang.String" itemvalue="zipp" />
|
56 |
+
<item index="37" class="java.lang.String" itemvalue="lazy_loader" />
|
57 |
+
<item index="38" class="java.lang.String" itemvalue="scipy" />
|
58 |
+
<item index="39" class="java.lang.String" itemvalue="opencv-python" />
|
59 |
+
<item index="40" class="java.lang.String" itemvalue="wheel" />
|
60 |
+
<item index="41" class="java.lang.String" itemvalue="packaging" />
|
61 |
+
<item index="42" class="java.lang.String" itemvalue="addict" />
|
62 |
+
<item index="43" class="java.lang.String" itemvalue="albumentations" />
|
63 |
+
<item index="44" class="java.lang.String" itemvalue="termcolor" />
|
64 |
+
<item index="45" class="java.lang.String" itemvalue="importlib-resources" />
|
65 |
+
<item index="46" class="java.lang.String" itemvalue="typing_extensions" />
|
66 |
+
<item index="47" class="java.lang.String" itemvalue="pytz" />
|
67 |
+
<item index="48" class="java.lang.String" itemvalue="einops" />
|
68 |
+
<item index="49" class="java.lang.String" itemvalue="Pillow" />
|
69 |
+
<item index="50" class="java.lang.String" itemvalue="torch" />
|
70 |
+
<item index="51" class="java.lang.String" itemvalue="torchvision" />
|
71 |
+
<item index="52" class="java.lang.String" itemvalue="torchsummary" />
|
72 |
+
</list>
|
73 |
+
</value>
|
74 |
+
</option>
|
75 |
+
</inspection_tool>
|
76 |
+
</profile>
|
77 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="Black">
|
4 |
+
<option name="sdkName" value="Python 3.9 (look)" />
|
5 |
+
</component>
|
6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (look)" project-jdk-type="Python SDK" />
|
7 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/image.iml" filepath="$PROJECT_DIR$/.idea/image.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/workspace.xml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="AutoImportSettings">
|
4 |
+
<option name="autoReloadType" value="SELECTIVE" />
|
5 |
+
</component>
|
6 |
+
<component name="ChangeListManager">
|
7 |
+
<list default="true" id="2992db4b-be05-4abd-b057-113b2d800a0e" name="更改" comment="" />
|
8 |
+
<option name="SHOW_DIALOG" value="false" />
|
9 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
10 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
11 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
12 |
+
</component>
|
13 |
+
<component name="MarkdownSettingsMigration">
|
14 |
+
<option name="stateVersion" value="1" />
|
15 |
+
</component>
|
16 |
+
<component name="ProjectColorInfo"><![CDATA[{
|
17 |
+
"associatedIndex": 1
|
18 |
+
}]]></component>
|
19 |
+
<component name="ProjectId" id="2y2VrJeUMs6MgMG5dRwdH89cBQE" />
|
20 |
+
<component name="ProjectViewState">
|
21 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
22 |
+
<option name="showLibraryContents" value="true" />
|
23 |
+
</component>
|
24 |
+
<component name="PropertiesComponent"><![CDATA[{
|
25 |
+
"keyToString": {
|
26 |
+
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
27 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
28 |
+
"last_opened_file_path": "E:/挑战杯/image"
|
29 |
+
}
|
30 |
+
}]]></component>
|
31 |
+
<component name="SharedIndexes">
|
32 |
+
<attachedChunks>
|
33 |
+
<set>
|
34 |
+
<option value="bundled-python-sdk-09665e90c3a7-d3b881c8e49f-com.jetbrains.pycharm.community.sharedIndexes.bundled-PC-233.15026.15" />
|
35 |
+
</set>
|
36 |
+
</attachedChunks>
|
37 |
+
</component>
|
38 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
|
39 |
+
<component name="TaskManager">
|
40 |
+
<task active="true" id="Default" summary="默认任务">
|
41 |
+
<changelist id="2992db4b-be05-4abd-b057-113b2d800a0e" name="更改" comment="" />
|
42 |
+
<created>1749032852317</created>
|
43 |
+
<option name="number" value="Default" />
|
44 |
+
<option name="presentableId" value="Default" />
|
45 |
+
<updated>1749032852317</updated>
|
46 |
+
</task>
|
47 |
+
<servers />
|
48 |
+
</component>
|
49 |
+
</project>
|
Dockerfile
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use the official Python image as a base
|
2 |
+
FROM python:3.9-slim
|
3 |
+
|
4 |
+
# Install system dependencies for OpenCV and wget
|
5 |
+
RUN apt-get update && apt-get install -y libgl1-mesa-glx libglib2.0-0 wget && rm -rf /var/lib/apt/lists/*
|
6 |
+
|
7 |
+
# Set working directory inside the container
|
8 |
+
WORKDIR /app
|
9 |
+
|
10 |
+
# Copy and install dependencies
|
11 |
+
COPY requirements.txt .
|
12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
13 |
+
|
14 |
+
# Download model weights from Hugging Face
|
15 |
+
RUN wget -O best_swin_upernet_main.pth "https://huggingface.co/samyakshrestha/swin-medical-segmentation/resolve/main/best_swin_upernet_main.pth?download=true"
|
16 |
+
|
17 |
+
# Copy the rest of the app code
|
18 |
+
COPY . .
|
19 |
+
|
20 |
+
# Expose the port
|
21 |
+
EXPOSE 8000
|
22 |
+
|
23 |
+
# Run FastAPI with uvicorn
|
24 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
@@ -1,12 +1,89 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Skin Lesion Segmentation API with Uncertainty Estimation
|
2 |
+
|
3 |
+
This repository contains a production-ready FastAPI service for performing semantic segmentation of skin lesions, trained on the ISIC 2018 dataset. The segmentation model uses a Swin-Tiny transformer encoder with a custom UPerNet decoder, and incorporates Monte Carlo (MC) Dropout at inference time to estimate predictive uncertainty.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- RESTful FastAPI endpoint for real-time image segmentation
|
8 |
+
- Swin-Tiny + UPerNet architecture for accurate lesion boundary detection
|
9 |
+
- Monte Carlo Dropout-based uncertainty quantification
|
10 |
+
- Heatmap visualization of model uncertainty using OpenCV colormaps
|
11 |
+
- Side-by-side output of the predicted segmentation mask and uncertainty map
|
12 |
+
- Clean, modular codebase suitable for further extension and deployment
|
13 |
+
|
14 |
+
## Inference Workflow
|
15 |
+
|
16 |
+
1. The API accepts a dermoscopic image uploaded via a POST request to the `/predict/` endpoint.
|
17 |
+
2. The image is resized, normalized, and passed through the segmentation model.
|
18 |
+
3. The model performs multiple stochastic forward passes using MC Dropout.
|
19 |
+
4. The mean prediction is binarized to produce a segmentation mask.
|
20 |
+
5. The standard deviation across predictions is visualized as a heatmap to express uncertainty.
|
21 |
+
6. The API returns both the binary mask and uncertainty heatmap as a combined PNG image.
|
22 |
+
|
23 |
+
## File Structure
|
24 |
+
|
25 |
+
| File | Description |
|
26 |
+
|-------------------|--------------------------------------------------------------|
|
27 |
+
| `main.py` | FastAPI application and endpoint logic |
|
28 |
+
| `model.py` | Model architecture and MC Dropout inference utilities |
|
29 |
+
| `requirements.txt`| Required Python packages |
|
30 |
+
| `.gitignore` | Specifies ignored files (e.g., `venv/`, model weights) |
|
31 |
+
|
32 |
+
## Requirements
|
33 |
+
|
34 |
+
- Python 3.9+
|
35 |
+
- PyTorch
|
36 |
+
- FastAPI
|
37 |
+
- OpenCV
|
38 |
+
- Pillow
|
39 |
+
- timm
|
40 |
+
- torchvision
|
41 |
+
- uvicorn
|
42 |
+
|
43 |
+
Install dependencies:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
pip install -r requirements.txt
|
47 |
+
```
|
48 |
+
|
49 |
+
|
50 |
+
## Model Weights
|
51 |
+
|
52 |
+
To run the API, you will need to download the pretrained model weights and place the file in the root directory of this repository.
|
53 |
+
|
54 |
+
**Download link:**
|
55 |
+
|
56 |
+
[best_swin_upernet_main.pth](https://huggingface.co/samyakshrestha/swin-medical-segmentation/resolve/main/best_swin_upernet_main.pth?download=true)
|
57 |
+
|
58 |
+
After downloading, ensure the file is named exactly "best_swin_upernet_main.pth" and saved in the same directory as `main.py` and `model.py`.
|
59 |
+
|
60 |
+
This weight file contains the Swin-Tiny UPerNet model trained on the ISIC 2018 skin lesion segmentation dataset.
|
61 |
+
|
62 |
+
|
63 |
+
## Dataset
|
64 |
+
|
65 |
+
This project uses the ISIC 2018 Challenge Dataset for training, which contains dermoscopic images of skin lesions annotated for binary segmentation.
|
66 |
+
|
67 |
+
Dataset access: https://challenge.isic-archive.com/data/
|
68 |
+
|
69 |
+
## Docker Support
|
70 |
+
|
71 |
+
This project includes a `Dockerfile` for containerized deployment. Run the API with Docker:
|
72 |
+
|
73 |
+
### Build the Docker image
|
74 |
+
```bash
|
75 |
+
docker build -t medical-segmentation-api .
|
76 |
+
```
|
77 |
+
|
78 |
+
### Run the container
|
79 |
+
```bash
|
80 |
+
docker run -p 8000:8000 medical-segmentation-api
|
81 |
+
```
|
82 |
+
|
83 |
+
|
84 |
+
The API will then be available at:
|
85 |
+
http://localhost:8000/predict/
|
86 |
+
|
87 |
+
## Author
|
88 |
+
|
89 |
+
Samyak Shrestha
|
__pycache__/model.cpython-310.pyc
ADDED
Binary file (4.8 kB). View file
|
|
__pycache__/model.cpython-39.pyc
ADDED
Binary file (4.82 kB). View file
|
|
aa.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
url = "http://127.0.0.1:4011/predict/"
|
4 |
+
files = {'file': open('D:\zkwg\image\data\iamge1.jpg', 'rb')} # 替换为您的图像路径
|
5 |
+
response = requests.post(url, files=files)
|
6 |
+
|
7 |
+
# 保存返回的预测结果
|
8 |
+
with open('D:\zkwg\image\data\prediction_result.png', 'wb') as f:
|
9 |
+
f.write(response.content)
|
app.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from model import load_model, predict_with_uncertainty
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText
|
9 |
+
from io import BytesIO
|
10 |
+
|
11 |
+
# 设置页面配置
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="医学图像分析系统",
|
14 |
+
page_icon="🏥",
|
15 |
+
layout="wide"
|
16 |
+
)
|
17 |
+
|
18 |
+
# 初始化模型
|
19 |
+
@st.cache_resource
|
20 |
+
def load_models():
|
21 |
+
# 加载分割模型
|
22 |
+
seg_model = load_model()
|
23 |
+
seg_model.eval()
|
24 |
+
|
25 |
+
# 加载分析模型
|
26 |
+
model_id = "google/medgemma-4b-it"
|
27 |
+
analysis_model = AutoModelForImageTextToText.from_pretrained(
|
28 |
+
model_id,
|
29 |
+
torch_dtype=torch.bfloat16,
|
30 |
+
device_map="auto",
|
31 |
+
token="HUGGINGFACE_TOKEN"
|
32 |
+
)
|
33 |
+
processor = AutoProcessor.from_pretrained(
|
34 |
+
model_id,
|
35 |
+
token="HUGGINGFACE_TOKEN"
|
36 |
+
)
|
37 |
+
|
38 |
+
return seg_model, analysis_model, processor
|
39 |
+
|
40 |
+
# 页面标题
|
41 |
+
st.title("🏥 医学图像分析系统")
|
42 |
+
st.markdown("---")
|
43 |
+
|
44 |
+
# 加载模型
|
45 |
+
with st.spinner("正在加载模型..."):
|
46 |
+
seg_model, analysis_model, processor = load_models()
|
47 |
+
|
48 |
+
# 创建两列布局
|
49 |
+
col1, col2 = st.columns(2)
|
50 |
+
|
51 |
+
with col1:
|
52 |
+
st.subheader("📤 上传图片")
|
53 |
+
uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
|
54 |
+
|
55 |
+
if uploaded_file is not None:
|
56 |
+
# 显示原始图片
|
57 |
+
image = Image.open(uploaded_file).convert("RGB")
|
58 |
+
st.image(image, caption="原始图片", use_column_width=True)
|
59 |
+
|
60 |
+
# 处理图片
|
61 |
+
if st.button("开始分析"):
|
62 |
+
with st.spinner("正在处理..."):
|
63 |
+
# 图像分割
|
64 |
+
image_resized = image.resize((224, 224))
|
65 |
+
transform = transforms.ToTensor()
|
66 |
+
image_tensor = transform(image_resized).unsqueeze(0)
|
67 |
+
|
68 |
+
# 执行分割
|
69 |
+
preds_mean, preds_uncertainty = predict_with_uncertainty(image_tensor)
|
70 |
+
|
71 |
+
# 生成分割掩码
|
72 |
+
pred_binary = (preds_mean > 0.5).astype(np.uint8) * 255
|
73 |
+
mask_image = Image.fromarray(pred_binary).convert("L")
|
74 |
+
|
75 |
+
# 处理不确定性图
|
76 |
+
uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8)
|
77 |
+
uncertainty_colormap = cv2.applyColorMap((uncertainty * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
|
78 |
+
uncertainty_image = Image.fromarray(uncertainty_colormap).convert("RGB")
|
79 |
+
|
80 |
+
# 合并图片
|
81 |
+
combined = Image.new("RGB", (mask_image.width + uncertainty_image.width, mask_image.height))
|
82 |
+
combined.paste(mask_image.convert("RGB"), (0, 0))
|
83 |
+
combined.paste(uncertainty_image, (mask_image.width, 0))
|
84 |
+
|
85 |
+
# 图像分析
|
86 |
+
messages = [
|
87 |
+
{
|
88 |
+
"role": "system",
|
89 |
+
"content": [{"type": "text", "text": "你是一名皮肤病专家,请使用中文分析图片."}]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"role": "user",
|
93 |
+
"content": [
|
94 |
+
{"type": "text", "text": "这是一张皮肤病的图片,帮我分析一下"},
|
95 |
+
{"type": "image", "image": image}
|
96 |
+
]
|
97 |
+
}
|
98 |
+
]
|
99 |
+
|
100 |
+
inputs = processor.apply_chat_template(
|
101 |
+
messages, add_generation_prompt=True, tokenize=True,
|
102 |
+
return_dict=True, return_tensors="pt"
|
103 |
+
).to(analysis_model.device, dtype=torch.bfloat16)
|
104 |
+
|
105 |
+
input_len = inputs["input_ids"].shape[-1]
|
106 |
+
|
107 |
+
with torch.inference_mode():
|
108 |
+
generation = analysis_model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
109 |
+
generation = generation[0][input_len:]
|
110 |
+
|
111 |
+
analysis_text = processor.decode(generation, skip_special_tokens=True)
|
112 |
+
|
113 |
+
# 显示结果
|
114 |
+
with col2:
|
115 |
+
st.subheader("📊 分析结果")
|
116 |
+
st.image(combined, caption="分割结果", use_column_width=True)
|
117 |
+
st.markdown("### 📝 图像分析")
|
118 |
+
st.write(analysis_text)
|
119 |
+
|
120 |
+
# 添加页脚
|
121 |
+
st.markdown("---")
|
122 |
+
st.markdown("### 使用说明")
|
123 |
+
st.markdown("""
|
124 |
+
1. 在左侧上传一张医学图像
|
125 |
+
2. 点击"开始分析"按钮
|
126 |
+
3. 系统将自动进行图像分割和分析
|
127 |
+
4. 右侧将显示分割结果和分析报告
|
128 |
+
""")
|
data/iamge1.jpg
ADDED
![]() |
data/iamge2.jpg
ADDED
![]() |
data/iamge3.png
ADDED
![]() |
data/prediction_result.png
ADDED
![]() |
main.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from fastapi import FastAPI, UploadFile, File
|
3 |
+
from fastapi.responses import Response
|
4 |
+
import uvicorn
|
5 |
+
from model import load_model, predict_with_uncertainty
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from PIL import Image
|
8 |
+
from io import BytesIO
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import torch
|
12 |
+
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
# Load the model when the API starts
|
16 |
+
model = load_model()
|
17 |
+
model.eval()
|
18 |
+
|
19 |
+
def convert_to_image(array, colormap=None):
|
20 |
+
array = (array * 255).astype(np.uint8) # Normalize to 0–255 range
|
21 |
+
if colormap is not None:
|
22 |
+
array = cv2.applyColorMap(array, colormap)
|
23 |
+
return Image.fromarray(array)
|
24 |
+
|
25 |
+
@app.post("/predict/")
|
26 |
+
async def predict_mask(file: UploadFile = File(...)):
|
27 |
+
# Read and preprocess the image
|
28 |
+
image = Image.open(BytesIO(await file.read())).convert("RGB")
|
29 |
+
image = image.resize((224, 224))
|
30 |
+
transform = transforms.ToTensor()
|
31 |
+
image_tensor = transform(image).unsqueeze(0)
|
32 |
+
|
33 |
+
# Perform MC Dropout Inference
|
34 |
+
preds_mean, preds_uncertainty = predict_with_uncertainty(image_tensor)
|
35 |
+
|
36 |
+
# Binary mask (0 or 255)
|
37 |
+
pred_binary = (preds_mean > 0.5).astype(np.uint8) * 255
|
38 |
+
mask_image = Image.fromarray(pred_binary).convert("L")
|
39 |
+
|
40 |
+
# Normalize and apply colormap to uncertainty
|
41 |
+
uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8)
|
42 |
+
uncertainty_colormap = cv2.applyColorMap((uncertainty * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
|
43 |
+
uncertainty_image = Image.fromarray(uncertainty_colormap).convert("RGB")
|
44 |
+
|
45 |
+
# Combine side by side
|
46 |
+
combined = Image.new("RGB", (mask_image.width + uncertainty_image.width, mask_image.height))
|
47 |
+
combined.paste(mask_image.convert("RGB"), (0, 0))
|
48 |
+
combined.paste(uncertainty_image, (mask_image.width, 0))
|
49 |
+
|
50 |
+
# Save to buffer and return
|
51 |
+
img_io = BytesIO()
|
52 |
+
combined.save(img_io, format="PNG")
|
53 |
+
img_io.seek(0)
|
54 |
+
|
55 |
+
return Response(content=img_io.getvalue(), media_type="image/png")
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
# 直接启动模式,适合各种操作系统环境
|
59 |
+
logging.info("启动MCP客户端API服务...")
|
60 |
+
uvicorn.run(app, host="0.0.0.0", port=4011)
|
model.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import timm
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
# -------------------------------
|
9 |
+
# Define Pyramid Pooling Module (with GroupNorm)
|
10 |
+
# -------------------------------
|
11 |
+
class PyramidPoolingModule(nn.Module):
|
12 |
+
def __init__(self, in_channels, pool_sizes=[1, 2, 3, 6]):
|
13 |
+
super().__init__()
|
14 |
+
self.pool_layers = nn.ModuleList([
|
15 |
+
nn.Sequential(
|
16 |
+
nn.AdaptiveAvgPool2d(pool_size),
|
17 |
+
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False),
|
18 |
+
nn.GroupNorm(num_groups=8, num_channels=in_channels // 4),
|
19 |
+
nn.ReLU(inplace=True)
|
20 |
+
) for pool_size in pool_sizes
|
21 |
+
])
|
22 |
+
total_channels = in_channels + len(pool_sizes) * (in_channels // 4)
|
23 |
+
self.conv = nn.Conv2d(total_channels, in_channels, kernel_size=1, bias=False)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
pooled_features = [x]
|
27 |
+
for layer in self.pool_layers:
|
28 |
+
pooled = layer(x)
|
29 |
+
pooled = F.interpolate(pooled, size=x.shape[2:], mode='bilinear', align_corners=False)
|
30 |
+
pooled_features.append(pooled)
|
31 |
+
x = torch.cat(pooled_features, dim=1)
|
32 |
+
x = self.conv(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
# -------------------------------
|
36 |
+
# Define UPerNet Decoder (With Dropout)
|
37 |
+
# -------------------------------
|
38 |
+
class UPerNetDecoder(nn.Module):
|
39 |
+
def __init__(self, encoder_channels, num_classes=1, dropout_rate=0.1):
|
40 |
+
super().__init__()
|
41 |
+
self.ppm = PyramidPoolingModule(encoder_channels[-1])
|
42 |
+
self.lateral_conv2 = nn.Conv2d(encoder_channels[2], encoder_channels[-1], kernel_size=1)
|
43 |
+
self.conv3 = nn.Sequential(
|
44 |
+
nn.Conv2d(encoder_channels[-1], encoder_channels[2], kernel_size=1),
|
45 |
+
nn.Dropout2d(p=dropout_rate)
|
46 |
+
)
|
47 |
+
self.lateral_conv1 = nn.Conv2d(encoder_channels[1], encoder_channels[2], kernel_size=1)
|
48 |
+
self.conv2 = nn.Sequential(
|
49 |
+
nn.Conv2d(encoder_channels[2], encoder_channels[1], kernel_size=1),
|
50 |
+
nn.Dropout2d(p=dropout_rate)
|
51 |
+
)
|
52 |
+
self.lateral_conv0 = nn.Conv2d(encoder_channels[0], encoder_channels[1], kernel_size=1)
|
53 |
+
self.conv1 = nn.Sequential(
|
54 |
+
nn.Conv2d(encoder_channels[1], encoder_channels[0], kernel_size=1),
|
55 |
+
nn.Dropout2d(p=dropout_rate)
|
56 |
+
)
|
57 |
+
self.segmentation_head = nn.Conv2d(encoder_channels[0], num_classes, kernel_size=1)
|
58 |
+
|
59 |
+
def forward(self, features):
|
60 |
+
f0, f1, f2, f3 = features
|
61 |
+
x3 = self.ppm(f3)
|
62 |
+
x3_up = F.interpolate(x3, size=f2.shape[2:], mode="bilinear", align_corners=False)
|
63 |
+
fuse2 = x3_up + self.lateral_conv2(f2)
|
64 |
+
fuse2 = self.conv3(fuse2)
|
65 |
+
fuse2_up = F.interpolate(fuse2, size=f1.shape[2:], mode="bilinear", align_corners=False)
|
66 |
+
fuse1 = fuse2_up + self.lateral_conv1(f1)
|
67 |
+
fuse1 = self.conv2(fuse1)
|
68 |
+
fuse1_up = F.interpolate(fuse1, size=f0.shape[2:], mode="bilinear", align_corners=False)
|
69 |
+
fuse0 = fuse1_up + self.lateral_conv0(f0)
|
70 |
+
fuse0 = self.conv1(fuse0)
|
71 |
+
x_out = F.interpolate(fuse0, size=(224, 224), mode="bilinear", align_corners=False)
|
72 |
+
output = self.segmentation_head(x_out)
|
73 |
+
return output
|
74 |
+
|
75 |
+
# -------------------------------
|
76 |
+
# Define Swin-Tiny UPerNet Model
|
77 |
+
# -------------------------------
|
78 |
+
class SwinTinyUPerNet(nn.Module):
|
79 |
+
def __init__(self, num_classes=1, dropout_rate=0.1):
|
80 |
+
super().__init__()
|
81 |
+
self.encoder = timm.create_model(
|
82 |
+
"swin_tiny_patch4_window7_224.ms_in22k_ft_in1k",
|
83 |
+
pretrained=True,
|
84 |
+
features_only=True
|
85 |
+
)
|
86 |
+
encoder_channels = self.encoder.feature_info.channels()
|
87 |
+
self.decoder = UPerNetDecoder(encoder_channels, num_classes, dropout_rate=dropout_rate)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
features = self.encoder(x)
|
91 |
+
features = [f.permute(0, 3, 1, 2) if f.dim() == 4 else f for f in features]
|
92 |
+
output = self.decoder(features)
|
93 |
+
return F.interpolate(output, size=(224, 224), mode="bilinear", align_corners=False)
|
94 |
+
|
95 |
+
# -------------------------------
|
96 |
+
# Load the Model
|
97 |
+
# -------------------------------
|
98 |
+
def load_model():
|
99 |
+
model = SwinTinyUPerNet(num_classes=1)
|
100 |
+
model.load_state_dict(torch.load("best_swin_upernet_main.pth", map_location=torch.device("cpu")), strict=False)
|
101 |
+
model.eval()
|
102 |
+
return model
|
103 |
+
|
104 |
+
# -------------------------------
|
105 |
+
# Enable Dropout at Inference Time
|
106 |
+
# -------------------------------
|
107 |
+
def enable_dropout(m):
|
108 |
+
if isinstance(m, nn.Dropout) or isinstance(m, nn.Dropout2d):
|
109 |
+
m.train()
|
110 |
+
|
111 |
+
# -------------------------------
|
112 |
+
# Perform Inference with MC Dropout
|
113 |
+
# -------------------------------
|
114 |
+
def predict_with_uncertainty(image_tensor, num_samples=10):
|
115 |
+
model = load_model()
|
116 |
+
model.apply(enable_dropout)
|
117 |
+
preds_list = []
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
for _ in range(num_samples):
|
121 |
+
preds = torch.sigmoid(model(image_tensor))
|
122 |
+
preds_list.append(preds)
|
123 |
+
|
124 |
+
preds_array = torch.stack(preds_list, dim=0)
|
125 |
+
preds_mean = preds_array.mean(dim=0).squeeze().cpu().numpy()
|
126 |
+
preds_uncertainty = preds_array.std(dim=0).squeeze().cpu().numpy()
|
127 |
+
|
128 |
+
# Normalize uncertainty map
|
129 |
+
preds_uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8)
|
130 |
+
return preds_mean, preds_uncertainty
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.32.0
|
2 |
+
torch==2.2.0
|
3 |
+
torchvision==0.17.0
|
4 |
+
Pillow==10.2.0
|
5 |
+
numpy==1.26.4
|
6 |
+
opencv-python==4.9.0.80
|
7 |
+
fastapi==0.110.0
|
8 |
+
uvicorn==0.27.1
|
9 |
+
python-multipart==0.0.9
|
10 |
+
huggingface-hub>=0.30.0
|
11 |
+
accelerate>=0.27.0
|
12 |
+
sentencepiece==0.2.0
|
13 |
+
protobuf==4.25.3
|
14 |
+
einops==0.7.0
|
15 |
+
safetensors==0.4.2
|