Spaces:
Running
Running
""" | |
模型加载和管理工具 | |
支持筛查类和建议类模型的统一管理 | |
""" | |
import pickle | |
import pandas as pd | |
import numpy as np | |
import logging | |
import os | |
from pathlib import Path | |
from typing import Dict, Any, Optional | |
import warnings | |
warnings.filterwarnings('ignore') | |
# 安全模型加载 - 从私有HF仓库加载 | |
try: | |
from huggingface_hub import hf_hub_download | |
HF_HUB_AVAILABLE = True | |
except ImportError: | |
HF_HUB_AVAILABLE = False | |
print("⚠️ huggingface_hub未安装,将使用本地模型文件") | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ModelManager: | |
"""模型管理器""" | |
def __init__(self): | |
self.screening_models = {} | |
self.advisory_models = {} | |
self.model_configs = {} | |
self.thresholds = {} | |
# 模型路径配置 - 支持本地和云端部署 | |
self.app_path = Path(__file__).parent.parent | |
# 检查是否使用HF模型 - 本地测试时默认false,HF部署时设置环境变量为true | |
self.use_hf_models = os.getenv("USE_HF_MODELS", "false").lower() == "true" | |
self.hf_model_repo = os.getenv("HF_MODEL_REPO", "Ning311/sarco-advisor-models") | |
self.hf_token = os.getenv("HF_TOKEN", None) | |
# 详细日志用于HF部署调试 | |
logger.info(f"🔍 环境变量检查:") | |
logger.info(f" USE_HF_MODELS = {os.getenv('USE_HF_MODELS', 'NOT_SET')}") | |
logger.info(f" HF_MODEL_REPO = {self.hf_model_repo}") | |
logger.info(f" HF_TOKEN = {'SET' if self.hf_token else 'NOT_SET'}") | |
logger.info(f" use_hf_models = {self.use_hf_models}") | |
logger.info(f" HF_HUB_AVAILABLE = {HF_HUB_AVAILABLE}") | |
if self.use_hf_models and HF_HUB_AVAILABLE: | |
logger.info(f"🔒 使用HF私有仓库模型: {self.hf_model_repo}") | |
# HF模式下的模型路径 | |
self.screening_paths = { | |
'sarcoI': "models/screening/sarcoI", | |
'sarcoII': "models/screening/sarcoII" | |
} | |
self.advisory_paths = { | |
'sarcoI': "models/advisory/sarcoI", | |
'sarcoII': "models/advisory/sarcoII" | |
} | |
else: | |
logger.info("📁 使用本地模型文件") | |
# 本地模式下的模型路径 | |
self.screening_paths = { | |
'sarcoI': self.app_path / "models/screening/sarcoI", | |
'sarcoII': self.app_path / "models/screening/sarcoII" | |
} | |
self.advisory_paths = { | |
'sarcoI': self.app_path / "models/advisory/sarcoI", | |
'sarcoII': self.app_path / "models/advisory/sarcoII" | |
} | |
def load_all_models(self): | |
"""加载所有模型""" | |
try: | |
self._load_screening_models() | |
self._load_advisory_models() | |
self._load_thresholds() | |
logger.info("所有模型加载完成") | |
except Exception as e: | |
logger.error(f"模型加载失败: {str(e)}") | |
raise | |
def _load_screening_models(self): | |
"""加载筛查类模型""" | |
try: | |
# SarcoI筛查模型 - RandomForest | |
if self.use_hf_models and HF_HUB_AVAILABLE: | |
# 从HF下载模型 - 使用正确的文件路径 | |
sarcoI_rf_path = hf_hub_download( | |
repo_id=self.hf_model_repo, | |
filename="models/screening/sarcoI/randomforest_model.pkl", | |
token=self.hf_token | |
) | |
logger.info(f"从HF下载SarcoI筛查模型: {sarcoI_rf_path}") | |
else: | |
# 使用本地模型 | |
sarcoI_rf_path = self.screening_paths['sarcoI'] / "randomforest_model.pkl" | |
logger.info(f"使用本地SarcoI筛查模型: {sarcoI_rf_path}") | |
with open(sarcoI_rf_path, 'rb') as f: | |
self.screening_models['sarcoI'] = pickle.load(f) | |
logger.info("✅ SarcoI筛查模型加载成功") | |
# SarcoII筛查模型 - CatBoost (.cbm格式) | |
if self.use_hf_models and HF_HUB_AVAILABLE: | |
# 从HF下载模型 - 使用正确的文件路径 | |
sarcoII_cat_path = hf_hub_download( | |
repo_id=self.hf_model_repo, | |
filename="models/screening/sarcoII/catboost_model.cbm", | |
token=self.hf_token | |
) | |
logger.info(f"从HF下载SarcoII筛查模型: {sarcoII_cat_path}") | |
else: | |
# 使用本地模型 | |
sarcoII_cat_path = self.screening_paths['sarcoII'] / "catboost_model.cbm" | |
logger.info(f"使用本地SarcoII筛查模型: {sarcoII_cat_path}") | |
# 需要特殊处理CatBoost模型加载 | |
try: | |
import catboost as cb | |
self.screening_models['sarcoII'] = cb.CatBoostClassifier() | |
self.screening_models['sarcoII'].load_model(str(sarcoII_cat_path)) | |
logger.info("✅ SarcoII筛查模型加载成功") | |
except ImportError: | |
logger.error("CatBoost未安装,无法加载SarcoII筛查模型") | |
raise | |
logger.info("筛查模型加载成功") | |
except Exception as e: | |
logger.error(f"筛查模型加载失败: {str(e)}") | |
raise | |
def _load_advisory_models(self): | |
"""加载建议类模型(高精确率)""" | |
try: | |
# SarcoI建议模型 (CatBoost) | |
if self.use_hf_models and HF_HUB_AVAILABLE: | |
sarcoI_cat_path = hf_hub_download( | |
repo_id=self.hf_model_repo, | |
filename="models/advisory/sarcoI/CatBoost_model.pkl", | |
token=self.hf_token | |
) | |
logger.info(f"从HF下载SarcoI建议模型: {sarcoI_cat_path}") | |
else: | |
sarcoI_cat_path = self.advisory_paths['sarcoI'] / "CatBoost_model.pkl" | |
logger.info(f"使用本地SarcoI建议模型: {sarcoI_cat_path}") | |
# SarcoI建议模型是CatBoost模型 - 优先使用pickle加载 | |
with open(sarcoI_cat_path, 'rb') as f: | |
loaded_model = pickle.load(f) | |
logger.info(f"🔍 SarcoI建议模型类型: {type(loaded_model)}") | |
# 检查是否是有效的机器学习模型 | |
if hasattr(loaded_model, 'predict_proba'): | |
self.advisory_models['sarcoI'] = loaded_model | |
logger.info("✅ SarcoI建议模型加载成功 (pickle格式)") | |
else: | |
# 尝试从字典中提取模型 | |
if isinstance(loaded_model, dict): | |
logger.info(f"🔍 字典键: {list(loaded_model.keys())}") | |
# 尝试常见的模型键,特别是CatBoost相关的键 | |
model_keys = ['model', 'classifier', 'estimator', 'catboost_model', 'cb_model', 'best_model', 'trained_model', 'final_model', 'best_estimator'] | |
found_model = False | |
for key in model_keys: | |
if key in loaded_model: | |
candidate_model = loaded_model[key] | |
logger.info(f"🔍 尝试键 '{key}': {type(candidate_model)}") | |
if hasattr(candidate_model, 'predict_proba'): | |
self.advisory_models['sarcoI'] = candidate_model | |
logger.info(f"✅ 从字典提取SarcoI建议模型成功 (键: {key})") | |
found_model = True | |
break | |
if not found_model: | |
# 如果没找到标准键,尝试所有值 | |
for key, value in loaded_model.items(): | |
logger.info(f"🔍 检查键 '{key}': {type(value)}") | |
if hasattr(value, 'predict_proba'): | |
self.advisory_models['sarcoI'] = value | |
logger.info(f"✅ 从字典提取SarcoI建议模型成功 (键: {key})") | |
found_model = True | |
break | |
if not found_model: | |
# 最后尝试:如果字典只有一个值,直接使用 | |
if len(loaded_model) == 1: | |
single_key = list(loaded_model.keys())[0] | |
single_value = loaded_model[single_key] | |
logger.info(f"🔍 字典只有一个键 '{single_key}': {type(single_value)}") | |
if hasattr(single_value, 'predict_proba'): | |
self.advisory_models['sarcoI'] = single_value | |
logger.info(f"✅ 使用字典中唯一值作为SarcoI建议模型 (键: {single_key})") | |
found_model = True | |
if not found_model: | |
logger.error(f"❌ 字典内容详情: {[(k, type(v), hasattr(v, 'predict_proba') if hasattr(v, '__dict__') else 'N/A') for k, v in loaded_model.items()]}") | |
raise ValueError(f"字典中没有找到有效的机器学习模型") | |
else: | |
raise ValueError(f"加载的对象不是有效的机器学习模型: {type(loaded_model)}") | |
# SarcoII建议模型 (RandomForest) | |
if self.use_hf_models and HF_HUB_AVAILABLE: | |
sarcoII_rf_path = hf_hub_download( | |
repo_id=self.hf_model_repo, | |
filename="models/advisory/sarcoII/RandomForest_model.pkl", | |
token=self.hf_token | |
) | |
logger.info(f"从HF下载SarcoII建议模型: {sarcoII_rf_path}") | |
else: | |
sarcoII_rf_path = self.advisory_paths['sarcoII'] / "RandomForest_model.pkl" | |
logger.info(f"使用本地SarcoII建议模型: {sarcoII_rf_path}") | |
with open(sarcoII_rf_path, 'rb') as f: | |
loaded_model = pickle.load(f) | |
logger.info(f"🔍 SarcoII建议模型类型: {type(loaded_model)}") | |
# 检查是否是有效的机器学习模型 | |
if hasattr(loaded_model, 'predict_proba'): | |
self.advisory_models['sarcoII'] = loaded_model | |
logger.info("✅ SarcoII建议模型加载成功 - 具备predict_proba方法") | |
else: | |
logger.error(f"❌ SarcoII建议模型无效 - 缺少predict_proba方法,类型: {type(loaded_model)}") | |
# 尝试从字典中提取模型(如果是包装格式) | |
if isinstance(loaded_model, dict) and 'model' in loaded_model: | |
actual_model = loaded_model['model'] | |
logger.info(f"🔄 尝试从字典提取模型: {type(actual_model)}") | |
if hasattr(actual_model, 'predict_proba'): | |
self.advisory_models['sarcoII'] = actual_model | |
logger.info("✅ 从字典成功提取SarcoII建议模型") | |
else: | |
raise ValueError(f"字典中的模型也无效: {type(actual_model)}") | |
else: | |
raise ValueError(f"加载的对象不是有效的机器学习模型: {type(loaded_model)}") | |
logger.info("建议模型加载成功") | |
except Exception as e: | |
logger.error(f"建议模型加载失败: {str(e)}") | |
raise | |
def _load_thresholds(self): | |
"""加载模型阈值配置""" | |
try: | |
# 标准化风险阈值 - 使用统一的风险评估标准 | |
# SarcoI 筛查模型: 低风险<0.12, 中风险0.121-0.149, 高风险≥0.15 | |
sarcoI_screening_threshold = 0.15 # 高风险阈值 | |
# SarcoII 筛查模型: 低风险<0.06, 中风险0.061-0.089, 高风险≥0.09 | |
sarcoII_screening_threshold = 0.09 # 高风险阈值 | |
# SarcoI 建议模型: 低风险<0.31, 中风险0.311-0.359, 高风险≥0.36 | |
sarcoI_advisory_threshold = 0.36 # 高风险阈值 | |
# SarcoII 建议模型: 低风险<0.47, 中风险0.47-0.519, 高风险≥0.52 | |
sarcoII_advisory_threshold = 0.52 # 高风险阈值 | |
# 注释掉动态阈值加载,使用标准化固定阈值确保一致性 | |
# 这样可以避免不同环境下阈值不一致的问题 | |
logger.info("使用标准化固定阈值,确保风险评估一致性") | |
logger.info(f"SarcoI筛查阈值: {sarcoI_screening_threshold} (高风险≥0.15)") | |
logger.info(f"SarcoI建议阈值: {sarcoI_advisory_threshold} (高风险≥0.36)") | |
logger.info(f"SarcoII筛查阈值: {sarcoII_screening_threshold} (高风险≥0.09)") | |
logger.info(f"SarcoII建议阈值: {sarcoII_advisory_threshold} (高风险≥0.52)") | |
# 如果需要动态加载阈值,可以取消以下注释: | |
# if sarcoI_screening_results.exists(): | |
# with open(sarcoI_screening_results, 'rb') as f: | |
# results = pickle.load(f) | |
# if 'rf_best_threshold' in results: | |
# sarcoI_screening_threshold = results['rf_best_threshold'] | |
# elif 'randomforest_best_threshold' in results: | |
# sarcoI_screening_threshold = results['randomforest_best_threshold'] | |
# 设置标准化阈值 | |
self.thresholds = { | |
'sarcoI': { | |
'screening': sarcoI_screening_threshold, # 0.15 (高风险阈值) | |
'advisory': sarcoI_advisory_threshold # 0.36 (高风险阈值) | |
}, | |
'sarcoII': { | |
'screening': sarcoII_screening_threshold, # 0.09 (高风险阈值) | |
'advisory': sarcoII_advisory_threshold # 0.52 (高风险阈值) | |
} | |
} | |
logger.info("标准化阈值配置加载成功") | |
logger.info(f"阈值详情: {self.thresholds}") | |
except Exception as e: | |
logger.error(f"阈值加载失败: {str(e)}") | |
# 使用标准化默认阈值 | |
self.thresholds = { | |
'sarcoI': {'screening': 0.15, 'advisory': 0.36}, | |
'sarcoII': {'screening': 0.09, 'advisory': 0.52} | |
} | |
def predict_screening(self, user_data: Dict[str, float], model_type: str) -> Dict[str, Any]: | |
"""筛查预测 - 高召回率""" | |
try: | |
# 准备特征数据 | |
features_df = self._prepare_features(user_data, model_type, mode='screening') | |
# 模型预测 | |
model = self.screening_models[model_type] | |
probability = model.predict_proba(features_df)[0, 1] | |
threshold = self.thresholds[model_type]['screening'] | |
# 风险分级 | |
risk_level = self._get_risk_level(probability, threshold, mode='screening', model_type=model_type) | |
return { | |
'probability': float(probability), | |
'risk_level': risk_level, | |
'threshold': float(threshold), | |
'model_type': f"{model_type}_screening" | |
} | |
except Exception as e: | |
logger.error(f"{model_type}筛查预测失败: {str(e)}") | |
raise | |
def predict_advisory(self, user_data: Dict[str, float], model_type: str) -> Dict[str, Any]: | |
"""建议预测 - 高精确率""" | |
try: | |
# 调试信息 | |
logger.info(f"🔍 建议预测调试 - 模型类型: {model_type}") | |
logger.info(f"🔍 可用建议模型: {list(self.advisory_models.keys())}") | |
# 检查模型是否存在 | |
if model_type not in self.advisory_models: | |
raise KeyError(f"建议模型 '{model_type}' 不存在,可用模型: {list(self.advisory_models.keys())}") | |
# 准备特征数据 | |
features_df = self._prepare_features(user_data, model_type, mode='advisory') | |
# 模型预测 | |
model = self.advisory_models[model_type] | |
logger.info(f"🔍 使用模型: {type(model)}") | |
probability = model.predict_proba(features_df)[0, 1] | |
threshold = self.thresholds[model_type]['advisory'] | |
# 风险分级 | |
risk_level = self._get_risk_level(probability, threshold, mode='advisory', model_type=model_type) | |
return { | |
'probability': float(probability), | |
'risk_level': risk_level, | |
'threshold': float(threshold), | |
'model_type': f"{model_type}_advisory" | |
} | |
except Exception as e: | |
logger.error(f"{model_type}建议预测失败: {str(e)}") | |
raise | |
def _prepare_features(self, user_data: Dict[str, float], model_type: str, mode: str) -> pd.DataFrame: | |
"""准备模型特征 - 基于实际训练数据的特征顺序""" | |
if model_type == 'sarcoI': | |
if mode == 'screening': | |
# SarcoI筛查模型特征 - 基于实际模型期望的特征顺序 | |
# 模型期望:['age_years', 'WWI', 'body_mass_index'] | |
features = [ | |
'age_years', 'WWI', 'body_mass_index' | |
] | |
else: # advisory | |
# SarcoI建议模型特征 (来自/Users/ning/Desktop/idea/代码forSarcoAdvisor/4.DICE建模/预筛选/SarcoI_train_final.csv) | |
features = [ | |
'body_mass_index', 'race_ethnicity', 'WWI', 'age_years', | |
'Activity_Sedentary_Ratio', 'Total_Moderate_Minutes_week', 'Vigorous_MET_Ratio' | |
] | |
else: # sarcoII | |
if mode == 'screening': | |
# SarcoII筛查模型特征 - 基于CatBoost模型期望的特征顺序 | |
# 模型期望:['WWI', 'age_years', 'race_ethnicity', 'body_mass_index'] | |
features = [ | |
'WWI', 'age_years', 'race_ethnicity', 'body_mass_index' | |
] | |
else: # advisory | |
# SarcoII建议模型特征 (来自/Users/ning/Desktop/idea/代码forSarcoAdvisor/4.DICE建模/预筛选/SarcoII_train_final.csv) | |
features = [ | |
'body_mass_index', 'race_ethnicity', 'age_years', 'Activity_Sedentary_Ratio', | |
'Activity_Diversity_Index', 'WWI', 'Vigorous_MET_Ratio', 'sedentary_minutes' | |
] | |
# 构建特征DataFrame | |
feature_values = [] | |
missing_features = [] | |
for feature in features: | |
if feature in user_data and user_data[feature] is not None: | |
feature_values.append(user_data[feature]) | |
else: | |
# 处理缺失特征的默认值 | |
default_values = { | |
'Total_Moderate_Minutes_week': 150, # WHO推荐值 | |
'Total_Moderate_Equivalent_Minutes': 150, # WHO推荐值 | |
'Activity_Diversity_Index': 2, # 中等多样性 | |
'sedentary_minutes': 480, # 8小时 | |
'Avg_Vigorous_Duration_Per_Bout': 0, # 默认无高强度活动 | |
'arthritis': 0, # 默认无关节炎 | |
'diabetes': 0, # 默认无糖尿病 | |
'Activity_Sedentary_Ratio': 0.5, # 默认活动比例 | |
'Vigorous_MET_Ratio': 0.2 # 默认高强度比例 | |
} | |
default_val = default_values.get(feature, 0) | |
feature_values.append(default_val) | |
missing_features.append(feature) | |
logger.warning(f"特征 '{feature}' 缺失,使用默认值: {default_val}") | |
if missing_features: | |
logger.info(f"{model_type}_{mode} 缺失特征: {missing_features}") | |
features_df = pd.DataFrame([feature_values], columns=features) | |
logger.info(f"{model_type}_{mode} 特征准备完成: {list(features_df.columns)} | 形状: {features_df.shape}") | |
return features_df | |
def _get_risk_level(self, probability: float, threshold: float, mode: str, model_type: str = None) -> str: | |
""" | |
根据概率和标准化阈值确定风险等级 | |
标准化风险阈值: | |
SarcoI 筛查模型: 低风险<0.12, 中风险0.121-0.149, 高风险≥0.15 | |
SarcoI 建议模型: 低风险<0.31, 中风险0.311-0.359, 高风险≥0.36 | |
SarcoII 筛查模型: 低风险<0.06, 中风险0.061-0.089, 高风险≥0.09 | |
SarcoII 建议模型: 低风险<0.47, 中风险0.47-0.519, 高风险≥0.52 | |
Args: | |
probability: 模型预测概率 | |
threshold: 决策阈值 (用于向后兼容,实际使用标准化阈值) | |
mode: 模式 ('screening' 或 'advisory') | |
model_type: 模型类型 ('sarcoI' 或 'sarcoII') | |
Returns: | |
风险等级: 'low', 'medium', 'high' | |
""" | |
# 使用标准化阈值 | |
if model_type == 'sarcoI': | |
if mode == 'screening': | |
# SarcoI筛查模式: 低风险<0.12, 中风险0.121-0.149, 高风险≥0.15 | |
if probability >= 0.15: | |
return 'high' | |
elif probability > 0.12: # 修正:>0.12 而不是 >=0.121 | |
return 'medium' | |
else: | |
return 'low' | |
else: # advisory | |
# SarcoI建议模式: 低风险<0.31, 中风险0.311-0.359, 高风险≥0.36 | |
if probability >= 0.36: | |
return 'high' | |
elif probability > 0.31: # 修正:>0.31 而不是 >=0.311 | |
return 'medium' | |
else: | |
return 'low' | |
elif model_type == 'sarcoII': | |
if mode == 'screening': | |
# SarcoII筛查模式: 低风险<0.06, 中风险0.061-0.089, 高风险≥0.09 | |
if probability >= 0.09: | |
return 'high' | |
elif probability > 0.06: # 修正:>0.06 而不是 >=0.061 | |
return 'medium' | |
else: | |
return 'low' | |
else: # advisory | |
# SarcoII建议模式: 低风险<0.47, 中风险0.47-0.519, 高风险≥0.52 | |
if probability >= 0.52: | |
return 'high' | |
elif probability >= 0.47: # 这个保持不变,因为0.47既是低风险上限也是中风险下限 | |
return 'medium' | |
else: | |
return 'low' | |
else: | |
# 默认逻辑(向后兼容)- 使用传入的threshold | |
if probability >= threshold: | |
return 'high' | |
elif probability >= threshold * 0.75: | |
return 'medium' | |
else: | |
return 'low' | |
def get_comprehensive_risk(self, sarcoI_screening_result: Dict, sarcoI_advisory_result: Dict = None, | |
sarcoII_screening_result: Dict = None, sarcoII_advisory_result: Dict = None) -> Dict: | |
""" | |
计算新的综合风险等级 - 基于建议模型优先的融合方案 | |
Args: | |
sarcoI_screening_result: SarcoI筛查模型结果 | |
sarcoI_advisory_result: SarcoI建议模型结果 (可选) | |
sarcoII_screening_result: SarcoII筛查模型结果 (可选) | |
sarcoII_advisory_result: SarcoII建议模型结果 (可选) | |
Returns: | |
Dict: 包含SarcoI和SarcoII综合风险的字典 | |
""" | |
results = {} | |
# SarcoI 综合风险判定 | |
if sarcoI_screening_result: | |
P_recall_I = sarcoI_screening_result['probability'] | |
P_precision_I = sarcoI_advisory_result['probability'] if sarcoI_advisory_result else 0.0 | |
# 使用实际的模型阈值 | |
sarcoI_advisory_threshold = self.thresholds['sarcoI']['advisory'] | |
sarcoI_screening_threshold = self.thresholds['sarcoI']['screening'] | |
if P_precision_I >= sarcoI_advisory_threshold: # 建议模型高风险阈值 | |
sarcoI_comprehensive_risk = "high" | |
sarcoI_risk_reason = "advisory_model_high_risk" | |
elif P_recall_I >= sarcoI_screening_threshold: # 筛查模型高风险阈值 | |
sarcoI_comprehensive_risk = "medium" | |
sarcoI_risk_reason = "screening_model_risk" | |
else: | |
sarcoI_comprehensive_risk = "low" | |
sarcoI_risk_reason = "both_models_low_risk" | |
results['sarcoI'] = { | |
'comprehensive_risk': sarcoI_comprehensive_risk, | |
'screening_probability': P_recall_I, | |
'advisory_probability': P_precision_I, | |
'risk_reason': sarcoI_risk_reason | |
} | |
# SarcoII 综合风险判定 | |
if sarcoII_screening_result: | |
P_recall_II = sarcoII_screening_result['probability'] | |
P_precision_II = sarcoII_advisory_result['probability'] if sarcoII_advisory_result else 0.0 | |
# 使用实际的模型阈值 | |
sarcoII_advisory_threshold = self.thresholds['sarcoII']['advisory'] | |
sarcoII_screening_threshold = self.thresholds['sarcoII']['screening'] | |
if P_precision_II >= sarcoII_advisory_threshold: # 建议模型高风险阈值 | |
sarcoII_comprehensive_risk = "high" | |
sarcoII_risk_reason = "advisory_model_high_risk" | |
elif P_recall_II >= sarcoII_screening_threshold: # 筛查模型高风险阈值 | |
sarcoII_comprehensive_risk = "medium" | |
sarcoII_risk_reason = "screening_model_risk" | |
else: | |
sarcoII_comprehensive_risk = "low" | |
sarcoII_risk_reason = "both_models_low_risk" | |
results['sarcoII'] = { | |
'comprehensive_risk': sarcoII_comprehensive_risk, | |
'screening_probability': P_recall_II, | |
'advisory_probability': P_precision_II, | |
'risk_reason': sarcoII_risk_reason | |
} | |
return results | |
def get_overall_risk(self, sarcoI_result: Dict, sarcoII_result: Dict) -> str: | |
""" | |
计算综合风险等级 (保持向后兼容) | |
基于两个模型的预测结果,使用更科学的综合评估方法 | |
""" | |
sarcoI_risk = sarcoI_result['risk_level'] | |
sarcoII_risk = sarcoII_result['risk_level'] | |
sarcoI_prob = sarcoI_result['probability'] | |
sarcoII_prob = sarcoII_result['probability'] | |
# 获取各模型的阈值用于标准化 | |
sarcoI_threshold = self.thresholds['sarcoI']['advisory'] | |
sarcoII_threshold = self.thresholds['sarcoII']['advisory'] | |
# 任一高风险则整体高风险 | |
if sarcoI_risk == 'high' or sarcoII_risk == 'high': | |
return 'high' | |
# 任一中风险且另一不是低风险 | |
if (sarcoI_risk == 'medium' and sarcoII_risk != 'low') or \ | |
(sarcoII_risk == 'medium' and sarcoI_risk != 'low'): | |
return 'medium' | |
# 标准化概率后的综合评估 | |
# 将概率标准化到各自阈值 | |
normalized_sarcoI = sarcoI_prob / sarcoI_threshold | |
normalized_sarcoII = sarcoII_prob / sarcoII_threshold | |
# 加权平均 (可以根据模型性能调整权重) | |
weighted_score = (normalized_sarcoI + normalized_sarcoII) / 2 | |
# 基于标准化分数的风险分级 | |
if weighted_score >= 0.75: # 接近阈值的75% | |
return 'medium' | |
return 'low' | |
# 全局模型管理器实例 | |
model_manager = ModelManager() | |
# 立即加载所有模型 | |
try: | |
model_manager.load_all_models() | |
logger.info("🚀 全局模型管理器初始化完成") | |
except Exception as e: | |
logger.error(f"❌ 全局模型管理器初始化失败: {str(e)}") | |
# 不抛出异常,让应用继续运行,但会使用默认行为 |