aicomp_demo / utils /visualization.py
ceasonen
我的视网膜检测网站
04103fb
"""
可视化工具模块
包含训练过程可视化、模型解释性可视化等
"""
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional, Union
import pandas as pd
from matplotlib.animation import FuncAnimation
import warnings
warnings.filterwarnings('ignore')
# 设置matplotlib中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def plot_training_curves(train_history: Dict, save_path: str = None,
figsize: Tuple[int, int] = (15, 10)):
"""
绘制训练曲线
Args:
train_history: 训练历史字典
save_path: 保存路径
figsize: 图像大小
"""
fig, axes = plt.subplots(2, 2, figsize=figsize)
epochs = range(1, len(train_history['train_loss']) + 1)
# 损失曲线
axes[0, 0].plot(epochs, train_history['train_loss'], 'b-', label='训练损失', linewidth=2)
axes[0, 0].plot(epochs, train_history['val_loss'], 'r-', label='验证损失', linewidth=2)
axes[0, 0].set_title('损失曲线', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# 准确率曲线
axes[0, 1].plot(epochs, train_history['train_acc'], 'b-', label='训练准确率', linewidth=2)
axes[0, 1].plot(epochs, train_history['val_acc'], 'r-', label='验证准确率', linewidth=2)
axes[0, 1].set_title('准确率曲线', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# 学习率曲线
if 'lr' in train_history:
axes[1, 0].plot(epochs, train_history['lr'], 'g-', linewidth=2)
axes[1, 0].set_title('学习率变化', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True, alpha=0.3)
# 训练摘要
best_val_acc = max(train_history['val_acc'])
best_epoch = train_history['val_acc'].index(best_val_acc) + 1
final_train_loss = train_history['train_loss'][-1]
final_val_loss = train_history['val_loss'][-1]
summary_text = f"""训练摘要:
最佳验证准确率: {best_val_acc:.2f}%
最佳epoch: {best_epoch}
最终训练损失: {final_train_loss:.4f}
最终验证损失: {final_val_loss:.4f}
总训练轮数: {len(epochs)}
"""
axes[1, 1].text(0.1, 0.9, summary_text, transform=axes[1, 1].transAxes,
fontsize=12, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
axes[1, 1].set_xlim(0, 1)
axes[1, 1].set_ylim(0, 1)
axes[1, 1].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"训练曲线已保存: {save_path}")
return fig
def plot_class_distribution(data_dir: str, class_names: List[str] = None,
save_path: str = None, figsize: Tuple[int, int] = (12, 8)):
"""
绘制数据集类别分布
Args:
data_dir: 数据目录
class_names: 类别名称
save_path: 保存路径
figsize: 图像大小
"""
import os
class_counts = {}
# 统计各类别样本数
for class_idx, class_folder in enumerate(os.listdir(data_dir)):
class_path = os.path.join(data_dir, class_folder)
if os.path.isdir(class_path):
count = len([f for f in os.listdir(class_path)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
if class_names and class_idx < len(class_names):
class_name = class_names[class_idx]
else:
class_name = class_folder
class_counts[class_name] = count
# 绘制条形图
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
# 条形图
classes = list(class_counts.keys())
counts = list(class_counts.values())
colors = plt.cm.Set3(np.linspace(0, 1, len(classes)))
bars = ax1.bar(classes, counts, color=colors, edgecolor='black', linewidth=1)
ax1.set_title('类别样本分布', fontsize=14, fontweight='bold')
ax1.set_xlabel('类别')
ax1.set_ylabel('样本数量')
ax1.tick_params(axis='x', rotation=45)
# 添加数值标签
for bar, count in zip(bars, counts):
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(counts)*0.01,
str(count), ha='center', va='bottom', fontweight='bold')
# 饼图
ax2.pie(counts, labels=classes, autopct='%1.1f%%', colors=colors,
startangle=90, wedgeprops={'edgecolor': 'black'})
ax2.set_title('类别比例分布', fontsize=14, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"类别分布图已保存: {save_path}")
return fig, class_counts
def plot_sample_images(data_dir: str, class_names: List[str],
samples_per_class: int = 3,
save_path: str = None, figsize: Tuple[int, int] = (15, 10)):
"""
显示每个类别的样本图像
Args:
data_dir: 数据目录
class_names: 类别名称
samples_per_class: 每类显示的样本数
save_path: 保存路径
figsize: 图像大小
"""
import os
import random
n_classes = len(class_names)
fig, axes = plt.subplots(n_classes, samples_per_class, figsize=figsize)
if n_classes == 1:
axes = [axes]
for class_idx, class_name in enumerate(class_names):
class_dir = os.path.join(data_dir, class_name)
if not os.path.exists(class_dir):
# 尝试用索引查找目录
class_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
if class_idx < len(class_dirs):
class_dir = os.path.join(data_dir, class_dirs[class_idx])
if os.path.exists(class_dir):
# 获取图像文件
image_files = [f for f in os.listdir(class_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# 随机选择样本
selected_files = random.sample(image_files,
min(samples_per_class, len(image_files)))
for sample_idx, img_file in enumerate(selected_files):
img_path = os.path.join(class_dir, img_file)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if samples_per_class == 1:
ax = axes[class_idx]
else:
ax = axes[class_idx, sample_idx]
ax.imshow(img)
ax.axis('off')
if sample_idx == 0:
ax.set_ylabel(class_name, fontsize=12, fontweight='bold')
# 填充空白位置
for sample_idx in range(len(selected_files), samples_per_class):
if samples_per_class == 1:
ax = axes[class_idx]
else:
ax = axes[class_idx, sample_idx]
ax.axis('off')
plt.suptitle('各类别样本展示', fontsize=16, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"样本图像已保存: {save_path}")
return fig
def plot_model_comparison(results: List[Dict], save_path: str = None,
figsize: Tuple[int, int] = (15, 10)):
"""
比较不同模型的性能
Args:
results: 模型结果列表,每个元素包含模型名称和指标
save_path: 保存路径
figsize: 图像大小
"""
metrics_to_plot = ['accuracy', 'macro_f1', 'macro_precision', 'macro_recall']
model_names = [r['model_name'] for r in results]
fig, axes = plt.subplots(2, 2, figsize=figsize)
axes = axes.ravel()
for idx, metric in enumerate(metrics_to_plot):
values = [r['metrics'].get(metric, 0) for r in results]
colors = plt.cm.Set2(np.linspace(0, 1, len(model_names)))
bars = axes[idx].bar(model_names, values, color=colors,
edgecolor='black', linewidth=1)
axes[idx].set_title(f'{metric.replace("_", " ").title()}',
fontsize=12, fontweight='bold')
axes[idx].set_ylabel('Score')
axes[idx].tick_params(axis='x', rotation=45)
axes[idx].grid(True, alpha=0.3)
# 添加数值标签
for bar, value in zip(bars, values):
axes[idx].text(bar.get_x() + bar.get_width()/2,
bar.get_height() + max(values)*0.01,
f'{value:.3f}', ha='center', va='bottom', fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"模型比较图已保存: {save_path}")
return fig
def create_interactive_training_dashboard(train_history: Dict, save_path: str = None):
"""
创建交互式训练仪表板
Args:
train_history: 训练历史
save_path: 保存路径(HTML文件)
"""
# 惰性导入 plotly,未安装则给出提示
try:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
except Exception as exc:
raise ImportError(
"使用 create_interactive_training_dashboard 需要安装 plotly,请运行: pip install plotly"
) from exc
epochs = list(range(1, len(train_history['train_loss']) + 1))
# 创建子图
fig = make_subplots(
rows=2, cols=2,
subplot_titles=('损失曲线', '准确率曲线', '学习率变化', '训练摘要'),
specs=[[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"type": "table"}]]
)
# 损失曲线
fig.add_trace(
go.Scatter(x=epochs, y=train_history['train_loss'],
name='训练损失', line=dict(color='blue')),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=epochs, y=train_history['val_loss'],
name='验证损失', line=dict(color='red')),
row=1, col=1
)
# 准确率曲线
fig.add_trace(
go.Scatter(x=epochs, y=train_history['train_acc'],
name='训练准确率', line=dict(color='blue')),
row=1, col=2
)
fig.add_trace(
go.Scatter(x=epochs, y=train_history['val_acc'],
name='验证准确率', line=dict(color='red')),
row=1, col=2
)
# 学习率变化
if 'lr' in train_history:
fig.add_trace(
go.Scatter(x=epochs, y=train_history['lr'],
name='学习率', line=dict(color='green')),
row=2, col=1
)
fig.update_yaxes(type="log", row=2, col=1)
# 训练摘要表格
best_val_acc = max(train_history['val_acc'])
best_epoch = train_history['val_acc'].index(best_val_acc) + 1
summary_data = [
['指标', '数值'],
['最佳验证准确率', f'{best_val_acc:.2f}%'],
['最佳Epoch', str(best_epoch)],
['最终训练损失', f'{train_history["train_loss"][-1]:.4f}'],
['最终验证损失', f'{train_history["val_loss"][-1]:.4f}'],
['总训练轮数', str(len(epochs))]
]
fig.add_trace(
go.Table(
header=dict(values=summary_data[0], fill_color='lightblue'),
cells=dict(values=list(zip(*summary_data[1:])), fill_color='white')
),
row=2, col=2
)
# 更新布局
fig.update_layout(
title='训练过程可视化仪表板',
height=800,
showlegend=False
)
if save_path:
fig.write_html(save_path)
print(f"交互式仪表板已保存: {save_path}")
return fig
def visualize_feature_maps(model: torch.nn.Module, image: torch.Tensor,
layer_name: str, save_path: str = None,
figsize: Tuple[int, int] = (20, 15)):
"""
可视化特征图
Args:
model: PyTorch模型
image: 输入图像tensor
layer_name: 要可视化的层名称
save_path: 保存路径
figsize: 图像大小
"""
# 注册hook获取特征图
features = {}
def hook_fn(module, input, output):
features['feature_map'] = output
# 找到目标层并注册hook
target_layer = None
for name, module in model.named_modules():
if layer_name in name:
target_layer = module
break
if target_layer is None:
print(f"未找到层: {layer_name}")
return None
handle = target_layer.register_forward_hook(hook_fn)
# 前向传播
model.eval()
with torch.no_grad():
_ = model(image.unsqueeze(0))
# 移除hook
handle.remove()
# 获取特征图
feature_map = features['feature_map'].squeeze(0) # 移除batch维度
n_features = feature_map.shape[0]
# 计算网格大小
n_cols = 8
n_rows = (n_features + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
if n_rows == 1:
axes = [axes]
axes = np.array(axes).ravel()
for i in range(n_features):
feature = feature_map[i].cpu().numpy()
# 标准化到[0,1]
feature = (feature - feature.min()) / (feature.max() - feature.min() + 1e-8)
axes[i].imshow(feature, cmap='viridis')
axes[i].set_title(f'Feature {i+1}')
axes[i].axis('off')
# 隐藏多余的子图
for i in range(n_features, len(axes)):
axes[i].axis('off')
plt.suptitle(f'特征图可视化 - {layer_name}', fontsize=16, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"特征图已保存: {save_path}")
return fig
def plot_attention_weights(attention_weights: np.ndarray, tokens: List[str] = None,
save_path: str = None, figsize: Tuple[int, int] = (12, 10)):
"""
可视化注意力权重(适用于Vision Transformer)
Args:
attention_weights: 注意力权重矩阵
tokens: token标签
save_path: 保存路径
figsize: 图像大小
"""
plt.figure(figsize=figsize)
if tokens is None:
tokens = [f'Token {i+1}' for i in range(attention_weights.shape[0])]
# 创建热力图
sns.heatmap(attention_weights,
xticklabels=tokens, yticklabels=tokens,
cmap='Blues', annot=False, square=True,
cbar_kws={'label': 'Attention Weight'})
plt.title('注意力权重可视化', fontsize=16, fontweight='bold')
plt.xlabel('Key Tokens')
plt.ylabel('Query Tokens')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"注意力权重图已保存: {save_path}")
return plt.gcf()
def create_prediction_gallery(images: List[str], predictions: List[Dict],
true_labels: List[str] = None,
save_path: str = None,
images_per_row: int = 4,
figsize: Tuple[int, int] = (20, 15)):
"""
创建预测结果画廊
Args:
images: 图像路径列表
predictions: 预测结果列表
true_labels: 真实标签列表
save_path: 保存路径
images_per_row: 每行图像数
figsize: 图像大小
"""
n_images = len(images)
n_rows = (n_images + images_per_row - 1) // images_per_row
fig, axes = plt.subplots(n_rows, images_per_row, figsize=figsize)
if n_rows == 1:
axes = [axes]
axes = np.array(axes).ravel()
for i, (img_path, pred) in enumerate(zip(images, predictions)):
if i >= len(axes):
break
# 读取并显示图像
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
axes[i].imshow(img)
# 创建标题
title = f"预测: {pred['predicted_label']}\n"
title += f"置信度: {pred['confidence']:.3f}"
if true_labels and i < len(true_labels):
title = f"真实: {true_labels[i]}\n" + title
# 如果预测错误,用红色标题
if pred['predicted_label'] != true_labels[i]:
axes[i].set_title(title, color='red', fontweight='bold')
else:
axes[i].set_title(title, color='green', fontweight='bold')
else:
axes[i].set_title(title, fontweight='bold')
axes[i].axis('off')
# 隐藏多余的子图
for i in range(len(images), len(axes)):
axes[i].axis('off')
plt.suptitle('预测结果画廊', fontsize=16, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"预测画廊已保存: {save_path}")
return fig
if __name__ == "__main__":
# 测试可视化功能
# 模拟训练历史数据
epochs = 50
train_history = {
'train_loss': [1.5 * np.exp(-0.1 * i) + 0.1 + 0.05 * np.random.randn() for i in range(epochs)],
'val_loss': [1.6 * np.exp(-0.08 * i) + 0.15 + 0.08 * np.random.randn() for i in range(epochs)],
'train_acc': [60 * (1 - np.exp(-0.1 * i)) + 10 * np.random.randn() for i in range(epochs)],
'val_acc': [55 * (1 - np.exp(-0.08 * i)) + 12 * np.random.randn() for i in range(epochs)],
'lr': [0.001 * (0.9 ** (i // 10)) for i in range(epochs)]
}
# 绘制训练曲线
fig = plot_training_curves(train_history, 'test_training_curves.png')
plt.close()
# 创建交互式仪表板
interactive_fig = create_interactive_training_dashboard(train_history, 'test_dashboard.html')
print("可视化测试完成!")