from .dataset import Dataset, ValDataset, TestDataset | |
from torch.utils.data import DataLoader | |
def find_dataset_using_name(name): | |
mapping = { | |
"VideoTrain": Dataset, | |
"VideoVal": ValDataset, | |
"VideoTest": TestDataset, | |
} | |
cls = mapping.get(name, None) | |
if cls is None: | |
raise ValueError(f"Fail to find dataset {name}") | |
return cls | |
def create_dataset(metainfo, split): | |
dataset_cls = find_dataset_using_name(split.type) | |
dataset = dataset_cls(metainfo, split) | |
return DataLoader( | |
dataset, | |
batch_size=split.batch_size, | |
drop_last=split.drop_last, | |
shuffle=split.shuffle, | |
num_workers=split.worker, | |
pin_memory=True | |
) |