import os import pandas as pd from PIL import Image import torch from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms import lightning as L import kornia as K import numpy as np import random class XrayDataset(Dataset): def __init__(self, data_frame, root_dir, transform=None, apply_equalization=False): self.data_frame = data_frame self.root_dir = root_dir self.transform = transform self.apply_equalization = apply_equalization def __len__(self): return len(self.data_frame) def __getitem__(self, idx): row = self.data_frame.iloc[idx] img_path = os.path.join(self.root_dir, row["file_name"]) img = Image.open(img_path) img = img.convert("L") if self.transform: img = self.transform(img) # Apply CLAHE if flag is set if self.apply_equalization: # img = transforms.ToTensor()(img) img = K.enhance.equalize_clahe(img.unsqueeze(0)).squeeze(0) label = torch.tensor(row["value"], dtype=torch.float) # Ensure label is float return img, label, row["file_name"] class XrayData(L.LightningDataModule): common_seed = 42 @staticmethod def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) def __init__( self, root_dir, label_csv, batch_size=32, val_split=0.2, apply_equalization=False, ): super().__init__() self.root_dir = root_dir self.label_csv = label_csv self.batch_size = batch_size self.val_split = val_split self.apply_equalization = apply_equalization torch.manual_seed(self.common_seed) torch.cuda.manual_seed_all(self.common_seed) torch.backends.cudnn.deterministic = True self.train_transform = transforms.Compose([ transforms.Resize((224, 224)), # transforms.RandomHorizontalFlip(), # transforms.RandomRotation(20), transforms.ToTensor(), ]) self.val_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) self.full_dataset = None def setup(self, stage=None): data_frame = pd.read_csv(self.label_csv) data_frame = data_frame.sample( frac=1, random_state=self.common_seed).reset_index(drop=True) dataset_size = len(data_frame) val_size = int(dataset_size * self.val_split) train_size = dataset_size - val_size # Split the dataset using random_split full_dataset = XrayDataset( data_frame, self.root_dir, transform=None, # We'll apply the correct transform later apply_equalization=self.apply_equalization, ) self.train_dataset, self.val_dataset = random_split( full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(self.common_seed), ) def train_transforms(x): return self.train_transform(x) if self.train_transform else x def val_transforms(x): return self.val_transform(x) if self.val_transform else x self.train_dataset.dataset.transform = train_transforms self.val_dataset.dataset.transform = val_transforms def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, worker_init_fn=self.seed_worker, ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, worker_init_fn=self.seed_worker, ) def test_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, worker_init_fn=self.seed_worker, )