Sarco-Monitor / utils /secure_model_loader.py
Ning311's picture
Upload 40 files
ad05511 verified
"""
安全模型加载器 - 从私有HF仓库加载模型
用于公开Space但保护模型文件
"""
import pickle
import pandas as pd
import numpy as np
import logging
import os
from pathlib import Path
from typing import Dict, Any, Optional
import warnings
warnings.filterwarnings('ignore')
# 安全模型加载 - 从私有HF仓库加载
try:
from huggingface_hub import hf_hub_download
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
print("⚠️ huggingface_hub未安装,将使用本地模型文件")
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SecureModelManager:
"""安全模型管理器 - 从私有HF仓库加载模型"""
def __init__(self):
"""初始化安全模型管理器"""
self.screening_models = {}
self.advisory_models = {}
self.thresholds = {}
# HF私有仓库配置
self.hf_repo_id = os.getenv("HF_MODEL_REPO", "YOUR_USERNAME/sarco-advisor-models")
self.hf_token = os.getenv("HF_TOKEN")
self.use_hf_models = HF_HUB_AVAILABLE and self.hf_token
if self.use_hf_models:
logger.info(f"🔒 使用HF私有仓库加载模型: {self.hf_repo_id}")
else:
logger.info("📁 回退到本地模型文件")
# 导入原始模型管理器作为备用
from .model_loader import ModelManager
self.fallback_manager = ModelManager()
# 加载所有模型
self.load_all_models()
def load_model_from_hf(self, model_path: str):
"""从HF私有仓库加载模型"""
try:
# 下载模型文件到临时位置
local_path = hf_hub_download(
repo_id=self.hf_repo_id,
filename=model_path,
token=self.hf_token,
cache_dir="/tmp/hf_models" # 临时缓存,不会被下载
)
# 加载模型
with open(local_path, 'rb') as f:
model = pickle.load(f)
logger.info(f"✅ 从HF仓库加载模型: {model_path}")
return model
except Exception as e:
logger.error(f"❌ HF模型加载失败 {model_path}: {str(e)}")
return None
def load_all_models(self):
"""加载所有模型"""
if self.use_hf_models:
self._load_models_from_hf()
else:
self._load_models_locally()
def _load_models_from_hf(self):
"""从HF私有仓库加载所有模型"""
logger.info("🔒 从HF私有仓库加载模型...")
# 模型文件映射
model_files = {
# 筛查模型
'sarcoI_screening': 'models/screening/sarcoI/randomforest_model.pkl',
# 建议模型
'sarcoI_advisory': 'models/advisory/sarcoI/CatBoost_model.pkl',
'sarcoII_advisory': 'models/advisory/sarcoII/RandomForest_model.pkl'
}
# 阈值文件
threshold_files = {
'sarcoI_screening': 'models/screening/sarcoI/optimization_results.pkl',
'sarcoII_screening': 'models/screening/sarcoII/optimization_results.pkl'
}
# 加载模型
for model_name, model_path in model_files.items():
model = self.load_model_from_hf(model_path)
if model:
if 'screening' in model_name:
model_type = model_name.replace('_screening', '')
self.screening_models[model_type] = model
elif 'advisory' in model_name:
model_type = model_name.replace('_advisory', '')
self.advisory_models[model_type] = model
# 加载阈值
for threshold_name, threshold_path in threshold_files.items():
threshold_data = self.load_model_from_hf(threshold_path)
if threshold_data:
model_type = threshold_name.replace('_screening', '')
# 解析阈值数据
if model_type == 'sarcoI':
if 'rf_best_threshold' in threshold_data:
self.thresholds[model_type] = {
'screening': threshold_data['rf_best_threshold'],
'advisory': 0.36 # 默认建议模型阈值
}
elif model_type == 'sarcoII':
if 'catboost_best_threshold' in threshold_data:
self.thresholds[model_type] = {
'screening': threshold_data['catboost_best_threshold'],
'advisory': 0.52 # 默认建议模型阈值
}
logger.info("✅ HF模型加载完成")
def _load_models_locally(self):
"""回退到本地模型加载"""
logger.info("📁 使用本地模型文件...")
if hasattr(self, 'fallback_manager'):
self.fallback_manager.load_all_models()
# 复制模型和阈值
self.screening_models = self.fallback_manager.screening_models
self.advisory_models = self.fallback_manager.advisory_models
self.thresholds = self.fallback_manager.thresholds
logger.info("✅ 本地模型加载完成")
def predict_screening(self, user_data: Dict, model_type: str) -> Dict:
"""筛查预测"""
if hasattr(self, 'fallback_manager') and not self.use_hf_models:
return self.fallback_manager.predict_screening(user_data, model_type)
# HF模型预测逻辑
if model_type not in self.screening_models:
raise ValueError(f"筛查模型 {model_type} 未找到")
model = self.screening_models[model_type]
threshold = self.thresholds.get(model_type, {}).get('screening', 0.5)
# 准备特征数据
if model_type == 'sarcoI':
features = ['age_years', 'WWI', 'body_mass_index']
else:
features = ['age_years', 'race_ethnicity', 'WWI', 'body_mass_index']
X = np.array([[user_data[f] for f in features]])
# 预测
probability = model.predict_proba(X)[0][1]
risk_level = 'high' if probability >= threshold else 'low'
return {
'probability': probability,
'risk_level': risk_level,
'threshold': threshold,
'model_type': model_type
}
def predict_advisory(self, user_data: Dict, model_type: str) -> Dict:
"""建议预测"""
if hasattr(self, 'fallback_manager') and not self.use_hf_models:
return self.fallback_manager.predict_advisory(user_data, model_type)
# HF模型预测逻辑
if model_type not in self.advisory_models:
raise ValueError(f"建议模型 {model_type} 未找到")
model = self.advisory_models[model_type]
threshold = self.thresholds.get(model_type, {}).get('advisory', 0.5)
# 准备特征数据(简化版本)
if model_type == 'sarcoI':
features = ['body_mass_index', 'race_ethnicity', 'WWI', 'age_years',
'Activity_Sedentary_Ratio', 'Total_Moderate_Minutes_week', 'Vigorous_MET_Ratio']
else:
features = ['body_mass_index', 'race_ethnicity', 'age_years', 'Activity_Sedentary_Ratio',
'Activity_Diversity_Index', 'WWI', 'Vigorous_MET_Ratio', 'sedentary_minutes']
# 检查特征是否存在
available_features = []
for f in features:
if f in user_data:
available_features.append(user_data[f])
else:
available_features.append(0.0) # 默认值
X = np.array([available_features])
# 预测
probability = model.predict_proba(X)[0][1]
risk_level = 'high' if probability >= threshold else 'low'
return {
'probability': probability,
'risk_level': risk_level,
'threshold': threshold,
'model_type': model_type
}
def get_comprehensive_risk(self, sarcoI_screening_result: Dict, sarcoI_advisory_result: Dict = None,
sarcoII_screening_result: Dict = None, sarcoII_advisory_result: Dict = None) -> Dict:
"""综合风险评估"""
if hasattr(self, 'fallback_manager') and not self.use_hf_models:
return self.fallback_manager.get_comprehensive_risk(
sarcoI_screening_result, sarcoI_advisory_result,
sarcoII_screening_result, sarcoII_advisory_result
)
# 使用与原始模型管理器相同的逻辑
results = {}
# SarcoI 综合风险判定
if sarcoI_screening_result:
P_recall_I = sarcoI_screening_result['probability']
P_precision_I = sarcoI_advisory_result['probability'] if sarcoI_advisory_result else 0.0
sarcoI_advisory_threshold = self.thresholds.get('sarcoI', {}).get('advisory', 0.36)
sarcoI_screening_threshold = self.thresholds.get('sarcoI', {}).get('screening', 0.23)
if P_precision_I >= sarcoI_advisory_threshold:
sarcoI_comprehensive_risk = "high"
sarcoI_risk_reason = "advisory_model_high_risk"
elif P_recall_I >= sarcoI_screening_threshold:
sarcoI_comprehensive_risk = "medium"
sarcoI_risk_reason = "screening_model_risk"
else:
sarcoI_comprehensive_risk = "low"
sarcoI_risk_reason = "both_models_low_risk"
results['sarcoI'] = {
'comprehensive_risk': sarcoI_comprehensive_risk,
'screening_probability': P_recall_I,
'advisory_probability': P_precision_I,
'risk_reason': sarcoI_risk_reason
}
# SarcoII 综合风险判定
if sarcoII_screening_result:
P_recall_II = sarcoII_screening_result['probability']
P_precision_II = sarcoII_advisory_result['probability'] if sarcoII_advisory_result else 0.0
sarcoII_advisory_threshold = self.thresholds.get('sarcoII', {}).get('advisory', 0.52)
sarcoII_screening_threshold = self.thresholds.get('sarcoII', {}).get('screening', 0.15)
if P_precision_II >= sarcoII_advisory_threshold:
sarcoII_comprehensive_risk = "high"
sarcoII_risk_reason = "advisory_model_high_risk"
elif P_recall_II >= sarcoII_screening_threshold:
sarcoII_comprehensive_risk = "medium"
sarcoII_risk_reason = "screening_model_risk"
else:
sarcoII_comprehensive_risk = "low"
sarcoII_risk_reason = "both_models_low_risk"
results['sarcoII'] = {
'comprehensive_risk': sarcoII_comprehensive_risk,
'screening_probability': P_recall_II,
'advisory_probability': P_precision_II,
'risk_reason': sarcoII_risk_reason
}
return results