import torch from torch.utils.data import Dataset class HumanActionDataset(Dataset): def __init__(self, hf_dataset_split, transform=None): """ hf_dataset_split: Hugging Face dataset split, e.g. ds['train'] transform: torchvision transforms """ self.dataset = hf_dataset_split self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] image = item["image"] # PIL.Image.Image label = item["labels"] if self.transform: image = self.transform(image) return image, label