from collections import defaultdict import lightning as L import os import torch import numpy as np from torch import Tensor from typing import Dict, Union, List from lightning.pytorch.callbacks import BasePredictionWriter from numpy import ndarray from ..data.raw_data import RawData from ..data.order import OrderConfig, get_order from ..model.spec import ModelSpec from ..tokenizer.spec import DetokenzeOutput class ARSystem(L.LightningModule): def __init__( self, steps_per_epoch: int, model: ModelSpec, generate_kwargs: Dict={}, output_path: Union[str, None]=None, record_res: Union[bool]=False, validate_cast: str='bfloat16', val_interval: Union[int, None]=None, val_start_from: Union[int, None]=None, ): super().__init__() self.save_hyperparameters(ignore="model") self.steps_per_epoch = steps_per_epoch self.model = model self.generate_kwargs = generate_kwargs self.output_path = output_path self.record_res = record_res self.validate_cast = validate_cast self.val_interval = val_interval self.val_start_from = val_start_from if self.record_res: assert self.output_path is not None, "record_res is True, but output_path in ar is None" def _predict_step(self, batch, batch_idx, dataloader_idx=None): batch['generate_kwargs'] = self.generate_kwargs res = self.model.predict_step(batch) assert isinstance(res, list), f"expect type of prediction from {self.model.__class__} to be a list, found: {type(res)}" return res def predict_step(self, batch, batch_idx, dataloader_idx=None): try: prediction: List[DetokenzeOutput] = self._predict_step(batch=batch, batch_idx=batch_idx, dataloader_idx=dataloader_idx) return prediction except Exception as e: print(str(e)) return [] class ARWriter(BasePredictionWriter): def __init__( self, output_dir: Union[str, None], order_config: Union[OrderConfig, None]=None, **kwargs ): super().__init__('batch') self.output_dir = output_dir self.npz_dir = kwargs.get('npz_dir', None) self.user_mode = kwargs.get('user_mode', False) self.output_name = kwargs.get('output_name', None) # for a single name self.repeat = kwargs.get('repeat', 1) self.add_num = kwargs.get('add_num', False) self.export_npz = kwargs.get('export_npz', None) self.export_obj = kwargs.get('export_obj', None) self.export_fbx = kwargs.get('export_fbx', None) self.export_pc = kwargs.get('export_pc', None) if order_config is not None: self.order = get_order(config=order_config) else: self.order = None self._epoch = 0 def on_predict_end(self, trainer, pl_module): if self._epoch < self.repeat - 1: print(f"Finished prediction run {self._epoch + 1}/{self.repeat}, starting next run...") self._epoch += 1 trainer.predict_dataloader = trainer.datamodule.predict_dataloader() trainer.predict_loop.run() def write_on_batch_end(self, trainer, pl_module: ARSystem, prediction: List[Dict], batch_indices, batch, batch_idx, dataloader_idx): assert 'path' in batch paths = batch['path'] detokenize_output_list: List[DetokenzeOutput] = prediction vertices = batch['vertices'] origin_vertices = batch['origin_vertices'] origin_vertex_normals = batch['origin_vertex_normals'] origin_faces = batch['origin_faces'] origin_face_normals = batch['origin_face_normals'] num_points = batch['num_points'] num_faces = batch['num_faces'] if isinstance(origin_vertices, torch.Tensor): origin_vertices = origin_vertices.detach().cpu().numpy() if isinstance(origin_vertex_normals, torch.Tensor): origin_vertex_normals = origin_vertex_normals.detach().cpu().numpy() if isinstance(origin_faces, torch.Tensor): origin_faces = origin_faces.detach().cpu().numpy() if isinstance(origin_face_normals, torch.Tensor): origin_face_normals = origin_face_normals.detach().cpu().numpy() if isinstance(num_points, torch.Tensor): num_points = num_points.detach().cpu().numpy() if isinstance(num_faces, torch.Tensor): num_faces = num_faces.detach().cpu().numpy() for (id, detokenize_output) in enumerate(detokenize_output_list): assert isinstance(detokenize_output, DetokenzeOutput), f"expect item of the list to be DetokenzeOutput, found: {type(detokenize_output)}" def make_path(save_name: str, suffix: str, trim: bool=False): if trim: path = os.path.relpath(paths[id], self.npz_dir) else: path = paths[id] if self.output_dir is not None: path = os.path.join(self.output_dir, path) if self.add_num: path = os.path.join(path, f"{save_name}_{self._epoch}.{suffix}") else: path = os.path.join(path, f"{save_name}.{suffix}") return path num_p = num_points[id] num_f = num_faces[id] raw_data = RawData( vertices=origin_vertices[id, :num_p], vertex_normals=origin_vertex_normals[id, :num_p], faces=origin_faces[id, :num_f], face_normals=origin_face_normals[id, :num_f], joints=detokenize_output.joints, tails=detokenize_output.tails, parents=detokenize_output.parents, skin=None, no_skin=detokenize_output.no_skin, names=detokenize_output.names, matrix_local=None, path=None, cls=detokenize_output.cls, ) if not self.user_mode and self.export_npz is not None: print(make_path(self.export_npz, 'npz')) raw_data.save(path=make_path(self.export_npz, 'npz')) if not self.user_mode and self.export_obj is not None: raw_data.export_skeleton(path=make_path(self.export_obj, 'obj')) if not self.user_mode and self.export_pc is not None: raw_data.export_pc(path=make_path(self.export_pc, 'obj')) if self.export_fbx is not None: if not self.user_mode: raw_data.export_fbx(path=make_path(self.export_fbx, 'fbx')) else: if self.output_name is not None: raw_data.export_fbx(path=self.output_name) else: raw_data.export_fbx(path=make_path(self.export_fbx, 'fbx', trim=True))