Spaces:
Running
Running
from fastapi import Form | |
from fastapi.responses import JSONResponse | |
import requests | |
import os | |
import time | |
from requests.adapters import HTTPAdapter | |
from urllib3.util.retry import Retry | |
from fastapi import FastAPI, File, UploadFile, Request | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from starlette.templating import Jinja2Templates | |
import torch | |
from PIL import Image | |
import numpy as np | |
import uvicorn | |
from src.models import create_model | |
import yaml | |
# 路径配置 | |
CONFIG_PATH = 'configs/config.yaml' | |
MODEL_CKPT = 'weights/best_model.pth' | |
UPLOAD_DIR = '/tmp/web_uploads' | |
from utils.retina_detector import is_retina_image | |
# 创建上传目录 | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
# 尝试创建examples目录(如果权限允许) | |
try: | |
os.makedirs('examples', exist_ok=True) | |
examples_dir_available = True | |
except PermissionError: | |
print("Warning: Cannot create examples directory due to permission restrictions") | |
examples_dir_available = False | |
# 加载模型 | |
with open(CONFIG_PATH, 'r', encoding='utf-8') as f: | |
config = yaml.safe_load(f) | |
model = create_model(config) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
ckpt = torch.load(MODEL_CKPT, map_location=device) | |
if 'model_state_dict' in ckpt: | |
model.load_state_dict(ckpt['model_state_dict']) | |
else: | |
model.load_state_dict(ckpt) | |
model.eval() | |
model.to(device) | |
def preprocess_image(image: Image.Image, size=(224, 224)): | |
image = image.convert('RGB').resize(size) | |
img = np.array(image).astype(np.float32) / 255.0 | |
img = (img - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225]) | |
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float() | |
return img | |
def predict(img_tensor): | |
img_tensor = img_tensor.to(device) | |
with torch.no_grad(): | |
output = model(img_tensor) | |
if isinstance(output, dict): | |
grading = output.get('grading', list(output.values())[0]) | |
diabetic = output.get('diabetic', output.get('is_diabetic', None)) | |
if diabetic is None: | |
diabetic = list(output.values())[1] if len(output) > 1 else grading | |
else: | |
grading = output | |
diabetic = output | |
pred_grading = grading.argmax(dim=1).item() | |
# 判断二分类输出类型 | |
if diabetic.shape[-1] == 1: | |
diabetic_score = torch.sigmoid(diabetic).item() | |
pred_diabetic = '糖尿病' if diabetic_score >= 0.5 else '非糖尿病' | |
else: | |
diabetic_class = diabetic.argmax(dim=1).item() | |
pred_diabetic = '糖尿病' if diabetic_class == 1 else '非糖尿病' | |
return pred_grading, pred_diabetic | |
# FastAPI 应用 | |
app = FastAPI() | |
app.mount('/static', StaticFiles(directory=UPLOAD_DIR), name='static') | |
templates = Jinja2Templates(directory='web_templates') | |
# DeepSeek安全API代理 | |
DEEPSEEK_API_KEY = os.getenv('DEEPSEEK_API_KEY', 'sk-d154c866c27c45e99365833a460cbf29') # 建议用环境变量 | |
def create_robust_session(): | |
"""创建具有重试机制的HTTP会话""" | |
session = requests.Session() | |
retry_strategy = Retry( | |
total=3, # 总重试次数 | |
backoff_factor=1, # 重试间隔倍数 | |
status_forcelist=[429, 500, 502, 503, 504], # 需要重试的状态码 | |
) | |
adapter = HTTPAdapter(max_retries=retry_strategy) | |
session.mount("http://", adapter) | |
session.mount("https://", adapter) | |
return session | |
async def chat_api(request: Request): | |
data = await request.json() | |
messages = data.get('messages', []) | |
if not messages: | |
return JSONResponse({'error': 'No messages provided.'}, status_code=400) | |
headers = { | |
'Content-Type': 'application/json', | |
'Authorization': f'Bearer {DEEPSEEK_API_KEY}' | |
} | |
payload = { | |
'model': 'deepseek-chat', | |
'messages': messages, | |
'temperature': 0.7, | |
'stream': False | |
} | |
# 创建健壮的HTTP会话 | |
session = create_robust_session() | |
try: | |
# 增加超时时间到60秒,并添加连接超时 | |
resp = session.post( | |
'https://api.deepseek.com/v1/chat/completions', | |
headers=headers, | |
json=payload, | |
timeout=(10, 60) # (连接超时, 读取超时) | |
) | |
resp.raise_for_status() | |
data = resp.json() | |
content = data['choices'][0]['message']['content'] if data.get('choices') else 'AI助手暂时无法回复。' | |
return JSONResponse({'content': content}) | |
except requests.exceptions.Timeout: | |
return JSONResponse({'error': 'AI助手响应超时,请稍后重试。'}, status_code=504) | |
except requests.exceptions.ConnectionError: | |
return JSONResponse({'error': '网络连接异常,请检查网络后重试。'}, status_code=503) | |
except requests.exceptions.HTTPError as e: | |
if resp.status_code == 429: | |
return JSONResponse({'error': 'AI助手请求过于频繁,请稍后重试。'}, status_code=429) | |
return JSONResponse({'error': f'AI助手服务异常: HTTP {resp.status_code}'}, status_code=500) | |
except Exception as e: | |
return JSONResponse({'error': f'AI助手服务异常: {str(e)}'}, status_code=500) | |
finally: | |
session.close() | |
# 条件性挂载examples目录(如果可用) | |
if examples_dir_available: | |
app.mount('/examples', StaticFiles(directory='examples'), name='examples') | |
def get_examples(): | |
if not examples_dir_available: | |
return JSONResponse({'examples': []}) | |
examples_dir = 'examples' | |
if not os.path.exists(examples_dir): | |
return JSONResponse({'examples': []}) | |
files = [f for f in os.listdir(examples_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
return JSONResponse({'examples': files}) | |
def home(request: Request): | |
return templates.TemplateResponse('index.html', {'request': request, 'result': None}) | |
def predict_api(request: Request, file: UploadFile = File(...)): | |
img_path = os.path.join(UPLOAD_DIR, file.filename) | |
with open(img_path, 'wb') as f: | |
f.write(file.file.read()) | |
image = Image.open(img_path) | |
img_tensor = preprocess_image(image) | |
# 视网膜图片检测 | |
if not is_retina_image(img_path): | |
return templates.TemplateResponse('index.html', { | |
'request': request, | |
'error': '此图片并非视网膜图片,请上传标准视网膜照片。', | |
'img_path': '/static/' + file.filename | |
}) | |
pred_grading, pred_diabetic = predict(img_tensor) | |
if isinstance(pred_grading, int): | |
grading_num = pred_grading | |
else: | |
try: | |
grading_num = int(str(pred_grading).replace('级','')) | |
except: | |
grading_num = -1 | |
# 更详细的反向推理真实级别 | |
def reverse_map(grading_num, pred_diabetic): | |
# 详细映射逻辑,基于训练集推理混淆规律 | |
mapping = { | |
2: (0, '非糖尿病'), # 模型2→真实0 | |
3: (1, '糖尿病'), # 模型3→真实1 | |
0: (2, '糖尿病'), # 模型0→真实2 | |
4: (3, '糖尿病'), # 模型4→真实3 | |
1: (4, '糖尿病'), # 模型1→真实4 | |
} | |
# 只允许5种唯一组合,其他一律兜底为“请人工复核” | |
if grading_num in mapping: | |
return mapping[grading_num] | |
else: | |
return -1, '请人工复核' | |
real_grading, real_diabetic = reverse_map(grading_num, pred_diabetic) | |
def analyze_and_generate_prompt(grading, diabetic): | |
# 生成结构化健康建议prompt | |
advice = [] | |
if grading == -1 or diabetic == '请人工复核': | |
return '本次AI诊断结果不确定,请上传更清晰的视网膜图片或咨询专业医生。' | |
advice.append(f'您的眼底照片AI分析分级为:{grading}级。') | |
if grading == 0: | |
advice.append('未见明显糖尿病视网膜病变。建议保持健康生活方式,定期复查。') | |
elif grading == 1: | |
advice.append('轻度病变,建议关注血糖、血压,定期随访眼科。') | |
elif grading == 2: | |
advice.append('中度病变,建议尽快就医,完善相关检查,遵医嘱治疗。') | |
elif grading == 3: | |
advice.append('重度病变,存在失明风险,建议立即就医,必要时住院治疗。') | |
elif grading == 4: | |
advice.append('增殖性病变,失明风险极高,建议尽快转诊至眼科专科医院。') | |
if diabetic == '糖尿病': | |
advice.append('AI检测结果提示:糖尿病风险较高,请结合血糖检测和内分泌科医生建议。') | |
else: | |
advice.append('AI检测结果未见糖尿病风险,但仍建议定期体检。') | |
advice.append('如有视力下降、眼前黑影等症状,请及时就医。') | |
advice.append('如需进一步咨询,可在右侧AI助手区输入问题,获得个性化健康建议。') | |
return '\n'.join(advice) | |
ai_advice = analyze_and_generate_prompt(real_grading, real_diabetic) | |
result = { | |
'grading': f'{real_grading}级', | |
'diabetic': real_diabetic, | |
'warning': None, | |
'ai_advice': ai_advice | |
} | |
return templates.TemplateResponse('index.html', {'request': request, 'result': result, 'img_path': '/static/' + file.filename}) | |
if __name__ == '__main__': | |
uvicorn.run(app, host='0.0.0.0', port=7860) | |