aicomp_demo / src /inference.py
ceasonen
我的视网膜检测网站
04103fb
"""
推理模块
支持单张图像和批量推理,以及可视化
"""
import os
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import yaml
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Optional, Union
import albumentations as A
from albumentations.pytorch import ToTensorV2
"""
Grad-CAM 相关依赖(pytorch-grad-cam)在部分 Python/平台上可能不可用。
已改为惰性导入:仅在启用 grad_cam 时尝试 import,失败将优雅降级。
"""
import time
from src.models import create_model
from src.data_loader import create_data_transforms
class DRPredictor:
def print_model_profile(self, input_size: Tuple[int, int] = None):
"""打印模型参数量、FLOPs、模型大小等信息(支持量化/非量化)。"""
from src.models import count_parameters, model_size_mb
input_size = input_size or (self.config['data']['image_size'], self.config['data']['image_size'])
print("模型参数量: %d" % count_parameters(self.model))
print("模型大小: %.2f MB" % model_size_mb(self.model))
try:
from thop import profile
dummy = torch.randn(1, 3, input_size[0], input_size[1]).to(self.device)
flops, params = profile(self.model, inputs=(dummy,), verbose=False)
print("FLOPs: %.2f M" % (flops / 1e6))
except Exception as e:
print(f"FLOPs统计失败: {e}")
"""糖尿病视网膜病变预测器"""
def __init__(self, config_path: str, model_path: str = None):
"""
初始化预测器
from src.models import create_model
from src.data_loader import DiabeticRetinopathyDataset
Args:
config_path: 配置文件路径
model_path: 模型权重路径,如果为None则使用配置文件中的路径
"""
with open(config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
self.device = torch.device(
f"cuda:{self.config['device']['gpu_id']}"
if self.config['device']['use_gpu'] and torch.cuda.is_available()
else "cpu"
)
# 加载模型
self.model = create_model(self.config).to(self.device)
model_path = model_path or self.config['inference']['model_path']
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=self.device)
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
print(f"加载模型权重: {model_path}")
else:
print(f"警告: 模型文件不存在 {model_path}")
self.model.eval()
# 类别名称
self.class_names = self.config['data']['class_names']
# 创建预处理变换
self.transform = self._create_transform()
# 初始化GradCAM(用于可视化)
self.grad_cam = None
if self.config['inference'].get('grad_cam', False):
self._setup_grad_cam()
def load_config(config_path='configs/config.yaml'):
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def run_inference(model, dataloader, device):
model.eval()
results = []
with torch.no_grad():
for batch in dataloader:
# 兼容多任务
if len(batch) == 3:
images, labels, is_diabetic = batch
else:
images, labels = batch
is_diabetic = None
images = images.to(device)
outputs = model(images)
batch_size = images.size(0)
# 获取图片名
image_names = None
if hasattr(dataloader.dataset, 'df') and 'id_code' in dataloader.dataset.df.columns:
image_names = dataloader.dataset.df['id_code'].values
elif hasattr(dataloader.dataset, 'images'):
image_names = [os.path.basename(p) for p in dataloader.dataset.images]
# 预测
if isinstance(outputs, dict):
grading_logits = outputs['grading']
diabetic_logits = outputs['diabetic']
grading_pred = grading_logits.argmax(1).cpu().numpy()
diabetic_prob = torch.sigmoid(diabetic_logits).cpu().numpy()
diabetic_pred = (diabetic_prob > 0.5).astype(int)
grading_probs = torch.softmax(grading_logits, dim=1).cpu().numpy()
else:
grading_logits = outputs
grading_pred = grading_logits.argmax(1).cpu().numpy()
grading_probs = torch.softmax(grading_logits, dim=1).cpu().numpy()
diabetic_pred = None
diabetic_prob = None
for i in range(batch_size):
result = {}
# 图片名
if image_names is not None:
result['image'] = image_names[i]
# 标签
result['label'] = int(labels[i].cpu().numpy())
result['pred'] = int(grading_pred[i])
# 多分类概率
for c in range(grading_probs.shape[1]):
result[f'pred_{c}'] = float(grading_probs[i, c])
# 二分类标签/概率
if is_diabetic is not None:
result['is_diabetic'] = int(is_diabetic[i].cpu().numpy())
if diabetic_pred is not None:
result['is_diabetic_pred'] = int(diabetic_pred[i])
result['is_diabetic_prob'] = float(diabetic_prob[i])
results.append(result)
return results
def main():
config = load_config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = create_model(config)
model_path = config['inference']['model_path']
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.to(device)
# 推理数据集
test_csv = os.path.join(config['data']['test_dir'], '../test_labels.csv')
from src.data_loader import DiabeticRetinopathyDataset
dataset = DiabeticRetinopathyDataset(
data_dir=config['data']['test_dir'],
csv_file=test_csv,
image_size=config['data']['image_size']
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
results = run_inference(model, dataloader, device)
# 保存结果
df = pd.DataFrame(results)
# 自动生成 predictions.csv 供所有可视化脚本使用
os.makedirs('evaluation_results', exist_ok=True)
df.to_csv('evaluation_results/predictions.csv', index=False)
print('推理结果已保存为 evaluation_results/predictions.csv')
if __name__ == '__main__':
main()
def quantize_model(self):
"""对模型线性层执行动态量化(CPU 推理提速与减小体积)。"""
if self.device.type != 'cpu':
print("提示: 量化仅在 CPU 推理时有意义,已跳过")
return
try:
import torch.nn as nn
self.model = torch.quantization.quantize_dynamic(
self.model, {nn.Linear}, dtype=torch.qint8
)
self.model.eval()
print("已对模型执行动态量化(Linear→int8)")
except Exception as e:
print(f"量化失败,已跳过: {e}")
def _create_transform(self) -> A.Compose:
"""创建预处理变换"""
image_size = self.config['data']['image_size']
return A.Compose([
A.Resize(image_size, image_size),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
def _setup_grad_cam(self):
"""设置GradCAM(使用自研模块)。"""
try:
from utils.explainability import GradCAM
except Exception as e:
print(f"自研 GradCAM 模块加载失败: {e}")
self.grad_cam = None
return
# 根据模型类型选择目标层
model_name = self.config['model']['name'].lower()
if 'efficientnet' in model_name:
target_layers = [self.model.backbone.conv_head]
elif 'resnet' in model_name:
target_layers = [self.model.backbone[-1][-1].conv2]
elif 'vit' in model_name:
target_layers = [self.model.backbone.norm]
else:
# 默认使用最后一个卷积层
target_layers = []
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Conv2d):
target_layers = [module]
if target_layers:
self.grad_cam = GradCAM(
model=self.model,
target_layers=target_layers,
use_cuda=self.device.type == 'cuda'
)
print("已初始化自研 GradCAM 模块")
def preprocess_image(self, image_path: str) -> Tuple[torch.Tensor, np.ndarray]:
"""
预处理图像
Args:
image_path: 图像路径
Returns:
Tuple[torch.Tensor, np.ndarray]: 预处理后的tensor和原始图像
"""
# 读取图像
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法读取图像: {image_path}")
# 转换颜色空间
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 眼底图像预处理(去除黑边)
processed_image = self._preprocess_fundus_image(image_rgb)
# 应用变换
transformed = self.transform(image=processed_image)
tensor_image = transformed['image'].unsqueeze(0) # 添加批次维度
return tensor_image, processed_image
def _preprocess_fundus_image(self, image: np.ndarray) -> np.ndarray:
"""眼底图像预处理"""
# 转换为灰度图用于检测黑色边框
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# 阈值化去除黑色背景
_, thresh = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
# 找到轮廓
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
# 找到最大轮廓(眼底区域)
largest_contour = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(largest_contour)
# 裁剪图像
cropped = image[y:y+h, x:x+w]
return cropped
return image
def predict_single(self, image_path: str, return_probs: bool = True) -> Dict:
from utils.retina_detector import is_retina_image
if not is_retina_image(image_path):
raise ValueError('图片不符合要求,请上传标准视网膜照片。')
"""
单张图像预测
Args:
image_path: 图像路径
return_probs: 是否返回概率分布
Returns:
Dict: 预测结果
"""
start_time = time.time()
# 预处理
tensor_image, original_image = self.preprocess_image(image_path)
tensor_image = tensor_image.to(self.device)
# 推理
with torch.no_grad():
outputs = self.model(tensor_image)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
# 结果
predicted_class = predicted.item()
confidence_score = confidence.item()
result = {
'predicted_class': predicted_class,
'predicted_label': self.class_names[predicted_class],
'confidence': confidence_score,
'inference_time': time.time() - start_time
}
if return_probs:
result['probabilities'] = {
self.class_names[i]: prob.item()
for i, prob in enumerate(probabilities[0])
}
return result
def predict_batch(self, image_paths: List[str]) -> List[Dict]:
"""
批量预测
Args:
image_paths: 图像路径列表
Returns:
List[Dict]: 预测结果列表
"""
results = []
for image_path in image_paths:
try:
result = self.predict_single(image_path)
result['image_path'] = image_path
results.append(result)
except Exception as e:
print(f"预测失败 {image_path}: {e}")
results.append({
'image_path': image_path,
'error': str(e)
})
return results
def generate_grad_cam(self, image_path: str, target_class: int = None) -> np.ndarray:
"""
生成GradCAM可视化
Args:
image_path: 图像路径
target_class: 目标类别,如果为None则使用预测类别
Returns:
np.ndarray: GradCAM可视化图像
"""
if self.grad_cam is None:
raise ValueError("GradCAM未初始化,请在配置中启用grad_cam")
# 预处理图像
tensor_image, original_image = self.preprocess_image(image_path)
tensor_image = tensor_image.to(self.device)
# 如果没有指定目标类别,使用预测类别
if target_class is None:
with torch.no_grad():
outputs = self.model(tensor_image)
_, predicted = torch.max(outputs, 1)
target_class = predicted.item()
# 生成GradCAM(使用自研模块)
grayscale_cam = self.grad_cam.forward(tensor_image, target_class)
# 将原图像标准化到[0,1]范围
normalized_image = original_image.astype(np.float32) / 255.0
# 调整图像尺寸匹配CAM
image_size = self.config['data']['image_size']
resized_image = cv2.resize(normalized_image, (image_size, image_size))
# 生成可视化图像(使用自研函数)
from utils.explainability import show_cam_on_image
visualization = show_cam_on_image(resized_image, grayscale_cam, use_rgb=True)
return visualization
def visualize_prediction(self, image_path: str, save_path: str = None) -> plt.Figure:
"""
可视化预测结果
Args:
image_path: 图像路径
save_path: 保存路径
Returns:
plt.Figure: matplotlib图形对象
"""
# 预测
result = self.predict_single(image_path)
# 创建图形
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
# 原始图像
original_image = cv2.imread(image_path)
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
axes[0].imshow(original_image)
axes[0].set_title('原始图像')
axes[0].axis('off')
# 预测结果
predicted_label = result['predicted_label']
confidence = result['confidence']
axes[1].text(0.1, 0.9, f'预测类别: {predicted_label}', transform=axes[1].transAxes,
fontsize=14, fontweight='bold')
axes[1].text(0.1, 0.8, f'置信度: {confidence:.3f}', transform=axes[1].transAxes,
fontsize=12)
# 概率分布
if 'probabilities' in result:
y_pos = 0.7
for class_name, prob in result['probabilities'].items():
color = 'red' if class_name == predicted_label else 'black'
axes[1].text(0.1, y_pos, f'{class_name}: {prob:.3f}',
transform=axes[1].transAxes, fontsize=10, color=color)
y_pos -= 0.08
axes[1].axis('off')
axes[1].set_title('预测结果')
# GradCAM可视化
if self.grad_cam is not None:
try:
grad_cam_viz = self.generate_grad_cam(image_path)
axes[2].imshow(grad_cam_viz)
axes[2].set_title('GradCAM可视化')
except Exception as e:
axes[2].text(0.5, 0.5, f'GradCAM生成失败:\n{str(e)}',
transform=axes[2].transAxes, ha='center', va='center')
axes[2].set_title('GradCAM可视化')
else:
axes[2].text(0.5, 0.5, 'GradCAM未启用',
transform=axes[2].transAxes, ha='center', va='center')
axes[2].set_title('GradCAM可视化')
axes[2].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"可视化结果已保存: {save_path}")
return fig
def predict_with_tta(self, image_path: str, tta_transforms: int = 5) -> Dict:
"""
使用测试时增强(TTA)进行预测
Args:
image_path: 图像路径
tta_transforms: TTA变换次数
Returns:
Dict: 预测结果
"""
# 读取和预处理图像
tensor_image, original_image = self.preprocess_image(image_path)
all_predictions = []
# 创建TTA变换
tta_transform = A.Compose([
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Rotate(limit=15, p=0.5),
A.Resize(self.config['data']['image_size'], self.config['data']['image_size']),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# 原始预测
with torch.no_grad():
tensor_image = tensor_image.to(self.device)
outputs = self.model(tensor_image)
probabilities = F.softmax(outputs, dim=1)
all_predictions.append(probabilities.cpu().numpy())
# TTA预测
for _ in range(tta_transforms):
augmented = tta_transform(image=original_image)
aug_tensor = augmented['image'].unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(aug_tensor)
probabilities = F.softmax(outputs, dim=1)
all_predictions.append(probabilities.cpu().numpy())
# 平均预测结果
avg_predictions = np.mean(all_predictions, axis=0)[0]
predicted_class = np.argmax(avg_predictions)
confidence = avg_predictions[predicted_class]
return {
'predicted_class': predicted_class,
'predicted_label': self.class_names[predicted_class],
'confidence': confidence,
'probabilities': {
self.class_names[i]: prob
for i, prob in enumerate(avg_predictions)
},
'tta_used': True
}
def batch_inference(config_path: str, model_path: str,
input_dir: str, output_file: str):
"""
批量推理工具函数
Args:
config_path: 配置文件路径
model_path: 模型路径
input_dir: 输入图像目录
output_file: 输出CSV文件路径
"""
import pandas as pd
# 创建预测器
predictor = DRPredictor(config_path, model_path)
# 获取所有图像文件
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
image_paths = []
for ext in image_extensions:
image_paths.extend(
[os.path.join(input_dir, f) for f in os.listdir(input_dir)
if f.lower().endswith(ext)]
)
if not image_paths:
print(f"在目录 {input_dir} 中未找到图像文件")
return
print(f"找到 {len(image_paths)} 张图像,开始批量推理...")
# 批量预测
results = predictor.predict_batch(image_paths)
# 整理结果
df_data = []
for result in results:
if 'error' not in result:
row = {
'image_path': os.path.basename(result['image_path']),
'predicted_class': result['predicted_class'],
'predicted_label': result['predicted_label'],
'confidence': result['confidence'],
'inference_time': result['inference_time']
}
# 添加概率分布
if 'probabilities' in result:
for class_name, prob in result['probabilities'].items():
row[f'prob_{class_name}'] = prob
df_data.append(row)
else:
df_data.append({
'image_path': os.path.basename(result['image_path']),
'error': result['error']
})
# 保存结果
df = pd.DataFrame(df_data)
df.to_csv(output_file, index=False, encoding='utf-8')
print(f"结果已保存到: {output_file}")
# 统计
if 'predicted_class' in df.columns:
print("\n预测分布:")
print(df['predicted_label'].value_counts())
print(f"\n平均置信度: {df['confidence'].mean():.3f}")
print(f"平均推理时间: {df['inference_time'].mean():.3f}s")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='糖尿病视网膜病变预测')
parser.add_argument('--config', type=str, default='configs/config.yaml', help='配置文件路径')
parser.add_argument('--model', type=str, help='模型权重路径')
parser.add_argument('--image', type=str, help='单张图像预测')
parser.add_argument('--batch_dir', type=str, help='批量预测目录')
parser.add_argument('--output', type=str, default='predictions.csv', help='输出文件')
parser.add_argument('--visualize', action='store_true', help='可视化预测结果')
parser.add_argument('--tta', action='store_true', help='单图或批量预测时启用TTA')
parser.add_argument('--quantize', action='store_true', help='启用动态量化(CPU 更快、更小)')
args = parser.parse_args()
if args.image:
# 单张图像预测
predictor = DRPredictor(args.config, args.model)
if args.quantize:
predictor.quantize_model()
predictor.print_model_profile()
if args.tta:
result = predictor.predict_with_tta(args.image)
else:
result = predictor.predict_single(args.image)
print("预测结果:")
print(f"类别: {result['predicted_label']}")
print(f"置信度: {result['confidence']:.3f}")
print(f"推理时间: {result['inference_time']:.3f}s")
if 'probabilities' in result:
print("\n概率分布:")
for class_name, prob in result['probabilities'].items():
print(f" {class_name}: {prob:.3f}")
if args.visualize:
save_path = args.image.replace('.jpg', '_prediction.png').replace('.png', '_prediction.png')
fig = predictor.visualize_prediction(args.image, save_path)
plt.show()
elif args.batch_dir:
# 批量预测
predictor = DRPredictor(args.config, args.model)
if args.quantize:
predictor.quantize_model()
predictor.print_model_profile()
if args.tta:
# 手动实现带TTA的批量推理
import os
import pandas as pd
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
image_paths = []
for ext in image_extensions:
image_paths.extend(
[os.path.join(args.batch_dir, f) for f in os.listdir(args.batch_dir)
if f.lower().endswith(ext)]
)
results = []
for image_path in image_paths:
try:
r = predictor.predict_with_tta(image_path)
r['image_path'] = image_path
results.append(r)
except Exception as e:
results.append({'image_path': image_path, 'error': str(e)})
df = pd.DataFrame(results)
df.to_csv(args.output, index=False, encoding='utf-8')
print(f"结果已保存到: {args.output}")
else:
# 普通批量推理
batch_inference(args.config, args.model, args.batch_dir, args.output)
else:
print("请指定 --image 或 --batch_dir 参数")