dinosaur_project / src /dataset.py
lucvantien1211's picture
Upload src folder, which contain python module and script
25ce9a0 verified
raw
history blame contribute delete
900 Bytes
'''
- This file contains definition for a custom dataset class inherits from
torchvision.datasets.ImageFolder. The only reason why we have to do this is
because we need to filter out child drawing and fossil image of dinosaurs
using CLIP.
- We will modify the __getitem__ method to return images, labels, and paths.
By doing this, we can utilize DataLoader for batch processing, and clean our
data much faster.
'''
import os
from torchvision.datasets import ImageFolder
class DinoDataset(ImageFolder):
'''
Custom dataset class inherits from torch.datasets.ImageFolder
'''
def __init__(self, root, transform=None):
super().__init__(root, transform)
self.paths = [os.path.join(root, p[0]) for p in self.samples]
def __getitem__(self, idx):
img, label = super().__getitem__(idx)
return img, label, self.paths[idx]