""" 可视化工具模块 包含训练过程可视化、模型解释性可视化等 """ import matplotlib.pyplot as plt import seaborn as sns import numpy as np import cv2 from PIL import Image import torch import torch.nn.functional as F from typing import List, Tuple, Dict, Optional, Union import pandas as pd from matplotlib.animation import FuncAnimation import warnings warnings.filterwarnings('ignore') # 设置matplotlib中文字体 plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False def plot_training_curves(train_history: Dict, save_path: str = None, figsize: Tuple[int, int] = (15, 10)): """ 绘制训练曲线 Args: train_history: 训练历史字典 save_path: 保存路径 figsize: 图像大小 """ fig, axes = plt.subplots(2, 2, figsize=figsize) epochs = range(1, len(train_history['train_loss']) + 1) # 损失曲线 axes[0, 0].plot(epochs, train_history['train_loss'], 'b-', label='训练损失', linewidth=2) axes[0, 0].plot(epochs, train_history['val_loss'], 'r-', label='验证损失', linewidth=2) axes[0, 0].set_title('损失曲线', fontsize=14, fontweight='bold') axes[0, 0].set_xlabel('Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].legend() axes[0, 0].grid(True, alpha=0.3) # 准确率曲线 axes[0, 1].plot(epochs, train_history['train_acc'], 'b-', label='训练准确率', linewidth=2) axes[0, 1].plot(epochs, train_history['val_acc'], 'r-', label='验证准确率', linewidth=2) axes[0, 1].set_title('准确率曲线', fontsize=14, fontweight='bold') axes[0, 1].set_xlabel('Epoch') axes[0, 1].set_ylabel('Accuracy (%)') axes[0, 1].legend() axes[0, 1].grid(True, alpha=0.3) # 学习率曲线 if 'lr' in train_history: axes[1, 0].plot(epochs, train_history['lr'], 'g-', linewidth=2) axes[1, 0].set_title('学习率变化', fontsize=14, fontweight='bold') axes[1, 0].set_xlabel('Epoch') axes[1, 0].set_ylabel('Learning Rate') axes[1, 0].set_yscale('log') axes[1, 0].grid(True, alpha=0.3) # 训练摘要 best_val_acc = max(train_history['val_acc']) best_epoch = train_history['val_acc'].index(best_val_acc) + 1 final_train_loss = train_history['train_loss'][-1] final_val_loss = train_history['val_loss'][-1] summary_text = f"""训练摘要: 最佳验证准确率: {best_val_acc:.2f}% 最佳epoch: {best_epoch} 最终训练损失: {final_train_loss:.4f} 最终验证损失: {final_val_loss:.4f} 总训练轮数: {len(epochs)} """ axes[1, 1].text(0.1, 0.9, summary_text, transform=axes[1, 1].transAxes, fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8)) axes[1, 1].set_xlim(0, 1) axes[1, 1].set_ylim(0, 1) axes[1, 1].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"训练曲线已保存: {save_path}") return fig def plot_class_distribution(data_dir: str, class_names: List[str] = None, save_path: str = None, figsize: Tuple[int, int] = (12, 8)): """ 绘制数据集类别分布 Args: data_dir: 数据目录 class_names: 类别名称 save_path: 保存路径 figsize: 图像大小 """ import os class_counts = {} # 统计各类别样本数 for class_idx, class_folder in enumerate(os.listdir(data_dir)): class_path = os.path.join(data_dir, class_folder) if os.path.isdir(class_path): count = len([f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) if class_names and class_idx < len(class_names): class_name = class_names[class_idx] else: class_name = class_folder class_counts[class_name] = count # 绘制条形图 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) # 条形图 classes = list(class_counts.keys()) counts = list(class_counts.values()) colors = plt.cm.Set3(np.linspace(0, 1, len(classes))) bars = ax1.bar(classes, counts, color=colors, edgecolor='black', linewidth=1) ax1.set_title('类别样本分布', fontsize=14, fontweight='bold') ax1.set_xlabel('类别') ax1.set_ylabel('样本数量') ax1.tick_params(axis='x', rotation=45) # 添加数值标签 for bar, count in zip(bars, counts): ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(counts)*0.01, str(count), ha='center', va='bottom', fontweight='bold') # 饼图 ax2.pie(counts, labels=classes, autopct='%1.1f%%', colors=colors, startangle=90, wedgeprops={'edgecolor': 'black'}) ax2.set_title('类别比例分布', fontsize=14, fontweight='bold') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"类别分布图已保存: {save_path}") return fig, class_counts def plot_sample_images(data_dir: str, class_names: List[str], samples_per_class: int = 3, save_path: str = None, figsize: Tuple[int, int] = (15, 10)): """ 显示每个类别的样本图像 Args: data_dir: 数据目录 class_names: 类别名称 samples_per_class: 每类显示的样本数 save_path: 保存路径 figsize: 图像大小 """ import os import random n_classes = len(class_names) fig, axes = plt.subplots(n_classes, samples_per_class, figsize=figsize) if n_classes == 1: axes = [axes] for class_idx, class_name in enumerate(class_names): class_dir = os.path.join(data_dir, class_name) if not os.path.exists(class_dir): # 尝试用索引查找目录 class_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))] if class_idx < len(class_dirs): class_dir = os.path.join(data_dir, class_dirs[class_idx]) if os.path.exists(class_dir): # 获取图像文件 image_files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] # 随机选择样本 selected_files = random.sample(image_files, min(samples_per_class, len(image_files))) for sample_idx, img_file in enumerate(selected_files): img_path = os.path.join(class_dir, img_file) img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if samples_per_class == 1: ax = axes[class_idx] else: ax = axes[class_idx, sample_idx] ax.imshow(img) ax.axis('off') if sample_idx == 0: ax.set_ylabel(class_name, fontsize=12, fontweight='bold') # 填充空白位置 for sample_idx in range(len(selected_files), samples_per_class): if samples_per_class == 1: ax = axes[class_idx] else: ax = axes[class_idx, sample_idx] ax.axis('off') plt.suptitle('各类别样本展示', fontsize=16, fontweight='bold') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"样本图像已保存: {save_path}") return fig def plot_model_comparison(results: List[Dict], save_path: str = None, figsize: Tuple[int, int] = (15, 10)): """ 比较不同模型的性能 Args: results: 模型结果列表,每个元素包含模型名称和指标 save_path: 保存路径 figsize: 图像大小 """ metrics_to_plot = ['accuracy', 'macro_f1', 'macro_precision', 'macro_recall'] model_names = [r['model_name'] for r in results] fig, axes = plt.subplots(2, 2, figsize=figsize) axes = axes.ravel() for idx, metric in enumerate(metrics_to_plot): values = [r['metrics'].get(metric, 0) for r in results] colors = plt.cm.Set2(np.linspace(0, 1, len(model_names))) bars = axes[idx].bar(model_names, values, color=colors, edgecolor='black', linewidth=1) axes[idx].set_title(f'{metric.replace("_", " ").title()}', fontsize=12, fontweight='bold') axes[idx].set_ylabel('Score') axes[idx].tick_params(axis='x', rotation=45) axes[idx].grid(True, alpha=0.3) # 添加数值标签 for bar, value in zip(bars, values): axes[idx].text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(values)*0.01, f'{value:.3f}', ha='center', va='bottom', fontweight='bold') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"模型比较图已保存: {save_path}") return fig def create_interactive_training_dashboard(train_history: Dict, save_path: str = None): """ 创建交互式训练仪表板 Args: train_history: 训练历史 save_path: 保存路径(HTML文件) """ # 惰性导入 plotly,未安装则给出提示 try: import plotly.graph_objects as go from plotly.subplots import make_subplots except Exception as exc: raise ImportError( "使用 create_interactive_training_dashboard 需要安装 plotly,请运行: pip install plotly" ) from exc epochs = list(range(1, len(train_history['train_loss']) + 1)) # 创建子图 fig = make_subplots( rows=2, cols=2, subplot_titles=('损失曲线', '准确率曲线', '学习率变化', '训练摘要'), specs=[[{"secondary_y": False}, {"secondary_y": False}], [{"secondary_y": False}, {"type": "table"}]] ) # 损失曲线 fig.add_trace( go.Scatter(x=epochs, y=train_history['train_loss'], name='训练损失', line=dict(color='blue')), row=1, col=1 ) fig.add_trace( go.Scatter(x=epochs, y=train_history['val_loss'], name='验证损失', line=dict(color='red')), row=1, col=1 ) # 准确率曲线 fig.add_trace( go.Scatter(x=epochs, y=train_history['train_acc'], name='训练准确率', line=dict(color='blue')), row=1, col=2 ) fig.add_trace( go.Scatter(x=epochs, y=train_history['val_acc'], name='验证准确率', line=dict(color='red')), row=1, col=2 ) # 学习率变化 if 'lr' in train_history: fig.add_trace( go.Scatter(x=epochs, y=train_history['lr'], name='学习率', line=dict(color='green')), row=2, col=1 ) fig.update_yaxes(type="log", row=2, col=1) # 训练摘要表格 best_val_acc = max(train_history['val_acc']) best_epoch = train_history['val_acc'].index(best_val_acc) + 1 summary_data = [ ['指标', '数值'], ['最佳验证准确率', f'{best_val_acc:.2f}%'], ['最佳Epoch', str(best_epoch)], ['最终训练损失', f'{train_history["train_loss"][-1]:.4f}'], ['最终验证损失', f'{train_history["val_loss"][-1]:.4f}'], ['总训练轮数', str(len(epochs))] ] fig.add_trace( go.Table( header=dict(values=summary_data[0], fill_color='lightblue'), cells=dict(values=list(zip(*summary_data[1:])), fill_color='white') ), row=2, col=2 ) # 更新布局 fig.update_layout( title='训练过程可视化仪表板', height=800, showlegend=False ) if save_path: fig.write_html(save_path) print(f"交互式仪表板已保存: {save_path}") return fig def visualize_feature_maps(model: torch.nn.Module, image: torch.Tensor, layer_name: str, save_path: str = None, figsize: Tuple[int, int] = (20, 15)): """ 可视化特征图 Args: model: PyTorch模型 image: 输入图像tensor layer_name: 要可视化的层名称 save_path: 保存路径 figsize: 图像大小 """ # 注册hook获取特征图 features = {} def hook_fn(module, input, output): features['feature_map'] = output # 找到目标层并注册hook target_layer = None for name, module in model.named_modules(): if layer_name in name: target_layer = module break if target_layer is None: print(f"未找到层: {layer_name}") return None handle = target_layer.register_forward_hook(hook_fn) # 前向传播 model.eval() with torch.no_grad(): _ = model(image.unsqueeze(0)) # 移除hook handle.remove() # 获取特征图 feature_map = features['feature_map'].squeeze(0) # 移除batch维度 n_features = feature_map.shape[0] # 计算网格大小 n_cols = 8 n_rows = (n_features + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) if n_rows == 1: axes = [axes] axes = np.array(axes).ravel() for i in range(n_features): feature = feature_map[i].cpu().numpy() # 标准化到[0,1] feature = (feature - feature.min()) / (feature.max() - feature.min() + 1e-8) axes[i].imshow(feature, cmap='viridis') axes[i].set_title(f'Feature {i+1}') axes[i].axis('off') # 隐藏多余的子图 for i in range(n_features, len(axes)): axes[i].axis('off') plt.suptitle(f'特征图可视化 - {layer_name}', fontsize=16, fontweight='bold') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"特征图已保存: {save_path}") return fig def plot_attention_weights(attention_weights: np.ndarray, tokens: List[str] = None, save_path: str = None, figsize: Tuple[int, int] = (12, 10)): """ 可视化注意力权重(适用于Vision Transformer) Args: attention_weights: 注意力权重矩阵 tokens: token标签 save_path: 保存路径 figsize: 图像大小 """ plt.figure(figsize=figsize) if tokens is None: tokens = [f'Token {i+1}' for i in range(attention_weights.shape[0])] # 创建热力图 sns.heatmap(attention_weights, xticklabels=tokens, yticklabels=tokens, cmap='Blues', annot=False, square=True, cbar_kws={'label': 'Attention Weight'}) plt.title('注意力权重可视化', fontsize=16, fontweight='bold') plt.xlabel('Key Tokens') plt.ylabel('Query Tokens') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"注意力权重图已保存: {save_path}") return plt.gcf() def create_prediction_gallery(images: List[str], predictions: List[Dict], true_labels: List[str] = None, save_path: str = None, images_per_row: int = 4, figsize: Tuple[int, int] = (20, 15)): """ 创建预测结果画廊 Args: images: 图像路径列表 predictions: 预测结果列表 true_labels: 真实标签列表 save_path: 保存路径 images_per_row: 每行图像数 figsize: 图像大小 """ n_images = len(images) n_rows = (n_images + images_per_row - 1) // images_per_row fig, axes = plt.subplots(n_rows, images_per_row, figsize=figsize) if n_rows == 1: axes = [axes] axes = np.array(axes).ravel() for i, (img_path, pred) in enumerate(zip(images, predictions)): if i >= len(axes): break # 读取并显示图像 img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) axes[i].imshow(img) # 创建标题 title = f"预测: {pred['predicted_label']}\n" title += f"置信度: {pred['confidence']:.3f}" if true_labels and i < len(true_labels): title = f"真实: {true_labels[i]}\n" + title # 如果预测错误,用红色标题 if pred['predicted_label'] != true_labels[i]: axes[i].set_title(title, color='red', fontweight='bold') else: axes[i].set_title(title, color='green', fontweight='bold') else: axes[i].set_title(title, fontweight='bold') axes[i].axis('off') # 隐藏多余的子图 for i in range(len(images), len(axes)): axes[i].axis('off') plt.suptitle('预测结果画廊', fontsize=16, fontweight='bold') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"预测画廊已保存: {save_path}") return fig if __name__ == "__main__": # 测试可视化功能 # 模拟训练历史数据 epochs = 50 train_history = { 'train_loss': [1.5 * np.exp(-0.1 * i) + 0.1 + 0.05 * np.random.randn() for i in range(epochs)], 'val_loss': [1.6 * np.exp(-0.08 * i) + 0.15 + 0.08 * np.random.randn() for i in range(epochs)], 'train_acc': [60 * (1 - np.exp(-0.1 * i)) + 10 * np.random.randn() for i in range(epochs)], 'val_acc': [55 * (1 - np.exp(-0.08 * i)) + 12 * np.random.randn() for i in range(epochs)], 'lr': [0.001 * (0.9 ** (i // 10)) for i in range(epochs)] } # 绘制训练曲线 fig = plot_training_curves(train_history, 'test_training_curves.png') plt.close() # 创建交互式仪表板 interactive_fig = create_interactive_training_dashboard(train_history, 'test_dashboard.html') print("可视化测试完成!")