""" 推理模块 支持单张图像和批量推理,以及可视化 """ import os import torch import torch.nn.functional as F import cv2 import numpy as np from PIL import Image import yaml import matplotlib.pyplot as plt from typing import List, Tuple, Dict, Optional, Union import albumentations as A from albumentations.pytorch import ToTensorV2 """ Grad-CAM 相关依赖(pytorch-grad-cam)在部分 Python/平台上可能不可用。 已改为惰性导入:仅在启用 grad_cam 时尝试 import,失败将优雅降级。 """ import time from src.models import create_model from src.data_loader import create_data_transforms class DRPredictor: def print_model_profile(self, input_size: Tuple[int, int] = None): """打印模型参数量、FLOPs、模型大小等信息(支持量化/非量化)。""" from src.models import count_parameters, model_size_mb input_size = input_size or (self.config['data']['image_size'], self.config['data']['image_size']) print("模型参数量: %d" % count_parameters(self.model)) print("模型大小: %.2f MB" % model_size_mb(self.model)) try: from thop import profile dummy = torch.randn(1, 3, input_size[0], input_size[1]).to(self.device) flops, params = profile(self.model, inputs=(dummy,), verbose=False) print("FLOPs: %.2f M" % (flops / 1e6)) except Exception as e: print(f"FLOPs统计失败: {e}") """糖尿病视网膜病变预测器""" def __init__(self, config_path: str, model_path: str = None): """ 初始化预测器 from src.models import create_model from src.data_loader import DiabeticRetinopathyDataset Args: config_path: 配置文件路径 model_path: 模型权重路径,如果为None则使用配置文件中的路径 """ with open(config_path, 'r', encoding='utf-8') as f: self.config = yaml.safe_load(f) self.device = torch.device( f"cuda:{self.config['device']['gpu_id']}" if self.config['device']['use_gpu'] and torch.cuda.is_available() else "cpu" ) # 加载模型 self.model = create_model(self.config).to(self.device) model_path = model_path or self.config['inference']['model_path'] if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=self.device) if 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict']) else: self.model.load_state_dict(checkpoint) print(f"加载模型权重: {model_path}") else: print(f"警告: 模型文件不存在 {model_path}") self.model.eval() # 类别名称 self.class_names = self.config['data']['class_names'] # 创建预处理变换 self.transform = self._create_transform() # 初始化GradCAM(用于可视化) self.grad_cam = None if self.config['inference'].get('grad_cam', False): self._setup_grad_cam() def load_config(config_path='configs/config.yaml'): with open(config_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) def run_inference(model, dataloader, device): model.eval() results = [] with torch.no_grad(): for batch in dataloader: # 兼容多任务 if len(batch) == 3: images, labels, is_diabetic = batch else: images, labels = batch is_diabetic = None images = images.to(device) outputs = model(images) batch_size = images.size(0) # 获取图片名 image_names = None if hasattr(dataloader.dataset, 'df') and 'id_code' in dataloader.dataset.df.columns: image_names = dataloader.dataset.df['id_code'].values elif hasattr(dataloader.dataset, 'images'): image_names = [os.path.basename(p) for p in dataloader.dataset.images] # 预测 if isinstance(outputs, dict): grading_logits = outputs['grading'] diabetic_logits = outputs['diabetic'] grading_pred = grading_logits.argmax(1).cpu().numpy() diabetic_prob = torch.sigmoid(diabetic_logits).cpu().numpy() diabetic_pred = (diabetic_prob > 0.5).astype(int) grading_probs = torch.softmax(grading_logits, dim=1).cpu().numpy() else: grading_logits = outputs grading_pred = grading_logits.argmax(1).cpu().numpy() grading_probs = torch.softmax(grading_logits, dim=1).cpu().numpy() diabetic_pred = None diabetic_prob = None for i in range(batch_size): result = {} # 图片名 if image_names is not None: result['image'] = image_names[i] # 标签 result['label'] = int(labels[i].cpu().numpy()) result['pred'] = int(grading_pred[i]) # 多分类概率 for c in range(grading_probs.shape[1]): result[f'pred_{c}'] = float(grading_probs[i, c]) # 二分类标签/概率 if is_diabetic is not None: result['is_diabetic'] = int(is_diabetic[i].cpu().numpy()) if diabetic_pred is not None: result['is_diabetic_pred'] = int(diabetic_pred[i]) result['is_diabetic_prob'] = float(diabetic_prob[i]) results.append(result) return results def main(): config = load_config() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = create_model(config) model_path = config['inference']['model_path'] checkpoint = torch.load(model_path, map_location=device) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model.to(device) # 推理数据集 test_csv = os.path.join(config['data']['test_dir'], '../test_labels.csv') from src.data_loader import DiabeticRetinopathyDataset dataset = DiabeticRetinopathyDataset( data_dir=config['data']['test_dir'], csv_file=test_csv, image_size=config['data']['image_size'] ) dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False) results = run_inference(model, dataloader, device) # 保存结果 df = pd.DataFrame(results) # 自动生成 predictions.csv 供所有可视化脚本使用 os.makedirs('evaluation_results', exist_ok=True) df.to_csv('evaluation_results/predictions.csv', index=False) print('推理结果已保存为 evaluation_results/predictions.csv') if __name__ == '__main__': main() def quantize_model(self): """对模型线性层执行动态量化(CPU 推理提速与减小体积)。""" if self.device.type != 'cpu': print("提示: 量化仅在 CPU 推理时有意义,已跳过") return try: import torch.nn as nn self.model = torch.quantization.quantize_dynamic( self.model, {nn.Linear}, dtype=torch.qint8 ) self.model.eval() print("已对模型执行动态量化(Linear→int8)") except Exception as e: print(f"量化失败,已跳过: {e}") def _create_transform(self) -> A.Compose: """创建预处理变换""" image_size = self.config['data']['image_size'] return A.Compose([ A.Resize(image_size, image_size), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) def _setup_grad_cam(self): """设置GradCAM(使用自研模块)。""" try: from utils.explainability import GradCAM except Exception as e: print(f"自研 GradCAM 模块加载失败: {e}") self.grad_cam = None return # 根据模型类型选择目标层 model_name = self.config['model']['name'].lower() if 'efficientnet' in model_name: target_layers = [self.model.backbone.conv_head] elif 'resnet' in model_name: target_layers = [self.model.backbone[-1][-1].conv2] elif 'vit' in model_name: target_layers = [self.model.backbone.norm] else: # 默认使用最后一个卷积层 target_layers = [] for name, module in self.model.named_modules(): if isinstance(module, torch.nn.Conv2d): target_layers = [module] if target_layers: self.grad_cam = GradCAM( model=self.model, target_layers=target_layers, use_cuda=self.device.type == 'cuda' ) print("已初始化自研 GradCAM 模块") def preprocess_image(self, image_path: str) -> Tuple[torch.Tensor, np.ndarray]: """ 预处理图像 Args: image_path: 图像路径 Returns: Tuple[torch.Tensor, np.ndarray]: 预处理后的tensor和原始图像 """ # 读取图像 image = cv2.imread(image_path) if image is None: raise ValueError(f"无法读取图像: {image_path}") # 转换颜色空间 image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 眼底图像预处理(去除黑边) processed_image = self._preprocess_fundus_image(image_rgb) # 应用变换 transformed = self.transform(image=processed_image) tensor_image = transformed['image'].unsqueeze(0) # 添加批次维度 return tensor_image, processed_image def _preprocess_fundus_image(self, image: np.ndarray) -> np.ndarray: """眼底图像预处理""" # 转换为灰度图用于检测黑色边框 gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # 阈值化去除黑色背景 _, thresh = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY) # 找到轮廓 contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: # 找到最大轮廓(眼底区域) largest_contour = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(largest_contour) # 裁剪图像 cropped = image[y:y+h, x:x+w] return cropped return image def predict_single(self, image_path: str, return_probs: bool = True) -> Dict: from utils.retina_detector import is_retina_image if not is_retina_image(image_path): raise ValueError('图片不符合要求,请上传标准视网膜照片。') """ 单张图像预测 Args: image_path: 图像路径 return_probs: 是否返回概率分布 Returns: Dict: 预测结果 """ start_time = time.time() # 预处理 tensor_image, original_image = self.preprocess_image(image_path) tensor_image = tensor_image.to(self.device) # 推理 with torch.no_grad(): outputs = self.model(tensor_image) probabilities = F.softmax(outputs, dim=1) confidence, predicted = torch.max(probabilities, 1) # 结果 predicted_class = predicted.item() confidence_score = confidence.item() result = { 'predicted_class': predicted_class, 'predicted_label': self.class_names[predicted_class], 'confidence': confidence_score, 'inference_time': time.time() - start_time } if return_probs: result['probabilities'] = { self.class_names[i]: prob.item() for i, prob in enumerate(probabilities[0]) } return result def predict_batch(self, image_paths: List[str]) -> List[Dict]: """ 批量预测 Args: image_paths: 图像路径列表 Returns: List[Dict]: 预测结果列表 """ results = [] for image_path in image_paths: try: result = self.predict_single(image_path) result['image_path'] = image_path results.append(result) except Exception as e: print(f"预测失败 {image_path}: {e}") results.append({ 'image_path': image_path, 'error': str(e) }) return results def generate_grad_cam(self, image_path: str, target_class: int = None) -> np.ndarray: """ 生成GradCAM可视化 Args: image_path: 图像路径 target_class: 目标类别,如果为None则使用预测类别 Returns: np.ndarray: GradCAM可视化图像 """ if self.grad_cam is None: raise ValueError("GradCAM未初始化,请在配置中启用grad_cam") # 预处理图像 tensor_image, original_image = self.preprocess_image(image_path) tensor_image = tensor_image.to(self.device) # 如果没有指定目标类别,使用预测类别 if target_class is None: with torch.no_grad(): outputs = self.model(tensor_image) _, predicted = torch.max(outputs, 1) target_class = predicted.item() # 生成GradCAM(使用自研模块) grayscale_cam = self.grad_cam.forward(tensor_image, target_class) # 将原图像标准化到[0,1]范围 normalized_image = original_image.astype(np.float32) / 255.0 # 调整图像尺寸匹配CAM image_size = self.config['data']['image_size'] resized_image = cv2.resize(normalized_image, (image_size, image_size)) # 生成可视化图像(使用自研函数) from utils.explainability import show_cam_on_image visualization = show_cam_on_image(resized_image, grayscale_cam, use_rgb=True) return visualization def visualize_prediction(self, image_path: str, save_path: str = None) -> plt.Figure: """ 可视化预测结果 Args: image_path: 图像路径 save_path: 保存路径 Returns: plt.Figure: matplotlib图形对象 """ # 预测 result = self.predict_single(image_path) # 创建图形 fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # 原始图像 original_image = cv2.imread(image_path) original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) axes[0].imshow(original_image) axes[0].set_title('原始图像') axes[0].axis('off') # 预测结果 predicted_label = result['predicted_label'] confidence = result['confidence'] axes[1].text(0.1, 0.9, f'预测类别: {predicted_label}', transform=axes[1].transAxes, fontsize=14, fontweight='bold') axes[1].text(0.1, 0.8, f'置信度: {confidence:.3f}', transform=axes[1].transAxes, fontsize=12) # 概率分布 if 'probabilities' in result: y_pos = 0.7 for class_name, prob in result['probabilities'].items(): color = 'red' if class_name == predicted_label else 'black' axes[1].text(0.1, y_pos, f'{class_name}: {prob:.3f}', transform=axes[1].transAxes, fontsize=10, color=color) y_pos -= 0.08 axes[1].axis('off') axes[1].set_title('预测结果') # GradCAM可视化 if self.grad_cam is not None: try: grad_cam_viz = self.generate_grad_cam(image_path) axes[2].imshow(grad_cam_viz) axes[2].set_title('GradCAM可视化') except Exception as e: axes[2].text(0.5, 0.5, f'GradCAM生成失败:\n{str(e)}', transform=axes[2].transAxes, ha='center', va='center') axes[2].set_title('GradCAM可视化') else: axes[2].text(0.5, 0.5, 'GradCAM未启用', transform=axes[2].transAxes, ha='center', va='center') axes[2].set_title('GradCAM可视化') axes[2].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"可视化结果已保存: {save_path}") return fig def predict_with_tta(self, image_path: str, tta_transforms: int = 5) -> Dict: """ 使用测试时增强(TTA)进行预测 Args: image_path: 图像路径 tta_transforms: TTA变换次数 Returns: Dict: 预测结果 """ # 读取和预处理图像 tensor_image, original_image = self.preprocess_image(image_path) all_predictions = [] # 创建TTA变换 tta_transform = A.Compose([ A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Rotate(limit=15, p=0.5), A.Resize(self.config['data']['image_size'], self.config['data']['image_size']), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) # 原始预测 with torch.no_grad(): tensor_image = tensor_image.to(self.device) outputs = self.model(tensor_image) probabilities = F.softmax(outputs, dim=1) all_predictions.append(probabilities.cpu().numpy()) # TTA预测 for _ in range(tta_transforms): augmented = tta_transform(image=original_image) aug_tensor = augmented['image'].unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(aug_tensor) probabilities = F.softmax(outputs, dim=1) all_predictions.append(probabilities.cpu().numpy()) # 平均预测结果 avg_predictions = np.mean(all_predictions, axis=0)[0] predicted_class = np.argmax(avg_predictions) confidence = avg_predictions[predicted_class] return { 'predicted_class': predicted_class, 'predicted_label': self.class_names[predicted_class], 'confidence': confidence, 'probabilities': { self.class_names[i]: prob for i, prob in enumerate(avg_predictions) }, 'tta_used': True } def batch_inference(config_path: str, model_path: str, input_dir: str, output_file: str): """ 批量推理工具函数 Args: config_path: 配置文件路径 model_path: 模型路径 input_dir: 输入图像目录 output_file: 输出CSV文件路径 """ import pandas as pd # 创建预测器 predictor = DRPredictor(config_path, model_path) # 获取所有图像文件 image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] image_paths = [] for ext in image_extensions: image_paths.extend( [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.lower().endswith(ext)] ) if not image_paths: print(f"在目录 {input_dir} 中未找到图像文件") return print(f"找到 {len(image_paths)} 张图像,开始批量推理...") # 批量预测 results = predictor.predict_batch(image_paths) # 整理结果 df_data = [] for result in results: if 'error' not in result: row = { 'image_path': os.path.basename(result['image_path']), 'predicted_class': result['predicted_class'], 'predicted_label': result['predicted_label'], 'confidence': result['confidence'], 'inference_time': result['inference_time'] } # 添加概率分布 if 'probabilities' in result: for class_name, prob in result['probabilities'].items(): row[f'prob_{class_name}'] = prob df_data.append(row) else: df_data.append({ 'image_path': os.path.basename(result['image_path']), 'error': result['error'] }) # 保存结果 df = pd.DataFrame(df_data) df.to_csv(output_file, index=False, encoding='utf-8') print(f"结果已保存到: {output_file}") # 统计 if 'predicted_class' in df.columns: print("\n预测分布:") print(df['predicted_label'].value_counts()) print(f"\n平均置信度: {df['confidence'].mean():.3f}") print(f"平均推理时间: {df['inference_time'].mean():.3f}s") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='糖尿病视网膜病变预测') parser.add_argument('--config', type=str, default='configs/config.yaml', help='配置文件路径') parser.add_argument('--model', type=str, help='模型权重路径') parser.add_argument('--image', type=str, help='单张图像预测') parser.add_argument('--batch_dir', type=str, help='批量预测目录') parser.add_argument('--output', type=str, default='predictions.csv', help='输出文件') parser.add_argument('--visualize', action='store_true', help='可视化预测结果') parser.add_argument('--tta', action='store_true', help='单图或批量预测时启用TTA') parser.add_argument('--quantize', action='store_true', help='启用动态量化(CPU 更快、更小)') args = parser.parse_args() if args.image: # 单张图像预测 predictor = DRPredictor(args.config, args.model) if args.quantize: predictor.quantize_model() predictor.print_model_profile() if args.tta: result = predictor.predict_with_tta(args.image) else: result = predictor.predict_single(args.image) print("预测结果:") print(f"类别: {result['predicted_label']}") print(f"置信度: {result['confidence']:.3f}") print(f"推理时间: {result['inference_time']:.3f}s") if 'probabilities' in result: print("\n概率分布:") for class_name, prob in result['probabilities'].items(): print(f" {class_name}: {prob:.3f}") if args.visualize: save_path = args.image.replace('.jpg', '_prediction.png').replace('.png', '_prediction.png') fig = predictor.visualize_prediction(args.image, save_path) plt.show() elif args.batch_dir: # 批量预测 predictor = DRPredictor(args.config, args.model) if args.quantize: predictor.quantize_model() predictor.print_model_profile() if args.tta: # 手动实现带TTA的批量推理 import os import pandas as pd image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] image_paths = [] for ext in image_extensions: image_paths.extend( [os.path.join(args.batch_dir, f) for f in os.listdir(args.batch_dir) if f.lower().endswith(ext)] ) results = [] for image_path in image_paths: try: r = predictor.predict_with_tta(image_path) r['image_path'] = image_path results.append(r) except Exception as e: results.append({'image_path': image_path, 'error': str(e)}) df = pd.DataFrame(results) df.to_csv(args.output, index=False, encoding='utf-8') print(f"结果已保存到: {args.output}") else: # 普通批量推理 batch_inference(args.config, args.model, args.batch_dir, args.output) else: print("请指定 --image 或 --batch_dir 参数")