├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── configs ├── bike.txt ├── bike_full.txt ├── bike_tot.txt ├── chair.txt ├── chess.txt ├── chess_full.txt ├── chess_simple.txt ├── cornell.txt ├── drums.txt ├── fern.txt ├── ficus.txt ├── flower.txt ├── fortress.txt ├── horns.txt ├── hotdog.txt ├── leaves.txt ├── lego.txt ├── materials.txt ├── mic.txt ├── orchids.txt ├── room.txt ├── scannet.txt ├── ship.txt └── trex.txt ├── download_example_data.sh ├── image_mem.py ├── image_mem_helpers.py ├── imgs └── pipeline.jpg ├── load_blender.py ├── load_deepvoxels.py ├── load_llff.py ├── requirements.txt ├── run_nerf.py └── run_nerf_helpers.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | **/__pycache__ 3 | *.png 4 | *.mp4 5 | *.npy 6 | *.npz 7 | *.dae 8 | data/* 9 | logs/* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gword/Recursive-NeRF/e61ceb302d0bc028fe4524771b363bf514342b5f/.gitmodules -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 bmild 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recursive-NeRF: An Efficient and Dynamically Growing NeRF 2 | This is official implementation of Recursive-NeRF: An Efficient and Dynamically Growing NeRF. 3 | 4 | Paper link: https://ieeexplore.ieee.org/document/9909994 5 | 6 | ## Abstract 7 | View synthesis methods using implicit continuous shape representations learned from a set of images, such as the Neural Radiance Field (NeRF) method, have gained increasing attention due to their high quality imagery and scalability to high resolution. 8 | However, the heavy computation required by its volumetric approach prevents NeRF from being useful in practice; minutes are taken to render a single image of a few megapixels. 9 | Now, an image of a scene can be rendered in a level-of-detail manner, so we posit that a complicated region of the scene should be represented by a large neural network while a small neural network is capable of encoding a simple region, enabling a balance between efficiency and quality. 10 | Recursive-NeRF is our embodiment of this idea, providing an efficient and adaptive rendering and training approach for NeRF. 11 | The core of Recursive-NeRF learns uncertainties for query coordinates, representing the quality of the predicted color and volumetric intensity at each level. 12 | Only query coordinates with high uncertainties are forwarded to the next level to a bigger neural network with a more powerful representational capability. 13 | The final rendered image is a composition of results from neural networks of all levels. 14 | Our evaluation on public datasets and a large-scale scene dataset we collected shows that Recursive-NeRF is more efficient than NeRF while providing state-of-the-art quality. -------------------------------------------------------------------------------- /configs/bike.txt: -------------------------------------------------------------------------------- 1 | expname = bike 2 | datadir = ./data/our_seq_bike 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = False 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 512 14 | 15 | half_res = False 16 | raw_noise_std = 1e0 17 | blender_factor = 8 18 | testskip = 1 -------------------------------------------------------------------------------- /configs/bike_full.txt: -------------------------------------------------------------------------------- 1 | expname = bike_full 2 | datadir = ./data/our_seq_bike_full 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = False 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | half_res = False 16 | raw_noise_std = 1e0 17 | blender_factor = 8 18 | testskip = 1 19 | faketestskip = 8 20 | i_tottest = 400000 21 | do_intrinsic = True 22 | render_test = True 23 | near = 0.3 24 | far = 6.0 25 | -------------------------------------------------------------------------------- /configs/bike_tot.txt: -------------------------------------------------------------------------------- 1 | expname = bike_tot 2 | datadir = ./data/our_seq_bike_tot 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = False 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 512 14 | 15 | half_res = False 16 | raw_noise_std = 1e0 17 | blender_factor = 8 18 | testskip = 8 -------------------------------------------------------------------------------- /configs/chair.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_chair 2 | datadir = ./data/nerf_synthetic/chair 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | testskip=1 19 | faketestskip=8 20 | i_tottest = 600000 -------------------------------------------------------------------------------- /configs/chess.txt: -------------------------------------------------------------------------------- 1 | expname = chess 2 | datadir = ./data/chess 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = False 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | half_res = False 19 | -------------------------------------------------------------------------------- /configs/chess_full.txt: -------------------------------------------------------------------------------- 1 | expname = chess 2 | datadir = ./data/chess_full 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = False 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 512 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | half_res = False 19 | raw_noise_std = 1e0 -------------------------------------------------------------------------------- /configs/chess_simple.txt: -------------------------------------------------------------------------------- 1 | expname = chess 2 | datadir = ./data/chess_simple 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = False 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 512 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | half_res = False 19 | -------------------------------------------------------------------------------- /configs/cornell.txt: -------------------------------------------------------------------------------- 1 | expname = cornell 2 | datadir = ./data/our_cornell 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | half_res = False 19 | blender_factor = 1 20 | testskip = 1 21 | faketestskip = 8 22 | i_tottest = 400000 23 | do_intrinsic = True 24 | render_test = True 25 | far = 20.0 -------------------------------------------------------------------------------- /configs/drums.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_drums 2 | datadir = ./data/nerf_synthetic/drums 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | testskip=1 19 | faketestskip=8 20 | i_tottest = 600000 -------------------------------------------------------------------------------- /configs/fern.txt: -------------------------------------------------------------------------------- 1 | expname = fern_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fern 4 | dataset_type = llff 5 | 6 | factor = 4 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 128 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | i_tottest = 600000 -------------------------------------------------------------------------------- /configs/ficus.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_ficus 2 | datadir = ./data/nerf_synthetic/ficus 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | testskip=1 19 | faketestskip=8 20 | i_tottest = 600000 -------------------------------------------------------------------------------- /configs/flower.txt: -------------------------------------------------------------------------------- 1 | expname = flower_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/flower 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/fortress.txt: -------------------------------------------------------------------------------- 1 | expname = fortress_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fortress 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/horns.txt: -------------------------------------------------------------------------------- 1 | expname = horns_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/horns 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/hotdog.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_hotdog 2 | datadir = ./data/nerf_synthetic/hotdog 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | testskip=1 19 | faketestskip=8 20 | i_tottest = 600000 -------------------------------------------------------------------------------- /configs/leaves.txt: -------------------------------------------------------------------------------- 1 | expname = leaves_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/leaves 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/lego.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_lego 2 | datadir = ./data/nerf_synthetic/lego 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | testskip=1 19 | faketestskip=8 20 | i_tottest = 600000 -------------------------------------------------------------------------------- /configs/materials.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_materials 2 | datadir = ./data/nerf_synthetic/materials 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | half_res = False 19 | testskip=1 20 | faketestskip=8 21 | i_tottest = 600000 22 | -------------------------------------------------------------------------------- /configs/mic.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_mic 2 | datadir = ./data/nerf_synthetic/mic 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | testskip=1 19 | faketestskip=8 20 | i_tottest = 600000 -------------------------------------------------------------------------------- /configs/orchids.txt: -------------------------------------------------------------------------------- 1 | expname = orchids_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/orchids 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/room.txt: -------------------------------------------------------------------------------- 1 | expname = room_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/room 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/scannet.txt: -------------------------------------------------------------------------------- 1 | expname = scannet 2 | datadir = ./data/scannet_scene0706_00 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = False 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 512 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | half_res = False 19 | raw_noise_std = 1e0 -------------------------------------------------------------------------------- /configs/ship.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_ship 2 | datadir = ./data/nerf_synthetic/ship 3 | dataset_type = blender 4 | 5 | no_batching = True 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | testskip=1 19 | faketestskip=8 20 | i_tottest = 600000 -------------------------------------------------------------------------------- /configs/trex.txt: -------------------------------------------------------------------------------- 1 | expname = trex_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/trex 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /download_example_data.sh: -------------------------------------------------------------------------------- 1 | wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz 2 | mkdir -p data 3 | cd data 4 | wget https://people.eecs.berkeley.edu/~bmild/nerf/nerf_example_data.zip 5 | unzip nerf_example_data.zip 6 | cd .. 7 | -------------------------------------------------------------------------------- /image_mem.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import imageio 4 | import json 5 | import random 6 | import time 7 | import jittor as jt 8 | from jittor import nn 9 | from tqdm import tqdm, trange 10 | import datetime 11 | import cv2 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | from image_mem_helpers import * 16 | 17 | from load_llff import load_llff_data 18 | from load_deepvoxels import load_dv_data 19 | from load_blender import load_blender_data 20 | from tensorboardX import SummaryWriter 21 | 22 | jt.flags.use_cuda = 1 23 | # np.random.seed(0) 24 | DEBUG = False 25 | 26 | def config_parser(): 27 | gpu = "gpu"+os.environ["CUDA_VISIBLE_DEVICES"] 28 | import configargparse 29 | parser = configargparse.ArgumentParser() 30 | parser.add_argument('--config', is_config_file=True, 31 | help='config file path') 32 | parser.add_argument("--expname", type=str, 33 | help='experiment name') 34 | parser.add_argument("--basedir", type=str, default='./logs/'+gpu+"/", 35 | help='where to store ckpts and logs') 36 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 37 | help='input data directory') 38 | 39 | # training options 40 | parser.add_argument("--netdepth", type=int, default=8, 41 | help='layers in network') 42 | parser.add_argument("--step1", type=int, default=10000, 43 | help='?') 44 | parser.add_argument("--step2", type=int, default=20000, 45 | help='?') 46 | parser.add_argument("--step3", type=int, default=3000000, 47 | help='?') 48 | parser.add_argument("--netwidth", type=int, default=256, 49 | help='channels per layer') 50 | parser.add_argument("--netdepth_fine", type=int, default=8, 51 | help='layers in fine network') 52 | parser.add_argument("--netwidth_fine", type=int, default=256, 53 | help='channels per layer in fine network') 54 | parser.add_argument("--N_rand", type=int, default=32*32*4, 55 | help='batch size (number of random rays per gradient step)') 56 | parser.add_argument("--lrate", type=float, default=5e-4, 57 | help='learning rate') 58 | parser.add_argument("--lrate_decay", type=int, default=250, 59 | help='exponential learning rate decay (in 1000 steps)') 60 | parser.add_argument("--chunk", type=int, default=1024*8, 61 | help='number of rays processed in parallel, decrease if running out of memory') 62 | parser.add_argument("--netchunk", type=int, default=1024*64, 63 | help='number of pts sent through network in parallel, decrease if running out of memory') 64 | parser.add_argument("--no_batching", action='store_true', 65 | help='only take random rays from 1 image at a time') 66 | parser.add_argument("--no_reload", action='store_true', 67 | help='do not reload weights from saved ckpt') 68 | parser.add_argument("--ft_path", type=str, default=None, 69 | help='specific weights npy file to reload for coarse network') 70 | parser.add_argument("--threshold", type=float, default=0, 71 | help='threshold') 72 | 73 | # rendering options 74 | parser.add_argument("--N_samples", type=int, default=64, 75 | help='number of coarse samples per ray') 76 | parser.add_argument("--N_importance", type=int, default=0, 77 | help='number of additional fine samples per ray') 78 | parser.add_argument("--perturb", type=float, default=1., 79 | help='set to 0. for no jitter, 1. for jitter') 80 | parser.add_argument("--use_viewdirs", action='store_true', 81 | help='use full 5D input instead of 3D') 82 | parser.add_argument("--i_embed", type=int, default=0, 83 | help='set 0 for default positional encoding, -1 for none') 84 | parser.add_argument("--multires", type=int, default=10, 85 | help='log2 of max freq for positional encoding (3D location)') 86 | parser.add_argument("--multires_views", type=int, default=4, 87 | help='log2 of max freq for positional encoding (2D direction)') 88 | parser.add_argument("--raw_noise_std", type=float, default=0., 89 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 90 | 91 | parser.add_argument("--render_only", action='store_true', 92 | help='do not optimize, reload weights and render out render_poses path') 93 | parser.add_argument("--render_test", action='store_true', 94 | help='render the test set instead of render_poses path') 95 | parser.add_argument("--render_factor", type=int, default=0, 96 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 97 | 98 | # training options 99 | parser.add_argument("--precrop_iters", type=int, default=0, 100 | help='number of steps to train on central crops') 101 | parser.add_argument("--head_num", type=int, default=8, 102 | help='number of heads') 103 | parser.add_argument("--precrop_frac", type=float, 104 | default=.5, help='fraction of img taken for central crops') 105 | parser.add_argument("--large_scene", action='store_true', 106 | help='use large scene') 107 | 108 | # dataset options 109 | parser.add_argument("--dataset_type", type=str, default='llff', 110 | help='options: llff / blender / deepvoxels') 111 | parser.add_argument("--testskip", type=int, default=8, 112 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 113 | parser.add_argument("--faketestskip", type=int, default=1, 114 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 115 | 116 | ## deepvoxels flags 117 | parser.add_argument("--shape", type=str, default='greek', 118 | help='options : armchair / cube / greek / vase') 119 | 120 | ## blender flags 121 | parser.add_argument("--white_bkgd", action='store_true', 122 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 123 | parser.add_argument("--half_res", action='store_true', 124 | help='load blender synthetic data at 400x400 instead of 800x800') 125 | parser.add_argument("--near", type=float, default=2., 126 | help='downsample factor for LLFF images') 127 | parser.add_argument("--far", type=float, default=6., 128 | help='downsample factor for LLFF images') 129 | parser.add_argument("--do_intrinsic", action='store_true', 130 | help='use intrinsic matrix') 131 | parser.add_argument("--blender_factor", type=int, default=1, 132 | help='downsample factor for blender images') 133 | 134 | ## llff flags 135 | parser.add_argument("--factor", type=int, default=8, 136 | help='downsample factor for LLFF images') 137 | parser.add_argument("--no_ndc", action='store_true', 138 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 139 | parser.add_argument("--lindisp", action='store_true', 140 | help='sampling linearly in disparity rather than depth') 141 | parser.add_argument("--spherify", action='store_true', 142 | help='set for spherical 360 scenes') 143 | parser.add_argument("--llffhold", type=int, default=8, 144 | help='will take every 1/N images as LLFF test set, paper uses 8') 145 | 146 | # logging/saving options 147 | parser.add_argument("--i_print", type=int, default=100, 148 | help='frequency of console printout and metric loggin') 149 | parser.add_argument("--i_img", type=int, default=50000, 150 | help='frequency of tensorboard image logging') 151 | parser.add_argument("--i_weights", type=int, default=10000, 152 | help='frequency of weight ckpt saving') 153 | parser.add_argument("--i_testset", type=int, default=5000, 154 | help='frequency of testset saving') 155 | parser.add_argument("--i_tottest", type=int, default=400000, 156 | help='frequency of testset saving') 157 | parser.add_argument("--i_video", type=int, default=50000, 158 | help='frequency of render_poses video saving') 159 | 160 | parser.add_argument("--scaledown", type=int, default=1, 161 | help='frequency of render_poses video saving') 162 | 163 | return parser 164 | 165 | def create_nerf(args): 166 | """Instantiate NeRF's MLP model. 167 | """ 168 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed, 2) 169 | input_ch_views = 0 170 | output_ch = 3 171 | skips = [4] 172 | model = NeRF(D=args.netdepth, W=args.netwidth, 173 | input_ch=input_ch, output_ch=output_ch, skips=skips, 174 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, 175 | head_num=args.head_num, threshold=args.threshold) 176 | optimizer = jt.optim.Adam(params=model.parameters(), lr=args.lrate, betas=(0.9, 0.999)) 177 | 178 | return model, optimizer, embed_fn 179 | 180 | def render(H, W, x, chunk, embed_fn, model): 181 | rgb = [] 182 | confs = [] 183 | outnums = [] 184 | for i in range(0,x.shape[0],chunk): 185 | xx = x[i:i+chunk] 186 | feat = embed_fn(x[i:i+chunk]) 187 | y, conf, sout_num = model(feat, xx, False) 188 | rgb.append(y[-1]) 189 | confs.append(conf[-1]) 190 | outnums.append(sout_num) 191 | rgb[-1].sync() 192 | confs[-1].sync() 193 | rgb = jt.concat(rgb, 0).reshape(H,W,3) 194 | confs = jt.concat(confs, 0) 195 | outnum = np.concatenate(outnums, axis=0) 196 | print("outnum",outnum.shape) 197 | outnum = outnum.sum(0) 198 | print("outnum",outnum.shape,outnum) 199 | return rgb, confs 200 | 201 | def dfs(t, points, model): 202 | k = len(model.son_list[t]) 203 | print("dfs",t,"k",k) 204 | print("points",points.shape) 205 | if t in model.force_out: 206 | if points.shape[0]>=k: 207 | centroid = points[jt.array(np.random.choice(points.shape[0], k, replace=False))] 208 | print("centroid",centroid.shape) 209 | # print("step",-1,centroid.numpy()) 210 | for step in range(100): 211 | # print("step",step) 212 | dis = (points.unsqueeze(1) - centroid.unsqueeze(0)).sqr().sum(-1).sqrt() 213 | min_idx, _ = jt.argmin(dis,-1) 214 | # print("min_idx",min_idx.shape) 215 | for i in range(k): 216 | # print("i",i,(min_idx==i).sum) 217 | centroid[i] = points[min_idx==i].mean(0) 218 | jt.sync_all() 219 | # jt.display_memory_info()z 220 | # print("step",step,centroid.numpy()) 221 | else: 222 | centroid = jt.rand((k,2)) 223 | print("centroid fail",centroid.shape) 224 | print("centroid",centroid.shape,centroid) 225 | setattr(model, model.node_list[t].anchors, centroid.detach()) 226 | # warning mpi 227 | if jt.mpi and False: 228 | # v1 = getattr(model, model.node_list[t].anchors) 229 | # v1.assign(jt.mpi.broadcast(v2, root=0)) 230 | jt.mpi.broadcast(getattr(model, model.node_list[t].anchors), root=0) 231 | print("model", jt.mpi.local_rank(), t, getattr(model, model.node_list[t].anchors)) 232 | for i in model.son_list[t]: 233 | # model.outnet[i].alpha_linear.load_state_dict(model.outnet[t].alpha_linear.state_dict()) 234 | model.outnet[i].load_state_dict(model.outnet[t].state_dict()) 235 | return model.son_list[t] 236 | else: 237 | centroid = model.get_anchor(model.node_list[t].anchors) 238 | dis = (points.unsqueeze(1) - centroid.unsqueeze(0)).sqr().sum(-1).sqrt() 239 | min_idx, _ = jt.argmin(dis,-1) 240 | res = [] 241 | for i in range(k): 242 | res += dfs(model.son_list[t][i], points[min_idx==i], model) 243 | return res 244 | 245 | def do_kmeans(pts, model): 246 | force_out = dfs(0, pts, model) 247 | model.force_out = force_out 248 | 249 | def print_conf_img(rgb, conf, threshold, testsavedir): 250 | os.makedirs(testsavedir, exist_ok=True) 251 | # threshold_list = [threshold, 2e-2, 3e-2, 4e-2, 5e-2, 1e-1, 5e-1, 1] 252 | threshold_list = [threshold] 253 | H, W, C = rgb.shape 254 | red = np.array([1,0,0]) 255 | for th in threshold_list: 256 | filename = os.path.join(testsavedir, 'confimg_'+str(th)+'.png') 257 | logimg = rgb.copy().reshape([-1,3]) 258 | print("logimg1",logimg.shape) 259 | # bo = (conf= rand_coords.shape[0]: 338 | # print("Shuffle data after an epoch!") 339 | rand_idx = jt.randperm(coords_f.shape[0]) 340 | rand_coords = coords_f[rand_idx] 341 | rand_image = image.reshape([-1,3])[rand_idx] 342 | i_batch = 0 343 | else: 344 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 345 | select_coords = coords[select_inds].long() # (N_rand, 2) 346 | x = coords_f[select_inds] # (N_rand, 2) 347 | target = image[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 348 | # print("select_coords",select_coords.shape,select_coords) 349 | # print("x",x.shape,x) 350 | # print("target",target.shape,target) 351 | 352 | # if i == 100: 353 | # import cProfile 354 | # pr = cProfile.Profile() 355 | # pr.enable() 356 | # jt.flags.trace_py_var=3 357 | # jt.flags.profiler_enable = 1 358 | # elif i == 110: 359 | # pr.disable() 360 | # pr.print_stats() 361 | # jt.flags.profiler_enable = 0 362 | # jt.profiler.report() 363 | # jt.flags.trace_py_var=0 364 | feat = embed_fn(x) 365 | y, conf, _ = model(feat, x, True) 366 | save=((y-target)**2).detach().unsqueeze(-2) 367 | # loss = img2mse(y[-1], target) 368 | loss = img2mse(y, target.unsqueeze(0)) 369 | psnr = mse2psnr(loss) 370 | 371 | sloss = jt.maximum(((y-target)**2).detach()-conf,0.) 372 | loss_conf_loss = sloss.mean() 373 | loss_conf_mean = jt.maximum(conf, 0.).mean() 374 | loss_conf = loss_conf_loss+loss_conf_mean*0.01 375 | loss = loss + loss_conf*0.1 376 | 377 | # optimizer.backward(loss) 378 | optimizer.step(loss) 379 | jt.sync_all() 380 | # jt.display_memory_info() 381 | 382 | # NOTE: IMPORTANT! 383 | ### update learning rate ### 384 | decay_rate = 0.1 385 | decay_steps = args.lrate_decay * 1000 386 | sstep = global_step 387 | if sstep>split_tree3: 388 | sstep-=split_tree3 389 | elif sstep>split_tree2: 390 | sstep-=split_tree2 391 | elif sstep>split_tree1: 392 | sstep-=split_tree1 393 | new_lrate = args.lrate * (decay_rate ** (sstep / decay_steps)) 394 | for param_group in optimizer.param_groups: 395 | param_group['lr'] = new_lrate 396 | 397 | if i%args.i_print==0: 398 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") 399 | writer.add_scalar("train/loss", loss.item(), global_step) 400 | writer.add_scalar("train/loss_conf", loss_conf.item(), global_step) 401 | writer.add_scalar("train/PSNR", psnr.item(), global_step) 402 | writer.add_scalar("lr/lr", new_lrate, global_step) 403 | if i%args.i_testset==0: 404 | with jt.no_grad(): 405 | rgb, _ = render(H, W, coords_f, N_rand, embed_fn, model) 406 | print("rgb",rgb.shape) 407 | psnr = mse2psnr(img2mse(rgb, image)) 408 | rgb = rgb.numpy() 409 | 410 | writer.add_scalar('test/psnr', psnr.item(), global_step) 411 | if i%args.i_img==0: 412 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 413 | filename = os.path.join(testsavedir, '{:03d}.png'.format(i)) 414 | os.makedirs(testsavedir, exist_ok=True) 415 | imageio.imwrite(filename, to8b(rgb)) 416 | 417 | writer.add_image('test/rgb', to8b(rgb), global_step, dataformats="HWC") 418 | writer.add_image('test/target', image.numpy(), global_step, dataformats="HWC") 419 | if i==split_tree1 or i==split_tree2 or i==split_tree3: 420 | with jt.no_grad(): 421 | rgb, conf = render(H, W, coords_f, N_rand, embed_fn, model) 422 | print("split conf",conf.shape) 423 | print("split coords_f",coords_f.shape) 424 | pts = coords_f[conf.squeeze(1)>=args.threshold] 425 | print("split pts threshold:",pts.shape) 426 | # pts = coords_f[conf>=2e-2] 427 | # print("split pts 2e-2:",pts.shape) 428 | # pts = coords_f[conf>=3e-2] 429 | # print("split pts 3e-2:",pts.shape) 430 | # pts = coords_f[conf>=4e-2] 431 | # print("split pts 4e-2:",pts.shape) 432 | # pts = coords_f[conf>=5e-2] 433 | # print("split pts 5e-2:",pts.shape) 434 | # pts = coords_f[conf>=1e-1] 435 | # print("split pts 1e-1:",pts.shape) 436 | # pts = coords_f[conf>=5e-1] 437 | # print("split pts 5e-2:",pts.shape) 438 | # pts = coords_f[conf>=1] 439 | # print("split pts 1e-0:",pts.shape) 440 | do_kmeans(pts, model) 441 | # print_conf_img(rgb.numpy(), conf.numpy(), args.threshold, os.path.join(basedir, expname, 'confimg_{:06d}'.format(i))) 442 | jt.gc() 443 | global_step += 1 444 | 445 | 446 | if __name__=='__main__': 447 | train() 448 | -------------------------------------------------------------------------------- /image_mem_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jittor as jt 3 | from jittor import nn 4 | import numpy as np 5 | 6 | 7 | # Misc 8 | img2mse = lambda x, y : jt.mean((x - y) ** 2) 9 | mse2psnr = lambda x : -10. * jt.log(x) / jt.log(jt.array(np.array([10.]))) 10 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 11 | 12 | 13 | # Positional encoding (section 5.1) 14 | class Embedder: 15 | def __init__(self, **kwargs): 16 | self.kwargs = kwargs 17 | self.create_embedding_fn() 18 | 19 | def create_embedding_fn(self): 20 | embed_fns = [] 21 | d = self.kwargs['input_dims'] 22 | out_dim = 0 23 | if self.kwargs['include_input']: 24 | embed_fns.append(lambda x : x) 25 | out_dim += d 26 | 27 | max_freq = self.kwargs['max_freq_log2'] 28 | N_freqs = self.kwargs['num_freqs'] 29 | 30 | if self.kwargs['log_sampling']: 31 | freq_bands = 2.**jt.linspace(0., max_freq, steps=N_freqs) 32 | else: 33 | freq_bands = jt.linspace(2.**0., 2.**max_freq, steps=N_freqs) 34 | 35 | for freq in freq_bands: 36 | for p_fn in self.kwargs['periodic_fns']: 37 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 38 | out_dim += d 39 | 40 | self.embed_fns = embed_fns 41 | self.out_dim = out_dim 42 | 43 | def embed(self, inputs): 44 | return jt.concat([fn(inputs) for fn in self.embed_fns], -1) 45 | 46 | 47 | def get_embedder(multires, i=0, input_dims=3): 48 | if i == -1: 49 | return nn.Identity(), 3 50 | 51 | embed_kwargs = { 52 | 'include_input' : True, 53 | 'input_dims' : input_dims, 54 | 'max_freq_log2' : multires-1, 55 | 'num_freqs' : multires, 56 | 'log_sampling' : True, 57 | 'periodic_fns' : [jt.sin, jt.cos], 58 | } 59 | 60 | embedder_obj = Embedder(**embed_kwargs) 61 | embed = lambda x, eo=embedder_obj : eo.embed(x) 62 | return embed, embedder_obj.out_dim 63 | 64 | #tree class 65 | class Node(): 66 | def __init__(self, anchors, sons, linears): 67 | self.anchors = anchors 68 | self.sons = sons 69 | self.linears = linears 70 | 71 | # Model 72 | class OutputNet(nn.Module): 73 | def __init__(self, W, input_ch_views): 74 | """ 75 | """ 76 | super(OutputNet, self).__init__() 77 | 78 | self.rgb_linear = nn.Linear(W, 3) 79 | 80 | def execute(self, h, input_views): 81 | 82 | rgb = self.rgb_linear(h) 83 | return rgb 84 | 85 | # Model 86 | class NeRF(nn.Module): 87 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False, head_num=8, threshold=3e-2): 88 | """ 89 | """ 90 | super(NeRF, self).__init__() 91 | D=12 92 | print("input_ch",input_ch,"input_ch_views",input_ch_views,"output_ch",output_ch,"skips",skips,"use_viewdirs",use_viewdirs) 93 | # W=128 94 | self.D = D 95 | self.W = W 96 | self.input_ch = input_ch 97 | self.input_ch_views = input_ch_views 98 | skips=[] 99 | # skips = [2,4,6] 100 | self.skips = skips 101 | # self.ress = [1,3,7,11] 102 | # self.ress = [] 103 | # self.outs = [1,3,7,11] 104 | self.force_out = [0] 105 | # self.force_out = [7,8,9,10,11,12,13,14] 106 | self.use_viewdirs = use_viewdirs 107 | assert self.use_viewdirs==False 108 | 109 | self.threshold = threshold 110 | 111 | self.build_tree(head_num) 112 | self.pts_linears = nn.ModuleList( 113 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skip_linear else nn.Linear(W + input_ch, W) for i in range(self.linear_num-1)]) 114 | # for i in range(self.nlinear_list[0]+1,len(self.pts_linears)): 115 | # jt.init.constant_(self.pts_linears[i].weight, 0.0) 116 | # jt.init.constant_(self.pts_linears[i].bias, 0.0) 117 | # self.confidence_linears = nn.ModuleList([nn.Linear(W+ input_ch, 1) for i in range(D)]) 118 | self.confidence_linears = nn.ModuleList([nn.Linear(W, 1) for i in range(self.node_num)]) 119 | # self.outnet = OutputNet(W, input_ch_views) 120 | self.outnet = nn.ModuleList([OutputNet(W, input_ch_views) for i in range(self.node_num)]) 121 | 122 | def get_anchor(self, i): 123 | return getattr(self, i) 124 | 125 | def get_son_list(self, son_num, nlinear): 126 | son_list = [] 127 | nlinear_list = [] 128 | skip_linear = [] 129 | 130 | queue = [(0,0)] 131 | head = 0 132 | tot_linear = 0 133 | while head0: 236 | anchor = "anchor"+str(self.anchor_num) 237 | self.anchor_num += 1 238 | setattr(self, anchor, jt.array(self.anchor_list[i])) 239 | # setattr(self, anchor, jt.random([len(son), 3])) 240 | else: 241 | anchor = None 242 | linear = list(range(self.linear_num, self.linear_num+self.nlinear_list[i])) 243 | self.linear_num += self.nlinear_list[i] 244 | self.node_list.append(Node(anchor, son, linear)) 245 | 246 | def my_concat(self, a, b, dim): 247 | if a is None: 248 | return b 249 | elif b is None: 250 | return a 251 | else: 252 | return jt.concat([a,b],dim) 253 | 254 | def search(self, t, p, h, input_pts, input_views, remain_mask): 255 | node = self.node_list[t] 256 | # print("search t",t,"remain_mask",remain_mask.sum()) 257 | identity = h 258 | for i in range(len(node.linears)): 259 | # print("i",i) 260 | # print("h",h.shape) 261 | # print(self.pts_linears[node.linears[i]]) 262 | # print("len",len(self.pts_linears),"node.linears[i]",node.linears[i],"node",node) 263 | # print("t",t,"i",i,"h",h.shape,"line",self.pts_linears[node.linears[i]].weight.shape) 264 | h = self.pts_linears[node.linears[i]](h) 265 | # if t==0 and i==0: 266 | # identity = h 267 | # if i==len(node.linears)-1: 268 | # h = h+identity 269 | if i%2==1: 270 | h += identity 271 | h = jt.nn.relu(h) 272 | if i%2==1 or (t==0 and i==0): 273 | identity = h 274 | if node.linears[i] in self.skip_linear: 275 | h = jt.concat([input_pts, h], -1) 276 | 277 | confidence = self.confidence_linears[t](h).view(-1) 278 | # threshold = 0.0 279 | threshold = self.threshold 280 | # threshold = -1e10 281 | output = self.outnet[t](h, input_views) 282 | # output = self.outnet[0](h, input_views) 283 | out_num = np.zeros((self.node_num)) 284 | 285 | if len(node.sons)>0 and (not t in self.force_out): 286 | son_outputs = None 287 | son_outputs_fuse = None 288 | son_confs = None 289 | son_confs_fuse = None 290 | idxs = None 291 | idxs_fuse = None 292 | 293 | anchor = self.get_anchor(node.anchors) 294 | dis = (anchor.unsqueeze(0)-p.unsqueeze(1)).sqr().sum(-1).sqrt() 295 | min_idx, _ = jt.argmin(dis,-1) 296 | for i in range(len(node.sons)): 297 | # print("t",t,"i",i) 298 | next_t = node.sons[i] 299 | sidx = jt.arange(0,p.shape[0]) 300 | # print("min_idx==i",min_idx==i) 301 | sidx = sidx[min_idx==i] 302 | # print("sidx",sidx) 303 | next_p = p[sidx] 304 | next_h = h[sidx] 305 | next_input_pts = input_pts[sidx] 306 | next_input_views = input_views[sidx] 307 | next_remain_mask = remain_mask[sidx].copy() 308 | next_conf = confidence[sidx] 309 | next_remain_mask[threshold>next_conf] = 0 310 | sidx_fuse = sidx[next_remain_mask==1] 311 | 312 | # print("start t",t,"i",i,"next_t",next_t) 313 | next_outputs, next_outputs_fuse, next_confs, next_confs_fuse, next_out_num = self.search(next_t, next_p, next_h, next_input_pts, next_input_views, next_remain_mask) 314 | out_num = out_num+next_out_num 315 | # print("search", t, next_t) 316 | # print("next_outputs",next_outputs.shape) 317 | # print("next_outputs_fuse",next_outputs_fuse.shape) 318 | # print("next_confs",next_confs.shape) 319 | # print("next_confs_fuse",next_confs_fuse.shape) 320 | 321 | # print("end t",t,"i",i,"next_t",next_t) 322 | son_outputs = self.my_concat(son_outputs, next_outputs, 1) 323 | son_outputs_fuse = self.my_concat(son_outputs_fuse, next_outputs_fuse, 1) 324 | son_confs = self.my_concat(son_confs, next_confs, 1) 325 | son_confs_fuse = self.my_concat(son_confs_fuse, next_confs_fuse, 1) 326 | idxs = self.my_concat(idxs, sidx, 0) 327 | idxs_fuse = self.my_concat(idxs_fuse, sidx_fuse, 0) 328 | 329 | # print("t",t) 330 | son_outputs_save = jt.zeros(son_outputs.shape) 331 | son_outputs_save[:,idxs] = son_outputs 332 | son_outputs_save = jt.concat([output.unsqueeze(0), son_outputs_save], 0) 333 | son_confs_save = jt.zeros(son_confs.shape) 334 | son_confs_save[:,idxs] = son_confs 335 | son_confs_save = jt.concat([confidence.unsqueeze(1).unsqueeze(0), son_confs_save], 0) 336 | 337 | out_remain_mask = remain_mask.copy() 338 | out_remain_mask[threshold<=confidence] = 0 339 | idx_out = jt.arange(0,out_remain_mask.shape[0])[out_remain_mask==1] 340 | outputs_out = output[idx_out].unsqueeze(0) 341 | out_num[t] = outputs_out.shape[1] 342 | confs_out = confidence[idx_out].unsqueeze(1).unsqueeze(0) 343 | outputs_out = jt.concat([outputs_out, son_outputs_fuse], 1) 344 | confs_out = jt.concat([confs_out, son_confs_fuse], 1) 345 | idx_out = jt.concat([idx_out, idxs_fuse], 0) 346 | 347 | outputs_out_save = jt.zeros(output.unsqueeze(0).shape) 348 | outputs_out_save[:, idx_out] = outputs_out 349 | outputs_out_save = outputs_out_save[:, remain_mask==1] 350 | confs_out_save = jt.zeros(confidence.unsqueeze(1).unsqueeze(0).shape) 351 | confs_out_save[:, idx_out] = confs_out 352 | confs_out_save = confs_out_save[:, remain_mask==1] 353 | 354 | return son_outputs_save, outputs_out_save, son_confs_save, confs_out_save, out_num 355 | else: 356 | outputs_save = output.unsqueeze(0) 357 | outputs_save_log = outputs_save.copy() 358 | confs_save = confidence.unsqueeze(1).unsqueeze(0) 359 | # print("outputs_save",outputs_save.shape) 360 | # print("remain_mask",remain_mask.shape) 361 | # print("remain_mask==1",remain_mask==1) 362 | outputs_out_save = outputs_save[:, remain_mask==1] 363 | confs_out_save = confs_save[:, remain_mask==1] 364 | out_num[t] = outputs_out_save.shape[1] 365 | 366 | remain_mask[threshold<=confidence] = 0 367 | # print("out:", remain_mask.sum().numpy(), "remain:", remain_mask.shape[0]-remain_mask.sum().numpy()) 368 | 369 | # print("outputs_out_save",outputs_out_save.shape) 370 | if not self.training: 371 | if t%4==0: 372 | outputs_save_log[..., 0] *= 0. 373 | outputs_save_log[..., 1] *= 0. 374 | elif t%4==1: 375 | outputs_save_log[..., 0] *= 0. 376 | outputs_save_log[..., 2] *= 0. 377 | elif t%4==2: 378 | outputs_save_log[..., 1] *= 0. 379 | outputs_save_log[..., 2] *= 0. 380 | elif t%4==3: 381 | outputs_save_log[..., 0] *= 0. 382 | # elif t==1: 383 | # outputs_out_save[..., 1] *= 0. 384 | # elif t==2: 385 | # outputs_out_save[..., 2] *= 0. 386 | outputs_save = jt.concat([outputs_save, outputs_save_log], 0) 387 | confs_save = jt.concat([confs_save, confs_save], 0) 388 | # print("outputs_out_save out",outputs_out_save.shape) 389 | 390 | # print("outputs_out_save",outputs_out_save.shape) 391 | 392 | return outputs_save, outputs_out_save, confs_save, confs_out_save, out_num 393 | 394 | def do_train(self, x, p): 395 | input_pts, input_views = jt.split(x, [self.input_ch, self.input_ch_views], dim=-1) 396 | remain_mask = jt.ones(input_pts.shape[0]) 397 | outputs, outputs_fuse, confs, confs_fuse, out_num = self.search(0, p, input_pts, input_pts, input_views, remain_mask) 398 | 399 | outputs = jt.concat([outputs, outputs_fuse], 0) 400 | confs = jt.concat([confs, confs_fuse], 0) 401 | 402 | return outputs, confs, np.zeros([1]) 403 | 404 | def do_eval(self, x, p): 405 | input_pts, input_views = jt.split(x, [self.input_ch, self.input_ch_views], dim=-1) 406 | remain_mask = jt.ones(input_pts.shape[0]) 407 | outputs, outputs_fuse, confs, confs_fuse, out_num = self.search(0, p, input_pts, input_pts, input_views, remain_mask) 408 | 409 | log = "out: " 410 | sout_num = list(out_num) 411 | for i in range(len(sout_num)): 412 | log += str(i)+": %d; " % sout_num[i] 413 | # print(log) 414 | sout_num = np.array(sout_num) 415 | outputs = outputs[-1:] 416 | 417 | outputs = jt.concat([outputs, outputs_fuse], 0) 418 | confs = jt.concat([confs, confs_fuse], 0) 419 | 420 | return outputs, confs, sout_num 421 | 422 | def execute(self, x, p, training): 423 | self.training = training 424 | if training: 425 | return self.do_train(x, p) 426 | else: 427 | # return self.do_train(x, p) 428 | return self.do_eval(x, p) 429 | 430 | def load_weights_from_keras(self, weights): 431 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False" 432 | 433 | # Load pts_linears 434 | for i in range(self.D): 435 | idx_pts_linears = 2 * i 436 | self.pts_linears[i].weight.data = jt.array(np.transpose(weights[idx_pts_linears])) 437 | self.pts_linears[i].bias.data = jt.array(np.transpose(weights[idx_pts_linears+1])) 438 | 439 | # Load feature_linear 440 | idx_feature_linear = 2 * self.D 441 | self.feature_linear.weight.data = jt.array(np.transpose(weights[idx_feature_linear])) 442 | self.feature_linear.bias.data = jt.array(np.transpose(weights[idx_feature_linear+1])) 443 | 444 | # Load views_linears 445 | idx_views_linears = 2 * self.D + 2 446 | self.views_linears[0].weight.data = jt.array(np.transpose(weights[idx_views_linears])) 447 | self.views_linears[0].bias.data = jt.array(np.transpose(weights[idx_views_linears+1])) 448 | 449 | # Load rgb_linear 450 | idx_rbg_linear = 2 * self.D + 4 451 | self.rgb_linear.weight.data = jt.array(np.transpose(weights[idx_rbg_linear])) 452 | self.rgb_linear.bias.data = jt.array(np.transpose(weights[idx_rbg_linear+1])) 453 | 454 | # Load alpha_linear 455 | idx_alpha_linear = 2 * self.D + 6 456 | self.alpha_linear.weight.data = jt.array(np.transpose(weights[idx_alpha_linear])) 457 | self.alpha_linear.bias.data = jt.array(np.transpose(weights[idx_alpha_linear+1])) 458 | 459 | 460 | 461 | # Ray helpers 462 | def get_rays(H, W, focal, c2w, intrinsic = None): 463 | i, j = jt.meshgrid(jt.linspace(0, W-1, W), jt.linspace(0, H-1, H)) 464 | i = i.t() 465 | j = j.t() 466 | if intrinsic is None: 467 | dirs = jt.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -jt.ones_like(i)], -1).unsqueeze(-2) 468 | else: 469 | i+=0.5 470 | j+=0.5 471 | dirs = jt.stack([i, j, jt.ones_like(i)], -1).unsqueeze(-2) 472 | dirs = jt.sum(dirs * intrinsic[:3,:3], -1).unsqueeze(-2) 473 | # Rotate ray directions from camera frame to the world frame 474 | rays_d = jt.sum(dirs * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 475 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 476 | rays_o = c2w[:3,-1].expand(rays_d.shape) 477 | return rays_o, rays_d 478 | 479 | 480 | def get_rays_np(H, W, focal, c2w): 481 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 482 | dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1) 483 | # Rotate ray directions from camera frame to the world frame 484 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 485 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 486 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 487 | return rays_o, rays_d 488 | 489 | 490 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 491 | # Shift ray origins to near plane 492 | t = -(near + rays_o[...,2]) / rays_d[...,2] 493 | rays_o = rays_o + t.unsqueeze(-1) * rays_d 494 | 495 | # Projection 496 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 497 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 498 | o2 = 1. + 2. * near / rays_o[...,2] 499 | 500 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 501 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 502 | d2 = -2. * near / rays_o[...,2] 503 | 504 | rays_o = jt.stack([o0,o1,o2], -1) 505 | rays_d = jt.stack([d0,d1,d2], -1) 506 | 507 | return rays_o, rays_d 508 | 509 | 510 | # Hierarchical sampling (section 5.2) 511 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 512 | # Get pdf 513 | weights = weights + 1e-5 # prevent nans 514 | pdf = weights / jt.sum(weights, -1, keepdims=True) 515 | cdf = jt.cumsum(pdf, -1) 516 | cdf = jt.concat([jt.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) 517 | 518 | # Take uniform samples 519 | if det: 520 | u = jt.linspace(0., 1., steps=N_samples) 521 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 522 | else: 523 | u = jt.random(list(cdf.shape[:-1]) + [N_samples]) 524 | 525 | # Pytest, overwrite u with numpy's fixed random numbers 526 | if pytest: 527 | # np.random.seed(0) 528 | new_shape = list(cdf.shape[:-1]) + [N_samples] 529 | if det: 530 | u = np.linspace(0., 1., N_samples) 531 | u = np.broadcast_to(u, new_shape) 532 | else: 533 | u = np.random.rand(*new_shape) 534 | u = jt.array(u) 535 | 536 | # Invert CDF 537 | # u = u.contiguous() 538 | # inds = searchsorted(cdf, u, side='right') 539 | inds = jt.searchsorted(cdf, u, right=True) 540 | below = jt.maximum(jt.zeros_like(inds-1), inds-1) 541 | above = jt.minimum((cdf.shape[-1]-1) * jt.ones_like(inds), inds) 542 | inds_g = jt.stack([below, above], -1) # (batch, N_samples, 2) 543 | 544 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 545 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 546 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 547 | cdf_g = jt.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 548 | bins_g = jt.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 549 | 550 | denom = (cdf_g[...,1]-cdf_g[...,0]) 551 | cond = jt.where(denom<1e-5) 552 | denom[cond] = 1. 553 | t = (u-cdf_g[...,0])/denom 554 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 555 | 556 | return samples 557 | -------------------------------------------------------------------------------- /imgs/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gword/Recursive-NeRF/e61ceb302d0bc028fe4524771b363bf514342b5f/imgs/pipeline.jpg -------------------------------------------------------------------------------- /load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | import json 5 | import jittor as jt 6 | import cv2 7 | 8 | 9 | trans_t = lambda t : jt.array(np.array([ 10 | [1,0,0,0], 11 | [0,1,0,0], 12 | [0,0,1,t], 13 | [0,0,0,1]]).astype(np.float32)) 14 | 15 | rot_phi = lambda phi : jt.array(np.array([ 16 | [1,0,0,0], 17 | [0,np.cos(phi),-np.sin(phi),0], 18 | [0,np.sin(phi), np.cos(phi),0], 19 | [0,0,0,1]]).astype(np.float32)) 20 | 21 | rot_theta = lambda th : jt.array(np.array([ 22 | [np.cos(th),0,-np.sin(th),0], 23 | [0,1,0,0], 24 | [np.sin(th),0, np.cos(th),0], 25 | [0,0,0,1]]).astype(np.float32)) 26 | 27 | 28 | def pose_spherical(theta, phi, radius): 29 | c2w = trans_t(radius) 30 | c2w = rot_phi(phi/180.*np.pi) @ c2w 31 | c2w = rot_theta(theta/180.*np.pi) @ c2w 32 | c2w = jt.array(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 33 | return c2w 34 | 35 | 36 | def load_blender_data(basedir, half_res=False, testskip=1, factor=1, do_intrinsic = False): 37 | if half_res and factor==1: 38 | factor=2 39 | splits = ['train', 'val', 'test'] 40 | metas = {} 41 | for s in splits: 42 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 43 | metas[s] = json.load(fp) 44 | 45 | all_imgs = [] 46 | all_poses = [] 47 | counts = [0] 48 | for s in splits: 49 | meta = metas[s] 50 | imgs = [] 51 | poses = [] 52 | if s=='train' or testskip==0: 53 | skip = 1 54 | else: 55 | skip = testskip 56 | 57 | for frame in meta['frames'][::skip]: 58 | fname = os.path.join(basedir, frame['file_path'] + '.png') 59 | imgs.append(imageio.imread(fname)) 60 | poses.append(np.array(frame['transform_matrix'])) 61 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 62 | poses = np.array(poses).astype(np.float32) 63 | counts.append(counts[-1] + imgs.shape[0]) 64 | all_imgs.append(imgs) 65 | all_poses.append(poses) 66 | 67 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 68 | 69 | imgs = np.concatenate(all_imgs, 0) 70 | poses = np.concatenate(all_poses, 0) 71 | 72 | H, W = imgs[0].shape[:2] 73 | camera_angle_x = float(meta['camera_angle_x']) 74 | focal = .5 * W / np.tan(.5 * camera_angle_x) 75 | 76 | if do_intrinsic: 77 | a=np.array(meta['intrinsic_matrix']) 78 | # H, W, focal = 480,640,585 79 | # a = np.eye(4).astype(np.float32) 80 | # a[0,0]=focal 81 | # a[1,1]=focal 82 | # a[0,2]=W/2. 83 | # a[1,2]=H/2. 84 | if factor>1: 85 | a[:2]/=float(factor) 86 | a=np.linalg.inv(a) 87 | intrinsic=a 88 | print("intrinsic",intrinsic) 89 | 90 | render_poses = jt.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 91 | # render_poses = [] 92 | # meta = metas['test'] 93 | # for frame in meta['frames'][:40]: 94 | # render_poses.append(np.array(frame['transform_matrix'])) 95 | # render_poses = jt.array(np.array(render_poses).astype(np.float32)) 96 | if do_intrinsic: 97 | render_poses = [] 98 | meta = metas['test'] 99 | start = np.array(meta['frames'][0]['transform_matrix']) 100 | render_poses.append(start) 101 | for f in range(50,len(meta['frames']),50): 102 | end = np.array(meta['frames'][f]['transform_matrix']) 103 | for i in range(10): 104 | p=i/9. 105 | render_poses.append(start*(1.0-p)+end*p) 106 | start = end 107 | render_poses = jt.array(np.array(render_poses).astype(np.float32)) 108 | 109 | if factor>1: 110 | H = H//factor 111 | W = W//factor 112 | focal = focal/float(factor) 113 | 114 | imgs_half_res = np.zeros((imgs.shape[0], H, W, imgs.shape[-1])) 115 | for i, img in enumerate(imgs): 116 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 117 | imgs = imgs_half_res 118 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 119 | 120 | 121 | if do_intrinsic: 122 | return imgs, poses, intrinsic, render_poses, [H, W, focal], i_split 123 | else: 124 | return imgs, poses, render_poses, [H, W, focal], i_split 125 | 126 | 127 | -------------------------------------------------------------------------------- /load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | 8 | 9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 10 | # Get camera intrinsics 11 | with open(filepath, 'r') as file: 12 | f, cx, cy = list(map(float, file.readline().split()))[:3] 13 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 14 | near_plane = float(file.readline()) 15 | scale = float(file.readline()) 16 | height, width = map(float, file.readline().split()) 17 | 18 | try: 19 | world2cam_poses = int(file.readline()) 20 | except ValueError: 21 | world2cam_poses = None 22 | 23 | if world2cam_poses is None: 24 | world2cam_poses = False 25 | 26 | world2cam_poses = bool(world2cam_poses) 27 | 28 | print(cx,cy,f,height,width) 29 | 30 | cx = cx / width * trgt_sidelength 31 | cy = cy / height * trgt_sidelength 32 | f = trgt_sidelength / height * f 33 | 34 | fx = f 35 | if invert_y: 36 | fy = -f 37 | else: 38 | fy = f 39 | 40 | # Build the intrinsic matrices 41 | full_intrinsic = np.array([[fx, 0., cx, 0.], 42 | [0., fy, cy, 0], 43 | [0., 0, 1, 0], 44 | [0, 0, 0, 1]]) 45 | 46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 47 | 48 | 49 | def load_pose(filename): 50 | assert os.path.isfile(filename) 51 | nums = open(filename).read().split() 52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) 53 | 54 | 55 | H = 512 56 | W = 512 57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 58 | 59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 61 | focal = full_intrinsic[0,0] 62 | print(H, W, focal) 63 | 64 | 65 | def dir2poses(posedir): 66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 67 | transf = np.array([ 68 | [1,0,0,0], 69 | [0,-1,0,0], 70 | [0,0,-1,0], 71 | [0,0,0,1.], 72 | ]) 73 | poses = poses @ transf 74 | poses = poses[:,:3,:4].astype(np.float32) 75 | return poses 76 | 77 | posedir = os.path.join(deepvoxels_base, 'pose') 78 | poses = dir2poses(posedir) 79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 80 | testposes = testposes[::testskip] 81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 82 | valposes = valposes[::testskip] 83 | 84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) 86 | 87 | 88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 91 | 92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 95 | 96 | all_imgs = [imgs, valimgs, testimgs] 97 | counts = [0] + [x.shape[0] for x in all_imgs] 98 | counts = np.cumsum(counts) 99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 100 | 101 | imgs = np.concatenate(all_imgs, 0) 102 | poses = np.concatenate([poses, valposes, testposes], 0) 103 | 104 | render_poses = testposes 105 | 106 | print(poses.shape, imgs.shape) 107 | 108 | return imgs, poses, render_poses, [H,W,focal], i_split 109 | 110 | 111 | -------------------------------------------------------------------------------- /load_llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, imageio 3 | 4 | 5 | ########## Slightly modified version of LLFF data loading code 6 | ########## see https://github.com/Fyusion/LLFF for original 7 | 8 | def _minify(basedir, factors=[], resolutions=[]): 9 | needtoload = False 10 | for r in factors: 11 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 12 | if not os.path.exists(imgdir): 13 | needtoload = True 14 | for r in resolutions: 15 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 16 | if not os.path.exists(imgdir): 17 | needtoload = True 18 | if not needtoload: 19 | return 20 | 21 | from shutil import copy 22 | from subprocess import check_output 23 | 24 | imgdir = os.path.join(basedir, 'images') 25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 27 | imgdir_orig = imgdir 28 | 29 | wd = os.getcwd() 30 | 31 | for r in factors + resolutions: 32 | if isinstance(r, int): 33 | name = 'images_{}'.format(r) 34 | resizearg = '{}%'.format(100./r) 35 | else: 36 | name = 'images_{}x{}'.format(r[1], r[0]) 37 | resizearg = '{}x{}'.format(r[1], r[0]) 38 | imgdir = os.path.join(basedir, name) 39 | if os.path.exists(imgdir): 40 | continue 41 | 42 | print('Minifying', r, basedir) 43 | 44 | os.makedirs(imgdir) 45 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 46 | 47 | ext = imgs[0].split('.')[-1] 48 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 49 | print(args) 50 | os.chdir(imgdir) 51 | check_output(args, shell=True) 52 | os.chdir(wd) 53 | 54 | if ext != 'png': 55 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 56 | print('Removed duplicates') 57 | print('Done') 58 | 59 | 60 | 61 | 62 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 63 | 64 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 65 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) 66 | bds = poses_arr[:, -2:].transpose([1,0]) 67 | 68 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 69 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 70 | sh = imageio.imread(img0).shape 71 | 72 | sfx = '' 73 | 74 | if factor is not None: 75 | sfx = '_{}'.format(factor) 76 | _minify(basedir, factors=[factor]) 77 | factor = factor 78 | elif height is not None: 79 | factor = sh[0] / float(height) 80 | width = int(sh[1] / factor) 81 | _minify(basedir, resolutions=[[height, width]]) 82 | sfx = '_{}x{}'.format(width, height) 83 | elif width is not None: 84 | factor = sh[1] / float(width) 85 | height = int(sh[0] / factor) 86 | _minify(basedir, resolutions=[[height, width]]) 87 | sfx = '_{}x{}'.format(width, height) 88 | else: 89 | factor = 1 90 | 91 | imgdir = os.path.join(basedir, 'images' + sfx) 92 | if not os.path.exists(imgdir): 93 | print( imgdir, 'does not exist, returning' ) 94 | return 95 | 96 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 97 | if poses.shape[-1] != len(imgfiles): 98 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) 99 | return 100 | 101 | sh = imageio.imread(imgfiles[0]).shape 102 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 103 | poses[2, 4, :] = poses[2, 4, :] * 1./factor 104 | 105 | if not load_imgs: 106 | return poses, bds 107 | 108 | def imread(f): 109 | if f.endswith('png'): 110 | return imageio.imread(f, ignoregamma=True) 111 | else: 112 | return imageio.imread(f) 113 | 114 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] 115 | imgs = np.stack(imgs, -1) 116 | 117 | print('Loaded image data', imgs.shape, poses[:,-1,0]) 118 | return poses, bds, imgs 119 | 120 | 121 | 122 | 123 | 124 | 125 | def normalize(x): 126 | return x / np.linalg.norm(x) 127 | 128 | def viewmatrix(z, up, pos): 129 | vec2 = normalize(z) 130 | vec1_avg = up 131 | vec0 = normalize(np.cross(vec1_avg, vec2)) 132 | vec1 = normalize(np.cross(vec2, vec0)) 133 | m = np.stack([vec0, vec1, vec2, pos], 1) 134 | return m 135 | 136 | def ptstocam(pts, c2w): 137 | tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0] 138 | return tt 139 | 140 | def poses_avg(poses): 141 | 142 | hwf = poses[0, :3, -1:] 143 | 144 | center = poses[:, :3, 3].mean(0) 145 | vec2 = normalize(poses[:, :3, 2].sum(0)) 146 | up = poses[:, :3, 1].sum(0) 147 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 148 | 149 | return c2w 150 | 151 | 152 | 153 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 154 | render_poses = [] 155 | rads = np.array(list(rads) + [1.]) 156 | hwf = c2w[:,4:5] 157 | 158 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 159 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 160 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 161 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 162 | return render_poses 163 | 164 | 165 | 166 | def recenter_poses(poses): 167 | 168 | poses_ = poses+0 169 | bottom = np.reshape([0,0,0,1.], [1,4]) 170 | c2w = poses_avg(poses) 171 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 172 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) 173 | poses = np.concatenate([poses[:,:3,:4], bottom], -2) 174 | 175 | poses = np.linalg.inv(c2w) @ poses 176 | poses_[:,:3,:4] = poses[:,:3,:4] 177 | poses = poses_ 178 | return poses 179 | 180 | 181 | ##################### 182 | 183 | 184 | def spherify_poses(poses, bds): 185 | 186 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) 187 | 188 | rays_d = poses[:,:3,2:3] 189 | rays_o = poses[:,:3,3:4] 190 | 191 | def min_line_dist(rays_o, rays_d): 192 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) 193 | b_i = -A_i @ rays_o 194 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) 195 | return pt_mindist 196 | 197 | pt_mindist = min_line_dist(rays_o, rays_d) 198 | 199 | center = pt_mindist 200 | up = (poses[:,:3,3] - center).mean(0) 201 | 202 | vec0 = normalize(up) 203 | vec1 = normalize(np.cross([.1,.2,.3], vec0)) 204 | vec2 = normalize(np.cross(vec0, vec1)) 205 | pos = center 206 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 207 | 208 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) 209 | 210 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) 211 | 212 | sc = 1./rad 213 | poses_reset[:,:3,3] *= sc 214 | bds *= sc 215 | rad *= sc 216 | 217 | centroid = np.mean(poses_reset[:,:3,3], 0) 218 | zh = centroid[2] 219 | radcircle = np.sqrt(rad**2-zh**2) 220 | new_poses = [] 221 | 222 | for th in np.linspace(0.,2.*np.pi, 120): 223 | 224 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 225 | up = np.array([0,0,-1.]) 226 | 227 | vec2 = normalize(camorigin) 228 | vec0 = normalize(np.cross(vec2, up)) 229 | vec1 = normalize(np.cross(vec2, vec0)) 230 | pos = camorigin 231 | p = np.stack([vec0, vec1, vec2, pos], 1) 232 | 233 | new_poses.append(p) 234 | 235 | new_poses = np.stack(new_poses, 0) 236 | 237 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) 238 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) 239 | 240 | return poses_reset, new_poses, bds 241 | 242 | 243 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): 244 | 245 | 246 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 247 | print('Loaded', basedir, bds.min(), bds.max()) 248 | 249 | # Correct rotation matrix ordering and move variable dim to axis 0 250 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 251 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 252 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 253 | images = imgs 254 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 255 | 256 | # Rescale if bd_factor is provided 257 | sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) 258 | poses[:,:3,3] *= sc 259 | bds *= sc 260 | 261 | if recenter: 262 | poses = recenter_poses(poses) 263 | 264 | if spherify: 265 | poses, render_poses, bds = spherify_poses(poses, bds) 266 | 267 | else: 268 | 269 | c2w = poses_avg(poses) 270 | print('recentered', c2w.shape) 271 | print(c2w[:3,:4]) 272 | 273 | ## Get spiral 274 | # Get average pose 275 | up = normalize(poses[:, :3, 1].sum(0)) 276 | 277 | # Find a reasonable "focus depth" for this dataset 278 | close_depth, inf_depth = bds.min()*.9, bds.max()*5. 279 | dt = .75 280 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) 281 | focal = mean_dz 282 | 283 | # Get radii for spiral path 284 | shrink_factor = .8 285 | zdelta = close_depth * .2 286 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T 287 | rads = np.percentile(np.abs(tt), 90, 0) 288 | c2w_path = c2w 289 | N_views = 120 290 | N_rots = 2 291 | if path_zflat: 292 | # zloc = np.percentile(tt, 10, 0)[2] 293 | zloc = -close_depth * .1 294 | c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2] 295 | rads[2] = 0. 296 | N_rots = 1 297 | N_views/=2 298 | 299 | # Generate poses for spiral path 300 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) 301 | 302 | 303 | render_poses = np.array(render_poses).astype(np.float32) 304 | 305 | c2w = poses_avg(poses) 306 | print('Data:') 307 | print(poses.shape, images.shape, bds.shape) 308 | 309 | dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1) 310 | i_test = np.argmin(dists) 311 | print('HOLDOUT view is', i_test) 312 | 313 | images = images.astype(np.float32) 314 | poses = poses.astype(np.float32) 315 | 316 | return images, poses, bds, render_poses, i_test 317 | 318 | 319 | 320 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jittor 2 | imageio 3 | imageio-ffmpeg 4 | matplotlib 5 | configargparse 6 | tensorboard==1.14.0 7 | tqdm 8 | opencv-python 9 | -------------------------------------------------------------------------------- /run_nerf.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import imageio 4 | import json 5 | import random 6 | import time 7 | import jittor as jt 8 | from jittor import nn 9 | from tqdm import tqdm, trange 10 | import datetime 11 | 12 | import matplotlib.pyplot as plt 13 | 14 | from run_nerf_helpers import * 15 | 16 | from load_llff import load_llff_data 17 | from load_deepvoxels import load_dv_data 18 | from load_blender import load_blender_data 19 | from tensorboardX import SummaryWriter 20 | 21 | 22 | jt.flags.use_cuda = 1 23 | # np.random.seed(0) 24 | DEBUG = False 25 | 26 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 27 | """Prepares inputs and applies network 'fn'. 28 | """ 29 | inputs_flat = jt.reshape(inputs, [-1, inputs.shape[-1]]) 30 | # print("x min", inputs_flat[:,0].min(), inputs_flat[:,0].max()) 31 | # print("y min", inputs_flat[:,1].min(), inputs_flat[:,1].max()) 32 | # print("z min", inputs_flat[:,2].min(), inputs_flat[:,2].max()) 33 | embedded = embed_fn(inputs_flat) 34 | training = not jt.flags.no_grad 35 | 36 | if viewdirs is not None: 37 | input_dirs = viewdirs[:,None].expand(inputs.shape) 38 | input_dirs_flat = jt.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 39 | embedded_dirs = embeddirs_fn(input_dirs_flat) 40 | embedded = jt.concat([embedded, embedded_dirs], -1) 41 | 42 | if netchunk is None: 43 | outputs_flat, loss_conf, loss_num = fn(embedded) 44 | else: 45 | flat_list = [] 46 | conf_list = [] 47 | sout_num_list = [] 48 | loss_conf = jt.zeros([1]) 49 | # loss_num = 0 50 | # pre_value = 0 51 | for i in range(0, embedded.shape[0], netchunk): 52 | flat, conf, sout_num = fn(embedded[i:i+netchunk], inputs_flat[i:i+netchunk], training) 53 | flat_list.append(flat) 54 | conf_list.append(conf) 55 | sout_num_list.append(np.expand_dims(sout_num, 0)) 56 | # loss_conf += lconf 57 | # loss_num += num 58 | # pre_value = (flat.reshape((-1))[0] + conf.reshape((-1))[0] + lconf.reshape((-1))[0]) * 0 59 | if not training: 60 | flat_list[-1].sync() 61 | conf_list[-1].sync() 62 | loss_conf.sync() 63 | outputs_flat = jt.concat(flat_list, 1) 64 | conf_flat = jt.concat(conf_list, 1) 65 | sout_flat = np.concatenate(sout_num_list, 0) 66 | sout_flat = sout_flat.sum(0) 67 | # loss_conf /= float(loss_num) 68 | outputs = jt.reshape(outputs_flat, [outputs_flat.shape[0]] + list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 69 | conf = jt.reshape(conf_flat, [conf_flat.shape[0]] + list(inputs.shape[:-1]) + [conf_flat.shape[-1]]) 70 | return outputs, conf, loss_conf, sout_flat 71 | 72 | 73 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 74 | """Render rays in smaller minibatches to avoid OOM. 75 | """ 76 | all_ret = {} 77 | for i in range(0, rays_flat.shape[0], chunk): 78 | # jt.display_memory_info() 79 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 80 | for k in ret: 81 | if k not in all_ret: 82 | all_ret[k] = [] 83 | all_ret[k].append(ret[k]) 84 | # print("all_ret[k]",k, len(all_ret[k]), all_ret[k][0].shape) 85 | # ret[k].sync() 86 | # jt.display_memory_info() 87 | for k in all_ret: 88 | # print("k",k) 89 | if k=="loss_conf" or k=="loss_conf0": 90 | all_ret[k] = jt.concat(all_ret[k], 0) 91 | elif k=="pts": 92 | all_ret[k] = np.concatenate(all_ret[k], axis=0) 93 | elif k=="outnum" or k=="outnum0": 94 | all_ret[k] = np.concatenate(all_ret[k], axis=0).sum(0) 95 | else: 96 | all_ret[k] = jt.concat(all_ret[k], 1) 97 | # all_ret = {k : jt.concat(all_ret[k], 1) for k in all_ret} 98 | return all_ret 99 | 100 | 101 | def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, intrinsic=None, ndc=True, 102 | near=0., far=1., 103 | use_viewdirs=False, c2w_staticcam=None, 104 | **kwargs): 105 | """Render rays 106 | Args: 107 | H: int. Height of image in pixels. 108 | W: int. Width of image in pixels. 109 | focal: float. Focal length of pinhole camera. 110 | chunk: int. Maximum number of rays to process simultaneously. Used to 111 | control maximum memory usage. Does not affect final results. 112 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 113 | each example in batch. 114 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 115 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 116 | near: float or array of shape [batch_size]. Nearest distance for a ray. 117 | far: float or array of shape [batch_size]. Farthest distance for a ray. 118 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 119 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 120 | camera while using other c2w argument for viewing directions. 121 | Returns: 122 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 123 | disp_map: [batch_size]. Disparity map. Inverse of depth. 124 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 125 | extras: dict with everything returned by render_rays(). 126 | """ 127 | if c2w is not None: 128 | # special case to render full image 129 | print("render c2w",c2w.shape) 130 | rays_o, rays_d = get_rays(H, W, focal, c2w, intrinsic) 131 | else: 132 | # use provided ray batch 133 | rays_o, rays_d = rays 134 | 135 | if use_viewdirs: 136 | # provide ray directions as input 137 | viewdirs = rays_d 138 | if c2w_staticcam is not None: 139 | assert intrinsic is None 140 | # special case to visualize effect of viewdirs 141 | print("render c2w_staticcam",c2w_staticcam.shape) 142 | rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam) 143 | viewdirs = viewdirs / jt.norm(viewdirs, k=2, dim=-1, keepdim=True) 144 | viewdirs = jt.reshape(viewdirs, [-1,3]).float() 145 | 146 | sh = rays_d.shape # [..., 3] 147 | if ndc: 148 | # for forward facing scenes 149 | rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d) 150 | 151 | # Create ray batch 152 | rays_o = jt.reshape(rays_o, [-1,3]).float() 153 | rays_d = jt.reshape(rays_d, [-1,3]).float() 154 | 155 | near, far = near * jt.ones_like(rays_d[...,:1]), far * jt.ones_like(rays_d[...,:1]) 156 | rays = jt.concat([rays_o, rays_d, near, far], -1) 157 | if use_viewdirs: 158 | rays = jt.concat([rays, viewdirs], -1) 159 | 160 | # Render and reshape 161 | all_ret = batchify_rays(rays, chunk, **kwargs) 162 | for k in all_ret: 163 | if k=="loss_conf" or k=="loss_conf0" or k=="pts" or k=="outnum" or k=="outnum0": 164 | continue 165 | k_sh = [all_ret[k].shape[0]] + list(sh[:-1]) + list(all_ret[k].shape[2:]) 166 | all_ret[k] = jt.reshape(all_ret[k], k_sh) 167 | 168 | k_extract = ['rgb_map', 'disp_map', 'acc_map'] 169 | ret_list = [all_ret[k] for k in k_extract] 170 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 171 | return ret_list + [ret_dict] 172 | 173 | 174 | def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0, intrinsic = None, get_points = False, log_path = None, large_scene = False): 175 | 176 | H, W, focal = hwf 177 | 178 | if render_factor!=0: 179 | # Render downsampled for speed 180 | H = H//render_factor 181 | W = W//render_factor 182 | focal = focal/render_factor 183 | 184 | rgbs = [] 185 | rgbs_log = [] 186 | disps = [] 187 | outnum0 = [] 188 | outnum = [] 189 | points = None 190 | 191 | t = time.time() 192 | for i, c2w in enumerate(tqdm(render_poses)): 193 | # import ipdb 194 | # ipdb.set_trace() 195 | print(i, time.time() - t) 196 | t = time.time() 197 | rgb, disp, acc, extras = render(H, W, focal, chunk=chunk, c2w=c2w[:3,:4], intrinsic=intrinsic, **render_kwargs) 198 | rgbs.append(rgb[-1].numpy()) 199 | rgbs_log.append(rgb[-2].numpy()) 200 | disps.append(disp[-1].numpy()) 201 | outnum0.append(np.expand_dims(extras['outnum0'], 0)) 202 | outnum.append(np.expand_dims(extras['outnum'], 0)) 203 | if get_points: 204 | point = extras['pts'] 205 | # conf = extras['conf_map'][-1].numpy() 206 | # point = point.reshape((-1, point.shape[-1])) 207 | # conf = conf.reshape((-1)) 208 | # idx=np.random.choice(point.shape[0], point.shape[0]//render_poses.shape[0]) 209 | # point = point[idx] 210 | if points is None: 211 | points = point 212 | else: 213 | points = np.concatenate((points, point), axis=0) 214 | 215 | if i==0: 216 | print(rgb.shape, disp.shape) 217 | 218 | """ 219 | if gt_imgs is not None and render_factor==0: 220 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) 221 | print(p) 222 | """ 223 | 224 | if savedir is not None: 225 | rgb8 = to8b(rgbs[-1]) 226 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 227 | imageio.imwrite(filename, rgb8) 228 | del rgb 229 | del disp 230 | del acc 231 | del extras 232 | 233 | if large_scene: 234 | if get_points and intrinsic is None: 235 | break 236 | else: 237 | if get_points: 238 | break 239 | if jt.mpi and jt.mpi.local_rank()!=0: 240 | break 241 | 242 | outnum0 = np.concatenate(outnum0, axis=0).sum(0) 243 | outnum = np.concatenate(outnum, axis=0).sum(0) 244 | sout_num = list(outnum0) 245 | log = "" 246 | for i in range(len(sout_num)): 247 | log += str(i)+": %d, " % int(sout_num[i]) 248 | sout_num = list(outnum) 249 | log += "\n" 250 | for i in range(len(sout_num)): 251 | log += str(i)+": %d, " % int(sout_num[i]) 252 | print(log) 253 | if log_path is not None: 254 | with open(log_path, 'w') as file_object: 255 | file_object.write(log) 256 | 257 | rgbs = np.stack(rgbs, 0) 258 | rgbs_log = np.stack(rgbs_log, 0) 259 | disps = np.stack(disps, 0) 260 | 261 | if get_points: 262 | return rgbs, disps, rgbs_log, points 263 | else: 264 | return rgbs, disps, rgbs_log 265 | 266 | def create_nerf(args): 267 | """Instantiate NeRF's MLP model. 268 | """ 269 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 270 | 271 | input_ch_views = 0 272 | embeddirs_fn = None 273 | if args.use_viewdirs: 274 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 275 | output_ch = 5 if args.N_importance > 0 else 4 276 | skips = [4] 277 | model = NeRF(D=args.netdepth, W=args.netwidth, 278 | input_ch=input_ch, output_ch=output_ch, skips=skips, 279 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, head_num=args.head_num, threshold=args.threshold*2.0 if args.large_scene else args.threshold*10.0) 280 | grad_vars = list(model.parameters()) 281 | 282 | model_fine = None 283 | if args.N_importance > 0: 284 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 285 | input_ch=input_ch, output_ch=output_ch, skips=skips, 286 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, head_num=args.head_num, threshold=args.threshold) 287 | grad_vars += list(model_fine.parameters()) 288 | 289 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 290 | embed_fn=embed_fn, 291 | embeddirs_fn=embeddirs_fn, 292 | netchunk=args.netchunk) 293 | 294 | # Create optimizer 295 | optimizer = jt.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 296 | 297 | start = 0 298 | basedir = args.basedir 299 | expname = args.expname 300 | 301 | ########################## 302 | 303 | # Load checkpoints 304 | if args.ft_path is not None and args.ft_path!='None': 305 | ckpts = [args.ft_path] 306 | else: 307 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f] 308 | 309 | print('Found ckpts', ckpts) 310 | if len(ckpts) > 0 and not args.no_reload: 311 | ckpt_path = ckpts[-1] 312 | print('Reloading from', ckpt_path) 313 | ckpt = jt.load(ckpt_path) 314 | 315 | start = ckpt['global_step'] 316 | # optimizer.load_state_dict(ckpt['optimizer_state_dict']) 317 | 318 | # Load model 319 | model.load_state_dict(ckpt['network_fn_state_dict']) 320 | if model_fine is not None: 321 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 322 | 323 | ########################## 324 | 325 | render_kwargs_train = { 326 | 'network_query_fn' : network_query_fn, 327 | 'perturb' : args.perturb, 328 | 'N_importance' : args.N_importance, 329 | 'network_fine' : model_fine, 330 | 'N_samples' : args.N_samples, 331 | 'network_fn' : model, 332 | 'use_viewdirs' : args.use_viewdirs, 333 | 'white_bkgd' : args.white_bkgd, 334 | 'raw_noise_std' : args.raw_noise_std, 335 | 'threshold' : args.threshold, 336 | } 337 | 338 | # NDC only good for LLFF-style forward facing data 339 | if args.dataset_type != 'llff' or args.no_ndc: 340 | print('Not ndc!') 341 | render_kwargs_train['ndc'] = False 342 | render_kwargs_train['lindisp'] = args.lindisp 343 | 344 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 345 | render_kwargs_test['perturb'] = False 346 | render_kwargs_test['raw_noise_std'] = 0. 347 | 348 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 349 | 350 | 351 | def raw2outputs(raw, conf, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): 352 | """Transforms model's predictions to semantically meaningful values. 353 | Args: 354 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 355 | z_vals: [num_rays, num_samples along ray]. Integration time. 356 | rays_d: [num_rays, 3]. Direction of each ray. 357 | Returns: 358 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 359 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 360 | acc_map: [num_rays]. Sum of weights along each ray. 361 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 362 | depth_map: [num_rays]. Estimated distance to object. 363 | """ 364 | raw2alpha = lambda raw, dists, act_fn=jt.nn.relu: 1.-jt.exp(-act_fn(raw)*dists) 365 | 366 | dists = z_vals[...,1:] - z_vals[...,:-1] 367 | dists = jt.concat([dists, jt.array(np.array([1e10]).astype(np.float32)).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] 368 | 369 | dists = dists * jt.norm(rays_d.unsqueeze(-2), k=2, dim=-1) 370 | 371 | rgb = jt.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] 372 | noise = 0. 373 | if raw_noise_std > 0.: 374 | noise = jt.init.gauss(raw[...,3].shape, raw.dtype) * raw_noise_std 375 | 376 | # Overwrite randomly sampled data if pytest 377 | if pytest: 378 | # np.random.seed(0) 379 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std 380 | noise = jt.array(noise) 381 | 382 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] 383 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 384 | weights = alpha * jt.cumprod(jt.concat([jt.ones((1,alpha.shape[1], 1)), 1.-alpha + 1e-10], -1), -1)[..., :-1] 385 | rgb_map = jt.sum(weights.unsqueeze(-1) * rgb, -2) # [N_rays, 3] 386 | # conf_map = jt.sum(weights.unsqueeze(-1) * conf, -2) # [N_rays, 1] 387 | # conf_map = jt.mean(conf, -2) # [N_rays, 1] 388 | conf_map = conf 389 | 390 | depth_map = jt.sum(weights * z_vals, -1) 391 | disp_map = 1./jt.maximum(1e-10 * jt.ones_like(depth_map), depth_map / jt.sum(weights, -1)) 392 | acc_map = jt.sum(weights, -1) 393 | 394 | if white_bkgd: 395 | rgb_map = rgb_map + (1.-acc_map.unsqueeze(-1)) 396 | 397 | return rgb_map, disp_map, acc_map, weights, depth_map, conf_map 398 | 399 | 400 | def render_rays(ray_batch, 401 | network_fn, 402 | network_query_fn, 403 | N_samples, 404 | retraw=False, 405 | lindisp=False, 406 | perturb=0., 407 | N_importance=0, 408 | network_fine=None, 409 | white_bkgd=False, 410 | raw_noise_std=0., 411 | threshold=3e-2, 412 | verbose=False, 413 | pytest=False): 414 | """Volumetric rendering. 415 | Args: 416 | ray_batch: array of shape [batch_size, ...]. All information necessary 417 | for sampling along a ray, including: ray origin, ray direction, min 418 | dist, max dist, and unit-magnitude viewing direction. 419 | network_fn: function. Model for predicting RGB and density at each point 420 | in space. 421 | network_query_fn: function used for passing queries to network_fn. 422 | N_samples: int. Number of different times to sample along each ray. 423 | retraw: bool. If True, include model's raw, unprocessed predictions. 424 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 425 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 426 | random points in time. 427 | N_importance: int. Number of additional times to sample along each ray. 428 | These samples are only passed to network_fine. 429 | network_fine: "fine" network with same spec as network_fn. 430 | white_bkgd: bool. If True, assume a white background. 431 | raw_noise_std: ... 432 | verbose: bool. If True, print more debugging info. 433 | Returns: 434 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. 435 | disp_map: [num_rays]. Disparity map. 1 / depth. 436 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. 437 | raw: [num_rays, num_samples, 4]. Raw predictions from model. 438 | rgb0: See rgb_map. Output for coarse model. 439 | disp0: See disp_map. Output for coarse model. 440 | acc0: See acc_map. Output for coarse model. 441 | z_std: [num_rays]. Standard deviation of distances along ray for each 442 | sample. 443 | """ 444 | training = not jt.flags.no_grad 445 | N_rays = ray_batch.shape[0] 446 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each 447 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None 448 | bounds = jt.reshape(ray_batch[...,6:8], [-1,1,2]) 449 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 450 | 451 | t_vals = jt.linspace(0., 1., steps=N_samples) 452 | if not lindisp: 453 | z_vals = near * (1.-t_vals) + far * (t_vals) 454 | else: 455 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) 456 | 457 | z_vals = z_vals.expand([N_rays, N_samples]) 458 | 459 | if perturb > 0.: 460 | # get intervals between samples 461 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 462 | upper = jt.concat([mids, z_vals[...,-1:]], -1) 463 | lower = jt.concat([z_vals[...,:1], mids], -1) 464 | # stratified samples in those intervals 465 | t_rand = jt.random(z_vals.shape) 466 | 467 | # Pytest, overwrite u with numpy's fixed random numbers 468 | if pytest: 469 | # np.random.seed(0) 470 | t_rand = np.random.rand(*list(z_vals.shape)) 471 | t_rand = jt.array(t_rand) 472 | 473 | z_vals = lower + (upper - lower) * t_rand 474 | 475 | pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N_rays, N_samples, 3] 476 | 477 | 478 | # raw = run_network(pts) 479 | raw, conf, loss_conf, outnum = network_query_fn(pts, viewdirs, network_fn) 480 | rgb_map, disp_map, acc_map, weights, depth_map, conf_map = raw2outputs(raw, conf, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 481 | 482 | if N_importance > 0: 483 | 484 | rgb_map_0, disp_map_0, acc_map_0, conf_map_0, loss_conf_0, weights_0, outnum_0 = rgb_map, disp_map, acc_map, conf_map, loss_conf, weights, outnum 485 | 486 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 487 | z_samples = sample_pdf(z_vals_mid, weights[-1,...,1:-1], N_importance, det=(perturb==0.), pytest=pytest) 488 | z_samples = z_samples.detach() 489 | 490 | _, z_vals = jt.argsort(jt.concat([z_vals, z_samples], -1), -1) 491 | pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N_rays, N_samples + N_importance, 3] 492 | 493 | run_fn = network_fn if network_fine is None else network_fine 494 | # raw = run_network(pts, fn=run_fn) 495 | raw, conf, loss_conf, outnum = network_query_fn(pts, viewdirs, run_fn) 496 | # print("raw",raw.shape) 497 | # print("pts",pts.shape) 498 | 499 | rgb_map, disp_map, acc_map, weights, depth_map, conf_map = raw2outputs(raw, conf, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 500 | # print("rgb_map",rgb_map.shape) 501 | # print("conf_map",conf_map.shape) 502 | if training: 503 | ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map, 'conf_map' : conf_map, 'loss_conf' : loss_conf, 'weights' : weights} 504 | if retraw: 505 | ret['raw'] = raw 506 | if N_importance > 0: 507 | ret['rgb0'] = rgb_map_0 508 | ret['disp0'] = disp_map_0 509 | ret['acc0'] = acc_map_0 510 | ret['conf_map0'] = conf_map_0 511 | ret['loss_conf0'] = loss_conf_0 512 | ret['weights0'] = weights_0 513 | # ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] TODO: support jt.std 514 | else: 515 | point = pts.numpy() 516 | conf = conf_map[-1].numpy() 517 | point = point.reshape((-1, point.shape[-1])) 518 | conf = conf.reshape((-1)) 519 | idx=np.random.choice(point.shape[0], point.shape[0]//10, replace=False) 520 | point = point[idx] 521 | conf = conf[idx] 522 | # threshold = 0.0 523 | point = point[conf>=threshold] 524 | ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map, 'pts' : point, 'outnum' : np.expand_dims(outnum,0)} 525 | if retraw: 526 | ret['raw'] = raw 527 | if N_importance > 0: 528 | ret['rgb0'] = rgb_map_0 529 | ret['disp0'] = disp_map_0 530 | ret['acc0'] = acc_map_0 531 | ret['outnum0'] = np.expand_dims(outnum_0,0) 532 | 533 | # for k in ret: 534 | # if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 535 | # print(f"! [Numerical Error] {k} contains nan or inf.") 536 | 537 | return ret 538 | 539 | def dfs(t, points, model, model_fine): 540 | k = len(model.son_list[t]) 541 | if t in model.force_out: 542 | if points.shape[0]>=k: 543 | centroid = points[jt.array(np.random.choice(points.shape[0], k, replace=False))] 544 | print("centroid",centroid.shape) 545 | # print("step",-1,centroid.numpy()) 546 | for step in range(100): 547 | dis = (points.unsqueeze(1) - centroid.unsqueeze(0)).sqr().sum(-1).sqrt() 548 | min_idx, _ = jt.argmin(dis,-1) 549 | # print("min_idx",min_idx.shape) 550 | for i in range(k): 551 | # print("i",i,(min_idx==i).sum()) 552 | centroid[i] = points[min_idx==i].mean(0) 553 | # print("step",step,centroid.numpy()) 554 | else: 555 | centroid = jt.rand((k,3)) 556 | print("centroid fail",centroid.shape) 557 | setattr(model, model.node_list[t].anchors, centroid.detach()) 558 | setattr(model_fine, model_fine.node_list[t].anchors, centroid.detach()) 559 | if jt.mpi: 560 | # v1 = getattr(model, model.node_list[t].anchors) 561 | # v1.assign(jt.mpi.broadcast(v2, root=0)) 562 | # v2 = getattr(model_fine, model_fine.node_list[t].anchors) 563 | # v2.assign(jt.mpi.broadcast(v2, root=0)) 564 | jt.mpi.broadcast(getattr(model, model.node_list[t].anchors), root=0) 565 | jt.mpi.broadcast(getattr(model_fine, model_fine.node_list[t].anchors), root=0) 566 | print("model", jt.mpi.local_rank(), t, getattr(model, model.node_list[t].anchors)) 567 | print("modelfine", jt.mpi.local_rank(), t, getattr(model_fine, model_fine.node_list[t].anchors)) 568 | for i in model.son_list[t]: 569 | model.outnet[i].alpha_linear.load_state_dict(model.outnet[t].alpha_linear.state_dict()) 570 | model_fine.outnet[i].alpha_linear.load_state_dict(model_fine.outnet[t].alpha_linear.state_dict()) 571 | # model.outnet[i].load_state_dict(model.outnet[t].state_dict()) 572 | # model_fine.outnet[i].load_state_dict(model_fine.outnet[t].state_dict()) 573 | return model.son_list[t] 574 | else: 575 | centroid = model.get_anchor(model.node_list[t].anchors) 576 | dis = (points.unsqueeze(1) - centroid.unsqueeze(0)).sqr().sum(-1).sqrt() 577 | min_idx, _ = jt.argmin(dis,-1) 578 | res = [] 579 | for i in range(k): 580 | res += dfs(model.son_list[t][i], points[min_idx==i], model, model_fine) 581 | return res 582 | 583 | def do_kmeans(points, model, model_fine): 584 | # points = points.reshape(-1, points.shape[-1]) 585 | # confs = confs.reshape(-1) 586 | print("do_kmeans",points.shape) 587 | points = jt.array(points) 588 | force_out = dfs(0, points, model, model_fine) 589 | 590 | model.force_out = force_out 591 | model_fine.force_out = force_out 592 | 593 | def config_parser(): 594 | gpu = "gpu"+os.environ["CUDA_VISIBLE_DEVICES"] 595 | import configargparse 596 | parser = configargparse.ArgumentParser() 597 | parser.add_argument('--config', is_config_file=True, 598 | help='config file path') 599 | parser.add_argument("--expname", type=str, 600 | help='experiment name') 601 | parser.add_argument("--basedir", type=str, default='./logs/'+gpu+"/", 602 | help='where to store ckpts and logs') 603 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 604 | help='input data directory') 605 | 606 | # training options 607 | parser.add_argument("--netdepth", type=int, default=8, 608 | help='layers in network') 609 | parser.add_argument("--step1", type=int, default=5000, 610 | help='?') 611 | parser.add_argument("--step2", type=int, default=10000, 612 | help='?') 613 | parser.add_argument("--step3", type=int, default=15000, 614 | help='?') 615 | parser.add_argument("--netwidth", type=int, default=256, 616 | help='channels per layer') 617 | parser.add_argument("--netdepth_fine", type=int, default=8, 618 | help='layers in fine network') 619 | parser.add_argument("--netwidth_fine", type=int, default=256, 620 | help='channels per layer in fine network') 621 | parser.add_argument("--N_rand", type=int, default=32*32*4, 622 | help='batch size (number of random rays per gradient step)') 623 | parser.add_argument("--lrate", type=float, default=5e-4, 624 | help='learning rate') 625 | parser.add_argument("--lrate_decay", type=int, default=250, 626 | help='exponential learning rate decay (in 1000 steps)') 627 | parser.add_argument("--chunk", type=int, default=1024*8, 628 | help='number of rays processed in parallel, decrease if running out of memory') 629 | parser.add_argument("--netchunk", type=int, default=1024*64, 630 | help='number of pts sent through network in parallel, decrease if running out of memory') 631 | parser.add_argument("--no_batching", action='store_true', 632 | help='only take random rays from 1 image at a time') 633 | parser.add_argument("--no_reload", action='store_true', 634 | help='do not reload weights from saved ckpt') 635 | parser.add_argument("--ft_path", type=str, default=None, 636 | help='specific weights npy file to reload for coarse network') 637 | parser.add_argument("--threshold", type=float, default=1e-2, 638 | help='threshold') 639 | 640 | # rendering options 641 | parser.add_argument("--N_samples", type=int, default=64, 642 | help='number of coarse samples per ray') 643 | parser.add_argument("--N_importance", type=int, default=0, 644 | help='number of additional fine samples per ray') 645 | parser.add_argument("--perturb", type=float, default=1., 646 | help='set to 0. for no jitter, 1. for jitter') 647 | parser.add_argument("--use_viewdirs", action='store_true', 648 | help='use full 5D input instead of 3D') 649 | parser.add_argument("--i_embed", type=int, default=0, 650 | help='set 0 for default positional encoding, -1 for none') 651 | parser.add_argument("--multires", type=int, default=10, 652 | help='log2 of max freq for positional encoding (3D location)') 653 | parser.add_argument("--multires_views", type=int, default=4, 654 | help='log2 of max freq for positional encoding (2D direction)') 655 | parser.add_argument("--raw_noise_std", type=float, default=0., 656 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 657 | 658 | parser.add_argument("--render_only", action='store_true', 659 | help='do not optimize, reload weights and render out render_poses path') 660 | parser.add_argument("--render_test", action='store_true', 661 | help='render the test set instead of render_poses path') 662 | parser.add_argument("--render_factor", type=int, default=0, 663 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 664 | 665 | # training options 666 | parser.add_argument("--precrop_iters", type=int, default=0, 667 | help='number of steps to train on central crops') 668 | parser.add_argument("--head_num", type=int, default=8, 669 | help='number of heads') 670 | parser.add_argument("--precrop_frac", type=float, 671 | default=.5, help='fraction of img taken for central crops') 672 | parser.add_argument("--large_scene", action='store_true', 673 | help='use large scene') 674 | 675 | # dataset options 676 | parser.add_argument("--dataset_type", type=str, default='llff', 677 | help='options: llff / blender / deepvoxels') 678 | parser.add_argument("--testskip", type=int, default=8, 679 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 680 | parser.add_argument("--faketestskip", type=int, default=1, 681 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 682 | 683 | ## deepvoxels flags 684 | parser.add_argument("--shape", type=str, default='greek', 685 | help='options : armchair / cube / greek / vase') 686 | 687 | ## blender flags 688 | parser.add_argument("--white_bkgd", action='store_true', 689 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 690 | parser.add_argument("--half_res", action='store_true', 691 | help='load blender synthetic data at 400x400 instead of 800x800') 692 | parser.add_argument("--near", type=float, default=2., 693 | help='downsample factor for LLFF images') 694 | parser.add_argument("--far", type=float, default=6., 695 | help='downsample factor for LLFF images') 696 | parser.add_argument("--do_intrinsic", action='store_true', 697 | help='use intrinsic matrix') 698 | parser.add_argument("--blender_factor", type=int, default=1, 699 | help='downsample factor for blender images') 700 | 701 | ## llff flags 702 | parser.add_argument("--factor", type=int, default=8, 703 | help='downsample factor for LLFF images') 704 | parser.add_argument("--no_ndc", action='store_true', 705 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 706 | parser.add_argument("--lindisp", action='store_true', 707 | help='sampling linearly in disparity rather than depth') 708 | parser.add_argument("--spherify", action='store_true', 709 | help='set for spherical 360 scenes') 710 | parser.add_argument("--llffhold", type=int, default=8, 711 | help='will take every 1/N images as LLFF test set, paper uses 8') 712 | 713 | # logging/saving options 714 | parser.add_argument("--i_print", type=int, default=100, 715 | help='frequency of console printout and metric loggin') 716 | parser.add_argument("--i_img", type=int, default=5000, 717 | help='frequency of tensorboard image logging') 718 | parser.add_argument("--i_weights", type=int, default=10000, 719 | help='frequency of weight ckpt saving') 720 | parser.add_argument("--i_testset", type=int, default=50000, 721 | help='frequency of testset saving') 722 | parser.add_argument("--i_tottest", type=int, default=400000, 723 | help='frequency of testset saving') 724 | parser.add_argument("--i_video", type=int, default=50000, 725 | help='frequency of render_poses video saving') 726 | 727 | return parser 728 | 729 | 730 | def train(): 731 | 732 | parser = config_parser() 733 | args = parser.parse_args() 734 | 735 | # Load data 736 | intrinsic = None 737 | if args.dataset_type == 'llff': 738 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 739 | recenter=True, bd_factor=.75, 740 | spherify=args.spherify) 741 | hwf = poses[0,:3,-1] 742 | poses = poses[:,:3,:4] 743 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 744 | if not isinstance(i_test, list): 745 | i_test = [i_test] 746 | 747 | if args.llffhold > 0: 748 | print('Auto LLFF holdout,', args.llffhold) 749 | i_test = np.arange(images.shape[0])[::args.llffhold] 750 | 751 | i_val = i_test 752 | i_test_tot = i_test 753 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 754 | (i not in i_test and i not in i_val)]) 755 | 756 | print('DEFINING BOUNDS') 757 | if args.no_ndc: 758 | near = np.ndarray.min(bds) * .9 759 | far = np.ndarray.max(bds) * 1. 760 | 761 | else: 762 | near = 0. 763 | far = 1. 764 | print('NEAR FAR', near, far) 765 | 766 | elif args.dataset_type == 'blender': 767 | testskip = args.testskip 768 | faketestskip = args.faketestskip 769 | if jt.mpi and jt.mpi.local_rank()!=0: 770 | testskip = faketestskip 771 | faketestskip = 1 772 | if args.do_intrinsic: 773 | images, poses, intrinsic, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, testskip, args.blender_factor, True) 774 | else: 775 | images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, testskip, args.blender_factor) 776 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) 777 | i_train, i_val, i_test = i_split 778 | i_test_tot = i_test 779 | i_test = i_test[::faketestskip] 780 | 781 | near = args.near 782 | far = args.far 783 | print("near", near) 784 | print("far", far) 785 | 786 | if args.white_bkgd: 787 | # accs = images[...,-1] 788 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 789 | else: 790 | images = images[...,:3] 791 | 792 | elif args.dataset_type == 'deepvoxels': 793 | 794 | images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, 795 | basedir=args.datadir, 796 | testskip=args.testskip) 797 | 798 | print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) 799 | i_train, i_val, i_test = i_split 800 | 801 | hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1)) 802 | near = hemi_R-1. 803 | far = hemi_R+1. 804 | 805 | else: 806 | print('Unknown dataset type', args.dataset_type, 'exiting') 807 | return 808 | 809 | # Cast intrinsics to right types 810 | H, W, focal = hwf 811 | H, W = int(H), int(W) 812 | hwf = [H, W, focal] 813 | 814 | if args.render_test: 815 | render_poses = np.array(poses[i_test]) 816 | 817 | # Create log dir and copy the config file 818 | basedir = args.basedir 819 | expname = args.expname 820 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 821 | f = os.path.join(basedir, expname, 'args.txt') 822 | with open(f, 'w') as file: 823 | for arg in sorted(vars(args)): 824 | attr = getattr(args, arg) 825 | file.write('{} = {}\n'.format(arg, attr)) 826 | if args.config is not None: 827 | f = os.path.join(basedir, expname, 'config.txt') 828 | with open(f, 'w') as file: 829 | file.write(open(args.config, 'r').read()) 830 | 831 | # Create nerf model 832 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) 833 | global_step = start 834 | 835 | bds_dict = { 836 | 'near' : near, 837 | 'far' : far, 838 | } 839 | render_kwargs_train.update(bds_dict) 840 | render_kwargs_test.update(bds_dict) 841 | 842 | # Move testing data to GPU 843 | render_poses = jt.array(render_poses) 844 | 845 | # Short circuit if only rendering out from trained model 846 | if args.render_only: 847 | print('RENDER ONLY') 848 | with jt.no_grad(): 849 | if args.render_test: 850 | # render_test switches to test poses 851 | images = images[i_test] 852 | else: 853 | # Default is smoother render_poses path 854 | images = None 855 | 856 | testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)) 857 | os.makedirs(testsavedir, exist_ok=True) 858 | print('test poses shape', render_poses.shape) 859 | 860 | rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor, intrinsic = intrinsic) 861 | print('Done rendering', testsavedir) 862 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) 863 | 864 | return 865 | 866 | # Prepare raybatch tensor if batching random rays 867 | accumulation_steps = 2 868 | N_rand = args.N_rand//accumulation_steps 869 | use_batching = not args.no_batching 870 | if use_batching: 871 | # For random ray batching 872 | print('get rays') 873 | rays = np.stack([get_rays_np(H, W, focal, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] 874 | print('done, concats') 875 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3] 876 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] 877 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only 878 | rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3] 879 | rays_rgb = rays_rgb.astype(np.float32) 880 | print('shuffle rays') 881 | np.random.shuffle(rays_rgb) 882 | 883 | print('done') 884 | i_batch = 0 885 | 886 | # Move training data to GPU 887 | images = jt.array(images.astype(np.float32)) 888 | # accs = jt.array(accs.astype(np.float32)) 889 | # a=images[0].copy() 890 | # b=images[1].copy() 891 | # c=images[2].copy() 892 | # print("images0",a.sum()) 893 | # print("images1",b.sum()) 894 | # print("images2",c.sum()) 895 | # print("images0",images[0].numpy().sum()) 896 | # print("images1",images[1].numpy().sum()) 897 | # print("images2",images[2].numpy().sum()) 898 | # print("images0",images[0].sum().numpy()) 899 | # print("images1",images[1].sum().numpy()) 900 | # print("images2",images[2].sum().numpy()) 901 | poses = jt.array(poses) 902 | if use_batching: 903 | rays_rgb = jt.array(rays_rgb) 904 | 905 | 906 | N_iters = 300000*accumulation_steps + 1 907 | split_tree1 = args.step1*accumulation_steps 908 | split_tree2 = args.step2*accumulation_steps 909 | split_tree3 = args.step3*accumulation_steps 910 | # split_tree1 = args.i_img 911 | # split_tree2 = args.i_img*2 912 | print('Begin') 913 | print('TRAIN views are', i_train) 914 | print('TEST views are', i_test) 915 | print('VAL views are', i_val) 916 | 917 | # Summary writers 918 | # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) 919 | if not jt.mpi or jt.mpi.local_rank()==0: 920 | date = str(datetime.datetime.now()) 921 | date = date[:date.rfind(":")].replace("-", "")\ 922 | .replace(":", "")\ 923 | .replace(" ", "_") 924 | gpu_idx = os.environ["CUDA_VISIBLE_DEVICES"] 925 | log_dir = os.path.join("./logs", "summaries", f"log_{date}_gpu{gpu_idx}_{args.expname}") 926 | if not os.path.exists(log_dir): 927 | os.makedirs(log_dir) 928 | writer = SummaryWriter(log_dir=log_dir) 929 | 930 | start = start + 1 931 | # if not use_batching and jt.mpi: 932 | # img_i_list = np.random.choice(i_train, N_iters) 933 | # print("before img_i_list",img_i_list.sum()) 934 | # jt.mpi.broadcast(img_i_list, root=0) 935 | # print("after img_i_list",img_i_list.sum()) 936 | for i in trange(start, N_iters): 937 | # print("i",i,"jt.mpi.local_rank()",jt.mpi.local_rank()) 938 | # jt.display_memory_info() 939 | time0 = time.time() 940 | 941 | # Sample random ray batch 942 | if use_batching: 943 | # Random over all images 944 | batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] 945 | batch = jt.transpose(batch, (1, 0, 2)) 946 | batch_rays, target_s = batch[:2], batch[2] 947 | 948 | i_batch += N_rand 949 | if i_batch >= rays_rgb.shape[0]: 950 | print("Shuffle data after an epoch!") 951 | rand_idx = jt.randperm(rays_rgb.shape[0]) 952 | rays_rgb = rays_rgb[rand_idx] 953 | i_batch = 0 954 | 955 | else: 956 | # Random from one image 957 | # if jt.mpi: 958 | # img_i = img_i_list[i] 959 | # else: 960 | # img_i = np.random.choice(i_train) 961 | img_i = np.random.choice(i_train) 962 | target = images[img_i]#.squeeze(0) 963 | # acc_target = accs[img_i] 964 | pose = poses[img_i, :3,:4]#.squeeze(0) 965 | 966 | if N_rand is not None: 967 | rays_o, rays_d = get_rays(H, W, focal, pose, intrinsic) # (H, W, 3), (H, W, 3) 968 | 969 | if i < args.precrop_iters: 970 | dH = int(H//2 * args.precrop_frac) 971 | dW = int(W//2 * args.precrop_frac) 972 | coords = jt.stack( 973 | jt.meshgrid( 974 | jt.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 975 | jt.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) 976 | ), -1) 977 | if i == start: 978 | print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") 979 | else: 980 | coords = jt.stack(jt.meshgrid(jt.linspace(0, H-1, H), jt.linspace(0, W-1, W)), -1) # (H, W, 2) 981 | 982 | coords = jt.reshape(coords, [-1,2]) # (H * W, 2) 983 | # if jt.mpi: 984 | # assert coords.shape[0]%jt.mpi.world_size()==0 985 | # select_inds = np.random.choice(coords.shape[0]//jt.mpi.world_size(), size=[N_rand], replace=False) # (N_rand,) 986 | # select_inds += (coords.shape[0]//jt.mpi.world_size())*jt.mpi.local_rank() 987 | # else: 988 | # select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 989 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 990 | select_coords = coords[select_inds].int() # (N_rand, 2) 991 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 992 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 993 | batch_rays = jt.stack([rays_o, rays_d], 0) 994 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 995 | # target_a = acc_target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 1) 996 | 997 | ##### Core optimization loop ##### 998 | rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, 999 | verbose=i < 10, retraw=True, 1000 | **render_kwargs_train) 1001 | # print("rgb",rgb.shape) 1002 | # print("target_s",target_s.shape) 1003 | img_loss = img2mse(rgb, target_s.unsqueeze(0)) 1004 | # img_loss = (img2mse(rgb[-1], target_s.unsqueeze(0))*10.0+img2mse(rgb[:-1], target_s.unsqueeze(0)))/11.0 1005 | trans = extras['raw'][...,-1] 1006 | loss = img_loss 1007 | # loss_conf = img2mse(extras['conf_map'], (rgb-target_s)**2) 1008 | # loss_conf = jt.maximum((rgb-target_s)**2-extras['conf_map'],0.)+jt.maximum(extras['conf_map']+0.01,0.).mean() 1009 | sloss = jt.maximum(((rgb-target_s)**2).detach().unsqueeze(-2)-extras['conf_map'],0.) 1010 | # loss_conf_loss = jt.sum(extras['weights'].unsqueeze(-1) * sloss, -2).mean() 1011 | loss_conf_loss = sloss.mean() 1012 | # loss_conf_loss = jt.maximum((rgb-target_s)**2-extras['conf_map'],0.).mean() 1013 | loss_conf_mean = jt.maximum(extras['conf_map'], 0.).mean() 1014 | # print("(rgb-target_s)",(rgb-target_s).shape) 1015 | # print("loss_conf_loss",loss_conf_loss.shape) 1016 | # print("loss_conf_mean",loss_conf_mean.shape) 1017 | # loss_conf = jt.maximum((rgb-target_s)**2-extras['conf_map'],0.).sum() 1018 | # print("ssqr",ssqr.shape,"mean",ssqr.mean(),"max",ssqr.max()) 1019 | # print("extras['conf_map']",extras['conf_map'].shape,"mean",extras['conf_map'].mean(),"max",extras['conf_map'].max()) 1020 | # print("extras['loss_conf']",extras['loss_conf'].shape,"mean",extras['loss_conf'].mean(),"max",extras['loss_conf'].max()) 1021 | # psnr = mse2psnr(img_loss) 1022 | # psnr = mse2psnr(img2mse(rgb[-2], target_s.unsqueeze(0))) 1023 | psnr = mse2psnr(img2mse(rgb[-1], target_s.unsqueeze(0))) 1024 | # loss_acc = jt.zeros([1]) 1025 | # if args.white_bkgd: 1026 | # # print("acc",acc.shape) 1027 | # # print("target_a",target_a.shape) 1028 | # loss_acc = loss_acc+img2mse(acc, target_a.unsqueeze(0)) 1029 | 1030 | if 'rgb0' in extras: 1031 | img_loss0 = img2mse(extras['rgb0'], target_s) 1032 | # img_loss0 = (img2mse(extras['rgb0'][-1], target_s.unsqueeze(0))*10.0+img2mse(extras['rgb0'][:-1], target_s.unsqueeze(0)))/11.0 1033 | loss = loss + img_loss0 1034 | # loss_conf = loss_conf + img2mse(extras['conf_map0'], (extras['rgb0']-target_s)**2) 1035 | # loss_conf = loss_conf + jt.maximum((extras['rgb0']-target_s)**2-extras['conf_map0'],0.)+jt.maximum(extras['conf_map0']+0.01,0.) 1036 | sloss = jt.maximum(((extras['rgb0']-target_s)**2).detach().unsqueeze(-2)-extras['conf_map0'],0.) 1037 | # loss_conf_loss = loss_conf_loss + jt.sum(extras['weights0'].unsqueeze(-1) * sloss, -2).mean() 1038 | loss_conf_loss = loss_conf_loss + sloss.mean() 1039 | # loss_conf_loss = loss_conf_loss + jt.maximum((extras['rgb0']-target_s)**2-extras['conf_map0'],0.).mean() 1040 | loss_conf_mean = loss_conf_mean + jt.maximum(extras['conf_map0'], 0.).mean() 1041 | # loss_conf = loss_conf + jt.maximum((extras['rgb0']-target_s)**2-extras['conf_map0'],0.).sum() 1042 | psnr0 = mse2psnr(img_loss0) 1043 | # if args.white_bkgd: 1044 | # loss_acc = loss_acc+img2mse(extras['acc0'], target_a.unsqueeze(0)) 1045 | loss_conf_loss = loss_conf_loss 1046 | # print("loss_conf_loss",loss_conf_loss) 1047 | # print("loss_conf_mean",loss_conf_mean) 1048 | loss_conf = loss_conf_loss+loss_conf_mean*0.01 1049 | loss = loss + loss_conf*0.1 1050 | # loss = loss + loss_acc 1051 | 1052 | jt.sync_all() 1053 | optimizer.backward(loss / accumulation_steps) 1054 | if i % accumulation_steps == 0: 1055 | optimizer.step() 1056 | jt.sync_all() 1057 | # optimizer.step(loss) 1058 | 1059 | # if global_step==10000: 1060 | # if global_step==0: 1061 | # render_kwargs_train['network_fn'].force_out = 15 1062 | # render_kwargs_train['network_fine'].force_out = 15 1063 | 1064 | # NOTE: IMPORTANT! 1065 | ### update learning rate ### 1066 | decay_rate = 0.1 1067 | decay_steps = args.lrate_decay * accumulation_steps * 1000 1068 | sstep = global_step 1069 | if sstep>split_tree3: 1070 | sstep-=split_tree3 1071 | elif sstep>split_tree2: 1072 | sstep-=split_tree2 1073 | elif sstep>split_tree1: 1074 | sstep-=split_tree1 1075 | new_lrate = args.lrate * (decay_rate ** (sstep / decay_steps)) 1076 | for param_group in optimizer.param_groups: 1077 | param_group['lr'] = new_lrate 1078 | ################################ 1079 | 1080 | dt = time.time()-time0 1081 | # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") 1082 | ##### end ##### 1083 | 1084 | # Rest is logging 1085 | if (i+1)%args.i_weights==0: 1086 | if (not jt.mpi or jt.mpi.local_rank()==0): 1087 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) 1088 | else: 1089 | path = os.path.join(basedir, expname, 'tmp.tar') 1090 | jt.save({ 1091 | 'global_step': global_step, 1092 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 1093 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 1094 | # 'optimizer_state_dict': optimizer.state_dict(), 1095 | }, path) 1096 | print('Saved checkpoints at', path) 1097 | 1098 | # jt.display_memory_info() 1099 | if i%args.i_video==0 and i > 0 or i==split_tree1 or i==split_tree2 or i==split_tree3: 1100 | # import ipdb 1101 | # ipdb.set_trace() 1102 | # Turn on testing mode 1103 | if not jt.mpi or jt.mpi.local_rank()==0: 1104 | with jt.no_grad(): 1105 | rgbs, disps, rgbs_log, points = render_path(render_poses, hwf, args.chunk, render_kwargs_test, intrinsic = intrinsic, get_points = True, large_scene = args.large_scene) 1106 | else: 1107 | points = jt.random([1000,3]) 1108 | if i==split_tree1 or i==split_tree2 or i==split_tree3: 1109 | do_kmeans(points, render_kwargs_train['network_fn'], render_kwargs_train['network_fine']) 1110 | # jt.display_memory_info() 1111 | if not jt.mpi or jt.mpi.local_rank()==0: 1112 | print('Done, saving', rgbs.shape, disps.shape) 1113 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) 1114 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) 1115 | imageio.mimwrite(moviebase + 'rgb_log.mp4', to8b(rgbs_log), fps=30, quality=8) 1116 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) 1117 | jt.gc() 1118 | 1119 | # if args.use_viewdirs: 1120 | # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4] 1121 | # with jt.no_grad(): 1122 | # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) 1123 | # render_kwargs_test['c2w_staticcam'] = None 1124 | # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) 1125 | if i%args.i_testset==0 and i > 0 and (not jt.mpi or jt.mpi.local_rank()==0): 1126 | si_test = i_test_tot if i%args.i_tottest==0 else i_test 1127 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 1128 | os.makedirs(testsavedir, exist_ok=True) 1129 | print('test poses shape', poses[si_test].shape) 1130 | with jt.no_grad(): 1131 | rgbs, disps, rgbs_log = render_path(jt.array(poses[si_test]), hwf, args.chunk, render_kwargs_test, gt_imgs=images[si_test], savedir=testsavedir, intrinsic = intrinsic, log_path = os.path.join(basedir, expname, 'outnum_{:06d}.txt'.format(i))) 1132 | tars = images[si_test] 1133 | testpsnr = mse2psnr(img2mse(jt.array(rgbs), tars)).item() 1134 | if not jt.mpi or jt.mpi.local_rank()==0: 1135 | writer.add_scalar('test/psnr_tot', testpsnr, global_step) 1136 | print('Saved test set') 1137 | 1138 | 1139 | 1140 | 1141 | if i%args.i_print==0: 1142 | tqdm.write(f"[TRAIN] Iter: {i} expname: {args.expname} Loss: {loss.item()} LossConf: {loss_conf.item()} PSNR: {psnr.item()}") 1143 | a=psnr0.item() 1144 | # print("before:",jt.mpi.local_rank()) 1145 | if not jt.mpi or jt.mpi.local_rank()==0: 1146 | writer.add_scalar("train/loss", loss.item(), global_step) 1147 | # writer.add_scalar("train/loss_acc", loss_acc.item(), global_step) 1148 | writer.add_scalar("train/loss_conf", loss_conf.item(), global_step) 1149 | writer.add_scalar("train/PSNR", psnr.item(), global_step) 1150 | writer.add_scalar("lr/lr", new_lrate, global_step) 1151 | # writer.add_histogram('tran', trans.numpy(), global_step) 1152 | if args.N_importance > 0: 1153 | writer.add_scalar("train/PSNR0", psnr0.item(), global_step) 1154 | # print("done:",jt.mpi.local_rank()) 1155 | # print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) 1156 | # print('iter time {:.05f}'.format(dt)) 1157 | 1158 | # with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): 1159 | # tf.contrib.summary.scalar('loss', loss) 1160 | # tf.contrib.summary.scalar('psnr', psnr) 1161 | # tf.contrib.summary.histogram('tran', trans) 1162 | # if args.N_importance > 0: 1163 | # tf.contrib.summary.scalar('psnr0', psnr0) 1164 | 1165 | 1166 | if i%args.i_img==0: 1167 | 1168 | # Log a rendered validation view to Tensorboard 1169 | img_i=np.random.choice(i_val) 1170 | target = images[img_i] 1171 | pose = poses[img_i, :3,:4] 1172 | with jt.no_grad(): 1173 | rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, intrinsic=intrinsic, 1174 | **render_kwargs_test) 1175 | sout_num = list(extras['outnum0']) 1176 | log = "i_img out_num0:" 1177 | for k in range(len(sout_num)): 1178 | log += str(k)+": %d; " % int(sout_num[k]) 1179 | print(log) 1180 | sout_num = list(extras['outnum']) 1181 | log = "i_img out_num:" 1182 | for k in range(len(sout_num)): 1183 | log += str(k)+": %d; " % int(sout_num[k]) 1184 | print(log) 1185 | psnr = mse2psnr(img2mse(rgb[-1], target)) 1186 | # psnr = mse2psnr(img2mse(rgb[-2], target)) 1187 | rgb_log = rgb[-2].numpy() 1188 | rgb = rgb[-1].numpy() 1189 | rgb0 = extras['rgb0'][-1].numpy() 1190 | # rgb = extras['rgb0'][-1].numpy() 1191 | disp = disp[-1].numpy() 1192 | acc = acc[-1].numpy() 1193 | acc0 = extras['acc0'][-1].numpy() 1194 | a = target.numpy() 1195 | a = psnr.item() 1196 | 1197 | if not jt.mpi or jt.mpi.local_rank()==0: 1198 | writer.add_image('test/rgb', to8b(rgb), global_step, dataformats="HWC") 1199 | writer.add_image('test/rgb0', to8b(rgb0), global_step, dataformats="HWC") 1200 | writer.add_image('log/rgb', to8b(rgb_log), global_step, dataformats="HWC") 1201 | # writer.add_image('test/disp', disp[...,np.newaxis], global_step, dataformats="HWC") 1202 | # writer.add_image('test/acc', acc[...,np.newaxis], global_step, dataformats="HWC") 1203 | # writer.add_image('test/acc0', acc0[...,np.newaxis], global_step, dataformats="HWC") 1204 | writer.add_image('test/target', target.numpy(), global_step, dataformats="HWC") 1205 | 1206 | writer.add_scalar('test/psnr', psnr.item(), global_step) 1207 | jt.gc() 1208 | # jt.display_memory_info() 1209 | # jt.display_max_memory_info() 1210 | # a=images[0].numpy() 1211 | # b=images[1].numpy() 1212 | # c=images[2].numpy() 1213 | # print("images0",a.shape,a.sum()) 1214 | # print("images1",b.shape,b.sum()) 1215 | # print("images2",c.shape,c.sum()) 1216 | # writer.add_image('test/rgb_target0', a, global_step, dataformats="HWC") 1217 | # writer.add_image('test/rgb_target1', b, global_step, dataformats="HWC") 1218 | # writer.add_image('test/rgb_target2', c, global_step, dataformats="HWC") 1219 | 1220 | 1221 | # if args.N_importance > 0: 1222 | 1223 | # with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): 1224 | # tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis]) 1225 | # tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis]) 1226 | # tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis]) 1227 | 1228 | 1229 | global_step += 1 1230 | 1231 | 1232 | if __name__=='__main__': 1233 | train() 1234 | -------------------------------------------------------------------------------- /run_nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jittor as jt 3 | from jittor import nn 4 | import numpy as np 5 | 6 | 7 | # Misc 8 | img2mse = lambda x, y : jt.mean((x - y) ** 2) 9 | mse2psnr = lambda x : -10. * jt.log(x) / jt.log(jt.array(np.array([10.]))) 10 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 11 | 12 | 13 | # Positional encoding (section 5.1) 14 | class Embedder: 15 | def __init__(self, **kwargs): 16 | self.kwargs = kwargs 17 | self.create_embedding_fn() 18 | 19 | def create_embedding_fn(self): 20 | embed_fns = [] 21 | d = self.kwargs['input_dims'] 22 | out_dim = 0 23 | if self.kwargs['include_input']: 24 | embed_fns.append(lambda x : x) 25 | out_dim += d 26 | 27 | max_freq = self.kwargs['max_freq_log2'] 28 | N_freqs = self.kwargs['num_freqs'] 29 | 30 | if self.kwargs['log_sampling']: 31 | freq_bands = 2.**jt.linspace(0., max_freq, steps=N_freqs) 32 | else: 33 | freq_bands = jt.linspace(2.**0., 2.**max_freq, steps=N_freqs) 34 | 35 | for freq in freq_bands: 36 | for p_fn in self.kwargs['periodic_fns']: 37 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 38 | out_dim += d 39 | 40 | self.embed_fns = embed_fns 41 | self.out_dim = out_dim 42 | 43 | def embed(self, inputs): 44 | return jt.concat([fn(inputs) for fn in self.embed_fns], -1) 45 | 46 | 47 | def get_embedder(multires, i=0): 48 | if i == -1: 49 | return nn.Identity(), 3 50 | 51 | embed_kwargs = { 52 | 'include_input' : True, 53 | 'input_dims' : 3, 54 | 'max_freq_log2' : multires-1, 55 | 'num_freqs' : multires, 56 | 'log_sampling' : True, 57 | 'periodic_fns' : [jt.sin, jt.cos], 58 | } 59 | 60 | embedder_obj = Embedder(**embed_kwargs) 61 | embed = lambda x, eo=embedder_obj : eo.embed(x) 62 | return embed, embedder_obj.out_dim 63 | 64 | #tree class 65 | class Node(): 66 | def __init__(self, anchors, sons, linears): 67 | self.anchors = anchors 68 | self.sons = sons 69 | self.linears = linears 70 | 71 | # Model 72 | class OutputNet(nn.Module): 73 | def __init__(self, W, input_ch_views): 74 | """ 75 | """ 76 | super(OutputNet, self).__init__() 77 | 78 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) 79 | self.feature_linear = nn.Linear(W, W) 80 | self.alpha_linear = nn.Linear(W, 1) 81 | self.rgb_linear = nn.Linear(W//2, 3) 82 | 83 | def execute(self, h, input_views): 84 | alpha = self.alpha_linear(h) 85 | feature = self.feature_linear(h) 86 | h = jt.concat([feature, input_views], -1) 87 | 88 | for i, l in enumerate(self.views_linears): 89 | h = self.views_linears[i](h) 90 | h = jt.nn.relu(h) 91 | 92 | rgb = self.rgb_linear(h) 93 | outputs = jt.concat([rgb, alpha], -1) 94 | return outputs 95 | 96 | # Model 97 | class NeRF(nn.Module): 98 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False, head_num=8, threshold=3e-2): 99 | """ 100 | """ 101 | super(NeRF, self).__init__() 102 | D=12 103 | self.D = D 104 | self.W = W 105 | self.input_ch = input_ch 106 | self.input_ch_views = input_ch_views 107 | skips=[] 108 | # skips = [2,4,6] 109 | self.skips = skips 110 | # self.ress = [1,3,7,11] 111 | # self.ress = [] 112 | # self.outs = [1,3,7,11] 113 | self.force_out = [0] 114 | # self.force_out = [7,8,9,10,11,12,13,14] 115 | self.use_viewdirs = use_viewdirs 116 | assert self.use_viewdirs==True 117 | 118 | self.threshold = threshold 119 | 120 | self.build_tree(head_num) 121 | self.pts_linears = nn.ModuleList( 122 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skip_linear else nn.Linear(W + input_ch, W) for i in range(self.linear_num-1)]) 123 | # for i in range(self.nlinear_list[0]+1,len(self.pts_linears)): 124 | # jt.init.constant_(self.pts_linears[i].weight, 0.0) 125 | # jt.init.constant_(self.pts_linears[i].bias, 0.0) 126 | # self.confidence_linears = nn.ModuleList([nn.Linear(W+ input_ch, 1) for i in range(D)]) 127 | self.confidence_linears = nn.ModuleList([nn.Linear(W, 1) for i in range(self.node_num)]) 128 | # self.outnet = OutputNet(W, input_ch_views) 129 | self.outnet = nn.ModuleList([OutputNet(W, input_ch_views) for i in range(self.node_num)]) 130 | 131 | def get_anchor(self, i): 132 | return getattr(self, i) 133 | 134 | def build_tree(self, head_num): 135 | # self.son_list = [[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16],[17,18,19,20],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]] 136 | # self.nlinear_list = [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2] 137 | 138 | # self.son_list = [[1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14],[],[],[],[],[],[],[],[]] 139 | # self.nlinear_list = [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2] 140 | 141 | # self.son_list = [[1,2],[3,4],[5,6],[],[],[],[]] 142 | # self.nlinear_list = [2,2,2,4,4,4,4] 143 | # self.skip_linear = [6,10,14,18] 144 | 145 | if head_num == 1: 146 | # 1 head 147 | self.son_list = [[1],[2],[3],[]] 148 | self.nlinear_list = [2,2,4,4] 149 | self.skip_linear = [4] 150 | elif head_num == 4: 151 | # 4 head 152 | self.son_list = [[1,2],[3,4],[5,6],[7],[8],[9],[10],[],[],[],[]] 153 | self.nlinear_list = [2,2,2,4,4,4,4,4,4,4,4] 154 | self.skip_linear = [6,10,14,18] 155 | elif head_num == 8: 156 | # 8 head 157 | self.son_list = [[1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14],[],[],[],[],[],[],[],[]] 158 | self.nlinear_list = [2,2,2,4,4,4,4,4,4,4,4,4,4,4,4] 159 | self.skip_linear = [6,10,14,18] 160 | elif head_num == 16: 161 | # 16 head 162 | self.son_list = [[1,2,3,4],[5,6],[7,8],[9,10],[11,12],[13,14],[15,16],[17,18],[19,20],[21,22],[23,24],[25,26],[27,28],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]] 163 | self.nlinear_list = [2,2,2,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4] 164 | self.skip_linear = [10,14,18,22,26,30,34,38] 165 | 166 | # self.anchor_list = [np.array([[-2,0,0],[2,0,0]]).astype(np.float32), 167 | # np.array([[0,-2,0],[0,2,0]]).astype(np.float32), 168 | # np.array([[0,-2,0],[0,2,0]]).astype(np.float32), 169 | # np.array([[-1,0,0],[0,0,0]]).astype(np.float32), 170 | # np.array([[-1,0,0],[0,0,0]]).astype(np.float32), 171 | # np.array([[0,0,0],[1,0,0]]).astype(np.float32), 172 | # np.array([[0,0,0],[1,0,0]]).astype(np.float32)] 173 | # self.anchor_list = [np.array([[-2,0,0],[2,0,0]]).astype(np.float32), 174 | # np.array([[0,-2,0],[0,2,0]]).astype(np.float32), 175 | # np.array([[0,-2,0],[0,2,0]]).astype(np.float32), 176 | # np.array([[0,0,0]]).astype(np.float32), 177 | # np.array([[0,0,0]]).astype(np.float32), 178 | # np.array([[0,0,0]]).astype(np.float32), 179 | # np.array([[0,0,0]]).astype(np.float32)] 180 | assert len(self.son_list) == len(self.nlinear_list) 181 | self.anchor_list = [np.array([[0,0,0]]).astype(np.float32)]*len(self.son_list) 182 | self.node_list = [] 183 | self.node_num = len(self.son_list) 184 | self.anchor_num = 0 185 | self.linear_num = 0 186 | for i in range(len(self.son_list)): 187 | son = self.son_list[i] 188 | if len(son)>0: 189 | anchor = "anchor"+str(self.anchor_num) 190 | self.anchor_num += 1 191 | setattr(self, anchor, jt.array(self.anchor_list[i])) 192 | # setattr(self, anchor, jt.random([len(son), 3])) 193 | else: 194 | anchor = None 195 | linear = list(range(self.linear_num, self.linear_num+self.nlinear_list[i])) 196 | self.linear_num += self.nlinear_list[i] 197 | self.node_list.append(Node(anchor, son, linear)) 198 | 199 | def my_concat(self, a, b, dim): 200 | if a is None: 201 | return b 202 | elif b is None: 203 | return a 204 | else: 205 | return jt.concat([a,b],dim) 206 | 207 | def search(self, t, p, h, input_pts, input_views, remain_mask): 208 | node = self.node_list[t] 209 | # print("search t",t,"remain_mask",remain_mask.sum()) 210 | identity = h 211 | for i in range(len(node.linears)): 212 | # print("i",i) 213 | # print("h",h.shape) 214 | # print(self.pts_linears[node.linears[i]]) 215 | # print("len",len(self.pts_linears),"node.linears[i]",node.linears[i],"node",node) 216 | # print("t",t,"i",i,"h",h.shape,"line",self.pts_linears[node.linears[i]].weight.shape) 217 | h = self.pts_linears[node.linears[i]](h) 218 | if t==0 and i==0: 219 | identity = h 220 | if i==len(node.linears)-1: 221 | h = h+identity 222 | h = jt.nn.relu(h) 223 | if node.linears[i] in self.skip_linear: 224 | h = jt.concat([input_pts, h], -1) 225 | 226 | confidence = self.confidence_linears[t](h).view(-1) 227 | # threshold = 0.0 228 | threshold = self.threshold 229 | # threshold = -1e10 230 | output = self.outnet[t](h, input_views) 231 | # output = self.outnet[0](h, input_views) 232 | out_num = np.zeros((self.node_num)) 233 | 234 | if len(node.sons)>0 and (not t in self.force_out): 235 | son_outputs = None 236 | son_outputs_fuse = None 237 | son_confs = None 238 | son_confs_fuse = None 239 | idxs = None 240 | idxs_fuse = None 241 | 242 | anchor = self.get_anchor(node.anchors) 243 | dis = (anchor.unsqueeze(0)-p.unsqueeze(1)).sqr().sum(-1).sqrt() 244 | min_idx, _ = jt.argmin(dis,-1) 245 | for i in range(len(node.sons)): 246 | # print("t",t,"i",i) 247 | next_t = node.sons[i] 248 | sidx = jt.arange(0,p.shape[0]) 249 | # print("min_idx==i",min_idx==i) 250 | sidx = sidx[min_idx==i] 251 | # print("sidx",sidx) 252 | next_p = p[sidx] 253 | next_h = h[sidx] 254 | next_input_pts = input_pts[sidx] 255 | next_input_views = input_views[sidx] 256 | next_remain_mask = remain_mask[sidx].copy() 257 | next_conf = confidence[sidx] 258 | next_remain_mask[threshold>next_conf] = 0 259 | sidx_fuse = sidx[next_remain_mask==1] 260 | 261 | # print("start t",t,"i",i,"next_t",next_t) 262 | next_outputs, next_outputs_fuse, next_confs, next_confs_fuse, next_out_num = self.search(next_t, next_p, next_h, next_input_pts, next_input_views, next_remain_mask) 263 | out_num = out_num+next_out_num 264 | # print("search", t, next_t) 265 | # print("next_outputs",next_outputs.shape) 266 | # print("next_outputs_fuse",next_outputs_fuse.shape) 267 | # print("next_confs",next_confs.shape) 268 | # print("next_confs_fuse",next_confs_fuse.shape) 269 | 270 | # print("end t",t,"i",i,"next_t",next_t) 271 | son_outputs = self.my_concat(son_outputs, next_outputs, 1) 272 | son_outputs_fuse = self.my_concat(son_outputs_fuse, next_outputs_fuse, 1) 273 | son_confs = self.my_concat(son_confs, next_confs, 1) 274 | son_confs_fuse = self.my_concat(son_confs_fuse, next_confs_fuse, 1) 275 | idxs = self.my_concat(idxs, sidx, 0) 276 | idxs_fuse = self.my_concat(idxs_fuse, sidx_fuse, 0) 277 | 278 | # print("t",t) 279 | son_outputs_save = jt.zeros(son_outputs.shape) 280 | son_outputs_save[:,idxs] = son_outputs 281 | son_outputs_save = jt.concat([output.unsqueeze(0), son_outputs_save], 0) 282 | son_confs_save = jt.zeros(son_confs.shape) 283 | son_confs_save[:,idxs] = son_confs 284 | son_confs_save = jt.concat([confidence.unsqueeze(1).unsqueeze(0), son_confs_save], 0) 285 | 286 | out_remain_mask = remain_mask.copy() 287 | out_remain_mask[threshold<=confidence] = 0 288 | idx_out = jt.arange(0,out_remain_mask.shape[0])[out_remain_mask==1] 289 | outputs_out = output[idx_out].unsqueeze(0) 290 | out_num[t] = outputs_out.shape[1] 291 | confs_out = confidence[idx_out].unsqueeze(1).unsqueeze(0) 292 | outputs_out = jt.concat([outputs_out, son_outputs_fuse], 1) 293 | confs_out = jt.concat([confs_out, son_confs_fuse], 1) 294 | idx_out = jt.concat([idx_out, idxs_fuse], 0) 295 | 296 | outputs_out_save = jt.zeros(output.unsqueeze(0).shape) 297 | outputs_out_save[:, idx_out] = outputs_out 298 | outputs_out_save = outputs_out_save[:, remain_mask==1] 299 | confs_out_save = jt.zeros(confidence.unsqueeze(1).unsqueeze(0).shape) 300 | confs_out_save[:, idx_out] = confs_out 301 | confs_out_save = confs_out_save[:, remain_mask==1] 302 | 303 | return son_outputs_save, outputs_out_save, son_confs_save, confs_out_save, out_num 304 | else: 305 | outputs_save = output.unsqueeze(0) 306 | outputs_save_log = outputs_save.copy() 307 | confs_save = confidence.unsqueeze(1).unsqueeze(0) 308 | # print("outputs_save",outputs_save.shape) 309 | # print("remain_mask",remain_mask.shape) 310 | # print("remain_mask==1",remain_mask==1) 311 | outputs_out_save = outputs_save[:, remain_mask==1] 312 | confs_out_save = confs_save[:, remain_mask==1] 313 | out_num[t] = outputs_out_save.shape[1] 314 | 315 | remain_mask[threshold<=confidence] = 0 316 | # print("out:", remain_mask.sum().numpy(), "remain:", remain_mask.shape[0]-remain_mask.sum().numpy()) 317 | 318 | # print("outputs_out_save",outputs_out_save.shape) 319 | if not self.training: 320 | if t%4==0: 321 | outputs_save_log[..., 0] *= 0. 322 | outputs_save_log[..., 1] *= 0. 323 | elif t%4==1: 324 | outputs_save_log[..., 0] *= 0. 325 | outputs_save_log[..., 2] *= 0. 326 | elif t%4==2: 327 | outputs_save_log[..., 1] *= 0. 328 | outputs_save_log[..., 2] *= 0. 329 | elif t%4==3: 330 | outputs_save_log[..., 0] *= 0. 331 | # elif t==1: 332 | # outputs_out_save[..., 1] *= 0. 333 | # elif t==2: 334 | # outputs_out_save[..., 2] *= 0. 335 | outputs_save = jt.concat([outputs_save, outputs_save_log], 0) 336 | confs_save = jt.concat([confs_save, confs_save], 0) 337 | # print("outputs_out_save out",outputs_out_save.shape) 338 | 339 | # print("outputs_out_save",outputs_out_save.shape) 340 | 341 | return outputs_save, outputs_out_save, confs_save, confs_out_save, out_num 342 | 343 | def do_train(self, x, p): 344 | input_pts, input_views = jt.split(x, [self.input_ch, self.input_ch_views], dim=-1) 345 | remain_mask = jt.ones(input_pts.shape[0]) 346 | outputs, outputs_fuse, confs, confs_fuse, out_num = self.search(0, p, input_pts, input_pts, input_views, remain_mask) 347 | 348 | outputs = jt.concat([outputs, outputs_fuse], 0) 349 | confs = jt.concat([confs, confs_fuse], 0) 350 | 351 | return outputs, confs, np.zeros([1]) 352 | 353 | def do_eval(self, x, p): 354 | input_pts, input_views = jt.split(x, [self.input_ch, self.input_ch_views], dim=-1) 355 | remain_mask = jt.ones(input_pts.shape[0]) 356 | outputs, outputs_fuse, confs, confs_fuse, out_num = self.search(0, p, input_pts, input_pts, input_views, remain_mask) 357 | 358 | log = "out: " 359 | sout_num = list(out_num) 360 | for i in range(len(sout_num)): 361 | log += str(i)+": %d; " % sout_num[i] 362 | # print(log) 363 | sout_num = np.array(sout_num) 364 | outputs = outputs[-1:] 365 | 366 | outputs = jt.concat([outputs, outputs_fuse], 0) 367 | confs = jt.concat([confs, confs_fuse], 0) 368 | 369 | return outputs, confs, sout_num 370 | 371 | def execute(self, x, p, training): 372 | self.training = training 373 | if training: 374 | return self.do_train(x, p) 375 | else: 376 | # return self.do_train(x, p) 377 | return self.do_eval(x, p) 378 | 379 | def load_weights_from_keras(self, weights): 380 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False" 381 | 382 | # Load pts_linears 383 | for i in range(self.D): 384 | idx_pts_linears = 2 * i 385 | self.pts_linears[i].weight.data = jt.array(np.transpose(weights[idx_pts_linears])) 386 | self.pts_linears[i].bias.data = jt.array(np.transpose(weights[idx_pts_linears+1])) 387 | 388 | # Load feature_linear 389 | idx_feature_linear = 2 * self.D 390 | self.feature_linear.weight.data = jt.array(np.transpose(weights[idx_feature_linear])) 391 | self.feature_linear.bias.data = jt.array(np.transpose(weights[idx_feature_linear+1])) 392 | 393 | # Load views_linears 394 | idx_views_linears = 2 * self.D + 2 395 | self.views_linears[0].weight.data = jt.array(np.transpose(weights[idx_views_linears])) 396 | self.views_linears[0].bias.data = jt.array(np.transpose(weights[idx_views_linears+1])) 397 | 398 | # Load rgb_linear 399 | idx_rbg_linear = 2 * self.D + 4 400 | self.rgb_linear.weight.data = jt.array(np.transpose(weights[idx_rbg_linear])) 401 | self.rgb_linear.bias.data = jt.array(np.transpose(weights[idx_rbg_linear+1])) 402 | 403 | # Load alpha_linear 404 | idx_alpha_linear = 2 * self.D + 6 405 | self.alpha_linear.weight.data = jt.array(np.transpose(weights[idx_alpha_linear])) 406 | self.alpha_linear.bias.data = jt.array(np.transpose(weights[idx_alpha_linear+1])) 407 | 408 | 409 | 410 | # Ray helpers 411 | def get_rays(H, W, focal, c2w, intrinsic = None): 412 | i, j = jt.meshgrid(jt.linspace(0, W-1, W), jt.linspace(0, H-1, H)) 413 | i = i.t() 414 | j = j.t() 415 | if intrinsic is None: 416 | dirs = jt.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -jt.ones_like(i)], -1).unsqueeze(-2) 417 | else: 418 | i+=0.5 419 | j+=0.5 420 | dirs = jt.stack([i, j, jt.ones_like(i)], -1).unsqueeze(-2) 421 | dirs = jt.sum(dirs * intrinsic[:3,:3], -1).unsqueeze(-2) 422 | # Rotate ray directions from camera frame to the world frame 423 | rays_d = jt.sum(dirs * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 424 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 425 | rays_o = c2w[:3,-1].expand(rays_d.shape) 426 | return rays_o, rays_d 427 | 428 | 429 | def get_rays_np(H, W, focal, c2w): 430 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 431 | dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1) 432 | # Rotate ray directions from camera frame to the world frame 433 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 434 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 435 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 436 | return rays_o, rays_d 437 | 438 | 439 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 440 | # Shift ray origins to near plane 441 | t = -(near + rays_o[...,2]) / rays_d[...,2] 442 | rays_o = rays_o + t.unsqueeze(-1) * rays_d 443 | 444 | # Projection 445 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 446 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 447 | o2 = 1. + 2. * near / rays_o[...,2] 448 | 449 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 450 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 451 | d2 = -2. * near / rays_o[...,2] 452 | 453 | rays_o = jt.stack([o0,o1,o2], -1) 454 | rays_d = jt.stack([d0,d1,d2], -1) 455 | 456 | return rays_o, rays_d 457 | 458 | 459 | # Hierarchical sampling (section 5.2) 460 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 461 | # Get pdf 462 | weights = weights + 1e-5 # prevent nans 463 | pdf = weights / jt.sum(weights, -1, keepdims=True) 464 | cdf = jt.cumsum(pdf, -1) 465 | cdf = jt.concat([jt.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) 466 | 467 | # Take uniform samples 468 | if det: 469 | u = jt.linspace(0., 1., steps=N_samples) 470 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 471 | else: 472 | u = jt.random(list(cdf.shape[:-1]) + [N_samples]) 473 | 474 | # Pytest, overwrite u with numpy's fixed random numbers 475 | if pytest: 476 | # np.random.seed(0) 477 | new_shape = list(cdf.shape[:-1]) + [N_samples] 478 | if det: 479 | u = np.linspace(0., 1., N_samples) 480 | u = np.broadcast_to(u, new_shape) 481 | else: 482 | u = np.random.rand(*new_shape) 483 | u = jt.array(u) 484 | 485 | # Invert CDF 486 | # u = u.contiguous() 487 | # inds = searchsorted(cdf, u, side='right') 488 | inds = jt.searchsorted(cdf, u, right=True) 489 | below = jt.maximum(jt.zeros_like(inds-1), inds-1) 490 | above = jt.minimum((cdf.shape[-1]-1) * jt.ones_like(inds), inds) 491 | inds_g = jt.stack([below, above], -1) # (batch, N_samples, 2) 492 | 493 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 494 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 495 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 496 | cdf_g = jt.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 497 | bins_g = jt.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 498 | 499 | denom = (cdf_g[...,1]-cdf_g[...,0]) 500 | cond = jt.where(denom<1e-5) 501 | denom[cond] = 1. 502 | t = (u-cdf_g[...,0])/denom 503 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 504 | 505 | return samples 506 | --------------------------------------------------------------------------------