Spaces:
Running
Running
""" | |
糖尿病视网膜病变检测项目 | |
数据处理模块 | |
""" | |
import os | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
from typing import Tuple, List, Optional, Dict | |
import yaml | |
class DiabeticRetinopathyDataset(Dataset): | |
"""糖尿病视网膜病变数据集类""" | |
def __init__( | |
self, | |
data_dir: str, | |
csv_file: Optional[str] = None, | |
transform: Optional[A.Compose] = None, | |
image_size: int = 224 | |
): | |
""" | |
初始化数据集 | |
Args: | |
data_dir: 图像数据目录 | |
csv_file: 标签CSV文件路径,如果为None则从文件名推断标签 | |
transform: 数据增强变换 | |
image_size: 图像尺寸 | |
""" | |
self.data_dir = data_dir | |
self.image_size = image_size | |
self.transform = transform | |
if csv_file and os.path.exists(csv_file): | |
# 从CSV文件读取标签 | |
self.df = pd.read_csv(csv_file) | |
# 获取类别名(需与config一致) | |
# 尝试自动读取class_names | |
class_names = None | |
config_path = os.path.join(os.path.dirname(__file__), '../..', 'configs', 'config.yaml') | |
if os.path.exists(config_path): | |
import yaml | |
with open(config_path, 'r', encoding='utf-8') as f: | |
config = yaml.safe_load(f) | |
class_names = config['data']['class_names'] | |
if not class_names: | |
# 兜底 | |
class_names = ['无病变', '轻度', '中度', '重度', '增殖性病变'] | |
# 兼容不同csv格式(prepare_data.py生成的为id_code/diagnosis/is_diabetic) | |
if 'id_code' in self.df.columns and 'diagnosis' in self.df.columns: | |
# 拼接: data_dir/类别名/图片id.png | |
self.images = [os.path.join(self.data_dir, class_names[row['diagnosis']], f"{row['id_code']}.png") for _, row in self.df.iterrows()] | |
self.labels = self.df['diagnosis'].tolist() | |
self.is_diabetic = self.df['is_diabetic'].tolist() if 'is_diabetic' in self.df.columns else None | |
elif 'image' in self.df.columns and 'label' in self.df.columns: | |
# 兼容旧格式 | |
self.images = [os.path.join(self.data_dir, class_names[row['label']], f"{row['image']}.png") for _, row in self.df.iterrows()] | |
self.labels = self.df['label'].tolist() | |
self.is_diabetic = self.df['is_diabetic'].tolist() if 'is_diabetic' in self.df.columns else None | |
else: | |
raise ValueError('CSV文件缺少 id_code/diagnosis 或 image/label 字段') | |
else: | |
# 从目录结构推断标签 | |
self.images, self.labels = self._load_from_directory() | |
self.is_diabetic = None | |
def _load_from_directory(self) -> Tuple[List[str], List[int]]: | |
"""从目录结构加载图像和标签""" | |
images = [] | |
labels = [] | |
# 假设目录结构为: data_dir/class_name/image.jpg | |
for class_idx, class_name in enumerate(os.listdir(self.data_dir)): | |
class_dir = os.path.join(self.data_dir, class_name) | |
if os.path.isdir(class_dir): | |
for img_file in os.listdir(class_dir): | |
if img_file.lower().endswith(('.png', '.jpg', '.jpeg')): | |
images.append(os.path.join(class_dir, img_file)) | |
labels.append(class_idx) | |
return images, labels | |
def __len__(self) -> int: | |
return len(self.images) | |
def __getitem__(self, idx: int): | |
"""获取单个样本,支持多任务输出。自动尝试多种图片后缀,跳过无法读取的图片。""" | |
import warnings | |
img_path = self.images[idx] | |
label = self.labels[idx] | |
is_diabetic = self.is_diabetic[idx] if self.is_diabetic is not None else None | |
# 自动尝试多种图片后缀 | |
if not os.path.exists(img_path): | |
base, ext = os.path.splitext(img_path) | |
tried = [img_path] | |
for suf in ['.png', '.jpg', '.jpeg', '.JPG', '.PNG', '.JPEG']: | |
alt_path = base + suf | |
if os.path.exists(alt_path): | |
img_path = alt_path | |
break | |
tried.append(alt_path) | |
image = cv2.imread(img_path) | |
if image is None: | |
warnings.warn(f"跳过无法读取图像: {img_path}") | |
# 返回None,DataLoader需配合collate_fn过滤 | |
return None | |
# 转换BGR到RGB | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# 预处理:裁剪和调整大小 | |
image = self._preprocess_image(image) | |
# 应用数据增强 | |
if self.transform: | |
augmented = self.transform(image=image) | |
image = augmented['image'] | |
else: | |
# 默认变换 | |
transform = A.Compose([ | |
A.Resize(self.image_size, self.image_size), | |
A.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
), | |
ToTensorV2() | |
]) | |
augmented = transform(image=image) | |
image = augmented['image'] | |
# 返回多任务标签(image, label, is_diabetic),兼容旧用法 | |
if is_diabetic is not None: | |
return image, label, is_diabetic | |
else: | |
return image, label | |
def _preprocess_image(self, image: np.ndarray) -> np.ndarray: | |
"""眼底图像预处理""" | |
# 去除黑色边框 | |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
# 找到非黑色区域 | |
_, thresh = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY) | |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if contours: | |
# 找到最大轮廓 | |
largest_contour = max(contours, key=cv2.contourArea) | |
x, y, w, h = cv2.boundingRect(largest_contour) | |
# 裁剪图像 | |
image = image[y:y+h, x:x+w] | |
return image | |
def create_data_transforms(config: dict, is_training: bool = True) -> A.Compose: | |
"""创建数据变换""" | |
image_size = config['data']['image_size'] | |
if is_training: | |
aug_config = config.get('augmentation', {}) | |
transforms_list = [] | |
# CLAHE | |
if aug_config.get('clahe', False): | |
transforms_list.append(A.CLAHE(clip_limit=2.0, p=0.5)) | |
# 随机旋转 | |
if aug_config.get('rotation', 0) > 0: | |
transforms_list.append(A.Rotate(limit=aug_config['rotation'], p=0.5)) | |
# 随机水平翻转 | |
if aug_config.get('horizontal_flip', False): | |
transforms_list.append(A.HorizontalFlip(p=0.5)) | |
# 亮度/对比度 | |
if aug_config.get('brightness', 0) > 0 or aug_config.get('contrast', 0) > 0: | |
transforms_list.append(A.RandomBrightnessContrast( | |
brightness_limit=aug_config.get('brightness', 0.15), | |
contrast_limit=aug_config.get('contrast', 0.15), | |
p=0.5 | |
)) | |
# 高斯模糊 | |
if aug_config.get('blur', False): | |
transforms_list.append(A.GaussianBlur(blur_limit=(3, 5), p=aug_config.get('blur_prob', 0.2))) | |
# 其它增强可按需添加 | |
transforms_list.append(A.Resize(image_size, image_size)) | |
transforms_list.append(A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) | |
transforms_list.append(ToTensorV2()) | |
else: | |
transforms_list = [ | |
A.Resize(image_size, image_size), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
] | |
return A.Compose(transforms_list) | |
def create_data_loaders(config: dict) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]: | |
"""创建数据加载器""" | |
data_config = config['data'] | |
# 创建变换 | |
train_transform = create_data_transforms(config, is_training=True) | |
val_transform = create_data_transforms(config, is_training=False) | |
# 创建数据集 | |
train_dataset = DiabeticRetinopathyDataset( | |
data_dir=data_config['train_dir'], | |
transform=train_transform, | |
image_size=data_config['image_size'] | |
) | |
val_dataset = DiabeticRetinopathyDataset( | |
data_dir=data_config['val_dir'], | |
transform=val_transform, | |
image_size=data_config['image_size'] | |
) | |
# 创建数据加载器 | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=data_config['batch_size'], | |
shuffle=True, | |
num_workers=data_config['num_workers'], | |
pin_memory=True | |
) | |
val_loader = DataLoader( | |
val_dataset, | |
batch_size=data_config['batch_size'], | |
shuffle=False, | |
num_workers=data_config['num_workers'], | |
pin_memory=True | |
) | |
# 测试集(可选) | |
test_loader = None | |
if os.path.exists(data_config.get('test_dir', '')): | |
test_dataset = DiabeticRetinopathyDataset( | |
data_dir=data_config['test_dir'], | |
transform=val_transform, | |
image_size=data_config['image_size'] | |
) | |
test_loader = DataLoader( | |
test_dataset, | |
batch_size=data_config['batch_size'], | |
shuffle=False, | |
num_workers=data_config['num_workers'], | |
pin_memory=True | |
) | |
return train_loader, val_loader, test_loader | |
def get_class_weights(data_dir: str, num_classes: int = 5) -> torch.Tensor: | |
"""计算类别权重用于处理数据不平衡""" | |
class_counts = [0] * num_classes | |
for class_idx, class_name in enumerate(os.listdir(data_dir)): | |
class_dir = os.path.join(data_dir, class_name) | |
if os.path.isdir(class_dir): | |
count = len([f for f in os.listdir(class_dir) | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) | |
if class_idx < num_classes: | |
class_counts[class_idx] = count | |
# 计算权重(逆频率) | |
total_samples = sum(class_counts) | |
class_weights = [total_samples / (num_classes * count) if count > 0 else 0 | |
for count in class_counts] | |
return torch.FloatTensor(class_weights) | |
if __name__ == "__main__": | |
# 测试数据加载器 | |
with open("configs/config.yaml", 'r', encoding='utf-8') as f: | |
config = yaml.safe_load(f) | |
try: | |
train_loader, val_loader, test_loader = create_data_loaders(config) | |
print(f"训练集样本数: {len(train_loader.dataset)}") | |
print(f"验证集样本数: {len(val_loader.dataset)}") | |
if test_loader: | |
print(f"测试集样本数: {len(test_loader.dataset)}") | |
# 测试一个批次 | |
for batch_idx, (images, labels) in enumerate(train_loader): | |
print(f"批次 {batch_idx}: 图像形状 {images.shape}, 标签形状 {labels.shape}") | |
break | |
except Exception as e: | |
print(f"数据加载测试失败: {e}") | |
print("请确保数据目录结构正确,或创建示例数据进行测试") | |