Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import pandas as pd | |
from PIL import Image | |
import torch | |
from torch.utils.data import Dataset, DataLoader, random_split | |
from torchvision import transforms | |
import lightning as L | |
import kornia as K | |
import numpy as np | |
import random | |
class XrayDataset(Dataset): | |
def __init__(self, | |
data_frame, | |
root_dir, | |
transform=None, | |
apply_equalization=False): | |
self.data_frame = data_frame | |
self.root_dir = root_dir | |
self.transform = transform | |
self.apply_equalization = apply_equalization | |
def __len__(self): | |
return len(self.data_frame) | |
def __getitem__(self, idx): | |
row = self.data_frame.iloc[idx] | |
img_path = os.path.join(self.root_dir, row["file_name"]) | |
img = Image.open(img_path) | |
img = img.convert("L") | |
if self.transform: | |
img = self.transform(img) | |
# Apply CLAHE if flag is set | |
if self.apply_equalization: | |
# img = transforms.ToTensor()(img) | |
img = K.enhance.equalize_clahe(img.unsqueeze(0)).squeeze(0) | |
label = torch.tensor(row["value"], | |
dtype=torch.float) # Ensure label is float | |
return img, label, row["file_name"] | |
class XrayData(L.LightningDataModule): | |
common_seed = 42 | |
def seed_worker(worker_id): | |
worker_seed = torch.initial_seed() % 2**32 | |
np.random.seed(worker_seed) | |
random.seed(worker_seed) | |
def __init__( | |
self, | |
root_dir, | |
label_csv, | |
batch_size=32, | |
val_split=0.2, | |
apply_equalization=False, | |
): | |
super().__init__() | |
self.root_dir = root_dir | |
self.label_csv = label_csv | |
self.batch_size = batch_size | |
self.val_split = val_split | |
self.apply_equalization = apply_equalization | |
torch.manual_seed(self.common_seed) | |
torch.cuda.manual_seed_all(self.common_seed) | |
torch.backends.cudnn.deterministic = True | |
self.train_transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
# transforms.RandomHorizontalFlip(), | |
# transforms.RandomRotation(20), | |
transforms.ToTensor(), | |
]) | |
self.val_transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
self.full_dataset = None | |
def setup(self, stage=None): | |
data_frame = pd.read_csv(self.label_csv) | |
data_frame = data_frame.sample( | |
frac=1, random_state=self.common_seed).reset_index(drop=True) | |
dataset_size = len(data_frame) | |
val_size = int(dataset_size * self.val_split) | |
train_size = dataset_size - val_size | |
# Split the dataset using random_split | |
full_dataset = XrayDataset( | |
data_frame, | |
self.root_dir, | |
transform=None, # We'll apply the correct transform later | |
apply_equalization=self.apply_equalization, | |
) | |
self.train_dataset, self.val_dataset = random_split( | |
full_dataset, | |
[train_size, val_size], | |
generator=torch.Generator().manual_seed(self.common_seed), | |
) | |
def train_transforms(x): | |
return self.train_transform(x) if self.train_transform else x | |
def val_transforms(x): | |
return self.val_transform(x) if self.val_transform else x | |
self.train_dataset.dataset.transform = train_transforms | |
self.val_dataset.dataset.transform = val_transforms | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=True, | |
num_workers=4, | |
worker_init_fn=self.seed_worker, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.val_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=4, | |
worker_init_fn=self.seed_worker, | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
self.val_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=4, | |
worker_init_fn=self.seed_worker, | |
) | |