├── .gitignore ├── LICENSE ├── README.md ├── configs ├── meetroom.json ├── meetroom_full.json ├── meetroom_init.json ├── n3dv.json ├── n3dv_full.json ├── n3dv_full_hightv.json └── n3dv_init.json ├── opt.py ├── prepare_dataset.py ├── render_delta.py ├── train_video_n3dv_base.py ├── train_video_n3dv_full.py ├── train_video_n3dv_pilot.py └── util ├── __init__.py ├── co3d_dataset.py ├── config_util.py ├── dataset.py ├── dataset_base.py ├── llff_dataset.py ├── load_llff.py ├── nerf_dataset.py ├── nsvf_dataset.py └── util.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlgoHunt/StreamRF/be8d2b800ce5ee9aa2e46f2efef4c9959478efcc/.gitignore -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, the Plenoxels authors 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Streaming Radiance Fields for 3D Video Synthesis 2 | 3 | Lingzhi Li, Zhen Shen, Zhongshu Wang, Li Shen, Ping Tan 4 | 5 | Alibaba Group 6 | 7 | Citation: 8 | ``` 9 | 10 | @article{li2022streaming, 11 | title={Streaming radiance fields for 3d video synthesis}, 12 | author={Li, Lingzhi and Shen, Zhen and Wang, Zhongshu and Shen, Li and Tan, Ping}, 13 | journal={Advances in Neural Information Processing Systems}, 14 | volume={35}, 15 | pages={13485--13498}, 16 | year={2022} 17 | } 18 | 19 | ``` 20 | 21 | arXiv: 22 | 23 | 24 | 25 | https://user-images.githubusercontent.com/28325733/210695784-a309dce8-533b-4c93-b637-da369e2a288e.mp4 26 | 27 | Due to size limit, this is a downsampled video, check full resolution video [here](https://github.com/AlgoHunt/VideoHolder/releases/download/StreamRF/StreamRF-Camera.Ready.Video.mp4). 28 | 29 | ## Dataset 30 | **Meet Room Dataset**: [ModelScope魔搭](https://www.modelscope.cn/datasets/DAMOXR/dynamic_nerf_meeting_room_dataset/summary) ,[Google Drive](https://drive.google.com/drive/folders/1lNmQ6_ykyKjT6UKy-SnqWoSlI5yjh3l_?usp=share_link) 31 | 32 | > We will add more data in ModelScope. (我们会在魔搭里更新更多的数据,敬请关注) 33 | 34 | **N3DV Dataset**: 35 | https://github.com/facebookresearch/Neural_3D_Video 36 | 37 | 38 | 39 | 40 | ## Training StreamRF 41 | 42 | Following the [setup](https://github.com/sxyu/svox2#setup) of the orginal plenoxels' repository 43 | 44 | For each scene, extract frames from every video, and arrange them into the following structure: 45 | 46 | ```bash 47 | python prepare_dataset.py 48 | ``` 49 | 50 | ``` 51 | 52 | ├── 0000 53 | | ├── poses_bounds.npy 54 | | └── images 55 | | └── cam[00/01/02/.../20].png 56 | ... 57 | └── 0299 58 | ├── poses_bounds.npy 59 | └── images 60 | └── cam[00/01/02/.../20].png 61 | ``` 62 | We provide the pose_bounds.npy of both dataset in the Meet Room Dataset's [link]( https://drive.google.com/drive/folders/1lNmQ6_ykyKjT6UKy-SnqWoSlI5yjh3l_?usp=share_link). If you wants to generate poses_bounds.npy for yourself check DS-NeRF's [repo](https://github.com/dunbar12138/DSNeRF#generate-camera-poses-and-sparse-depth-information-using-colmap-optional). 63 | 64 | ### Meet Room Dataset 65 | 66 | 1. Initialize the first frame model 67 | 68 | ```bash 69 | python opt.py -t /0000 -c configs/meetroom_init.json --scale 1.0 70 | ``` 71 | 72 | 2. Train the pilot model 73 | 74 | ```bash 75 | python train_video_n3dv_pilot.py -t -c configs/meetroom.json --batch_size 20000 --pretrained --n_iters 1000 --lr_sigma 0.3 --lr_sigma_final 0.3 --lr_sh 1e-2 --lr_sh_final 1e-4 --lr_sigma_decay_steps 1000 --lr_sh_decay_steps 1000 --frame_end 300 --fps 30 --train_use_all 0 --scale 1.0 --sh_keep_thres 1.0 --sh_prune_thres 0.1 --performance_mode --dilate_rate_before 1 --dilate_rate_after 1 --stop_thres 0.01 --compress_saving --save_delta --pilot_factor 2 76 | ``` 77 | 78 | 3. Train the full model 79 | 80 | ```bash 81 | python train_video_n3dv_full.py -t -c configs/meetroom_full.json --batch_size 20000 --pretrained --n_iters 500 --lr_sigma 1.0 --lr_sigma_final 1.0 --lr_sh 1e-2 --lr_sh_final 1e-2 --lr_sigma_decay_steps 500 --lr_sh_decay_steps 500 --frame_end 300 --fps 30 --train_use_all 0 --scale 1.0 --sh_keep_thres 1.5 --sh_prune_thres 0.3 --performance_mode --dilate_rate_before 2 --dilate_rate_after 2 --compress_saving --save_delta --apply_narrow_band 82 | ``` 83 | 84 | #### N3DV Dataset 85 | 86 | 1. Initialize the first frame model 87 | 88 | ```bash 89 | python opt.py -t /0000 -c configs/init_ablation/n3dv_init.json --offset 500 --scale 0.5 --nosphereinit 90 | ``` 91 | 92 | 2. Train the pilot model 93 | ```bash 94 | python train_video_n3dv_pilot.py -t -c configs/n3dv.json --batch_size 20000 --pretrained --n_iters 750 --lr_sigma 1.0 --lr_sigma_final 1.0 --lr_sh 1e-2 --lr_sh_final 1e-3 --lr_sigma_decay_steps 750 --lr_sh_decay_steps 750 --frame_end 300 --fps 30 --train_use_all 0 --offset 750 --scale 0.5 --sh_keep_thres 0.5 --sh_prune_thres 0.1 --performance_mode --dilate_rate_before 1 --dilate_rate_after 1 --stop_thres 0.01 --compress_saving --save_delta --pilot_factor 2 95 | ``` 96 | 97 | 3. Train the full model 98 | ```bash 99 | python train_video_n3dv_full.py -t -c configs/n3dv_full.json --batch_size 20000 --pretrained --n_iters 500 --lr_sigma 1.0 --lr_sigma_final 1.0 --lr_sh 1e-2 --lr_sh_final 3e-3 --lr_sigma_decay_steps 500 --lr_sh_decay_steps 300 --frame_end 300 --fps 30 --train_use_all 0 --offset 1500 --scale 0.5 --sh_keep_thres 1.0 --sh_prune_thres 0.2 --performance_mode --dilate_rate_before 2 --dilate_rate_after 2 --stop_thres 0.01 --compress_saving --save_delta --apply_narrow_band 100 | ``` 101 | 102 | We can change the tv loss weight for better video quality but lower PSNR 103 | 104 | 3. Train the pilot model 105 | 106 | ```bash 107 | python train_video_n3dv_full.py -t -c configs/n3dv_full_hightv.json --batch_size 20000 --pretrained --n_iters 500 --lr_sigma 1.0 --lr_sigma_final 1.0 --lr_sh 1e-2 --lr_sh_final 3e-3 --lr_sigma_decay_steps 500 --lr_sh_decay_steps 300 --frame_end 300 --fps 30 --train_use_all 0 --offset 1500 --scale 0.5 --sh_keep_thres 1.0 --sh_prune_thres 0.2 --performance_mode --dilate_rate_before 2 --dilate_rate_after 2 --stop_thres 0.01 --compress_saving --save_delta --apply_narrow_band 108 | ```bash 109 | 110 | ## Testing StreamRF 111 | 112 | For Meet Room Dataset: 113 | ```bash 114 | python render_delta.py -t -c configs/meetroom_full.json --batch_size 20000 --pretrained --frame_end 300 --fps 30 --scale 1.0 --performance_mode 115 | ``` 116 | 117 | For N3DV Dataset: 118 | ```bash 119 | python render_delta.py -t -c configs/n3dv_full.json --batch_size 20000 --pretrained --frame_end 300 --fps 30 --scale 0.5 --performance_mode 120 | ``` 121 | -------------------------------------------------------------------------------- /configs/meetroom.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "llff", 3 | "thresh_type": "sigma", 4 | "density_thresh": 5, 5 | "lambda_tv": 5e-3, 6 | "lambda_tv_sh": 5e-2, 7 | "lambda_sparsity": 1e-12, 8 | "background_brightness": 0.5, 9 | "last_sample_opaque": false, 10 | "tv_early_only": 0, 11 | "llffhold": 100, 12 | "eval_every": 1, 13 | "print_every": 1, 14 | "lr_sigma_delay_steps":0, 15 | "lr_sh_delay_steps":0 16 | 17 | } 18 | -------------------------------------------------------------------------------- /configs/meetroom_full.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "llff", 3 | "thresh_type": "sigma", 4 | "density_thresh": 5, 5 | "lambda_tv": 5e-3, 6 | "lambda_tv_sh": 5e-2, 7 | "lambda_sparsity": 1e-12, 8 | "background_brightness": 0.5, 9 | "last_sample_opaque": false, 10 | "tv_early_only": 0, 11 | "llffhold": 100, 12 | "eval_every": 1, 13 | "print_every": 1, 14 | "lr_sigma_delay_steps":0, 15 | "lr_sh_delay_steps":0 16 | 17 | } 18 | -------------------------------------------------------------------------------- /configs/meetroom_init.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[256, 256, 128], [512, 512, 128],[768,768,128]]", 3 | "upsamp_every": 38400, 4 | "lr_sigma": 3e1, 5 | "lr_sh": 1e-2, 6 | "thresh_type": "sigma", 7 | "density_thresh": 5, 8 | "lambda_tv": 5e-3, 9 | "lambda_tv_sh": 5e-2, 10 | "lambda_sparsity": 1e-12, 11 | "background_brightness": 0.5, 12 | "last_sample_opaque": false, 13 | "tv_early_only": 0, 14 | "llffhold": 100 15 | } 16 | -------------------------------------------------------------------------------- /configs/n3dv.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "llff", 3 | "thresh_type": "sigma", 4 | "density_thresh": 5, 5 | "lambda_tv": 5e-3, 6 | "lambda_tv_sh": 5e-2, 7 | "lambda_sparsity": 1e-12, 8 | "background_brightness": 0.5, 9 | "last_sample_opaque": false, 10 | "tv_early_only": 0, 11 | "llffhold": 100, 12 | "eval_every": 1, 13 | "print_every": 1, 14 | "lr_sigma_delay_steps":0, 15 | "lr_sh_delay_steps":0 16 | 17 | } 18 | -------------------------------------------------------------------------------- /configs/n3dv_full.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "llff", 3 | "thresh_type": "sigma", 4 | "density_thresh": 5, 5 | "lambda_tv": 5e-4, 6 | "lambda_tv_sh": 5e-3, 7 | "lambda_sparsity": 1e-12, 8 | "background_brightness": 0.5, 9 | "last_sample_opaque": false, 10 | "tv_early_only": 0, 11 | "llffhold": 100, 12 | "eval_every": 1, 13 | "print_every": 1, 14 | "lr_sigma_delay_steps":0, 15 | "lr_sh_delay_steps":0 16 | 17 | } 18 | -------------------------------------------------------------------------------- /configs/n3dv_full_hightv.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "llff", 3 | "thresh_type": "sigma", 4 | "density_thresh": 5, 5 | "lambda_tv": 5e-3, 6 | "lambda_tv_sh": 5e-2, 7 | "lambda_sparsity": 1e-12, 8 | "background_brightness": 0.5, 9 | "last_sample_opaque": false, 10 | "tv_early_only": 0, 11 | "llffhold": 100, 12 | "eval_every": 1, 13 | "print_every": 1, 14 | "lr_sigma_delay_steps":0, 15 | "lr_sh_delay_steps":0 16 | 17 | } 18 | -------------------------------------------------------------------------------- /configs/n3dv_init.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[128,128,128], [256,256,128], [512,512,128]]", 3 | "upsamp_every": 38400, 4 | "lr_sigma": 3e1, 5 | "lr_sh": 1e-2, 6 | "thresh_type": "sigma", 7 | "density_thresh": 5, 8 | "lambda_tv": 5e-4, 9 | "lambda_tv_sh": 5e-3, 10 | "lambda_sparsity": 1e-12, 11 | "background_brightness": 0.5, 12 | "last_sample_opaque": false, 13 | "tv_early_only": 0, 14 | "llffhold": 100 15 | } 16 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Alex Yu 2 | 3 | # First, install svox2 4 | # Then, python opt.py /nerf_synthetic/ -t ckpt/ 5 | # or use launching script: sh launch.sh 6 | import torch 7 | import torch.cuda 8 | import torch.optim 9 | import torch.nn.functional as F 10 | import svox2 11 | import json 12 | import imageio 13 | import os 14 | from os import path 15 | import shutil 16 | import gc 17 | import numpy as np 18 | import math 19 | import argparse 20 | import cv2 21 | from util.dataset import datasets 22 | from util.util import Timing, get_expon_lr_func, generate_dirs_equirect, viridis_cmap 23 | from util import config_util 24 | 25 | from warnings import warn 26 | from datetime import datetime 27 | from torch.utils.tensorboard import SummaryWriter 28 | 29 | from tqdm import tqdm 30 | from typing import NamedTuple, Optional, Union 31 | 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | 34 | parser = argparse.ArgumentParser() 35 | config_util.define_common_args(parser) 36 | 37 | 38 | group = parser.add_argument_group("general") 39 | group.add_argument('--train_dir', '-t', type=str, default='ckpt', 40 | help='checkpoint and logging directory') 41 | 42 | group.add_argument('--reso', 43 | type=str, 44 | default= 45 | "[[256, 256, 256], [512, 512, 512]]", 46 | help='List of grid resolution (will be evaled as json);' 47 | 'resamples to the next one every upsamp_every iters, then ' + 48 | 'stays at the last one; ' + 49 | 'should be a list where each item is a list of 3 ints or an int') 50 | group.add_argument('--upsamp_every', type=int, default= 51 | 3 * 12800, 52 | help='upsample the grid every x iters') 53 | group.add_argument('--init_iters', type=int, default= 54 | 0, 55 | help='do not upsample for first x iters') 56 | group.add_argument('--upsample_density_add', type=float, default= 57 | 0.0, 58 | help='add the remaining density by this amount when upsampling') 59 | 60 | group.add_argument('--basis_type', 61 | choices=['sh', '3d_texture', 'mlp'], 62 | default='sh', 63 | help='Basis function type') 64 | 65 | group.add_argument('--basis_reso', type=int, default=32, 66 | help='basis grid resolution (only for learned texture)') 67 | group.add_argument('--sh_dim', type=int, default=9, help='SH/learned basis dimensions (at most 10)') 68 | 69 | group.add_argument('--mlp_posenc_size', type=int, default=4, help='Positional encoding size if using MLP basis; 0 to disable') 70 | group.add_argument('--mlp_width', type=int, default=32, help='MLP width if using MLP basis') 71 | 72 | group.add_argument('--background_nlayers', type=int, default=0,#32, 73 | help='Number of background layers (0=disable BG model)') 74 | group.add_argument('--background_reso', type=int, default=512, help='Background resolution') 75 | 76 | 77 | 78 | group = parser.add_argument_group("optimization") 79 | group.add_argument('--n_iters', type=int, default=10 * 12800, help='total number of iters to optimize for') 80 | group.add_argument('--batch_size', type=int, default= 81 | 5000, 82 | #100000, 83 | # 2000, 84 | help='batch size') 85 | 86 | 87 | # TODO: make the lr higher near the end 88 | group.add_argument('--sigma_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Density optimizer") 89 | group.add_argument('--lr_sigma', type=float, default=3e1, help='SGD/rmsprop lr for sigma') 90 | group.add_argument('--lr_sigma_final', type=float, default=5e-2) 91 | group.add_argument('--lr_sigma_decay_steps', type=int, default=250000) 92 | group.add_argument('--lr_sigma_delay_steps', type=int, default=15000, 93 | help="Reverse cosine steps (0 means disable)") 94 | group.add_argument('--lr_sigma_delay_mult', type=float, default=1e-2)#1e-4)#1e-4) 95 | 96 | 97 | group.add_argument('--sh_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="SH optimizer") 98 | group.add_argument('--lr_sh', type=float, default= 99 | 1e-2, 100 | help='SGD/rmsprop lr for SH') 101 | group.add_argument('--lr_sh_final', type=float, 102 | default= 103 | 5e-6 104 | ) 105 | group.add_argument('--lr_sh_decay_steps', type=int, default=250000) 106 | group.add_argument('--lr_sh_delay_steps', type=int, default=0, help="Reverse cosine steps (0 means disable)") 107 | group.add_argument('--lr_sh_delay_mult', type=float, default=1e-2) 108 | 109 | group.add_argument('--lr_fg_begin_step', type=int, default=0, help="Foreground begins training at given step number") 110 | 111 | # BG LRs 112 | group.add_argument('--bg_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Background optimizer") 113 | group.add_argument('--lr_sigma_bg', type=float, default=3e0, 114 | help='SGD/rmsprop lr for background') 115 | group.add_argument('--lr_sigma_bg_final', type=float, default=3e-3, 116 | help='SGD/rmsprop lr for background') 117 | group.add_argument('--lr_sigma_bg_decay_steps', type=int, default=250000) 118 | group.add_argument('--lr_sigma_bg_delay_steps', type=int, default=0, help="Reverse cosine steps (0 means disable)") 119 | group.add_argument('--lr_sigma_bg_delay_mult', type=float, default=1e-2) 120 | 121 | group.add_argument('--lr_color_bg', type=float, default=1e-1, 122 | help='SGD/rmsprop lr for background') 123 | group.add_argument('--lr_color_bg_final', type=float, default=5e-6,#1e-4, 124 | help='SGD/rmsprop lr for background') 125 | group.add_argument('--lr_color_bg_decay_steps', type=int, default=250000) 126 | group.add_argument('--lr_color_bg_delay_steps', type=int, default=0, help="Reverse cosine steps (0 means disable)") 127 | group.add_argument('--lr_color_bg_delay_mult', type=float, default=1e-2) 128 | # END BG LRs 129 | 130 | group.add_argument('--basis_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Learned basis optimizer") 131 | group.add_argument('--lr_basis', type=float, default=#2e6, 132 | 1e-6, 133 | help='SGD/rmsprop lr for SH') 134 | group.add_argument('--lr_basis_final', type=float, 135 | default= 136 | 1e-6 137 | ) 138 | group.add_argument('--lr_basis_decay_steps', type=int, default=250000) 139 | group.add_argument('--lr_basis_delay_steps', type=int, default=0,#15000, 140 | help="Reverse cosine steps (0 means disable)") 141 | group.add_argument('--lr_basis_begin_step', type=int, default=0)#4 * 12800) 142 | group.add_argument('--lr_basis_delay_mult', type=float, default=1e-2) 143 | 144 | group.add_argument('--rms_beta', type=float, default=0.95, help="RMSProp exponential averaging factor") 145 | 146 | group.add_argument('--print_every', type=int, default=20, help='print every') 147 | group.add_argument('--save_every', type=int, default=5, 148 | help='save every x epochs') 149 | group.add_argument('--eval_every', type=int, default=1, 150 | help='evaluate every x epochs') 151 | 152 | group.add_argument('--init_sigma', type=float, 153 | default=0.1, 154 | help='initialization sigma') 155 | group.add_argument('--init_sigma_bg', type=float, 156 | default=0.1, 157 | help='initialization sigma (for BG)') 158 | 159 | # Extra logging 160 | group.add_argument('--log_mse_image', action='store_true', default=False) 161 | group.add_argument('--log_depth_map', action='store_true', default=False) 162 | group.add_argument('--log_depth_map_use_thresh', type=float, default=None, 163 | help="If specified, uses the Dex-neRF version of depth with given thresh; else returns expected term") 164 | 165 | 166 | group = parser.add_argument_group("misc experiments") 167 | group.add_argument('--thresh_type', 168 | choices=["weight", "sigma"], 169 | default="weight", 170 | help='Upsample threshold type') 171 | group.add_argument('--weight_thresh', type=float, 172 | default=0.0005 * 512, 173 | # default=0.025 * 512, 174 | help='Upsample weight threshold; will be divided by resulting z-resolution') 175 | group.add_argument('--density_thresh', type=float, 176 | default=5.0, 177 | help='Upsample sigma threshold') 178 | group.add_argument('--background_density_thresh', type=float, 179 | default=1.0+1e-9, 180 | help='Background sigma threshold for sparsification') 181 | group.add_argument('--max_grid_elements', type=int, 182 | default=44_000_000, 183 | help='Max items to store after upsampling ' 184 | '(the number here is given for 22GB memory)') 185 | 186 | group.add_argument('--tune_mode', action='store_true', default=False, 187 | help='hypertuning mode (do not save, for speed)') 188 | group.add_argument('--tune_nosave', action='store_true', default=False, 189 | help='do not save any checkpoint even at the end') 190 | 191 | 192 | 193 | group = parser.add_argument_group("losses") 194 | # Foreground TV 195 | group.add_argument('--lambda_tv', type=float, default=1e-5) 196 | group.add_argument('--tv_sparsity', type=float, default=0.01) 197 | group.add_argument('--tv_logalpha', action='store_true', default=False, 198 | help='Use log(1-exp(-delta * sigma)) as in neural volumes') 199 | 200 | group.add_argument('--lambda_tv_sh', type=float, default=1e-3) 201 | group.add_argument('--tv_sh_sparsity', type=float, default=0.01) 202 | 203 | group.add_argument('--lambda_tv_lumisphere', type=float, default=0.0)#1e-2)#1e-3) 204 | group.add_argument('--tv_lumisphere_sparsity', type=float, default=0.01) 205 | group.add_argument('--tv_lumisphere_dir_factor', type=float, default=0.0) 206 | 207 | group.add_argument('--tv_decay', type=float, default=1.0) 208 | 209 | group.add_argument('--lambda_l2_sh', type=float, default=0.0)#1e-4) 210 | group.add_argument('--tv_early_only', type=int, default=1, help="Turn off TV regularization after the first split/prune") 211 | 212 | group.add_argument('--tv_contiguous', type=int, default=1, 213 | help="Apply TV only on contiguous link chunks, which is faster") 214 | # End Foreground TV 215 | 216 | group.add_argument('--lambda_sparsity', type=float, default= 217 | 0.0, 218 | help="Weight for sparsity loss as in SNeRG/PlenOctrees " + 219 | "(but applied on the ray)") 220 | group.add_argument('--lambda_beta', type=float, default= 221 | 0.0, 222 | help="Weight for beta distribution sparsity loss as in neural volumes") 223 | 224 | 225 | # Background TV 226 | group.add_argument('--lambda_tv_background_sigma', type=float, default=1e-2) 227 | group.add_argument('--lambda_tv_background_color', type=float, default=1e-2) 228 | 229 | group.add_argument('--tv_background_sparsity', type=float, default=0.01) 230 | # End Background TV 231 | 232 | # Basis TV 233 | group.add_argument('--lambda_tv_basis', type=float, default=0.0, 234 | help='Learned basis total variation loss') 235 | # End Basis TV 236 | 237 | group.add_argument('--weight_decay_sigma', type=float, default=1.0) 238 | group.add_argument('--weight_decay_sh', type=float, default=1.0) 239 | 240 | group.add_argument('--lr_decay', action='store_true', default=True) 241 | 242 | group.add_argument('--n_train', type=int, default=None, help='Number of training images. Defaults to use all avaiable.') 243 | 244 | group.add_argument('--nosphereinit', action='store_true', default=False, 245 | help='do not start with sphere bounds (please do not use for 360)') 246 | 247 | group.add_argument('--offset', type=int, default=250) 248 | 249 | args = parser.parse_args() 250 | config_util.maybe_merge_config_file(args) 251 | 252 | assert args.lr_sigma_final <= args.lr_sigma, "lr_sigma must be >= lr_sigma_final" 253 | assert args.lr_sh_final <= args.lr_sh, "lr_sh must be >= lr_sh_final" 254 | assert args.lr_basis_final <= args.lr_basis, "lr_basis must be >= lr_basis_final" 255 | 256 | os.makedirs(args.train_dir, exist_ok=True) 257 | summary_writer = SummaryWriter(args.train_dir) 258 | 259 | reso_list = json.loads(args.reso) 260 | reso_id = 0 261 | 262 | with open(path.join(args.train_dir, 'args.json'), 'w') as f: 263 | json.dump(args.__dict__, f, indent=2) 264 | # Changed name to prevent errors 265 | shutil.copyfile(__file__, path.join(args.train_dir, 'opt_frozen.py')) 266 | 267 | torch.manual_seed(20200823) 268 | np.random.seed(20200823) 269 | 270 | factor = 1 271 | def deploy_dset(dset): 272 | dset.c2w = torch.from_numpy(dset.c2w) 273 | dset.gt = torch.from_numpy(dset.gt).float() 274 | if not dset.is_train_split: 275 | dset.render_c2w = torch.from_numpy(dset.render_c2w) 276 | else: 277 | dset.gen_rays() 278 | return dset 279 | dset = datasets[args.dataset_type]( 280 | args.data_dir, 281 | split="train", 282 | device=device, 283 | factor=factor, 284 | n_images=args.n_train, 285 | offset=args.offset, 286 | **config_util.build_data_options(args)) 287 | 288 | if args.background_nlayers > 0 and not dset.should_use_background: 289 | warn('Using a background model for dataset type ' + str(type(dset)) + ' which typically does not use background') 290 | 291 | dset_test = datasets[args.dataset_type]( 292 | args.data_dir, split="test", **config_util.build_data_options(args)) 293 | deploy_dset(dset) 294 | deploy_dset(dset_test) 295 | global_start_time = datetime.now() 296 | 297 | grid = svox2.SparseGrid(reso=reso_list[reso_id], 298 | center=dset.scene_center, 299 | radius=dset.scene_radius, 300 | use_sphere_bound=dset.use_sphere_bound and not args.nosphereinit, 301 | basis_dim=args.sh_dim, 302 | use_z_order=True, 303 | device=device, 304 | basis_reso=args.basis_reso, 305 | basis_type=svox2.__dict__['BASIS_TYPE_' + args.basis_type.upper()], 306 | mlp_posenc_size=args.mlp_posenc_size, 307 | mlp_width=args.mlp_width, 308 | background_nlayers=args.background_nlayers, 309 | background_reso=args.background_reso) 310 | 311 | # DC -> gray; mind the SH scaling! 312 | grid.sh_data.data[:] = 0.0 313 | grid.density_data.data[:] = 0.0 if args.lr_fg_begin_step > 0 else args.init_sigma 314 | 315 | if grid.use_background: 316 | grid.background_data.data[..., -1] = args.init_sigma_bg 317 | # grid.background_data.data[..., :-1] = 0.5 / svox2.utils.SH_C0 318 | 319 | # grid.sh_data.data[:, 0] = 4.0 320 | # osh = grid.density_data.data.shape 321 | # den = grid.density_data.data.view(grid.links.shape) 322 | # # den[:] = 0.00 323 | # # den[:, :256, :] = 1e9 324 | # # den[:, :, 0] = 1e9 325 | # grid.density_data.data = den.view(osh) 326 | 327 | optim_basis_mlp = None 328 | 329 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 330 | grid.reinit_learned_bases(init_type='sh') 331 | # grid.reinit_learned_bases(init_type='fourier') 332 | # grid.reinit_learned_bases(init_type='sg', upper_hemi=True) 333 | # grid.basis_data.data.normal_(mean=0.28209479177387814, std=0.001) 334 | 335 | elif grid.basis_type == svox2.BASIS_TYPE_MLP: 336 | # MLP! 337 | optim_basis_mlp = torch.optim.Adam( 338 | grid.basis_mlp.parameters(), 339 | lr=args.lr_basis 340 | ) 341 | 342 | 343 | grid.requires_grad_(True) 344 | config_util.setup_render_opts(grid.opt, args) 345 | print('Render options', grid.opt) 346 | 347 | gstep_id_base = 0 348 | 349 | resample_cameras = [ 350 | svox2.Camera(c2w.to(device=device), 351 | dset.intrins.get('fx', i), 352 | dset.intrins.get('fy', i), 353 | dset.intrins.get('cx', i), 354 | dset.intrins.get('cy', i), 355 | width=dset.get_image_size(i)[1], 356 | height=dset.get_image_size(i)[0], 357 | ndc_coeffs=dset.ndc_coeffs) for i, c2w in enumerate(dset.c2w) 358 | ] 359 | ckpt_path = path.join(args.train_dir, 'ckpt.npz') 360 | 361 | lr_sigma_func = get_expon_lr_func(args.lr_sigma, args.lr_sigma_final, args.lr_sigma_delay_steps, 362 | args.lr_sigma_delay_mult, args.lr_sigma_decay_steps) 363 | lr_sh_func = get_expon_lr_func(args.lr_sh, args.lr_sh_final, args.lr_sh_delay_steps, 364 | args.lr_sh_delay_mult, args.lr_sh_decay_steps) 365 | lr_basis_func = get_expon_lr_func(args.lr_basis, args.lr_basis_final, args.lr_basis_delay_steps, 366 | args.lr_basis_delay_mult, args.lr_basis_decay_steps) 367 | lr_sigma_bg_func = get_expon_lr_func(args.lr_sigma_bg, args.lr_sigma_bg_final, args.lr_sigma_bg_delay_steps, 368 | args.lr_sigma_bg_delay_mult, args.lr_sigma_bg_decay_steps) 369 | lr_color_bg_func = get_expon_lr_func(args.lr_color_bg, args.lr_color_bg_final, args.lr_color_bg_delay_steps, 370 | args.lr_color_bg_delay_mult, args.lr_color_bg_decay_steps) 371 | lr_sigma_factor = 1.0 372 | lr_sh_factor = 1.0 373 | lr_basis_factor = 1.0 374 | 375 | last_upsamp_step = args.init_iters 376 | 377 | if args.enable_random: 378 | warn("Randomness is enabled for training (normal for LLFF & scenes with background)") 379 | 380 | epoch_id = -1 381 | while True: 382 | dset.shuffle_rays() 383 | epoch_id += 1 384 | epoch_size = dset.rays.origins.size(0) 385 | batches_per_epoch = (epoch_size-1)//args.batch_size+1 386 | # Test 387 | def eval_step(): 388 | # Put in a function to avoid memory leak 389 | print('Eval step') 390 | with torch.no_grad(): 391 | stats_test = {'psnr' : 0.0, 'mse' : 0.0} 392 | 393 | # Standard set 394 | N_IMGS_TO_EVAL = min(20 if epoch_id > 0 else 5, dset_test.n_images) 395 | N_IMGS_TO_SAVE = N_IMGS_TO_EVAL # if not args.tune_mode else 1 396 | img_eval_interval = dset_test.n_images // N_IMGS_TO_EVAL 397 | img_save_interval = (N_IMGS_TO_EVAL // N_IMGS_TO_SAVE) 398 | img_ids = range(0, dset_test.n_images, img_eval_interval) 399 | 400 | # Special 'very hard' specular + fuzz set 401 | # img_ids = [2, 5, 7, 9, 21, 402 | # 44, 45, 47, 49, 56, 403 | # 80, 88, 99, 115, 120, 404 | # 154] 405 | # img_save_interval = 1 406 | 407 | n_images_gen = 0 408 | for i, img_id in tqdm(enumerate(img_ids), total=len(img_ids)): 409 | c2w = dset_test.c2w[img_id].to(device=device) 410 | cam = svox2.Camera(c2w, 411 | dset_test.intrins.get('fx', img_id), 412 | dset_test.intrins.get('fy', img_id), 413 | dset_test.intrins.get('cx', img_id), 414 | dset_test.intrins.get('cy', img_id), 415 | width=dset_test.get_image_size(img_id)[1], 416 | height=dset_test.get_image_size(img_id)[0], 417 | ndc_coeffs=dset_test.ndc_coeffs) 418 | rgb_pred_test = grid.volume_render_image(cam, use_kernel=True) 419 | rgb_gt_test = dset_test.gt[img_id].to(device=device) 420 | all_mses = ((rgb_gt_test - rgb_pred_test) ** 2).cpu() 421 | if i % img_save_interval == 0: 422 | img_pred = rgb_pred_test.cpu() 423 | img_pred.clamp_max_(1.0) 424 | summary_writer.add_image(f'test/image_{img_id:04d}', 425 | img_pred, global_step=gstep_id_base, dataformats='HWC') 426 | if args.log_mse_image: 427 | mse_img = all_mses / all_mses.max() 428 | summary_writer.add_image(f'test/mse_map_{img_id:04d}', 429 | mse_img, global_step=gstep_id_base, dataformats='HWC') 430 | if args.log_depth_map: 431 | depth_img = grid.volume_render_depth_image(cam, 432 | args.log_depth_map_use_thresh if 433 | args.log_depth_map_use_thresh else None 434 | ) 435 | depth_img = viridis_cmap(depth_img.cpu()) 436 | summary_writer.add_image(f'test/depth_map_{img_id:04d}', 437 | depth_img, 438 | global_step=gstep_id_base, dataformats='HWC') 439 | 440 | rgb_pred_test = rgb_gt_test = None 441 | mse_num : float = all_mses.mean().item() 442 | psnr = -10.0 * math.log10(mse_num) 443 | if math.isnan(psnr): 444 | print('NAN PSNR', i, img_id, mse_num) 445 | assert False 446 | stats_test['mse'] += mse_num 447 | stats_test['psnr'] += psnr 448 | n_images_gen += 1 449 | 450 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE or \ 451 | grid.basis_type == svox2.BASIS_TYPE_MLP: 452 | # Add spherical map visualization 453 | EQ_RESO = 256 454 | eq_dirs = generate_dirs_equirect(EQ_RESO * 2, EQ_RESO) 455 | eq_dirs = torch.from_numpy(eq_dirs).to(device=device).view(-1, 3) 456 | 457 | if grid.basis_type == svox2.BASIS_TYPE_MLP: 458 | sphfuncs = grid._eval_basis_mlp(eq_dirs) 459 | else: 460 | sphfuncs = grid._eval_learned_bases(eq_dirs) 461 | sphfuncs = sphfuncs.view(EQ_RESO, EQ_RESO*2, -1).permute([2, 0, 1]).cpu().numpy() 462 | 463 | stats = [(sphfunc.min(), sphfunc.mean(), sphfunc.max()) 464 | for sphfunc in sphfuncs] 465 | sphfuncs_cmapped = [viridis_cmap(sphfunc) for sphfunc in sphfuncs] 466 | for im, (minv, meanv, maxv) in zip(sphfuncs_cmapped, stats): 467 | cv2.putText(im, f"{minv=:.4f} {meanv=:.4f} {maxv=:.4f}", (10, 20), 468 | 0, 0.5, [255, 0, 0]) 469 | sphfuncs_cmapped = np.concatenate(sphfuncs_cmapped, axis=0) 470 | summary_writer.add_image(f'test/spheric', 471 | sphfuncs_cmapped, global_step=gstep_id_base, dataformats='HWC') 472 | # END add spherical map visualization 473 | 474 | stats_test['mse'] /= n_images_gen 475 | stats_test['psnr'] /= n_images_gen 476 | for stat_name in stats_test: 477 | summary_writer.add_scalar('test/' + stat_name, 478 | stats_test[stat_name], global_step=gstep_id_base) 479 | summary_writer.add_scalar('epoch_id', float(epoch_id), global_step=gstep_id_base) 480 | print('eval stats:', stats_test) 481 | if epoch_id % max(factor, args.eval_every) == 0: #and (epoch_id > 0 or not args.tune_mode): 482 | # NOTE: we do an eval sanity check, if not in tune_mode 483 | eval_step() 484 | gc.collect() 485 | 486 | def train_step(): 487 | print('Train step') 488 | pbar = tqdm(enumerate(range(0, epoch_size, args.batch_size)), total=batches_per_epoch) 489 | stats = {"mse" : 0.0, "psnr" : 0.0, "invsqr_mse" : 0.0} 490 | for iter_id, batch_begin in pbar: 491 | gstep_id = iter_id + gstep_id_base 492 | if args.lr_fg_begin_step > 0 and gstep_id == args.lr_fg_begin_step: 493 | grid.density_data.data[:] = args.init_sigma 494 | lr_sigma = lr_sigma_func(gstep_id) * lr_sigma_factor 495 | lr_sh = lr_sh_func(gstep_id) * lr_sh_factor 496 | lr_basis = lr_basis_func(gstep_id - args.lr_basis_begin_step) * lr_basis_factor 497 | lr_sigma_bg = lr_sigma_bg_func(gstep_id - args.lr_basis_begin_step) * lr_basis_factor 498 | lr_color_bg = lr_color_bg_func(gstep_id - args.lr_basis_begin_step) * lr_basis_factor 499 | if not args.lr_decay: 500 | lr_sigma = args.lr_sigma * lr_sigma_factor 501 | lr_sh = args.lr_sh * lr_sh_factor 502 | lr_basis = args.lr_basis * lr_basis_factor 503 | 504 | batch_end = min(batch_begin + args.batch_size, epoch_size) 505 | batch_origins = dset.rays.origins[batch_begin: batch_end] 506 | batch_dirs = dset.rays.dirs[batch_begin: batch_end] 507 | rgb_gt = dset.rays.gt[batch_begin: batch_end] 508 | rays = svox2.Rays(batch_origins, batch_dirs) 509 | 510 | # with Timing("volrend_fused"): 511 | rgb_pred = grid.volume_render_fused(rays, rgb_gt, 512 | beta_loss=args.lambda_beta, 513 | sparsity_loss=args.lambda_sparsity, 514 | randomize=args.enable_random) 515 | 516 | # with Timing("loss_comp"): 517 | mse = F.mse_loss(rgb_gt, rgb_pred) 518 | 519 | # Stats 520 | mse_num : float = mse.detach().item() 521 | psnr = -10.0 * math.log10(mse_num) 522 | stats['mse'] += mse_num 523 | stats['psnr'] += psnr 524 | stats['invsqr_mse'] += 1.0 / mse_num ** 2 525 | 526 | if (iter_id + 1) % args.print_every == 0: 527 | # Print averaged stats 528 | pbar.set_description(f'epoch {epoch_id} psnr={psnr:.2f}') 529 | for stat_name in stats: 530 | stat_val = stats[stat_name] / args.print_every 531 | summary_writer.add_scalar(stat_name, stat_val, global_step=gstep_id) 532 | stats[stat_name] = 0.0 533 | # if args.lambda_tv > 0.0: 534 | # with torch.no_grad(): 535 | # tv = grid.tv(logalpha=args.tv_logalpha, ndc_coeffs=dset.ndc_coeffs) 536 | # summary_writer.add_scalar("loss_tv", tv, global_step=gstep_id) 537 | # if args.lambda_tv_sh > 0.0: 538 | # with torch.no_grad(): 539 | # tv_sh = grid.tv_color() 540 | # summary_writer.add_scalar("loss_tv_sh", tv_sh, global_step=gstep_id) 541 | # with torch.no_grad(): 542 | # tv_basis = grid.tv_basis() # summary_writer.add_scalar("loss_tv_basis", tv_basis, global_step=gstep_id) 543 | summary_writer.add_scalar("lr_sh", lr_sh, global_step=gstep_id) 544 | summary_writer.add_scalar("lr_sigma", lr_sigma, global_step=gstep_id) 545 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 546 | summary_writer.add_scalar("lr_basis", lr_basis, global_step=gstep_id) 547 | if grid.use_background: 548 | summary_writer.add_scalar("lr_sigma_bg", lr_sigma_bg, global_step=gstep_id) 549 | summary_writer.add_scalar("lr_color_bg", lr_color_bg, global_step=gstep_id) 550 | 551 | if args.weight_decay_sh < 1.0: 552 | grid.sh_data.data *= args.weight_decay_sigma 553 | if args.weight_decay_sigma < 1.0: 554 | grid.density_data.data *= args.weight_decay_sh 555 | 556 | # # For outputting the % sparsity of the gradient 557 | # indexer = grid.sparse_sh_grad_indexer 558 | # if indexer is not None: 559 | # if indexer.dtype == torch.bool: 560 | # nz = torch.count_nonzero(indexer) 561 | # else: 562 | # nz = indexer.size() 563 | # with open(os.path.join(args.train_dir, 'grad_sparsity.txt'), 'a') as sparsity_file: 564 | # sparsity_file.write(f"{gstep_id} {nz}\n") 565 | 566 | # Apply TV/Sparsity regularizers 567 | if args.lambda_tv > 0.0: 568 | # with Timing("tv_inpl"): 569 | grid.inplace_tv_grad(grid.density_data.grad, 570 | scaling=args.lambda_tv, 571 | sparse_frac=args.tv_sparsity, 572 | logalpha=args.tv_logalpha, 573 | ndc_coeffs=dset.ndc_coeffs, 574 | contiguous=args.tv_contiguous) 575 | if args.lambda_tv_sh > 0.0: 576 | # with Timing("tv_color_inpl"): 577 | grid.inplace_tv_color_grad(grid.sh_data.grad, 578 | scaling=args.lambda_tv_sh, 579 | sparse_frac=args.tv_sh_sparsity, 580 | ndc_coeffs=dset.ndc_coeffs, 581 | contiguous=args.tv_contiguous) 582 | if args.lambda_tv_lumisphere > 0.0: 583 | grid.inplace_tv_lumisphere_grad(grid.sh_data.grad, 584 | scaling=args.lambda_tv_lumisphere, 585 | dir_factor=args.tv_lumisphere_dir_factor, 586 | sparse_frac=args.tv_lumisphere_sparsity, 587 | ndc_coeffs=dset.ndc_coeffs) 588 | if args.lambda_l2_sh > 0.0: 589 | grid.inplace_l2_color_grad(grid.sh_data.grad, 590 | scaling=args.lambda_l2_sh) 591 | if grid.use_background and (args.lambda_tv_background_sigma > 0.0 or args.lambda_tv_background_color > 0.0): 592 | grid.inplace_tv_background_grad(grid.background_data.grad, 593 | scaling=args.lambda_tv_background_color, 594 | scaling_density=args.lambda_tv_background_sigma, 595 | sparse_frac=args.tv_background_sparsity, 596 | contiguous=args.tv_contiguous) 597 | if args.lambda_tv_basis > 0.0: 598 | tv_basis = grid.tv_basis() 599 | loss_tv_basis = tv_basis * args.lambda_tv_basis 600 | loss_tv_basis.backward() 601 | # print('nz density', torch.count_nonzero(grid.sparse_grad_indexer).item(), 602 | # ' sh', torch.count_nonzero(grid.sparse_sh_grad_indexer).item()) 603 | 604 | # Manual SGD/rmsprop step 605 | if gstep_id >= args.lr_fg_begin_step: 606 | grid.optim_density_step(lr_sigma, beta=args.rms_beta, optim=args.sigma_optim) 607 | grid.optim_sh_step(lr_sh, beta=args.rms_beta, optim=args.sh_optim) 608 | if grid.use_background: 609 | grid.optim_background_step(lr_sigma_bg, lr_color_bg, beta=args.rms_beta, optim=args.bg_optim) 610 | if gstep_id >= args.lr_basis_begin_step: 611 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 612 | grid.optim_basis_step(lr_basis, beta=args.rms_beta, optim=args.basis_optim) 613 | elif grid.basis_type == svox2.BASIS_TYPE_MLP: 614 | optim_basis_mlp.step() 615 | optim_basis_mlp.zero_grad() 616 | 617 | train_step() 618 | gc.collect() 619 | gstep_id_base += batches_per_epoch 620 | 621 | # ckpt_path = path.join(args.train_dir, f'ckpt_{epoch_id:05d}.npz') 622 | # Overwrite prev checkpoints since they are very huge 623 | if args.save_every > 0 and (epoch_id + 1) % max( 624 | factor, args.save_every) == 0 and not args.tune_mode: 625 | print('Saving', ckpt_path) 626 | grid.save(ckpt_path) 627 | 628 | if (gstep_id_base - last_upsamp_step) >= args.upsamp_every: 629 | last_upsamp_step = gstep_id_base 630 | if reso_id < len(reso_list) - 1: 631 | print('* Upsampling from', reso_list[reso_id], 'to', reso_list[reso_id + 1]) 632 | if args.tv_early_only > 0: 633 | print('turning off TV regularization') 634 | args.lambda_tv = 0.0 635 | args.lambda_tv_sh = 0.0 636 | elif args.tv_decay != 1.0: 637 | args.lambda_tv *= args.tv_decay 638 | args.lambda_tv_sh *= args.tv_decay 639 | 640 | reso_id += 1 641 | use_sparsify = True 642 | z_reso = reso_list[reso_id] if isinstance(reso_list[reso_id], int) else reso_list[reso_id][2] 643 | grid.resample(reso=reso_list[reso_id], 644 | sigma_thresh=args.density_thresh, 645 | weight_thresh=args.weight_thresh / z_reso if use_sparsify else 0.0, 646 | dilate=2, #use_sparsify, 647 | cameras=resample_cameras if args.thresh_type == 'weight' else None, 648 | max_elements=args.max_grid_elements) 649 | 650 | if grid.use_background and reso_id <= 1: 651 | grid.sparsify_background(args.background_density_thresh) 652 | 653 | if args.upsample_density_add: 654 | grid.density_data.data[:] += args.upsample_density_add 655 | 656 | if factor > 1 and reso_id < len(reso_list) - 1: 657 | print('* Using higher resolution images due to large grid; new factor', factor) 658 | factor //= 2 659 | dset.gen_rays(factor=factor) 660 | dset.shuffle_rays() 661 | 662 | if gstep_id_base >= args.n_iters: 663 | print('* Final eval and save') 664 | eval_step() 665 | global_stop_time = datetime.now() 666 | secs = (global_stop_time - global_start_time).total_seconds() 667 | timings_file = open(os.path.join(args.train_dir, 'time_mins.txt'), 'a') 668 | timings_file.write(f"{secs / 60}\n") 669 | if not args.tune_nosave: 670 | grid.save(ckpt_path) 671 | break 672 | -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import argparse 5 | import cv2 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('data_dir', type=str) 11 | parser.add_argument('-n','--frames_num', type=int, default=300) 12 | parser.add_argument('-o','--output_dir', type=str, default=None) 13 | return parser.parse_args() 14 | 15 | """ 16 | Note: please install opencv first 17 | Extract every frame from video 18 | """ 19 | def opencv_extractor(filename, outdir, prefix): 20 | cap = cv2.VideoCapture(filename) 21 | fps = cap.get(cv2.CAP_PROP_FPS) 22 | count = 0 23 | while cap.isOpened(): 24 | is_read, frame = cap.read() 25 | if not is_read: 26 | break 27 | cv2.imwrite(os.path.join(outdir, f'{count:04d}', 'images', prefix+'.png'), frame) 28 | count += 1 29 | 30 | 31 | def prepare_data(data_dir, outdir, frame_num): 32 | files = os.listdir(data_dir) 33 | pose_bound = os.path.join(data_dir, 'poses_bounds.npy') 34 | assert os.path.exists(pose_bound), f'{pose_bound} file not found' 35 | print(f'step 1: copy [poses_bounds] to each directory') 36 | for i in range(frame_num): 37 | os.makedirs(os.path.join(outdir, f'{i:04d}'), exist_ok=True) 38 | shutil.copy(pose_bound, os.path.join(outdir, f'{i:04d}', 'poses_bounds.npy')) 39 | os.makedirs(os.path.join(outdir, f'{i:04d}', 'images'), exist_ok=True) 40 | 41 | print(f'step 2: extract frames from video') 42 | videos = [osp.join(data_dir, f) for f in files if f.endswith('mp4')] 43 | for video in videos: 44 | print(f'processing {video}') 45 | cam_num = int(video.split('/')[-1].split('.')[0].split('_')[1]) 46 | opencv_extractor(video, outdir, prefix=f'cam{cam_num:02d}') 47 | 48 | print(f'Done!') 49 | 50 | 51 | 52 | if __name__ == '__main__': 53 | args = parse_args() 54 | datadir = args.data_dir 55 | frameN = args.frames_num 56 | output_dir = args.output_dir 57 | if output_dir is None: 58 | output_dir = os.path.join(datadir, 'frames') 59 | os.makedirs(output_dir, exist_ok=True) 60 | prepare_data(datadir, output_dir, frameN) 61 | -------------------------------------------------------------------------------- /render_delta.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5 2 | from operator import index 3 | from pydoc import describe 4 | import torch 5 | import torch.cuda 6 | import torch.optim 7 | import torch.nn.functional as F 8 | import svox2 9 | import svox2.csrc as _C 10 | import svox2.utils 11 | import json 12 | import imageio 13 | import os 14 | from os import path 15 | import shutil 16 | import gc 17 | import numpy as np 18 | import math 19 | import argparse 20 | from util.dataset import datasets 21 | from util.util import Timing, get_expon_lr_func, viridis_cmap 22 | from util import config_util 23 | 24 | from warnings import warn 25 | from datetime import datetime 26 | from torch.utils.tensorboard import SummaryWriter 27 | 28 | from tqdm import tqdm 29 | from typing import NamedTuple, Optional, Union 30 | from loguru import logger 31 | import time 32 | 33 | # runtime_svox2file = os.path.join(os.path.dirname(svox2.__file__), 'svox2.py') 34 | # update_svox2file = '../svox2/svox2.py' 35 | # if md5(open(runtime_svox2file,'rb').read()).hexdigest() != md5(open(update_svox2file,'rb').read()).hexdigest(): 36 | # raise Exception("Not INSTALL the NEWEST svox2.py") 37 | 38 | device = "cuda" if torch.cuda.is_available() else "cpu" 39 | 40 | parser = argparse.ArgumentParser() 41 | config_util.define_common_args(parser) 42 | 43 | group = parser.add_argument_group("general") 44 | group.add_argument('--train_dir', '-t', type=str, default='ckpt', 45 | help='checkpoint and logging directory') 46 | group.add_argument('--basis_type', 47 | choices=['sh', '3d_texture', 'mlp'], 48 | default='sh', 49 | help='Basis function type') 50 | group.add_argument('--sh_dim', type=int, default=9, help='SH/learned basis dimensions (at most 10)') 51 | 52 | group = parser.add_argument_group("optimization") 53 | group.add_argument('--n_iters', type=int, default=10 * 12800, help='total number of iters to optimize for') 54 | group.add_argument('--batch_size', type=int, default= 55 | 20000, 56 | help='batch size') 57 | group.add_argument('--sigma_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Density optimizer") 58 | group.add_argument('--lr_sigma', type=float, default=3e1, help='SGD/rmsprop lr for sigma') 59 | group.add_argument('--lr_sigma_final', type=float, default=5e-2) 60 | group.add_argument('--lr_sigma_decay_steps', type=int, default=250000) 61 | group.add_argument('--lr_sigma_delay_steps', type=int, default=15000, 62 | help="Reverse cosine steps (0 means disable)") 63 | group.add_argument('--lr_sigma_delay_mult', type=float, default=1e-2)#1e-4)#1e-4) 64 | 65 | 66 | group.add_argument('--sh_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="SH optimizer") 67 | group.add_argument('--lr_sh', type=float, default=1e-2,help='SGD/rmsprop lr for SH') 68 | group.add_argument('--lr_sh_final', type=float,default=5e-6) 69 | group.add_argument('--lr_sh_decay_steps', type=int, default=250000) 70 | group.add_argument('--lr_sh_delay_steps', type=int, default=0, help="Reverse cosine steps (0 means disable)") 71 | group.add_argument('--lr_sh_delay_mult', type=float, default=1e-2) 72 | 73 | group.add_argument('--lr_fg_begin_step', type=int, default=0, help="Foreground begins training at given step number") 74 | 75 | group.add_argument('--rms_beta', type=float, default=0.95, help="RMSProp exponential averaging factor") 76 | 77 | group.add_argument('--print_every', type=int, default=20, help='print every') 78 | group.add_argument('--save_every', type=int, default=5, 79 | help='save every x epochs') 80 | group.add_argument('--eval_every', type=int, default=1, 81 | help='evaluate every x epochs') 82 | 83 | group.add_argument('--init_sigma', type=float, 84 | default=0.1, 85 | help='initialization sigma') 86 | group.add_argument('--log_mse_image', action='store_true', default=False) 87 | group.add_argument('--log_depth_map', action='store_true', default=False) 88 | group.add_argument('--log_depth_map_use_thresh', type=float, default=None, 89 | help="If specified, uses the Dex-neRF version of depth with given thresh; else returns expected term") 90 | 91 | 92 | group = parser.add_argument_group("misc experiments") 93 | group.add_argument('--thresh_type', 94 | choices=["weight", "sigma"], 95 | default="weight", 96 | help='Upsample threshold type') 97 | group.add_argument('--weight_thresh', type=float, 98 | default=0.0005 * 512, 99 | # default=0.025 * 512, 100 | help='Upsample weight threshold; will be divided by resulting z-resolution') 101 | group.add_argument('--density_thresh', type=float, 102 | default=5.0, 103 | help='Upsample sigma threshold') 104 | group.add_argument('--background_density_thresh', type=float, 105 | default=1.0+1e-9, 106 | help='Background sigma threshold for sparsification') 107 | group.add_argument('--max_grid_elements', type=int, 108 | default=44_000_000, 109 | help='Max items to store after upsampling ' 110 | '(the number here is given for 22GB memory)') 111 | 112 | # group.add_argument('--tune_mode', action='store_true', default=False, 113 | # help='hypertuning mode (do not save, for speed)') 114 | group.add_argument('--tune_nosave', action='store_true', default=False, 115 | help='do not save any checkpoint even at the end') 116 | group = parser.add_argument_group("losses") 117 | # Foreground TV 118 | group.add_argument('--lambda_tv', type=float, default=1e-5) 119 | group.add_argument('--tv_sparsity', type=float, default=0.01) 120 | group.add_argument('--tv_logalpha', action='store_true', default=False, 121 | help='Use log(1-exp(-delta * sigma)) as in neural volumes') 122 | 123 | group.add_argument('--lambda_tv_sh', type=float, default=1e-3) 124 | group.add_argument('--tv_sh_sparsity', type=float, default=0.01) 125 | 126 | group.add_argument('--lambda_tv_lumisphere', type=float, default=0.0)#1e-2)#1e-3) 127 | group.add_argument('--tv_lumisphere_sparsity', type=float, default=0.01) 128 | group.add_argument('--tv_lumisphere_dir_factor', type=float, default=0.0) 129 | 130 | group.add_argument('--tv_decay', type=float, default=1.0) 131 | 132 | group.add_argument('--lambda_l2_sh', type=float, default=0.0)#1e-4) 133 | group.add_argument('--tv_early_only', type=int, default=1, help="Turn off TV regularization after the first split/prune") 134 | 135 | group.add_argument('--tv_contiguous', type=int, default=1, 136 | help="Apply TV only on contiguous link chunks, which is faster") 137 | # End Foreground TV 138 | 139 | group.add_argument('--lr_decay', action='store_true', default=True) 140 | group.add_argument('--n_train', type=int, default=None, help='Number of training images. Defaults to use all avaiable.') 141 | 142 | 143 | group.add_argument('--lambda_sparsity', type=float, default= 144 | 0.0, 145 | help="Weight for sparsity loss as in SNeRG/PlenOctrees " + 146 | "(but applied on the ray)") 147 | group.add_argument('--lambda_beta', type=float, default= 148 | 0.0, 149 | help="Weight for beta distribution sparsity loss as in neural volumes") 150 | 151 | # ---------------- Finetune video related-------------- 152 | group = parser.add_argument_group("finetune") 153 | group.add_argument('--pretrained', type=str, default=None, 154 | help='pretrained model') 155 | group.add_argument('--strategy', type=int, default=0, 156 | help='specfic sample startegy') 157 | group.add_argument('--mask_grad_after_reg', type=int, default=1, 158 | help='mask out unwanted gradient after TV and other regularization') 159 | group.add_argument('--view_count_thres', type=int, default=-1, 160 | help='mask out unwanted gradient after TV and other regularization') 161 | 162 | group.add_argument('--frame_start', type=int, default=1, help='train frame among [frame_start, frame_end]') 163 | group.add_argument('--frame_end', type=int, default=30, help='train frame among [1, frame_end]') 164 | group.add_argument('--fps', type=int, default=30, help='video save fps') 165 | group.add_argument('--save_every_frame', action='store_true', default=False) 166 | group.add_argument('--dilate_rate', type=int, default=2, help="dilation rate for grid.links") 167 | group.add_argument('--use_grad_mask', action="store_true", default=False, help="dilation rate for grid.links") 168 | group.add_argument('--offset', type=int, default=250) 169 | 170 | group.add_argument('--sh_keep_thres', type=float, default=1) 171 | group.add_argument('--performance_mode', action="store_true", default=False, help="use perfomance_mode skip any unecessary code ") 172 | group.add_argument('--debug', action="store_true", default=False,help="switch on debug mode") 173 | group.add_argument('--keep_rms_data', action="store_true", default=False,help="switch on debug mode") 174 | 175 | 176 | group.add_argument('--render_all', action="store_true", default=False,help="render all camera in sequence") 177 | 178 | 179 | args = parser.parse_args() 180 | config_util.maybe_merge_config_file(args) 181 | 182 | DEBUG = args.debug 183 | assert args.lr_sigma_final <= args.lr_sigma, "lr_sigma must be >= lr_sigma_final" 184 | assert args.lr_sh_final <= args.lr_sh, "lr_sh must be >= lr_sh_final" 185 | 186 | os.makedirs(args.train_dir, exist_ok=True) 187 | os.makedirs(os.path.join(args.train_dir, 'test_images_sc'), exist_ok=True) 188 | os.makedirs(os.path.join(args.train_dir, 'test_images_depth_sc'), exist_ok=True) 189 | 190 | summary_writer = SummaryWriter(args.train_dir) 191 | 192 | with open(path.join(args.train_dir, 'args.json'), 'w') as f: 193 | json.dump(args.__dict__, f, indent=2) 194 | # Changed name to prevent errors 195 | shutil.copyfile(__file__, path.join(args.train_dir, 'opt_frozen.py')) 196 | 197 | torch.manual_seed(20200823) 198 | np.random.seed(20200823) 199 | 200 | assert os.path.exists(args.pretrained), "pretrained model not exist, please train the first frame!" 201 | print("Load pretrained model from ", args.pretrained) 202 | grid = svox2.SparseGrid.load(args.pretrained, device=device) 203 | config_util.setup_render_opts(grid.opt, args) 204 | print("Load pretrained model Done!") 205 | 206 | from copy import deepcopy 207 | from torch import nn 208 | 209 | 210 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 211 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) 212 | 213 | 214 | def grid_copy( old_grid: svox2.SparseGrid, device: Union[torch.device, str] = "cpu"): 215 | """ 216 | Load from path 217 | """ 218 | 219 | sh_data = old_grid.sh_data.clone() 220 | density_data = old_grid.density_data.clone() 221 | logger.error(f"copy grid cap {(old_grid.links>=0).sum()}") 222 | if hasattr(old_grid, "background_links") : 223 | background_data = old_grid.background_data 224 | background_links = old_grid.background_links 225 | else: 226 | background_data = None 227 | background_links = None 228 | 229 | links = old_grid.links.clone() 230 | basis_dim = (sh_data.shape[1]) // 3 231 | radius = deepcopy(old_grid.radius ) 232 | center = deepcopy(old_grid.center) 233 | grid_new = svox2.SparseGrid( 234 | 1, 235 | radius=radius, 236 | center=center, 237 | basis_dim=basis_dim, 238 | use_z_order=False, 239 | device="cpu", 240 | basis_type=old_grid.basis_type , 241 | mlp_posenc_size=old_grid.mlp_posenc_size, 242 | mlp_width=old_grid.mlp_width, 243 | background_nlayers=0, 244 | ) 245 | 246 | grid_new.viewcount_helper = torch.zeros_like(density_data, dtype=torch.int, device=device) 247 | grid_new.sh_data = nn.Parameter(sh_data).to(device=device) 248 | grid_new.density_data = nn.Parameter(density_data).to(device=device) 249 | grid_new.links = links.to(device=device) # torch.from_numpy(links).to(device=device) 250 | grid_new.capacity = grid_new.sh_data.size(0) 251 | if args.keep_rms_data: 252 | grid_new.sh_rms = old_grid.sh_rms 253 | grid_new.density_rms = old_grid.density_rms 254 | 255 | if background_data is not None: 256 | background_data = torch.from_numpy(background_data).to(device=device) 257 | grid_new.background_nlayers = background_data.shape[1] 258 | grid_new.background_reso = background_links.shape[1] 259 | grid_new.background_data = nn.Parameter(background_data) 260 | grid_new.background_links = torch.from_numpy(background_links).to(device=device) 261 | else: 262 | grid_new.background_data.data = grid_new.background_data.data.to(device=device) 263 | 264 | if grid_new.links.is_cuda: 265 | grid_new.accelerate() 266 | config_util.setup_render_opts(grid_new.opt, args) 267 | return grid_new 268 | 269 | from torch import nn 270 | 271 | 272 | def compress_loading(grid_pre, delta_path): 273 | 274 | delta = torch.load(delta_path) 275 | 276 | mask_next = delta['mask_next'].to(device) 277 | addition_density = delta['addition_density'].to(device) 278 | addition_sh = delta['addition_sh'].to(device) 279 | keep_density = delta['keep_density'].to(device) 280 | keep_sh = delta['keep_sh'].to(device) 281 | part2_keep_area =delta['part2_keep_area'] .to(device) 282 | 283 | mask_pre = grid_pre.links>=0 284 | new_cap = mask_next.sum() 285 | diff_area = torch.logical_xor(mask_pre, mask_next) 286 | 287 | add_area = (diff_area & mask_next) 288 | minus_area = (diff_area & mask_pre) 289 | 290 | 291 | (abs(addition_sh).sum(-1)<=0.9).sum() 292 | # import ipdb;ipdb.set_trace() 293 | 294 | logger.debug(f"diff area: {diff_area.sum()} add area: {add_area.sum()} minus area: {minus_area.sum()} ") 295 | remain_idx = grid_pre.links[mask_pre & ~ minus_area] 296 | remain_idx = remain_idx.long() 297 | 298 | remain_sh_data = grid_pre.sh_data[remain_idx] 299 | remain_density_data = grid_pre.density_data[remain_idx] 300 | 301 | 302 | new_sh_data = torch.zeros((new_cap,27), device=device).float() 303 | new_density_data = torch.zeros((new_cap,1), device=device).float() 304 | 305 | add_area_in_saprse = add_area[mask_next] 306 | 307 | # we also save voxel where sh change a lot 308 | 309 | 310 | # import ipdb;ipdb.set_trace() 311 | keep_numel = part2_keep_area.sum() 312 | add_numel = add_area.sum() 313 | 314 | keep_percent = (keep_numel/new_cap) * 100 315 | add_percent = (add_numel/new_cap) * 100 316 | keep_size = (keep_numel*2*28)/(1024*1024) 317 | add_size = (add_numel*2*28)/(1024*1024) 318 | 319 | logger.info(f"keep element: {keep_numel}/{keep_percent:.2f}/{keep_size:.2f} MB, add element: {add_numel}/{add_percent:.2f}/{add_size:.2f} MB") 320 | 321 | remain_sh_data[part2_keep_area] = keep_sh 322 | remain_density_data[part2_keep_area] = keep_density 323 | 324 | new_sh_data[add_area_in_saprse,:] = addition_sh 325 | new_density_data[add_area_in_saprse,:] = addition_density 326 | new_sh_data[~add_area_in_saprse,:] = remain_sh_data 327 | new_density_data[~add_area_in_saprse,:] = remain_density_data 328 | 329 | # though new_links equal to grid_next.links, we still calculate a mask for better code scalability 330 | new_mask = torch.logical_or(add_area, mask_pre) 331 | new_mask = torch.logical_and(new_mask, ~minus_area) 332 | new_links = torch.cumsum(new_mask.view(-1).to(torch.int32), dim=-1).int() - 1 333 | new_links[~new_mask.view(-1)] = -1 334 | 335 | # import ipdb;ipdb.set_trace() 336 | grid_pre.sh_data = nn.Parameter(new_sh_data) 337 | grid_pre.density_data = nn.Parameter(new_density_data) 338 | grid_pre.links = new_links.view(grid_pre.links.shape).to(device=device) 339 | 340 | return grid_pre 341 | 342 | 343 | 344 | def dilated_voxel_grid(dilate_rate = 2): 345 | active_mask = grid.links >= 0 346 | dilate_before = active_mask 347 | for i in range(dilate_rate): 348 | active_mask = _C.dilate(active_mask) 349 | # reactivate = torch.logical_xor(active_mask, dilate_before) 350 | new_cap = active_mask.sum() 351 | previous_sparse_area = dilate_before[active_mask] 352 | 353 | new_density = torch.zeros((new_cap,1), device=device).float() 354 | new_sh = torch.zeros((new_cap, grid.basis_dim*3), device=device).float() 355 | 356 | new_density[previous_sparse_area,:] = grid.density_data.data 357 | new_sh[previous_sparse_area,:] = grid.sh_data.data 358 | 359 | active_mask = active_mask.view(-1) 360 | new_links = torch.cumsum(active_mask.to(torch.int32), dim=-1).int() - 1 361 | new_links[~active_mask] = -1 362 | 363 | grid.density_data = torch.nn.Parameter(new_density) 364 | grid.sh_data = torch.nn.Parameter(new_sh) 365 | grid.links = new_links.view(grid.links.shape).to(device=device) 366 | 367 | 368 | 369 | def sparsify_voxel_grid(): 370 | reso = grid.links.shape 371 | grid.resample(reso=reso, 372 | sigma_thresh=args.density_thresh, 373 | weight_thresh=0.0, 374 | dilate=2, 375 | cameras= None, 376 | max_elements=args.max_grid_elements, 377 | accelerate=False) 378 | 379 | def sparsify_voxel_grid_fast(dilate=2): 380 | reso = grid.links.shape 381 | sample_vals_mask = grid.density_data >= args.density_thresh 382 | max_elements=args.max_grid_elements 383 | if max_elements > 0 and max_elements < grid.density_data.numel() \ 384 | and max_elements < torch.count_nonzero(sample_vals_mask): 385 | # To bound the memory usage 386 | sigma_thresh_bounded = torch.topk(grid.density_data.view(-1), 387 | k=max_elements, sorted=False).values.min().item() 388 | sigma_thresh = max(sigma_thresh, sigma_thresh_bounded) 389 | print(' Readjusted sigma thresh to fit to memory:', sigma_thresh) 390 | sample_vals_mask = grid.density_data >= sigma_thresh 391 | if grid.opt.last_sample_opaque: 392 | # Don't delete the last z layer 393 | sample_vals_mask[:, :, -1] = 1 394 | if dilate: 395 | for i in range(int(dilate)): 396 | sample_vals_mask = _C.dilate(sample_vals_mask) 397 | sample_vals_density = grid.density_data[sample_vals_mask] 398 | cnz = torch.count_nonzero(sample_vals_mask).item() 399 | sample_vals_sh = grid.sh_data[sample_vals_mask] 400 | init_links = ( 401 | torch.cumsum(sample_vals_mask.to(torch.int32), dim=-1).int() - 1 402 | ) 403 | init_links[~sample_vals_mask] = -1 404 | grid.capacity = cnz 405 | print(" New cap:", grid.capacity) 406 | del sample_vals_mask 407 | print('density', sample_vals_density.shape, sample_vals_density.dtype) 408 | print('sh', sample_vals_sh.shape, sample_vals_sh.dtype) 409 | print('links', init_links.shape, init_links.dtype) 410 | grid.density_data = nn.Parameter(sample_vals_density.view(-1, 1).to(device=device)) 411 | grid.sh_data = nn.Parameter(sample_vals_sh.to(device=device)) 412 | grid.links = init_links.view(reso).to(device=device) 413 | 414 | if args.dilate_rate > 0: 415 | logger.info("sparsify first!!!!") 416 | 417 | # grid_pre = grid_copy(grid, device=device) 418 | sparsify_voxel_grid() 419 | # slow_sparse_grid = grid_copy(grid, device=device) 420 | # grid = grid_pre 421 | # sparsify_voxel_grid() 422 | # import ipdb;ipdb.set_trace() 423 | 424 | # LR related 425 | lr_sigma_func = get_expon_lr_func(args.lr_sigma, args.lr_sigma_final, args.lr_sigma_delay_steps, 426 | args.lr_sigma_delay_mult, args.lr_sigma_decay_steps) 427 | lr_sh_func = get_expon_lr_func(args.lr_sh, args.lr_sh_final, args.lr_sh_delay_steps, 428 | args.lr_sh_delay_mult, args.lr_sh_decay_steps) 429 | lr_sigma_factor = 1.0 430 | lr_sh_factor = 1.0 431 | 432 | 433 | 434 | grid_raw = grid_copy(grid, device=device) 435 | 436 | 437 | 438 | 439 | 440 | def eval_one_frame(frame_idx, global_step_base): 441 | 442 | 443 | data_dir = os.path.join(args.data_dir, f'{frame_idx:04d}') 444 | train_dir = args.train_dir 445 | factor = 1 446 | 447 | dset_test = datasets[args.dataset_type]( 448 | data_dir, split= 'train' if args.render_all else "test", train_use_all=1 if args.render_all else 0,offset=args.offset, **config_util.build_data_options(args)) 449 | 450 | dset_eval = datasets[args.dataset_type]( 451 | data_dir, split="test", train_use_all=0,offset=args.offset, **config_util.build_data_options(args)) 452 | 453 | torch.save(dset_test.c2w,os.path.join(args.train_dir, 'c2w.pth')) 454 | epoch_id = -1 455 | 456 | gstep_id_base = 0 457 | 458 | # dset_motion.epoch_size = args.n_iters * args.batch_size 459 | # indexer_motion = dset_motion.shuffle_rays(strategy=strategy, replace=replacement) 460 | 461 | delta_path = os.path.join(args.train_dir,'grid_delta_full',f'{frame_idx:04d}.pth') 462 | compress_loading(grid_pre=grid, delta_path=delta_path) 463 | 464 | 465 | 466 | def eval_step(): 467 | 468 | stats_test = {'mse' : 0.0, 'psnr' : 0.0} 469 | # Standard set 470 | N_IMGS_TO_EVAL = min(20 if epoch_id > 0 else 5, dset_eval.n_images) 471 | N_IMGS_TO_SAVE = N_IMGS_TO_EVAL # if not args.tune_mode else 1 472 | img_eval_interval = dset_eval.n_images // N_IMGS_TO_EVAL 473 | img_save_interval = (N_IMGS_TO_EVAL // N_IMGS_TO_SAVE) 474 | img_ids = range(0, dset_eval.n_images, img_eval_interval) 475 | n_images_gen = 0 476 | for i, img_id in tqdm(enumerate(img_ids), total=len(img_ids)): 477 | c2w = torch.from_numpy(dset_eval.c2w[img_id]).to(device=device) 478 | cam = svox2.Camera(c2w, 479 | dset_eval.intrins.get('fx', img_id), 480 | dset_eval.intrins.get('fy', img_id), 481 | dset_eval.intrins.get('cx', img_id), 482 | dset_eval.intrins.get('cy', img_id), 483 | width=dset_eval.get_image_size(img_id)[1], 484 | height=dset_eval.get_image_size(img_id)[0], 485 | ndc_coeffs=dset_eval.ndc_coeffs) 486 | rgb_pred_test = grid.volume_render_image(cam, use_kernel=True) 487 | rgb_gt_test = torch.from_numpy(dset_eval.gt[img_id]).to(device=device) 488 | all_mses = ((rgb_gt_test - rgb_pred_test) ** 2).cpu() 489 | if i % img_save_interval == 0: 490 | img_pred = rgb_pred_test.cpu() 491 | img_pred.clamp_max_(1.0) 492 | summary_writer.add_image(f'test/image_{img_id:04d}', 493 | img_pred, global_step=frame_idx, dataformats='HWC') 494 | if args.log_mse_image: 495 | mse_img = all_mses / all_mses.max() 496 | summary_writer.add_image(f'test/mse_map_{img_id:04d}', 497 | mse_img, global_step=frame_idx, dataformats='HWC') 498 | if False or args.log_depth_map: 499 | depth_img = grid.volume_render_depth_image(cam, 500 | args.log_depth_map_use_thresh if 501 | args.log_depth_map_use_thresh else None 502 | ) 503 | depth_img = viridis_cmap(depth_img.cpu()) 504 | summary_writer.add_image(f'test/depth_map_{img_id:04d}', 505 | depth_img, 506 | global_step=frame_idx, dataformats='HWC') 507 | 508 | rgb_pred_test = rgb_gt_test = None 509 | mse_num : float = all_mses.mean().item() 510 | psnr = -10.0 * math.log10(mse_num) 511 | if math.isnan(psnr): 512 | print('NAN PSNR', i, img_id, mse_num) 513 | assert False 514 | stats_test['mse'] += mse_num 515 | stats_test['psnr'] += psnr 516 | n_images_gen += 1 517 | stats_test['mse'] /= n_images_gen 518 | stats_test['psnr'] /= n_images_gen 519 | for stat_name in stats_test: 520 | summary_writer.add_scalar('test/' + stat_name, 521 | stats_test[stat_name], global_step=gstep_id_base+global_step_base) 522 | summary_writer.add_scalar('epoch_id', float(epoch_id), global_step=gstep_id_base+global_step_base) 523 | print('eval stats:', stats_test) 524 | 525 | 526 | 527 | eval_step() 528 | 529 | 530 | def render_img(): 531 | 532 | c2ws = torch.from_numpy(dset_test.c2w).to(device=device) 533 | 534 | n_images = dset_test.n_images 535 | img_eval_interval = 1 536 | for img_id in tqdm(range(0, n_images, img_eval_interval)): 537 | 538 | dset_h, dset_w = dset_test.get_image_size(img_id) 539 | im_size = dset_h * dset_w 540 | w = dset_w #if args.crop == 1.0 else int(dset_w * args.crop) 541 | h = dset_h #if args.crop == 1.0 else int(dset_h * args.crop) 542 | 543 | if args.render_all: 544 | im_path = os.path.join(train_dir, 'test_images_sc', f'{frame_idx:04d}_{img_id:02d}.png' ) 545 | depth_path = os.path.join(train_dir, 'test_images_depth_sc', f'{frame_idx:04d}_{img_id:02d}.png' ) 546 | else: 547 | im_path = os.path.join(train_dir, 'test_images_sc', f'{frame_idx:04d}.png' ) 548 | cam = svox2.Camera(c2ws[img_id], 549 | dset_test.intrins.get('fx', img_id), 550 | dset_test.intrins.get('fy', img_id), 551 | dset_test.intrins.get('cx', img_id) + (w - dset_w) * 0.5, 552 | dset_test.intrins.get('cy', img_id) + (h - dset_h) * 0.5, 553 | w, h, 554 | ndc_coeffs=dset_test.ndc_coeffs) 555 | tic = time.time() 556 | 557 | im = grid.volume_render_image(cam, use_kernel=True, return_raylen=False) 558 | if DEBUG: 559 | torch.cuda.synchronize() 560 | logger.debug(f'rgb rendeing time: {time.time() - tic}') 561 | im.clamp_(0.0, 1.0) 562 | im = im.cpu().numpy() 563 | im = (im * 255).astype(np.uint8) 564 | imageio.imwrite(im_path, im) 565 | if not args.render_all: 566 | break 567 | return im 568 | 569 | with torch.no_grad(): 570 | eval_step() 571 | return render_img() 572 | 573 | 574 | 575 | train_start_time = datetime.now() 576 | train_frame_num = 0 577 | global_step_base = 0 578 | frames = [] 579 | for frame_idx in range(args.frame_start, args.frame_end): 580 | frames.append(eval_one_frame(frame_idx, global_step_base)) 581 | global_step_base += args.n_iters 582 | train_frame_num += 1 583 | 584 | 585 | train_end_time = datetime.now() 586 | secs = (train_end_time-train_start_time).total_seconds() 587 | 588 | if train_frame_num: 589 | average_time = secs / train_frame_num 590 | print(f'train {train_frame_num} images, cost {secs} s, average {average_time}s per image') 591 | tag = os.path.basename(args.train_dir) 592 | vid_path = os.path.join(args.train_dir, tag+'_from_saved_delta.mp4') 593 | # dep_vid_path = os.path.join(args.train_dir, 'render_depth.mp4') 594 | imageio.mimwrite(vid_path, frames, fps=args.fps, macro_block_size=8) 595 | print('video write into', vid_path) 596 | print('Final save ckpt') 597 | 598 | -------------------------------------------------------------------------------- /train_video_n3dv_base.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5 2 | from multiprocessing import process 3 | from operator import index 4 | from pydoc import describe 5 | import torch 6 | import torch.cuda 7 | import torch.optim 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | import svox2 13 | import svox2.csrc as _C 14 | import svox2.utils 15 | import json 16 | import imageio 17 | import os 18 | from os import path 19 | import time 20 | import shutil 21 | import gc 22 | import math 23 | import argparse 24 | 25 | import numpy as np 26 | 27 | from util.dataset import datasets 28 | from util.util import Timing, get_expon_lr_func, viridis_cmap 29 | from util import config_util 30 | 31 | from warnings import warn 32 | from datetime import datetime 33 | 34 | from torch.utils.tensorboard import SummaryWriter 35 | 36 | from tqdm import tqdm 37 | from typing import NamedTuple, Optional, Union 38 | from loguru import logger 39 | from multiprocess import Pool 40 | 41 | 42 | # runtime_svox2file = os.path.join(os.path.dirname(svox2.__file__), 'svox2.py') 43 | # update_svox2file = '../svox2/svox2.py' 44 | # if md5(open(runtime_svox2file,'rb').read()).hexdigest() != md5(open(update_svox2file,'rb').read()).hexdigest(): 45 | # raise Exception("Not INSTALL the NEWEST svox2.py") 46 | 47 | device = "cuda" if torch.cuda.is_available() else "cpu" 48 | 49 | parser = argparse.ArgumentParser() 50 | config_util.define_common_args(parser) 51 | 52 | group = parser.add_argument_group("general") 53 | group.add_argument('--train_dir', '-t', type=str, default='ckpt', 54 | help='checkpoint and logging directory') 55 | group.add_argument('--basis_type', 56 | choices=['sh', '3d_texture', 'mlp'], 57 | default='sh', 58 | help='Basis function type') 59 | group.add_argument('--sh_dim', type=int, default=9, help='SH/learned basis dimensions (at most 10)') 60 | 61 | group = parser.add_argument_group("optimization") 62 | group.add_argument('--n_iters', type=int, default=10 * 12800, help='total number of iters to optimize for') 63 | group.add_argument('--batch_size', type=int, default= 64 | 20000, 65 | help='batch size') 66 | group.add_argument('--sigma_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Density optimizer") 67 | group.add_argument('--lr_sigma', type=float, default=3e1, help='SGD/rmsprop lr for sigma') 68 | group.add_argument('--lr_sigma_final', type=float, default=5e-2) 69 | group.add_argument('--lr_sigma_decay_steps', type=int, default=250000) 70 | group.add_argument('--lr_sigma_delay_steps', type=int, default=15000, 71 | help="Reverse cosine steps (0 means disable)") 72 | group.add_argument('--lr_sigma_delay_mult', type=float, default=1e-2)#1e-4)#1e-4) 73 | 74 | 75 | group.add_argument('--sh_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="SH optimizer") 76 | group.add_argument('--lr_sh', type=float, default=1e-2,help='SGD/rmsprop lr for SH') 77 | group.add_argument('--lr_sh_final', type=float,default=5e-6) 78 | group.add_argument('--lr_sh_decay_steps', type=int, default=250000) 79 | group.add_argument('--lr_sh_delay_steps', type=int, default=0, help="Reverse cosine steps (0 means disable)") 80 | group.add_argument('--lr_sh_delay_mult', type=float, default=1e-2) 81 | 82 | group.add_argument('--lr_fg_begin_step', type=int, default=0, help="Foreground begins training at given step number") 83 | 84 | group.add_argument('--rms_beta', type=float, default=0.95, help="RMSProp exponential averaging factor") 85 | 86 | group.add_argument('--print_every', type=int, default=20, help='print every') 87 | group.add_argument('--save_every', type=int, default=5, 88 | help='save every x epochs') 89 | group.add_argument('--eval_every', type=int, default=1, 90 | help='evaluate every x epochs') 91 | 92 | group.add_argument('--init_sigma', type=float, 93 | default=0.1, 94 | help='initialization sigma') 95 | group.add_argument('--log_mse_image', action='store_true', default=False) 96 | group.add_argument('--log_depth_map', action='store_true', default=False) 97 | group.add_argument('--log_depth_map_use_thresh', type=float, default=None, 98 | help="If specified, uses the Dex-neRF version of depth with given thresh; else returns expected term") 99 | 100 | 101 | group = parser.add_argument_group("misc experiments") 102 | group.add_argument('--thresh_type', 103 | choices=["weight", "sigma"], 104 | default="weight", 105 | help='Upsample threshold type') 106 | group.add_argument('--weight_thresh', type=float, 107 | default=0.0005 * 512, 108 | # default=0.025 * 512, 109 | help='Upsample weight threshold; will be divided by resulting z-resolution') 110 | group.add_argument('--density_thresh', type=float, 111 | default=5.0, 112 | help='Upsample sigma threshold') 113 | group.add_argument('--background_density_thresh', type=float, 114 | default=1.0+1e-9, 115 | help='Background sigma threshold for sparsification') 116 | group.add_argument('--max_grid_elements', type=int, 117 | default=44_000_000, 118 | help='Max items to store after upsampling ' 119 | '(the number here is given for 22GB memory)') 120 | 121 | 122 | 123 | group = parser.add_argument_group("losses") 124 | # Foreground TV 125 | group.add_argument('--lambda_tv', type=float, default=1e-5) 126 | group.add_argument('--tv_sparsity', type=float, default=0.01) 127 | group.add_argument('--tv_logalpha', action='store_true', default=False, 128 | help='Use log(1-exp(-delta * sigma)) as in neural volumes') 129 | 130 | group.add_argument('--lambda_tv_sh', type=float, default=1e-3) 131 | group.add_argument('--tv_sh_sparsity', type=float, default=0.01) 132 | 133 | group.add_argument('--lambda_tv_lumisphere', type=float, default=0.0)#1e-2)#1e-3) 134 | group.add_argument('--tv_lumisphere_sparsity', type=float, default=0.01) 135 | group.add_argument('--tv_lumisphere_dir_factor', type=float, default=0.0) 136 | 137 | group.add_argument('--tv_decay', type=float, default=1.0) 138 | 139 | group.add_argument('--lambda_l2_sh', type=float, default=0.0)#1e-4) 140 | group.add_argument('--tv_early_only', type=int, default=1, help="Turn off TV regularization after the first split/prune") 141 | 142 | group.add_argument('--tv_contiguous', type=int, default=1, 143 | help="Apply TV only on contiguous link chunks, which is faster") 144 | # End Foreground TV 145 | 146 | group.add_argument('--lr_decay', action='store_true', default=True) 147 | group.add_argument('--n_train', type=int, default=None, help='Number of training images. Defaults to use all avaiable.') 148 | 149 | 150 | group.add_argument('--lambda_sparsity', type=float, default= 151 | 0.0, 152 | help="Weight for sparsity loss as in SNeRG/PlenOctrees " + 153 | "(but applied on the ray)") 154 | group.add_argument('--lambda_beta', type=float, default= 155 | 0.0, 156 | help="Weight for beta distribution sparsity loss as in neural volumes") 157 | 158 | # ---------------- Finetune video related-------------- 159 | group = parser.add_argument_group("finetune") 160 | group.add_argument('--pretrained', type=str, default=None, 161 | help='pretrained model') 162 | 163 | group.add_argument('--mask_grad_after_reg', type=int, default=1, 164 | help='mask out unwanted gradient after TV and other regularization') 165 | 166 | group.add_argument('--frame_start', type=int, default=1, help='train frame among [frame_start, frame_end]') 167 | group.add_argument('--frame_end', type=int, default=30, help='train frame among [1, frame_end]') 168 | group.add_argument('--fps', type=int, default=30, help='video save fps') 169 | 170 | group.add_argument('--train_use_all', type=int, default=0 ,help='whether to use all image as training set') 171 | group.add_argument('--save_every_frame', action='store_true', default=False) 172 | group.add_argument('--dilate_rate_before', type=int, default=2, help="dilation rate for grid.links before training") 173 | group.add_argument('--dilate_rate_after', type=int, default=2, help=" dilation rate for grid.links after training") 174 | 175 | 176 | group.add_argument('--offset', type=int, default=250) 177 | 178 | # fancy idea 179 | group.add_argument('--compress_saving', action="store_true", default=False, help="dilation rate for grid.links") 180 | group.add_argument('--sh_keep_thres', type=float, default=1) 181 | group.add_argument('--sh_prune_thres', type=float, default=0.2) 182 | 183 | group.add_argument('--performance_mode', action="store_true", default=False, help="use perfomance_mode skip any unecessary code ") 184 | group.add_argument('--debug', action="store_true", default=False,help="switch on debug mode") 185 | group.add_argument('--keep_rms_data', action="store_true", default=False,help="switch on debug mode") 186 | 187 | 188 | group.add_argument('--apply_narrow_band', action="store_true", default=False,help="apply_narrow_band") 189 | group.add_argument('--render_all', action="store_true", default=False,help="render all camera in sequence") 190 | group.add_argument('--save_delta', action="store_true", default=False,help="save delta in compress saving") 191 | 192 | args = parser.parse_args() 193 | config_util.maybe_merge_config_file(args) 194 | 195 | DEBUG = args.debug 196 | assert args.lr_sigma_final <= args.lr_sigma, "lr_sigma must be >= lr_sigma_final" 197 | assert args.lr_sh_final <= args.lr_sh, "lr_sh must be >= lr_sh_final" 198 | 199 | os.makedirs(args.train_dir, exist_ok=True) 200 | os.makedirs(os.path.join(args.train_dir, 'grid_delta'), exist_ok=True) 201 | os.makedirs(os.path.join(args.train_dir, 'grid_delta_z'), exist_ok=True) 202 | os.makedirs(os.path.join(args.train_dir, 'test_images'), exist_ok=True) 203 | os.makedirs(os.path.join(args.train_dir, 'test_images_depth'), exist_ok=True) 204 | 205 | logfolder = args.train_dir 206 | if os.path.exists(f'{logfolder}/log_base.log'): 207 | os.remove(f'{logfolder}/log_base.log') 208 | logger.add(f'{logfolder}/log_base.log' , format="{level} {message}", level='DEBUG' if args.debug else 'INFO') 209 | 210 | summary_writer = SummaryWriter(args.train_dir) 211 | 212 | with open(path.join(args.train_dir, 'args.json'), 'w') as f: 213 | json.dump(args.__dict__, f, indent=2) 214 | # Changed name to prevent errors 215 | shutil.copyfile(__file__, path.join(args.train_dir, 'train_frozen.py')) 216 | 217 | torch.manual_seed(20200823) 218 | np.random.seed(20200823) 219 | 220 | assert os.path.exists(args.pretrained), "pretrained model not exist, please train the first frame!" 221 | print("Load pretrained model from ", args.pretrained) 222 | grid = svox2.SparseGrid.load(args.pretrained, device=device) 223 | config_util.setup_render_opts(grid.opt, args) 224 | print("Load pretrained model Done!") 225 | 226 | from copy import deepcopy 227 | from torch import nn 228 | 229 | 230 | 231 | 232 | def grid_copy( old_grid: svox2.SparseGrid, device: Union[torch.device, str] = "cpu"): 233 | """ 234 | Load from path 235 | """ 236 | 237 | sh_data = old_grid.sh_data.clone() 238 | density_data = old_grid.density_data.clone() 239 | logger.debug(f"copy grid cap {(old_grid.links>=0).sum()}") 240 | if hasattr(old_grid, "background_links") : 241 | background_data = old_grid.background_data 242 | background_links = old_grid.background_links 243 | else: 244 | background_data = None 245 | background_links = None 246 | 247 | links = old_grid.links.clone() 248 | basis_dim = (sh_data.shape[1]) // 3 249 | radius = deepcopy(old_grid.radius ) 250 | center = deepcopy(old_grid.center) 251 | grid_new = svox2.SparseGrid( 252 | 1, 253 | radius=radius, 254 | center=center, 255 | basis_dim=basis_dim, 256 | use_z_order=False, 257 | device="cpu", 258 | basis_type=old_grid.basis_type , 259 | mlp_posenc_size=old_grid.mlp_posenc_size, 260 | mlp_width=old_grid.mlp_width, 261 | background_nlayers=0, 262 | ) 263 | 264 | grid_new.sh_data = nn.Parameter(sh_data).to(device=device) 265 | grid_new.density_data = nn.Parameter(density_data).to(device=device) 266 | grid_new.links = links.to(device=device) # torch.from_numpy(links).to(device=device) 267 | grid_new.capacity = grid_new.sh_data.size(0) 268 | if args.keep_rms_data: 269 | grid_new.sh_rms = old_grid.sh_rms 270 | grid_new.density_rms = old_grid.density_rms 271 | 272 | if background_data is not None: 273 | background_data = torch.from_numpy(background_data).to(device=device) 274 | grid_new.background_nlayers = background_data.shape[1] 275 | grid_new.background_reso = background_links.shape[1] 276 | grid_new.background_data = nn.Parameter(background_data) 277 | grid_new.background_links = torch.from_numpy(background_links).to(device=device) 278 | else: 279 | grid_new.background_data.data = grid_new.background_data.data.to(device=device) 280 | 281 | if grid_new.links.is_cuda: 282 | grid_new.accelerate() 283 | config_util.setup_render_opts(grid_new.opt, args) 284 | logger.debug(f"grid copy finish") 285 | return grid_new 286 | 287 | def delete_area(grid, delet_mask): 288 | new_mask = torch.logical_and(grid.links>=0, ~delet_mask) 289 | 290 | delet_mask = delet_mask[grid.links>=0] 291 | grid.density_data = nn.Parameter(grid.density_data[~delet_mask,:]) 292 | grid.sh_data = nn.Parameter(grid.sh_data[~delet_mask,:]) 293 | if args.keep_rms_data: 294 | grid.sh_rms = None 295 | grid.density_rms = None 296 | 297 | new_links = torch.cumsum(new_mask.view(-1).to(torch.int32), dim=-1).int() - 1 298 | new_links[~new_mask.view(-1)] = -1 299 | grid.links = new_links.view(grid.links.shape) 300 | 301 | @torch.no_grad() 302 | def compress_saving(grid_pre, grid_next, grid_holder, save_delta=False,saving_name=None): 303 | mask_pre = grid_pre.links>=0 304 | mask_next = grid_next.links>=0 305 | new_cap = mask_next.sum() 306 | 307 | diff_area = torch.logical_xor(mask_pre, mask_next) 308 | 309 | add_area = (diff_area & mask_next) 310 | minus_area = (diff_area & mask_pre) 311 | 312 | addition_density = grid_next.density_data[grid_next.links[add_area].long()] 313 | addition_sh = grid_next.sh_data[grid_next.links[add_area].long()] 314 | 315 | logger.debug(f"diff area: {diff_area.sum()} add area: {add_area.sum()} minus area: {minus_area.sum()} ") 316 | remain_idx = grid_pre.links[mask_pre & ~ minus_area] 317 | remain_idx = remain_idx.long() 318 | 319 | remain_sh_data = grid_pre.sh_data[remain_idx] 320 | remain_density_data = grid_pre.density_data[remain_idx] 321 | 322 | new_sh_data = torch.zeros((new_cap,27), device=device).float() 323 | new_density_data = torch.zeros((new_cap,1), device=device).float() 324 | 325 | add_area_in_saprse = add_area[mask_next] 326 | 327 | # we also save voxel where sh change a lot 328 | next_sh_data = grid_next.sh_data[~add_area_in_saprse,:] 329 | next_density_data = grid_next.density_data[~add_area_in_saprse,:] 330 | part2_keep_area = (abs(next_sh_data - remain_sh_data).sum(-1) > args.sh_keep_thres) 331 | keep_numel = part2_keep_area.sum() 332 | add_numel = add_area.sum() 333 | 334 | keep_percent = (keep_numel/new_cap) * 100 335 | add_percent = (add_numel/new_cap) * 100 336 | keep_size = (keep_numel*2*28)/(1024*1024) 337 | add_size = (add_numel*2*28)/(1024*1024) 338 | if save_delta: 339 | save_dict = {'mask_next':mask_next, 340 | 'addition_density':addition_density, 341 | 'addition_sh':addition_sh, 342 | 'part2_keep_area':part2_keep_area, 343 | 'keep_density':next_density_data[part2_keep_area], 344 | 'keep_sh':next_sh_data[part2_keep_area] 345 | } 346 | save_path = os.path.join(args.train_dir,'grid_delta',f'{saving_name}.pth') 347 | logger.info(f'svaing delta to : {save_path} ') 348 | torch.save(save_dict, save_path) 349 | logger.info(f"keep element: {keep_numel}/{keep_percent:.2f}/{keep_size:.2f} MB, add element: {add_numel}/{add_percent:.2f}/{add_size:.2f} MB") 350 | 351 | if save_delta: 352 | all_in_one = { 353 | 'mask_next':np.packbits(mask_next.cpu().numpy()), 354 | 'mask_keep':np.packbits(part2_keep_area.cpu().numpy()) , 355 | 'addition_density':addition_density.cpu().numpy().astype(np.float16), 356 | 'addition_sh':addition_sh.cpu().numpy().astype(np.float16), 357 | 'keep_density':next_density_data[part2_keep_area].cpu().numpy().astype(np.float16), 358 | 'keep_sh':next_sh_data[part2_keep_area].cpu().numpy().astype(np.float16) 359 | } 360 | save_path = os.path.join(args.train_dir,'grid_delta_z',f'{saving_name}.npz') 361 | 362 | np.savez_compressed(save_path, all_in_one) 363 | logger.info(f'saving delta z to : {save_path}') 364 | logger.info(f'saving size after compression: {os.path.getsize(save_path)/(1024*1024):.2f} MB') 365 | 366 | remain_sh_data[part2_keep_area] = next_sh_data[part2_keep_area] 367 | remain_density_data[part2_keep_area] = next_density_data[part2_keep_area] 368 | 369 | new_sh_data[add_area_in_saprse,:] = addition_sh 370 | new_density_data[add_area_in_saprse,:] = addition_density 371 | new_sh_data[~add_area_in_saprse,:] = remain_sh_data 372 | new_density_data[~add_area_in_saprse,:] = remain_density_data 373 | # though new_links equal to grid_next.links, we still calculate a mask for better scalability 374 | new_mask = torch.logical_or(add_area, mask_pre) 375 | new_mask = torch.logical_and(new_mask, ~minus_area) 376 | new_links = torch.cumsum(new_mask.view(-1).to(torch.int32), dim=-1).int() - 1 377 | new_links[~new_mask.view(-1)] = -1 378 | 379 | 380 | grid_holder.sh_data = nn.Parameter(new_sh_data) 381 | grid_holder.density_data = nn.Parameter(new_density_data) 382 | grid_holder.links = new_links.view(grid_next.links.shape).to(device=device) 383 | 384 | if args.keep_rms_data: 385 | grid_holder.sh_rms = grid_next.sh_rms 386 | grid_holder.density_rms = grid_next.density_rms 387 | 388 | logger.debug(f"compress saving finish") 389 | 390 | return grid_holder 391 | 392 | 393 | def dilated_voxel_grid(dilate_rate = 2): 394 | active_mask = grid.links >= 0 395 | dilate_before = active_mask 396 | for i in range(dilate_rate): 397 | active_mask = _C.dilate(active_mask) 398 | # reactivate = torch.logical_xor(active_mask, dilate_before) 399 | new_cap = active_mask.sum() 400 | previous_sparse_area = dilate_before[active_mask] 401 | 402 | new_density = torch.zeros((new_cap,1), device=device).float() 403 | new_sh = torch.zeros((new_cap, grid.basis_dim*3), device=device).float() 404 | 405 | new_density[previous_sparse_area,:] = grid.density_data.data 406 | new_sh[previous_sparse_area,:] = grid.sh_data.data 407 | 408 | active_mask = active_mask.view(-1) 409 | new_links = torch.cumsum(active_mask.to(torch.int32), dim=-1).int() - 1 410 | new_links[~active_mask] = -1 411 | 412 | grid.density_data = torch.nn.Parameter(new_density) 413 | grid.sh_data = torch.nn.Parameter(new_sh) 414 | grid.links = new_links.view(grid.links.shape).to(device=device) 415 | 416 | 417 | def sparsify_voxel_grid(grid, factor=[1,1,1],dilate=2): 418 | reso = grid.links.shape 419 | reso = [int(r * fac) for r, fac in zip(reso, factor)] 420 | grid.resample(reso=reso, 421 | sigma_thresh=args.density_thresh, 422 | weight_thresh=0.0, 423 | dilate=dilate, 424 | cameras= None, 425 | max_elements=args.max_grid_elements, 426 | accelerate=False) 427 | 428 | 429 | 430 | if args.dilate_rate_after > 0: 431 | logger.debug("sparsify first!!!!") 432 | sparsify_voxel_grid(grid,dilate=args.dilate_rate_after) 433 | 434 | 435 | # LR related 436 | lr_sigma_func = get_expon_lr_func(args.lr_sigma, args.lr_sigma_final, args.lr_sigma_delay_steps, 437 | args.lr_sigma_delay_mult, args.lr_sigma_decay_steps) 438 | lr_sh_func = get_expon_lr_func(args.lr_sh, args.lr_sh_final, args.lr_sh_delay_steps, 439 | args.lr_sh_delay_mult, args.lr_sh_decay_steps) 440 | lr_sigma_factor = 1.0 441 | lr_sh_factor = 1.0 442 | 443 | 444 | 445 | grid_raw = grid_copy(grid, device=device) 446 | 447 | 448 | from torch.multiprocessing import Queue, Process 449 | from queue import Empty 450 | frame_idx_queue = Queue() 451 | dset_queue = Queue() 452 | 453 | def pre_fetch_dataset(): 454 | while True: 455 | try: 456 | frame_idx = frame_idx_queue.get(block=True,timeout=60) 457 | except Empty: 458 | logger.debug('ending data prefetch process') 459 | return 460 | data_dir = os.path.join(args.data_dir, f'{frame_idx:04d}') 461 | train_dir = args.train_dir 462 | factor = 1 463 | dset_train = datasets[args.dataset_type]( 464 | data_dir, 465 | split="train", 466 | device=device, 467 | factor=factor, 468 | n_images=args.n_train, 469 | train_dir = train_dir, 470 | train_use_all=args.train_use_all, 471 | offset=args.offset, 472 | verbose=False, 473 | **config_util.build_data_options(args)) 474 | 475 | # dataset used to render test image, can include training camera for better visualization 476 | dset_test = datasets[args.dataset_type]( 477 | data_dir, split= 'train' if args.render_all else "test", train_use_all=1 if args.render_all else 0,offset=args.offset, verbose=False, **config_util.build_data_options(args)) 478 | 479 | # # dataset used for PSNR caculation 480 | dset_eval = datasets[args.dataset_type]( 481 | data_dir, split="test", train_use_all=0,offset=args.offset, verbose=False, **config_util.build_data_options(args)) 482 | 483 | logger.debug(f"finish loading frame:{frame_idx}") 484 | dset_queue.put((dset_train,dset_test, dset_eval)) 485 | return dset_train, dset_test, dset_eval 486 | 487 | def pre_fetch_dataset_standalone(frame_idx): 488 | data_dir = os.path.join(args.data_dir, f'{frame_idx:04d}') 489 | train_dir = args.train_dir 490 | factor = 1 491 | dset_train = datasets[args.dataset_type]( 492 | data_dir, 493 | split="train", 494 | device=device, 495 | factor=factor, 496 | n_images=args.n_train, 497 | train_dir = train_dir, 498 | train_use_all=args.train_use_all, 499 | offset=args.offset, 500 | verbose=False, 501 | **config_util.build_data_options(args)) 502 | 503 | # dataset used to render test image, can include training camera for better visualization 504 | dset_test = datasets[args.dataset_type]( 505 | data_dir, split= 'train' if args.render_all else "test", train_use_all=1 if args.render_all else 0,offset=args.offset, verbose=False, **config_util.build_data_options(args)) 506 | 507 | # # dataset used for PSNR caculation 508 | dset_eval = datasets[args.dataset_type]( 509 | data_dir, split="test", train_use_all=0,offset=args.offset, verbose=False, **config_util.build_data_options(args)) 510 | 511 | logger.debug(f"finish loading frame:{frame_idx}") 512 | return dset_train, dset_test, dset_eval 513 | 514 | 515 | 516 | def deploy_dset(dset): 517 | dset.c2w = torch.from_numpy(dset.c2w) 518 | dset.gt = torch.from_numpy(dset.gt).float() 519 | if not dset.is_train_split: 520 | dset.render_c2w = torch.from_numpy(dset.render_c2w) 521 | else: 522 | dset.gen_rays() 523 | return dset 524 | 525 | 526 | 527 | def finetune_one_frame(frame_idx, global_step_base, dsets): 528 | if args.compress_saving: 529 | grid_pre = grid_copy(old_grid = grid, device=device) 530 | with torch.no_grad(): 531 | if args.apply_narrow_band: 532 | active_mask = grid.links>= 0 533 | dmask = active_mask.clone() 534 | for _ in range(args.dilate_rate_before): 535 | dmask = _C.dilate(dmask) 536 | emask = ~active_mask 537 | for _ in range(6): 538 | emask = _C.dilate(emask) 539 | emask = ~emask 540 | narrow_band = torch.logical_xor(dmask, emask) 541 | 542 | 543 | if args.dilate_rate_before > 0: 544 | dilated_voxel_grid(dilate_rate=args.dilate_rate_before) 545 | 546 | if args.apply_narrow_band: 547 | grad_mask = narrow_band[grid.links>=0] 548 | grad_mask = grad_mask.view(-1) 549 | else: 550 | grad_mask = (torch.ones([1]).float().cuda() == 1) 551 | 552 | 553 | train_dir = args.train_dir 554 | 555 | dset_train, dset_test, dset_eval = dsets 556 | dset_train = deploy_dset(dset_train) 557 | dset_test = deploy_dset(dset_test) 558 | dset_eval = deploy_dset(dset_eval) 559 | 560 | epoch_id = -1 561 | global_start_time = datetime.now() 562 | gstep_id_base = 0 563 | 564 | 565 | shuffle_step = args.n_iters 566 | dset_train.epoch_size = shuffle_step * args.batch_size 567 | 568 | 569 | timer_dict = {'forward':0, 'regularization':0, 'optimization':0,'preparation':0,'narrowband':0} 570 | max_step = args.n_iters * (10 if frame_idx == 0 else 1) 571 | grid.accelerate() 572 | for gstep_id in tqdm(range(0, max_step)): 573 | if gstep_id==0 or gstep_id % shuffle_step == 0: 574 | with torch.no_grad(): 575 | dset_train.shuffle_rays() 576 | logger.debug('shuffle') 577 | 578 | def train_step(timer_dict): 579 | 580 | #============================= ray preparation stage ============================= 581 | tic = time.time() 582 | stats = {"mse" : 0.0, "psnr" : 0.0, "invsqr_mse" : 0.0} 583 | 584 | bstep_id = gstep_id % shuffle_step 585 | batch_begin = bstep_id * args.batch_size 586 | 587 | lr_sigma = lr_sigma_func(gstep_id) * lr_sigma_factor 588 | lr_sh = lr_sh_func(gstep_id) * lr_sh_factor 589 | if not args.lr_decay: 590 | lr_sigma = args.lr_sigma * lr_sigma_factor 591 | lr_sh = args.lr_sh * lr_sh_factor 592 | 593 | 594 | batch_end = batch_begin + args.batch_size 595 | batch_origins = dset_train.rays.origins[batch_begin: batch_end] 596 | batch_dirs = dset_train.rays.dirs[batch_begin: batch_end] 597 | 598 | rgb_gt = dset_train.rays.gt[batch_begin: batch_end] 599 | rays = svox2.Rays(batch_origins, batch_dirs) 600 | if args.debug: 601 | torch.cuda.synchronize() 602 | timer_dict['preparation'] += (time.time() - tic) 603 | 604 | 605 | #============================= forward stage ============================= 606 | tic = time.time() 607 | 608 | rgb_pred = grid.volume_render_fused(rays, rgb_gt, 609 | beta_loss=args.lambda_beta, 610 | sparsity_loss=args.lambda_sparsity, 611 | randomize=args.enable_random) 612 | 613 | if args.debug: 614 | torch.cuda.synchronize() 615 | timer_dict['forward'] += (time.time() - tic) 616 | 617 | if not args.performance_mode: 618 | with torch.no_grad(): 619 | mse = F.mse_loss(rgb_gt, rgb_pred) 620 | mse_num : float = mse.detach().item() 621 | psnr = -10.0 * math.log10(mse_num) 622 | stats['mse'] += mse_num 623 | stats['psnr'] += psnr 624 | stats['invsqr_mse'] += 1.0 / mse_num ** 2 625 | 626 | if (gstep_id + 1) % args.print_every == 0: 627 | 628 | for stat_name in stats: 629 | stat_val = stats[stat_name] / args.print_every 630 | summary_writer.add_scalar(stat_name, stat_val, global_step=gstep_id+global_step_base) 631 | stats[stat_name] = 0.0 632 | summary_writer.add_scalar("lr_sh", lr_sh, global_step=gstep_id+global_step_base) 633 | summary_writer.add_scalar("lr_sigma", lr_sigma, global_step=gstep_id+global_step_base) 634 | 635 | #============================= regularization stage ============================= 636 | tic = time.time() 637 | # Apply TV/Sparsity regularizers 638 | if args.lambda_tv > 0.0: 639 | # with Timing("tv_inpl"): 640 | grid.inplace_tv_grad(grid.density_data.grad, 641 | scaling=args.lambda_tv, 642 | sparse_frac=args.tv_sparsity , 643 | logalpha=args.tv_logalpha, 644 | ndc_coeffs=dset_train.ndc_coeffs, 645 | contiguous=args.tv_contiguous) 646 | 647 | if args.lambda_tv_sh > 0.0: 648 | # with Timing("tv_color_inpl"): 649 | grid.inplace_tv_color_grad(grid.sh_data.grad, 650 | scaling=args.lambda_tv_sh, 651 | sparse_frac=args.tv_sh_sparsity, 652 | ndc_coeffs=dset_train.ndc_coeffs, 653 | contiguous=args.tv_contiguous) 654 | 655 | if args.lambda_tv_lumisphere > 0.0: 656 | grid.inplace_tv_lumisphere_grad(grid.sh_data.grad, 657 | scaling=args.lambda_tv_lumisphere, 658 | dir_factor=args.tv_lumisphere_dir_factor, 659 | sparse_frac=args.tv_lumisphere_sparsity, 660 | ndc_coeffs=dset_train.ndc_coeffs) 661 | if args.lambda_l2_sh > 0.0: 662 | grid.inplace_l2_color_grad(grid.sh_data.grad, 663 | scaling=args.lambda_l2_sh) 664 | 665 | if args.debug: 666 | torch.cuda.synchronize() 667 | timer_dict['regularization'] += (time.time() - tic) 668 | 669 | 670 | #============================= narrow band stage ============================= 671 | tic = time.time() 672 | grid.sparse_sh_grad_indexer &= grad_mask 673 | grid.sparse_grad_indexer &= grad_mask 674 | if args.debug: 675 | torch.cuda.synchronize() 676 | timer_dict['narrowband'] += (time.time() - tic) 677 | 678 | #============================= optimization stage ============================= 679 | tic = time.time() 680 | grid.optim_density_step(lr_sigma, beta=args.rms_beta, optim=args.sigma_optim) 681 | grid.optim_sh_step(lr_sh, beta=args.rms_beta, optim=args.sh_optim) 682 | if args.debug: 683 | torch.cuda.synchronize() 684 | timer_dict['optimization'] += (time.time() - tic) 685 | 686 | def eval_step(): 687 | with torch.no_grad(): 688 | stats_test = {'mse' : 0.0, 'psnr' : 0.0} 689 | # Standard set 690 | N_IMGS_TO_EVAL = min(20 if epoch_id > 0 else 5, dset_eval.n_images) 691 | N_IMGS_TO_SAVE = N_IMGS_TO_EVAL # if not args.tune_mode else 1 692 | img_eval_interval = dset_eval.n_images // N_IMGS_TO_EVAL 693 | img_save_interval = (N_IMGS_TO_EVAL // N_IMGS_TO_SAVE) 694 | img_ids = range(0, dset_eval.n_images, img_eval_interval) 695 | n_images_gen = 0 696 | for i, img_id in tqdm(enumerate(img_ids), total=len(img_ids)): 697 | c2w = dset_eval.c2w[img_id].to(device=device) 698 | cam = svox2.Camera(c2w, 699 | dset_eval.intrins.get('fx', img_id), 700 | dset_eval.intrins.get('fy', img_id), 701 | dset_eval.intrins.get('cx', img_id), 702 | dset_eval.intrins.get('cy', img_id), 703 | width=dset_eval.get_image_size(img_id)[1], 704 | height=dset_eval.get_image_size(img_id)[0], 705 | ndc_coeffs=dset_eval.ndc_coeffs) 706 | rgb_pred_test = grid.volume_render_image(cam, use_kernel=True) 707 | rgb_gt_test = dset_eval.gt[img_id].to(device=device) 708 | all_mses = ((rgb_gt_test - rgb_pred_test) ** 2).cpu() 709 | if i % img_save_interval == 0: 710 | img_pred = rgb_pred_test.cpu() 711 | img_pred.clamp_max_(1.0) 712 | summary_writer.add_image(f'test/image_{img_id:04d}', 713 | img_pred, global_step=frame_idx, dataformats='HWC') 714 | if args.log_mse_image: 715 | mse_img = all_mses / all_mses.max() 716 | summary_writer.add_image(f'test/mse_map_{img_id:04d}', 717 | mse_img, global_step=frame_idx, dataformats='HWC') 718 | if False or args.log_depth_map: 719 | depth_img = grid.volume_render_depth_image(cam, 720 | args.log_depth_map_use_thresh if 721 | args.log_depth_map_use_thresh else None 722 | ) 723 | depth_img = viridis_cmap(depth_img.cpu()) 724 | summary_writer.add_image(f'test/depth_map_{img_id:04d}', 725 | depth_img, 726 | global_step=frame_idx, dataformats='HWC') 727 | 728 | rgb_pred_test = rgb_gt_test = None 729 | mse_num : float = all_mses.mean().item() 730 | psnr = -10.0 * math.log10(mse_num) 731 | if math.isnan(psnr): 732 | print('NAN PSNR', i, img_id, mse_num) 733 | assert False 734 | stats_test['mse'] += mse_num 735 | stats_test['psnr'] += psnr 736 | n_images_gen += 1 737 | stats_test['mse'] /= n_images_gen 738 | stats_test['psnr'] /= n_images_gen 739 | for stat_name in stats_test: 740 | summary_writer.add_scalar('test/' + stat_name, 741 | stats_test[stat_name], global_step=gstep_id_base+global_step_base) 742 | summary_writer.add_scalar('epoch_id', float(epoch_id), global_step=gstep_id_base+global_step_base) 743 | print('eval stats:', stats_test) 744 | logger.critical(f"per_frame_psnr: {frame_idx} {psnr}") 745 | return psnr 746 | if args.debug: 747 | torch.cuda.synchronize() 748 | 749 | tic = time.time() 750 | train_step(timer_dict) 751 | if args.debug: 752 | torch.cuda.synchronize() 753 | 754 | if gstep_id == max_step - 1: 755 | global_stop_time = datetime.now() 756 | line = '' 757 | for k,v in timer_dict.items(): 758 | line += f'{k}:sum: {v:.3f} sec / avg:{(v*1000)/max_step:.3f} ms, ' 759 | logger.info(line) 760 | secs = (global_stop_time - global_start_time).total_seconds() 761 | logger.info(f'cost: {secs}, s') 762 | psnr = eval_step() 763 | 764 | break 765 | 766 | 767 | if args.dilate_rate_after or args.dilate_rate_before: 768 | 769 | sparsify_voxel_grid(grid, dilate=args.dilate_rate_after) 770 | 771 | @torch.no_grad() 772 | def preprune(grid_pre, grid_next): 773 | 774 | mask_pre = grid_pre.links>=0 775 | mask_next = grid_next.links>=0 776 | new_cap = mask_next.sum() 777 | 778 | diff_area = torch.logical_xor(mask_pre, mask_next) 779 | 780 | add_area = (diff_area & mask_next) 781 | minus_area = (diff_area & mask_pre) 782 | logger.info(f"diff area before preprune: {diff_area.sum()} add area: {add_area.sum()} minus area: {minus_area.sum()} ") 783 | addition_density = grid_next.density_data[grid_next.links[add_area].long()] 784 | addition_sh = grid_next.sh_data[grid_next.links[add_area].long()] 785 | 786 | no_need_area = (abs(addition_sh).sum(-1) name mappings 171 | """ 172 | print(' Loading single CO3D sequence:', 173 | self.seq_cats[sequence_id], self.seq_names[sequence_id]) 174 | self.curr_seq_cat = self.seq_cats[sequence_id] 175 | self.curr_seq_name = self.seq_names[sequence_id] 176 | self.curr_offset = self.seq_offsets[sequence_id] 177 | self.next_offset = self.seq_offsets[sequence_id + 1] 178 | self.gt = [] 179 | fxs, fys, cxs, cys = [], [], [], [] 180 | image_sizes = [] 181 | c2ws = [] 182 | ref_c2ws = [] 183 | for i in tqdm(range(self.curr_offset, self.next_offset)): 184 | is_train = i % self.hold_every != 0 185 | ref_c2ws.append(self.image_pose[i]) 186 | if self.split.endswith('train') != is_train: 187 | continue 188 | im = cv2.imread(path.join(self.data_dir, self.image_path[i])) 189 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 190 | im = im[..., :3] 191 | h, w, _ = im.shape 192 | max_hw = max(h, w) 193 | approx_scale = self.max_image_dim / max_hw 194 | 195 | if approx_scale < 1.0: 196 | h2 = int(approx_scale * h) 197 | w2 = int(approx_scale * w) 198 | im = cv2.resize(im, (w2, h2), interpolation=cv2.INTER_AREA) 199 | else: 200 | h2 = h 201 | w2 = w 202 | scale = np.array([w2 / w, h2 / h], dtype=np.float32) 203 | image_sizes.append(np.array([h2, w2])) 204 | cxy = self.cxy[i] * scale 205 | fxy = self.fxy[i] * scale 206 | fxs.append(fxy[0]) 207 | fys.append(fxy[1]) 208 | cxs.append(cxy[0]) 209 | cys.append(cxy[1]) 210 | # grid = data_util.gen_grid(h2, w2, cxy.astype(np.float32), normalize_scale=False) 211 | # grid /= fxy.astype(np.float32) 212 | self.gt.append(torch.from_numpy(im)) 213 | c2ws.append(self.image_pose[i]) 214 | c2w = np.stack(c2ws, axis=0) 215 | ref_c2ws = np.stack(ref_c2ws, axis=0) # For rescaling scene 216 | self.image_size = np.stack(image_sizes) 217 | fxs = torch.tensor(fxs) 218 | fys = torch.tensor(fys) 219 | cxs = torch.tensor(cxs) 220 | cys = torch.tensor(cys) 221 | 222 | # Filter out crazy poses 223 | dists = np.linalg.norm(c2w[:, :3, 3] - np.median(c2w[:, :3, 3], axis=0), axis=-1) 224 | med = np.median(dists) 225 | good_mask = dists < med * self.max_pose_dist 226 | c2w = c2w[good_mask] 227 | self.image_size = self.image_size[good_mask] 228 | good_idx = np.where(good_mask)[0] 229 | self.gt = [self.gt[i] for i in good_idx] 230 | 231 | self.intrins_full = Intrin(fxs[good_mask], fys[good_mask], 232 | cxs[good_mask], cys[good_mask]) 233 | 234 | # Normalize 235 | # c2w[:, :3, 3] -= np.mean(c2w[:, :3, 3], axis=0) 236 | # dists = np.linalg.norm(c2w[:, :3, 3], axis=-1) 237 | # c2w[:, :3, 3] *= self.cam_scale_factor / np.median(dists) 238 | 239 | T, sscale = similarity_from_cameras(ref_c2ws) 240 | c2w = T @ c2w 241 | c2w[:, :3, 3] *= self.cam_scale_factor * sscale 242 | 243 | self.c2w = torch.from_numpy(c2w).float() 244 | self.cam_n_rays = self.image_size[:, 0] * self.image_size[:, 1] 245 | self.n_images = len(self.gt) 246 | self.image_size_full = self.image_size 247 | 248 | if self.split == "train": 249 | self.gen_rays(factor=1) 250 | else: 251 | # Rays are not needed for testing 252 | self.intrins : Intrin = self.intrins_full 253 | 254 | 255 | def gen_rays(self, factor=1): 256 | print(" Generating rays, scaling factor", factor) 257 | # Generate rays 258 | self.factor = factor 259 | self.image_size = self.image_size_full // factor 260 | true_factor = self.image_size_full[:, 0] / self.image_size[:, 0] 261 | self.intrins = self.intrins_full.scale(1.0 / true_factor) 262 | 263 | all_origins = [] 264 | all_dirs = [] 265 | all_gts = [] 266 | for i in tqdm(range(self.n_images)): 267 | yy, xx = torch.meshgrid( 268 | torch.arange(self.image_size[i, 0], dtype=torch.float32) + 0.5, 269 | torch.arange(self.image_size[i, 1], dtype=torch.float32) + 0.5, 270 | ) 271 | xx = (xx - self.intrins.get('cx', i)) / self.intrins.get('fx', i) 272 | yy = (yy - self.intrins.get('cy', i)) / self.intrins.get('fy', i) 273 | zz = torch.ones_like(xx) 274 | dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention 275 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 276 | dirs = dirs.reshape(-1, 3, 1) 277 | del xx, yy, zz 278 | dirs = (self.c2w[i, None, :3, :3] @ dirs)[..., 0] 279 | 280 | if factor != 1: 281 | gt = F.interpolate( 282 | self.gt[i].permute([2, 0, 1])[None], size=(self.image_size[i, 0], 283 | self.image_size[i, 1]), 284 | mode="area" 285 | )[0].permute([1, 2, 0]) 286 | gt = gt.reshape(-1, 3) 287 | else: 288 | gt = self.gt[i].reshape(-1, 3) 289 | origins = self.c2w[i, None, :3, 3].expand(self.image_size[i, 0] * 290 | self.image_size[i, 1], -1).contiguous() 291 | all_origins.append(origins) 292 | all_dirs.append(dirs) 293 | all_gts.append(gt) 294 | origins = all_origins 295 | dirs = all_dirs 296 | gt = all_gts 297 | 298 | if self.split == "train": 299 | origins = torch.cat([o.view(-1, 3) for o in origins], dim=0) 300 | dirs = torch.cat([o.view(-1, 3) for o in dirs], dim=0) 301 | gt = torch.cat([o.reshape(-1, 3) for o in gt], dim=0) 302 | 303 | self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt) 304 | self.rays = self.rays_init 305 | -------------------------------------------------------------------------------- /util/config_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from util.dataset import datasets 4 | import json 5 | 6 | 7 | def define_common_args(parser : argparse.ArgumentParser): 8 | parser.add_argument('data_dir', type=str) 9 | 10 | parser.add_argument('--config', '-c', 11 | type=str, 12 | default=None, 13 | help="Config yaml file (will override args)") 14 | 15 | group = parser.add_argument_group("Data loading") 16 | group.add_argument('--dataset_type', 17 | choices=list(datasets.keys()) + ["auto"], 18 | default="auto", 19 | help="Dataset type (specify type or use auto)") 20 | group.add_argument('--scene_scale', 21 | type=float, 22 | default=None, 23 | help="Global scene scaling (or use dataset default)") 24 | group.add_argument('--scale', 25 | type=float, 26 | default=None, 27 | help="Image scale, e.g. 0.5 for half resolution (or use dataset default)") 28 | group.add_argument('--seq_id', 29 | type=int, 30 | default=1000, 31 | help="Sequence ID (for CO3D only)") 32 | group.add_argument('--epoch_size', 33 | type=int, 34 | default=12800, 35 | help="Pseudo-epoch size in term of batches (to be consistent across datasets)") 36 | group.add_argument('--white_bkgd', 37 | type=bool, 38 | default=True, 39 | help="Whether to use white background (ignored in some datasets)") 40 | group.add_argument('--llffhold', 41 | type=int, 42 | default=8, 43 | help="LLFF holdout every") 44 | group.add_argument('--normalize_by_bbox', 45 | type=bool, 46 | default=False, 47 | help="Normalize by bounding box in bbox.txt, if available (NSVF dataset only); precedes normalize_by_camera") 48 | group.add_argument('--data_bbox_scale', 49 | type=float, 50 | default=1.2, 51 | help="Data bbox scaling (NSVF dataset only)") 52 | group.add_argument('--cam_scale_factor', 53 | type=float, 54 | default=0.95, 55 | help="Camera autoscale factor (NSVF/CO3D dataset only)") 56 | group.add_argument('--normalize_by_camera', 57 | type=bool, 58 | default=True, 59 | help="Normalize using cameras, assuming a 360 capture (NSVF dataset only); only used if not normalize_by_bbox") 60 | group.add_argument('--perm', action='store_true', default=False, 61 | help='sample by permutation of rays (true epoch) instead of ' 62 | 'uniformly random rays') 63 | 64 | group = parser.add_argument_group("Render options") 65 | group.add_argument('--step_size', 66 | type=float, 67 | default=0.5, 68 | help="Render step size (in voxel size units)") 69 | group.add_argument('--sigma_thresh', 70 | type=float, 71 | default=1e-8, 72 | help="Skips voxels with sigma < this") 73 | group.add_argument('--stop_thresh', 74 | type=float, 75 | default=1e-7, 76 | help="Ray march stopping threshold") 77 | group.add_argument('--background_brightness', 78 | type=float, 79 | default=1.0, 80 | help="Brightness of the infinite background") 81 | group.add_argument('--renderer_backend', '-B', 82 | choices=['cuvol', 'svox1', 'nvol'], 83 | default='cuvol', 84 | help="Renderer backend") 85 | group.add_argument('--random_sigma_std', 86 | type=float, 87 | default=0.0, 88 | help="Random Gaussian std to add to density values (only if enable_random)") 89 | group.add_argument('--random_sigma_std_background', 90 | type=float, 91 | default=0.0, 92 | help="Random Gaussian std to add to density values for BG (only if enable_random)") 93 | group.add_argument('--near_clip', 94 | type=float, 95 | default=0.00, 96 | help="Near clip distance (in world space distance units, only for FG)") 97 | group.add_argument('--use_spheric_clip', 98 | action='store_true', 99 | default=False, 100 | help="Use spheric ray clipping instead of voxel grid AABB " 101 | "(only for FG; changes near_clip to mean 1-near_intersection_radius; " 102 | "far intersection is always at radius 1)") 103 | group.add_argument('--enable_random', 104 | action='store_true', 105 | default=False, 106 | help="Random Gaussian std to add to density values") 107 | group.add_argument('--last_sample_opaque', 108 | action='store_true', 109 | default=False, 110 | help="Last sample has +1e9 density (used for LLFF)") 111 | 112 | 113 | def build_data_options(args): 114 | """ 115 | Arguments to pass as kwargs to the dataset constructor 116 | """ 117 | return { 118 | 'dataset_type': args.dataset_type, 119 | 'seq_id': args.seq_id, 120 | 'epoch_size': args.epoch_size * args.__dict__.get('batch_size', 5000), 121 | 'scene_scale': args.scene_scale, 122 | 'scale': args.scale, 123 | 'white_bkgd': args.white_bkgd, 124 | 'hold_every': args.llffhold, 125 | 'normalize_by_bbox': args.normalize_by_bbox, 126 | 'data_bbox_scale': args.data_bbox_scale, 127 | 'cam_scale_factor': args.cam_scale_factor, 128 | 'normalize_by_camera': args.normalize_by_camera, 129 | 'permutation': args.perm 130 | } 131 | 132 | def maybe_merge_config_file(args, allow_invalid=False): 133 | """ 134 | Load json config file if specified and merge the arguments 135 | """ 136 | if args.config is not None: 137 | with open(args.config, "r") as config_file: 138 | configs = json.load(config_file) 139 | invalid_args = list(set(configs.keys()) - set(dir(args))) 140 | if invalid_args and not allow_invalid: 141 | raise ValueError(f"Invalid args {invalid_args} in {args.config}.") 142 | args.__dict__.update(configs) 143 | 144 | def setup_render_opts(opt, args): 145 | """ 146 | Pass render arguments to the SparseGrid renderer options 147 | """ 148 | opt.step_size = args.step_size 149 | opt.sigma_thresh = args.sigma_thresh 150 | opt.stop_thresh = args.stop_thresh 151 | opt.background_brightness = args.background_brightness 152 | opt.backend = args.renderer_backend 153 | opt.random_sigma_std = args.random_sigma_std 154 | opt.random_sigma_std_background = args.random_sigma_std_background 155 | opt.last_sample_opaque = args.last_sample_opaque 156 | opt.near_clip = args.near_clip 157 | opt.use_spheric_clip = args.use_spheric_clip 158 | -------------------------------------------------------------------------------- /util/dataset.py: -------------------------------------------------------------------------------- 1 | from .nerf_dataset import NeRFDataset 2 | from .llff_dataset import LLFFDataset 3 | from .nsvf_dataset import NSVFDataset 4 | from .co3d_dataset import CO3DDataset 5 | from os import path 6 | 7 | def auto_dataset(root : str, *args, **kwargs): 8 | if path.isfile(path.join(root, 'apple', 'eval_batches_multisequence.json')): 9 | print("Detected CO3D dataset") 10 | return CO3DDataset(root, *args, **kwargs) 11 | elif path.isfile(path.join(root, 'poses_bounds.npy')): 12 | print("Detected LLFF dataset") 13 | return LLFFDataset(root, *args, **kwargs) 14 | elif path.isfile(path.join(root, 'transforms.json')) or \ 15 | path.isfile(path.join(root, 'transforms_train.json')): 16 | print("Detected NeRF (Blender) dataset") 17 | return NeRFDataset(root, *args, **kwargs) 18 | else: 19 | print("Defaulting to extended NSVF dataset") 20 | return NSVFDataset(root, *args, **kwargs) 21 | 22 | datasets = { 23 | 'nerf': NeRFDataset, 24 | 'llff': LLFFDataset, 25 | 'nsvf': NSVFDataset, 26 | 'co3d': CO3DDataset, 27 | 'auto': auto_dataset 28 | } 29 | -------------------------------------------------------------------------------- /util/dataset_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Union, Optional, List 4 | from .util import select_or_shuffle_rays, Rays, Intrin 5 | 6 | class DatasetBase: 7 | split: str 8 | permutation: bool 9 | epoch_size: Optional[int] 10 | n_images: int 11 | h_full: int 12 | w_full: int 13 | intrins_full: Intrin 14 | c2w: torch.Tensor # C2W OpenCV poses 15 | gt: Union[torch.Tensor, List[torch.Tensor]] # RGB images 16 | device : Union[str, torch.device] 17 | 18 | def __init__(self): 19 | self.ndc_coeffs = (-1, -1) 20 | self.use_sphere_bound = False 21 | self.should_use_background = True # a hint 22 | self.use_sphere_bound = True 23 | self.scene_center = [0.0, 0.0, 0.0] 24 | self.scene_radius = [1.0, 1.0, 1.0] 25 | self.permutation = False 26 | 27 | def shuffle_rays(self): 28 | """ 29 | Shuffle all rays 30 | """ 31 | if self.split == "train": 32 | del self.rays 33 | self.rays = select_or_shuffle_rays(self.rays_init, self.permutation, 34 | self.epoch_size, self.device) 35 | 36 | def gen_rays(self, factor=1): 37 | print(" Generating rays, scaling factor", factor) 38 | # Generate rays 39 | self.factor = factor 40 | self.h = self.h_full // factor 41 | self.w = self.w_full // factor 42 | true_factor = self.h_full / self.h 43 | self.intrins = self.intrins_full.scale(1.0 / true_factor) 44 | yy, xx = torch.meshgrid( 45 | torch.arange(self.h, dtype=torch.float32) + 0.5, 46 | torch.arange(self.w, dtype=torch.float32) + 0.5, 47 | ) 48 | xx = (xx - self.intrins.cx) / self.intrins.fx 49 | yy = (yy - self.intrins.cy) / self.intrins.fy 50 | zz = torch.ones_like(xx) 51 | dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention 52 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 53 | dirs = dirs.reshape(1, -1, 3, 1) 54 | del xx, yy, zz 55 | dirs = (self.c2w[:, None, :3, :3] @ dirs)[..., 0] 56 | 57 | if factor != 1: 58 | gt = F.interpolate( 59 | self.gt.permute([0, 3, 1, 2]), size=(self.h, self.w), mode="area" 60 | ).permute([0, 2, 3, 1]) 61 | gt = gt.reshape(self.n_images, -1, 3) 62 | else: 63 | gt = self.gt.reshape(self.n_images, -1, 3) 64 | origins = self.c2w[:, None, :3, 3].expand(-1, self.h * self.w, -1).contiguous() 65 | if self.split == "train": 66 | origins = origins.view(-1, 3) 67 | dirs = dirs.view(-1, 3) 68 | gt = gt.reshape(-1, 3) 69 | 70 | self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt) 71 | self.rays = self.rays_init 72 | 73 | def get_image_size(self, i : int): 74 | # H, W 75 | if hasattr(self, 'image_size'): 76 | return tuple(self.image_size[i]) 77 | else: 78 | return self.h, self.w 79 | -------------------------------------------------------------------------------- /util/llff_dataset.py: -------------------------------------------------------------------------------- 1 | # LLFF-format Forward-facing dataset loader 2 | # Please use the LLFF code to run COLMAP & convert 3 | # 4 | # Adapted from NeX data loading code (NOT using their hand picked bounds) 5 | # Entry point: LLFFDataset 6 | # 7 | # Original: 8 | # Copyright (c) 2021 VISTEC - Vidyasirimedhi Institute of Science and Technology 9 | # Distribute under MIT License 10 | 11 | #from torch.utils.data import Dataset 12 | from scipy.spatial.transform import Rotation 13 | import struct 14 | import json 15 | import glob 16 | import copy 17 | 18 | import numpy as np 19 | import os 20 | import torch 21 | import torch.nn.functional as F 22 | from collections import deque 23 | from tqdm import tqdm 24 | import imageio 25 | import cv2 26 | from .util import Rays, Intrin 27 | from .dataset_base import DatasetBase 28 | from .load_llff import load_llff_data 29 | from typing import Union, Optional 30 | 31 | from svox2.utils import convert_to_ndc 32 | 33 | class LLFFDataset(DatasetBase): 34 | """ 35 | LLFF dataset loader adapted from NeX code 36 | Some arguments are inherited from them and not super useful in our case 37 | """ 38 | def __init__( 39 | self, 40 | root : str, 41 | split : str, 42 | epoch_size : Optional[int] = None, 43 | device: Union[str, torch.device] = "cpu", 44 | permutation: bool = True, 45 | factor: int = 1, 46 | ref_img: str="", 47 | scale : Optional[float]=1.0/4.0, # 4x downsample 48 | dmin : float=-1, 49 | dmax : int=-1, 50 | invz : int= 0, 51 | transform=None, 52 | render_style="", 53 | hold_every=8, 54 | offset=250, 55 | verbose=False, 56 | **kwargs 57 | ): 58 | super().__init__() 59 | if scale is None: 60 | scale = 1.0 / 4.0 # Default 1/4 size for LLFF data since it's huge 61 | self.scale = scale 62 | self.dataset = root 63 | self.epoch_size = epoch_size 64 | self.device = device 65 | self.permutation = permutation 66 | self.split = split 67 | self.transform = transform 68 | self.verbose = verbose 69 | self.sfm = SfMData( 70 | root, 71 | ref_img=ref_img, 72 | dmin=dmin, 73 | dmax=dmax, 74 | invz=invz, 75 | scale=scale, 76 | render_style=render_style, 77 | offset=offset, 78 | hold_every=hold_every, 79 | verbose=verbose 80 | ) 81 | 82 | assert len(self.sfm.cams) == 1, \ 83 | "Currently assuming 1 camera for simplicity, " \ 84 | "please feel free to extend" 85 | 86 | self.imgs = [] 87 | is_train_split = split.endswith('train') 88 | for i, ind in enumerate(self.sfm.imgs): 89 | img = self.sfm.imgs[ind] 90 | img_train_split = ind % hold_every > 0 91 | if is_train_split == img_train_split: 92 | self.imgs.append(img) 93 | self.is_train_split = is_train_split 94 | 95 | self._load_images() 96 | self.n_images, self.h_full, self.w_full, _ = self.gt.shape 97 | assert self.h_full == self.sfm.ref_cam["height"] 98 | assert self.w_full == self.sfm.ref_cam["width"] 99 | 100 | self.intrins_full = Intrin(self.sfm.ref_cam['fx'], 101 | self.sfm.ref_cam['fy'], 102 | self.sfm.ref_cam['px'], 103 | self.sfm.ref_cam['py']) 104 | 105 | self.ndc_coeffs = (2 * self.intrins_full.fx / self.w_full, 106 | 2 * self.intrins_full.fy / self.h_full) 107 | # if self.split == "train": 108 | # self.gen_rays(factor=factor) 109 | # else: 110 | # # Rays are not needed for testing 111 | self.h, self.w = self.h_full, self.w_full 112 | self.intrins = self.intrins_full 113 | self.should_use_background = False # Give warning 114 | 115 | 116 | def _load_images(self): 117 | scale = self.scale 118 | 119 | all_gt = [] 120 | all_c2w = [] 121 | bottom = np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32) 122 | global_w2rc = np.concatenate([self.sfm.ref_img['r'], self.sfm.ref_img['t']], axis=1) 123 | global_w2rc = np.concatenate([global_w2rc, bottom], axis=0).astype(np.float64) 124 | for idx in tqdm(range(len(self.imgs)), disable=not self.verbose): 125 | R = self.imgs[idx]["R"].astype(np.float64) 126 | t = self.imgs[idx]["center"].astype(np.float64) 127 | c2w = np.concatenate([R, t], axis=1) 128 | c2w = np.concatenate([c2w, bottom], axis=0) 129 | # c2w = global_w2rc @ c2w 130 | all_c2w.append(np.expand_dims(c2w.astype(np.float32),0)) 131 | 132 | if 'path' in self.imgs[idx]: 133 | img_path = self.imgs[idx]["path"] 134 | img_path = os.path.join(self.dataset, img_path) 135 | if not os.path.isfile(img_path): 136 | path_noext = os.path.splitext(img_path)[0] 137 | # Hack: also try png 138 | if os.path.exists(path_noext + '.png'): 139 | img_path = path_noext + '.png' 140 | img = imageio.imread(img_path) 141 | if scale != 1 and not self.sfm.use_integral_scaling: 142 | h, w = img.shape[:2] 143 | if self.sfm.dataset_type == "deepview": 144 | newh = int(h * scale) # always floor down height 145 | neww = round(w * scale) 146 | else: 147 | newh = round(h * scale) 148 | neww = round(w * scale) 149 | img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA) 150 | all_gt.append(np.expand_dims(img,0)) 151 | 152 | self.gt = np.concatenate(all_gt, axis=0) / 255.0 153 | 154 | if self.gt.shape[-1] == 4: 155 | # Apply alpha channel 156 | self.gt = self.gt[..., :3] * self.gt[..., 3:] + (1.0 - self.gt[..., 3:]) 157 | 158 | self.c2w = np.concatenate(all_c2w, axis=0) 159 | bds_scale = 1.0 160 | self.z_bounds = [self.sfm.dmin * bds_scale, self.sfm.dmax * bds_scale] 161 | if bds_scale != 1.0: 162 | self.c2w[:, :3, 3] *= bds_scale 163 | 164 | 165 | if not self.is_train_split: 166 | render_c2w = [] 167 | for idx in tqdm(range(len(self.sfm.render_poses)), disable= not self.verbose): 168 | R = self.sfm.render_poses[idx]["R"].astype(np.float64) 169 | t = self.sfm.render_poses[idx]["center"].astype(np.float64) 170 | c2w = np.concatenate([R, t], axis=1) 171 | c2w = np.concatenate([c2w, bottom], axis=0) 172 | render_c2w.append(np.expand_dims(c2w.astype(np.float32), axis=0)) 173 | self.render_c2w = np.concatenate(render_c2w) 174 | if bds_scale != 1.0: 175 | self.render_c2w[:, :3, 3] *= bds_scale 176 | fx = self.sfm.ref_cam['fx'] 177 | fy = self.sfm.ref_cam['fy'] 178 | width = self.sfm.ref_cam['width'] 179 | height = self.sfm.ref_cam['height'] 180 | if self.verbose: 181 | print('z_bounds from LLFF:', self.z_bounds, '(not used)') 182 | 183 | # Padded bounds 184 | radx = 1 + 2 * self.sfm.offset / self.gt.shape[2] 185 | rady = 1 + 2 * self.sfm.offset / self.gt.shape[1] 186 | radz = 1.0 187 | self.scene_center = [0.0, 0.0, 0.0] 188 | self.scene_radius = [radx, rady, radz] 189 | if self.verbose: 190 | print('scene_radius', self.scene_radius) 191 | self.use_sphere_bound = False 192 | 193 | def gen_rays(self, factor=1): 194 | super().gen_rays(factor) 195 | # To NDC (currently, we are normalizing these rays unlike NeRF, 196 | # may not be ideal) 197 | origins, dirs = convert_to_ndc( 198 | self.rays.origins, 199 | self.rays.dirs, 200 | self.ndc_coeffs) 201 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 202 | 203 | self.rays_init = Rays(origins=origins, dirs=dirs, gt=self.rays.gt) 204 | self.rays = self.rays_init 205 | 206 | 207 | 208 | class SfMData: 209 | def __init__( 210 | self, 211 | root, 212 | ref_img="", 213 | scale=1, 214 | dmin=0, 215 | dmax=0, 216 | invz=0, 217 | render_style="", 218 | offset=200, 219 | hold_every=8, 220 | verbose=False 221 | ): 222 | self.scale = scale 223 | self.ref_cam = None 224 | self.ref_img = None 225 | self.render_poses = None 226 | self.dmin = dmin 227 | self.dmax = dmax 228 | self.invz = invz 229 | self.dataset = root 230 | self.dataset_type = "unknown" 231 | self.render_style = render_style 232 | self.hold_every = hold_every 233 | self.white_background = False # change background to white if transparent. 234 | self.index_split = [] # use for split dataset in blender 235 | self.offset = offset 236 | self.verbose = verbose 237 | # Detect dataset type 238 | can_hanle = ( 239 | self.readDeepview(root) 240 | or self.readLLFF(root, ref_img) 241 | or self.readColmap(root) 242 | ) 243 | if not can_hanle: 244 | raise Exception("Unknow dataset type") 245 | # Dataset processing 246 | self.cleanImgs() 247 | self.selectRef(ref_img) 248 | self.scaleAll(scale) 249 | self.selectDepth(dmin, dmax, offset) 250 | 251 | def cleanImgs(self): 252 | """ 253 | Remvoe non exist image from self.imgs 254 | """ 255 | todel = [] 256 | for image in self.imgs: 257 | img_path = self.dataset + "/" + self.imgs[image]["path"] 258 | if "center" not in self.imgs[image] or not os.path.exists(img_path): 259 | todel.append(image) 260 | for it in todel: 261 | del self.imgs[it] 262 | 263 | def selectRef(self, ref_img): 264 | """ 265 | Select Reference image 266 | """ 267 | if ref_img == "" and self.ref_cam is not None and self.ref_img is not None: 268 | return 269 | for img_id, img in self.imgs.items(): 270 | if ref_img in img["path"]: 271 | self.ref_img = img 272 | self.ref_cam = self.cams[img["camera_id"]] 273 | return 274 | raise Exception("reference view not found") 275 | 276 | def selectDepth(self, dmin, dmax, offset): 277 | """ 278 | Select dmin/dmax from planes.txt / bound.txt / argparse 279 | """ 280 | if self.dmin < 0 or self.dmax < 0: 281 | if os.path.exists(self.dataset + "/bounds.txt"): 282 | with open(self.dataset + "/bounds.txt", "r") as fi: 283 | data = [ 284 | np.reshape(np.matrix([float(y) for y in x.split(" ")]), [3, 1]) 285 | for x in fi.readlines()[3:] 286 | ] 287 | ls = [] 288 | for d in data: 289 | v = self.ref_img["r"] * d + self.ref_img["t"] 290 | ls.append(v[2]) 291 | self.dmin = np.min(ls) 292 | self.dmax = np.max(ls) 293 | self.invz = 0 294 | 295 | elif os.path.exists(self.dataset + "/planes.txt"): 296 | with open(self.dataset + "/planes.txt", "r") as fi: 297 | data = [float(x) for x in fi.readline().split(" ")] 298 | if len(data) == 3: 299 | self.dmin, self.dmax, self.invz = data 300 | elif len(data) == 2: 301 | self.dmin, self.dmax = data 302 | elif len(data) == 4: 303 | self.dmin, self.dmax, self.invz, self.offset = data 304 | self.offset = int(self.offset) 305 | if self.verbose: 306 | print(f"Read offset from planes.txt: {self.offset}") 307 | else: 308 | raise Exception("Malform planes.txt") 309 | else: 310 | if self.verbose: 311 | print("no planes.txt or bounds.txt found") 312 | if dmin > 0: 313 | if self.verbose: 314 | print("Overriding dmin %f-> %f" % (self.dmin, dmin)) 315 | self.dmin = dmin 316 | if dmax > 0: 317 | if self.verbose: 318 | print("Overriding dmax %f-> %f" % (self.dmax, dmax)) 319 | self.dmax = dmax 320 | if offset != 200: 321 | if self.verbose: 322 | print(f"Overriding offset {self.offset}-> {offset}") 323 | self.offset = offset 324 | if self.verbose: 325 | print( 326 | "dmin = %f, dmax = %f, invz = %d, offset = %d" 327 | % (self.dmin, self.dmax, self.invz, self.offset) 328 | ) 329 | 330 | def readLLFF(self, dataset, ref_img=""): 331 | """ 332 | Read LLFF 333 | Parameters: 334 | dataset (str): path to datasets 335 | ref_img (str): ref_image file name 336 | Returns: 337 | bool: return True if successful load LLFF data 338 | """ 339 | if not os.path.exists(os.path.join(dataset, "poses_bounds.npy")): 340 | return False 341 | image_dir = os.path.join(dataset, "images") 342 | if not os.path.exists(image_dir) and not os.path.isdir(image_dir): 343 | return False 344 | 345 | self.use_integral_scaling = False 346 | scaled_img_dir = 'images' 347 | scale = self.scale 348 | if scale != 1 and abs((1.0 / scale) - round(1.0 / scale)) < 1e-9: 349 | # Integral scaling 350 | scaled_img_dir = "images_" + str(round(1.0 / scale)) 351 | if os.path.isdir(os.path.join(self.dataset, scaled_img_dir)): 352 | self.use_integral_scaling = True 353 | image_dir = os.path.join(self.dataset, scaled_img_dir) 354 | if self.verbose: 355 | print('Using pre-scaled images from', image_dir) 356 | else: 357 | scaled_img_dir = "images" 358 | 359 | # load R,T 360 | ( 361 | reference_depth, 362 | reference_view_id, 363 | render_poses, 364 | poses, 365 | intrinsic 366 | ) = load_llff_data( 367 | dataset, factor=None, split_train_val=self.hold_every, 368 | render_style=self.render_style 369 | ) 370 | 371 | # NSVF-compatible sort key 372 | def nsvf_sort_key(x): 373 | if len(x) > 2 and x[1] == '_': 374 | return x[2:] 375 | else: 376 | return x 377 | def keep_images(x): 378 | exts = ['.png', '.jpg', '.jpeg', '.exr'] 379 | return [y for y in x if not y.startswith('.') and any((y.lower().endswith(ext) for ext in exts))] 380 | 381 | # get all image of this dataset 382 | images_path = [os.path.join(scaled_img_dir, f) for f in sorted(keep_images(os.listdir(image_dir)), key=nsvf_sort_key)] 383 | # import ipdb;ipdb.set_trace() 384 | # LLFF dataset has only single camera in dataset 385 | if len(intrinsic) == 3: 386 | H, W, f = intrinsic 387 | cx = W / 2.0 388 | cy = H / 2.0 389 | fx = f 390 | fy = f 391 | else: 392 | H, W, fx, fy, cx, cy = intrinsic 393 | 394 | self.cams = {0: buildCamera(W, H, fx, fy, cx, cy)} 395 | 396 | # create render_poses for video render 397 | self.render_poses = buildNerfPoses(render_poses) 398 | 399 | # create imgs pytorch dataset 400 | # we store train and validation together 401 | # but it will sperate later by pytorch dataloader 402 | self.imgs = buildNerfPoses(poses, images_path) 403 | 404 | # if not set ref_cam, use LLFF ref_cam 405 | if ref_img == "": 406 | # restore image id back from reference_view_id 407 | # by adding missing validation index 408 | image_id = reference_view_id + 1 # index 0 alway in validation set 409 | image_id = image_id + (image_id // self.hold_every) # every 8 will be validation set 410 | self.ref_cam = self.cams[0] 411 | 412 | self.ref_img = self.imgs[image_id] # here is reference view from train set 413 | 414 | # if not set dmin/dmax, use LLFF dmin/dmax 415 | if (self.dmin < 0 or self.dmax < 0) and ( 416 | not os.path.exists(dataset + "/planes.txt") 417 | ): 418 | self.dmin = reference_depth[0] 419 | self.dmax = reference_depth[1] 420 | self.dataset_type = "llff" 421 | return True 422 | 423 | def scaleAll(self, scale): 424 | self.ocams = copy.deepcopy(self.cams) # original camera 425 | for cam_id in self.cams.keys(): 426 | cam = self.cams[cam_id] 427 | ocam = self.ocams[cam_id] 428 | 429 | nw = round(ocam["width"] * scale) 430 | nh = round(ocam["height"] * scale) 431 | sw = nw / ocam["width"] 432 | sh = nh / ocam["height"] 433 | cam["fx"] = ocam["fx"] * sw 434 | cam["fy"] = ocam["fy"] * sh 435 | # TODO: What is the correct way? 436 | # cam["px"] = (ocam["px"] + 0.5) * sw - 0.5 437 | # cam["py"] = (ocam["py"] + 0.5) * sh - 0.5 438 | cam["px"] = ocam["px"] * sw 439 | cam["py"] = ocam["py"] * sh 440 | cam["width"] = nw 441 | cam["height"] = nh 442 | 443 | def readDeepview(self, dataset): 444 | if not os.path.exists(os.path.join(dataset, "models.json")): 445 | return False 446 | 447 | self.cams, self.imgs = readCameraDeepview(dataset) 448 | self.dataset_type = "deepview" 449 | return True 450 | 451 | def readColmap(self, dataset): 452 | sparse_folder = dataset + "/dense/sparse/" 453 | image_folder = dataset + "/dense/images/" 454 | if (not os.path.exists(image_folder)) or (not os.path.exists(sparse_folder)): 455 | return False 456 | 457 | self.imgs = readImagesBinary(os.path.join(sparse_folder, "images.bin")) 458 | self.cams = readCamerasBinary(sparse_folder + "/cameras.bin") 459 | self.dataset_type = "colmap" 460 | return True 461 | 462 | 463 | def readCameraDeepview(dataset): 464 | cams = {} 465 | imgs = {} 466 | with open(os.path.join(dataset, "models.json"), "r") as fi: 467 | js = json.load(fi) 468 | for i, cam in enumerate(js): 469 | for j, cam_info in enumerate(cam): 470 | img_id = cam_info["relative_path"] 471 | cam_id = img_id.split("/")[0] 472 | 473 | rotation = ( 474 | Rotation.from_rotvec(np.float32(cam_info["orientation"])) 475 | .as_matrix() 476 | .astype(np.float32) 477 | ) 478 | position = np.array([cam_info["position"]], dtype="f").reshape(3, 1) 479 | 480 | if i == 0: 481 | cams[cam_id] = { 482 | "width": int(cam_info["width"]), 483 | "height": int(cam_info["height"]), 484 | "fx": cam_info["focal_length"], 485 | "fy": cam_info["focal_length"] * cam_info["pixel_aspect_ratio"], 486 | "px": cam_info["principal_point"][0], 487 | "py": cam_info["principal_point"][1], 488 | } 489 | imgs[img_id] = { 490 | "camera_id": cam_id, 491 | "r": rotation, 492 | "t": -np.matmul(rotation, position), 493 | "R": rotation.transpose(), 494 | "center": position, 495 | "path": cam_info["relative_path"], 496 | } 497 | return cams, imgs 498 | 499 | 500 | def readImagesBinary(path): 501 | images = {} 502 | f = open(path, "rb") 503 | num_reg_images = struct.unpack("Q", f.read(8))[0] 504 | for i in range(num_reg_images): 505 | image_id = struct.unpack("I", f.read(4))[0] 506 | qv = np.fromfile(f, np.double, 4) 507 | 508 | tv = np.fromfile(f, np.double, 3) 509 | camera_id = struct.unpack("I", f.read(4))[0] 510 | 511 | name = "" 512 | name_char = -1 513 | while name_char != b"\x00": 514 | name_char = f.read(1) 515 | if name_char != b"\x00": 516 | name += name_char.decode("ascii") 517 | 518 | num_points2D = struct.unpack("Q", f.read(8))[0] 519 | 520 | for i in range(num_points2D): 521 | f.read(8 * 2) # for x and y 522 | f.read(8) # for point3d Iid 523 | 524 | r = Rotation.from_quat([qv[1], qv[2], qv[3], qv[0]]).as_dcm().astype(np.float32) 525 | t = tv.astype(np.float32).reshape(3, 1) 526 | 527 | R = np.transpose(r) 528 | center = -R @ t 529 | # storage is scalar first, from_quat takes scalar last. 530 | images[image_id] = { 531 | "camera_id": camera_id, 532 | "r": r, 533 | "t": t, 534 | "R": R, 535 | "center": center, 536 | "path": "dense/images/" + name, 537 | } 538 | 539 | f.close() 540 | return images 541 | 542 | 543 | def readCamerasBinary(path): 544 | cams = {} 545 | f = open(path, "rb") 546 | num_cameras = struct.unpack("Q", f.read(8))[0] 547 | 548 | # becomes pinhole camera model , 4 parameters 549 | for i in range(num_cameras): 550 | camera_id = struct.unpack("I", f.read(4))[0] 551 | model_id = struct.unpack("i", f.read(4))[0] 552 | 553 | width = struct.unpack("Q", f.read(8))[0] 554 | height = struct.unpack("Q", f.read(8))[0] 555 | 556 | fx = struct.unpack("d", f.read(8))[0] 557 | fy = struct.unpack("d", f.read(8))[0] 558 | px = struct.unpack("d", f.read(8))[0] 559 | py = struct.unpack("d", f.read(8))[0] 560 | 561 | cams[camera_id] = { 562 | "width": width, 563 | "height": height, 564 | "fx": fx, 565 | "fy": fy, 566 | "px": px, 567 | "py": py, 568 | } 569 | # fx, fy, cx, cy 570 | f.close() 571 | return cams 572 | 573 | 574 | def nerf_pose_to_ours(cam): 575 | R = cam[:3, :3] 576 | center = cam[:3, 3].reshape([3, 1]) 577 | center[1:] *= -1 578 | R[1:, 0] *= -1 579 | R[0, 1:] *= -1 580 | 581 | r = np.transpose(R) 582 | t = -r @ center 583 | return R, center, r, t 584 | 585 | 586 | def buildCamera(W, H, fx, fy, cx, cy): 587 | return { 588 | "width": int(W), 589 | "height": int(H), 590 | "fx": float(fx), 591 | "fy": float(fy), 592 | "px": float(cx), 593 | "py": float(cy), 594 | } 595 | 596 | 597 | def buildNerfPoses(poses, images_path=None): 598 | output = {} 599 | for poses_id in range(poses.shape[0]): 600 | R, center, r, t = nerf_pose_to_ours(poses[poses_id].astype(np.float32)) 601 | output[poses_id] = {"camera_id": 0, "r": r, "t": t, "R": R, "center": center} 602 | if images_path is not None: 603 | output[poses_id]["path"] = images_path[poses_id] 604 | 605 | return output 606 | -------------------------------------------------------------------------------- /util/load_llff.py: -------------------------------------------------------------------------------- 1 | # Originally from LLFF 2 | # https://github.com/Fyusion/LLFF 3 | # With minor modifications from NeX 4 | # https://github.com/nex-mpi/nex-code 5 | 6 | import numpy as np 7 | import os 8 | import imageio 9 | 10 | def get_image_size(path : str): 11 | """ 12 | Get image size without loading it 13 | """ 14 | from PIL import Image 15 | im = Image.open(path) 16 | return im.size[1], im.size[0] # H, W 17 | 18 | def _minify(basedir, factors=[], resolutions=[]): 19 | needtoload = False 20 | for r in factors: 21 | imgdir = os.path.join(basedir, "images_{}".format(r)) 22 | if not os.path.exists(imgdir): 23 | needtoload = True 24 | for r in resolutions: 25 | imgdir = os.path.join(basedir, "images_{}x{}".format(r[1], r[0])) 26 | if not os.path.exists(imgdir): 27 | needtoload = True 28 | if not needtoload: 29 | return 30 | 31 | from shutil import copy 32 | from subprocess import check_output 33 | 34 | imgdir = os.path.join(basedir, "images") 35 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 36 | imgs = [ 37 | f 38 | for f in imgs 39 | if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]]) 40 | ] 41 | imgdir_orig = imgdir 42 | 43 | wd = os.getcwd() 44 | 45 | for r in factors + resolutions: 46 | if isinstance(r, int): 47 | name = "images_{}".format(r) 48 | resizearg = "{}%".format(100.0 / r) 49 | else: 50 | name = "images_{}x{}".format(r[1], r[0]) 51 | resizearg = "{}x{}".format(r[1], r[0]) 52 | imgdir = os.path.join(basedir, name) 53 | if os.path.exists(imgdir): 54 | continue 55 | 56 | print("Minifying", r, basedir) 57 | 58 | os.makedirs(imgdir) 59 | check_output("cp {}/* {}".format(imgdir_orig, imgdir), shell=True) 60 | 61 | ext = imgs[0].split(".")[-1] 62 | args = " ".join( 63 | ["mogrify", "-resize", resizearg, "-format", "png", "*.{}".format(ext)] 64 | ) 65 | print(args) 66 | os.chdir(imgdir) 67 | check_output(args, shell=True) 68 | os.chdir(wd) 69 | 70 | if ext != "png": 71 | check_output("rm {}/*.{}".format(imgdir, ext), shell=True) 72 | print("Removed duplicates") 73 | print("Done") 74 | 75 | 76 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 77 | poses_arr = np.load(os.path.join(basedir, "poses_bounds.npy")) 78 | shape = 5 79 | 80 | # poss llff arr [3, 5, images] [R | T | intrinsic] 81 | # intrinsic same for all images 82 | if os.path.isfile(os.path.join(basedir, "hwf_cxcy.npy")): 83 | shape = 4 84 | # h, w, fx, fy, cx, cy 85 | intrinsic_arr = np.load(os.path.join(basedir, "hwf_cxcy.npy")) 86 | 87 | poses = poses_arr[:, :-2].reshape([-1, 3, shape]).transpose([1, 2, 0]) 88 | bds = poses_arr[:, -2:].transpose([1, 0]) 89 | 90 | if not os.path.isfile(os.path.join(basedir, "hwf_cxcy.npy")): 91 | intrinsic_arr = poses[:, 4, 0] 92 | poses = poses[:, :4, :] 93 | 94 | img0 = [ 95 | os.path.join(basedir, "images", f) 96 | for f in sorted(os.listdir(os.path.join(basedir, "images"))) 97 | if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") 98 | ][0] 99 | sh = get_image_size(img0) 100 | 101 | sfx = "" 102 | if factor is not None: 103 | sfx = "_{}".format(factor) 104 | _minify(basedir, factors=[factor]) 105 | factor = factor 106 | elif height is not None: 107 | factor = sh[0] / float(height) 108 | width = int(sh[1] / factor) 109 | _minify(basedir, resolutions=[[height, width]]) 110 | sfx = "_{}x{}".format(width, height) 111 | elif width is not None: 112 | factor = sh[1] / float(width) 113 | height = int(sh[0] / factor) 114 | _minify(basedir, resolutions=[[height, width]]) 115 | sfx = "_{}x{}".format(width, height) 116 | else: 117 | factor = 1 118 | 119 | imgdir = os.path.join(basedir, "images" + sfx) 120 | if not os.path.exists(imgdir): 121 | print(imgdir, "does not exist, returning") 122 | return 123 | 124 | imgfiles = [ 125 | os.path.join(imgdir, f) 126 | for f in sorted(os.listdir(imgdir)) 127 | if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") 128 | ] 129 | if poses.shape[-1] != len(imgfiles): 130 | print( 131 | "Mismatch between imgs {} and poses {} !!!!".format( 132 | len(imgfiles), poses.shape[-1] 133 | ) 134 | ) 135 | return 136 | 137 | if not load_imgs: 138 | return poses, bds, intrinsic_arr 139 | 140 | def imread(f): 141 | if f.endswith("png"): 142 | return imageio.imread(f, ignoregamma=True) 143 | else: 144 | return imageio.imread(f) 145 | 146 | imgs = imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles] 147 | imgs = np.stack(imgs, -1) 148 | 149 | print("Loaded image data", imgs.shape, poses[:, -1, 0]) 150 | return poses, bds, imgs, intrinsic_arr 151 | 152 | 153 | def normalize(x): 154 | return x / np.linalg.norm(x) 155 | 156 | 157 | def viewmatrix(z, up, pos): 158 | vec2 = normalize(z) 159 | vec1_avg = up 160 | vec0 = normalize(np.cross(vec1_avg, vec2)) 161 | vec1 = normalize(np.cross(vec2, vec0)) 162 | m = np.stack([vec0, vec1, vec2, pos], 1) 163 | return m 164 | 165 | 166 | def ptstocam(pts, c2w): 167 | tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0] 168 | return tt 169 | 170 | 171 | def poses_avg(poses): 172 | # poses [images, 3, 4] not [images, 3, 5] 173 | # hwf = poses[0, :3, -1:] 174 | 175 | center = poses[:, :3, 3].mean(0) 176 | vec2 = normalize(poses[:, :3, 2].sum(0)) 177 | up = poses[:, :3, 1].sum(0) 178 | c2w = np.concatenate([viewmatrix(vec2, up, center)], 1) 179 | 180 | return c2w 181 | 182 | 183 | def render_path_axis(c2w, up, ax, rad, focal, N): 184 | render_poses = [] 185 | center = c2w[:, 3] 186 | hwf = c2w[:, 4:5] 187 | v = c2w[:, ax] * rad 188 | for t in np.linspace(-1.0, 1.0, N + 1)[:-1]: 189 | c = center + t * v 190 | z = normalize(c - (center - focal * c2w[:, 2])) 191 | # render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 192 | render_poses.append(viewmatrix(z, up, c)) 193 | return render_poses 194 | 195 | 196 | def render_path_spiral(c2w, up, rads, focal, zrate, rots, N): 197 | render_poses = [] 198 | rads = np.array(list(rads) + [1.0]) 199 | # hwf = c2w[:,4:5] 200 | 201 | for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]: 202 | c = np.dot( 203 | c2w[:3, :4], 204 | np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]) 205 | * rads, 206 | ) 207 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) 208 | # render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 209 | render_poses.append(viewmatrix(z, up, c)) 210 | return render_poses 211 | 212 | 213 | def recenter_poses(poses): 214 | # poses [images, 3, 4] 215 | poses_ = poses + 0 216 | bottom = np.reshape([0, 0, 0, 1.0], [1, 4]) 217 | c2w = poses_avg(poses) 218 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 219 | 220 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 221 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 222 | 223 | poses = np.linalg.inv(c2w) @ poses 224 | poses_[:, :3, :4] = poses[:, :3, :4] 225 | poses = poses_ 226 | return poses 227 | 228 | 229 | def spherify_poses(poses, bds): 230 | p34_to_44 = lambda p: np.concatenate( 231 | [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1 232 | ) 233 | 234 | rays_d = poses[:, :3, 2:3] 235 | rays_o = poses[:, :3, 3:4] 236 | 237 | def min_line_dist(rays_o, rays_d): 238 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) 239 | b_i = -A_i @ rays_o 240 | pt_mindist = np.squeeze( 241 | -np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0) 242 | ) 243 | return pt_mindist 244 | 245 | pt_mindist = min_line_dist(rays_o, rays_d) 246 | 247 | center = pt_mindist 248 | up = (poses[:, :3, 3] - center).mean(0) 249 | 250 | vec0 = normalize(up) 251 | vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0)) 252 | vec2 = normalize(np.cross(vec0, vec1)) 253 | pos = center 254 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 255 | 256 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) 257 | 258 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) 259 | 260 | sc = 1.0 / rad 261 | poses_reset[:, :3, 3] *= sc 262 | bds *= sc 263 | rad *= sc 264 | 265 | centroid = np.mean(poses_reset[:, :3, 3], 0) 266 | zh = centroid[2] 267 | radcircle = np.sqrt(rad ** 2 - zh ** 2) 268 | new_poses = [] 269 | 270 | for th in np.linspace(0.0, 2.0 * np.pi, 120): 271 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 272 | up = np.array([0, 0, -1.0]) 273 | 274 | vec2 = normalize(camorigin) 275 | vec0 = normalize(np.cross(vec2, up)) 276 | vec1 = normalize(np.cross(vec2, vec0)) 277 | pos = camorigin 278 | p = np.stack([vec0, vec1, vec2, pos], 1) 279 | 280 | new_poses.append(p) 281 | 282 | new_poses = np.stack(new_poses, 0) 283 | 284 | new_poses = np.concatenate( 285 | [new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1 286 | ) 287 | poses_reset = np.concatenate( 288 | [ 289 | poses_reset[:, :3, :4], 290 | np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape), 291 | ], 292 | -1, 293 | ) 294 | 295 | return poses_reset, new_poses, bds 296 | 297 | 298 | def load_llff_data( 299 | basedir, 300 | factor=None, 301 | recenter=True, 302 | bd_factor=0.75, 303 | spherify=False, 304 | # path_zflat=False, 305 | split_train_val=8, 306 | render_style="", 307 | verbose=False 308 | ): 309 | 310 | # poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 311 | poses, bds, intrinsic = _load_data( 312 | basedir, factor=factor, load_imgs=False 313 | ) # factor=8 downsamples original imgs by 8x 314 | if verbose: 315 | print("Loaded LLFF data", basedir, bds.min(), bds.max()) 316 | 317 | # Correct rotation matrix ordering and move variable dim to axis 0 318 | # poses [R | T] [3, 4, images] 319 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 320 | # poses [3, 4, images] --> [images, 3, 4] 321 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 322 | 323 | # imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 324 | # images = imgs 325 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 326 | 327 | # Rescale if bd_factor is provided 328 | sc = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor) 329 | poses[:, :3, 3] *= sc 330 | bds *= sc 331 | 332 | if recenter: 333 | poses = recenter_poses(poses) 334 | 335 | if spherify: 336 | poses, render_poses, bds = spherify_poses(poses, bds) 337 | else: 338 | c2w = poses_avg(poses) 339 | if verbose: 340 | print("recentered", c2w.shape) 341 | 342 | ## Get spiral 343 | # Get average pose 344 | up = normalize(poses[:, :3, 1].sum(0)) 345 | 346 | close_depth, inf_depth = -1, -1 347 | # Find a reasonable "focus depth" for this dataset 348 | # if os.path.exists(os.path.join(basedir, "planes_spiral.txt")): 349 | # with open(os.path.join(basedir, "planes_spiral.txt"), "r") as fi: 350 | # data = [float(x) for x in fi.readline().split(" ")] 351 | # dmin, dmax = data[:2] 352 | # close_depth = dmin * 0.9 353 | # inf_depth = dmax * 5.0 354 | # elif os.path.exists(os.path.join(basedir, "planes.txt")): 355 | # with open(os.path.join(basedir, "planes.txt"), "r") as fi: 356 | # data = [float(x) for x in fi.readline().split(" ")] 357 | # if len(data) == 3: 358 | # dmin, dmax, invz = data 359 | # elif len(data) == 4: 360 | # dmin, dmax, invz, _ = data 361 | # close_depth = dmin * 0.9 362 | # inf_depth = dmax * 5.0 363 | 364 | prev_close, prev_inf = close_depth, inf_depth 365 | if close_depth < 0 or inf_depth < 0 or render_style == "llff": 366 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 5.0 367 | 368 | if render_style == "shiny": 369 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 5.0 370 | if close_depth < prev_close: 371 | close_depth = prev_close 372 | if inf_depth > prev_inf: 373 | inf_depth = prev_inf 374 | 375 | dt = 0.75 376 | mean_dz = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) 377 | focal = mean_dz 378 | 379 | # Get radii for spiral path 380 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 381 | rads = np.percentile(np.abs(tt), 90, 0) 382 | c2w_path = c2w 383 | N_views = 120 384 | N_rots = 2 385 | # if path_zflat: 386 | # # zloc = np.percentile(tt, 10, 0)[2] 387 | # zloc = -close_depth * 0.1 388 | # c2w_path[:3, 3] = c2w_path[:3, 3] + zloc * c2w_path[:3, 2] 389 | # rads[2] = 0.0 390 | # N_rots = 1 391 | # N_views /= 2 392 | 393 | render_poses = render_path_spiral( 394 | c2w_path, up, rads, focal, zrate=0.5, rots=N_rots, N=N_views 395 | ) 396 | 397 | render_poses = np.array(render_poses).astype(np.float32) 398 | # reference_view_id should stay in train set only 399 | validation_ids = np.arange(poses.shape[0]) 400 | validation_ids[::split_train_val] = -1 401 | validation_ids = validation_ids < 0 402 | train_ids = np.logical_not(validation_ids) 403 | train_poses = poses[train_ids] 404 | train_bds = bds[train_ids] 405 | c2w = poses_avg(train_poses) 406 | 407 | dists = np.sum(np.square(c2w[:3, 3] - train_poses[:, :3, 3]), -1) 408 | reference_view_id = np.argmin(dists) 409 | reference_depth = train_bds[reference_view_id] 410 | if verbose: 411 | print(reference_depth) 412 | 413 | return ( 414 | reference_depth, 415 | reference_view_id, 416 | render_poses, 417 | poses, 418 | intrinsic 419 | ) 420 | -------------------------------------------------------------------------------- /util/nerf_dataset.py: -------------------------------------------------------------------------------- 1 | # Standard NeRF Blender dataset loader 2 | from .util import Rays, Intrin, select_or_shuffle_rays 3 | from .dataset_base import DatasetBase 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import NamedTuple, Optional, Union 7 | from os import path 8 | import imageio 9 | from tqdm import tqdm 10 | import cv2 11 | import json 12 | import numpy as np 13 | 14 | 15 | class NeRFDataset(DatasetBase): 16 | """ 17 | NeRF dataset loader 18 | """ 19 | 20 | focal: float 21 | c2w: torch.Tensor # (n_images, 4, 4) 22 | gt: torch.Tensor # (n_images, h, w, 3) 23 | h: int 24 | w: int 25 | n_images: int 26 | rays: Optional[Rays] 27 | split: str 28 | 29 | def __init__( 30 | self, 31 | root, 32 | split, 33 | epoch_size : Optional[int] = None, 34 | device: Union[str, torch.device] = "cpu", 35 | scene_scale: Optional[float] = None, 36 | factor: int = 1, 37 | scale : Optional[float] = None, 38 | permutation: bool = True, 39 | white_bkgd: bool = True, 40 | n_images = None, 41 | **kwargs 42 | ): 43 | super().__init__() 44 | assert path.isdir(root), f"'{root}' is not a directory" 45 | 46 | if scene_scale is None: 47 | scene_scale = 2/3 48 | if scale is None: 49 | scale = 1.0 50 | self.device = device 51 | self.permutation = permutation 52 | self.epoch_size = epoch_size 53 | all_c2w = [] 54 | all_gt = [] 55 | 56 | split_name = split if split != "test_train" else "train" 57 | data_path = path.join(root, split_name) 58 | data_json = path.join(root, "transforms_" + split_name + ".json") 59 | 60 | print("LOAD DATA", data_path) 61 | 62 | j = json.load(open(data_json, "r")) 63 | 64 | # OpenGL -> OpenCV 65 | cam_trans = torch.diag(torch.tensor([1, -1, -1, 1], dtype=torch.float32)) 66 | 67 | for frame in tqdm(j["frames"]): 68 | fpath = path.join(data_path, path.basename(frame["file_path"]) + ".png") 69 | c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32) 70 | c2w = c2w @ cam_trans # To OpenCV 71 | 72 | im_gt = imageio.imread(fpath) 73 | if scale < 1.0: 74 | full_size = list(im_gt.shape[:2]) 75 | rsz_h, rsz_w = [round(hw * scale) for hw in full_size] 76 | im_gt = cv2.resize(im_gt, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA) 77 | 78 | all_c2w.append(c2w) 79 | all_gt.append(torch.from_numpy(im_gt)) 80 | focal = float( 81 | 0.5 * all_gt[0].shape[1] / np.tan(0.5 * j["camera_angle_x"]) 82 | ) 83 | self.c2w = torch.stack(all_c2w) 84 | self.c2w[:, :3, 3] *= scene_scale 85 | 86 | self.gt = torch.stack(all_gt).float() / 255.0 87 | if self.gt.size(-1) == 4: 88 | if white_bkgd: 89 | # Apply alpha channel 90 | self.gt = self.gt[..., :3] * self.gt[..., 3:] + (1.0 - self.gt[..., 3:]) 91 | else: 92 | self.gt = self.gt[..., :3] 93 | 94 | self.n_images, self.h_full, self.w_full, _ = self.gt.shape 95 | # Choose a subset of training images 96 | if n_images is not None: 97 | if n_images > self.n_images: 98 | print(f'using {self.n_images} available training views instead of the requested {n_images}.') 99 | n_images = self.n_images 100 | self.n_images = n_images 101 | self.gt = self.gt[0:n_images,...] 102 | self.c2w = self.c2w[0:n_images,...] 103 | 104 | self.intrins_full : Intrin = Intrin(focal, focal, 105 | self.w_full * 0.5, 106 | self.h_full * 0.5) 107 | 108 | self.split = split 109 | self.scene_scale = scene_scale 110 | if self.split == "train": 111 | self.gen_rays(factor=factor) 112 | else: 113 | # Rays are not needed for testing 114 | self.h, self.w = self.h_full, self.w_full 115 | self.intrins : Intrin = self.intrins_full 116 | 117 | self.should_use_background = False # Give warning 118 | 119 | -------------------------------------------------------------------------------- /util/nsvf_dataset.py: -------------------------------------------------------------------------------- 1 | # Extended NSVF-format dataset loader 2 | # This is a more sane format vs the NeRF formats 3 | 4 | from .util import Rays, Intrin, similarity_from_cameras 5 | from .dataset_base import DatasetBase 6 | import torch 7 | import torch.nn.functional as F 8 | from typing import NamedTuple, Optional, Union 9 | from os import path 10 | import os 11 | import cv2 12 | import imageio 13 | from tqdm import tqdm 14 | import json 15 | import numpy as np 16 | from warnings import warn 17 | 18 | 19 | class NSVFDataset(DatasetBase): 20 | """ 21 | Extended NSVF dataset loader 22 | """ 23 | 24 | focal: float 25 | c2w: torch.Tensor # (n_images, 4, 4) 26 | gt: torch.Tensor # (n_images, h, w, 3) 27 | h: int 28 | w: int 29 | n_images: int 30 | rays: Optional[Rays] 31 | split: str 32 | 33 | def __init__( 34 | self, 35 | root, 36 | split, 37 | epoch_size : Optional[int] = None, 38 | device: Union[str, torch.device] = "cpu", 39 | scene_scale: Optional[float] = None, # Scene scaling 40 | factor: int = 1, # Image scaling (on ray gen; use gen_rays(factor) to dynamically change scale) 41 | scale : Optional[float] = 1.0, # Image scaling (on load) 42 | permutation: bool = True, 43 | white_bkgd: bool = True, 44 | normalize_by_bbox: bool = False, 45 | data_bbox_scale : float = 1.1, # Only used if normalize_by_bbox 46 | cam_scale_factor : float = 0.95, 47 | normalize_by_camera: bool = True, 48 | **kwargs 49 | ): 50 | super().__init__() 51 | assert path.isdir(root), f"'{root}' is not a directory" 52 | 53 | if scene_scale is None: 54 | scene_scale = 1.0 55 | if scale is None: 56 | scale = 1.0 57 | 58 | self.device = device 59 | self.permutation = permutation 60 | self.epoch_size = epoch_size 61 | all_c2w = [] 62 | all_gt = [] 63 | 64 | split_name = split if split != "test_train" else "train" 65 | 66 | print("LOAD NSVF DATA", root, 'split', split) 67 | 68 | self.split = split 69 | 70 | def sort_key(x): 71 | if len(x) > 2 and x[1] == "_": 72 | return x[2:] 73 | return x 74 | def look_for_dir(cands, required=True): 75 | for cand in cands: 76 | if path.isdir(path.join(root, cand)): 77 | return cand 78 | if required: 79 | assert False, "None of " + str(cands) + " found in data directory" 80 | return "" 81 | 82 | img_dir_name = look_for_dir(["images", "image", "rgb"]) 83 | pose_dir_name = look_for_dir(["poses", "pose"]) 84 | # intrin_dir_name = look_for_dir(["intrin"], required=False) 85 | img_files = sorted(os.listdir(path.join(root, img_dir_name)), key=sort_key) 86 | 87 | # Select subset of files 88 | if self.split == "train" or self.split == "test_train": 89 | img_files = [x for x in img_files if x.startswith("0_")] 90 | elif self.split == "val": 91 | img_files = [x for x in img_files if x.startswith("1_")] 92 | elif self.split == "test": 93 | test_img_files = [x for x in img_files if x.startswith("2_")] 94 | if len(test_img_files) == 0: 95 | test_img_files = [x for x in img_files if x.startswith("1_")] 96 | img_files = test_img_files 97 | 98 | assert len(img_files) > 0, "No matching images in directory: " + path.join(data_dir, img_dir_name) 99 | self.img_files = img_files 100 | 101 | dynamic_resize = scale < 1 102 | self.use_integral_scaling = False 103 | scaled_img_dir = '' 104 | if dynamic_resize and abs((1.0 / scale) - round(1.0 / scale)) < 1e-9: 105 | resized_dir = img_dir_name + "_" + str(round(1.0 / scale)) 106 | if path.exists(path.join(root, resized_dir)): 107 | img_dir_name = resized_dir 108 | dynamic_resize = False 109 | print("> Pre-resized images from", img_dir_name) 110 | if dynamic_resize: 111 | print("> WARNING: Dynamically resizing images") 112 | 113 | full_size = [0, 0] 114 | rsz_h = rsz_w = 0 115 | 116 | for img_fname in tqdm(img_files): 117 | img_path = path.join(root, img_dir_name, img_fname) 118 | image = imageio.imread(img_path) 119 | pose_fname = path.splitext(img_fname)[0] + ".txt" 120 | pose_path = path.join(root, pose_dir_name, pose_fname) 121 | # intrin_path = path.join(root, intrin_dir_name, pose_fname) 122 | 123 | cam_mtx = np.loadtxt(pose_path).reshape(-1, 4) 124 | if len(cam_mtx) == 3: 125 | bottom = np.array([[0.0, 0.0, 0.0, 1.0]]) 126 | cam_mtx = np.concatenate([cam_mtx, bottom], axis=0) 127 | all_c2w.append(torch.from_numpy(cam_mtx)) # C2W (4, 4) OpenCV 128 | full_size = list(image.shape[:2]) 129 | rsz_h, rsz_w = [round(hw * scale) for hw in full_size] 130 | if dynamic_resize: 131 | image = cv2.resize(image, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA) 132 | 133 | all_gt.append(torch.from_numpy(image)) 134 | 135 | 136 | self.c2w_f64 = torch.stack(all_c2w) 137 | 138 | print('NORMALIZE BY?', 'bbox' if normalize_by_bbox else 'camera' if normalize_by_camera else 'manual') 139 | if normalize_by_bbox: 140 | # Not used, but could be helpful 141 | bbox_path = path.join(root, "bbox.txt") 142 | if path.exists(bbox_path): 143 | bbox_data = np.loadtxt(bbox_path) 144 | center = (bbox_data[:3] + bbox_data[3:6]) * 0.5 145 | radius = (bbox_data[3:6] - bbox_data[:3]) * 0.5 * data_bbox_scale 146 | 147 | # Recenter 148 | self.c2w_f64[:, :3, 3] -= center 149 | # Rescale 150 | scene_scale = 1.0 / radius.max() 151 | else: 152 | warn('normalize_by_bbox=True but bbox.txt was not available') 153 | elif normalize_by_camera: 154 | norm_pose_files = sorted(os.listdir(path.join(root, pose_dir_name)), key=sort_key) 155 | norm_poses = np.stack([np.loadtxt(path.join(root, pose_dir_name, x)).reshape(-1, 4) 156 | for x in norm_pose_files], axis=0) 157 | 158 | # Select subset of files 159 | T, sscale = similarity_from_cameras(norm_poses) 160 | 161 | self.c2w_f64 = torch.from_numpy(T) @ self.c2w_f64 162 | scene_scale = cam_scale_factor * sscale 163 | 164 | # center = np.mean(norm_poses[:, :3, 3], axis=0) 165 | # radius = np.median(np.linalg.norm(norm_poses[:, :3, 3] - center, axis=-1)) 166 | # self.c2w_f64[:, :3, 3] -= center 167 | # scene_scale = cam_scale_factor / radius 168 | # print('good', self.c2w_f64[:2], scene_scale) 169 | 170 | print('scene_scale', scene_scale) 171 | self.c2w_f64[:, :3, 3] *= scene_scale 172 | self.c2w = self.c2w_f64.float() 173 | 174 | self.gt = torch.stack(all_gt).double() / 255.0 175 | if self.gt.size(-1) == 4: 176 | if white_bkgd: 177 | # Apply alpha channel 178 | self.gt = self.gt[..., :3] * self.gt[..., 3:] + (1.0 - self.gt[..., 3:]) 179 | else: 180 | self.gt = self.gt[..., :3] 181 | self.gt = self.gt.float() 182 | 183 | assert full_size[0] > 0 and full_size[1] > 0, "Empty images" 184 | self.n_images, self.h_full, self.w_full, _ = self.gt.shape 185 | 186 | intrin_path = path.join(root, "intrinsics.txt") 187 | assert path.exists(intrin_path), "intrinsics unavailable" 188 | try: 189 | K: np.ndarray = np.loadtxt(intrin_path) 190 | fx = K[0, 0] 191 | fy = K[1, 1] 192 | cx = K[0, 2] 193 | cy = K[1, 2] 194 | except: 195 | # Weird format sometimes in NSVF data 196 | with open(intrin_path, "r") as f: 197 | spl = f.readline().split() 198 | fx = fy = float(spl[0]) 199 | cx = float(spl[1]) 200 | cy = float(spl[2]) 201 | if scale < 1.0: 202 | scale_w = rsz_w / full_size[1] 203 | scale_h = rsz_h / full_size[0] 204 | fx *= scale_w 205 | cx *= scale_w 206 | fy *= scale_h 207 | cy *= scale_h 208 | 209 | self.intrins_full : Intrin = Intrin(fx, fy, cx, cy) 210 | print(' intrinsics (loaded reso)', self.intrins_full) 211 | 212 | self.scene_scale = scene_scale 213 | if self.split == "train": 214 | self.gen_rays(factor=factor) 215 | else: 216 | # Rays are not needed for testing 217 | self.h, self.w = self.h_full, self.w_full 218 | self.intrins : Intrin = self.intrins_full 219 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.cuda 3 | import torch.nn.functional as F 4 | from typing import Optional, Union, List 5 | from dataclasses import dataclass 6 | import numpy as np 7 | import cv2 8 | from scipy.spatial.transform import Rotation 9 | from scipy.interpolate import CubicSpline 10 | from matplotlib import pyplot as plt 11 | from warnings import warn 12 | 13 | 14 | @dataclass 15 | class Rays: 16 | origins: Union[torch.Tensor, List[torch.Tensor]] 17 | dirs: Union[torch.Tensor, List[torch.Tensor]] 18 | gt: Union[torch.Tensor, List[torch.Tensor]] 19 | 20 | def to(self, *args, **kwargs): 21 | origins = self.origins.to(*args, **kwargs) 22 | dirs = self.dirs.to(*args, **kwargs) 23 | gt = self.gt.to(*args, **kwargs) 24 | return Rays(origins, dirs, gt) 25 | 26 | def __getitem__(self, key): 27 | origins = self.origins[key] 28 | dirs = self.dirs[key] 29 | gt = self.gt[key] 30 | return Rays(origins, dirs, gt) 31 | 32 | def __len__(self): 33 | return self.origins.size(0) 34 | 35 | @dataclass 36 | class Intrin: 37 | fx: Union[float, torch.Tensor] 38 | fy: Union[float, torch.Tensor] 39 | cx: Union[float, torch.Tensor] 40 | cy: Union[float, torch.Tensor] 41 | 42 | def scale(self, scaling: float): 43 | return Intrin( 44 | self.fx * scaling, 45 | self.fy * scaling, 46 | self.cx * scaling, 47 | self.cy * scaling 48 | ) 49 | 50 | def get(self, field:str, image_id:int=0): 51 | val = self.__dict__[field] 52 | return val if isinstance(val, float) else val[image_id].item() 53 | 54 | 55 | class Timing: 56 | """ 57 | Timing environment 58 | usage: 59 | with Timing("message"): 60 | your commands here 61 | will print CUDA runtime in ms 62 | """ 63 | 64 | def __init__(self, name): 65 | self.name = name 66 | 67 | def __enter__(self): 68 | self.start = torch.cuda.Event(enable_timing=True) 69 | self.end = torch.cuda.Event(enable_timing=True) 70 | self.start.record() 71 | 72 | def __exit__(self, type, value, traceback): 73 | self.end.record() 74 | torch.cuda.synchronize() 75 | print(self.name, "elapsed", self.start.elapsed_time(self.end), "ms") 76 | 77 | 78 | def get_expon_lr_func( 79 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 80 | ): 81 | """ 82 | Continuous learning rate decay function. Adapted from JaxNeRF 83 | 84 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 85 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 86 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 87 | function of lr_delay_mult, such that the initial learning rate is 88 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 89 | to the normal learning rate when steps>lr_delay_steps. 90 | 91 | :param conf: config subtree 'lr' or similar 92 | :param max_steps: int, the number of steps during optimization. 93 | :return HoF which takes step as input 94 | """ 95 | 96 | def helper(step): 97 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 98 | # Disable this parameter 99 | return 0.0 100 | if lr_delay_steps > 0: 101 | # A kind of reverse cosine decay. 102 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 103 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 104 | ) 105 | else: 106 | delay_rate = 1.0 107 | t = np.clip(step / max_steps, 0, 1) 108 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 109 | return delay_rate * log_lerp 110 | 111 | return helper 112 | 113 | 114 | def viridis_cmap(gray: np.ndarray): 115 | """ 116 | Visualize a single-channel image using matplotlib's viridis color map 117 | yellow is high value, blue is low 118 | :param gray: np.ndarray, (H, W) or (H, W, 1) unscaled 119 | :return: (H, W, 3) float32 in [0, 1] 120 | """ 121 | colored = plt.cm.viridis(plt.Normalize()(gray.squeeze()))[..., :-1] 122 | return colored.astype(np.float32) 123 | 124 | 125 | def save_img(img: np.ndarray, path: str): 126 | """Save an image to disk. Image should have values in [0,1].""" 127 | img = np.array((np.clip(img, 0.0, 1.0) * 255.0).astype(np.uint8)) 128 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 129 | cv2.imwrite(path, img) 130 | 131 | 132 | def equirect2xyz(uv, rows, cols): 133 | """ 134 | Convert equirectangular coordinate to unit vector, 135 | inverse of xyz2equirect 136 | Taken from Vickie Ye 137 | Args: 138 | uv: np.ndarray [..., 2] x, y coordinates in image space in [-1.0, 1.0] 139 | Returns: 140 | xyz: np.ndarray [..., 3] unit vectors 141 | """ 142 | lon = (uv[..., 0] * (1.0 / cols) - 0.5) * (2 * np.pi) 143 | lat = -(uv[..., 1] * (1.0 / rows) - 0.5) * np.pi 144 | coslat = np.cos(lat) 145 | return np.stack( 146 | [ 147 | coslat * np.sin(lon), 148 | np.sin(lat), 149 | coslat * np.cos(lon), 150 | ], 151 | axis=-1, 152 | ) 153 | 154 | def xyz2equirect(bearings, rows, cols): 155 | """ 156 | Convert ray direction vectors into equirectangular pixel coordinates. 157 | Inverse of equirect2xyz. 158 | Taken from Vickie Ye 159 | """ 160 | lat = np.arcsin(bearings[..., 1]) 161 | lon = np.arctan2(bearings[..., 0], bearings[..., 2]) 162 | x = cols * (0.5 + lon / 2 / np.pi) 163 | y = rows * (0.5 - lat / np.pi) 164 | return np.stack([x, y], axis=-1) 165 | 166 | def generate_dirs_equirect(w, h): 167 | x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking 168 | np.arange(w, dtype=np.float32) + 0.5, # X-Axis (columns) 169 | np.arange(h, dtype=np.float32) + 0.5, # Y-Axis (rows) 170 | indexing="xy", 171 | ) 172 | uv = np.stack([x * (2.0 / w) - 1.0, y * (2.0 / h) - 1.0], axis=-1) 173 | camera_dirs = equirect2xyz(uv) 174 | return camera_dirs 175 | 176 | 177 | # Data 178 | def select_or_shuffle_rays(rays_init : Rays, 179 | permutation: int = False, 180 | epoch_size: Optional[int] = None, 181 | device: Union[str, torch.device] = "cpu"): 182 | n_rays = rays_init.origins.size(0) 183 | n_samp = n_rays if (epoch_size is None) else epoch_size 184 | if permutation: 185 | print(" Shuffling rays") 186 | indexer = torch.randperm(n_rays, device='cpu')[:n_samp] 187 | else: 188 | print(" Selecting random rays") 189 | indexer = torch.randint(n_rays, (n_samp,), device='cpu') 190 | return rays_init[indexer].to(device=device) 191 | 192 | 193 | def compute_ssim( 194 | img0, 195 | img1, 196 | max_val=1.0, 197 | filter_size=11, 198 | filter_sigma=1.5, 199 | k1=0.01, 200 | k2=0.03, 201 | return_map=False, 202 | ): 203 | """Computes SSIM from two images. 204 | 205 | This function was modeled after tf.image.ssim, and should produce comparable 206 | output. 207 | 208 | Args: 209 | img0: torch.tensor. An image of size [..., width, height, num_channels]. 210 | img1: torch.tensor. An image of size [..., width, height, num_channels]. 211 | max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. 212 | filter_size: int >= 1. Window size. 213 | filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. 214 | k1: float > 0. One of the SSIM dampening parameters. 215 | k2: float > 0. One of the SSIM dampening parameters. 216 | return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned 217 | 218 | Returns: 219 | Each image's mean SSIM, or a tensor of individual values if `return_map`. 220 | """ 221 | device = img0.device 222 | ori_shape = img0.size() 223 | width, height, num_channels = ori_shape[-3:] 224 | img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 225 | img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 226 | batch_size = img0.shape[0] 227 | 228 | # Construct a 1D Gaussian blur filter. 229 | hw = filter_size // 2 230 | shift = (2 * hw - filter_size + 1) / 2 231 | f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2 232 | filt = torch.exp(-0.5 * f_i) 233 | filt /= torch.sum(filt) 234 | 235 | # Blur in x and y (faster than the 2D convolution). 236 | # z is a tensor of size [B, H, W, C] 237 | filt_fn1 = lambda z: F.conv2d( 238 | z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1), 239 | padding=[hw, 0], groups=num_channels) 240 | filt_fn2 = lambda z: F.conv2d( 241 | z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1), 242 | padding=[0, hw], groups=num_channels) 243 | 244 | # Vmap the blurs to the tensor size, and then compose them. 245 | filt_fn = lambda z: filt_fn1(filt_fn2(z)) 246 | mu0 = filt_fn(img0) 247 | mu1 = filt_fn(img1) 248 | mu00 = mu0 * mu0 249 | mu11 = mu1 * mu1 250 | mu01 = mu0 * mu1 251 | sigma00 = filt_fn(img0 ** 2) - mu00 252 | sigma11 = filt_fn(img1 ** 2) - mu11 253 | sigma01 = filt_fn(img0 * img1) - mu01 254 | 255 | # Clip the variances and covariances to valid values. 256 | # Variance must be non-negative: 257 | sigma00 = torch.clamp(sigma00, min=0.0) 258 | sigma11 = torch.clamp(sigma11, min=0.0) 259 | sigma01 = torch.sign(sigma01) * torch.min( 260 | torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) 261 | ) 262 | 263 | c1 = (k1 * max_val) ** 2 264 | c2 = (k2 * max_val) ** 2 265 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 266 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 267 | ssim_map = numer / denom 268 | ssim = torch.mean(ssim_map.reshape([-1, num_channels*width*height]), dim=-1) 269 | return ssim_map if return_map else ssim 270 | 271 | 272 | def generate_rays(w, h, focal, camtoworlds, equirect=False): 273 | """ 274 | Generate perspective camera rays. Principal point is at center. 275 | Args: 276 | w: int image width 277 | h: int image heigth 278 | focal: float real focal length 279 | camtoworlds: jnp.ndarray [B, 4, 4] c2w homogeneous poses 280 | equirect: if true, generates spherical rays instead of pinhole 281 | Returns: 282 | rays: Rays a namedtuple(origins [B, 3], directions [B, 3], viewdirs [B, 3]) 283 | """ 284 | x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking 285 | np.arange(w, dtype=np.float32), # X-Axis (columns) 286 | np.arange(h, dtype=np.float32), # Y-Axis (rows) 287 | indexing="xy", 288 | ) 289 | 290 | if equirect: 291 | uv = np.stack([x * (2.0 / w) - 1.0, y * (2.0 / h) - 1.0], axis=-1) 292 | camera_dirs = equirect2xyz(uv) 293 | else: 294 | camera_dirs = np.stack( 295 | [ 296 | (x - w * 0.5) / focal, 297 | -(y - h * 0.5) / focal, 298 | -np.ones_like(x), 299 | ], 300 | axis=-1, 301 | ) 302 | 303 | # camera_dirs = camera_dirs / np.linalg.norm(camera_dirs, axis=-1, keepdims=True) 304 | 305 | c2w = camtoworlds[:, None, None, :3, :3] 306 | camera_dirs = camera_dirs[None, Ellipsis, None] 307 | directions = np.matmul(c2w, camera_dirs)[Ellipsis, 0] 308 | origins = np.broadcast_to( 309 | camtoworlds[:, None, None, :3, -1], directions.shape 310 | ) 311 | norms = np.linalg.norm(directions, axis=-1, keepdims=True) 312 | viewdirs = directions / norms 313 | rays = Rays( 314 | origins=origins, directions=directions, viewdirs=viewdirs 315 | ) 316 | return rays 317 | 318 | 319 | def similarity_from_cameras(c2w): 320 | """ 321 | Get a similarity transform to normalize dataset 322 | from c2w (OpenCV convention) cameras 323 | 324 | :param c2w: (N, 4) 325 | 326 | :return T (4,4) , scale (float) 327 | """ 328 | t = c2w[:, :3, 3] 329 | R = c2w[:, :3, :3] 330 | 331 | # (1) Rotate the world so that z+ is the up axis 332 | # we estimate the up axis by averaging the camera up axes 333 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) 334 | world_up = np.mean(ups, axis=0) 335 | world_up /= np.linalg.norm(world_up) 336 | 337 | up_camspace = np.array([0.0, -1.0, 0.0]) 338 | c = (up_camspace * world_up).sum() 339 | cross = np.cross(world_up, up_camspace) 340 | skew = np.array([[0.0, -cross[2], cross[1]], 341 | [cross[2], 0.0, -cross[0]], 342 | [-cross[1], cross[0], 0.0]]) 343 | if c > -1: 344 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1+c) 345 | else: 346 | # In the unlikely case the original data has y+ up axis, 347 | # rotate 180-deg about x axis 348 | R_align = np.array([[-1.0, 0.0, 0.0], 349 | [0.0, 1.0, 0.0], 350 | [0.0, 0.0, 1.0]]) 351 | 352 | 353 | # R_align = np.eye(3) # DEBUG 354 | R = (R_align @ R) 355 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) 356 | t = (R_align @ t[..., None])[..., 0] 357 | 358 | # (2) Recenter the scene using camera center rays 359 | # find the closest point to the origin for each camera's center ray 360 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds 361 | 362 | # median for more robustness 363 | translate = -np.median(nearest, axis=0) 364 | 365 | # translate = -np.mean(t, axis=0) # DEBUG 366 | 367 | transform = np.eye(4) 368 | transform[:3, 3] = translate 369 | transform[:3, :3] = R_align 370 | 371 | # (3) Rescale the scene using camera distances 372 | scale = 1.0 / np.median(np.linalg.norm(t + translate, axis=-1)) 373 | return transform, scale 374 | 375 | def jiggle_and_interp_poses(poses : torch.Tensor, 376 | n_inter: int, 377 | noise_std : float=0.0): 378 | """ 379 | For generating a novel trajectory close to known trajectory 380 | 381 | :param poses: torch.Tensor (B, 4, 4) 382 | :param n_inter: int, number of views to interpolate in total 383 | :param noise_std: float, default 0 384 | """ 385 | n_views_in = poses.size(0) 386 | poses_np = poses.cpu().numpy().copy() 387 | rot = Rotation.from_matrix(poses_np[:, :3, :3]) 388 | trans = poses_np[:, :3, 3] 389 | trans += np.random.randn(*trans.shape) * noise_std 390 | pose_quat = rot.as_quat() 391 | 392 | t_in = np.arange(n_views_in, dtype=np.float32) 393 | t_out = np.linspace(t_in[0], t_in[-1], n_inter, dtype=np.float32) 394 | 395 | q_new = CubicSpline(t_in, pose_quat) 396 | q_new : np.ndarray = q_new(t_out) 397 | q_new = q_new / np.linalg.norm(q_new, axis=-1)[..., None] 398 | 399 | t_new = CubicSpline(t_in, trans) 400 | t_new = t_new(t_out) 401 | 402 | rot_new = Rotation.from_quat(q_new) 403 | R_new = rot_new.as_matrix() 404 | 405 | Rt_new = np.concatenate([R_new, t_new[..., None]], axis=-1) 406 | bottom = np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32) 407 | bottom = bottom[None].repeat(Rt_new.shape[0], 0) 408 | Rt_new = np.concatenate([Rt_new, bottom], axis=-2) 409 | Rt_new = torch.from_numpy(Rt_new).to(device=poses.device, dtype=poses.dtype) 410 | return Rt_new 411 | 412 | 413 | # Rather ugly pose generation code, derived from NeRF 414 | def _trans_t(t): 415 | return np.array( 416 | [ 417 | [1, 0, 0, 0], 418 | [0, 1, 0, 0], 419 | [0, 0, 1, t], 420 | [0, 0, 0, 1], 421 | ], 422 | dtype=np.float32, 423 | ) 424 | 425 | 426 | def _rot_phi(phi): 427 | return np.array( 428 | [ 429 | [1, 0, 0, 0], 430 | [0, np.cos(phi), -np.sin(phi), 0], 431 | [0, np.sin(phi), np.cos(phi), 0], 432 | [0, 0, 0, 1], 433 | ], 434 | dtype=np.float32, 435 | ) 436 | 437 | 438 | def _rot_theta(th): 439 | return np.array( 440 | [ 441 | [np.cos(th), 0, -np.sin(th), 0], 442 | [0, 1, 0, 0], 443 | [np.sin(th), 0, np.cos(th), 0], 444 | [0, 0, 0, 1], 445 | ], 446 | dtype=np.float32, 447 | ) 448 | 449 | def pose_spherical(theta : float, phi : float, radius : float, offset : Optional[np.ndarray]=None, 450 | vec_up : Optional[np.ndarray]=None): 451 | """ 452 | Generate spherical rendering poses, from NeRF. Forgive the code horror 453 | :return: r (3,), t (3,) 454 | """ 455 | c2w = _trans_t(radius) 456 | c2w = _rot_phi(phi / 180.0 * np.pi) @ c2w 457 | c2w = _rot_theta(theta / 180.0 * np.pi) @ c2w 458 | c2w = ( 459 | np.array( 460 | [[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], 461 | dtype=np.float32, 462 | ) 463 | @ c2w 464 | ) 465 | if vec_up is not None: 466 | vec_up = vec_up / np.linalg.norm(vec_up) 467 | vec_1 = np.array([vec_up[0], -vec_up[2], vec_up[1]]) 468 | vec_2 = np.cross(vec_up, vec_1) 469 | 470 | trans = np.eye(4, 4, dtype=np.float32) 471 | trans[:3, 0] = vec_1 472 | trans[:3, 1] = vec_2 473 | trans[:3, 2] = vec_up 474 | c2w = trans @ c2w 475 | c2w = c2w @ np.diag(np.array([1, -1, -1, 1], dtype=np.float32)) 476 | if offset is not None: 477 | c2w[:3, 3] += offset 478 | return c2w 479 | --------------------------------------------------------------------------------