Spaces:
Running
Running
""" | |
安全模型加载器 - 从私有HF仓库加载模型 | |
用于公开Space但保护模型文件 | |
""" | |
import pickle | |
import pandas as pd | |
import numpy as np | |
import logging | |
import os | |
from pathlib import Path | |
from typing import Dict, Any, Optional | |
import warnings | |
warnings.filterwarnings('ignore') | |
# 安全模型加载 - 从私有HF仓库加载 | |
try: | |
from huggingface_hub import hf_hub_download | |
HF_HUB_AVAILABLE = True | |
except ImportError: | |
HF_HUB_AVAILABLE = False | |
print("⚠️ huggingface_hub未安装,将使用本地模型文件") | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class SecureModelManager: | |
"""安全模型管理器 - 从私有HF仓库加载模型""" | |
def __init__(self): | |
"""初始化安全模型管理器""" | |
self.screening_models = {} | |
self.advisory_models = {} | |
self.thresholds = {} | |
# HF私有仓库配置 | |
self.hf_repo_id = os.getenv("HF_MODEL_REPO", "YOUR_USERNAME/sarco-advisor-models") | |
self.hf_token = os.getenv("HF_TOKEN") | |
self.use_hf_models = HF_HUB_AVAILABLE and self.hf_token | |
if self.use_hf_models: | |
logger.info(f"🔒 使用HF私有仓库加载模型: {self.hf_repo_id}") | |
else: | |
logger.info("📁 回退到本地模型文件") | |
# 导入原始模型管理器作为备用 | |
from .model_loader import ModelManager | |
self.fallback_manager = ModelManager() | |
# 加载所有模型 | |
self.load_all_models() | |
def load_model_from_hf(self, model_path: str): | |
"""从HF私有仓库加载模型""" | |
try: | |
# 下载模型文件到临时位置 | |
local_path = hf_hub_download( | |
repo_id=self.hf_repo_id, | |
filename=model_path, | |
token=self.hf_token, | |
cache_dir="/tmp/hf_models" # 临时缓存,不会被下载 | |
) | |
# 加载模型 | |
with open(local_path, 'rb') as f: | |
model = pickle.load(f) | |
logger.info(f"✅ 从HF仓库加载模型: {model_path}") | |
return model | |
except Exception as e: | |
logger.error(f"❌ HF模型加载失败 {model_path}: {str(e)}") | |
return None | |
def load_all_models(self): | |
"""加载所有模型""" | |
if self.use_hf_models: | |
self._load_models_from_hf() | |
else: | |
self._load_models_locally() | |
def _load_models_from_hf(self): | |
"""从HF私有仓库加载所有模型""" | |
logger.info("🔒 从HF私有仓库加载模型...") | |
# 模型文件映射 | |
model_files = { | |
# 筛查模型 | |
'sarcoI_screening': 'models/screening/sarcoI/randomforest_model.pkl', | |
# 建议模型 | |
'sarcoI_advisory': 'models/advisory/sarcoI/CatBoost_model.pkl', | |
'sarcoII_advisory': 'models/advisory/sarcoII/RandomForest_model.pkl' | |
} | |
# 阈值文件 | |
threshold_files = { | |
'sarcoI_screening': 'models/screening/sarcoI/optimization_results.pkl', | |
'sarcoII_screening': 'models/screening/sarcoII/optimization_results.pkl' | |
} | |
# 加载模型 | |
for model_name, model_path in model_files.items(): | |
model = self.load_model_from_hf(model_path) | |
if model: | |
if 'screening' in model_name: | |
model_type = model_name.replace('_screening', '') | |
self.screening_models[model_type] = model | |
elif 'advisory' in model_name: | |
model_type = model_name.replace('_advisory', '') | |
self.advisory_models[model_type] = model | |
# 加载阈值 | |
for threshold_name, threshold_path in threshold_files.items(): | |
threshold_data = self.load_model_from_hf(threshold_path) | |
if threshold_data: | |
model_type = threshold_name.replace('_screening', '') | |
# 解析阈值数据 | |
if model_type == 'sarcoI': | |
if 'rf_best_threshold' in threshold_data: | |
self.thresholds[model_type] = { | |
'screening': threshold_data['rf_best_threshold'], | |
'advisory': 0.36 # 默认建议模型阈值 | |
} | |
elif model_type == 'sarcoII': | |
if 'catboost_best_threshold' in threshold_data: | |
self.thresholds[model_type] = { | |
'screening': threshold_data['catboost_best_threshold'], | |
'advisory': 0.52 # 默认建议模型阈值 | |
} | |
logger.info("✅ HF模型加载完成") | |
def _load_models_locally(self): | |
"""回退到本地模型加载""" | |
logger.info("📁 使用本地模型文件...") | |
if hasattr(self, 'fallback_manager'): | |
self.fallback_manager.load_all_models() | |
# 复制模型和阈值 | |
self.screening_models = self.fallback_manager.screening_models | |
self.advisory_models = self.fallback_manager.advisory_models | |
self.thresholds = self.fallback_manager.thresholds | |
logger.info("✅ 本地模型加载完成") | |
def predict_screening(self, user_data: Dict, model_type: str) -> Dict: | |
"""筛查预测""" | |
if hasattr(self, 'fallback_manager') and not self.use_hf_models: | |
return self.fallback_manager.predict_screening(user_data, model_type) | |
# HF模型预测逻辑 | |
if model_type not in self.screening_models: | |
raise ValueError(f"筛查模型 {model_type} 未找到") | |
model = self.screening_models[model_type] | |
threshold = self.thresholds.get(model_type, {}).get('screening', 0.5) | |
# 准备特征数据 | |
if model_type == 'sarcoI': | |
features = ['age_years', 'WWI', 'body_mass_index'] | |
else: | |
features = ['age_years', 'race_ethnicity', 'WWI', 'body_mass_index'] | |
X = np.array([[user_data[f] for f in features]]) | |
# 预测 | |
probability = model.predict_proba(X)[0][1] | |
risk_level = 'high' if probability >= threshold else 'low' | |
return { | |
'probability': probability, | |
'risk_level': risk_level, | |
'threshold': threshold, | |
'model_type': model_type | |
} | |
def predict_advisory(self, user_data: Dict, model_type: str) -> Dict: | |
"""建议预测""" | |
if hasattr(self, 'fallback_manager') and not self.use_hf_models: | |
return self.fallback_manager.predict_advisory(user_data, model_type) | |
# HF模型预测逻辑 | |
if model_type not in self.advisory_models: | |
raise ValueError(f"建议模型 {model_type} 未找到") | |
model = self.advisory_models[model_type] | |
threshold = self.thresholds.get(model_type, {}).get('advisory', 0.5) | |
# 准备特征数据(简化版本) | |
if model_type == 'sarcoI': | |
features = ['body_mass_index', 'race_ethnicity', 'WWI', 'age_years', | |
'Activity_Sedentary_Ratio', 'Total_Moderate_Minutes_week', 'Vigorous_MET_Ratio'] | |
else: | |
features = ['body_mass_index', 'race_ethnicity', 'age_years', 'Activity_Sedentary_Ratio', | |
'Activity_Diversity_Index', 'WWI', 'Vigorous_MET_Ratio', 'sedentary_minutes'] | |
# 检查特征是否存在 | |
available_features = [] | |
for f in features: | |
if f in user_data: | |
available_features.append(user_data[f]) | |
else: | |
available_features.append(0.0) # 默认值 | |
X = np.array([available_features]) | |
# 预测 | |
probability = model.predict_proba(X)[0][1] | |
risk_level = 'high' if probability >= threshold else 'low' | |
return { | |
'probability': probability, | |
'risk_level': risk_level, | |
'threshold': threshold, | |
'model_type': model_type | |
} | |
def get_comprehensive_risk(self, sarcoI_screening_result: Dict, sarcoI_advisory_result: Dict = None, | |
sarcoII_screening_result: Dict = None, sarcoII_advisory_result: Dict = None) -> Dict: | |
"""综合风险评估""" | |
if hasattr(self, 'fallback_manager') and not self.use_hf_models: | |
return self.fallback_manager.get_comprehensive_risk( | |
sarcoI_screening_result, sarcoI_advisory_result, | |
sarcoII_screening_result, sarcoII_advisory_result | |
) | |
# 使用与原始模型管理器相同的逻辑 | |
results = {} | |
# SarcoI 综合风险判定 | |
if sarcoI_screening_result: | |
P_recall_I = sarcoI_screening_result['probability'] | |
P_precision_I = sarcoI_advisory_result['probability'] if sarcoI_advisory_result else 0.0 | |
sarcoI_advisory_threshold = self.thresholds.get('sarcoI', {}).get('advisory', 0.36) | |
sarcoI_screening_threshold = self.thresholds.get('sarcoI', {}).get('screening', 0.23) | |
if P_precision_I >= sarcoI_advisory_threshold: | |
sarcoI_comprehensive_risk = "high" | |
sarcoI_risk_reason = "advisory_model_high_risk" | |
elif P_recall_I >= sarcoI_screening_threshold: | |
sarcoI_comprehensive_risk = "medium" | |
sarcoI_risk_reason = "screening_model_risk" | |
else: | |
sarcoI_comprehensive_risk = "low" | |
sarcoI_risk_reason = "both_models_low_risk" | |
results['sarcoI'] = { | |
'comprehensive_risk': sarcoI_comprehensive_risk, | |
'screening_probability': P_recall_I, | |
'advisory_probability': P_precision_I, | |
'risk_reason': sarcoI_risk_reason | |
} | |
# SarcoII 综合风险判定 | |
if sarcoII_screening_result: | |
P_recall_II = sarcoII_screening_result['probability'] | |
P_precision_II = sarcoII_advisory_result['probability'] if sarcoII_advisory_result else 0.0 | |
sarcoII_advisory_threshold = self.thresholds.get('sarcoII', {}).get('advisory', 0.52) | |
sarcoII_screening_threshold = self.thresholds.get('sarcoII', {}).get('screening', 0.15) | |
if P_precision_II >= sarcoII_advisory_threshold: | |
sarcoII_comprehensive_risk = "high" | |
sarcoII_risk_reason = "advisory_model_high_risk" | |
elif P_recall_II >= sarcoII_screening_threshold: | |
sarcoII_comprehensive_risk = "medium" | |
sarcoII_risk_reason = "screening_model_risk" | |
else: | |
sarcoII_comprehensive_risk = "low" | |
sarcoII_risk_reason = "both_models_low_risk" | |
results['sarcoII'] = { | |
'comprehensive_risk': sarcoII_comprehensive_risk, | |
'screening_probability': P_recall_II, | |
'advisory_probability': P_precision_II, | |
'risk_reason': sarcoII_risk_reason | |
} | |
return results | |