|
"""A modified image folder class |
|
|
|
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) |
|
so that this class can load images from both current directory and its subdirectories. |
|
""" |
|
|
|
import torch.utils.data as data |
|
from pathlib import Path |
|
from PIL import Image |
|
|
|
IMG_EXTENSIONS = [ |
|
".jpg", |
|
".JPG", |
|
".jpeg", |
|
".JPEG", |
|
".png", |
|
".PNG", |
|
".ppm", |
|
".PPM", |
|
".bmp", |
|
".BMP", |
|
".tif", |
|
".TIF", |
|
".tiff", |
|
".TIFF", |
|
] |
|
|
|
|
|
def is_image_file(filename): |
|
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) |
|
|
|
|
|
def make_dataset(dir, max_dataset_size=float("inf")): |
|
images = [] |
|
dir_path = Path(dir) |
|
assert dir_path.is_dir(), f"{dir} is not a valid directory" |
|
|
|
for path in sorted(dir_path.rglob("*")): |
|
if path.is_file() and is_image_file(path.name): |
|
images.append(str(path)) |
|
return images[: min(max_dataset_size, len(images))] |
|
|
|
|
|
def default_loader(path): |
|
return Image.open(path).convert("RGB") |
|
|
|
|
|
class ImageFolder(data.Dataset): |
|
|
|
def __init__(self, root, transform=None, return_paths=False, loader=default_loader): |
|
imgs = make_dataset(root) |
|
if len(imgs) == 0: |
|
raise (RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) |
|
|
|
self.root = root |
|
self.imgs = imgs |
|
self.transform = transform |
|
self.return_paths = return_paths |
|
self.loader = loader |
|
|
|
def __getitem__(self, index): |
|
path = self.imgs[index] |
|
img = self.loader(path) |
|
if self.transform is not None: |
|
img = self.transform(img) |
|
if self.return_paths: |
|
return img, path |
|
else: |
|
return img |
|
|
|
def __len__(self): |
|
return len(self.imgs) |
|
|