|
import logging |
|
from fastapi import FastAPI, UploadFile, File |
|
from fastapi.responses import Response |
|
import uvicorn |
|
from model import load_model, predict_with_uncertainty |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
from io import BytesIO |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
|
|
app = FastAPI() |
|
|
|
|
|
model = load_model() |
|
model.eval() |
|
|
|
def convert_to_image(array, colormap=None): |
|
array = (array * 255).astype(np.uint8) |
|
if colormap is not None: |
|
array = cv2.applyColorMap(array, colormap) |
|
return Image.fromarray(array) |
|
|
|
@app.post("/predict/") |
|
async def predict_mask(file: UploadFile = File(...)): |
|
|
|
image = Image.open(BytesIO(await file.read())).convert("RGB") |
|
image = image.resize((224, 224)) |
|
transform = transforms.ToTensor() |
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
preds_mean, preds_uncertainty = predict_with_uncertainty(image_tensor) |
|
|
|
|
|
pred_binary = (preds_mean > 0.5).astype(np.uint8) * 255 |
|
mask_image = Image.fromarray(pred_binary).convert("L") |
|
|
|
|
|
uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8) |
|
uncertainty_colormap = cv2.applyColorMap((uncertainty * 255).astype(np.uint8), cv2.COLORMAP_INFERNO) |
|
uncertainty_image = Image.fromarray(uncertainty_colormap).convert("RGB") |
|
|
|
|
|
combined = Image.new("RGB", (mask_image.width + uncertainty_image.width, mask_image.height)) |
|
combined.paste(mask_image.convert("RGB"), (0, 0)) |
|
combined.paste(uncertainty_image, (mask_image.width, 0)) |
|
|
|
|
|
img_io = BytesIO() |
|
combined.save(img_io, format="PNG") |
|
img_io.seek(0) |
|
|
|
return Response(content=img_io.getvalue(), media_type="image/png") |
|
|
|
if __name__ == "__main__": |
|
|
|
logging.info("启动MCP客户端API服务...") |
|
uvicorn.run(app, host="0.0.0.0", port=4011) |