aicomp_demo / src /train.py
ceasonen
我的视网膜检测网站
04103fb
"""
训练模块
包含训练循环、验证、早停等功能
"""
import os
import time
import copy
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
import logging
from tensorboardX import SummaryWriter
from src.data_loader import create_data_loaders, get_class_weights
from src.models import create_model, count_parameters, model_size_mb
from utils.metrics import calculate_metrics, plot_confusion_matrix
class EarlyStopping:
"""早停机制"""
def __init__(self, patience: int = 7, min_delta: float = 0.0,
restore_best_weights: bool = True):
self.patience = patience
self.min_delta = min_delta
self.restore_best_weights = restore_best_weights
self.best_loss = None
self.counter = 0
self.best_weights = None
def __call__(self, val_loss: float, model: nn.Module) -> bool:
if self.best_loss is None:
self.best_loss = val_loss
self.best_weights = copy.deepcopy(model.state_dict())
elif val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.best_weights = copy.deepcopy(model.state_dict())
else:
self.counter += 1
if self.counter >= self.patience:
if self.restore_best_weights:
model.load_state_dict(self.best_weights)
return True
return False
class DRTrainer:
def run_qat(self):
"""量化感知训练(QAT)流程"""
qat_cfg = self.config['training']
if not qat_cfg.get('qat', False):
return
import copy
import torch.quantization as tq
qat_epochs = qat_cfg.get('qat_epochs', 10)
qat_backend = qat_cfg.get('qat_backend', 'fbgemm')
export_path = qat_cfg.get('qat_export_path', 'weights/qat_model.onnx')
self.logger.info(f"开始QAT微调: epochs={qat_epochs}, backend={qat_backend}")
# 1. 准备量化模型
model_qat = copy.deepcopy(self.model).to(self.device)
model_qat.train()
model_qat.fuse_model = getattr(model_qat, 'fuse_model', None)
if model_qat.fuse_model:
model_qat.fuse_model()
tq.backend = qat_backend
model_qat.qconfig = tq.get_default_qat_qconfig(qat_backend)
tq.prepare_qat(model_qat, inplace=True)
optimizer = torch.optim.Adam(model_qat.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# 2. QAT训练
for epoch in range(qat_epochs):
model_qat.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in self.train_loader:
images, labels = images.to(self.device), labels.to(self.device)
optimizer.zero_grad()
outputs = model_qat(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
avg_loss = running_loss / len(self.train_loader)
acc = 100. * correct / total
self.logger.info(f"[QAT] Epoch {epoch+1}/{qat_epochs} Loss: {avg_loss:.4f} Acc: {acc:.2f}%")
# 3. 转换为量化模型
model_qat.eval()
model_int8 = tq.convert(model_qat.cpu().eval(), inplace=False)
self.logger.info("QAT模型量化完成,准备导出ONNX...")
# 4. 导出ONNX
dummy = torch.randn(1, 3, self.config['data']['image_size'], self.config['data']['image_size'])
torch.onnx.export(model_int8, dummy, export_path, input_names=['input'], output_names=['output'], opset_version=12)
self.logger.info(f"QAT量化模型已导出: {export_path}")
"""糖尿病视网膜病变检测模型训练器"""
def __init__(self, config: dict):
self.config = config
self.device = torch.device(
f"cuda:{config['device']['gpu_id']}"
if config['device']['use_gpu'] and torch.cuda.is_available()
else "cpu"
)
# 创建日志目录
os.makedirs(config['logging']['log_dir'], exist_ok=True)
os.makedirs(config['logging']['tensorboard_dir'], exist_ok=True)
# 确保权重保存目录存在
os.makedirs(os.path.dirname(config['training']['model_save_path']), exist_ok=True)
# 设置日志
self._setup_logging()
# 初始化模型
self.model = create_model(config).to(self.device)
self.logger.info(f"模型参数数量: {count_parameters(self.model):,}")
self.logger.info(f"模型大小: {model_size_mb(self.model):.2f} MB")
# 创建数据加载器
self.train_loader, self.val_loader, self.test_loader = create_data_loaders(config)
# === 知识蒸馏相关 ===
self.distill = self.config['training'].get('distill', False)
self.teacher_model = None
if self.distill:
from utils.losses import DistillationLoss
teacher_name = self.config['training'].get('distill_teacher', 'efficientnet_b3')
student_name = self.config['training'].get('distill_student', self.config['model']['name'])
# student模型用config['model'],teacher模型用teacher_name
teacher_config = copy.deepcopy(self.config)
teacher_config['model']['name'] = teacher_name
self.teacher_model = create_model(teacher_config).to(self.device)
self.teacher_model.eval()
# teacher权重加载(如有)
teacher_ckpt = self.config['training'].get('distill_teacher_ckpt', None)
if teacher_ckpt and os.path.exists(teacher_ckpt):
state = torch.load(teacher_ckpt, map_location=self.device)
if 'model_state_dict' in state:
self.teacher_model.load_state_dict(state['model_state_dict'])
else:
self.teacher_model.load_state_dict(state)
self.logger.info(f"已加载teacher模型权重: {teacher_ckpt}")
else:
self.logger.warning("未指定teacher权重,teacher模型将使用随机初始化!")
alpha = self.config['training'].get('distill_alpha', 0.7)
beta = self.config['training'].get('distill_beta', 0.3)
temperature = self.config['training'].get('distill_temperature', 4.0)
self.criterion = DistillationLoss(alpha=alpha, beta=beta, temperature=temperature)
else:
# 创建损失函数(支持类别权重、Focal Loss)
label_smoothing = self.config['training'].get('label_smoothing', 0.0)
use_focal = self.config['training'].get('use_focal_loss', False)
class_weights = None
if config['data'].get('use_class_weights', False):
class_weights = get_class_weights(
config['data']['train_dir'],
config['model']['num_classes']
).to(self.device)
# 自动写入 config.yaml
try:
with open('configs/config.yaml', 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
cfg['training']['class_weights'] = [float(w) for w in class_weights.cpu().numpy()]
with open('configs/config.yaml', 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, allow_unicode=True)
except Exception as e:
self.logger.warning(f"自动写入类别权重到 config.yaml 失败: {e}")
if use_focal:
from utils.losses import FocalLoss
gamma = self.config['training'].get('focal_gamma', 2.0)
alpha = self.config['training'].get('focal_alpha', None)
if alpha is not None:
alpha = torch.tensor(alpha, dtype=torch.float32, device=self.device)
elif class_weights is not None:
alpha = class_weights
self.criterion = FocalLoss(alpha=alpha, gamma=gamma)
else:
self.criterion = nn.CrossEntropyLoss(
weight=class_weights,
label_smoothing=label_smoothing if label_smoothing > 0 else 0.0,
)
# 创建优化器
self.optimizer = self._create_optimizer()
# 创建学习率调度器
self.scheduler = self._create_scheduler()
# 混合精度训练
self.use_amp = config['device'].get('mixed_precision', False)
if self.use_amp:
self.scaler = GradScaler()
# 早停
early_stopping_config = config['training']
self.early_stopping = EarlyStopping(
patience=early_stopping_config.get('early_stopping_patience', 10)
)
# TensorBoard
self.writer = SummaryWriter(config['logging']['tensorboard_dir'])
# 训练历史
self.train_history = {
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': [],
'lr': []
}
self.best_val_acc = 0.0
self.start_epoch = 0
def _setup_logging(self):
"""设置日志"""
log_file = os.path.join(self.config['logging']['log_dir'], 'training.log')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file, encoding='utf-8'),
logging.StreamHandler()
]
)
self.logger = logging.getLogger(__name__)
def _create_optimizer(self) -> optim.Optimizer:
"""创建优化器"""
opt_config = self.config['optimizer']
lr = self.config['training']['learning_rate']
weight_decay = self.config['training']['weight_decay']
if opt_config['name'].lower() == 'adam':
optimizer = optim.Adam(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(opt_config.get('beta1', 0.9), opt_config.get('beta2', 0.999))
)
elif opt_config['name'].lower() == 'adamw':
optimizer = optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(opt_config.get('beta1', 0.9), opt_config.get('beta2', 0.999))
)
elif opt_config['name'].lower() == 'sgd':
optimizer = optim.SGD(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
momentum=opt_config.get('momentum', 0.9)
)
else:
raise ValueError(f"不支持的优化器: {opt_config['name']}")
return optimizer
def _create_scheduler(self):
"""创建学习率调度器"""
scheduler_name = self.config['training'].get('scheduler', 'cosine')
if scheduler_name == 'cosine':
scheduler = CosineAnnealingLR(
self.optimizer,
T_max=self.config['training']['epochs']
)
elif scheduler_name == 'step':
scheduler = StepLR(
self.optimizer,
step_size=30,
gamma=0.1
)
elif scheduler_name == 'plateau':
scheduler = ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=5,
verbose=True
)
else:
scheduler = None
return scheduler
def train_epoch(self, epoch: int) -> Tuple[float, float]:
"""训练一个epoch,支持多任务(分级+二分类)"""
self.model.train()
running_loss = 0.0
correct = 0
total = 0
correct_bin = 0
total_bin = 0
progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}')
for batch_idx, batch in enumerate(progress_bar):
# 支持(images, label, is_diabetic) 或 (images, label)
if len(batch) == 3:
images, labels, is_diabetic = batch
images = images.to(self.device)
labels = labels.to(self.device)
is_diabetic = is_diabetic.to(self.device).float()
else:
images, labels = batch
images = images.to(self.device)
labels = labels.to(self.device)
is_diabetic = None
self.optimizer.zero_grad()
if self.use_amp:
with autocast():
outputs = self.model(images)
if isinstance(outputs, dict):
loss_grading = self.criterion(outputs['grading'], labels)
if is_diabetic is not None:
loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic)
loss = loss_grading + loss_diabetic
else:
loss = loss_grading
else:
loss = self.criterion(outputs, labels)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
outputs = self.model(images)
if isinstance(outputs, dict):
loss_grading = self.criterion(outputs['grading'], labels)
if is_diabetic is not None:
loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic)
loss = loss_grading + loss_diabetic
else:
loss = loss_grading
else:
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
# 统计分级准确率
if isinstance(outputs, dict):
out_grading = outputs['grading']
_, predicted = out_grading.max(1)
else:
predicted = outputs.max(1)[1]
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# 统计二分类准确率
if is_diabetic is not None and isinstance(outputs, dict):
out_bin = torch.sigmoid(outputs['diabetic'])
pred_bin = (out_bin > 0.5).long()
correct_bin += pred_bin.eq(is_diabetic.long()).sum().item()
total_bin += is_diabetic.size(0)
running_loss += loss.item()
# 更新进度条
postfix = {'Loss': f'{loss.item():.4f}', 'Acc': f'{100.*correct/total:.2f}%'}
if total_bin > 0:
postfix['BinAcc'] = f'{100.*correct_bin/total_bin:.2f}%'
progress_bar.set_postfix(postfix)
epoch_loss = running_loss / len(self.train_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
def validate(self) -> Tuple[float, float, Dict]:
"""多任务验证,输出分级和二分类准确率"""
self.model.eval()
running_loss = 0.0
all_predictions = []
all_labels = []
all_bin_preds = []
all_bin_labels = []
with torch.no_grad():
for batch in tqdm(self.val_loader, desc='Validating'):
if len(batch) == 3:
images, labels, is_diabetic = batch
images = images.to(self.device)
labels = labels.to(self.device)
is_diabetic = is_diabetic.to(self.device).float()
else:
images, labels = batch
images = images.to(self.device)
labels = labels.to(self.device)
is_diabetic = None
if self.use_amp:
with autocast():
outputs = self.model(images)
if isinstance(outputs, dict):
loss_grading = self.criterion(outputs['grading'], labels)
if is_diabetic is not None:
loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic)
loss = loss_grading + loss_diabetic
else:
loss = loss_grading
else:
loss = self.criterion(outputs, labels)
else:
outputs = self.model(images)
if isinstance(outputs, dict):
loss_grading = self.criterion(outputs['grading'], labels)
if is_diabetic is not None:
loss_diabetic = nn.BCEWithLogitsLoss()(outputs['diabetic'], is_diabetic)
loss = loss_grading + loss_diabetic
else:
loss = loss_grading
else:
loss = self.criterion(outputs, labels)
running_loss += loss.item()
# 分级预测
if isinstance(outputs, dict):
out_grading = outputs['grading']
_, predicted = out_grading.max(1)
else:
predicted = outputs.max(1)[1]
all_predictions.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# 二分类预测
if is_diabetic is not None and isinstance(outputs, dict):
out_bin = torch.sigmoid(outputs['diabetic'])
pred_bin = (out_bin > 0.5).long()
all_bin_preds.extend(pred_bin.cpu().numpy())
all_bin_labels.extend(is_diabetic.cpu().numpy())
val_loss = running_loss / len(self.val_loader)
val_acc = 100. * accuracy_score(all_labels, all_predictions)
metrics = calculate_metrics(all_labels, all_predictions)
# 二分类准确率
if all_bin_labels:
bin_acc = 100. * accuracy_score(all_bin_labels, all_bin_preds)
metrics['bin_acc'] = bin_acc
return val_loss, val_acc, metrics
def save_checkpoint(self, epoch: int, is_best: bool = False):
"""保存检查点"""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_val_acc': self.best_val_acc,
'train_history': self.train_history,
'config': self.config
}
if self.scheduler:
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
# 保存最新检查点
checkpoint_path = os.path.join(
os.path.dirname(self.config['training']['model_save_path']),
'last_checkpoint.pth'
)
torch.save(checkpoint, checkpoint_path)
# 保存最佳模型
if is_best:
best_path = self.config['training']['model_save_path']
torch.save(checkpoint, best_path)
self.logger.info(f"保存最佳模型: {best_path}")
def load_checkpoint(self, checkpoint_path: str):
"""加载检查点"""
if not os.path.exists(checkpoint_path):
self.logger.info("未找到检查点,从头开始训练")
return
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.best_val_acc = checkpoint.get('best_val_acc', 0.0)
self.start_epoch = checkpoint.get('epoch', 0) + 1
self.train_history = checkpoint.get('train_history', self.train_history)
if self.scheduler and 'scheduler_state_dict' in checkpoint:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.logger.info(f"从epoch {self.start_epoch} 恢复训练")
def train(self):
"""完整的训练流程"""
self.logger.info("开始训练...")
self.logger.info(f"训练设备: {self.device}")
self.logger.info(f"训练集大小: {len(self.train_loader.dataset)}")
self.logger.info(f"验证集大小: {len(self.val_loader.dataset)}")
# 尝试加载检查点
checkpoint_path = os.path.join(
os.path.dirname(self.config['training']['model_save_path']),
'last_checkpoint.pth'
)
self.load_checkpoint(checkpoint_path)
for epoch in range(self.start_epoch, self.config['training']['epochs']):
start_time = time.time()
# 训练
train_loss, train_acc = self.train_epoch(epoch)
# 验证
val_loss, val_acc, val_metrics = self.validate()
# 学习率调度
if self.scheduler:
if isinstance(self.scheduler, ReduceLROnPlateau):
self.scheduler.step(val_loss)
else:
self.scheduler.step()
# 记录历史
current_lr = self.optimizer.param_groups[0]['lr']
self.train_history['train_loss'].append(train_loss)
self.train_history['train_acc'].append(train_acc)
self.train_history['val_loss'].append(val_loss)
self.train_history['val_acc'].append(val_acc)
self.train_history['lr'].append(current_lr)
# TensorBoard记录
self.writer.add_scalar('Loss/Train', train_loss, epoch)
self.writer.add_scalar('Loss/Val', val_loss, epoch)
self.writer.add_scalar('Accuracy/Train', train_acc, epoch)
self.writer.add_scalar('Accuracy/Val', val_acc, epoch)
self.writer.add_scalar('Learning_Rate', current_lr, epoch)
# 记录验证指标
for metric_name, metric_value in val_metrics.items():
if isinstance(metric_value, (int, float)):
self.writer.add_scalar(f'Metrics/{metric_name}', metric_value, epoch)
# 保存最佳模型
is_best = val_acc > self.best_val_acc
if is_best:
self.best_val_acc = val_acc
# 定期保存检查点
if (epoch + 1) % self.config['logging']['save_frequency'] == 0 or is_best:
self.save_checkpoint(epoch, is_best)
# 计算训练时间
epoch_time = time.time() - start_time
# 打印结果
self.logger.info(
f"Epoch [{epoch+1}/{self.config['training']['epochs']}] "
f"Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% "
f"Val Loss: {val_loss:.4f} Val Acc: {val_acc:.2f}% "
f"Time: {epoch_time:.2f}s LR: {current_lr:.6f}"
)
# 早停检查
if self.early_stopping(val_loss, self.model):
self.logger.info(f"Early stopping at epoch {epoch+1}")
break
self.logger.info(f"训练完成!最佳验证准确率: {self.best_val_acc:.2f}%")
# 绘制训练曲线
self.plot_training_history()
# 在测试集上评估
if self.test_loader:
self.evaluate_on_test()
# === QAT流程 ===
self.run_qat()
def plot_training_history(self):
"""绘制训练历史曲线"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# 损失曲线
axes[0, 0].plot(self.train_history['train_loss'], label='Train Loss')
axes[0, 0].plot(self.train_history['val_loss'], label='Val Loss')
axes[0, 0].set_title('Loss Curves')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)
# 准确率曲线
axes[0, 1].plot(self.train_history['train_acc'], label='Train Acc')
axes[0, 1].plot(self.train_history['val_acc'], label='Val Acc')
axes[0, 1].set_title('Accuracy Curves')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].legend()
axes[0, 1].grid(True)
# 学习率曲线
axes[1, 0].plot(self.train_history['lr'])
axes[1, 0].set_title('Learning Rate')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True)
# 最佳性能标记
best_epoch = np.argmax(self.train_history['val_acc'])
axes[1, 1].text(0.1, 0.8, f'Best Val Acc: {self.best_val_acc:.2f}%',
transform=axes[1, 1].transAxes, fontsize=12)
axes[1, 1].text(0.1, 0.7, f'Best Epoch: {best_epoch + 1}',
transform=axes[1, 1].transAxes, fontsize=12)
axes[1, 1].text(0.1, 0.6, f'Total Epochs: {len(self.train_history["val_acc"])}',
transform=axes[1, 1].transAxes, fontsize=12)
axes[1, 1].axis('off')
plt.tight_layout()
plt.savefig(os.path.join(self.config['logging']['log_dir'], 'training_history.png'),
dpi=300, bbox_inches='tight')
plt.close()
def evaluate_on_test(self):
"""多任务测试集评估"""
self.logger.info("在测试集上评估模型...")
# 加载最佳模型
best_model_path = self.config['training']['model_save_path']
if os.path.exists(best_model_path):
checkpoint = torch.load(best_model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
all_predictions = []
all_labels = []
all_bin_preds = []
all_bin_labels = []
with torch.no_grad():
for batch in tqdm(self.test_loader, desc='Testing'):
if len(batch) == 3:
images, labels, is_diabetic = batch
images = images.to(self.device)
labels = labels.to(self.device)
is_diabetic = is_diabetic.to(self.device).float()
else:
images, labels = batch
images = images.to(self.device)
labels = labels.to(self.device)
is_diabetic = None
outputs = self.model(images)
# 分级预测
if isinstance(outputs, dict):
out_grading = outputs['grading']
_, predicted = out_grading.max(1)
else:
predicted = outputs.max(1)[1]
all_predictions.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# 二分类预测
if is_diabetic is not None and isinstance(outputs, dict):
out_bin = torch.sigmoid(outputs['diabetic'])
pred_bin = (out_bin > 0.5).long()
all_bin_preds.extend(pred_bin.cpu().numpy())
all_bin_labels.extend(is_diabetic.cpu().numpy())
# 计算指标
test_metrics = calculate_metrics(all_labels, all_predictions)
if all_bin_labels:
bin_acc = 100. * accuracy_score(all_bin_labels, all_bin_preds)
test_metrics['bin_acc'] = bin_acc
# 打印结果
self.logger.info("测试集结果:")
for metric_name, metric_value in test_metrics.items():
if isinstance(metric_value, (int, float)):
self.logger.info(f"{metric_name}: {metric_value:.4f}")
# 绘制混淆矩阵
cm = confusion_matrix(all_labels, all_predictions)
plot_confusion_matrix(
cm,
self.config['data']['class_names'],
save_path=os.path.join(self.config['logging']['log_dir'], 'confusion_matrix.png')
)
if __name__ == "__main__":
# 加载配置
with open("configs/config.yaml", 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 创建训练器并开始训练
trainer = DRTrainer(config)
trainer.train()