Spaces:
Running
Running
""" | |
训练模块 | |
包含训练循环、验证、早停等功能 | |
""" | |
import os | |
import time | |
import copy | |
import yaml | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, ReduceLROnPlateau | |
from torch.cuda.amp import GradScaler, autocast | |
import numpy as np | |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from tqdm import tqdm | |
from typing import Dict, List, Tuple, Optional | |
import logging | |
from tensorboardX import SummaryWriter | |
from src.data_loader import create_data_loaders, get_class_weights | |
from src.models import create_model, count_parameters, model_size_mb | |
from utils.metrics import calculate_metrics, plot_confusion_matrix | |
class EarlyStopping: | |
"""早停机制""" | |
def __init__(self, patience: int = 7, min_delta: float = 0.0, | |
restore_best_weights: bool = True): | |
self.patience = patience | |
self.min_delta = min_delta | |
self.restore_best_weights = restore_best_weights | |
self.best_loss = None | |
self.counter = 0 | |
self.best_weights = None | |
def __call__(self, val_loss: float, model: nn.Module) -> bool: | |
if self.best_loss is None: | |
self.best_loss = val_loss | |
self.best_weights = copy.deepcopy(model.state_dict()) | |
elif val_loss < self.best_loss - self.min_delta: | |
self.best_loss = val_loss | |
self.counter = 0 | |
self.best_weights = copy.deepcopy(model.state_dict()) | |
else: | |
self.counter += 1 | |
if self.counter >= self.patience: | |
if self.restore_best_weights: | |
model.load_state_dict(self.best_weights) | |
return True | |
return False | |
class DRTrainer: | |
def run_qat(self): | |
"""量化感知训练(QAT)流程""" | |
qat_cfg = self.config['training'] | |
if not qat_cfg.get('qat', False): | |
return | |
import copy | |
import torch.quantization as tq | |
qat_epochs = qat_cfg.get('qat_epochs', 10) | |
qat_backend = qat_cfg.get('qat_backend', 'fbgemm') | |
export_path = qat_cfg.get('qat_export_path', 'weights/qat_model.onnx') | |
self.logger.info(f"开始QAT微调: epochs={qat_epochs}, backend={qat_backend}") | |
# 1. 准备量化模型 | |
model_qat = copy.deepcopy(self.model).to(self.device) | |
model_qat.train() | |
model_qat.fuse_model = getattr(model_qat, 'fuse_model', None) | |
if model_qat.fuse_model: | |
model_qat.fuse_model() | |
tq.backend = qat_backend | |
model_qat.qconfig = tq.get_default_qat_qconfig(qat_backend) | |
tq.prepare_qat(model_qat, inplace=True) | |
optimizer = torch.optim.Adam(model_qat.parameters(), lr=1e-4) | |
criterion = nn.CrossEntropyLoss() | |
# 2. QAT训练 | |
for epoch in range(qat_epochs): | |
model_qat.train() | |
running_loss = 0.0 | |
correct = 0 | |
total = 0 | |
for images, labels in self.train_loader: | |
images, labels = images.to(self.device), labels.to(self.device) | |
optimizer.zero_grad() | |
outputs = model_qat(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
_, predicted = outputs.max(1) | |
total += labels.size(0) | |
correct += predicted.eq(labels).sum().item() | |
avg_loss = running_loss / len(self.train_loader) | |
acc = 100. * correct / total | |
self.logger.info(f"[QAT] Epoch {epoch+1}/{qat_epochs} Loss: {avg_loss:.4f} Acc: {acc:.2f}%") | |
# 3. 转换为量化模型 | |
model_qat.eval() | |
model_int8 = tq.convert(model_qat.cpu().eval(), inplace=False) | |
self.logger.info("QAT模型量化完成,准备导出ONNX...") | |
# 4. 导出ONNX | |
dummy = torch.randn(1, 3, self.config['data']['image_size'], self.config['data']['image_size']) | |
torch.onnx.export(model_int8, dummy, export_path, input_names=['input'], output_names=['output'], opset_version=12) | |
self.logger.info(f"QAT量化模型已导出: {export_path}") | |
"""糖尿病视网膜病变检测模型训练器""" | |
def __init__(self, config: dict): | |
self.config = config | |
self.device = torch.device( | |
f"cuda:{config['device']['gpu_id']}" | |
if config['device']['use_gpu'] and torch.cuda.is_available() | |
else "cpu" | |
) | |
# 创建日志目录 | |
os.makedirs(config['logging']['log_dir'], exist_ok=True) | |
os.makedirs(config['logging']['tensorboard_dir'], exist_ok=True) | |
# 确保权重保存目录存在 | |
os.makedirs(os.path.dirname(config['training']['model_save_path']), exist_ok=True) | |
# 设置日志 | |
self._setup_logging() | |
# 初始化模型 | |
self.model = create_model(config).to(self.device) | |
self.logger.info(f"模型参数数量: {count_parameters(self.model):,}") | |
self.logger.info(f"模型大小: {model_size_mb(self.model):.2f} MB") | |
# 创建数据加载器 | |
self.train_loader, self.val_loader, self.test_loader = create_data_loaders(config) | |
# === 知识蒸馏相关 === | |
self.distill = self.config['training'].get('distill', False) | |
self.teacher_model = None | |
if self.distill: | |
from utils.losses import DistillationLoss | |
teacher_name = self.config['training'].get('distill_teacher', 'efficientnet_b3') | |
student_name = self.config['training'].get('distill_student', self.config['model']['name']) | |
# student模型用config['model'],teacher模型用teacher_name | |
teacher_config = copy.deepcopy(self.config) | |
teacher_config['model']['name'] = teacher_name | |
self.teacher_model = create_model(teacher_config).to(self.device) | |
self.teacher_model.eval() | |
# teacher权重加载(如有) | |
teacher_ckpt = self.config['training'].get('distill_teacher_ckpt', None) | |
if teacher_ckpt and os.path.exists(teacher_ckpt): | |
state = torch.load(teacher_ckpt, map_location=self.device) | |
if 'model_state_dict' in state: | |
self.teacher_model.load_state_dict(state['model_state_dict']) | |
else: | |
self.teacher_model.load_state_dict(state) | |
self.logger.info(f"已加载teacher模型权重: {teacher_ckpt}") | |
else: | |
self.logger.warning("未指定teacher权重,teacher模型将使用随机初始化!") | |
alpha = self.config['training'].get('distill_alpha', 0.7) | |
beta = self.config['training'].get('distill_beta', 0.3) | |
temperature = self.config['training'].get('distill_temperature', 4.0) | |
self.criterion = DistillationLoss(alpha=alpha, beta=beta, temperature=temperature) | |
else: | |
# 创建损失函数(支持类别权重、Focal Loss) | |
label_smoothing = self.config['training'].get('label_smoothing', 0.0) | |
use_focal = self.config['training'].get('use_focal_loss', False) | |
class_weights = None | |
if config['data'].get('use_class_weights', False): | |
class_weights = get_class_weights( | |
config['data']['train_dir'], | |
config['model']['num_classes'] | |
).to(self.device) | |
# 自动写入 config.yaml | |
try: | |
with open('configs/config.yaml', 'r', encoding='utf-8') as f: | |
cfg = yaml.safe_load(f) | |
cfg['training']['class_weights'] = [float(w) for w in class_weights.cpu().numpy()] | |
with open('configs/config.yaml', 'w', encoding='utf-8') as f: | |
yaml.dump(cfg, f, allow_unicode=True) | |
except Exception as e: | |
self.logger.warning(f"自动写入类别权重到 config.yaml 失败: {e}") | |
if use_focal: | |
from utils.losses import FocalLoss | |
gamma = self.config['training'].get('focal_gamma', 2.0) | |
alpha = self.config['training'].get('focal_alpha', None) | |
if alpha is not None: | |
alpha = torch.tensor(alpha, dtype=torch.float32, device=self.device) | |
elif class_weights is not None: | |
alpha = class_weights | |
self.criterion = FocalLoss(alpha=alpha, gamma=gamma) | |
else: | |
self.criterion = nn.CrossEntropyLoss( | |
weight=class_weights, | |
label_smoothing=label_smoothing if label_smoothing > 0 else 0.0, | |
) | |
# 创建优化器 | |
self.optimizer = self._create_optimizer() | |
# 创建学习率调度器 | |
self.scheduler = self._create_scheduler() | |
# 混合精度训练 | |
self.use_amp = config['device'].get('mixed_precision', False) | |
if self.use_amp: | |
self.scaler = GradScaler() | |
# 早停 | |
early_stopping_config = config['training'] | |
self.early_stopping = EarlyStopping( | |
patience=early_stopping_config.get('early_stopping_patience', 10) | |
) | |
# TensorBoard | |
self.writer = SummaryWriter(config['logging']['tensorboard_dir']) | |
# 训练历史 | |
self.train_history = { | |
'train_loss': [], | |
'train_acc': [], | |
'val_loss': [], | |
'val_acc': [], | |
'lr': [] | |
} | |
self.best_val_acc = 0.0 | |
self.start_epoch = 0 | |
def _setup_logging(self): | |
"""设置日志""" | |
log_file = os.path.join(self.config['logging']['log_dir'], 'training.log') | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(log_file, encoding='utf-8'), | |
logging.StreamHandler() | |
] | |
) | |
self.logger = logging.getLogger(__name__) | |
def _create_optimizer(self) -> optim.Optimizer: | |
"""创建优化器""" | |
opt_config = self.config['optimizer'] | |
lr = self.config['training']['learning_rate'] | |
weight_decay = self.config['training']['weight_decay'] | |
if opt_config['name'].lower() == 'adam': | |
optimizer = optim.Adam( | |
self.model.parameters(), | |
lr=lr, | |
weight_decay=weight_decay, | |
betas=(opt_config.get('beta1', 0.9), opt_config.get('beta2', 0.999)) | |
) | |
elif opt_config['name'].lower() == 'adamw': | |
optimizer = optim.AdamW( | |
self.model.parameters(), | |
lr=lr, | |
weight_decay=weight_decay, | |
betas=(opt_config.get('beta1', 0.9), opt_config.get('beta2', 0.999)) | |
) | |
elif opt_config['name'].lower() == 'sgd': | |
optimizer = optim.SGD( | |
self.model.parameters(), | |
lr=lr, | |
weight_decay=weight_decay, | |
momentum=opt_config.get('momentum', 0.9) | |
) | |
else: | |
raise ValueError(f"不支持的优化器: {opt_config['name']}") | |
return optimizer | |
def _create_scheduler(self): | |
"""创建学习率调度器""" | |
scheduler_name = self.config['training'].get('scheduler', 'cosine') | |
if scheduler_name == 'cosine': | |
scheduler = CosineAnnealingLR( | |
self.optimizer, | |
T_max=self.config['training']['epochs'] | |
) | |
elif scheduler_name == 'step': | |
scheduler = StepLR( | |
self.optimizer, | |
step_size=30, | |
gamma=0.1 | |
) | |
elif scheduler_name == 'plateau': | |
scheduler = ReduceLROnPlateau( | |
self.optimizer, | |
mode='min', | |
factor=0.5, | |
patience=5, | |
verbose=True | |
) | |
else: | |
scheduler = None | |
return scheduler | |
def train_epoch(self, epoch: int) -> Tuple[float, float]: | |
"""训练一个epoch,支持多任务(分级+二分类)""" | |
self.model.train() | |
running_loss = 0.0 | |
correct = 0 | |
total = 0 | |
correct_bin = 0 | |
total_bin = 0 | |
progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}') | |
for batch_idx, batch in enumerate(progress_bar): | |
# 支持(images, label, is_diabetic) 或 (images, label) | |
if len(batch) == 3: | |
images, labels, is_diabetic = batch | |
images = images.to(self.device) | |
labels = labels.to(self.device) | |
is_diabetic = is_diabetic.to(self.device).float() | |
else: | |
images, labels = batch | |
images = images.to(self.device) | |
labels = labels.to(self.device) | |
is_diabetic = None | |
self.optimizer.zero_grad() | |
if self.use_amp: | |
with autocast(): | |
outputs = self.model(images) | |
if isinstance(outputs, dict): | |
loss_grading = self.criterion(outputs['grading'], labels) | |
if is_diabetic is not None: | |
loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic) | |
loss = loss_grading + loss_diabetic | |
else: | |
loss = loss_grading | |
else: | |
loss = self.criterion(outputs, labels) | |
self.scaler.scale(loss).backward() | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
else: | |
outputs = self.model(images) | |
if isinstance(outputs, dict): | |
loss_grading = self.criterion(outputs['grading'], labels) | |
if is_diabetic is not None: | |
loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic) | |
loss = loss_grading + loss_diabetic | |
else: | |
loss = loss_grading | |
else: | |
loss = self.criterion(outputs, labels) | |
loss.backward() | |
self.optimizer.step() | |
# 统计分级准确率 | |
if isinstance(outputs, dict): | |
out_grading = outputs['grading'] | |
_, predicted = out_grading.max(1) | |
else: | |
predicted = outputs.max(1)[1] | |
total += labels.size(0) | |
correct += predicted.eq(labels).sum().item() | |
# 统计二分类准确率 | |
if is_diabetic is not None and isinstance(outputs, dict): | |
out_bin = torch.sigmoid(outputs['diabetic']) | |
pred_bin = (out_bin > 0.5).long() | |
correct_bin += pred_bin.eq(is_diabetic.long()).sum().item() | |
total_bin += is_diabetic.size(0) | |
running_loss += loss.item() | |
# 更新进度条 | |
postfix = {'Loss': f'{loss.item():.4f}', 'Acc': f'{100.*correct/total:.2f}%'} | |
if total_bin > 0: | |
postfix['BinAcc'] = f'{100.*correct_bin/total_bin:.2f}%' | |
progress_bar.set_postfix(postfix) | |
epoch_loss = running_loss / len(self.train_loader) | |
epoch_acc = 100. * correct / total | |
return epoch_loss, epoch_acc | |
def validate(self) -> Tuple[float, float, Dict]: | |
"""多任务验证,输出分级和二分类准确率""" | |
self.model.eval() | |
running_loss = 0.0 | |
all_predictions = [] | |
all_labels = [] | |
all_bin_preds = [] | |
all_bin_labels = [] | |
with torch.no_grad(): | |
for batch in tqdm(self.val_loader, desc='Validating'): | |
if len(batch) == 3: | |
images, labels, is_diabetic = batch | |
images = images.to(self.device) | |
labels = labels.to(self.device) | |
is_diabetic = is_diabetic.to(self.device).float() | |
else: | |
images, labels = batch | |
images = images.to(self.device) | |
labels = labels.to(self.device) | |
is_diabetic = None | |
if self.use_amp: | |
with autocast(): | |
outputs = self.model(images) | |
if isinstance(outputs, dict): | |
loss_grading = self.criterion(outputs['grading'], labels) | |
if is_diabetic is not None: | |
loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic) | |
loss = loss_grading + loss_diabetic | |
else: | |
loss = loss_grading | |
else: | |
loss = self.criterion(outputs, labels) | |
else: | |
outputs = self.model(images) | |
if isinstance(outputs, dict): | |
loss_grading = self.criterion(outputs['grading'], labels) | |
if is_diabetic is not None: | |
loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic) | |
loss = loss_grading + loss_diabetic | |
else: | |
loss = loss_grading | |
else: | |
loss = self.criterion(outputs, labels) | |
running_loss += loss.item() | |
# 分级预测 | |
if isinstance(outputs, dict): | |
out_grading = outputs['grading'] | |
_, predicted = out_grading.max(1) | |
else: | |
predicted = outputs.max(1)[1] | |
all_predictions.extend(predicted.cpu().numpy()) | |
all_labels.extend(labels.cpu().numpy()) | |
# 二分类预测 | |
if is_diabetic is not None and isinstance(outputs, dict): | |
out_bin = torch.sigmoid(outputs['diabetic']) | |
pred_bin = (out_bin > 0.5).long() | |
all_bin_preds.extend(pred_bin.cpu().numpy()) | |
all_bin_labels.extend(is_diabetic.cpu().numpy()) | |
val_loss = running_loss / len(self.val_loader) | |
val_acc = 100. * accuracy_score(all_labels, all_predictions) | |
metrics = calculate_metrics(all_labels, all_predictions) | |
# 二分类准确率 | |
if all_bin_labels: | |
bin_acc = 100. * accuracy_score(all_bin_labels, all_bin_preds) | |
metrics['bin_acc'] = bin_acc | |
return val_loss, val_acc, metrics | |
def save_checkpoint(self, epoch: int, is_best: bool = False): | |
"""保存检查点""" | |
checkpoint = { | |
'epoch': epoch, | |
'model_state_dict': self.model.state_dict(), | |
'optimizer_state_dict': self.optimizer.state_dict(), | |
'best_val_acc': self.best_val_acc, | |
'train_history': self.train_history, | |
'config': self.config | |
} | |
if self.scheduler: | |
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() | |
# 保存最新检查点 | |
checkpoint_path = os.path.join( | |
os.path.dirname(self.config['training']['model_save_path']), | |
'last_checkpoint.pth' | |
) | |
torch.save(checkpoint, checkpoint_path) | |
# 保存最佳模型 | |
if is_best: | |
best_path = self.config['training']['model_save_path'] | |
torch.save(checkpoint, best_path) | |
self.logger.info(f"保存最佳模型: {best_path}") | |
def load_checkpoint(self, checkpoint_path: str): | |
"""加载检查点""" | |
if not os.path.exists(checkpoint_path): | |
self.logger.info("未找到检查点,从头开始训练") | |
return | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
self.best_val_acc = checkpoint.get('best_val_acc', 0.0) | |
self.start_epoch = checkpoint.get('epoch', 0) + 1 | |
self.train_history = checkpoint.get('train_history', self.train_history) | |
if self.scheduler and 'scheduler_state_dict' in checkpoint: | |
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
self.logger.info(f"从epoch {self.start_epoch} 恢复训练") | |
def train(self): | |
"""完整的训练流程""" | |
self.logger.info("开始训练...") | |
self.logger.info(f"训练设备: {self.device}") | |
self.logger.info(f"训练集大小: {len(self.train_loader.dataset)}") | |
self.logger.info(f"验证集大小: {len(self.val_loader.dataset)}") | |
# 尝试加载检查点 | |
checkpoint_path = os.path.join( | |
os.path.dirname(self.config['training']['model_save_path']), | |
'last_checkpoint.pth' | |
) | |
self.load_checkpoint(checkpoint_path) | |
for epoch in range(self.start_epoch, self.config['training']['epochs']): | |
start_time = time.time() | |
# 训练 | |
train_loss, train_acc = self.train_epoch(epoch) | |
# 验证 | |
val_loss, val_acc, val_metrics = self.validate() | |
# 学习率调度 | |
if self.scheduler: | |
if isinstance(self.scheduler, ReduceLROnPlateau): | |
self.scheduler.step(val_loss) | |
else: | |
self.scheduler.step() | |
# 记录历史 | |
current_lr = self.optimizer.param_groups[0]['lr'] | |
self.train_history['train_loss'].append(train_loss) | |
self.train_history['train_acc'].append(train_acc) | |
self.train_history['val_loss'].append(val_loss) | |
self.train_history['val_acc'].append(val_acc) | |
self.train_history['lr'].append(current_lr) | |
# TensorBoard记录 | |
self.writer.add_scalar('Loss/Train', train_loss, epoch) | |
self.writer.add_scalar('Loss/Val', val_loss, epoch) | |
self.writer.add_scalar('Accuracy/Train', train_acc, epoch) | |
self.writer.add_scalar('Accuracy/Val', val_acc, epoch) | |
self.writer.add_scalar('Learning_Rate', current_lr, epoch) | |
# 记录验证指标 | |
for metric_name, metric_value in val_metrics.items(): | |
if isinstance(metric_value, (int, float)): | |
self.writer.add_scalar(f'Metrics/{metric_name}', metric_value, epoch) | |
# 保存最佳模型 | |
is_best = val_acc > self.best_val_acc | |
if is_best: | |
self.best_val_acc = val_acc | |
# 定期保存检查点 | |
if (epoch + 1) % self.config['logging']['save_frequency'] == 0 or is_best: | |
self.save_checkpoint(epoch, is_best) | |
# 计算训练时间 | |
epoch_time = time.time() - start_time | |
# 打印结果 | |
self.logger.info( | |
f"Epoch [{epoch+1}/{self.config['training']['epochs']}] " | |
f"Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% " | |
f"Val Loss: {val_loss:.4f} Val Acc: {val_acc:.2f}% " | |
f"Time: {epoch_time:.2f}s LR: {current_lr:.6f}" | |
) | |
# 早停检查 | |
if self.early_stopping(val_loss, self.model): | |
self.logger.info(f"Early stopping at epoch {epoch+1}") | |
break | |
self.logger.info(f"训练完成!最佳验证准确率: {self.best_val_acc:.2f}%") | |
# 绘制训练曲线 | |
self.plot_training_history() | |
# 在测试集上评估 | |
if self.test_loader: | |
self.evaluate_on_test() | |
# === QAT流程 === | |
self.run_qat() | |
def plot_training_history(self): | |
"""绘制训练历史曲线""" | |
fig, axes = plt.subplots(2, 2, figsize=(15, 10)) | |
# 损失曲线 | |
axes[0, 0].plot(self.train_history['train_loss'], label='Train Loss') | |
axes[0, 0].plot(self.train_history['val_loss'], label='Val Loss') | |
axes[0, 0].set_title('Loss Curves') | |
axes[0, 0].set_xlabel('Epoch') | |
axes[0, 0].set_ylabel('Loss') | |
axes[0, 0].legend() | |
axes[0, 0].grid(True) | |
# 准确率曲线 | |
axes[0, 1].plot(self.train_history['train_acc'], label='Train Acc') | |
axes[0, 1].plot(self.train_history['val_acc'], label='Val Acc') | |
axes[0, 1].set_title('Accuracy Curves') | |
axes[0, 1].set_xlabel('Epoch') | |
axes[0, 1].set_ylabel('Accuracy (%)') | |
axes[0, 1].legend() | |
axes[0, 1].grid(True) | |
# 学习率曲线 | |
axes[1, 0].plot(self.train_history['lr']) | |
axes[1, 0].set_title('Learning Rate') | |
axes[1, 0].set_xlabel('Epoch') | |
axes[1, 0].set_ylabel('Learning Rate') | |
axes[1, 0].set_yscale('log') | |
axes[1, 0].grid(True) | |
# 最佳性能标记 | |
best_epoch = np.argmax(self.train_history['val_acc']) | |
axes[1, 1].text(0.1, 0.8, f'Best Val Acc: {self.best_val_acc:.2f}%', | |
transform=axes[1, 1].transAxes, fontsize=12) | |
axes[1, 1].text(0.1, 0.7, f'Best Epoch: {best_epoch + 1}', | |
transform=axes[1, 1].transAxes, fontsize=12) | |
axes[1, 1].text(0.1, 0.6, f'Total Epochs: {len(self.train_history["val_acc"])}', | |
transform=axes[1, 1].transAxes, fontsize=12) | |
axes[1, 1].axis('off') | |
plt.tight_layout() | |
plt.savefig(os.path.join(self.config['logging']['log_dir'], 'training_history.png'), | |
dpi=300, bbox_inches='tight') | |
plt.close() | |
def evaluate_on_test(self): | |
"""多任务测试集评估""" | |
self.logger.info("在测试集上评估模型...") | |
# 加载最佳模型 | |
best_model_path = self.config['training']['model_save_path'] | |
if os.path.exists(best_model_path): | |
checkpoint = torch.load(best_model_path, map_location=self.device) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
self.model.eval() | |
all_predictions = [] | |
all_labels = [] | |
all_bin_preds = [] | |
all_bin_labels = [] | |
with torch.no_grad(): | |
for batch in tqdm(self.test_loader, desc='Testing'): | |
if len(batch) == 3: | |
images, labels, is_diabetic = batch | |
images = images.to(self.device) | |
labels = labels.to(self.device) | |
is_diabetic = is_diabetic.to(self.device).float() | |
else: | |
images, labels = batch | |
images = images.to(self.device) | |
labels = labels.to(self.device) | |
is_diabetic = None | |
outputs = self.model(images) | |
# 分级预测 | |
if isinstance(outputs, dict): | |
out_grading = outputs['grading'] | |
_, predicted = out_grading.max(1) | |
else: | |
predicted = outputs.max(1)[1] | |
all_predictions.extend(predicted.cpu().numpy()) | |
all_labels.extend(labels.cpu().numpy()) | |
# 二分类预测 | |
if is_diabetic is not None and isinstance(outputs, dict): | |
out_bin = torch.sigmoid(outputs['diabetic']) | |
pred_bin = (out_bin > 0.5).long() | |
all_bin_preds.extend(pred_bin.cpu().numpy()) | |
all_bin_labels.extend(is_diabetic.cpu().numpy()) | |
# 计算指标 | |
test_metrics = calculate_metrics(all_labels, all_predictions) | |
if all_bin_labels: | |
bin_acc = 100. * accuracy_score(all_bin_labels, all_bin_preds) | |
test_metrics['bin_acc'] = bin_acc | |
# 打印结果 | |
self.logger.info("测试集结果:") | |
for metric_name, metric_value in test_metrics.items(): | |
if isinstance(metric_value, (int, float)): | |
self.logger.info(f"{metric_name}: {metric_value:.4f}") | |
# 绘制混淆矩阵 | |
cm = confusion_matrix(all_labels, all_predictions) | |
plot_confusion_matrix( | |
cm, | |
self.config['data']['class_names'], | |
save_path=os.path.join(self.config['logging']['log_dir'], 'confusion_matrix.png') | |
) | |
if __name__ == "__main__": | |
# 加载配置 | |
with open("configs/config.yaml", 'r', encoding='utf-8') as f: | |
config = yaml.safe_load(f) | |
# 创建训练器并开始训练 | |
trainer = DRTrainer(config) | |
trainer.train() | |