""" 评估指标计算工具 """ import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import ( accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report, cohen_kappa_score, roc_auc_score, roc_curve, auc ) from sklearn.preprocessing import label_binarize import pandas as pd from typing import List, Dict, Tuple, Optional import warnings warnings.filterwarnings('ignore') def calculate_metrics(y_true: List[int], y_pred: List[int], class_names: List[str] = None) -> Dict: """ 计算分类指标 Args: y_true: 真实标签 y_pred: 预测标签 class_names: 类别名称 Returns: Dict: 指标字典 """ # 基本指标 accuracy = accuracy_score(y_true, y_pred) precision, recall, f1, support = precision_recall_fscore_support( y_true, y_pred, average=None, zero_division=0 ) # 宏观和微观平均 macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support( y_true, y_pred, average='macro', zero_division=0 ) micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support( y_true, y_pred, average='micro', zero_division=0 ) weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support( y_true, y_pred, average='weighted', zero_division=0 ) # Cohen's Kappa kappa = cohen_kappa_score(y_true, y_pred) # 混淆矩阵 cm = confusion_matrix(y_true, y_pred) metrics = { 'accuracy': accuracy, 'macro_precision': macro_precision, 'macro_recall': macro_recall, 'macro_f1': macro_f1, 'micro_precision': micro_precision, 'micro_recall': micro_recall, 'micro_f1': micro_f1, 'weighted_precision': weighted_precision, 'weighted_recall': weighted_recall, 'weighted_f1': weighted_f1, 'cohen_kappa': kappa, 'confusion_matrix': cm, 'support': support } # 每个类别的指标 if class_names is None: class_names = [f'Class_{i}' for i in range(len(precision))] for i, class_name in enumerate(class_names): if i < len(precision): metrics[f'{class_name}_precision'] = precision[i] metrics[f'{class_name}_recall'] = recall[i] metrics[f'{class_name}_f1'] = f1[i] metrics[f'{class_name}_support'] = support[i] return metrics def calculate_multiclass_auc(y_true: np.ndarray, y_scores: np.ndarray, num_classes: int) -> Dict: """ 计算多类别AUC指标 Args: y_true: 真实标签 (one-hot或标签编码) y_scores: 预测概率 num_classes: 类别数量 Returns: Dict: AUC指标 """ # 转换为one-hot编码 if y_true.ndim == 1: y_true_binary = label_binarize(y_true, classes=range(num_classes)) if num_classes == 2: y_true_binary = np.hstack([1-y_true_binary, y_true_binary]) else: y_true_binary = y_true # 计算每个类别的AUC auc_scores = {} fpr = {} tpr = {} for i in range(num_classes): if np.sum(y_true_binary[:, i]) > 0: # 确保类别存在 fpr[i], tpr[i], _ = roc_curve(y_true_binary[:, i], y_scores[:, i]) auc_scores[f'class_{i}_auc'] = auc(fpr[i], tpr[i]) # 宏观平均AUC if len(auc_scores) > 0: auc_scores['macro_auc'] = np.mean(list(auc_scores.values())) # 微观平均AUC(多类别) try: micro_auc = roc_auc_score(y_true_binary, y_scores, average='micro', multi_class='ovr') auc_scores['micro_auc'] = micro_auc except: pass auc_scores['fpr'] = fpr auc_scores['tpr'] = tpr return auc_scores def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], normalize: bool = False, title: str = 'Confusion Matrix', save_path: str = None, figsize: Tuple[int, int] = (10, 8)): """ 绘制混淆矩阵 Args: cm: 混淆矩阵 class_names: 类别名称 normalize: 是否标准化 title: 图标题 save_path: 保存路径 figsize: 图像大小 """ if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] fmt = '.2f' else: fmt = 'd' plt.figure(figsize=figsize) sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues', xticklabels=class_names, yticklabels=class_names, cbar_kws={'label': 'Count' if not normalize else 'Proportion'}) plt.title(title, fontsize=16) plt.xlabel('Predicted Label', fontsize=14) plt.ylabel('True Label', fontsize=14) plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"混淆矩阵已保存: {save_path}") return plt.gcf() def plot_roc_curves(fpr: Dict, tpr: Dict, auc_scores: Dict, class_names: List[str], save_path: str = None, figsize: Tuple[int, int] = (12, 8)): """ 绘制ROC曲线 Args: fpr: 假正率字典 tpr: 真正率字典 auc_scores: AUC分数字典 class_names: 类别名称 save_path: 保存路径 figsize: 图像大小 """ plt.figure(figsize=figsize) # 绘制每个类别的ROC曲线 colors = plt.cm.Set1(np.linspace(0, 1, len(class_names))) for i, (color, class_name) in enumerate(zip(colors, class_names)): if i in fpr and i in tpr: auc_score = auc_scores.get(f'class_{i}_auc', 0) plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_name} (AUC = {auc_score:.3f})') # 对角线 plt.plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.8) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate', fontsize=14) plt.ylabel('True Positive Rate', fontsize=14) plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=16) plt.legend(loc="lower right") plt.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"ROC曲线已保存: {save_path}") return plt.gcf() def plot_precision_recall_curve(y_true: np.ndarray, y_scores: np.ndarray, class_names: List[str], save_path: str = None, figsize: Tuple[int, int] = (12, 8)): """ 绘制Precision-Recall曲线 Args: y_true: 真实标签 (one-hot编码) y_scores: 预测概率 class_names: 类别名称 save_path: 保存路径 figsize: 图像大小 """ from sklearn.metrics import precision_recall_curve, average_precision_score plt.figure(figsize=figsize) colors = plt.cm.Set1(np.linspace(0, 1, len(class_names))) for i, (color, class_name) in enumerate(zip(colors, class_names)): if i < y_scores.shape[1]: precision, recall, _ = precision_recall_curve(y_true[:, i], y_scores[:, i]) ap_score = average_precision_score(y_true[:, i], y_scores[:, i]) plt.plot(recall, precision, color=color, lw=2, label=f'{class_name} (AP = {ap_score:.3f})') plt.xlabel('Recall', fontsize=14) plt.ylabel('Precision', fontsize=14) plt.title('Precision-Recall Curves', fontsize=16) plt.legend(loc="lower left") plt.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Precision-Recall曲线已保存: {save_path}") return plt.gcf() def create_classification_report(y_true: List[int], y_pred: List[int], class_names: List[str] = None, save_path: str = None) -> pd.DataFrame: """ 创建分类报告DataFrame Args: y_true: 真实标签 y_pred: 预测标签 class_names: 类别名称 save_path: 保存路径 Returns: pd.DataFrame: 分类报告 """ report_dict = classification_report( y_true, y_pred, target_names=class_names, output_dict=True, zero_division=0 ) df_report = pd.DataFrame(report_dict).transpose() if save_path: df_report.to_csv(save_path, encoding='utf-8') print(f"分类报告已保存: {save_path}") return df_report def evaluate_model_comprehensive(y_true: List[int], y_pred: List[int], y_scores: Optional[np.ndarray] = None, class_names: List[str] = None, output_dir: str = 'evaluation_results') -> Dict: """ 综合评估模型 Args: y_true: 真实标签 y_pred: 预测标签 y_scores: 预测概率(用于AUC计算) class_names: 类别名称 output_dir: 输出目录 Returns: Dict: 评估结果 """ import os os.makedirs(output_dir, exist_ok=True) # 基础指标 metrics = calculate_metrics(y_true, y_pred, class_names) # 混淆矩阵可视化 cm = metrics['confusion_matrix'] # 原始计数 plot_confusion_matrix( cm, class_names, normalize=False, title='Confusion Matrix (Counts)', save_path=os.path.join(output_dir, 'confusion_matrix_counts.png') ) plt.close() # 标准化 plot_confusion_matrix( cm, class_names, normalize=True, title='Confusion Matrix (Normalized)', save_path=os.path.join(output_dir, 'confusion_matrix_normalized.png') ) plt.close() # 分类报告 report_df = create_classification_report( y_true, y_pred, class_names, save_path=os.path.join(output_dir, 'classification_report.csv') ) # AUC相关指标(如果提供了概率) if y_scores is not None: num_classes = len(class_names) if class_names else len(np.unique(y_true)) # 转换标签为one-hot编码 y_true_binary = label_binarize(y_true, classes=range(num_classes)) if num_classes == 2: y_true_binary = np.hstack([1-y_true_binary, y_true_binary]) # 计算AUC auc_metrics = calculate_multiclass_auc(y_true_binary, y_scores, num_classes) metrics.update(auc_metrics) # ROC曲线 plot_roc_curves( auc_metrics['fpr'], auc_metrics['tpr'], auc_metrics, class_names, save_path=os.path.join(output_dir, 'roc_curves.png') ) plt.close() # Precision-Recall曲线 plot_precision_recall_curve( y_true_binary, y_scores, class_names, save_path=os.path.join(output_dir, 'precision_recall_curves.png') ) plt.close() # 保存指标到文件 metrics_df = pd.DataFrame([ {'metric': k, 'value': v} for k, v in metrics.items() if isinstance(v, (int, float, np.number)) ]) metrics_df.to_csv(os.path.join(output_dir, 'metrics_summary.csv'), index=False) print(f"评估结果已保存到目录: {output_dir}") return metrics def print_metrics_summary(metrics: Dict, class_names: List[str] = None): """ 打印指标摘要 Args: metrics: 指标字典 class_names: 类别名称 """ print("=" * 60) print("模型评估结果摘要") print("=" * 60) print(f"总体准确率: {metrics['accuracy']:.4f}") print(f"Cohen's Kappa: {metrics['cohen_kappa']:.4f}") print() print("宏观平均:") print(f" 精确率: {metrics['macro_precision']:.4f}") print(f" 召回率: {metrics['macro_recall']:.4f}") print(f" F1分数: {metrics['macro_f1']:.4f}") print() print("加权平均:") print(f" 精确率: {metrics['weighted_precision']:.4f}") print(f" 召回率: {metrics['weighted_recall']:.4f}") print(f" F1分数: {metrics['weighted_f1']:.4f}") print() if class_names: print("各类别性能:") print("-" * 60) print(f"{'类别':<15} {'精确率':<10} {'召回率':<10} {'F1分数':<10} {'样本数':<10}") print("-" * 60) for i, class_name in enumerate(class_names): precision_key = f'{class_name}_precision' recall_key = f'{class_name}_recall' f1_key = f'{class_name}_f1' support_key = f'{class_name}_support' if all(key in metrics for key in [precision_key, recall_key, f1_key, support_key]): print(f"{class_name:<15} " f"{metrics[precision_key]:<10.4f} " f"{metrics[recall_key]:<10.4f} " f"{metrics[f1_key]:<10.4f} " f"{int(metrics[support_key]):<10}") if 'macro_auc' in metrics: print(f"\n宏观平均AUC: {metrics['macro_auc']:.4f}") if 'micro_auc' in metrics: print(f"微观平均AUC: {metrics['micro_auc']:.4f}") print("=" * 60) if __name__ == "__main__": # 测试评估指标 np.random.seed(42) # 生成模拟数据 n_samples = 1000 n_classes = 5 y_true = np.random.randint(0, n_classes, n_samples) y_pred = np.random.randint(0, n_classes, n_samples) y_scores = np.random.rand(n_samples, n_classes) y_scores = y_scores / y_scores.sum(axis=1, keepdims=True) # 归一化为概率 class_names = ['无病变', '轻度', '中度', '重度', '增殖性'] # 综合评估 results = evaluate_model_comprehensive( y_true, y_pred, y_scores, class_names, 'test_evaluation' ) # 打印摘要 print_metrics_summary(results, class_names)