jkorstad's picture
Correctly add UniRig source files
f499d3b
from copy import deepcopy
from dataclasses import dataclass
import lightning.pytorch as pl
# from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import torch
from torch import LongTensor
from torch.utils import data
from torch.utils.data import DataLoader, Dataset
from typing import Dict, List, Tuple, Union, Callable
import os
import numpy as np
from .raw_data import RawData
from .asset import Asset
from .transform import TransformConfig, transform_asset
from .datapath import DatapathConfig, Datapath
from .spec import ConfigSpec
from ..tokenizer.spec import TokenizerSpec, TokenizerConfig
from ..tokenizer.parse import get_tokenizer
from ..model.spec import ModelInput
@dataclass
class DatasetConfig(ConfigSpec):
'''
Config to handle dataset format.
'''
# shuffle dataset
shuffle: bool
# batch size
batch_size: int
# number of workers
num_workers: int
# datapath
datapath_config: DatapathConfig
# use pin memory
pin_memory: bool = True
# use persistent workers
persistent_workers: bool = True
@classmethod
def parse(cls, config) -> 'DatapathConfig':
cls.check_keys(config)
return DatasetConfig(
shuffle=config.shuffle,
batch_size=config.batch_size,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
persistent_workers=config.persistent_workers,
datapath_config=DatapathConfig.parse(config.datapath_config),
)
def split_by_cls(self) -> Dict[str, 'DatasetConfig']:
res: Dict[str, DatasetConfig] = {}
datapath_config_dict = self.datapath_config.split_by_cls()
for cls in self.datapath_config.data_path:
res[cls] = deepcopy(self)
res[cls].datapath_config = datapath_config_dict[cls]
return res
class UniRigDatasetModule(pl.LightningDataModule):
def __init__(
self,
process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None,
predict_transform_config: Union[TransformConfig, None]=None,
tokenizer_config: Union[TokenizerConfig, None]=None,
debug: bool=False,
data_name: str='raw_data.npz',
datapath: Union[Datapath, None]=None,
cls: Union[str, None]=None,
):
super().__init__()
self.process_fn = process_fn
self.predict_dataset_config = predict_dataset_config
self.predict_transform_config = predict_transform_config
self.tokenizer_config = tokenizer_config
self.debug = debug
self.data_name = data_name
if debug:
print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m")
if datapath is not None:
self.train_datapath = None
self.validate_datapath = None
self.predict_datapath = {
cls: deepcopy(datapath),
}
self.predict_dataset_config = {
cls: DatasetConfig(
shuffle=False,
batch_size=1,
num_workers=0,
datapath_config=deepcopy(datapath),
pin_memory=False,
persistent_workers=False,
)
}
else:
# build predict datapath
if self.predict_dataset_config is not None:
self.predict_datapath = {
cls: Datapath(self.predict_dataset_config[cls].datapath_config)
for cls in self.predict_dataset_config
}
else:
self.predict_datapath = None
# get tokenizer
if tokenizer_config is None:
self.tokenizer = None
else:
self.tokenizer = get_tokenizer(config=tokenizer_config)
def prepare_data(self):
pass
def setup(self, stage=None):
if self.predict_datapath is not None:
self._predict_ds = {}
for cls in self.predict_datapath:
self._predict_ds[cls] = UniRigDataset(
process_fn=self.process_fn,
data=self.predict_datapath[cls].get_data(),
name=f"predict-{cls}",
tokenizer=self.tokenizer,
transform_config=self.predict_transform_config,
debug=self.debug,
data_name=self.data_name,
)
def predict_dataloader(self):
if not hasattr(self, "_predict_ds"):
self.setup()
return self._create_dataloader(
dataset=self._predict_ds,
config=self.predict_dataset_config,
is_train=False,
drop_last=False,
)
def _create_dataloader(
self,
dataset: Union[Dataset, Dict[str, Dataset]],
config: DatasetConfig,
is_train: bool,
**kwargs,
) -> Union[DataLoader, Dict[str, DataLoader]]:
def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs):
return DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=config.shuffle,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
persistent_workers=config.persistent_workers,
collate_fn=dataset.collate_fn,
**kwargs,
)
if isinstance(dataset, Dict):
return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()}
else:
return create_single_dataloader(dataset, config, **kwargs)
class UniRigDataset(Dataset):
def __init__(
self,
data: List[Tuple[str, str]], # (cls, part)
name: str,
process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
tokenizer: Union[TokenizerSpec, None]=None,
transform_config: Union[TransformConfig, None]=None,
debug: bool=False,
data_name: str='raw_data.npz',
) -> None:
super().__init__()
self.data = data
self.name = name
self.process_fn = process_fn
self.tokenizer = tokenizer
self.transform_config = transform_config
self.debug = debug
self.data_name = data_name
if not debug:
assert self.process_fn is not None, 'missing data processing function'
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx) -> ModelInput:
cls, dir_path = self.data[idx]
raw_data = RawData.load(path=os.path.join(dir_path, self.data_name))
asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name)
first_augments, second_augments = transform_asset(
asset=asset,
transform_config=self.transform_config,
)
if self.tokenizer is not None and asset.parents is not None:
tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input())
else:
tokens = None
return ModelInput(
tokens=tokens,
pad=None if self.tokenizer is None else self.tokenizer.pad,
vertices=asset.sampled_vertices.astype(np.float32),
normals=asset.sampled_normals.astype(np.float32),
joints=None if asset.joints is None else asset.joints.astype(np.float32),
tails=None if asset.tails is None else asset.tails.astype(np.float32),
asset=asset,
augments=None,
)
def _collate_fn_debug(self, batch):
return batch
def _collate_fn(self, batch):
return data.dataloader.default_collate(self.process_fn(batch))
def collate_fn(self, batch):
if self.debug:
return self._collate_fn_debug(batch)
return self._collate_fn(batch)