Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| class SchemaStringDataset(Dataset): | |
| def __init__(self, data, config): | |
| self.data = data | |
| self.config = config | |
| def __len__(self): | |
| # Return the dataset size specified in the configuration | |
| return self.config["dataset_size"] | |
| def transform_entry(self, entry): | |
| # Filter out None and NaN values | |
| filtered_entry = {k: v for k, v in entry.items() if v is not np.nan and v is not None} | |
| # Check if there are any entries after filtering | |
| if not filtered_entry: | |
| return '', '' # Return empty strings if no valid entries exist | |
| # Use the rest of the entry as input | |
| inputs = [f"{k}:{v}" for k, v in filtered_entry.items()] | |
| return ' '.join(inputs) | |
| def __getitem__(self, idx): | |
| transformed_data = { | |
| 'inputs': [] | |
| } | |
| item = self.data[idx] | |
| input_data = {k: v for k, v in item.items()} | |
| inputs = self.transform_entry(input_data) | |
| transformed_data['inputs'] = inputs | |
| transformed_data['idx'] = idx | |
| # Return the transformed item for the current idx | |
| return transformed_data | |