aicomp_demo / utils /metrics.py
ceasonen
我的视网膜检测网站
04103fb
"""
评估指标计算工具
"""
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
accuracy_score, precision_recall_fscore_support,
confusion_matrix, classification_report, cohen_kappa_score,
roc_auc_score, roc_curve, auc
)
from sklearn.preprocessing import label_binarize
import pandas as pd
from typing import List, Dict, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')
def calculate_metrics(y_true: List[int], y_pred: List[int],
class_names: List[str] = None) -> Dict:
"""
计算分类指标
Args:
y_true: 真实标签
y_pred: 预测标签
class_names: 类别名称
Returns:
Dict: 指标字典
"""
# 基本指标
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, average=None, zero_division=0
)
# 宏观和微观平均
macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
y_true, y_pred, average='macro', zero_division=0
)
micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
y_true, y_pred, average='micro', zero_division=0
)
weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
y_true, y_pred, average='weighted', zero_division=0
)
# Cohen's Kappa
kappa = cohen_kappa_score(y_true, y_pred)
# 混淆矩阵
cm = confusion_matrix(y_true, y_pred)
metrics = {
'accuracy': accuracy,
'macro_precision': macro_precision,
'macro_recall': macro_recall,
'macro_f1': macro_f1,
'micro_precision': micro_precision,
'micro_recall': micro_recall,
'micro_f1': micro_f1,
'weighted_precision': weighted_precision,
'weighted_recall': weighted_recall,
'weighted_f1': weighted_f1,
'cohen_kappa': kappa,
'confusion_matrix': cm,
'support': support
}
# 每个类别的指标
if class_names is None:
class_names = [f'Class_{i}' for i in range(len(precision))]
for i, class_name in enumerate(class_names):
if i < len(precision):
metrics[f'{class_name}_precision'] = precision[i]
metrics[f'{class_name}_recall'] = recall[i]
metrics[f'{class_name}_f1'] = f1[i]
metrics[f'{class_name}_support'] = support[i]
return metrics
def calculate_multiclass_auc(y_true: np.ndarray, y_scores: np.ndarray,
num_classes: int) -> Dict:
"""
计算多类别AUC指标
Args:
y_true: 真实标签 (one-hot或标签编码)
y_scores: 预测概率
num_classes: 类别数量
Returns:
Dict: AUC指标
"""
# 转换为one-hot编码
if y_true.ndim == 1:
y_true_binary = label_binarize(y_true, classes=range(num_classes))
if num_classes == 2:
y_true_binary = np.hstack([1-y_true_binary, y_true_binary])
else:
y_true_binary = y_true
# 计算每个类别的AUC
auc_scores = {}
fpr = {}
tpr = {}
for i in range(num_classes):
if np.sum(y_true_binary[:, i]) > 0: # 确保类别存在
fpr[i], tpr[i], _ = roc_curve(y_true_binary[:, i], y_scores[:, i])
auc_scores[f'class_{i}_auc'] = auc(fpr[i], tpr[i])
# 宏观平均AUC
if len(auc_scores) > 0:
auc_scores['macro_auc'] = np.mean(list(auc_scores.values()))
# 微观平均AUC(多类别)
try:
micro_auc = roc_auc_score(y_true_binary, y_scores, average='micro', multi_class='ovr')
auc_scores['micro_auc'] = micro_auc
except:
pass
auc_scores['fpr'] = fpr
auc_scores['tpr'] = tpr
return auc_scores
def plot_confusion_matrix(cm: np.ndarray, class_names: List[str],
normalize: bool = False, title: str = 'Confusion Matrix',
save_path: str = None, figsize: Tuple[int, int] = (10, 8)):
"""
绘制混淆矩阵
Args:
cm: 混淆矩阵
class_names: 类别名称
normalize: 是否标准化
title: 图标题
save_path: 保存路径
figsize: 图像大小
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fmt = '.2f'
else:
fmt = 'd'
plt.figure(figsize=figsize)
sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues',
xticklabels=class_names, yticklabels=class_names,
cbar_kws={'label': 'Count' if not normalize else 'Proportion'})
plt.title(title, fontsize=16)
plt.xlabel('Predicted Label', fontsize=14)
plt.ylabel('True Label', fontsize=14)
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"混淆矩阵已保存: {save_path}")
return plt.gcf()
def plot_roc_curves(fpr: Dict, tpr: Dict, auc_scores: Dict,
class_names: List[str], save_path: str = None,
figsize: Tuple[int, int] = (12, 8)):
"""
绘制ROC曲线
Args:
fpr: 假正率字典
tpr: 真正率字典
auc_scores: AUC分数字典
class_names: 类别名称
save_path: 保存路径
figsize: 图像大小
"""
plt.figure(figsize=figsize)
# 绘制每个类别的ROC曲线
colors = plt.cm.Set1(np.linspace(0, 1, len(class_names)))
for i, (color, class_name) in enumerate(zip(colors, class_names)):
if i in fpr and i in tpr:
auc_score = auc_scores.get(f'class_{i}_auc', 0)
plt.plot(fpr[i], tpr[i], color=color, lw=2,
label=f'{class_name} (AUC = {auc_score:.3f})')
# 对角线
plt.plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.8)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=16)
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"ROC曲线已保存: {save_path}")
return plt.gcf()
def plot_precision_recall_curve(y_true: np.ndarray, y_scores: np.ndarray,
class_names: List[str], save_path: str = None,
figsize: Tuple[int, int] = (12, 8)):
"""
绘制Precision-Recall曲线
Args:
y_true: 真实标签 (one-hot编码)
y_scores: 预测概率
class_names: 类别名称
save_path: 保存路径
figsize: 图像大小
"""
from sklearn.metrics import precision_recall_curve, average_precision_score
plt.figure(figsize=figsize)
colors = plt.cm.Set1(np.linspace(0, 1, len(class_names)))
for i, (color, class_name) in enumerate(zip(colors, class_names)):
if i < y_scores.shape[1]:
precision, recall, _ = precision_recall_curve(y_true[:, i], y_scores[:, i])
ap_score = average_precision_score(y_true[:, i], y_scores[:, i])
plt.plot(recall, precision, color=color, lw=2,
label=f'{class_name} (AP = {ap_score:.3f})')
plt.xlabel('Recall', fontsize=14)
plt.ylabel('Precision', fontsize=14)
plt.title('Precision-Recall Curves', fontsize=16)
plt.legend(loc="lower left")
plt.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Precision-Recall曲线已保存: {save_path}")
return plt.gcf()
def create_classification_report(y_true: List[int], y_pred: List[int],
class_names: List[str] = None,
save_path: str = None) -> pd.DataFrame:
"""
创建分类报告DataFrame
Args:
y_true: 真实标签
y_pred: 预测标签
class_names: 类别名称
save_path: 保存路径
Returns:
pd.DataFrame: 分类报告
"""
report_dict = classification_report(
y_true, y_pred,
target_names=class_names,
output_dict=True,
zero_division=0
)
df_report = pd.DataFrame(report_dict).transpose()
if save_path:
df_report.to_csv(save_path, encoding='utf-8')
print(f"分类报告已保存: {save_path}")
return df_report
def evaluate_model_comprehensive(y_true: List[int], y_pred: List[int],
y_scores: Optional[np.ndarray] = None,
class_names: List[str] = None,
output_dir: str = 'evaluation_results') -> Dict:
"""
综合评估模型
Args:
y_true: 真实标签
y_pred: 预测标签
y_scores: 预测概率(用于AUC计算)
class_names: 类别名称
output_dir: 输出目录
Returns:
Dict: 评估结果
"""
import os
os.makedirs(output_dir, exist_ok=True)
# 基础指标
metrics = calculate_metrics(y_true, y_pred, class_names)
# 混淆矩阵可视化
cm = metrics['confusion_matrix']
# 原始计数
plot_confusion_matrix(
cm, class_names, normalize=False,
title='Confusion Matrix (Counts)',
save_path=os.path.join(output_dir, 'confusion_matrix_counts.png')
)
plt.close()
# 标准化
plot_confusion_matrix(
cm, class_names, normalize=True,
title='Confusion Matrix (Normalized)',
save_path=os.path.join(output_dir, 'confusion_matrix_normalized.png')
)
plt.close()
# 分类报告
report_df = create_classification_report(
y_true, y_pred, class_names,
save_path=os.path.join(output_dir, 'classification_report.csv')
)
# AUC相关指标(如果提供了概率)
if y_scores is not None:
num_classes = len(class_names) if class_names else len(np.unique(y_true))
# 转换标签为one-hot编码
y_true_binary = label_binarize(y_true, classes=range(num_classes))
if num_classes == 2:
y_true_binary = np.hstack([1-y_true_binary, y_true_binary])
# 计算AUC
auc_metrics = calculate_multiclass_auc(y_true_binary, y_scores, num_classes)
metrics.update(auc_metrics)
# ROC曲线
plot_roc_curves(
auc_metrics['fpr'], auc_metrics['tpr'], auc_metrics,
class_names, save_path=os.path.join(output_dir, 'roc_curves.png')
)
plt.close()
# Precision-Recall曲线
plot_precision_recall_curve(
y_true_binary, y_scores, class_names,
save_path=os.path.join(output_dir, 'precision_recall_curves.png')
)
plt.close()
# 保存指标到文件
metrics_df = pd.DataFrame([
{'metric': k, 'value': v} for k, v in metrics.items()
if isinstance(v, (int, float, np.number))
])
metrics_df.to_csv(os.path.join(output_dir, 'metrics_summary.csv'), index=False)
print(f"评估结果已保存到目录: {output_dir}")
return metrics
def print_metrics_summary(metrics: Dict, class_names: List[str] = None):
"""
打印指标摘要
Args:
metrics: 指标字典
class_names: 类别名称
"""
print("=" * 60)
print("模型评估结果摘要")
print("=" * 60)
print(f"总体准确率: {metrics['accuracy']:.4f}")
print(f"Cohen's Kappa: {metrics['cohen_kappa']:.4f}")
print()
print("宏观平均:")
print(f" 精确率: {metrics['macro_precision']:.4f}")
print(f" 召回率: {metrics['macro_recall']:.4f}")
print(f" F1分数: {metrics['macro_f1']:.4f}")
print()
print("加权平均:")
print(f" 精确率: {metrics['weighted_precision']:.4f}")
print(f" 召回率: {metrics['weighted_recall']:.4f}")
print(f" F1分数: {metrics['weighted_f1']:.4f}")
print()
if class_names:
print("各类别性能:")
print("-" * 60)
print(f"{'类别':<15} {'精确率':<10} {'召回率':<10} {'F1分数':<10} {'样本数':<10}")
print("-" * 60)
for i, class_name in enumerate(class_names):
precision_key = f'{class_name}_precision'
recall_key = f'{class_name}_recall'
f1_key = f'{class_name}_f1'
support_key = f'{class_name}_support'
if all(key in metrics for key in [precision_key, recall_key, f1_key, support_key]):
print(f"{class_name:<15} "
f"{metrics[precision_key]:<10.4f} "
f"{metrics[recall_key]:<10.4f} "
f"{metrics[f1_key]:<10.4f} "
f"{int(metrics[support_key]):<10}")
if 'macro_auc' in metrics:
print(f"\n宏观平均AUC: {metrics['macro_auc']:.4f}")
if 'micro_auc' in metrics:
print(f"微观平均AUC: {metrics['micro_auc']:.4f}")
print("=" * 60)
if __name__ == "__main__":
# 测试评估指标
np.random.seed(42)
# 生成模拟数据
n_samples = 1000
n_classes = 5
y_true = np.random.randint(0, n_classes, n_samples)
y_pred = np.random.randint(0, n_classes, n_samples)
y_scores = np.random.rand(n_samples, n_classes)
y_scores = y_scores / y_scores.sum(axis=1, keepdims=True) # 归一化为概率
class_names = ['无病变', '轻度', '中度', '重度', '增殖性']
# 综合评估
results = evaluate_model_comprehensive(
y_true, y_pred, y_scores, class_names, 'test_evaluation'
)
# 打印摘要
print_metrics_summary(results, class_names)