Spaces:
Running
Running
""" | |
推理模块 | |
支持单张图像和批量推理,以及可视化 | |
""" | |
import os | |
import torch | |
import torch.nn.functional as F | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import yaml | |
import matplotlib.pyplot as plt | |
from typing import List, Tuple, Dict, Optional, Union | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
""" | |
Grad-CAM 相关依赖(pytorch-grad-cam)在部分 Python/平台上可能不可用。 | |
已改为惰性导入:仅在启用 grad_cam 时尝试 import,失败将优雅降级。 | |
""" | |
import time | |
from src.models import create_model | |
from src.data_loader import create_data_transforms | |
class DRPredictor: | |
def print_model_profile(self, input_size: Tuple[int, int] = None): | |
"""打印模型参数量、FLOPs、模型大小等信息(支持量化/非量化)。""" | |
from src.models import count_parameters, model_size_mb | |
input_size = input_size or (self.config['data']['image_size'], self.config['data']['image_size']) | |
print("模型参数量: %d" % count_parameters(self.model)) | |
print("模型大小: %.2f MB" % model_size_mb(self.model)) | |
try: | |
from thop import profile | |
dummy = torch.randn(1, 3, input_size[0], input_size[1]).to(self.device) | |
flops, params = profile(self.model, inputs=(dummy,), verbose=False) | |
print("FLOPs: %.2f M" % (flops / 1e6)) | |
except Exception as e: | |
print(f"FLOPs统计失败: {e}") | |
"""糖尿病视网膜病变预测器""" | |
def __init__(self, config_path: str, model_path: str = None): | |
""" | |
初始化预测器 | |
from src.models import create_model | |
from src.data_loader import DiabeticRetinopathyDataset | |
Args: | |
config_path: 配置文件路径 | |
model_path: 模型权重路径,如果为None则使用配置文件中的路径 | |
""" | |
with open(config_path, 'r', encoding='utf-8') as f: | |
self.config = yaml.safe_load(f) | |
self.device = torch.device( | |
f"cuda:{self.config['device']['gpu_id']}" | |
if self.config['device']['use_gpu'] and torch.cuda.is_available() | |
else "cpu" | |
) | |
# 加载模型 | |
self.model = create_model(self.config).to(self.device) | |
model_path = model_path or self.config['inference']['model_path'] | |
if os.path.exists(model_path): | |
checkpoint = torch.load(model_path, map_location=self.device) | |
if 'model_state_dict' in checkpoint: | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
self.model.load_state_dict(checkpoint) | |
print(f"加载模型权重: {model_path}") | |
else: | |
print(f"警告: 模型文件不存在 {model_path}") | |
self.model.eval() | |
# 类别名称 | |
self.class_names = self.config['data']['class_names'] | |
# 创建预处理变换 | |
self.transform = self._create_transform() | |
# 初始化GradCAM(用于可视化) | |
self.grad_cam = None | |
if self.config['inference'].get('grad_cam', False): | |
self._setup_grad_cam() | |
def load_config(config_path='configs/config.yaml'): | |
with open(config_path, 'r', encoding='utf-8') as f: | |
return yaml.safe_load(f) | |
def run_inference(model, dataloader, device): | |
model.eval() | |
results = [] | |
with torch.no_grad(): | |
for batch in dataloader: | |
# 兼容多任务 | |
if len(batch) == 3: | |
images, labels, is_diabetic = batch | |
else: | |
images, labels = batch | |
is_diabetic = None | |
images = images.to(device) | |
outputs = model(images) | |
batch_size = images.size(0) | |
# 获取图片名 | |
image_names = None | |
if hasattr(dataloader.dataset, 'df') and 'id_code' in dataloader.dataset.df.columns: | |
image_names = dataloader.dataset.df['id_code'].values | |
elif hasattr(dataloader.dataset, 'images'): | |
image_names = [os.path.basename(p) for p in dataloader.dataset.images] | |
# 预测 | |
if isinstance(outputs, dict): | |
grading_logits = outputs['grading'] | |
diabetic_logits = outputs['diabetic'] | |
grading_pred = grading_logits.argmax(1).cpu().numpy() | |
diabetic_prob = torch.sigmoid(diabetic_logits).cpu().numpy() | |
diabetic_pred = (diabetic_prob > 0.5).astype(int) | |
grading_probs = torch.softmax(grading_logits, dim=1).cpu().numpy() | |
else: | |
grading_logits = outputs | |
grading_pred = grading_logits.argmax(1).cpu().numpy() | |
grading_probs = torch.softmax(grading_logits, dim=1).cpu().numpy() | |
diabetic_pred = None | |
diabetic_prob = None | |
for i in range(batch_size): | |
result = {} | |
# 图片名 | |
if image_names is not None: | |
result['image'] = image_names[i] | |
# 标签 | |
result['label'] = int(labels[i].cpu().numpy()) | |
result['pred'] = int(grading_pred[i]) | |
# 多分类概率 | |
for c in range(grading_probs.shape[1]): | |
result[f'pred_{c}'] = float(grading_probs[i, c]) | |
# 二分类标签/概率 | |
if is_diabetic is not None: | |
result['is_diabetic'] = int(is_diabetic[i].cpu().numpy()) | |
if diabetic_pred is not None: | |
result['is_diabetic_pred'] = int(diabetic_pred[i]) | |
result['is_diabetic_prob'] = float(diabetic_prob[i]) | |
results.append(result) | |
return results | |
def main(): | |
config = load_config() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = create_model(config) | |
model_path = config['inference']['model_path'] | |
checkpoint = torch.load(model_path, map_location=device) | |
if 'model_state_dict' in checkpoint: | |
model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
model.load_state_dict(checkpoint) | |
model.to(device) | |
# 推理数据集 | |
test_csv = os.path.join(config['data']['test_dir'], '../test_labels.csv') | |
from src.data_loader import DiabeticRetinopathyDataset | |
dataset = DiabeticRetinopathyDataset( | |
data_dir=config['data']['test_dir'], | |
csv_file=test_csv, | |
image_size=config['data']['image_size'] | |
) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False) | |
results = run_inference(model, dataloader, device) | |
# 保存结果 | |
df = pd.DataFrame(results) | |
# 自动生成 predictions.csv 供所有可视化脚本使用 | |
os.makedirs('evaluation_results', exist_ok=True) | |
df.to_csv('evaluation_results/predictions.csv', index=False) | |
print('推理结果已保存为 evaluation_results/predictions.csv') | |
if __name__ == '__main__': | |
main() | |
def quantize_model(self): | |
"""对模型线性层执行动态量化(CPU 推理提速与减小体积)。""" | |
if self.device.type != 'cpu': | |
print("提示: 量化仅在 CPU 推理时有意义,已跳过") | |
return | |
try: | |
import torch.nn as nn | |
self.model = torch.quantization.quantize_dynamic( | |
self.model, {nn.Linear}, dtype=torch.qint8 | |
) | |
self.model.eval() | |
print("已对模型执行动态量化(Linear→int8)") | |
except Exception as e: | |
print(f"量化失败,已跳过: {e}") | |
def _create_transform(self) -> A.Compose: | |
"""创建预处理变换""" | |
image_size = self.config['data']['image_size'] | |
return A.Compose([ | |
A.Resize(image_size, image_size), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]) | |
def _setup_grad_cam(self): | |
"""设置GradCAM(使用自研模块)。""" | |
try: | |
from utils.explainability import GradCAM | |
except Exception as e: | |
print(f"自研 GradCAM 模块加载失败: {e}") | |
self.grad_cam = None | |
return | |
# 根据模型类型选择目标层 | |
model_name = self.config['model']['name'].lower() | |
if 'efficientnet' in model_name: | |
target_layers = [self.model.backbone.conv_head] | |
elif 'resnet' in model_name: | |
target_layers = [self.model.backbone[-1][-1].conv2] | |
elif 'vit' in model_name: | |
target_layers = [self.model.backbone.norm] | |
else: | |
# 默认使用最后一个卷积层 | |
target_layers = [] | |
for name, module in self.model.named_modules(): | |
if isinstance(module, torch.nn.Conv2d): | |
target_layers = [module] | |
if target_layers: | |
self.grad_cam = GradCAM( | |
model=self.model, | |
target_layers=target_layers, | |
use_cuda=self.device.type == 'cuda' | |
) | |
print("已初始化自研 GradCAM 模块") | |
def preprocess_image(self, image_path: str) -> Tuple[torch.Tensor, np.ndarray]: | |
""" | |
预处理图像 | |
Args: | |
image_path: 图像路径 | |
Returns: | |
Tuple[torch.Tensor, np.ndarray]: 预处理后的tensor和原始图像 | |
""" | |
# 读取图像 | |
image = cv2.imread(image_path) | |
if image is None: | |
raise ValueError(f"无法读取图像: {image_path}") | |
# 转换颜色空间 | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# 眼底图像预处理(去除黑边) | |
processed_image = self._preprocess_fundus_image(image_rgb) | |
# 应用变换 | |
transformed = self.transform(image=processed_image) | |
tensor_image = transformed['image'].unsqueeze(0) # 添加批次维度 | |
return tensor_image, processed_image | |
def _preprocess_fundus_image(self, image: np.ndarray) -> np.ndarray: | |
"""眼底图像预处理""" | |
# 转换为灰度图用于检测黑色边框 | |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
# 阈值化去除黑色背景 | |
_, thresh = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY) | |
# 找到轮廓 | |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if contours: | |
# 找到最大轮廓(眼底区域) | |
largest_contour = max(contours, key=cv2.contourArea) | |
x, y, w, h = cv2.boundingRect(largest_contour) | |
# 裁剪图像 | |
cropped = image[y:y+h, x:x+w] | |
return cropped | |
return image | |
def predict_single(self, image_path: str, return_probs: bool = True) -> Dict: | |
from utils.retina_detector import is_retina_image | |
if not is_retina_image(image_path): | |
raise ValueError('图片不符合要求,请上传标准视网膜照片。') | |
""" | |
单张图像预测 | |
Args: | |
image_path: 图像路径 | |
return_probs: 是否返回概率分布 | |
Returns: | |
Dict: 预测结果 | |
""" | |
start_time = time.time() | |
# 预处理 | |
tensor_image, original_image = self.preprocess_image(image_path) | |
tensor_image = tensor_image.to(self.device) | |
# 推理 | |
with torch.no_grad(): | |
outputs = self.model(tensor_image) | |
probabilities = F.softmax(outputs, dim=1) | |
confidence, predicted = torch.max(probabilities, 1) | |
# 结果 | |
predicted_class = predicted.item() | |
confidence_score = confidence.item() | |
result = { | |
'predicted_class': predicted_class, | |
'predicted_label': self.class_names[predicted_class], | |
'confidence': confidence_score, | |
'inference_time': time.time() - start_time | |
} | |
if return_probs: | |
result['probabilities'] = { | |
self.class_names[i]: prob.item() | |
for i, prob in enumerate(probabilities[0]) | |
} | |
return result | |
def predict_batch(self, image_paths: List[str]) -> List[Dict]: | |
""" | |
批量预测 | |
Args: | |
image_paths: 图像路径列表 | |
Returns: | |
List[Dict]: 预测结果列表 | |
""" | |
results = [] | |
for image_path in image_paths: | |
try: | |
result = self.predict_single(image_path) | |
result['image_path'] = image_path | |
results.append(result) | |
except Exception as e: | |
print(f"预测失败 {image_path}: {e}") | |
results.append({ | |
'image_path': image_path, | |
'error': str(e) | |
}) | |
return results | |
def generate_grad_cam(self, image_path: str, target_class: int = None) -> np.ndarray: | |
""" | |
生成GradCAM可视化 | |
Args: | |
image_path: 图像路径 | |
target_class: 目标类别,如果为None则使用预测类别 | |
Returns: | |
np.ndarray: GradCAM可视化图像 | |
""" | |
if self.grad_cam is None: | |
raise ValueError("GradCAM未初始化,请在配置中启用grad_cam") | |
# 预处理图像 | |
tensor_image, original_image = self.preprocess_image(image_path) | |
tensor_image = tensor_image.to(self.device) | |
# 如果没有指定目标类别,使用预测类别 | |
if target_class is None: | |
with torch.no_grad(): | |
outputs = self.model(tensor_image) | |
_, predicted = torch.max(outputs, 1) | |
target_class = predicted.item() | |
# 生成GradCAM(使用自研模块) | |
grayscale_cam = self.grad_cam.forward(tensor_image, target_class) | |
# 将原图像标准化到[0,1]范围 | |
normalized_image = original_image.astype(np.float32) / 255.0 | |
# 调整图像尺寸匹配CAM | |
image_size = self.config['data']['image_size'] | |
resized_image = cv2.resize(normalized_image, (image_size, image_size)) | |
# 生成可视化图像(使用自研函数) | |
from utils.explainability import show_cam_on_image | |
visualization = show_cam_on_image(resized_image, grayscale_cam, use_rgb=True) | |
return visualization | |
def visualize_prediction(self, image_path: str, save_path: str = None) -> plt.Figure: | |
""" | |
可视化预测结果 | |
Args: | |
image_path: 图像路径 | |
save_path: 保存路径 | |
Returns: | |
plt.Figure: matplotlib图形对象 | |
""" | |
# 预测 | |
result = self.predict_single(image_path) | |
# 创建图形 | |
fig, axes = plt.subplots(1, 3, figsize=(18, 6)) | |
# 原始图像 | |
original_image = cv2.imread(image_path) | |
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) | |
axes[0].imshow(original_image) | |
axes[0].set_title('原始图像') | |
axes[0].axis('off') | |
# 预测结果 | |
predicted_label = result['predicted_label'] | |
confidence = result['confidence'] | |
axes[1].text(0.1, 0.9, f'预测类别: {predicted_label}', transform=axes[1].transAxes, | |
fontsize=14, fontweight='bold') | |
axes[1].text(0.1, 0.8, f'置信度: {confidence:.3f}', transform=axes[1].transAxes, | |
fontsize=12) | |
# 概率分布 | |
if 'probabilities' in result: | |
y_pos = 0.7 | |
for class_name, prob in result['probabilities'].items(): | |
color = 'red' if class_name == predicted_label else 'black' | |
axes[1].text(0.1, y_pos, f'{class_name}: {prob:.3f}', | |
transform=axes[1].transAxes, fontsize=10, color=color) | |
y_pos -= 0.08 | |
axes[1].axis('off') | |
axes[1].set_title('预测结果') | |
# GradCAM可视化 | |
if self.grad_cam is not None: | |
try: | |
grad_cam_viz = self.generate_grad_cam(image_path) | |
axes[2].imshow(grad_cam_viz) | |
axes[2].set_title('GradCAM可视化') | |
except Exception as e: | |
axes[2].text(0.5, 0.5, f'GradCAM生成失败:\n{str(e)}', | |
transform=axes[2].transAxes, ha='center', va='center') | |
axes[2].set_title('GradCAM可视化') | |
else: | |
axes[2].text(0.5, 0.5, 'GradCAM未启用', | |
transform=axes[2].transAxes, ha='center', va='center') | |
axes[2].set_title('GradCAM可视化') | |
axes[2].axis('off') | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
print(f"可视化结果已保存: {save_path}") | |
return fig | |
def predict_with_tta(self, image_path: str, tta_transforms: int = 5) -> Dict: | |
""" | |
使用测试时增强(TTA)进行预测 | |
Args: | |
image_path: 图像路径 | |
tta_transforms: TTA变换次数 | |
Returns: | |
Dict: 预测结果 | |
""" | |
# 读取和预处理图像 | |
tensor_image, original_image = self.preprocess_image(image_path) | |
all_predictions = [] | |
# 创建TTA变换 | |
tta_transform = A.Compose([ | |
A.HorizontalFlip(p=0.5), | |
A.VerticalFlip(p=0.5), | |
A.Rotate(limit=15, p=0.5), | |
A.Resize(self.config['data']['image_size'], self.config['data']['image_size']), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]) | |
# 原始预测 | |
with torch.no_grad(): | |
tensor_image = tensor_image.to(self.device) | |
outputs = self.model(tensor_image) | |
probabilities = F.softmax(outputs, dim=1) | |
all_predictions.append(probabilities.cpu().numpy()) | |
# TTA预测 | |
for _ in range(tta_transforms): | |
augmented = tta_transform(image=original_image) | |
aug_tensor = augmented['image'].unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(aug_tensor) | |
probabilities = F.softmax(outputs, dim=1) | |
all_predictions.append(probabilities.cpu().numpy()) | |
# 平均预测结果 | |
avg_predictions = np.mean(all_predictions, axis=0)[0] | |
predicted_class = np.argmax(avg_predictions) | |
confidence = avg_predictions[predicted_class] | |
return { | |
'predicted_class': predicted_class, | |
'predicted_label': self.class_names[predicted_class], | |
'confidence': confidence, | |
'probabilities': { | |
self.class_names[i]: prob | |
for i, prob in enumerate(avg_predictions) | |
}, | |
'tta_used': True | |
} | |
def batch_inference(config_path: str, model_path: str, | |
input_dir: str, output_file: str): | |
""" | |
批量推理工具函数 | |
Args: | |
config_path: 配置文件路径 | |
model_path: 模型路径 | |
input_dir: 输入图像目录 | |
output_file: 输出CSV文件路径 | |
""" | |
import pandas as pd | |
# 创建预测器 | |
predictor = DRPredictor(config_path, model_path) | |
# 获取所有图像文件 | |
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] | |
image_paths = [] | |
for ext in image_extensions: | |
image_paths.extend( | |
[os.path.join(input_dir, f) for f in os.listdir(input_dir) | |
if f.lower().endswith(ext)] | |
) | |
if not image_paths: | |
print(f"在目录 {input_dir} 中未找到图像文件") | |
return | |
print(f"找到 {len(image_paths)} 张图像,开始批量推理...") | |
# 批量预测 | |
results = predictor.predict_batch(image_paths) | |
# 整理结果 | |
df_data = [] | |
for result in results: | |
if 'error' not in result: | |
row = { | |
'image_path': os.path.basename(result['image_path']), | |
'predicted_class': result['predicted_class'], | |
'predicted_label': result['predicted_label'], | |
'confidence': result['confidence'], | |
'inference_time': result['inference_time'] | |
} | |
# 添加概率分布 | |
if 'probabilities' in result: | |
for class_name, prob in result['probabilities'].items(): | |
row[f'prob_{class_name}'] = prob | |
df_data.append(row) | |
else: | |
df_data.append({ | |
'image_path': os.path.basename(result['image_path']), | |
'error': result['error'] | |
}) | |
# 保存结果 | |
df = pd.DataFrame(df_data) | |
df.to_csv(output_file, index=False, encoding='utf-8') | |
print(f"结果已保存到: {output_file}") | |
# 统计 | |
if 'predicted_class' in df.columns: | |
print("\n预测分布:") | |
print(df['predicted_label'].value_counts()) | |
print(f"\n平均置信度: {df['confidence'].mean():.3f}") | |
print(f"平均推理时间: {df['inference_time'].mean():.3f}s") | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description='糖尿病视网膜病变预测') | |
parser.add_argument('--config', type=str, default='configs/config.yaml', help='配置文件路径') | |
parser.add_argument('--model', type=str, help='模型权重路径') | |
parser.add_argument('--image', type=str, help='单张图像预测') | |
parser.add_argument('--batch_dir', type=str, help='批量预测目录') | |
parser.add_argument('--output', type=str, default='predictions.csv', help='输出文件') | |
parser.add_argument('--visualize', action='store_true', help='可视化预测结果') | |
parser.add_argument('--tta', action='store_true', help='单图或批量预测时启用TTA') | |
parser.add_argument('--quantize', action='store_true', help='启用动态量化(CPU 更快、更小)') | |
args = parser.parse_args() | |
if args.image: | |
# 单张图像预测 | |
predictor = DRPredictor(args.config, args.model) | |
if args.quantize: | |
predictor.quantize_model() | |
predictor.print_model_profile() | |
if args.tta: | |
result = predictor.predict_with_tta(args.image) | |
else: | |
result = predictor.predict_single(args.image) | |
print("预测结果:") | |
print(f"类别: {result['predicted_label']}") | |
print(f"置信度: {result['confidence']:.3f}") | |
print(f"推理时间: {result['inference_time']:.3f}s") | |
if 'probabilities' in result: | |
print("\n概率分布:") | |
for class_name, prob in result['probabilities'].items(): | |
print(f" {class_name}: {prob:.3f}") | |
if args.visualize: | |
save_path = args.image.replace('.jpg', '_prediction.png').replace('.png', '_prediction.png') | |
fig = predictor.visualize_prediction(args.image, save_path) | |
plt.show() | |
elif args.batch_dir: | |
# 批量预测 | |
predictor = DRPredictor(args.config, args.model) | |
if args.quantize: | |
predictor.quantize_model() | |
predictor.print_model_profile() | |
if args.tta: | |
# 手动实现带TTA的批量推理 | |
import os | |
import pandas as pd | |
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] | |
image_paths = [] | |
for ext in image_extensions: | |
image_paths.extend( | |
[os.path.join(args.batch_dir, f) for f in os.listdir(args.batch_dir) | |
if f.lower().endswith(ext)] | |
) | |
results = [] | |
for image_path in image_paths: | |
try: | |
r = predictor.predict_with_tta(image_path) | |
r['image_path'] = image_path | |
results.append(r) | |
except Exception as e: | |
results.append({'image_path': image_path, 'error': str(e)}) | |
df = pd.DataFrame(results) | |
df.to_csv(args.output, index=False, encoding='utf-8') | |
print(f"结果已保存到: {args.output}") | |
else: | |
# 普通批量推理 | |
batch_inference(args.config, args.model, args.batch_dir, args.output) | |
else: | |
print("请指定 --image 或 --batch_dir 参数") | |