File size: 8,224 Bytes
f499d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from copy import deepcopy
from dataclasses import dataclass
import lightning.pytorch as pl
# from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import torch
from torch import LongTensor
from torch.utils import data
from torch.utils.data import DataLoader, Dataset
from typing import Dict, List, Tuple, Union, Callable
import os
import numpy as np

from .raw_data import RawData
from .asset import Asset
from .transform import TransformConfig, transform_asset
from .datapath import DatapathConfig, Datapath
from .spec import ConfigSpec

from ..tokenizer.spec import TokenizerSpec, TokenizerConfig
from ..tokenizer.parse import get_tokenizer
from ..model.spec import ModelInput

@dataclass
class DatasetConfig(ConfigSpec):
    '''
    Config to handle dataset format.
    '''
    # shuffle dataset
    shuffle: bool

    # batch size
    batch_size: int

    # number of workers
    num_workers: int
    
    # datapath
    datapath_config: DatapathConfig
    
    # use pin memory
    pin_memory: bool = True
    
    # use persistent workers
    persistent_workers: bool = True
    
    @classmethod
    def parse(cls, config) -> 'DatapathConfig':
        cls.check_keys(config)
        return DatasetConfig(
            shuffle=config.shuffle,
            batch_size=config.batch_size,
            num_workers=config.num_workers,
            pin_memory=config.pin_memory,
            persistent_workers=config.persistent_workers,
            datapath_config=DatapathConfig.parse(config.datapath_config),
        )
    
    def split_by_cls(self) -> Dict[str, 'DatasetConfig']:
        res: Dict[str, DatasetConfig] = {}
        datapath_config_dict = self.datapath_config.split_by_cls()
        for cls in self.datapath_config.data_path:
            res[cls] = deepcopy(self)
            res[cls].datapath_config = datapath_config_dict[cls]
        return res

class UniRigDatasetModule(pl.LightningDataModule):  
    def __init__(
        self,
        process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
        predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None,
        predict_transform_config: Union[TransformConfig, None]=None,
        tokenizer_config: Union[TokenizerConfig, None]=None,
        debug: bool=False,
        data_name: str='raw_data.npz',
        datapath: Union[Datapath, None]=None,
        cls: Union[str, None]=None,
    ):
        super().__init__()
        self.process_fn                 = process_fn
        self.predict_dataset_config     = predict_dataset_config
        self.predict_transform_config   = predict_transform_config
        self.tokenizer_config           = tokenizer_config
        self.debug                      = debug
        self.data_name                  = data_name
        
        if debug:
            print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m")
        
        if datapath is not None:
            self.train_datapath = None
            self.validate_datapath = None
            self.predict_datapath = {
                cls: deepcopy(datapath),
            }
            self.predict_dataset_config = {
                cls: DatasetConfig(
                    shuffle=False,
                    batch_size=1,
                    num_workers=0,
                    datapath_config=deepcopy(datapath),
                    pin_memory=False,
                    persistent_workers=False,
                )
            }
        else:
            # build predict datapath
            if self.predict_dataset_config is not None:
                self.predict_datapath = {
                    cls: Datapath(self.predict_dataset_config[cls].datapath_config)
                    for cls in self.predict_dataset_config
                }
            else:
                self.predict_datapath = None
        
        # get tokenizer
        if tokenizer_config is None:
            self.tokenizer = None
        else:
            self.tokenizer = get_tokenizer(config=tokenizer_config)

    def prepare_data(self):
        pass

    def setup(self, stage=None):   
        if self.predict_datapath is not None:
            self._predict_ds = {}
            for cls in self.predict_datapath:
                self._predict_ds[cls] = UniRigDataset(
                    process_fn=self.process_fn,
                    data=self.predict_datapath[cls].get_data(),
                    name=f"predict-{cls}",
                    tokenizer=self.tokenizer,
                    transform_config=self.predict_transform_config,
                    debug=self.debug,
                    data_name=self.data_name,
                )
    
    def predict_dataloader(self):
        if not hasattr(self, "_predict_ds"):
            self.setup()
        return self._create_dataloader(
            dataset=self._predict_ds,
            config=self.predict_dataset_config,
            is_train=False,
            drop_last=False,
        )

    def _create_dataloader(
        self,
        dataset: Union[Dataset, Dict[str, Dataset]],
        config: DatasetConfig,
        is_train: bool,
        **kwargs,
    ) -> Union[DataLoader, Dict[str, DataLoader]]:
        def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs):
            return DataLoader(
                dataset,
                batch_size=config.batch_size,
                shuffle=config.shuffle,
                num_workers=config.num_workers,
                pin_memory=config.pin_memory,
                persistent_workers=config.persistent_workers,
                collate_fn=dataset.collate_fn,
                **kwargs,
            )
        if isinstance(dataset, Dict):
            return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()}
        else:
            return create_single_dataloader(dataset, config, **kwargs)

class UniRigDataset(Dataset):
    def __init__(
        self,
        data: List[Tuple[str, str]], # (cls, part)
        name: str,
        process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
        tokenizer: Union[TokenizerSpec, None]=None,
        transform_config: Union[TransformConfig, None]=None,
        debug: bool=False,
        data_name: str='raw_data.npz',
    ) -> None:
        super().__init__()
        
        self.data               = data
        self.name               = name
        self.process_fn         = process_fn
        self.tokenizer          = tokenizer
        self.transform_config   = transform_config
        self.debug              = debug
        self.data_name          = data_name
        
        if not debug:
            assert self.process_fn is not None, 'missing data processing function'

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx) -> ModelInput:
        cls, dir_path = self.data[idx]
        raw_data = RawData.load(path=os.path.join(dir_path, self.data_name))
        asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name)
        
        first_augments, second_augments = transform_asset(
            asset=asset,
            transform_config=self.transform_config,
        )
        if self.tokenizer is not None and asset.parents is not None:
            tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input())
        else:
            tokens = None
        return ModelInput(
            tokens=tokens,
            pad=None if self.tokenizer is None else self.tokenizer.pad,
            vertices=asset.sampled_vertices.astype(np.float32),
            normals=asset.sampled_normals.astype(np.float32),
            joints=None if asset.joints is None else asset.joints.astype(np.float32),
            tails=None if asset.tails is None else asset.tails.astype(np.float32),
            asset=asset,
            augments=None,
        )

    def _collate_fn_debug(self, batch):
        return batch
    
    def _collate_fn(self, batch):
        return data.dataloader.default_collate(self.process_fn(batch))

    def collate_fn(self, batch):
        if self.debug:
            return self._collate_fn_debug(batch)
        return self._collate_fn(batch)