""" 安全模型加载器 - 从私有HF仓库加载模型 用于公开Space但保护模型文件 """ import pickle import pandas as pd import numpy as np import logging import os from pathlib import Path from typing import Dict, Any, Optional import warnings warnings.filterwarnings('ignore') # 安全模型加载 - 从私有HF仓库加载 try: from huggingface_hub import hf_hub_download HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False print("⚠️ huggingface_hub未安装,将使用本地模型文件") # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class SecureModelManager: """安全模型管理器 - 从私有HF仓库加载模型""" def __init__(self): """初始化安全模型管理器""" self.screening_models = {} self.advisory_models = {} self.thresholds = {} # HF私有仓库配置 self.hf_repo_id = os.getenv("HF_MODEL_REPO", "YOUR_USERNAME/sarco-advisor-models") self.hf_token = os.getenv("HF_TOKEN") self.use_hf_models = HF_HUB_AVAILABLE and self.hf_token if self.use_hf_models: logger.info(f"🔒 使用HF私有仓库加载模型: {self.hf_repo_id}") else: logger.info("📁 回退到本地模型文件") # 导入原始模型管理器作为备用 from .model_loader import ModelManager self.fallback_manager = ModelManager() # 加载所有模型 self.load_all_models() def load_model_from_hf(self, model_path: str): """从HF私有仓库加载模型""" try: # 下载模型文件到临时位置 local_path = hf_hub_download( repo_id=self.hf_repo_id, filename=model_path, token=self.hf_token, cache_dir="/tmp/hf_models" # 临时缓存,不会被下载 ) # 加载模型 with open(local_path, 'rb') as f: model = pickle.load(f) logger.info(f"✅ 从HF仓库加载模型: {model_path}") return model except Exception as e: logger.error(f"❌ HF模型加载失败 {model_path}: {str(e)}") return None def load_all_models(self): """加载所有模型""" if self.use_hf_models: self._load_models_from_hf() else: self._load_models_locally() def _load_models_from_hf(self): """从HF私有仓库加载所有模型""" logger.info("🔒 从HF私有仓库加载模型...") # 模型文件映射 model_files = { # 筛查模型 'sarcoI_screening': 'models/screening/sarcoI/randomforest_model.pkl', # 建议模型 'sarcoI_advisory': 'models/advisory/sarcoI/CatBoost_model.pkl', 'sarcoII_advisory': 'models/advisory/sarcoII/RandomForest_model.pkl' } # 阈值文件 threshold_files = { 'sarcoI_screening': 'models/screening/sarcoI/optimization_results.pkl', 'sarcoII_screening': 'models/screening/sarcoII/optimization_results.pkl' } # 加载模型 for model_name, model_path in model_files.items(): model = self.load_model_from_hf(model_path) if model: if 'screening' in model_name: model_type = model_name.replace('_screening', '') self.screening_models[model_type] = model elif 'advisory' in model_name: model_type = model_name.replace('_advisory', '') self.advisory_models[model_type] = model # 加载阈值 for threshold_name, threshold_path in threshold_files.items(): threshold_data = self.load_model_from_hf(threshold_path) if threshold_data: model_type = threshold_name.replace('_screening', '') # 解析阈值数据 if model_type == 'sarcoI': if 'rf_best_threshold' in threshold_data: self.thresholds[model_type] = { 'screening': threshold_data['rf_best_threshold'], 'advisory': 0.36 # 默认建议模型阈值 } elif model_type == 'sarcoII': if 'catboost_best_threshold' in threshold_data: self.thresholds[model_type] = { 'screening': threshold_data['catboost_best_threshold'], 'advisory': 0.52 # 默认建议模型阈值 } logger.info("✅ HF模型加载完成") def _load_models_locally(self): """回退到本地模型加载""" logger.info("📁 使用本地模型文件...") if hasattr(self, 'fallback_manager'): self.fallback_manager.load_all_models() # 复制模型和阈值 self.screening_models = self.fallback_manager.screening_models self.advisory_models = self.fallback_manager.advisory_models self.thresholds = self.fallback_manager.thresholds logger.info("✅ 本地模型加载完成") def predict_screening(self, user_data: Dict, model_type: str) -> Dict: """筛查预测""" if hasattr(self, 'fallback_manager') and not self.use_hf_models: return self.fallback_manager.predict_screening(user_data, model_type) # HF模型预测逻辑 if model_type not in self.screening_models: raise ValueError(f"筛查模型 {model_type} 未找到") model = self.screening_models[model_type] threshold = self.thresholds.get(model_type, {}).get('screening', 0.5) # 准备特征数据 if model_type == 'sarcoI': features = ['age_years', 'WWI', 'body_mass_index'] else: features = ['age_years', 'race_ethnicity', 'WWI', 'body_mass_index'] X = np.array([[user_data[f] for f in features]]) # 预测 probability = model.predict_proba(X)[0][1] risk_level = 'high' if probability >= threshold else 'low' return { 'probability': probability, 'risk_level': risk_level, 'threshold': threshold, 'model_type': model_type } def predict_advisory(self, user_data: Dict, model_type: str) -> Dict: """建议预测""" if hasattr(self, 'fallback_manager') and not self.use_hf_models: return self.fallback_manager.predict_advisory(user_data, model_type) # HF模型预测逻辑 if model_type not in self.advisory_models: raise ValueError(f"建议模型 {model_type} 未找到") model = self.advisory_models[model_type] threshold = self.thresholds.get(model_type, {}).get('advisory', 0.5) # 准备特征数据(简化版本) if model_type == 'sarcoI': features = ['body_mass_index', 'race_ethnicity', 'WWI', 'age_years', 'Activity_Sedentary_Ratio', 'Total_Moderate_Minutes_week', 'Vigorous_MET_Ratio'] else: features = ['body_mass_index', 'race_ethnicity', 'age_years', 'Activity_Sedentary_Ratio', 'Activity_Diversity_Index', 'WWI', 'Vigorous_MET_Ratio', 'sedentary_minutes'] # 检查特征是否存在 available_features = [] for f in features: if f in user_data: available_features.append(user_data[f]) else: available_features.append(0.0) # 默认值 X = np.array([available_features]) # 预测 probability = model.predict_proba(X)[0][1] risk_level = 'high' if probability >= threshold else 'low' return { 'probability': probability, 'risk_level': risk_level, 'threshold': threshold, 'model_type': model_type } def get_comprehensive_risk(self, sarcoI_screening_result: Dict, sarcoI_advisory_result: Dict = None, sarcoII_screening_result: Dict = None, sarcoII_advisory_result: Dict = None) -> Dict: """综合风险评估""" if hasattr(self, 'fallback_manager') and not self.use_hf_models: return self.fallback_manager.get_comprehensive_risk( sarcoI_screening_result, sarcoI_advisory_result, sarcoII_screening_result, sarcoII_advisory_result ) # 使用与原始模型管理器相同的逻辑 results = {} # SarcoI 综合风险判定 if sarcoI_screening_result: P_recall_I = sarcoI_screening_result['probability'] P_precision_I = sarcoI_advisory_result['probability'] if sarcoI_advisory_result else 0.0 sarcoI_advisory_threshold = self.thresholds.get('sarcoI', {}).get('advisory', 0.36) sarcoI_screening_threshold = self.thresholds.get('sarcoI', {}).get('screening', 0.23) if P_precision_I >= sarcoI_advisory_threshold: sarcoI_comprehensive_risk = "high" sarcoI_risk_reason = "advisory_model_high_risk" elif P_recall_I >= sarcoI_screening_threshold: sarcoI_comprehensive_risk = "medium" sarcoI_risk_reason = "screening_model_risk" else: sarcoI_comprehensive_risk = "low" sarcoI_risk_reason = "both_models_low_risk" results['sarcoI'] = { 'comprehensive_risk': sarcoI_comprehensive_risk, 'screening_probability': P_recall_I, 'advisory_probability': P_precision_I, 'risk_reason': sarcoI_risk_reason } # SarcoII 综合风险判定 if sarcoII_screening_result: P_recall_II = sarcoII_screening_result['probability'] P_precision_II = sarcoII_advisory_result['probability'] if sarcoII_advisory_result else 0.0 sarcoII_advisory_threshold = self.thresholds.get('sarcoII', {}).get('advisory', 0.52) sarcoII_screening_threshold = self.thresholds.get('sarcoII', {}).get('screening', 0.15) if P_precision_II >= sarcoII_advisory_threshold: sarcoII_comprehensive_risk = "high" sarcoII_risk_reason = "advisory_model_high_risk" elif P_recall_II >= sarcoII_screening_threshold: sarcoII_comprehensive_risk = "medium" sarcoII_risk_reason = "screening_model_risk" else: sarcoII_comprehensive_risk = "low" sarcoII_risk_reason = "both_models_low_risk" results['sarcoII'] = { 'comprehensive_risk': sarcoII_comprehensive_risk, 'screening_probability': P_recall_II, 'advisory_probability': P_precision_II, 'risk_reason': sarcoII_risk_reason } return results