Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import yaml | |
from box import Box | |
import os | |
import torch | |
import lightning as L | |
from lightning.pytorch.callbacks import ModelCheckpoint, Callback | |
from typing import List | |
from math import ceil | |
import numpy as np | |
from lightning.pytorch.strategies import FSDPStrategy, DDPStrategy | |
from src.inference.download import download | |
from src.data.asset import Asset | |
from src.data.extract import get_files | |
from src.data.dataset import UniRigDatasetModule, DatasetConfig, ModelInput | |
from src.data.datapath import Datapath | |
from src.data.transform import TransformConfig | |
from src.tokenizer.spec import TokenizerConfig | |
from src.tokenizer.parse import get_tokenizer | |
from src.model.parse import get_model | |
from src.system.parse import get_system, get_writer | |
from tqdm import tqdm | |
import time | |
def load(task: str, path: str) -> Box: | |
if path.endswith('.yaml'): | |
path = path.removesuffix('.yaml') | |
path += '.yaml' | |
print(f"\033[92mload {task} config: {path}\033[0m") | |
return Box(yaml.safe_load(open(path, 'r'))) | |
def nullable_string(val): | |
if not val: | |
return None | |
return val | |
if __name__ == "__main__": | |
torch.set_float32_matmul_precision('high') | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--task", type=str, required=True) | |
parser.add_argument("--seed", type=int, required=False, default=123, | |
help="random seed") | |
parser.add_argument("--input", type=nullable_string, required=False, default=None, | |
help="a single input file or files splited by comma") | |
parser.add_argument("--input_dir", type=nullable_string, required=False, default=None, | |
help="input directory") | |
parser.add_argument("--output", type=nullable_string, required=False, default=None, | |
help="filename for a single output") | |
parser.add_argument("--output_dir", type=nullable_string, required=False, default=None, | |
help="output directory") | |
parser.add_argument("--npz_dir", type=nullable_string, required=False, default='tmp', | |
help="intermediate npz directory") | |
parser.add_argument("--cls", type=nullable_string, required=False, default=None, | |
help="class name") | |
parser.add_argument("--data_name", type=nullable_string, required=False, default=None, | |
help="npz filename from skeleton phase") | |
args = parser.parse_args() | |
L.seed_everything(args.seed, workers=True) | |
task = load('task', args.task) | |
mode = task.mode | |
assert mode in ['predict'] | |
if args.input is not None or args.input_dir is not None: | |
assert args.output_dir is not None or args.output is not None, 'output or output_dir must be specified' | |
assert args.npz_dir is not None, 'npz_dir must be specified' | |
files = get_files( | |
data_name=task.components.data_name, | |
inputs=args.input, | |
input_dataset_dir=args.input_dir, | |
output_dataset_dir=args.npz_dir, | |
force_override=True, | |
warning=False, | |
) | |
files = [f[1] for f in files] | |
if len(files) > 1 and args.output is not None: | |
print("\033[92mwarning: output is specified, but multiple files are detected. Output will be written.\033[0m") | |
datapath = Datapath(files=files, cls=args.cls) | |
else: | |
datapath = None | |
data_config = load('data', os.path.join('configs/data', task.components.data)) | |
transform_config = load('transform', os.path.join('configs/transform', task.components.transform)) | |
# get tokenizer | |
tokenizer_config = task.components.get('tokenizer', None) | |
if tokenizer_config is not None: | |
tokenizer_config = load('tokenizer', os.path.join('configs/tokenizer', task.components.tokenizer)) | |
tokenizer_config = TokenizerConfig.parse(config=tokenizer_config) | |
# get data name | |
data_name = task.components.get('data_name', 'raw_data.npz') | |
if args.data_name is not None: | |
data_name = args.data_name | |
# get predict dataset | |
predict_dataset_config = data_config.get('predict_dataset_config', None) | |
if predict_dataset_config is not None: | |
predict_dataset_config = DatasetConfig.parse(config=predict_dataset_config).split_by_cls() | |
# get predict transform | |
predict_transform_config = transform_config.get('predict_transform_config', None) | |
if predict_transform_config is not None: | |
predict_transform_config = TransformConfig.parse(config=predict_transform_config) | |
# get model | |
model_config = task.components.get('model', None) | |
if model_config is not None: | |
model_config = load('model', os.path.join('configs/model', model_config)) | |
if tokenizer_config is not None: | |
tokenizer = get_tokenizer(config=tokenizer_config) | |
else: | |
tokenizer = None | |
model = get_model(tokenizer=tokenizer, **model_config) | |
else: | |
model = None | |
# set data | |
data = UniRigDatasetModule( | |
process_fn=None if model is None else model._process_fn, | |
predict_dataset_config=predict_dataset_config, | |
predict_transform_config=predict_transform_config, | |
tokenizer_config=tokenizer_config, | |
debug=False, | |
data_name=data_name, | |
datapath=datapath, | |
cls=args.cls, | |
) | |
# add call backs | |
callbacks = [] | |
## get checkpoint callback | |
checkpoint_config = task.get('checkpoint', None) | |
if checkpoint_config is not None: | |
checkpoint_config['dirpath'] = os.path.join('experiments', task.experiment_name) | |
callbacks.append(ModelCheckpoint(**checkpoint_config)) | |
## get writer callback | |
writer_config = task.get('writer', None) | |
if writer_config is not None: | |
assert predict_transform_config is not None, 'missing predict_transform_config in transform' | |
if args.output_dir is not None or args.output is not None: | |
if args.output is not None: | |
assert args.output.endswith('.fbx'), 'output must be .fbx' | |
writer_config['npz_dir'] = args.npz_dir | |
writer_config['output_dir'] = args.output_dir | |
writer_config['output_name'] = args.output | |
writer_config['user_mode'] = True | |
callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config)) | |
# get trainer | |
trainer_config = task.get('trainer', {}) | |
# get system | |
system_config = task.components.get('system', None) | |
if system_config is not None: | |
system_config = load('system', os.path.join('configs/system', system_config)) | |
system = get_system( | |
**system_config, | |
model=model, | |
steps_per_epoch=1, | |
) | |
else: | |
system = None | |
logger = None | |
# set ckpt path | |
resume_from_checkpoint = task.get('resume_from_checkpoint', None) | |
resume_from_checkpoint = download(resume_from_checkpoint) | |
trainer = L.Trainer( | |
callbacks=callbacks, | |
logger=logger, | |
**trainer_config, | |
) | |
if mode == 'predict': | |
assert resume_from_checkpoint is not None, 'expect resume_from_checkpoint in task' | |
trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False) | |
else: | |
assert 0 |