Spaces:
Running
on
Zero
Running
on
Zero
from copy import deepcopy | |
from dataclasses import dataclass | |
import lightning.pytorch as pl | |
# from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS | |
import torch | |
from torch import LongTensor | |
from torch.utils import data | |
from torch.utils.data import DataLoader, Dataset | |
from typing import Dict, List, Tuple, Union, Callable | |
import os | |
import numpy as np | |
from .raw_data import RawData | |
from .asset import Asset | |
from .transform import TransformConfig, transform_asset | |
from .datapath import DatapathConfig, Datapath | |
from .spec import ConfigSpec | |
from ..tokenizer.spec import TokenizerSpec, TokenizerConfig | |
from ..tokenizer.parse import get_tokenizer | |
from ..model.spec import ModelInput | |
class DatasetConfig(ConfigSpec): | |
''' | |
Config to handle dataset format. | |
''' | |
# shuffle dataset | |
shuffle: bool | |
# batch size | |
batch_size: int | |
# number of workers | |
num_workers: int | |
# datapath | |
datapath_config: DatapathConfig | |
# use pin memory | |
pin_memory: bool = True | |
# use persistent workers | |
persistent_workers: bool = True | |
def parse(cls, config) -> 'DatapathConfig': | |
cls.check_keys(config) | |
return DatasetConfig( | |
shuffle=config.shuffle, | |
batch_size=config.batch_size, | |
num_workers=config.num_workers, | |
pin_memory=config.pin_memory, | |
persistent_workers=config.persistent_workers, | |
datapath_config=DatapathConfig.parse(config.datapath_config), | |
) | |
def split_by_cls(self) -> Dict[str, 'DatasetConfig']: | |
res: Dict[str, DatasetConfig] = {} | |
datapath_config_dict = self.datapath_config.split_by_cls() | |
for cls in self.datapath_config.data_path: | |
res[cls] = deepcopy(self) | |
res[cls].datapath_config = datapath_config_dict[cls] | |
return res | |
class UniRigDatasetModule(pl.LightningDataModule): | |
def __init__( | |
self, | |
process_fn: Union[Callable[[List[ModelInput]], Dict]]=None, | |
predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None, | |
predict_transform_config: Union[TransformConfig, None]=None, | |
tokenizer_config: Union[TokenizerConfig, None]=None, | |
debug: bool=False, | |
data_name: str='raw_data.npz', | |
datapath: Union[Datapath, None]=None, | |
cls: Union[str, None]=None, | |
): | |
super().__init__() | |
self.process_fn = process_fn | |
self.predict_dataset_config = predict_dataset_config | |
self.predict_transform_config = predict_transform_config | |
self.tokenizer_config = tokenizer_config | |
self.debug = debug | |
self.data_name = data_name | |
if debug: | |
print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m") | |
if datapath is not None: | |
self.train_datapath = None | |
self.validate_datapath = None | |
self.predict_datapath = { | |
cls: deepcopy(datapath), | |
} | |
self.predict_dataset_config = { | |
cls: DatasetConfig( | |
shuffle=False, | |
batch_size=1, | |
num_workers=0, | |
datapath_config=deepcopy(datapath), | |
pin_memory=False, | |
persistent_workers=False, | |
) | |
} | |
else: | |
# build predict datapath | |
if self.predict_dataset_config is not None: | |
self.predict_datapath = { | |
cls: Datapath(self.predict_dataset_config[cls].datapath_config) | |
for cls in self.predict_dataset_config | |
} | |
else: | |
self.predict_datapath = None | |
# get tokenizer | |
if tokenizer_config is None: | |
self.tokenizer = None | |
else: | |
self.tokenizer = get_tokenizer(config=tokenizer_config) | |
def prepare_data(self): | |
pass | |
def setup(self, stage=None): | |
if self.predict_datapath is not None: | |
self._predict_ds = {} | |
for cls in self.predict_datapath: | |
self._predict_ds[cls] = UniRigDataset( | |
process_fn=self.process_fn, | |
data=self.predict_datapath[cls].get_data(), | |
name=f"predict-{cls}", | |
tokenizer=self.tokenizer, | |
transform_config=self.predict_transform_config, | |
debug=self.debug, | |
data_name=self.data_name, | |
) | |
def predict_dataloader(self): | |
if not hasattr(self, "_predict_ds"): | |
self.setup() | |
return self._create_dataloader( | |
dataset=self._predict_ds, | |
config=self.predict_dataset_config, | |
is_train=False, | |
drop_last=False, | |
) | |
def _create_dataloader( | |
self, | |
dataset: Union[Dataset, Dict[str, Dataset]], | |
config: DatasetConfig, | |
is_train: bool, | |
**kwargs, | |
) -> Union[DataLoader, Dict[str, DataLoader]]: | |
def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs): | |
return DataLoader( | |
dataset, | |
batch_size=config.batch_size, | |
shuffle=config.shuffle, | |
num_workers=config.num_workers, | |
pin_memory=config.pin_memory, | |
persistent_workers=config.persistent_workers, | |
collate_fn=dataset.collate_fn, | |
**kwargs, | |
) | |
if isinstance(dataset, Dict): | |
return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()} | |
else: | |
return create_single_dataloader(dataset, config, **kwargs) | |
class UniRigDataset(Dataset): | |
def __init__( | |
self, | |
data: List[Tuple[str, str]], # (cls, part) | |
name: str, | |
process_fn: Union[Callable[[List[ModelInput]], Dict]]=None, | |
tokenizer: Union[TokenizerSpec, None]=None, | |
transform_config: Union[TransformConfig, None]=None, | |
debug: bool=False, | |
data_name: str='raw_data.npz', | |
) -> None: | |
super().__init__() | |
self.data = data | |
self.name = name | |
self.process_fn = process_fn | |
self.tokenizer = tokenizer | |
self.transform_config = transform_config | |
self.debug = debug | |
self.data_name = data_name | |
if not debug: | |
assert self.process_fn is not None, 'missing data processing function' | |
def __len__(self) -> int: | |
return len(self.data) | |
def __getitem__(self, idx) -> ModelInput: | |
cls, dir_path = self.data[idx] | |
raw_data = RawData.load(path=os.path.join(dir_path, self.data_name)) | |
asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name) | |
first_augments, second_augments = transform_asset( | |
asset=asset, | |
transform_config=self.transform_config, | |
) | |
if self.tokenizer is not None and asset.parents is not None: | |
tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input()) | |
else: | |
tokens = None | |
return ModelInput( | |
tokens=tokens, | |
pad=None if self.tokenizer is None else self.tokenizer.pad, | |
vertices=asset.sampled_vertices.astype(np.float32), | |
normals=asset.sampled_normals.astype(np.float32), | |
joints=None if asset.joints is None else asset.joints.astype(np.float32), | |
tails=None if asset.tails is None else asset.tails.astype(np.float32), | |
asset=asset, | |
augments=None, | |
) | |
def _collate_fn_debug(self, batch): | |
return batch | |
def _collate_fn(self, batch): | |
return data.dataloader.default_collate(self.process_fn(batch)) | |
def collate_fn(self, batch): | |
if self.debug: | |
return self._collate_fn_debug(batch) | |
return self._collate_fn(batch) |