Spaces:
Running
Running
""" | |
SarcoAdvisor FastAPI主应用 | |
肌少症风险评估和个性化建议系统 | |
""" | |
import logging | |
import time | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, HTTPException, Request, BackgroundTasks | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
# 导入自定义模块 | |
from schemas.user_input import ( | |
UserInput, ScreeningRequest, AdvisoryRequest, | |
ScreeningResponse, AdvisoryResponse, ErrorResponse | |
) | |
from models.screening_models import screening_service | |
from models.advisory_models import advisory_service | |
from utils.model_loader import model_manager | |
# 配置日志 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# 应用生命周期管理 | |
async def lifespan(app: FastAPI): | |
# 启动时加载模型 | |
logger.info("🚀 启动SarcoAdvisor Web应用...") | |
try: | |
logger.info("📊 加载机器学习模型...") | |
model_manager.load_all_models() | |
logger.info("✅ 模型加载完成") | |
logger.info("🚀 初始化DiCE解释器 (终极精度模式 - 追求最高质量,无任何时间限制)...") | |
advisory_service.initialize_dice() | |
logger.info("✅ DiCE解释器初始化完成 - 已启用终极精度模式 (500个候选,25倍多样性权重)") | |
logger.info("🎯 SarcoAdvisor服务就绪!") | |
except Exception as e: | |
logger.error(f"❌ 应用启动失败: {str(e)}") | |
raise | |
yield | |
# 关闭时清理资源 | |
logger.info("🔄 关闭SarcoAdvisor服务...") | |
# 创建FastAPI应用 | |
app = FastAPI( | |
title="SarcoAdvisor API", | |
description="肌少症风险评估和个性化建议系统", | |
version="1.0.0", | |
docs_url="/docs", | |
redoc_url="/redoc", | |
lifespan=lifespan | |
) | |
# 确保模型在应用启动时加载(备用方案) | |
async def startup_event(): | |
"""应用启动事件 - 确保模型加载""" | |
try: | |
# 检查模型是否已加载 | |
if not model_manager.advisory_models: | |
logger.warning("⚠️ 检测到模型未加载,强制加载...") | |
model_manager.load_all_models() | |
logger.info("✅ 备用模型加载完成") | |
# 检查DiCE是否已初始化 | |
if not hasattr(advisory_service, 'dice_explainers') or not advisory_service.dice_explainers: | |
logger.warning("⚠️ 检测到DiCE未初始化,强制初始化...") | |
advisory_service.initialize_dice() | |
logger.info("✅ 备用DiCE初始化完成") | |
logger.info(f"🎯 模型状态检查: 筛查模型={list(model_manager.screening_models.keys())}, 建议模型={list(model_manager.advisory_models.keys())}") | |
except Exception as e: | |
logger.error(f"❌ 备用启动过程失败: {str(e)}") | |
# 不抛出异常,让应用继续运行 | |
# 添加CORS中间件 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # 生产环境中应该限制具体域名 | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 挂载静态文件 | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# 模板配置 | |
templates = Jinja2Templates(directory="templates") | |
# 全局异常处理 | |
async def global_exception_handler(request: Request, exc: Exception): | |
logger.error(f"全局异常: {str(exc)}") | |
return JSONResponse( | |
status_code=500, | |
content=ErrorResponse( | |
error="Internal Server Error", | |
detail=str(exc), | |
timestamp=str(time.time()) | |
).model_dump() | |
) | |
# 根路径 - 返回快速评估页面 | |
async def home(request: Request): | |
"""快速评估主页面""" | |
return templates.TemplateResponse("quick_assessment.html", {"request": request}) | |
# 原完整评估页面 | |
async def full_assessment_page(request: Request): | |
"""完整评估页面""" | |
return templates.TemplateResponse("index.html", {"request": request}) | |
# 统一评估页面 (新的双模型评估) | |
async def unified_assessment_page(request: Request): | |
"""统一评估页面 - 同时生成SarcoI和SarcoII建议""" | |
return templates.TemplateResponse("unified_assessment.html", {"request": request}) | |
# 字体测试页面 | |
async def font_test_page(request: Request): | |
"""字体样式测试页面""" | |
return templates.TemplateResponse("font_test.html", {"request": request}) | |
# 健康检查接口 | |
async def health_check(): | |
"""健康检查""" | |
try: | |
# 检查模型是否已加载 | |
screening_ready = bool(model_manager.screening_models) | |
advisory_ready = bool(model_manager.advisory_models) | |
return { | |
"status": "healthy" if (screening_ready and advisory_ready) else "partial", | |
"timestamp": time.time(), | |
"services": { | |
"screening_models": screening_ready, | |
"advisory_models": advisory_ready, | |
"dice_explainers": bool(advisory_service.dice_explainers) | |
}, | |
"version": "1.0.0" | |
} | |
except Exception as e: | |
logger.error(f"健康检查失败: {str(e)}") | |
return JSONResponse( | |
status_code=503, | |
content={"status": "unhealthy", "error": str(e)} | |
) | |
# 诊断接口 - 用于调试性能问题 | |
async def diagnostic_check(request: dict): | |
"""诊断检查 - 帮助识别性能瓶颈""" | |
start_time = time.time() | |
try: | |
logger.info("开始诊断检查...") | |
# 1. 检查数据验证耗时 | |
validation_start = time.time() | |
try: | |
user_input = UserInput(**request) | |
validation_time = time.time() - validation_start | |
except Exception as e: | |
return {"error": f"数据验证失败: {str(e)}", "step": "validation"} | |
# 2. 检查筛查模型耗时 | |
screening_start = time.time() | |
try: | |
screening_result = await screening_service.screening_assessment(user_input, ['sarcoI', 'sarcoII']) | |
screening_time = time.time() - screening_start | |
except Exception as e: | |
return {"error": f"筛查模型失败: {str(e)}", "step": "screening"} | |
# 3. 检查建议模型基础预测耗时(不包括DiCE) | |
advisory_start = time.time() | |
try: | |
# 这里我们只做基础预测,不包括DiCE | |
user_dict = user_input.model_dump() | |
sarcoI_result = model_manager.predict_advisory(user_dict, 'sarcoI') | |
sarcoII_result = model_manager.predict_advisory(user_dict, 'sarcoII') | |
advisory_time = time.time() - advisory_start | |
except Exception as e: | |
return {"error": f"建议模型预测失败: {str(e)}", "step": "advisory_prediction"} | |
total_time = time.time() - start_time | |
return { | |
"status": "success", | |
"timings": { | |
"validation": f"{validation_time:.3f}s", | |
"screening": f"{screening_time:.3f}s", | |
"advisory_prediction": f"{advisory_time:.3f}s", | |
"total": f"{total_time:.3f}s" | |
}, | |
"note": "DiCE analysis not included in this diagnostic" | |
} | |
except Exception as e: | |
total_time = time.time() - start_time | |
logger.error(f"诊断检查失败: {str(e)}") | |
return { | |
"error": str(e), | |
"total_time": f"{total_time:.3f}s" | |
} | |
# 筛查接口 | |
async def screening_assessment(request: ScreeningRequest): | |
""" | |
风险筛查评估 | |
使用高召回率模型进行快速风险筛查 | |
""" | |
try: | |
logger.info(f"收到筛查请求: 模型={request.models}") | |
# 验证用户数据 | |
user_data = request.user_data | |
# 执行筛查 | |
result = await screening_service.screening_assessment( | |
user_data=user_data, | |
models=request.models | |
) | |
logger.info(f"筛查完成: 综合风险={result.overall_risk}, 耗时={result.processing_time:.2f}s") | |
return result | |
except Exception as e: | |
logger.error(f"筛查评估失败: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"筛查评估失败: {str(e)}" | |
) | |
# 建议生成接口 | |
async def advisory_recommendations(request: AdvisoryRequest): | |
""" | |
🚀 终极精度个性化建议生成 | |
使用完整数据集 + 高精确率模型 + DiCE反事实解释生成最精准的个性化建议 | |
配置:500个反事实候选,25倍多样性权重,无任何时间限制 | |
注意:追求终极精度,响应时间可能很长,请耐心等待最佳结果 | |
""" | |
try: | |
logger.info(f"收到建议请求: 风险类型={request.risk_types}, 建议数量={request.num_recommendations}") | |
# 验证用户数据 | |
user_data = request.user_data | |
# 生成建议 | |
result = await advisory_service.generate_recommendations( | |
user_data=user_data, | |
risk_types=request.risk_types, | |
num_recommendations=request.num_recommendations, | |
language=request.language | |
) | |
logger.info(f"建议生成完成: SarcoI={len(result.sarcoI_recommendations)}, " | |
f"SarcoII={len(result.sarcoII_recommendations)}, " | |
f"fallback={result.fallback_used}, 耗时={result.processing_time:.2f}s") | |
return result | |
except Exception as e: | |
logger.error(f"建议生成失败: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"建议生成失败: {str(e)}" | |
) | |
# 完整评估接口 (筛查 + 建议) | |
async def full_assessment(request: dict, background_tasks: BackgroundTasks): | |
""" | |
完整评估流程 | |
先进行筛查,对高风险用户生成个性化建议 | |
""" | |
try: | |
start_time = time.time() | |
logger.info("开始完整评估流程") | |
# 解析请求参数 | |
user_data = UserInput(**request) | |
language = request.get('language', 'zh') # 获取语言参数,默认中文 | |
logger.info(f"完整评估语言设置: {language}") | |
# 第一步: 筛查评估 | |
screening_request = ScreeningRequest( | |
user_data=user_data, | |
models=["sarcoI", "sarcoII"] | |
) | |
screening_result = await screening_assessment(screening_request) | |
# 第二步: 建议模型结果已包含在筛查结果中 | |
advisory_result = None | |
# 生成个性化建议(DiCE)- 改进逻辑 | |
# 1. 优先检查建议模型的风险评估结果 | |
# 2. 如果建议模型有结果,则生成建议(包括低风险的维持性建议) | |
# 3. 否则基于筛查模型结果决定 | |
# 智能建议生成:SarcoII是SarcoI的严重版,SarcoII高危则SarcoI必然高危 | |
risk_types = [] | |
should_generate_advice = False | |
primary_risk_type = None | |
# 检查是否有完整的活动数据用于建议生成 | |
activity_fields = ['PAQ605', 'PAQ620', 'PAQ635', 'PAQ650', 'PAQ665', 'PAD680'] | |
has_complete_data = all(field in user_data for field in activity_fields) | |
# 1. 首先检查SarcoII(更严重的肌少症) | |
logger.info(f"🔍 调试风险判断 - SarcoII筛查风险: {screening_result.sarcoII_risk}") | |
logger.info(f"🔍 调试风险判断 - SarcoII建议风险: {getattr(screening_result, 'sarcoII_advisory_risk', 'None')}") | |
logger.info(f"🔍 调试风险判断 - SarcoI筛查风险: {getattr(screening_result, 'sarcoI_risk', 'None')}") | |
logger.info(f"🔍 调试风险判断 - SarcoI建议风险: {getattr(screening_result, 'sarcoI_advisory_risk', 'None')}") | |
sarcoII_has_risk = False | |
if hasattr(screening_result, 'sarcoII_advisory_risk') and screening_result.sarcoII_advisory_risk: | |
if screening_result.sarcoII_advisory_risk in ["medium", "high"]: | |
sarcoII_has_risk = True | |
logger.info(f"SarcoII建议模型显示风险: {screening_result.sarcoII_advisory_risk}") | |
elif screening_result.sarcoII_risk in ["medium", "high"]: | |
sarcoII_has_risk = True | |
logger.info(f"SarcoII筛查模型显示风险: {screening_result.sarcoII_risk}") | |
# 2. 检查SarcoI | |
sarcoI_has_risk = False | |
if hasattr(screening_result, 'sarcoI_advisory_risk') and screening_result.sarcoI_advisory_risk: | |
if screening_result.sarcoI_advisory_risk in ["medium", "high"]: | |
sarcoI_has_risk = True | |
logger.info(f"SarcoI建议模型显示风险: {screening_result.sarcoI_advisory_risk}") | |
elif hasattr(screening_result, 'sarcoI_risk') and screening_result.sarcoI_risk and screening_result.sarcoI_risk in ["medium", "high"]: | |
sarcoI_has_risk = True | |
logger.info(f"SarcoI筛查模型显示风险: {screening_result.sarcoI_risk}") | |
logger.info(f"🔍 风险判断结果 - SarcoII有风险: {sarcoII_has_risk}, SarcoI有风险: {sarcoI_has_risk}") | |
# 3. 根据风险情况决定建议生成策略 | |
# 修正逻辑:只要筛查模型显示风险,就应该生成建议 | |
if sarcoII_has_risk: | |
# SarcoII高危:生成SarcoII建议(包含SarcoI改善) | |
risk_types = ["sarcoII"] | |
should_generate_advice = True | |
primary_risk_type = "sarcoII" | |
logger.info("🎯 检测到SarcoII风险,生成SarcoII建议(SarcoII高危意味着SarcoI也高危)") | |
elif sarcoI_has_risk: | |
# 只有SarcoI高危:生成SarcoI建议 | |
risk_types = ["sarcoI"] | |
should_generate_advice = True | |
primary_risk_type = "sarcoI" | |
logger.info("🎯 检测到SarcoI风险,生成SarcoI建议") | |
elif screening_result.sarcoII_risk in ["medium", "high"]: | |
# 修正:即使建议模型显示低风险,但筛查模型显示风险时,仍应生成建议 | |
risk_types = ["sarcoII"] | |
should_generate_advice = True | |
primary_risk_type = "sarcoII" | |
logger.info(f"🎯 SarcoII筛查模型显示{screening_result.sarcoII_risk}风险,生成建议") | |
elif hasattr(screening_result, 'sarcoI_risk') and screening_result.sarcoI_risk in ["medium", "high"]: | |
# 修正:即使建议模型显示低风险,但筛查模型显示风险时,仍应生成建议 | |
risk_types = ["sarcoI"] | |
should_generate_advice = True | |
primary_risk_type = "sarcoI" | |
logger.info(f"🎯 SarcoI筛查模型显示{screening_result.sarcoI_risk}风险,生成建议") | |
elif has_complete_data: | |
# 都是低风险但数据完整:生成维持性建议 | |
# 选择风险概率更高的模型生成维持性建议 | |
if screening_result.sarcoII_probability > getattr(screening_result, 'sarcoI_probability', 0): | |
risk_types = ["sarcoII"] | |
primary_risk_type = "sarcoII" | |
logger.info(f"🔄 低风险但数据完整,为SarcoII生成维持性建议(概率: {screening_result.sarcoII_probability:.3f})") | |
else: | |
risk_types = ["sarcoI"] | |
primary_risk_type = "sarcoI" | |
sarcoI_prob = getattr(screening_result, 'sarcoI_probability', 0) | |
logger.info(f"🔄 低风险但数据完整,为SarcoI生成维持性建议(概率: {sarcoI_prob:.3f})") | |
should_generate_advice = True | |
else: | |
logger.info("⚠️ 无风险且数据不完整,不生成建议") | |
if should_generate_advice and risk_types: | |
logger.info(f"为以下模型生成建议: {risk_types}") | |
advisory_request = AdvisoryRequest( | |
user_data=user_data, | |
risk_types=risk_types, | |
num_recommendations=5, # 🚀 增加到5个建议,获得更多样化的建议 | |
language=language # 🌍 传递语言参数 | |
) | |
advisory_result = await advisory_recommendations(advisory_request) | |
else: | |
logger.info("不满足建议生成条件,跳过DiCE建议") | |
# 计算新的综合风险评估 | |
comprehensive_risk = None | |
try: | |
# 准备筛查和建议模型结果 | |
sarcoI_screening = { | |
'probability': screening_result.sarcoI_probability, | |
'risk_level': screening_result.sarcoI_risk.value | |
} | |
sarcoI_advisory = None | |
if screening_result.sarcoI_advisory_probability is not None: | |
sarcoI_advisory = { | |
'probability': screening_result.sarcoI_advisory_probability, | |
'risk_level': screening_result.sarcoI_advisory_risk.value | |
} | |
sarcoII_screening = { | |
'probability': screening_result.sarcoII_probability, | |
'risk_level': screening_result.sarcoII_risk.value | |
} | |
sarcoII_advisory = None | |
if screening_result.sarcoII_advisory_probability is not None: | |
sarcoII_advisory = { | |
'probability': screening_result.sarcoII_advisory_probability, | |
'risk_level': screening_result.sarcoII_advisory_risk.value | |
} | |
# 计算综合风险 | |
comprehensive_risk = model_manager.get_comprehensive_risk( | |
sarcoI_screening_result=sarcoI_screening, | |
sarcoI_advisory_result=sarcoI_advisory, | |
sarcoII_screening_result=sarcoII_screening, | |
sarcoII_advisory_result=sarcoII_advisory | |
) | |
logger.info(f"综合风险评估完成: {comprehensive_risk}") | |
except Exception as e: | |
logger.error(f"综合风险计算失败: {str(e)}") | |
# 获取风险解释 | |
risk_explanation = screening_service.get_risk_explanation(screening_result) | |
total_time = time.time() - start_time | |
response = { | |
"screening": screening_result.model_dump(), | |
"advisory": advisory_result.model_dump() if advisory_result else None, | |
"comprehensive_risk": comprehensive_risk, # 新增综合风险评估 | |
"risk_explanation": risk_explanation, | |
"needs_advisory": bool(advisory_result), | |
"total_processing_time": total_time, | |
"timestamp": time.time() | |
} | |
logger.info(f"完整评估完成: 风险={screening_result.overall_risk}, " | |
f"生成建议={'是' if advisory_result else '否'}, 总耗时={total_time:.2f}s") | |
return response | |
except Exception as e: | |
logger.error(f"完整评估失败: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"完整评估失败: {str(e)}" | |
) | |
# 模型信息接口 | |
async def get_model_info(): | |
"""获取模型信息""" | |
try: | |
return { | |
"screening_models": { | |
"sarcoI": { | |
"type": "RandomForest", | |
"purpose": "高召回率筛查", | |
"threshold": model_manager.thresholds.get('sarcoI', {}).get('screening', 0.5), | |
"performance": { | |
"recall": 0.9114, | |
"precision": 0.4305, | |
"model_path": "/Users/ning/Desktop/idea/代码forSarcoAdvisor/3.建模/SarcoI_results/randomforest_model.pkl", | |
"features": 3 | |
} | |
}, | |
"sarcoII": { | |
"type": "CatBoost", | |
"purpose": "高召回率筛查", | |
"threshold": model_manager.thresholds.get('sarcoII', {}).get('screening', 0.5), | |
"performance": { | |
"precision": 0.2548, | |
"recall": 0.8983, | |
"model_path": "/Users/ning/Desktop/idea/代码forSarcoAdvisor/3.建模/SarcoII_results/catboost_model.cbm", | |
"features": 4 | |
} | |
} | |
}, | |
"advisory_models": { | |
"sarcoI": { | |
"type": "CatBoost", | |
"purpose": "高精确率建议 + DiCE", | |
"threshold": model_manager.thresholds.get('sarcoI', {}).get('advisory', 0.36), | |
"dice_features": 5 | |
}, | |
"sarcoII": { | |
"type": "RandomForest", | |
"purpose": "高精确率建议 + DiCE", | |
"threshold": model_manager.thresholds.get('sarcoII', {}).get('advisory', 0.52), | |
"dice_features": 6 | |
} | |
}, | |
"features": { | |
"sarcoI": ["body_mass_index", "race_ethnicity", "WWI", "age_years", | |
"Activity_Sedentary_Ratio", "Total_Moderate_Minutes_week", "Vigorous_MET_Ratio"], | |
"sarcoII": ["body_mass_index", "race_ethnicity", "age_years", | |
"Activity_Sedentary_Ratio", "Activity_Diversity_Index", "WWI", | |
"Vigorous_MET_Ratio", "sedentary_minutes"] | |
} | |
} | |
except Exception as e: | |
logger.error(f"获取模型信息失败: {str(e)}") | |
raise HTTPException(status_code=500, detail="获取模型信息失败") | |
# 快速评估接口 (基于4个共同特征) | |
async def quick_assessment(user_data: UserInput): | |
""" | |
快速评估接口 | |
基于4个模型的共同特征 (age_years, race_ethnicity, body_mass_index, WWI) | |
提供所有4个模型的初步评估结果 | |
""" | |
try: | |
start_time = time.time() | |
logger.info("开始快速评估流程") | |
# 纯筛查模式:只运行筛查模型,不运行建议模型 | |
screening_request = ScreeningRequest( | |
user_data=user_data, | |
models=["sarcoI", "sarcoII"] | |
) | |
# 调用筛查服务,明确指定不包含建议模型 | |
screening_result = await screening_service.screening_assessment( | |
user_data=user_data, | |
models=["sarcoI", "sarcoII"], | |
include_advisory=False # 关键:快速评估不运行建议模型 | |
) | |
# 获取风险解释 | |
risk_explanation = screening_service.get_risk_explanation(screening_result) | |
# 构建快速评估结果 | |
result = { | |
"screening": screening_result.model_dump(), | |
"advisory": None, # 快速评估不提供建议模型结果 | |
"risk_explanation": risk_explanation, | |
"needs_advisory": False, # 快速评估阶段不需要DiCE建议 | |
"assessment_type": "quick", | |
"common_features_used": ["age_years", "race_ethnicity", "body_mass_index", "WWI"], | |
"total_processing_time": 0, # 将在下面计算 | |
"timestamp": time.time() | |
} | |
total_time = time.time() - start_time | |
result["total_processing_time"] = total_time | |
logger.info(f"快速评估完成,总耗时={total_time:.2f}s") | |
return result | |
except Exception as e: | |
logger.error(f"快速评估失败: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"快速评估失败: {str(e)}" | |
) | |
# 应用启动函数 | |
def start_server(): | |
"""启动服务器""" | |
import os | |
port = int(os.environ.get("PORT", 8001)) # 默认使用8001端口 | |
uvicorn.run( | |
"main:app", | |
host="0.0.0.0", | |
port=port, | |
reload=False, # 云端部署时关闭reload | |
log_level="info" | |
) | |
if __name__ == "__main__": | |
start_server() |