from torch.utils.data import Dataset import json default_get = lambda key : lambda data: data[key] class PromptDataset(Dataset): def __init__(self,data_dir,*keys,**projections) -> None: self.data = [] for d in data_dir: list_contents = {key:default_get(key)(d) for key in keys if key in d.keys()} dict_contents = {projection:projections[projection](d) for projection in projections.keys()} self.data.append({**list_contents,**dict_contents}) def __getitem__(self, index) -> dict: return self.data[index] def __len__(self): return len(self.data) class FileDataset(PromptDataset): def __init__(self,data_dir,*keys,**projections) -> None: with open(data_dir,'r',encoding='utf-8') as file: data_dir = json.load(file) if not keys: keys = data_dir[0].keys() self.data = [] for d in data_dir: list_contents = {key:default_get(key)(d) for key in keys if key in d.keys()} dict_contents = {projection:projections[projection](d) for projection in projections.keys()} self.data.append({**list_contents,**dict_contents}) def __getitem__(self, index) -> dict: return self.data[index] def __len__(self): return len(self.data)