""" 训练模块 包含训练循环、验证、早停等功能 """ import os import time import copy import yaml import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, ReduceLROnPlateau from torch.cuda.amp import GradScaler, autocast import numpy as np from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix import matplotlib.pyplot as plt import seaborn as sns from tqdm import tqdm from typing import Dict, List, Tuple, Optional import logging from tensorboardX import SummaryWriter from src.data_loader import create_data_loaders, get_class_weights from src.models import create_model, count_parameters, model_size_mb from utils.metrics import calculate_metrics, plot_confusion_matrix class EarlyStopping: """早停机制""" def __init__(self, patience: int = 7, min_delta: float = 0.0, restore_best_weights: bool = True): self.patience = patience self.min_delta = min_delta self.restore_best_weights = restore_best_weights self.best_loss = None self.counter = 0 self.best_weights = None def __call__(self, val_loss: float, model: nn.Module) -> bool: if self.best_loss is None: self.best_loss = val_loss self.best_weights = copy.deepcopy(model.state_dict()) elif val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.counter = 0 self.best_weights = copy.deepcopy(model.state_dict()) else: self.counter += 1 if self.counter >= self.patience: if self.restore_best_weights: model.load_state_dict(self.best_weights) return True return False class DRTrainer: def run_qat(self): """量化感知训练(QAT)流程""" qat_cfg = self.config['training'] if not qat_cfg.get('qat', False): return import copy import torch.quantization as tq qat_epochs = qat_cfg.get('qat_epochs', 10) qat_backend = qat_cfg.get('qat_backend', 'fbgemm') export_path = qat_cfg.get('qat_export_path', 'weights/qat_model.onnx') self.logger.info(f"开始QAT微调: epochs={qat_epochs}, backend={qat_backend}") # 1. 准备量化模型 model_qat = copy.deepcopy(self.model).to(self.device) model_qat.train() model_qat.fuse_model = getattr(model_qat, 'fuse_model', None) if model_qat.fuse_model: model_qat.fuse_model() tq.backend = qat_backend model_qat.qconfig = tq.get_default_qat_qconfig(qat_backend) tq.prepare_qat(model_qat, inplace=True) optimizer = torch.optim.Adam(model_qat.parameters(), lr=1e-4) criterion = nn.CrossEntropyLoss() # 2. QAT训练 for epoch in range(qat_epochs): model_qat.train() running_loss = 0.0 correct = 0 total = 0 for images, labels in self.train_loader: images, labels = images.to(self.device), labels.to(self.device) optimizer.zero_grad() outputs = model_qat(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() avg_loss = running_loss / len(self.train_loader) acc = 100. * correct / total self.logger.info(f"[QAT] Epoch {epoch+1}/{qat_epochs} Loss: {avg_loss:.4f} Acc: {acc:.2f}%") # 3. 转换为量化模型 model_qat.eval() model_int8 = tq.convert(model_qat.cpu().eval(), inplace=False) self.logger.info("QAT模型量化完成,准备导出ONNX...") # 4. 导出ONNX dummy = torch.randn(1, 3, self.config['data']['image_size'], self.config['data']['image_size']) torch.onnx.export(model_int8, dummy, export_path, input_names=['input'], output_names=['output'], opset_version=12) self.logger.info(f"QAT量化模型已导出: {export_path}") """糖尿病视网膜病变检测模型训练器""" def __init__(self, config: dict): self.config = config self.device = torch.device( f"cuda:{config['device']['gpu_id']}" if config['device']['use_gpu'] and torch.cuda.is_available() else "cpu" ) # 创建日志目录 os.makedirs(config['logging']['log_dir'], exist_ok=True) os.makedirs(config['logging']['tensorboard_dir'], exist_ok=True) # 确保权重保存目录存在 os.makedirs(os.path.dirname(config['training']['model_save_path']), exist_ok=True) # 设置日志 self._setup_logging() # 初始化模型 self.model = create_model(config).to(self.device) self.logger.info(f"模型参数数量: {count_parameters(self.model):,}") self.logger.info(f"模型大小: {model_size_mb(self.model):.2f} MB") # 创建数据加载器 self.train_loader, self.val_loader, self.test_loader = create_data_loaders(config) # === 知识蒸馏相关 === self.distill = self.config['training'].get('distill', False) self.teacher_model = None if self.distill: from utils.losses import DistillationLoss teacher_name = self.config['training'].get('distill_teacher', 'efficientnet_b3') student_name = self.config['training'].get('distill_student', self.config['model']['name']) # student模型用config['model'],teacher模型用teacher_name teacher_config = copy.deepcopy(self.config) teacher_config['model']['name'] = teacher_name self.teacher_model = create_model(teacher_config).to(self.device) self.teacher_model.eval() # teacher权重加载(如有) teacher_ckpt = self.config['training'].get('distill_teacher_ckpt', None) if teacher_ckpt and os.path.exists(teacher_ckpt): state = torch.load(teacher_ckpt, map_location=self.device) if 'model_state_dict' in state: self.teacher_model.load_state_dict(state['model_state_dict']) else: self.teacher_model.load_state_dict(state) self.logger.info(f"已加载teacher模型权重: {teacher_ckpt}") else: self.logger.warning("未指定teacher权重,teacher模型将使用随机初始化!") alpha = self.config['training'].get('distill_alpha', 0.7) beta = self.config['training'].get('distill_beta', 0.3) temperature = self.config['training'].get('distill_temperature', 4.0) self.criterion = DistillationLoss(alpha=alpha, beta=beta, temperature=temperature) else: # 创建损失函数(支持类别权重、Focal Loss) label_smoothing = self.config['training'].get('label_smoothing', 0.0) use_focal = self.config['training'].get('use_focal_loss', False) class_weights = None if config['data'].get('use_class_weights', False): class_weights = get_class_weights( config['data']['train_dir'], config['model']['num_classes'] ).to(self.device) # 自动写入 config.yaml try: with open('configs/config.yaml', 'r', encoding='utf-8') as f: cfg = yaml.safe_load(f) cfg['training']['class_weights'] = [float(w) for w in class_weights.cpu().numpy()] with open('configs/config.yaml', 'w', encoding='utf-8') as f: yaml.dump(cfg, f, allow_unicode=True) except Exception as e: self.logger.warning(f"自动写入类别权重到 config.yaml 失败: {e}") if use_focal: from utils.losses import FocalLoss gamma = self.config['training'].get('focal_gamma', 2.0) alpha = self.config['training'].get('focal_alpha', None) if alpha is not None: alpha = torch.tensor(alpha, dtype=torch.float32, device=self.device) elif class_weights is not None: alpha = class_weights self.criterion = FocalLoss(alpha=alpha, gamma=gamma) else: self.criterion = nn.CrossEntropyLoss( weight=class_weights, label_smoothing=label_smoothing if label_smoothing > 0 else 0.0, ) # 创建优化器 self.optimizer = self._create_optimizer() # 创建学习率调度器 self.scheduler = self._create_scheduler() # 混合精度训练 self.use_amp = config['device'].get('mixed_precision', False) if self.use_amp: self.scaler = GradScaler() # 早停 early_stopping_config = config['training'] self.early_stopping = EarlyStopping( patience=early_stopping_config.get('early_stopping_patience', 10) ) # TensorBoard self.writer = SummaryWriter(config['logging']['tensorboard_dir']) # 训练历史 self.train_history = { 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': [] } self.best_val_acc = 0.0 self.start_epoch = 0 def _setup_logging(self): """设置日志""" log_file = os.path.join(self.config['logging']['log_dir'], 'training.log') logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file, encoding='utf-8'), logging.StreamHandler() ] ) self.logger = logging.getLogger(__name__) def _create_optimizer(self) -> optim.Optimizer: """创建优化器""" opt_config = self.config['optimizer'] lr = self.config['training']['learning_rate'] weight_decay = self.config['training']['weight_decay'] if opt_config['name'].lower() == 'adam': optimizer = optim.Adam( self.model.parameters(), lr=lr, weight_decay=weight_decay, betas=(opt_config.get('beta1', 0.9), opt_config.get('beta2', 0.999)) ) elif opt_config['name'].lower() == 'adamw': optimizer = optim.AdamW( self.model.parameters(), lr=lr, weight_decay=weight_decay, betas=(opt_config.get('beta1', 0.9), opt_config.get('beta2', 0.999)) ) elif opt_config['name'].lower() == 'sgd': optimizer = optim.SGD( self.model.parameters(), lr=lr, weight_decay=weight_decay, momentum=opt_config.get('momentum', 0.9) ) else: raise ValueError(f"不支持的优化器: {opt_config['name']}") return optimizer def _create_scheduler(self): """创建学习率调度器""" scheduler_name = self.config['training'].get('scheduler', 'cosine') if scheduler_name == 'cosine': scheduler = CosineAnnealingLR( self.optimizer, T_max=self.config['training']['epochs'] ) elif scheduler_name == 'step': scheduler = StepLR( self.optimizer, step_size=30, gamma=0.1 ) elif scheduler_name == 'plateau': scheduler = ReduceLROnPlateau( self.optimizer, mode='min', factor=0.5, patience=5, verbose=True ) else: scheduler = None return scheduler def train_epoch(self, epoch: int) -> Tuple[float, float]: """训练一个epoch,支持多任务(分级+二分类)""" self.model.train() running_loss = 0.0 correct = 0 total = 0 correct_bin = 0 total_bin = 0 progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}') for batch_idx, batch in enumerate(progress_bar): # 支持(images, label, is_diabetic) 或 (images, label) if len(batch) == 3: images, labels, is_diabetic = batch images = images.to(self.device) labels = labels.to(self.device) is_diabetic = is_diabetic.to(self.device).float() else: images, labels = batch images = images.to(self.device) labels = labels.to(self.device) is_diabetic = None self.optimizer.zero_grad() if self.use_amp: with autocast(): outputs = self.model(images) if isinstance(outputs, dict): loss_grading = self.criterion(outputs['grading'], labels) if is_diabetic is not None: loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic) loss = loss_grading + loss_diabetic else: loss = loss_grading else: loss = self.criterion(outputs, labels) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: outputs = self.model(images) if isinstance(outputs, dict): loss_grading = self.criterion(outputs['grading'], labels) if is_diabetic is not None: loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic) loss = loss_grading + loss_diabetic else: loss = loss_grading else: loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() # 统计分级准确率 if isinstance(outputs, dict): out_grading = outputs['grading'] _, predicted = out_grading.max(1) else: predicted = outputs.max(1)[1] total += labels.size(0) correct += predicted.eq(labels).sum().item() # 统计二分类准确率 if is_diabetic is not None and isinstance(outputs, dict): out_bin = torch.sigmoid(outputs['diabetic']) pred_bin = (out_bin > 0.5).long() correct_bin += pred_bin.eq(is_diabetic.long()).sum().item() total_bin += is_diabetic.size(0) running_loss += loss.item() # 更新进度条 postfix = {'Loss': f'{loss.item():.4f}', 'Acc': f'{100.*correct/total:.2f}%'} if total_bin > 0: postfix['BinAcc'] = f'{100.*correct_bin/total_bin:.2f}%' progress_bar.set_postfix(postfix) epoch_loss = running_loss / len(self.train_loader) epoch_acc = 100. * correct / total return epoch_loss, epoch_acc def validate(self) -> Tuple[float, float, Dict]: """多任务验证,输出分级和二分类准确率""" self.model.eval() running_loss = 0.0 all_predictions = [] all_labels = [] all_bin_preds = [] all_bin_labels = [] with torch.no_grad(): for batch in tqdm(self.val_loader, desc='Validating'): if len(batch) == 3: images, labels, is_diabetic = batch images = images.to(self.device) labels = labels.to(self.device) is_diabetic = is_diabetic.to(self.device).float() else: images, labels = batch images = images.to(self.device) labels = labels.to(self.device) is_diabetic = None if self.use_amp: with autocast(): outputs = self.model(images) if isinstance(outputs, dict): loss_grading = self.criterion(outputs['grading'], labels) if is_diabetic is not None: loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic) loss = loss_grading + loss_diabetic else: loss = loss_grading else: loss = self.criterion(outputs, labels) else: outputs = self.model(images) if isinstance(outputs, dict): loss_grading = self.criterion(outputs['grading'], labels) if is_diabetic is not None: loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic) loss = loss_grading + loss_diabetic else: loss = loss_grading else: loss = self.criterion(outputs, labels) running_loss += loss.item() # 分级预测 if isinstance(outputs, dict): out_grading = outputs['grading'] _, predicted = out_grading.max(1) else: predicted = outputs.max(1)[1] all_predictions.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) # 二分类预测 if is_diabetic is not None and isinstance(outputs, dict): out_bin = torch.sigmoid(outputs['diabetic']) pred_bin = (out_bin > 0.5).long() all_bin_preds.extend(pred_bin.cpu().numpy()) all_bin_labels.extend(is_diabetic.cpu().numpy()) val_loss = running_loss / len(self.val_loader) val_acc = 100. * accuracy_score(all_labels, all_predictions) metrics = calculate_metrics(all_labels, all_predictions) # 二分类准确率 if all_bin_labels: bin_acc = 100. * accuracy_score(all_bin_labels, all_bin_preds) metrics['bin_acc'] = bin_acc return val_loss, val_acc, metrics def save_checkpoint(self, epoch: int, is_best: bool = False): """保存检查点""" checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'best_val_acc': self.best_val_acc, 'train_history': self.train_history, 'config': self.config } if self.scheduler: checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() # 保存最新检查点 checkpoint_path = os.path.join( os.path.dirname(self.config['training']['model_save_path']), 'last_checkpoint.pth' ) torch.save(checkpoint, checkpoint_path) # 保存最佳模型 if is_best: best_path = self.config['training']['model_save_path'] torch.save(checkpoint, best_path) self.logger.info(f"保存最佳模型: {best_path}") def load_checkpoint(self, checkpoint_path: str): """加载检查点""" if not os.path.exists(checkpoint_path): self.logger.info("未找到检查点,从头开始训练") return checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.best_val_acc = checkpoint.get('best_val_acc', 0.0) self.start_epoch = checkpoint.get('epoch', 0) + 1 self.train_history = checkpoint.get('train_history', self.train_history) if self.scheduler and 'scheduler_state_dict' in checkpoint: self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.logger.info(f"从epoch {self.start_epoch} 恢复训练") def train(self): """完整的训练流程""" self.logger.info("开始训练...") self.logger.info(f"训练设备: {self.device}") self.logger.info(f"训练集大小: {len(self.train_loader.dataset)}") self.logger.info(f"验证集大小: {len(self.val_loader.dataset)}") # 尝试加载检查点 checkpoint_path = os.path.join( os.path.dirname(self.config['training']['model_save_path']), 'last_checkpoint.pth' ) self.load_checkpoint(checkpoint_path) for epoch in range(self.start_epoch, self.config['training']['epochs']): start_time = time.time() # 训练 train_loss, train_acc = self.train_epoch(epoch) # 验证 val_loss, val_acc, val_metrics = self.validate() # 学习率调度 if self.scheduler: if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(val_loss) else: self.scheduler.step() # 记录历史 current_lr = self.optimizer.param_groups[0]['lr'] self.train_history['train_loss'].append(train_loss) self.train_history['train_acc'].append(train_acc) self.train_history['val_loss'].append(val_loss) self.train_history['val_acc'].append(val_acc) self.train_history['lr'].append(current_lr) # TensorBoard记录 self.writer.add_scalar('Loss/Train', train_loss, epoch) self.writer.add_scalar('Loss/Val', val_loss, epoch) self.writer.add_scalar('Accuracy/Train', train_acc, epoch) self.writer.add_scalar('Accuracy/Val', val_acc, epoch) self.writer.add_scalar('Learning_Rate', current_lr, epoch) # 记录验证指标 for metric_name, metric_value in val_metrics.items(): if isinstance(metric_value, (int, float)): self.writer.add_scalar(f'Metrics/{metric_name}', metric_value, epoch) # 保存最佳模型 is_best = val_acc > self.best_val_acc if is_best: self.best_val_acc = val_acc # 定期保存检查点 if (epoch + 1) % self.config['logging']['save_frequency'] == 0 or is_best: self.save_checkpoint(epoch, is_best) # 计算训练时间 epoch_time = time.time() - start_time # 打印结果 self.logger.info( f"Epoch [{epoch+1}/{self.config['training']['epochs']}] " f"Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% " f"Val Loss: {val_loss:.4f} Val Acc: {val_acc:.2f}% " f"Time: {epoch_time:.2f}s LR: {current_lr:.6f}" ) # 早停检查 if self.early_stopping(val_loss, self.model): self.logger.info(f"Early stopping at epoch {epoch+1}") break self.logger.info(f"训练完成!最佳验证准确率: {self.best_val_acc:.2f}%") # 绘制训练曲线 self.plot_training_history() # 在测试集上评估 if self.test_loader: self.evaluate_on_test() # === QAT流程 === self.run_qat() def plot_training_history(self): """绘制训练历史曲线""" fig, axes = plt.subplots(2, 2, figsize=(15, 10)) # 损失曲线 axes[0, 0].plot(self.train_history['train_loss'], label='Train Loss') axes[0, 0].plot(self.train_history['val_loss'], label='Val Loss') axes[0, 0].set_title('Loss Curves') axes[0, 0].set_xlabel('Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].legend() axes[0, 0].grid(True) # 准确率曲线 axes[0, 1].plot(self.train_history['train_acc'], label='Train Acc') axes[0, 1].plot(self.train_history['val_acc'], label='Val Acc') axes[0, 1].set_title('Accuracy Curves') axes[0, 1].set_xlabel('Epoch') axes[0, 1].set_ylabel('Accuracy (%)') axes[0, 1].legend() axes[0, 1].grid(True) # 学习率曲线 axes[1, 0].plot(self.train_history['lr']) axes[1, 0].set_title('Learning Rate') axes[1, 0].set_xlabel('Epoch') axes[1, 0].set_ylabel('Learning Rate') axes[1, 0].set_yscale('log') axes[1, 0].grid(True) # 最佳性能标记 best_epoch = np.argmax(self.train_history['val_acc']) axes[1, 1].text(0.1, 0.8, f'Best Val Acc: {self.best_val_acc:.2f}%', transform=axes[1, 1].transAxes, fontsize=12) axes[1, 1].text(0.1, 0.7, f'Best Epoch: {best_epoch + 1}', transform=axes[1, 1].transAxes, fontsize=12) axes[1, 1].text(0.1, 0.6, f'Total Epochs: {len(self.train_history["val_acc"])}', transform=axes[1, 1].transAxes, fontsize=12) axes[1, 1].axis('off') plt.tight_layout() plt.savefig(os.path.join(self.config['logging']['log_dir'], 'training_history.png'), dpi=300, bbox_inches='tight') plt.close() def evaluate_on_test(self): """多任务测试集评估""" self.logger.info("在测试集上评估模型...") # 加载最佳模型 best_model_path = self.config['training']['model_save_path'] if os.path.exists(best_model_path): checkpoint = torch.load(best_model_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() all_predictions = [] all_labels = [] all_bin_preds = [] all_bin_labels = [] with torch.no_grad(): for batch in tqdm(self.test_loader, desc='Testing'): if len(batch) == 3: images, labels, is_diabetic = batch images = images.to(self.device) labels = labels.to(self.device) is_diabetic = is_diabetic.to(self.device).float() else: images, labels = batch images = images.to(self.device) labels = labels.to(self.device) is_diabetic = None outputs = self.model(images) # 分级预测 if isinstance(outputs, dict): out_grading = outputs['grading'] _, predicted = out_grading.max(1) else: predicted = outputs.max(1)[1] all_predictions.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) # 二分类预测 if is_diabetic is not None and isinstance(outputs, dict): out_bin = torch.sigmoid(outputs['diabetic']) pred_bin = (out_bin > 0.5).long() all_bin_preds.extend(pred_bin.cpu().numpy()) all_bin_labels.extend(is_diabetic.cpu().numpy()) # 计算指标 test_metrics = calculate_metrics(all_labels, all_predictions) if all_bin_labels: bin_acc = 100. * accuracy_score(all_bin_labels, all_bin_preds) test_metrics['bin_acc'] = bin_acc # 打印结果 self.logger.info("测试集结果:") for metric_name, metric_value in test_metrics.items(): if isinstance(metric_value, (int, float)): self.logger.info(f"{metric_name}: {metric_value:.4f}") # 绘制混淆矩阵 cm = confusion_matrix(all_labels, all_predictions) plot_confusion_matrix( cm, self.config['data']['class_names'], save_path=os.path.join(self.config['logging']['log_dir'], 'confusion_matrix.png') ) if __name__ == "__main__": # 加载配置 with open("configs/config.yaml", 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # 创建训练器并开始训练 trainer = DRTrainer(config) trainer.train()