Sarco-Monitor / utils /model_loader.py
Ning311's picture
Update utils/model_loader.py
47da25e verified
"""
模型加载和管理工具
支持筛查类和建议类模型的统一管理
"""
import pickle
import pandas as pd
import numpy as np
import logging
import os
from pathlib import Path
from typing import Dict, Any, Optional
import warnings
warnings.filterwarnings('ignore')
# 安全模型加载 - 从私有HF仓库加载
try:
from huggingface_hub import hf_hub_download
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
print("⚠️ huggingface_hub未安装,将使用本地模型文件")
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
"""模型管理器"""
def __init__(self):
self.screening_models = {}
self.advisory_models = {}
self.model_configs = {}
self.thresholds = {}
# 模型路径配置 - 支持本地和云端部署
self.app_path = Path(__file__).parent.parent
# 检查是否使用HF模型 - 本地测试时默认false,HF部署时设置环境变量为true
self.use_hf_models = os.getenv("USE_HF_MODELS", "false").lower() == "true"
self.hf_model_repo = os.getenv("HF_MODEL_REPO", "Ning311/sarco-advisor-models")
self.hf_token = os.getenv("HF_TOKEN", None)
# 详细日志用于HF部署调试
logger.info(f"🔍 环境变量检查:")
logger.info(f" USE_HF_MODELS = {os.getenv('USE_HF_MODELS', 'NOT_SET')}")
logger.info(f" HF_MODEL_REPO = {self.hf_model_repo}")
logger.info(f" HF_TOKEN = {'SET' if self.hf_token else 'NOT_SET'}")
logger.info(f" use_hf_models = {self.use_hf_models}")
logger.info(f" HF_HUB_AVAILABLE = {HF_HUB_AVAILABLE}")
if self.use_hf_models and HF_HUB_AVAILABLE:
logger.info(f"🔒 使用HF私有仓库模型: {self.hf_model_repo}")
# HF模式下的模型路径
self.screening_paths = {
'sarcoI': "models/screening/sarcoI",
'sarcoII': "models/screening/sarcoII"
}
self.advisory_paths = {
'sarcoI': "models/advisory/sarcoI",
'sarcoII': "models/advisory/sarcoII"
}
else:
logger.info("📁 使用本地模型文件")
# 本地模式下的模型路径
self.screening_paths = {
'sarcoI': self.app_path / "models/screening/sarcoI",
'sarcoII': self.app_path / "models/screening/sarcoII"
}
self.advisory_paths = {
'sarcoI': self.app_path / "models/advisory/sarcoI",
'sarcoII': self.app_path / "models/advisory/sarcoII"
}
def load_all_models(self):
"""加载所有模型"""
try:
self._load_screening_models()
self._load_advisory_models()
self._load_thresholds()
logger.info("所有模型加载完成")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
def _load_screening_models(self):
"""加载筛查类模型"""
try:
# SarcoI筛查模型 - RandomForest
if self.use_hf_models and HF_HUB_AVAILABLE:
# 从HF下载模型 - 使用正确的文件路径
sarcoI_rf_path = hf_hub_download(
repo_id=self.hf_model_repo,
filename="models/screening/sarcoI/randomforest_model.pkl",
token=self.hf_token
)
logger.info(f"从HF下载SarcoI筛查模型: {sarcoI_rf_path}")
else:
# 使用本地模型
sarcoI_rf_path = self.screening_paths['sarcoI'] / "randomforest_model.pkl"
logger.info(f"使用本地SarcoI筛查模型: {sarcoI_rf_path}")
with open(sarcoI_rf_path, 'rb') as f:
self.screening_models['sarcoI'] = pickle.load(f)
logger.info("✅ SarcoI筛查模型加载成功")
# SarcoII筛查模型 - CatBoost (.cbm格式)
if self.use_hf_models and HF_HUB_AVAILABLE:
# 从HF下载模型 - 使用正确的文件路径
sarcoII_cat_path = hf_hub_download(
repo_id=self.hf_model_repo,
filename="models/screening/sarcoII/catboost_model.cbm",
token=self.hf_token
)
logger.info(f"从HF下载SarcoII筛查模型: {sarcoII_cat_path}")
else:
# 使用本地模型
sarcoII_cat_path = self.screening_paths['sarcoII'] / "catboost_model.cbm"
logger.info(f"使用本地SarcoII筛查模型: {sarcoII_cat_path}")
# 需要特殊处理CatBoost模型加载
try:
import catboost as cb
self.screening_models['sarcoII'] = cb.CatBoostClassifier()
self.screening_models['sarcoII'].load_model(str(sarcoII_cat_path))
logger.info("✅ SarcoII筛查模型加载成功")
except ImportError:
logger.error("CatBoost未安装,无法加载SarcoII筛查模型")
raise
logger.info("筛查模型加载成功")
except Exception as e:
logger.error(f"筛查模型加载失败: {str(e)}")
raise
def _load_advisory_models(self):
"""加载建议类模型(高精确率)"""
try:
# SarcoI建议模型 (CatBoost)
if self.use_hf_models and HF_HUB_AVAILABLE:
sarcoI_cat_path = hf_hub_download(
repo_id=self.hf_model_repo,
filename="models/advisory/sarcoI/CatBoost_model.pkl",
token=self.hf_token
)
logger.info(f"从HF下载SarcoI建议模型: {sarcoI_cat_path}")
else:
sarcoI_cat_path = self.advisory_paths['sarcoI'] / "CatBoost_model.pkl"
logger.info(f"使用本地SarcoI建议模型: {sarcoI_cat_path}")
# SarcoI建议模型是CatBoost模型 - 优先使用pickle加载
with open(sarcoI_cat_path, 'rb') as f:
loaded_model = pickle.load(f)
logger.info(f"🔍 SarcoI建议模型类型: {type(loaded_model)}")
# 检查是否是有效的机器学习模型
if hasattr(loaded_model, 'predict_proba'):
self.advisory_models['sarcoI'] = loaded_model
logger.info("✅ SarcoI建议模型加载成功 (pickle格式)")
else:
# 尝试从字典中提取模型
if isinstance(loaded_model, dict):
logger.info(f"🔍 字典键: {list(loaded_model.keys())}")
# 尝试常见的模型键,特别是CatBoost相关的键
model_keys = ['model', 'classifier', 'estimator', 'catboost_model', 'cb_model', 'best_model', 'trained_model', 'final_model', 'best_estimator']
found_model = False
for key in model_keys:
if key in loaded_model:
candidate_model = loaded_model[key]
logger.info(f"🔍 尝试键 '{key}': {type(candidate_model)}")
if hasattr(candidate_model, 'predict_proba'):
self.advisory_models['sarcoI'] = candidate_model
logger.info(f"✅ 从字典提取SarcoI建议模型成功 (键: {key})")
found_model = True
break
if not found_model:
# 如果没找到标准键,尝试所有值
for key, value in loaded_model.items():
logger.info(f"🔍 检查键 '{key}': {type(value)}")
if hasattr(value, 'predict_proba'):
self.advisory_models['sarcoI'] = value
logger.info(f"✅ 从字典提取SarcoI建议模型成功 (键: {key})")
found_model = True
break
if not found_model:
# 最后尝试:如果字典只有一个值,直接使用
if len(loaded_model) == 1:
single_key = list(loaded_model.keys())[0]
single_value = loaded_model[single_key]
logger.info(f"🔍 字典只有一个键 '{single_key}': {type(single_value)}")
if hasattr(single_value, 'predict_proba'):
self.advisory_models['sarcoI'] = single_value
logger.info(f"✅ 使用字典中唯一值作为SarcoI建议模型 (键: {single_key})")
found_model = True
if not found_model:
logger.error(f"❌ 字典内容详情: {[(k, type(v), hasattr(v, 'predict_proba') if hasattr(v, '__dict__') else 'N/A') for k, v in loaded_model.items()]}")
raise ValueError(f"字典中没有找到有效的机器学习模型")
else:
raise ValueError(f"加载的对象不是有效的机器学习模型: {type(loaded_model)}")
# SarcoII建议模型 (RandomForest)
if self.use_hf_models and HF_HUB_AVAILABLE:
sarcoII_rf_path = hf_hub_download(
repo_id=self.hf_model_repo,
filename="models/advisory/sarcoII/RandomForest_model.pkl",
token=self.hf_token
)
logger.info(f"从HF下载SarcoII建议模型: {sarcoII_rf_path}")
else:
sarcoII_rf_path = self.advisory_paths['sarcoII'] / "RandomForest_model.pkl"
logger.info(f"使用本地SarcoII建议模型: {sarcoII_rf_path}")
with open(sarcoII_rf_path, 'rb') as f:
loaded_model = pickle.load(f)
logger.info(f"🔍 SarcoII建议模型类型: {type(loaded_model)}")
# 检查是否是有效的机器学习模型
if hasattr(loaded_model, 'predict_proba'):
self.advisory_models['sarcoII'] = loaded_model
logger.info("✅ SarcoII建议模型加载成功 - 具备predict_proba方法")
else:
logger.error(f"❌ SarcoII建议模型无效 - 缺少predict_proba方法,类型: {type(loaded_model)}")
# 尝试从字典中提取模型(如果是包装格式)
if isinstance(loaded_model, dict) and 'model' in loaded_model:
actual_model = loaded_model['model']
logger.info(f"🔄 尝试从字典提取模型: {type(actual_model)}")
if hasattr(actual_model, 'predict_proba'):
self.advisory_models['sarcoII'] = actual_model
logger.info("✅ 从字典成功提取SarcoII建议模型")
else:
raise ValueError(f"字典中的模型也无效: {type(actual_model)}")
else:
raise ValueError(f"加载的对象不是有效的机器学习模型: {type(loaded_model)}")
logger.info("建议模型加载成功")
except Exception as e:
logger.error(f"建议模型加载失败: {str(e)}")
raise
def _load_thresholds(self):
"""加载模型阈值配置"""
try:
# 标准化风险阈值 - 使用统一的风险评估标准
# SarcoI 筛查模型: 低风险<0.12, 中风险0.121-0.149, 高风险≥0.15
sarcoI_screening_threshold = 0.15 # 高风险阈值
# SarcoII 筛查模型: 低风险<0.06, 中风险0.061-0.089, 高风险≥0.09
sarcoII_screening_threshold = 0.09 # 高风险阈值
# SarcoI 建议模型: 低风险<0.31, 中风险0.311-0.359, 高风险≥0.36
sarcoI_advisory_threshold = 0.36 # 高风险阈值
# SarcoII 建议模型: 低风险<0.47, 中风险0.47-0.519, 高风险≥0.52
sarcoII_advisory_threshold = 0.52 # 高风险阈值
# 注释掉动态阈值加载,使用标准化固定阈值确保一致性
# 这样可以避免不同环境下阈值不一致的问题
logger.info("使用标准化固定阈值,确保风险评估一致性")
logger.info(f"SarcoI筛查阈值: {sarcoI_screening_threshold} (高风险≥0.15)")
logger.info(f"SarcoI建议阈值: {sarcoI_advisory_threshold} (高风险≥0.36)")
logger.info(f"SarcoII筛查阈值: {sarcoII_screening_threshold} (高风险≥0.09)")
logger.info(f"SarcoII建议阈值: {sarcoII_advisory_threshold} (高风险≥0.52)")
# 如果需要动态加载阈值,可以取消以下注释:
# if sarcoI_screening_results.exists():
# with open(sarcoI_screening_results, 'rb') as f:
# results = pickle.load(f)
# if 'rf_best_threshold' in results:
# sarcoI_screening_threshold = results['rf_best_threshold']
# elif 'randomforest_best_threshold' in results:
# sarcoI_screening_threshold = results['randomforest_best_threshold']
# 设置标准化阈值
self.thresholds = {
'sarcoI': {
'screening': sarcoI_screening_threshold, # 0.15 (高风险阈值)
'advisory': sarcoI_advisory_threshold # 0.36 (高风险阈值)
},
'sarcoII': {
'screening': sarcoII_screening_threshold, # 0.09 (高风险阈值)
'advisory': sarcoII_advisory_threshold # 0.52 (高风险阈值)
}
}
logger.info("标准化阈值配置加载成功")
logger.info(f"阈值详情: {self.thresholds}")
except Exception as e:
logger.error(f"阈值加载失败: {str(e)}")
# 使用标准化默认阈值
self.thresholds = {
'sarcoI': {'screening': 0.15, 'advisory': 0.36},
'sarcoII': {'screening': 0.09, 'advisory': 0.52}
}
def predict_screening(self, user_data: Dict[str, float], model_type: str) -> Dict[str, Any]:
"""筛查预测 - 高召回率"""
try:
# 准备特征数据
features_df = self._prepare_features(user_data, model_type, mode='screening')
# 模型预测
model = self.screening_models[model_type]
probability = model.predict_proba(features_df)[0, 1]
threshold = self.thresholds[model_type]['screening']
# 风险分级
risk_level = self._get_risk_level(probability, threshold, mode='screening', model_type=model_type)
return {
'probability': float(probability),
'risk_level': risk_level,
'threshold': float(threshold),
'model_type': f"{model_type}_screening"
}
except Exception as e:
logger.error(f"{model_type}筛查预测失败: {str(e)}")
raise
def predict_advisory(self, user_data: Dict[str, float], model_type: str) -> Dict[str, Any]:
"""建议预测 - 高精确率"""
try:
# 调试信息
logger.info(f"🔍 建议预测调试 - 模型类型: {model_type}")
logger.info(f"🔍 可用建议模型: {list(self.advisory_models.keys())}")
# 检查模型是否存在
if model_type not in self.advisory_models:
raise KeyError(f"建议模型 '{model_type}' 不存在,可用模型: {list(self.advisory_models.keys())}")
# 准备特征数据
features_df = self._prepare_features(user_data, model_type, mode='advisory')
# 模型预测
model = self.advisory_models[model_type]
logger.info(f"🔍 使用模型: {type(model)}")
probability = model.predict_proba(features_df)[0, 1]
threshold = self.thresholds[model_type]['advisory']
# 风险分级
risk_level = self._get_risk_level(probability, threshold, mode='advisory', model_type=model_type)
return {
'probability': float(probability),
'risk_level': risk_level,
'threshold': float(threshold),
'model_type': f"{model_type}_advisory"
}
except Exception as e:
logger.error(f"{model_type}建议预测失败: {str(e)}")
raise
def _prepare_features(self, user_data: Dict[str, float], model_type: str, mode: str) -> pd.DataFrame:
"""准备模型特征 - 基于实际训练数据的特征顺序"""
if model_type == 'sarcoI':
if mode == 'screening':
# SarcoI筛查模型特征 - 基于实际模型期望的特征顺序
# 模型期望:['age_years', 'WWI', 'body_mass_index']
features = [
'age_years', 'WWI', 'body_mass_index'
]
else: # advisory
# SarcoI建议模型特征 (来自/Users/ning/Desktop/idea/代码forSarcoAdvisor/4.DICE建模/预筛选/SarcoI_train_final.csv)
features = [
'body_mass_index', 'race_ethnicity', 'WWI', 'age_years',
'Activity_Sedentary_Ratio', 'Total_Moderate_Minutes_week', 'Vigorous_MET_Ratio'
]
else: # sarcoII
if mode == 'screening':
# SarcoII筛查模型特征 - 基于CatBoost模型期望的特征顺序
# 模型期望:['WWI', 'age_years', 'race_ethnicity', 'body_mass_index']
features = [
'WWI', 'age_years', 'race_ethnicity', 'body_mass_index'
]
else: # advisory
# SarcoII建议模型特征 (来自/Users/ning/Desktop/idea/代码forSarcoAdvisor/4.DICE建模/预筛选/SarcoII_train_final.csv)
features = [
'body_mass_index', 'race_ethnicity', 'age_years', 'Activity_Sedentary_Ratio',
'Activity_Diversity_Index', 'WWI', 'Vigorous_MET_Ratio', 'sedentary_minutes'
]
# 构建特征DataFrame
feature_values = []
missing_features = []
for feature in features:
if feature in user_data and user_data[feature] is not None:
feature_values.append(user_data[feature])
else:
# 处理缺失特征的默认值
default_values = {
'Total_Moderate_Minutes_week': 150, # WHO推荐值
'Total_Moderate_Equivalent_Minutes': 150, # WHO推荐值
'Activity_Diversity_Index': 2, # 中等多样性
'sedentary_minutes': 480, # 8小时
'Avg_Vigorous_Duration_Per_Bout': 0, # 默认无高强度活动
'arthritis': 0, # 默认无关节炎
'diabetes': 0, # 默认无糖尿病
'Activity_Sedentary_Ratio': 0.5, # 默认活动比例
'Vigorous_MET_Ratio': 0.2 # 默认高强度比例
}
default_val = default_values.get(feature, 0)
feature_values.append(default_val)
missing_features.append(feature)
logger.warning(f"特征 '{feature}' 缺失,使用默认值: {default_val}")
if missing_features:
logger.info(f"{model_type}_{mode} 缺失特征: {missing_features}")
features_df = pd.DataFrame([feature_values], columns=features)
logger.info(f"{model_type}_{mode} 特征准备完成: {list(features_df.columns)} | 形状: {features_df.shape}")
return features_df
def _get_risk_level(self, probability: float, threshold: float, mode: str, model_type: str = None) -> str:
"""
根据概率和标准化阈值确定风险等级
标准化风险阈值:
SarcoI 筛查模型: 低风险<0.12, 中风险0.121-0.149, 高风险≥0.15
SarcoI 建议模型: 低风险<0.31, 中风险0.311-0.359, 高风险≥0.36
SarcoII 筛查模型: 低风险<0.06, 中风险0.061-0.089, 高风险≥0.09
SarcoII 建议模型: 低风险<0.47, 中风险0.47-0.519, 高风险≥0.52
Args:
probability: 模型预测概率
threshold: 决策阈值 (用于向后兼容,实际使用标准化阈值)
mode: 模式 ('screening' 或 'advisory')
model_type: 模型类型 ('sarcoI' 或 'sarcoII')
Returns:
风险等级: 'low', 'medium', 'high'
"""
# 使用标准化阈值
if model_type == 'sarcoI':
if mode == 'screening':
# SarcoI筛查模式: 低风险<0.12, 中风险0.121-0.149, 高风险≥0.15
if probability >= 0.15:
return 'high'
elif probability > 0.12: # 修正:>0.12 而不是 >=0.121
return 'medium'
else:
return 'low'
else: # advisory
# SarcoI建议模式: 低风险<0.31, 中风险0.311-0.359, 高风险≥0.36
if probability >= 0.36:
return 'high'
elif probability > 0.31: # 修正:>0.31 而不是 >=0.311
return 'medium'
else:
return 'low'
elif model_type == 'sarcoII':
if mode == 'screening':
# SarcoII筛查模式: 低风险<0.06, 中风险0.061-0.089, 高风险≥0.09
if probability >= 0.09:
return 'high'
elif probability > 0.06: # 修正:>0.06 而不是 >=0.061
return 'medium'
else:
return 'low'
else: # advisory
# SarcoII建议模式: 低风险<0.47, 中风险0.47-0.519, 高风险≥0.52
if probability >= 0.52:
return 'high'
elif probability >= 0.47: # 这个保持不变,因为0.47既是低风险上限也是中风险下限
return 'medium'
else:
return 'low'
else:
# 默认逻辑(向后兼容)- 使用传入的threshold
if probability >= threshold:
return 'high'
elif probability >= threshold * 0.75:
return 'medium'
else:
return 'low'
def get_comprehensive_risk(self, sarcoI_screening_result: Dict, sarcoI_advisory_result: Dict = None,
sarcoII_screening_result: Dict = None, sarcoII_advisory_result: Dict = None) -> Dict:
"""
计算新的综合风险等级 - 基于建议模型优先的融合方案
Args:
sarcoI_screening_result: SarcoI筛查模型结果
sarcoI_advisory_result: SarcoI建议模型结果 (可选)
sarcoII_screening_result: SarcoII筛查模型结果 (可选)
sarcoII_advisory_result: SarcoII建议模型结果 (可选)
Returns:
Dict: 包含SarcoI和SarcoII综合风险的字典
"""
results = {}
# SarcoI 综合风险判定
if sarcoI_screening_result:
P_recall_I = sarcoI_screening_result['probability']
P_precision_I = sarcoI_advisory_result['probability'] if sarcoI_advisory_result else 0.0
# 使用实际的模型阈值
sarcoI_advisory_threshold = self.thresholds['sarcoI']['advisory']
sarcoI_screening_threshold = self.thresholds['sarcoI']['screening']
if P_precision_I >= sarcoI_advisory_threshold: # 建议模型高风险阈值
sarcoI_comprehensive_risk = "high"
sarcoI_risk_reason = "advisory_model_high_risk"
elif P_recall_I >= sarcoI_screening_threshold: # 筛查模型高风险阈值
sarcoI_comprehensive_risk = "medium"
sarcoI_risk_reason = "screening_model_risk"
else:
sarcoI_comprehensive_risk = "low"
sarcoI_risk_reason = "both_models_low_risk"
results['sarcoI'] = {
'comprehensive_risk': sarcoI_comprehensive_risk,
'screening_probability': P_recall_I,
'advisory_probability': P_precision_I,
'risk_reason': sarcoI_risk_reason
}
# SarcoII 综合风险判定
if sarcoII_screening_result:
P_recall_II = sarcoII_screening_result['probability']
P_precision_II = sarcoII_advisory_result['probability'] if sarcoII_advisory_result else 0.0
# 使用实际的模型阈值
sarcoII_advisory_threshold = self.thresholds['sarcoII']['advisory']
sarcoII_screening_threshold = self.thresholds['sarcoII']['screening']
if P_precision_II >= sarcoII_advisory_threshold: # 建议模型高风险阈值
sarcoII_comprehensive_risk = "high"
sarcoII_risk_reason = "advisory_model_high_risk"
elif P_recall_II >= sarcoII_screening_threshold: # 筛查模型高风险阈值
sarcoII_comprehensive_risk = "medium"
sarcoII_risk_reason = "screening_model_risk"
else:
sarcoII_comprehensive_risk = "low"
sarcoII_risk_reason = "both_models_low_risk"
results['sarcoII'] = {
'comprehensive_risk': sarcoII_comprehensive_risk,
'screening_probability': P_recall_II,
'advisory_probability': P_precision_II,
'risk_reason': sarcoII_risk_reason
}
return results
def get_overall_risk(self, sarcoI_result: Dict, sarcoII_result: Dict) -> str:
"""
计算综合风险等级 (保持向后兼容)
基于两个模型的预测结果,使用更科学的综合评估方法
"""
sarcoI_risk = sarcoI_result['risk_level']
sarcoII_risk = sarcoII_result['risk_level']
sarcoI_prob = sarcoI_result['probability']
sarcoII_prob = sarcoII_result['probability']
# 获取各模型的阈值用于标准化
sarcoI_threshold = self.thresholds['sarcoI']['advisory']
sarcoII_threshold = self.thresholds['sarcoII']['advisory']
# 任一高风险则整体高风险
if sarcoI_risk == 'high' or sarcoII_risk == 'high':
return 'high'
# 任一中风险且另一不是低风险
if (sarcoI_risk == 'medium' and sarcoII_risk != 'low') or \
(sarcoII_risk == 'medium' and sarcoI_risk != 'low'):
return 'medium'
# 标准化概率后的综合评估
# 将概率标准化到各自阈值
normalized_sarcoI = sarcoI_prob / sarcoI_threshold
normalized_sarcoII = sarcoII_prob / sarcoII_threshold
# 加权平均 (可以根据模型性能调整权重)
weighted_score = (normalized_sarcoI + normalized_sarcoII) / 2
# 基于标准化分数的风险分级
if weighted_score >= 0.75: # 接近阈值的75%
return 'medium'
return 'low'
# 全局模型管理器实例
model_manager = ModelManager()
# 立即加载所有模型
try:
model_manager.load_all_models()
logger.info("🚀 全局模型管理器初始化完成")
except Exception as e:
logger.error(f"❌ 全局模型管理器初始化失败: {str(e)}")
# 不抛出异常,让应用继续运行,但会使用默认行为