Spaces:
Running
Running
""" | |
评估指标计算工具 | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from sklearn.metrics import ( | |
accuracy_score, precision_recall_fscore_support, | |
confusion_matrix, classification_report, cohen_kappa_score, | |
roc_auc_score, roc_curve, auc | |
) | |
from sklearn.preprocessing import label_binarize | |
import pandas as pd | |
from typing import List, Dict, Tuple, Optional | |
import warnings | |
warnings.filterwarnings('ignore') | |
def calculate_metrics(y_true: List[int], y_pred: List[int], | |
class_names: List[str] = None) -> Dict: | |
""" | |
计算分类指标 | |
Args: | |
y_true: 真实标签 | |
y_pred: 预测标签 | |
class_names: 类别名称 | |
Returns: | |
Dict: 指标字典 | |
""" | |
# 基本指标 | |
accuracy = accuracy_score(y_true, y_pred) | |
precision, recall, f1, support = precision_recall_fscore_support( | |
y_true, y_pred, average=None, zero_division=0 | |
) | |
# 宏观和微观平均 | |
macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support( | |
y_true, y_pred, average='macro', zero_division=0 | |
) | |
micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support( | |
y_true, y_pred, average='micro', zero_division=0 | |
) | |
weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support( | |
y_true, y_pred, average='weighted', zero_division=0 | |
) | |
# Cohen's Kappa | |
kappa = cohen_kappa_score(y_true, y_pred) | |
# 混淆矩阵 | |
cm = confusion_matrix(y_true, y_pred) | |
metrics = { | |
'accuracy': accuracy, | |
'macro_precision': macro_precision, | |
'macro_recall': macro_recall, | |
'macro_f1': macro_f1, | |
'micro_precision': micro_precision, | |
'micro_recall': micro_recall, | |
'micro_f1': micro_f1, | |
'weighted_precision': weighted_precision, | |
'weighted_recall': weighted_recall, | |
'weighted_f1': weighted_f1, | |
'cohen_kappa': kappa, | |
'confusion_matrix': cm, | |
'support': support | |
} | |
# 每个类别的指标 | |
if class_names is None: | |
class_names = [f'Class_{i}' for i in range(len(precision))] | |
for i, class_name in enumerate(class_names): | |
if i < len(precision): | |
metrics[f'{class_name}_precision'] = precision[i] | |
metrics[f'{class_name}_recall'] = recall[i] | |
metrics[f'{class_name}_f1'] = f1[i] | |
metrics[f'{class_name}_support'] = support[i] | |
return metrics | |
def calculate_multiclass_auc(y_true: np.ndarray, y_scores: np.ndarray, | |
num_classes: int) -> Dict: | |
""" | |
计算多类别AUC指标 | |
Args: | |
y_true: 真实标签 (one-hot或标签编码) | |
y_scores: 预测概率 | |
num_classes: 类别数量 | |
Returns: | |
Dict: AUC指标 | |
""" | |
# 转换为one-hot编码 | |
if y_true.ndim == 1: | |
y_true_binary = label_binarize(y_true, classes=range(num_classes)) | |
if num_classes == 2: | |
y_true_binary = np.hstack([1-y_true_binary, y_true_binary]) | |
else: | |
y_true_binary = y_true | |
# 计算每个类别的AUC | |
auc_scores = {} | |
fpr = {} | |
tpr = {} | |
for i in range(num_classes): | |
if np.sum(y_true_binary[:, i]) > 0: # 确保类别存在 | |
fpr[i], tpr[i], _ = roc_curve(y_true_binary[:, i], y_scores[:, i]) | |
auc_scores[f'class_{i}_auc'] = auc(fpr[i], tpr[i]) | |
# 宏观平均AUC | |
if len(auc_scores) > 0: | |
auc_scores['macro_auc'] = np.mean(list(auc_scores.values())) | |
# 微观平均AUC(多类别) | |
try: | |
micro_auc = roc_auc_score(y_true_binary, y_scores, average='micro', multi_class='ovr') | |
auc_scores['micro_auc'] = micro_auc | |
except: | |
pass | |
auc_scores['fpr'] = fpr | |
auc_scores['tpr'] = tpr | |
return auc_scores | |
def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], | |
normalize: bool = False, title: str = 'Confusion Matrix', | |
save_path: str = None, figsize: Tuple[int, int] = (10, 8)): | |
""" | |
绘制混淆矩阵 | |
Args: | |
cm: 混淆矩阵 | |
class_names: 类别名称 | |
normalize: 是否标准化 | |
title: 图标题 | |
save_path: 保存路径 | |
figsize: 图像大小 | |
""" | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
fmt = '.2f' | |
else: | |
fmt = 'd' | |
plt.figure(figsize=figsize) | |
sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues', | |
xticklabels=class_names, yticklabels=class_names, | |
cbar_kws={'label': 'Count' if not normalize else 'Proportion'}) | |
plt.title(title, fontsize=16) | |
plt.xlabel('Predicted Label', fontsize=14) | |
plt.ylabel('True Label', fontsize=14) | |
plt.xticks(rotation=45) | |
plt.yticks(rotation=0) | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"混淆矩阵已保存: {save_path}") | |
return plt.gcf() | |
def plot_roc_curves(fpr: Dict, tpr: Dict, auc_scores: Dict, | |
class_names: List[str], save_path: str = None, | |
figsize: Tuple[int, int] = (12, 8)): | |
""" | |
绘制ROC曲线 | |
Args: | |
fpr: 假正率字典 | |
tpr: 真正率字典 | |
auc_scores: AUC分数字典 | |
class_names: 类别名称 | |
save_path: 保存路径 | |
figsize: 图像大小 | |
""" | |
plt.figure(figsize=figsize) | |
# 绘制每个类别的ROC曲线 | |
colors = plt.cm.Set1(np.linspace(0, 1, len(class_names))) | |
for i, (color, class_name) in enumerate(zip(colors, class_names)): | |
if i in fpr and i in tpr: | |
auc_score = auc_scores.get(f'class_{i}_auc', 0) | |
plt.plot(fpr[i], tpr[i], color=color, lw=2, | |
label=f'{class_name} (AUC = {auc_score:.3f})') | |
# 对角线 | |
plt.plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.8) | |
plt.xlim([0.0, 1.0]) | |
plt.ylim([0.0, 1.05]) | |
plt.xlabel('False Positive Rate', fontsize=14) | |
plt.ylabel('True Positive Rate', fontsize=14) | |
plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=16) | |
plt.legend(loc="lower right") | |
plt.grid(True, alpha=0.3) | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"ROC曲线已保存: {save_path}") | |
return plt.gcf() | |
def plot_precision_recall_curve(y_true: np.ndarray, y_scores: np.ndarray, | |
class_names: List[str], save_path: str = None, | |
figsize: Tuple[int, int] = (12, 8)): | |
""" | |
绘制Precision-Recall曲线 | |
Args: | |
y_true: 真实标签 (one-hot编码) | |
y_scores: 预测概率 | |
class_names: 类别名称 | |
save_path: 保存路径 | |
figsize: 图像大小 | |
""" | |
from sklearn.metrics import precision_recall_curve, average_precision_score | |
plt.figure(figsize=figsize) | |
colors = plt.cm.Set1(np.linspace(0, 1, len(class_names))) | |
for i, (color, class_name) in enumerate(zip(colors, class_names)): | |
if i < y_scores.shape[1]: | |
precision, recall, _ = precision_recall_curve(y_true[:, i], y_scores[:, i]) | |
ap_score = average_precision_score(y_true[:, i], y_scores[:, i]) | |
plt.plot(recall, precision, color=color, lw=2, | |
label=f'{class_name} (AP = {ap_score:.3f})') | |
plt.xlabel('Recall', fontsize=14) | |
plt.ylabel('Precision', fontsize=14) | |
plt.title('Precision-Recall Curves', fontsize=16) | |
plt.legend(loc="lower left") | |
plt.grid(True, alpha=0.3) | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"Precision-Recall曲线已保存: {save_path}") | |
return plt.gcf() | |
def create_classification_report(y_true: List[int], y_pred: List[int], | |
class_names: List[str] = None, | |
save_path: str = None) -> pd.DataFrame: | |
""" | |
创建分类报告DataFrame | |
Args: | |
y_true: 真实标签 | |
y_pred: 预测标签 | |
class_names: 类别名称 | |
save_path: 保存路径 | |
Returns: | |
pd.DataFrame: 分类报告 | |
""" | |
report_dict = classification_report( | |
y_true, y_pred, | |
target_names=class_names, | |
output_dict=True, | |
zero_division=0 | |
) | |
df_report = pd.DataFrame(report_dict).transpose() | |
if save_path: | |
df_report.to_csv(save_path, encoding='utf-8') | |
print(f"分类报告已保存: {save_path}") | |
return df_report | |
def evaluate_model_comprehensive(y_true: List[int], y_pred: List[int], | |
y_scores: Optional[np.ndarray] = None, | |
class_names: List[str] = None, | |
output_dir: str = 'evaluation_results') -> Dict: | |
""" | |
综合评估模型 | |
Args: | |
y_true: 真实标签 | |
y_pred: 预测标签 | |
y_scores: 预测概率(用于AUC计算) | |
class_names: 类别名称 | |
output_dir: 输出目录 | |
Returns: | |
Dict: 评估结果 | |
""" | |
import os | |
os.makedirs(output_dir, exist_ok=True) | |
# 基础指标 | |
metrics = calculate_metrics(y_true, y_pred, class_names) | |
# 混淆矩阵可视化 | |
cm = metrics['confusion_matrix'] | |
# 原始计数 | |
plot_confusion_matrix( | |
cm, class_names, normalize=False, | |
title='Confusion Matrix (Counts)', | |
save_path=os.path.join(output_dir, 'confusion_matrix_counts.png') | |
) | |
plt.close() | |
# 标准化 | |
plot_confusion_matrix( | |
cm, class_names, normalize=True, | |
title='Confusion Matrix (Normalized)', | |
save_path=os.path.join(output_dir, 'confusion_matrix_normalized.png') | |
) | |
plt.close() | |
# 分类报告 | |
report_df = create_classification_report( | |
y_true, y_pred, class_names, | |
save_path=os.path.join(output_dir, 'classification_report.csv') | |
) | |
# AUC相关指标(如果提供了概率) | |
if y_scores is not None: | |
num_classes = len(class_names) if class_names else len(np.unique(y_true)) | |
# 转换标签为one-hot编码 | |
y_true_binary = label_binarize(y_true, classes=range(num_classes)) | |
if num_classes == 2: | |
y_true_binary = np.hstack([1-y_true_binary, y_true_binary]) | |
# 计算AUC | |
auc_metrics = calculate_multiclass_auc(y_true_binary, y_scores, num_classes) | |
metrics.update(auc_metrics) | |
# ROC曲线 | |
plot_roc_curves( | |
auc_metrics['fpr'], auc_metrics['tpr'], auc_metrics, | |
class_names, save_path=os.path.join(output_dir, 'roc_curves.png') | |
) | |
plt.close() | |
# Precision-Recall曲线 | |
plot_precision_recall_curve( | |
y_true_binary, y_scores, class_names, | |
save_path=os.path.join(output_dir, 'precision_recall_curves.png') | |
) | |
plt.close() | |
# 保存指标到文件 | |
metrics_df = pd.DataFrame([ | |
{'metric': k, 'value': v} for k, v in metrics.items() | |
if isinstance(v, (int, float, np.number)) | |
]) | |
metrics_df.to_csv(os.path.join(output_dir, 'metrics_summary.csv'), index=False) | |
print(f"评估结果已保存到目录: {output_dir}") | |
return metrics | |
def print_metrics_summary(metrics: Dict, class_names: List[str] = None): | |
""" | |
打印指标摘要 | |
Args: | |
metrics: 指标字典 | |
class_names: 类别名称 | |
""" | |
print("=" * 60) | |
print("模型评估结果摘要") | |
print("=" * 60) | |
print(f"总体准确率: {metrics['accuracy']:.4f}") | |
print(f"Cohen's Kappa: {metrics['cohen_kappa']:.4f}") | |
print() | |
print("宏观平均:") | |
print(f" 精确率: {metrics['macro_precision']:.4f}") | |
print(f" 召回率: {metrics['macro_recall']:.4f}") | |
print(f" F1分数: {metrics['macro_f1']:.4f}") | |
print() | |
print("加权平均:") | |
print(f" 精确率: {metrics['weighted_precision']:.4f}") | |
print(f" 召回率: {metrics['weighted_recall']:.4f}") | |
print(f" F1分数: {metrics['weighted_f1']:.4f}") | |
print() | |
if class_names: | |
print("各类别性能:") | |
print("-" * 60) | |
print(f"{'类别':<15} {'精确率':<10} {'召回率':<10} {'F1分数':<10} {'样本数':<10}") | |
print("-" * 60) | |
for i, class_name in enumerate(class_names): | |
precision_key = f'{class_name}_precision' | |
recall_key = f'{class_name}_recall' | |
f1_key = f'{class_name}_f1' | |
support_key = f'{class_name}_support' | |
if all(key in metrics for key in [precision_key, recall_key, f1_key, support_key]): | |
print(f"{class_name:<15} " | |
f"{metrics[precision_key]:<10.4f} " | |
f"{metrics[recall_key]:<10.4f} " | |
f"{metrics[f1_key]:<10.4f} " | |
f"{int(metrics[support_key]):<10}") | |
if 'macro_auc' in metrics: | |
print(f"\n宏观平均AUC: {metrics['macro_auc']:.4f}") | |
if 'micro_auc' in metrics: | |
print(f"微观平均AUC: {metrics['micro_auc']:.4f}") | |
print("=" * 60) | |
if __name__ == "__main__": | |
# 测试评估指标 | |
np.random.seed(42) | |
# 生成模拟数据 | |
n_samples = 1000 | |
n_classes = 5 | |
y_true = np.random.randint(0, n_classes, n_samples) | |
y_pred = np.random.randint(0, n_classes, n_samples) | |
y_scores = np.random.rand(n_samples, n_classes) | |
y_scores = y_scores / y_scores.sum(axis=1, keepdims=True) # 归一化为概率 | |
class_names = ['无病变', '轻度', '中度', '重度', '增殖性'] | |
# 综合评估 | |
results = evaluate_model_comprehensive( | |
y_true, y_pred, y_scores, class_names, 'test_evaluation' | |
) | |
# 打印摘要 | |
print_metrics_summary(results, class_names) | |