smog / src /datasets /get_dataset.py
vonexel's picture
add: src
fe64bad verified
from .amass import AMASS
def get_dataset(name="amass"):
return AMASS
def get_datasets(parameters, clip_preprocess, split="train"):
DATA = AMASS
if split == 'all':
train = DATA(split='train', clip_preprocess=clip_preprocess, **parameters)
test = DATA(split='vald', clip_preprocess=clip_preprocess, **parameters)
# add specific parameters from the dataset loading
train.update_parameters(parameters)
test.update_parameters(parameters)
else:
dataset = DATA(split=split, clip_preprocess=clip_preprocess, **parameters)
train = dataset
# test: shallow copy (share the memory) but set the other indices
from copy import copy
test = copy(train)
test.split = test
# add specific parameters from the dataset loading
dataset.update_parameters(parameters)
datasets = {"train": train,
"test": test}
return datasets