ReaLens / data /image_folder.py
Abubakar740
Upload LDR to HDR application files
2232b2c
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
import torch.utils.data as data
from pathlib import Path
from PIL import Image
IMG_EXTENSIONS = [
".jpg",
".JPG",
".jpeg",
".JPEG",
".png",
".PNG",
".ppm",
".PPM",
".bmp",
".BMP",
".tif",
".TIF",
".tiff",
".TIFF",
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
dir_path = Path(dir)
assert dir_path.is_dir(), f"{dir} is not a valid directory"
for path in sorted(dir_path.rglob("*")):
if path.is_file() and is_image_file(path.name):
images.append(str(path))
return images[: min(max_dataset_size, len(images))]
def default_loader(path):
return Image.open(path).convert("RGB")
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise (RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)