xray-reg / dataset /xray_loader.py
SuperSecureHuman's picture
Upload 59 files
465d7e4 verified
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import lightning as L
import kornia as K
import numpy as np
import random
class XrayDataset(Dataset):
def __init__(self,
data_frame,
root_dir,
transform=None,
apply_equalization=False):
self.data_frame = data_frame
self.root_dir = root_dir
self.transform = transform
self.apply_equalization = apply_equalization
def __len__(self):
return len(self.data_frame)
def __getitem__(self, idx):
row = self.data_frame.iloc[idx]
img_path = os.path.join(self.root_dir, row["file_name"])
img = Image.open(img_path)
img = img.convert("L")
if self.transform:
img = self.transform(img)
# Apply CLAHE if flag is set
if self.apply_equalization:
# img = transforms.ToTensor()(img)
img = K.enhance.equalize_clahe(img.unsqueeze(0)).squeeze(0)
label = torch.tensor(row["value"],
dtype=torch.float) # Ensure label is float
return img, label, row["file_name"]
class XrayData(L.LightningDataModule):
common_seed = 42
@staticmethod
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def __init__(
self,
root_dir,
label_csv,
batch_size=32,
val_split=0.2,
apply_equalization=False,
):
super().__init__()
self.root_dir = root_dir
self.label_csv = label_csv
self.batch_size = batch_size
self.val_split = val_split
self.apply_equalization = apply_equalization
torch.manual_seed(self.common_seed)
torch.cuda.manual_seed_all(self.common_seed)
torch.backends.cudnn.deterministic = True
self.train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(20),
transforms.ToTensor(),
])
self.val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
self.full_dataset = None
def setup(self, stage=None):
data_frame = pd.read_csv(self.label_csv)
data_frame = data_frame.sample(
frac=1, random_state=self.common_seed).reset_index(drop=True)
dataset_size = len(data_frame)
val_size = int(dataset_size * self.val_split)
train_size = dataset_size - val_size
# Split the dataset using random_split
full_dataset = XrayDataset(
data_frame,
self.root_dir,
transform=None, # We'll apply the correct transform later
apply_equalization=self.apply_equalization,
)
self.train_dataset, self.val_dataset = random_split(
full_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(self.common_seed),
)
def train_transforms(x):
return self.train_transform(x) if self.train_transform else x
def val_transforms(x):
return self.val_transform(x) if self.val_transform else x
self.train_dataset.dataset.transform = train_transforms
self.val_dataset.dataset.transform = val_transforms
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=4,
worker_init_fn=self.seed_worker,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=4,
worker_init_fn=self.seed_worker,
)
def test_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=4,
worker_init_fn=self.seed_worker,
)