| import glob |
| import logging |
| import os |
| import sys |
| import time |
|
|
| from absl import app |
| import gin |
| from internal import configs |
| from internal import datasets |
| from internal import models |
| from internal import train_utils |
| from internal import checkpoints |
| from internal import utils |
| from internal import vis |
| from matplotlib import cm |
| import mediapy as media |
| import torch |
| import numpy as np |
| import accelerate |
| import imageio |
| from torch.utils._pytree import tree_map |
|
|
| configs.define_common_flags() |
|
|
|
|
| def create_videos(config, base_dir, out_dir, out_name, num_frames): |
| """Creates videos out of the images saved to disk.""" |
| names = [n for n in config.exp_path.split('/') if n] |
| |
| exp_name, scene_name = names[-2:] |
| video_prefix = f'{scene_name}_{exp_name}_{out_name}' |
|
|
| zpad = max(3, len(str(num_frames - 1))) |
| idx_to_str = lambda idx: str(idx).zfill(zpad) |
|
|
| utils.makedirs(base_dir) |
|
|
| |
| depth_file = os.path.join(out_dir, f'distance_mean_{idx_to_str(0)}.tiff') |
| depth_frame = utils.load_img(depth_file) |
| shape = depth_frame.shape |
| p = config.render_dist_percentile |
| distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p]) |
| |
| depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) |
| lo, hi = distance_limits |
| print(f'Video shape is {shape[:2]}') |
|
|
| for k in ['color', 'normals', 'acc', 'distance_mean', 'distance_median']: |
| video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') |
| file_ext = 'png' if k in ['color', 'normals'] else 'tiff' |
| file0 = os.path.join(out_dir, f'{k}_{idx_to_str(0)}.{file_ext}') |
| if not utils.file_exists(file0): |
| print(f'Images missing for tag {k}') |
| continue |
| print(f'Making video {video_file}...') |
|
|
| writer = imageio.get_writer(video_file, fps=config.render_video_fps) |
| for idx in range(num_frames): |
| img_file = os.path.join(out_dir, f'{k}_{idx_to_str(idx)}.{file_ext}') |
| if not utils.file_exists(img_file): |
| ValueError(f'Image file {img_file} does not exist.') |
|
|
| img = utils.load_img(img_file) |
| if k in ['color', 'normals']: |
| img = img / 255. |
| elif k.startswith('distance'): |
| |
| |
| |
|
|
| img = vis.visualize_cmap(img, np.ones_like(img), cm.get_cmap('turbo'), lo, hi, curve_fn=depth_curve_fn) |
|
|
| frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8) |
| writer.append_data(frame) |
| writer.close() |
|
|
|
|
| def main(unused_argv): |
| config = configs.load_config() |
| config.exp_path = os.path.join('exp', config.exp_name) |
| config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') |
| config.render_dir = os.path.join(config.exp_path, 'render') |
|
|
| accelerator = accelerate.Accelerator() |
| |
| logging.basicConfig( |
| format="%(asctime)s: %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| force=True, |
| handlers=[logging.StreamHandler(sys.stdout), |
| logging.FileHandler(os.path.join(config.exp_path, 'log_render.txt'))], |
| level=logging.INFO, |
| ) |
| sys.excepthook = utils.handle_exception |
| logger = accelerate.logging.get_logger(__name__) |
| logger.info(config) |
| logger.info(accelerator.state, main_process_only=False) |
|
|
| config.world_size = accelerator.num_processes |
| config.global_rank = accelerator.process_index |
| accelerate.utils.set_seed(config.seed, device_specific=True) |
| model = models.Model(config=config) |
| model.eval() |
|
|
| dataset = datasets.load_dataset('test', config.data_dir, config) |
| dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), |
| shuffle=False, |
| batch_size=1, |
| collate_fn=dataset.collate_fn, |
| ) |
| dataiter = iter(dataloader) |
| if config.rawnerf_mode: |
| postprocess_fn = dataset.metadata['postprocess_fn'] |
| else: |
| postprocess_fn = lambda z: z |
|
|
| model = accelerator.prepare(model) |
| step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) |
|
|
| logger.info(f'Rendering checkpoint at step {step}.') |
|
|
| out_name = 'path_renders' if config.render_path else 'test_preds' |
| out_name = f'{out_name}_step_{step}2' |
| out_dir = os.path.join(config.render_dir, out_name) |
| utils.makedirs(out_dir) |
|
|
| path_fn = lambda x: os.path.join(out_dir, x) |
|
|
| |
| zpad = max(3, len(str(dataset.size - 1))) |
| idx_to_str = lambda idx: str(idx).zfill(zpad) |
|
|
| for idx in range(dataset.size): |
| |
| idx_str = idx_to_str(idx) |
| curr_file = path_fn(f'color_{idx_str}.png') |
| if utils.file_exists(curr_file): |
| logger.info(f'Image {idx + 1}/{dataset.size} already exists, skipping') |
| continue |
| batch = next(dataiter) |
| batch = tree_map(lambda x: x.to(accelerator.device) if x is not None else None, batch) |
| logger.info(f'Evaluating image {idx + 1}/{dataset.size}') |
| eval_start_time = time.time() |
| rendering = models.render_image(model, accelerator, |
| batch, False, 1, config) |
|
|
| logger.info(f'Rendered in {(time.time() - eval_start_time):0.3f}s') |
|
|
| if accelerator.is_main_process: |
| rendering['rgb'] = postprocess_fn(rendering['rgb']) |
| rendering = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, rendering) |
| utils.save_img_u8(rendering['rgb'], path_fn(f'color_{idx_str}.png')) |
| if 'normals' in rendering: |
| utils.save_img_u8(rendering['normals'] / 2. + 0.5, |
| path_fn(f'normals_{idx_str}.png')) |
| utils.save_img_f32(rendering['distance_mean'], |
| path_fn(f'distance_mean_{idx_str}.tiff')) |
| utils.save_img_f32(rendering['distance_median'], |
| path_fn(f'distance_median_{idx_str}.tiff')) |
| utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx_str}.tiff')) |
| num_files = len(glob.glob(path_fn('acc_*.tiff'))) |
| if accelerator.is_main_process and num_files == dataset.size: |
| logger.info(f'All files found, creating videos.') |
| create_videos(config, config.render_dir, out_dir, out_name, dataset.size) |
| accelerator.wait_for_everyone() |
| logger.info('Finish rendering.') |
|
|
| if __name__ == '__main__': |
| with gin.config_scope('eval'): |
| app.run(main) |
|
|