| | import glob |
| | import logging |
| | import os |
| | import sys |
| | import time |
| |
|
| | from absl import app |
| | import gin |
| | from internal import configs |
| | from internal import datasets |
| | from internal import models |
| | from internal import train_utils |
| | from internal import checkpoints |
| | from internal import utils |
| | from internal import vis |
| | from matplotlib import cm |
| | import mediapy as media |
| | import torch |
| | import numpy as np |
| | import accelerate |
| | import imageio |
| | from torch.utils._pytree import tree_map |
| |
|
| | configs.define_common_flags() |
| |
|
| |
|
| | def create_videos(config, base_dir, out_dir, out_name, num_frames): |
| | """Creates videos out of the images saved to disk.""" |
| | names = [n for n in config.exp_path.split('/') if n] |
| | |
| | exp_name, scene_name = names[-2:] |
| | video_prefix = f'{scene_name}_{exp_name}_{out_name}' |
| |
|
| | zpad = max(3, len(str(num_frames - 1))) |
| | idx_to_str = lambda idx: str(idx).zfill(zpad) |
| |
|
| | utils.makedirs(base_dir) |
| |
|
| | |
| | depth_file = os.path.join(out_dir, f'distance_mean_{idx_to_str(0)}.tiff') |
| | depth_frame = utils.load_img(depth_file) |
| | shape = depth_frame.shape |
| | p = config.render_dist_percentile |
| | distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p]) |
| | |
| | depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) |
| | lo, hi = distance_limits |
| | print(f'Video shape is {shape[:2]}') |
| |
|
| | for k in ['color', 'normals', 'acc', 'distance_mean', 'distance_median']: |
| | video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') |
| | file_ext = 'png' if k in ['color', 'normals'] else 'tiff' |
| | file0 = os.path.join(out_dir, f'{k}_{idx_to_str(0)}.{file_ext}') |
| | if not utils.file_exists(file0): |
| | print(f'Images missing for tag {k}') |
| | continue |
| | print(f'Making video {video_file}...') |
| |
|
| | writer = imageio.get_writer(video_file, fps=config.render_video_fps) |
| | for idx in range(num_frames): |
| | img_file = os.path.join(out_dir, f'{k}_{idx_to_str(idx)}.{file_ext}') |
| | if not utils.file_exists(img_file): |
| | ValueError(f'Image file {img_file} does not exist.') |
| |
|
| | img = utils.load_img(img_file) |
| | if k in ['color', 'normals']: |
| | img = img / 255. |
| | elif k.startswith('distance'): |
| | |
| | |
| | |
| |
|
| | img = vis.visualize_cmap(img, np.ones_like(img), cm.get_cmap('turbo'), lo, hi, curve_fn=depth_curve_fn) |
| |
|
| | frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8) |
| | writer.append_data(frame) |
| | writer.close() |
| |
|
| |
|
| | def main(unused_argv): |
| | config = configs.load_config() |
| | config.exp_path = os.path.join('exp', config.exp_name) |
| | config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') |
| | config.render_dir = os.path.join(config.exp_path, 'render') |
| |
|
| | accelerator = accelerate.Accelerator() |
| | |
| | logging.basicConfig( |
| | format="%(asctime)s: %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | force=True, |
| | handlers=[logging.StreamHandler(sys.stdout), |
| | logging.FileHandler(os.path.join(config.exp_path, 'log_render.txt'))], |
| | level=logging.INFO, |
| | ) |
| | sys.excepthook = utils.handle_exception |
| | logger = accelerate.logging.get_logger(__name__) |
| | logger.info(config) |
| | logger.info(accelerator.state, main_process_only=False) |
| |
|
| | config.world_size = accelerator.num_processes |
| | config.global_rank = accelerator.process_index |
| | accelerate.utils.set_seed(config.seed, device_specific=True) |
| | model = models.Model(config=config) |
| | model.eval() |
| |
|
| | dataset = datasets.load_dataset('test', config.data_dir, config) |
| | dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), |
| | shuffle=False, |
| | batch_size=1, |
| | collate_fn=dataset.collate_fn, |
| | ) |
| | dataiter = iter(dataloader) |
| | if config.rawnerf_mode: |
| | postprocess_fn = dataset.metadata['postprocess_fn'] |
| | else: |
| | postprocess_fn = lambda z: z |
| |
|
| | model = accelerator.prepare(model) |
| | step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) |
| |
|
| | logger.info(f'Rendering checkpoint at step {step}.') |
| |
|
| | out_name = 'path_renders' if config.render_path else 'test_preds' |
| | out_name = f'{out_name}_step_{step}2' |
| | out_dir = os.path.join(config.render_dir, out_name) |
| | utils.makedirs(out_dir) |
| |
|
| | path_fn = lambda x: os.path.join(out_dir, x) |
| |
|
| | |
| | zpad = max(3, len(str(dataset.size - 1))) |
| | idx_to_str = lambda idx: str(idx).zfill(zpad) |
| |
|
| | for idx in range(dataset.size): |
| | |
| | idx_str = idx_to_str(idx) |
| | curr_file = path_fn(f'color_{idx_str}.png') |
| | if utils.file_exists(curr_file): |
| | logger.info(f'Image {idx + 1}/{dataset.size} already exists, skipping') |
| | continue |
| | batch = next(dataiter) |
| | batch = tree_map(lambda x: x.to(accelerator.device) if x is not None else None, batch) |
| | logger.info(f'Evaluating image {idx + 1}/{dataset.size}') |
| | eval_start_time = time.time() |
| | rendering = models.render_image(model, accelerator, |
| | batch, False, 1, config) |
| |
|
| | logger.info(f'Rendered in {(time.time() - eval_start_time):0.3f}s') |
| |
|
| | if accelerator.is_main_process: |
| | rendering['rgb'] = postprocess_fn(rendering['rgb']) |
| | rendering = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, rendering) |
| | utils.save_img_u8(rendering['rgb'], path_fn(f'color_{idx_str}.png')) |
| | if 'normals' in rendering: |
| | utils.save_img_u8(rendering['normals'] / 2. + 0.5, |
| | path_fn(f'normals_{idx_str}.png')) |
| | utils.save_img_f32(rendering['distance_mean'], |
| | path_fn(f'distance_mean_{idx_str}.tiff')) |
| | utils.save_img_f32(rendering['distance_median'], |
| | path_fn(f'distance_median_{idx_str}.tiff')) |
| | utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx_str}.tiff')) |
| | num_files = len(glob.glob(path_fn('acc_*.tiff'))) |
| | if accelerator.is_main_process and num_files == dataset.size: |
| | logger.info(f'All files found, creating videos.') |
| | create_videos(config, config.render_dir, out_dir, out_name, dataset.size) |
| | accelerator.wait_for_everyone() |
| | logger.info('Finish rendering.') |
| |
|
| | if __name__ == '__main__': |
| | with gin.config_scope('eval'): |
| | app.run(main) |
| |
|