Spaces:
Running
Running
""" | |
可视化工具模块 | |
包含训练过程可视化、模型解释性可视化等 | |
""" | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from typing import List, Tuple, Dict, Optional, Union | |
import pandas as pd | |
from matplotlib.animation import FuncAnimation | |
import warnings | |
warnings.filterwarnings('ignore') | |
# 设置matplotlib中文字体 | |
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] | |
plt.rcParams['axes.unicode_minus'] = False | |
def plot_training_curves(train_history: Dict, save_path: str = None, | |
figsize: Tuple[int, int] = (15, 10)): | |
""" | |
绘制训练曲线 | |
Args: | |
train_history: 训练历史字典 | |
save_path: 保存路径 | |
figsize: 图像大小 | |
""" | |
fig, axes = plt.subplots(2, 2, figsize=figsize) | |
epochs = range(1, len(train_history['train_loss']) + 1) | |
# 损失曲线 | |
axes[0, 0].plot(epochs, train_history['train_loss'], 'b-', label='训练损失', linewidth=2) | |
axes[0, 0].plot(epochs, train_history['val_loss'], 'r-', label='验证损失', linewidth=2) | |
axes[0, 0].set_title('损失曲线', fontsize=14, fontweight='bold') | |
axes[0, 0].set_xlabel('Epoch') | |
axes[0, 0].set_ylabel('Loss') | |
axes[0, 0].legend() | |
axes[0, 0].grid(True, alpha=0.3) | |
# 准确率曲线 | |
axes[0, 1].plot(epochs, train_history['train_acc'], 'b-', label='训练准确率', linewidth=2) | |
axes[0, 1].plot(epochs, train_history['val_acc'], 'r-', label='验证准确率', linewidth=2) | |
axes[0, 1].set_title('准确率曲线', fontsize=14, fontweight='bold') | |
axes[0, 1].set_xlabel('Epoch') | |
axes[0, 1].set_ylabel('Accuracy (%)') | |
axes[0, 1].legend() | |
axes[0, 1].grid(True, alpha=0.3) | |
# 学习率曲线 | |
if 'lr' in train_history: | |
axes[1, 0].plot(epochs, train_history['lr'], 'g-', linewidth=2) | |
axes[1, 0].set_title('学习率变化', fontsize=14, fontweight='bold') | |
axes[1, 0].set_xlabel('Epoch') | |
axes[1, 0].set_ylabel('Learning Rate') | |
axes[1, 0].set_yscale('log') | |
axes[1, 0].grid(True, alpha=0.3) | |
# 训练摘要 | |
best_val_acc = max(train_history['val_acc']) | |
best_epoch = train_history['val_acc'].index(best_val_acc) + 1 | |
final_train_loss = train_history['train_loss'][-1] | |
final_val_loss = train_history['val_loss'][-1] | |
summary_text = f"""训练摘要: | |
最佳验证准确率: {best_val_acc:.2f}% | |
最佳epoch: {best_epoch} | |
最终训练损失: {final_train_loss:.4f} | |
最终验证损失: {final_val_loss:.4f} | |
总训练轮数: {len(epochs)} | |
""" | |
axes[1, 1].text(0.1, 0.9, summary_text, transform=axes[1, 1].transAxes, | |
fontsize=12, verticalalignment='top', | |
bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8)) | |
axes[1, 1].set_xlim(0, 1) | |
axes[1, 1].set_ylim(0, 1) | |
axes[1, 1].axis('off') | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"训练曲线已保存: {save_path}") | |
return fig | |
def plot_class_distribution(data_dir: str, class_names: List[str] = None, | |
save_path: str = None, figsize: Tuple[int, int] = (12, 8)): | |
""" | |
绘制数据集类别分布 | |
Args: | |
data_dir: 数据目录 | |
class_names: 类别名称 | |
save_path: 保存路径 | |
figsize: 图像大小 | |
""" | |
import os | |
class_counts = {} | |
# 统计各类别样本数 | |
for class_idx, class_folder in enumerate(os.listdir(data_dir)): | |
class_path = os.path.join(data_dir, class_folder) | |
if os.path.isdir(class_path): | |
count = len([f for f in os.listdir(class_path) | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) | |
if class_names and class_idx < len(class_names): | |
class_name = class_names[class_idx] | |
else: | |
class_name = class_folder | |
class_counts[class_name] = count | |
# 绘制条形图 | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) | |
# 条形图 | |
classes = list(class_counts.keys()) | |
counts = list(class_counts.values()) | |
colors = plt.cm.Set3(np.linspace(0, 1, len(classes))) | |
bars = ax1.bar(classes, counts, color=colors, edgecolor='black', linewidth=1) | |
ax1.set_title('类别样本分布', fontsize=14, fontweight='bold') | |
ax1.set_xlabel('类别') | |
ax1.set_ylabel('样本数量') | |
ax1.tick_params(axis='x', rotation=45) | |
# 添加数值标签 | |
for bar, count in zip(bars, counts): | |
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(counts)*0.01, | |
str(count), ha='center', va='bottom', fontweight='bold') | |
# 饼图 | |
ax2.pie(counts, labels=classes, autopct='%1.1f%%', colors=colors, | |
startangle=90, wedgeprops={'edgecolor': 'black'}) | |
ax2.set_title('类别比例分布', fontsize=14, fontweight='bold') | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"类别分布图已保存: {save_path}") | |
return fig, class_counts | |
def plot_sample_images(data_dir: str, class_names: List[str], | |
samples_per_class: int = 3, | |
save_path: str = None, figsize: Tuple[int, int] = (15, 10)): | |
""" | |
显示每个类别的样本图像 | |
Args: | |
data_dir: 数据目录 | |
class_names: 类别名称 | |
samples_per_class: 每类显示的样本数 | |
save_path: 保存路径 | |
figsize: 图像大小 | |
""" | |
import os | |
import random | |
n_classes = len(class_names) | |
fig, axes = plt.subplots(n_classes, samples_per_class, figsize=figsize) | |
if n_classes == 1: | |
axes = [axes] | |
for class_idx, class_name in enumerate(class_names): | |
class_dir = os.path.join(data_dir, class_name) | |
if not os.path.exists(class_dir): | |
# 尝试用索引查找目录 | |
class_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))] | |
if class_idx < len(class_dirs): | |
class_dir = os.path.join(data_dir, class_dirs[class_idx]) | |
if os.path.exists(class_dir): | |
# 获取图像文件 | |
image_files = [f for f in os.listdir(class_dir) | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
# 随机选择样本 | |
selected_files = random.sample(image_files, | |
min(samples_per_class, len(image_files))) | |
for sample_idx, img_file in enumerate(selected_files): | |
img_path = os.path.join(class_dir, img_file) | |
img = cv2.imread(img_path) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
if samples_per_class == 1: | |
ax = axes[class_idx] | |
else: | |
ax = axes[class_idx, sample_idx] | |
ax.imshow(img) | |
ax.axis('off') | |
if sample_idx == 0: | |
ax.set_ylabel(class_name, fontsize=12, fontweight='bold') | |
# 填充空白位置 | |
for sample_idx in range(len(selected_files), samples_per_class): | |
if samples_per_class == 1: | |
ax = axes[class_idx] | |
else: | |
ax = axes[class_idx, sample_idx] | |
ax.axis('off') | |
plt.suptitle('各类别样本展示', fontsize=16, fontweight='bold') | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"样本图像已保存: {save_path}") | |
return fig | |
def plot_model_comparison(results: List[Dict], save_path: str = None, | |
figsize: Tuple[int, int] = (15, 10)): | |
""" | |
比较不同模型的性能 | |
Args: | |
results: 模型结果列表,每个元素包含模型名称和指标 | |
save_path: 保存路径 | |
figsize: 图像大小 | |
""" | |
metrics_to_plot = ['accuracy', 'macro_f1', 'macro_precision', 'macro_recall'] | |
model_names = [r['model_name'] for r in results] | |
fig, axes = plt.subplots(2, 2, figsize=figsize) | |
axes = axes.ravel() | |
for idx, metric in enumerate(metrics_to_plot): | |
values = [r['metrics'].get(metric, 0) for r in results] | |
colors = plt.cm.Set2(np.linspace(0, 1, len(model_names))) | |
bars = axes[idx].bar(model_names, values, color=colors, | |
edgecolor='black', linewidth=1) | |
axes[idx].set_title(f'{metric.replace("_", " ").title()}', | |
fontsize=12, fontweight='bold') | |
axes[idx].set_ylabel('Score') | |
axes[idx].tick_params(axis='x', rotation=45) | |
axes[idx].grid(True, alpha=0.3) | |
# 添加数值标签 | |
for bar, value in zip(bars, values): | |
axes[idx].text(bar.get_x() + bar.get_width()/2, | |
bar.get_height() + max(values)*0.01, | |
f'{value:.3f}', ha='center', va='bottom', fontweight='bold') | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"模型比较图已保存: {save_path}") | |
return fig | |
def create_interactive_training_dashboard(train_history: Dict, save_path: str = None): | |
""" | |
创建交互式训练仪表板 | |
Args: | |
train_history: 训练历史 | |
save_path: 保存路径(HTML文件) | |
""" | |
# 惰性导入 plotly,未安装则给出提示 | |
try: | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
except Exception as exc: | |
raise ImportError( | |
"使用 create_interactive_training_dashboard 需要安装 plotly,请运行: pip install plotly" | |
) from exc | |
epochs = list(range(1, len(train_history['train_loss']) + 1)) | |
# 创建子图 | |
fig = make_subplots( | |
rows=2, cols=2, | |
subplot_titles=('损失曲线', '准确率曲线', '学习率变化', '训练摘要'), | |
specs=[[{"secondary_y": False}, {"secondary_y": False}], | |
[{"secondary_y": False}, {"type": "table"}]] | |
) | |
# 损失曲线 | |
fig.add_trace( | |
go.Scatter(x=epochs, y=train_history['train_loss'], | |
name='训练损失', line=dict(color='blue')), | |
row=1, col=1 | |
) | |
fig.add_trace( | |
go.Scatter(x=epochs, y=train_history['val_loss'], | |
name='验证损失', line=dict(color='red')), | |
row=1, col=1 | |
) | |
# 准确率曲线 | |
fig.add_trace( | |
go.Scatter(x=epochs, y=train_history['train_acc'], | |
name='训练准确率', line=dict(color='blue')), | |
row=1, col=2 | |
) | |
fig.add_trace( | |
go.Scatter(x=epochs, y=train_history['val_acc'], | |
name='验证准确率', line=dict(color='red')), | |
row=1, col=2 | |
) | |
# 学习率变化 | |
if 'lr' in train_history: | |
fig.add_trace( | |
go.Scatter(x=epochs, y=train_history['lr'], | |
name='学习率', line=dict(color='green')), | |
row=2, col=1 | |
) | |
fig.update_yaxes(type="log", row=2, col=1) | |
# 训练摘要表格 | |
best_val_acc = max(train_history['val_acc']) | |
best_epoch = train_history['val_acc'].index(best_val_acc) + 1 | |
summary_data = [ | |
['指标', '数值'], | |
['最佳验证准确率', f'{best_val_acc:.2f}%'], | |
['最佳Epoch', str(best_epoch)], | |
['最终训练损失', f'{train_history["train_loss"][-1]:.4f}'], | |
['最终验证损失', f'{train_history["val_loss"][-1]:.4f}'], | |
['总训练轮数', str(len(epochs))] | |
] | |
fig.add_trace( | |
go.Table( | |
header=dict(values=summary_data[0], fill_color='lightblue'), | |
cells=dict(values=list(zip(*summary_data[1:])), fill_color='white') | |
), | |
row=2, col=2 | |
) | |
# 更新布局 | |
fig.update_layout( | |
title='训练过程可视化仪表板', | |
height=800, | |
showlegend=False | |
) | |
if save_path: | |
fig.write_html(save_path) | |
print(f"交互式仪表板已保存: {save_path}") | |
return fig | |
def visualize_feature_maps(model: torch.nn.Module, image: torch.Tensor, | |
layer_name: str, save_path: str = None, | |
figsize: Tuple[int, int] = (20, 15)): | |
""" | |
可视化特征图 | |
Args: | |
model: PyTorch模型 | |
image: 输入图像tensor | |
layer_name: 要可视化的层名称 | |
save_path: 保存路径 | |
figsize: 图像大小 | |
""" | |
# 注册hook获取特征图 | |
features = {} | |
def hook_fn(module, input, output): | |
features['feature_map'] = output | |
# 找到目标层并注册hook | |
target_layer = None | |
for name, module in model.named_modules(): | |
if layer_name in name: | |
target_layer = module | |
break | |
if target_layer is None: | |
print(f"未找到层: {layer_name}") | |
return None | |
handle = target_layer.register_forward_hook(hook_fn) | |
# 前向传播 | |
model.eval() | |
with torch.no_grad(): | |
_ = model(image.unsqueeze(0)) | |
# 移除hook | |
handle.remove() | |
# 获取特征图 | |
feature_map = features['feature_map'].squeeze(0) # 移除batch维度 | |
n_features = feature_map.shape[0] | |
# 计算网格大小 | |
n_cols = 8 | |
n_rows = (n_features + n_cols - 1) // n_cols | |
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) | |
if n_rows == 1: | |
axes = [axes] | |
axes = np.array(axes).ravel() | |
for i in range(n_features): | |
feature = feature_map[i].cpu().numpy() | |
# 标准化到[0,1] | |
feature = (feature - feature.min()) / (feature.max() - feature.min() + 1e-8) | |
axes[i].imshow(feature, cmap='viridis') | |
axes[i].set_title(f'Feature {i+1}') | |
axes[i].axis('off') | |
# 隐藏多余的子图 | |
for i in range(n_features, len(axes)): | |
axes[i].axis('off') | |
plt.suptitle(f'特征图可视化 - {layer_name}', fontsize=16, fontweight='bold') | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"特征图已保存: {save_path}") | |
return fig | |
def plot_attention_weights(attention_weights: np.ndarray, tokens: List[str] = None, | |
save_path: str = None, figsize: Tuple[int, int] = (12, 10)): | |
""" | |
可视化注意力权重(适用于Vision Transformer) | |
Args: | |
attention_weights: 注意力权重矩阵 | |
tokens: token标签 | |
save_path: 保存路径 | |
figsize: 图像大小 | |
""" | |
plt.figure(figsize=figsize) | |
if tokens is None: | |
tokens = [f'Token {i+1}' for i in range(attention_weights.shape[0])] | |
# 创建热力图 | |
sns.heatmap(attention_weights, | |
xticklabels=tokens, yticklabels=tokens, | |
cmap='Blues', annot=False, square=True, | |
cbar_kws={'label': 'Attention Weight'}) | |
plt.title('注意力权重可视化', fontsize=16, fontweight='bold') | |
plt.xlabel('Key Tokens') | |
plt.ylabel('Query Tokens') | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"注意力权重图已保存: {save_path}") | |
return plt.gcf() | |
def create_prediction_gallery(images: List[str], predictions: List[Dict], | |
true_labels: List[str] = None, | |
save_path: str = None, | |
images_per_row: int = 4, | |
figsize: Tuple[int, int] = (20, 15)): | |
""" | |
创建预测结果画廊 | |
Args: | |
images: 图像路径列表 | |
predictions: 预测结果列表 | |
true_labels: 真实标签列表 | |
save_path: 保存路径 | |
images_per_row: 每行图像数 | |
figsize: 图像大小 | |
""" | |
n_images = len(images) | |
n_rows = (n_images + images_per_row - 1) // images_per_row | |
fig, axes = plt.subplots(n_rows, images_per_row, figsize=figsize) | |
if n_rows == 1: | |
axes = [axes] | |
axes = np.array(axes).ravel() | |
for i, (img_path, pred) in enumerate(zip(images, predictions)): | |
if i >= len(axes): | |
break | |
# 读取并显示图像 | |
img = cv2.imread(img_path) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
axes[i].imshow(img) | |
# 创建标题 | |
title = f"预测: {pred['predicted_label']}\n" | |
title += f"置信度: {pred['confidence']:.3f}" | |
if true_labels and i < len(true_labels): | |
title = f"真实: {true_labels[i]}\n" + title | |
# 如果预测错误,用红色标题 | |
if pred['predicted_label'] != true_labels[i]: | |
axes[i].set_title(title, color='red', fontweight='bold') | |
else: | |
axes[i].set_title(title, color='green', fontweight='bold') | |
else: | |
axes[i].set_title(title, fontweight='bold') | |
axes[i].axis('off') | |
# 隐藏多余的子图 | |
for i in range(len(images), len(axes)): | |
axes[i].axis('off') | |
plt.suptitle('预测结果画廊', fontsize=16, fontweight='bold') | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"预测画廊已保存: {save_path}") | |
return fig | |
if __name__ == "__main__": | |
# 测试可视化功能 | |
# 模拟训练历史数据 | |
epochs = 50 | |
train_history = { | |
'train_loss': [1.5 * np.exp(-0.1 * i) + 0.1 + 0.05 * np.random.randn() for i in range(epochs)], | |
'val_loss': [1.6 * np.exp(-0.08 * i) + 0.15 + 0.08 * np.random.randn() for i in range(epochs)], | |
'train_acc': [60 * (1 - np.exp(-0.1 * i)) + 10 * np.random.randn() for i in range(epochs)], | |
'val_acc': [55 * (1 - np.exp(-0.08 * i)) + 12 * np.random.randn() for i in range(epochs)], | |
'lr': [0.001 * (0.9 ** (i // 10)) for i in range(epochs)] | |
} | |
# 绘制训练曲线 | |
fig = plot_training_curves(train_history, 'test_training_curves.png') | |
plt.close() | |
# 创建交互式仪表板 | |
interactive_fig = create_interactive_training_dashboard(train_history, 'test_dashboard.html') | |
print("可视化测试完成!") | |