├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── configs
├── bike.txt
├── bike_full.txt
├── bike_tot.txt
├── chair.txt
├── chess.txt
├── chess_full.txt
├── chess_simple.txt
├── cornell.txt
├── drums.txt
├── fern.txt
├── ficus.txt
├── flower.txt
├── fortress.txt
├── horns.txt
├── hotdog.txt
├── leaves.txt
├── lego.txt
├── materials.txt
├── mic.txt
├── orchids.txt
├── room.txt
├── scannet.txt
├── ship.txt
└── trex.txt
├── download_example_data.sh
├── image_mem.py
├── image_mem_helpers.py
├── imgs
└── pipeline.jpg
├── load_blender.py
├── load_deepvoxels.py
├── load_llff.py
├── requirements.txt
├── run_nerf.py
└── run_nerf_helpers.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.ipynb_checkpoints
2 | **/__pycache__
3 | *.png
4 | *.mp4
5 | *.npy
6 | *.npz
7 | *.dae
8 | data/*
9 | logs/*
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gword/Recursive-NeRF/e61ceb302d0bc028fe4524771b363bf514342b5f/.gitmodules
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 bmild
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Recursive-NeRF: An Efficient and Dynamically Growing NeRF
2 | This is official implementation of Recursive-NeRF: An Efficient and Dynamically Growing NeRF.
3 |
4 | Paper link: https://ieeexplore.ieee.org/document/9909994
5 |
6 | ## Abstract
7 | View synthesis methods using implicit continuous shape representations learned from a set of images, such as the Neural Radiance Field (NeRF) method, have gained increasing attention due to their high quality imagery and scalability to high resolution.
8 | However, the heavy computation required by its volumetric approach prevents NeRF from being useful in practice; minutes are taken to render a single image of a few megapixels.
9 | Now, an image of a scene can be rendered in a level-of-detail manner, so we posit that a complicated region of the scene should be represented by a large neural network while a small neural network is capable of encoding a simple region, enabling a balance between efficiency and quality.
10 | Recursive-NeRF is our embodiment of this idea, providing an efficient and adaptive rendering and training approach for NeRF.
11 | The core of Recursive-NeRF learns uncertainties for query coordinates, representing the quality of the predicted color and volumetric intensity at each level.
12 | Only query coordinates with high uncertainties are forwarded to the next level to a bigger neural network with a more powerful representational capability.
13 | The final rendered image is a composition of results from neural networks of all levels.
14 | Our evaluation on public datasets and a large-scale scene dataset we collected shows that Recursive-NeRF is more efficient than NeRF while providing state-of-the-art quality.
--------------------------------------------------------------------------------
/configs/bike.txt:
--------------------------------------------------------------------------------
1 | expname = bike
2 | datadir = ./data/our_seq_bike
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = False
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 512
14 |
15 | half_res = False
16 | raw_noise_std = 1e0
17 | blender_factor = 8
18 | testskip = 1
--------------------------------------------------------------------------------
/configs/bike_full.txt:
--------------------------------------------------------------------------------
1 | expname = bike_full
2 | datadir = ./data/our_seq_bike_full
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = False
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | half_res = False
16 | raw_noise_std = 1e0
17 | blender_factor = 8
18 | testskip = 1
19 | faketestskip = 8
20 | i_tottest = 400000
21 | do_intrinsic = True
22 | render_test = True
23 | near = 0.3
24 | far = 6.0
25 |
--------------------------------------------------------------------------------
/configs/bike_tot.txt:
--------------------------------------------------------------------------------
1 | expname = bike_tot
2 | datadir = ./data/our_seq_bike_tot
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = False
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 512
14 |
15 | half_res = False
16 | raw_noise_std = 1e0
17 | blender_factor = 8
18 | testskip = 8
--------------------------------------------------------------------------------
/configs/chair.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_chair
2 | datadir = ./data/nerf_synthetic/chair
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | testskip=1
19 | faketestskip=8
20 | i_tottest = 600000
--------------------------------------------------------------------------------
/configs/chess.txt:
--------------------------------------------------------------------------------
1 | expname = chess
2 | datadir = ./data/chess
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = False
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | half_res = False
19 |
--------------------------------------------------------------------------------
/configs/chess_full.txt:
--------------------------------------------------------------------------------
1 | expname = chess
2 | datadir = ./data/chess_full
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = False
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 512
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | half_res = False
19 | raw_noise_std = 1e0
--------------------------------------------------------------------------------
/configs/chess_simple.txt:
--------------------------------------------------------------------------------
1 | expname = chess
2 | datadir = ./data/chess_simple
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = False
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 512
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | half_res = False
19 |
--------------------------------------------------------------------------------
/configs/cornell.txt:
--------------------------------------------------------------------------------
1 | expname = cornell
2 | datadir = ./data/our_cornell
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | half_res = False
19 | blender_factor = 1
20 | testskip = 1
21 | faketestskip = 8
22 | i_tottest = 400000
23 | do_intrinsic = True
24 | render_test = True
25 | far = 20.0
--------------------------------------------------------------------------------
/configs/drums.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_drums
2 | datadir = ./data/nerf_synthetic/drums
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | testskip=1
19 | faketestskip=8
20 | i_tottest = 600000
--------------------------------------------------------------------------------
/configs/fern.txt:
--------------------------------------------------------------------------------
1 | expname = fern_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/fern
4 | dataset_type = llff
5 |
6 | factor = 4
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 128
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 | i_tottest = 600000
--------------------------------------------------------------------------------
/configs/ficus.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_ficus
2 | datadir = ./data/nerf_synthetic/ficus
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | testskip=1
19 | faketestskip=8
20 | i_tottest = 600000
--------------------------------------------------------------------------------
/configs/flower.txt:
--------------------------------------------------------------------------------
1 | expname = flower_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/flower
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/fortress.txt:
--------------------------------------------------------------------------------
1 | expname = fortress_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/fortress
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/horns.txt:
--------------------------------------------------------------------------------
1 | expname = horns_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/horns
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/hotdog.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_hotdog
2 | datadir = ./data/nerf_synthetic/hotdog
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | testskip=1
19 | faketestskip=8
20 | i_tottest = 600000
--------------------------------------------------------------------------------
/configs/leaves.txt:
--------------------------------------------------------------------------------
1 | expname = leaves_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/leaves
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/lego.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_lego
2 | datadir = ./data/nerf_synthetic/lego
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | testskip=1
19 | faketestskip=8
20 | i_tottest = 600000
--------------------------------------------------------------------------------
/configs/materials.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_materials
2 | datadir = ./data/nerf_synthetic/materials
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | half_res = False
19 | testskip=1
20 | faketestskip=8
21 | i_tottest = 600000
22 |
--------------------------------------------------------------------------------
/configs/mic.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_mic
2 | datadir = ./data/nerf_synthetic/mic
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | testskip=1
19 | faketestskip=8
20 | i_tottest = 600000
--------------------------------------------------------------------------------
/configs/orchids.txt:
--------------------------------------------------------------------------------
1 | expname = orchids_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/orchids
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/room.txt:
--------------------------------------------------------------------------------
1 | expname = room_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/room
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/scannet.txt:
--------------------------------------------------------------------------------
1 | expname = scannet
2 | datadir = ./data/scannet_scene0706_00
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = False
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 512
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | half_res = False
19 | raw_noise_std = 1e0
--------------------------------------------------------------------------------
/configs/ship.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_ship
2 | datadir = ./data/nerf_synthetic/ship
3 | dataset_type = blender
4 |
5 | no_batching = True
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | testskip=1
19 | faketestskip=8
20 | i_tottest = 600000
--------------------------------------------------------------------------------
/configs/trex.txt:
--------------------------------------------------------------------------------
1 | expname = trex_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/trex
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/download_example_data.sh:
--------------------------------------------------------------------------------
1 | wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz
2 | mkdir -p data
3 | cd data
4 | wget https://people.eecs.berkeley.edu/~bmild/nerf/nerf_example_data.zip
5 | unzip nerf_example_data.zip
6 | cd ..
7 |
--------------------------------------------------------------------------------
/image_mem.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import imageio
4 | import json
5 | import random
6 | import time
7 | import jittor as jt
8 | from jittor import nn
9 | from tqdm import tqdm, trange
10 | import datetime
11 | import cv2
12 |
13 | import matplotlib.pyplot as plt
14 |
15 | from image_mem_helpers import *
16 |
17 | from load_llff import load_llff_data
18 | from load_deepvoxels import load_dv_data
19 | from load_blender import load_blender_data
20 | from tensorboardX import SummaryWriter
21 |
22 | jt.flags.use_cuda = 1
23 | # np.random.seed(0)
24 | DEBUG = False
25 |
26 | def config_parser():
27 | gpu = "gpu"+os.environ["CUDA_VISIBLE_DEVICES"]
28 | import configargparse
29 | parser = configargparse.ArgumentParser()
30 | parser.add_argument('--config', is_config_file=True,
31 | help='config file path')
32 | parser.add_argument("--expname", type=str,
33 | help='experiment name')
34 | parser.add_argument("--basedir", type=str, default='./logs/'+gpu+"/",
35 | help='where to store ckpts and logs')
36 | parser.add_argument("--datadir", type=str, default='./data/llff/fern',
37 | help='input data directory')
38 |
39 | # training options
40 | parser.add_argument("--netdepth", type=int, default=8,
41 | help='layers in network')
42 | parser.add_argument("--step1", type=int, default=10000,
43 | help='?')
44 | parser.add_argument("--step2", type=int, default=20000,
45 | help='?')
46 | parser.add_argument("--step3", type=int, default=3000000,
47 | help='?')
48 | parser.add_argument("--netwidth", type=int, default=256,
49 | help='channels per layer')
50 | parser.add_argument("--netdepth_fine", type=int, default=8,
51 | help='layers in fine network')
52 | parser.add_argument("--netwidth_fine", type=int, default=256,
53 | help='channels per layer in fine network')
54 | parser.add_argument("--N_rand", type=int, default=32*32*4,
55 | help='batch size (number of random rays per gradient step)')
56 | parser.add_argument("--lrate", type=float, default=5e-4,
57 | help='learning rate')
58 | parser.add_argument("--lrate_decay", type=int, default=250,
59 | help='exponential learning rate decay (in 1000 steps)')
60 | parser.add_argument("--chunk", type=int, default=1024*8,
61 | help='number of rays processed in parallel, decrease if running out of memory')
62 | parser.add_argument("--netchunk", type=int, default=1024*64,
63 | help='number of pts sent through network in parallel, decrease if running out of memory')
64 | parser.add_argument("--no_batching", action='store_true',
65 | help='only take random rays from 1 image at a time')
66 | parser.add_argument("--no_reload", action='store_true',
67 | help='do not reload weights from saved ckpt')
68 | parser.add_argument("--ft_path", type=str, default=None,
69 | help='specific weights npy file to reload for coarse network')
70 | parser.add_argument("--threshold", type=float, default=0,
71 | help='threshold')
72 |
73 | # rendering options
74 | parser.add_argument("--N_samples", type=int, default=64,
75 | help='number of coarse samples per ray')
76 | parser.add_argument("--N_importance", type=int, default=0,
77 | help='number of additional fine samples per ray')
78 | parser.add_argument("--perturb", type=float, default=1.,
79 | help='set to 0. for no jitter, 1. for jitter')
80 | parser.add_argument("--use_viewdirs", action='store_true',
81 | help='use full 5D input instead of 3D')
82 | parser.add_argument("--i_embed", type=int, default=0,
83 | help='set 0 for default positional encoding, -1 for none')
84 | parser.add_argument("--multires", type=int, default=10,
85 | help='log2 of max freq for positional encoding (3D location)')
86 | parser.add_argument("--multires_views", type=int, default=4,
87 | help='log2 of max freq for positional encoding (2D direction)')
88 | parser.add_argument("--raw_noise_std", type=float, default=0.,
89 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
90 |
91 | parser.add_argument("--render_only", action='store_true',
92 | help='do not optimize, reload weights and render out render_poses path')
93 | parser.add_argument("--render_test", action='store_true',
94 | help='render the test set instead of render_poses path')
95 | parser.add_argument("--render_factor", type=int, default=0,
96 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
97 |
98 | # training options
99 | parser.add_argument("--precrop_iters", type=int, default=0,
100 | help='number of steps to train on central crops')
101 | parser.add_argument("--head_num", type=int, default=8,
102 | help='number of heads')
103 | parser.add_argument("--precrop_frac", type=float,
104 | default=.5, help='fraction of img taken for central crops')
105 | parser.add_argument("--large_scene", action='store_true',
106 | help='use large scene')
107 |
108 | # dataset options
109 | parser.add_argument("--dataset_type", type=str, default='llff',
110 | help='options: llff / blender / deepvoxels')
111 | parser.add_argument("--testskip", type=int, default=8,
112 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
113 | parser.add_argument("--faketestskip", type=int, default=1,
114 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
115 |
116 | ## deepvoxels flags
117 | parser.add_argument("--shape", type=str, default='greek',
118 | help='options : armchair / cube / greek / vase')
119 |
120 | ## blender flags
121 | parser.add_argument("--white_bkgd", action='store_true',
122 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
123 | parser.add_argument("--half_res", action='store_true',
124 | help='load blender synthetic data at 400x400 instead of 800x800')
125 | parser.add_argument("--near", type=float, default=2.,
126 | help='downsample factor for LLFF images')
127 | parser.add_argument("--far", type=float, default=6.,
128 | help='downsample factor for LLFF images')
129 | parser.add_argument("--do_intrinsic", action='store_true',
130 | help='use intrinsic matrix')
131 | parser.add_argument("--blender_factor", type=int, default=1,
132 | help='downsample factor for blender images')
133 |
134 | ## llff flags
135 | parser.add_argument("--factor", type=int, default=8,
136 | help='downsample factor for LLFF images')
137 | parser.add_argument("--no_ndc", action='store_true',
138 | help='do not use normalized device coordinates (set for non-forward facing scenes)')
139 | parser.add_argument("--lindisp", action='store_true',
140 | help='sampling linearly in disparity rather than depth')
141 | parser.add_argument("--spherify", action='store_true',
142 | help='set for spherical 360 scenes')
143 | parser.add_argument("--llffhold", type=int, default=8,
144 | help='will take every 1/N images as LLFF test set, paper uses 8')
145 |
146 | # logging/saving options
147 | parser.add_argument("--i_print", type=int, default=100,
148 | help='frequency of console printout and metric loggin')
149 | parser.add_argument("--i_img", type=int, default=50000,
150 | help='frequency of tensorboard image logging')
151 | parser.add_argument("--i_weights", type=int, default=10000,
152 | help='frequency of weight ckpt saving')
153 | parser.add_argument("--i_testset", type=int, default=5000,
154 | help='frequency of testset saving')
155 | parser.add_argument("--i_tottest", type=int, default=400000,
156 | help='frequency of testset saving')
157 | parser.add_argument("--i_video", type=int, default=50000,
158 | help='frequency of render_poses video saving')
159 |
160 | parser.add_argument("--scaledown", type=int, default=1,
161 | help='frequency of render_poses video saving')
162 |
163 | return parser
164 |
165 | def create_nerf(args):
166 | """Instantiate NeRF's MLP model.
167 | """
168 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed, 2)
169 | input_ch_views = 0
170 | output_ch = 3
171 | skips = [4]
172 | model = NeRF(D=args.netdepth, W=args.netwidth,
173 | input_ch=input_ch, output_ch=output_ch, skips=skips,
174 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs,
175 | head_num=args.head_num, threshold=args.threshold)
176 | optimizer = jt.optim.Adam(params=model.parameters(), lr=args.lrate, betas=(0.9, 0.999))
177 |
178 | return model, optimizer, embed_fn
179 |
180 | def render(H, W, x, chunk, embed_fn, model):
181 | rgb = []
182 | confs = []
183 | outnums = []
184 | for i in range(0,x.shape[0],chunk):
185 | xx = x[i:i+chunk]
186 | feat = embed_fn(x[i:i+chunk])
187 | y, conf, sout_num = model(feat, xx, False)
188 | rgb.append(y[-1])
189 | confs.append(conf[-1])
190 | outnums.append(sout_num)
191 | rgb[-1].sync()
192 | confs[-1].sync()
193 | rgb = jt.concat(rgb, 0).reshape(H,W,3)
194 | confs = jt.concat(confs, 0)
195 | outnum = np.concatenate(outnums, axis=0)
196 | print("outnum",outnum.shape)
197 | outnum = outnum.sum(0)
198 | print("outnum",outnum.shape,outnum)
199 | return rgb, confs
200 |
201 | def dfs(t, points, model):
202 | k = len(model.son_list[t])
203 | print("dfs",t,"k",k)
204 | print("points",points.shape)
205 | if t in model.force_out:
206 | if points.shape[0]>=k:
207 | centroid = points[jt.array(np.random.choice(points.shape[0], k, replace=False))]
208 | print("centroid",centroid.shape)
209 | # print("step",-1,centroid.numpy())
210 | for step in range(100):
211 | # print("step",step)
212 | dis = (points.unsqueeze(1) - centroid.unsqueeze(0)).sqr().sum(-1).sqrt()
213 | min_idx, _ = jt.argmin(dis,-1)
214 | # print("min_idx",min_idx.shape)
215 | for i in range(k):
216 | # print("i",i,(min_idx==i).sum)
217 | centroid[i] = points[min_idx==i].mean(0)
218 | jt.sync_all()
219 | # jt.display_memory_info()z
220 | # print("step",step,centroid.numpy())
221 | else:
222 | centroid = jt.rand((k,2))
223 | print("centroid fail",centroid.shape)
224 | print("centroid",centroid.shape,centroid)
225 | setattr(model, model.node_list[t].anchors, centroid.detach())
226 | # warning mpi
227 | if jt.mpi and False:
228 | # v1 = getattr(model, model.node_list[t].anchors)
229 | # v1.assign(jt.mpi.broadcast(v2, root=0))
230 | jt.mpi.broadcast(getattr(model, model.node_list[t].anchors), root=0)
231 | print("model", jt.mpi.local_rank(), t, getattr(model, model.node_list[t].anchors))
232 | for i in model.son_list[t]:
233 | # model.outnet[i].alpha_linear.load_state_dict(model.outnet[t].alpha_linear.state_dict())
234 | model.outnet[i].load_state_dict(model.outnet[t].state_dict())
235 | return model.son_list[t]
236 | else:
237 | centroid = model.get_anchor(model.node_list[t].anchors)
238 | dis = (points.unsqueeze(1) - centroid.unsqueeze(0)).sqr().sum(-1).sqrt()
239 | min_idx, _ = jt.argmin(dis,-1)
240 | res = []
241 | for i in range(k):
242 | res += dfs(model.son_list[t][i], points[min_idx==i], model)
243 | return res
244 |
245 | def do_kmeans(pts, model):
246 | force_out = dfs(0, pts, model)
247 | model.force_out = force_out
248 |
249 | def print_conf_img(rgb, conf, threshold, testsavedir):
250 | os.makedirs(testsavedir, exist_ok=True)
251 | # threshold_list = [threshold, 2e-2, 3e-2, 4e-2, 5e-2, 1e-1, 5e-1, 1]
252 | threshold_list = [threshold]
253 | H, W, C = rgb.shape
254 | red = np.array([1,0,0])
255 | for th in threshold_list:
256 | filename = os.path.join(testsavedir, 'confimg_'+str(th)+'.png')
257 | logimg = rgb.copy().reshape([-1,3])
258 | print("logimg1",logimg.shape)
259 | # bo = (conf
= rand_coords.shape[0]:
338 | # print("Shuffle data after an epoch!")
339 | rand_idx = jt.randperm(coords_f.shape[0])
340 | rand_coords = coords_f[rand_idx]
341 | rand_image = image.reshape([-1,3])[rand_idx]
342 | i_batch = 0
343 | else:
344 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
345 | select_coords = coords[select_inds].long() # (N_rand, 2)
346 | x = coords_f[select_inds] # (N_rand, 2)
347 | target = image[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
348 | # print("select_coords",select_coords.shape,select_coords)
349 | # print("x",x.shape,x)
350 | # print("target",target.shape,target)
351 |
352 | # if i == 100:
353 | # import cProfile
354 | # pr = cProfile.Profile()
355 | # pr.enable()
356 | # jt.flags.trace_py_var=3
357 | # jt.flags.profiler_enable = 1
358 | # elif i == 110:
359 | # pr.disable()
360 | # pr.print_stats()
361 | # jt.flags.profiler_enable = 0
362 | # jt.profiler.report()
363 | # jt.flags.trace_py_var=0
364 | feat = embed_fn(x)
365 | y, conf, _ = model(feat, x, True)
366 | save=((y-target)**2).detach().unsqueeze(-2)
367 | # loss = img2mse(y[-1], target)
368 | loss = img2mse(y, target.unsqueeze(0))
369 | psnr = mse2psnr(loss)
370 |
371 | sloss = jt.maximum(((y-target)**2).detach()-conf,0.)
372 | loss_conf_loss = sloss.mean()
373 | loss_conf_mean = jt.maximum(conf, 0.).mean()
374 | loss_conf = loss_conf_loss+loss_conf_mean*0.01
375 | loss = loss + loss_conf*0.1
376 |
377 | # optimizer.backward(loss)
378 | optimizer.step(loss)
379 | jt.sync_all()
380 | # jt.display_memory_info()
381 |
382 | # NOTE: IMPORTANT!
383 | ### update learning rate ###
384 | decay_rate = 0.1
385 | decay_steps = args.lrate_decay * 1000
386 | sstep = global_step
387 | if sstep>split_tree3:
388 | sstep-=split_tree3
389 | elif sstep>split_tree2:
390 | sstep-=split_tree2
391 | elif sstep>split_tree1:
392 | sstep-=split_tree1
393 | new_lrate = args.lrate * (decay_rate ** (sstep / decay_steps))
394 | for param_group in optimizer.param_groups:
395 | param_group['lr'] = new_lrate
396 |
397 | if i%args.i_print==0:
398 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
399 | writer.add_scalar("train/loss", loss.item(), global_step)
400 | writer.add_scalar("train/loss_conf", loss_conf.item(), global_step)
401 | writer.add_scalar("train/PSNR", psnr.item(), global_step)
402 | writer.add_scalar("lr/lr", new_lrate, global_step)
403 | if i%args.i_testset==0:
404 | with jt.no_grad():
405 | rgb, _ = render(H, W, coords_f, N_rand, embed_fn, model)
406 | print("rgb",rgb.shape)
407 | psnr = mse2psnr(img2mse(rgb, image))
408 | rgb = rgb.numpy()
409 |
410 | writer.add_scalar('test/psnr', psnr.item(), global_step)
411 | if i%args.i_img==0:
412 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
413 | filename = os.path.join(testsavedir, '{:03d}.png'.format(i))
414 | os.makedirs(testsavedir, exist_ok=True)
415 | imageio.imwrite(filename, to8b(rgb))
416 |
417 | writer.add_image('test/rgb', to8b(rgb), global_step, dataformats="HWC")
418 | writer.add_image('test/target', image.numpy(), global_step, dataformats="HWC")
419 | if i==split_tree1 or i==split_tree2 or i==split_tree3:
420 | with jt.no_grad():
421 | rgb, conf = render(H, W, coords_f, N_rand, embed_fn, model)
422 | print("split conf",conf.shape)
423 | print("split coords_f",coords_f.shape)
424 | pts = coords_f[conf.squeeze(1)>=args.threshold]
425 | print("split pts threshold:",pts.shape)
426 | # pts = coords_f[conf>=2e-2]
427 | # print("split pts 2e-2:",pts.shape)
428 | # pts = coords_f[conf>=3e-2]
429 | # print("split pts 3e-2:",pts.shape)
430 | # pts = coords_f[conf>=4e-2]
431 | # print("split pts 4e-2:",pts.shape)
432 | # pts = coords_f[conf>=5e-2]
433 | # print("split pts 5e-2:",pts.shape)
434 | # pts = coords_f[conf>=1e-1]
435 | # print("split pts 1e-1:",pts.shape)
436 | # pts = coords_f[conf>=5e-1]
437 | # print("split pts 5e-2:",pts.shape)
438 | # pts = coords_f[conf>=1]
439 | # print("split pts 1e-0:",pts.shape)
440 | do_kmeans(pts, model)
441 | # print_conf_img(rgb.numpy(), conf.numpy(), args.threshold, os.path.join(basedir, expname, 'confimg_{:06d}'.format(i)))
442 | jt.gc()
443 | global_step += 1
444 |
445 |
446 | if __name__=='__main__':
447 | train()
448 |
--------------------------------------------------------------------------------
/image_mem_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jittor as jt
3 | from jittor import nn
4 | import numpy as np
5 |
6 |
7 | # Misc
8 | img2mse = lambda x, y : jt.mean((x - y) ** 2)
9 | mse2psnr = lambda x : -10. * jt.log(x) / jt.log(jt.array(np.array([10.])))
10 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
11 |
12 |
13 | # Positional encoding (section 5.1)
14 | class Embedder:
15 | def __init__(self, **kwargs):
16 | self.kwargs = kwargs
17 | self.create_embedding_fn()
18 |
19 | def create_embedding_fn(self):
20 | embed_fns = []
21 | d = self.kwargs['input_dims']
22 | out_dim = 0
23 | if self.kwargs['include_input']:
24 | embed_fns.append(lambda x : x)
25 | out_dim += d
26 |
27 | max_freq = self.kwargs['max_freq_log2']
28 | N_freqs = self.kwargs['num_freqs']
29 |
30 | if self.kwargs['log_sampling']:
31 | freq_bands = 2.**jt.linspace(0., max_freq, steps=N_freqs)
32 | else:
33 | freq_bands = jt.linspace(2.**0., 2.**max_freq, steps=N_freqs)
34 |
35 | for freq in freq_bands:
36 | for p_fn in self.kwargs['periodic_fns']:
37 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
38 | out_dim += d
39 |
40 | self.embed_fns = embed_fns
41 | self.out_dim = out_dim
42 |
43 | def embed(self, inputs):
44 | return jt.concat([fn(inputs) for fn in self.embed_fns], -1)
45 |
46 |
47 | def get_embedder(multires, i=0, input_dims=3):
48 | if i == -1:
49 | return nn.Identity(), 3
50 |
51 | embed_kwargs = {
52 | 'include_input' : True,
53 | 'input_dims' : input_dims,
54 | 'max_freq_log2' : multires-1,
55 | 'num_freqs' : multires,
56 | 'log_sampling' : True,
57 | 'periodic_fns' : [jt.sin, jt.cos],
58 | }
59 |
60 | embedder_obj = Embedder(**embed_kwargs)
61 | embed = lambda x, eo=embedder_obj : eo.embed(x)
62 | return embed, embedder_obj.out_dim
63 |
64 | #tree class
65 | class Node():
66 | def __init__(self, anchors, sons, linears):
67 | self.anchors = anchors
68 | self.sons = sons
69 | self.linears = linears
70 |
71 | # Model
72 | class OutputNet(nn.Module):
73 | def __init__(self, W, input_ch_views):
74 | """
75 | """
76 | super(OutputNet, self).__init__()
77 |
78 | self.rgb_linear = nn.Linear(W, 3)
79 |
80 | def execute(self, h, input_views):
81 |
82 | rgb = self.rgb_linear(h)
83 | return rgb
84 |
85 | # Model
86 | class NeRF(nn.Module):
87 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False, head_num=8, threshold=3e-2):
88 | """
89 | """
90 | super(NeRF, self).__init__()
91 | D=12
92 | print("input_ch",input_ch,"input_ch_views",input_ch_views,"output_ch",output_ch,"skips",skips,"use_viewdirs",use_viewdirs)
93 | # W=128
94 | self.D = D
95 | self.W = W
96 | self.input_ch = input_ch
97 | self.input_ch_views = input_ch_views
98 | skips=[]
99 | # skips = [2,4,6]
100 | self.skips = skips
101 | # self.ress = [1,3,7,11]
102 | # self.ress = []
103 | # self.outs = [1,3,7,11]
104 | self.force_out = [0]
105 | # self.force_out = [7,8,9,10,11,12,13,14]
106 | self.use_viewdirs = use_viewdirs
107 | assert self.use_viewdirs==False
108 |
109 | self.threshold = threshold
110 |
111 | self.build_tree(head_num)
112 | self.pts_linears = nn.ModuleList(
113 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skip_linear else nn.Linear(W + input_ch, W) for i in range(self.linear_num-1)])
114 | # for i in range(self.nlinear_list[0]+1,len(self.pts_linears)):
115 | # jt.init.constant_(self.pts_linears[i].weight, 0.0)
116 | # jt.init.constant_(self.pts_linears[i].bias, 0.0)
117 | # self.confidence_linears = nn.ModuleList([nn.Linear(W+ input_ch, 1) for i in range(D)])
118 | self.confidence_linears = nn.ModuleList([nn.Linear(W, 1) for i in range(self.node_num)])
119 | # self.outnet = OutputNet(W, input_ch_views)
120 | self.outnet = nn.ModuleList([OutputNet(W, input_ch_views) for i in range(self.node_num)])
121 |
122 | def get_anchor(self, i):
123 | return getattr(self, i)
124 |
125 | def get_son_list(self, son_num, nlinear):
126 | son_list = []
127 | nlinear_list = []
128 | skip_linear = []
129 |
130 | queue = [(0,0)]
131 | head = 0
132 | tot_linear = 0
133 | while head0:
236 | anchor = "anchor"+str(self.anchor_num)
237 | self.anchor_num += 1
238 | setattr(self, anchor, jt.array(self.anchor_list[i]))
239 | # setattr(self, anchor, jt.random([len(son), 3]))
240 | else:
241 | anchor = None
242 | linear = list(range(self.linear_num, self.linear_num+self.nlinear_list[i]))
243 | self.linear_num += self.nlinear_list[i]
244 | self.node_list.append(Node(anchor, son, linear))
245 |
246 | def my_concat(self, a, b, dim):
247 | if a is None:
248 | return b
249 | elif b is None:
250 | return a
251 | else:
252 | return jt.concat([a,b],dim)
253 |
254 | def search(self, t, p, h, input_pts, input_views, remain_mask):
255 | node = self.node_list[t]
256 | # print("search t",t,"remain_mask",remain_mask.sum())
257 | identity = h
258 | for i in range(len(node.linears)):
259 | # print("i",i)
260 | # print("h",h.shape)
261 | # print(self.pts_linears[node.linears[i]])
262 | # print("len",len(self.pts_linears),"node.linears[i]",node.linears[i],"node",node)
263 | # print("t",t,"i",i,"h",h.shape,"line",self.pts_linears[node.linears[i]].weight.shape)
264 | h = self.pts_linears[node.linears[i]](h)
265 | # if t==0 and i==0:
266 | # identity = h
267 | # if i==len(node.linears)-1:
268 | # h = h+identity
269 | if i%2==1:
270 | h += identity
271 | h = jt.nn.relu(h)
272 | if i%2==1 or (t==0 and i==0):
273 | identity = h
274 | if node.linears[i] in self.skip_linear:
275 | h = jt.concat([input_pts, h], -1)
276 |
277 | confidence = self.confidence_linears[t](h).view(-1)
278 | # threshold = 0.0
279 | threshold = self.threshold
280 | # threshold = -1e10
281 | output = self.outnet[t](h, input_views)
282 | # output = self.outnet[0](h, input_views)
283 | out_num = np.zeros((self.node_num))
284 |
285 | if len(node.sons)>0 and (not t in self.force_out):
286 | son_outputs = None
287 | son_outputs_fuse = None
288 | son_confs = None
289 | son_confs_fuse = None
290 | idxs = None
291 | idxs_fuse = None
292 |
293 | anchor = self.get_anchor(node.anchors)
294 | dis = (anchor.unsqueeze(0)-p.unsqueeze(1)).sqr().sum(-1).sqrt()
295 | min_idx, _ = jt.argmin(dis,-1)
296 | for i in range(len(node.sons)):
297 | # print("t",t,"i",i)
298 | next_t = node.sons[i]
299 | sidx = jt.arange(0,p.shape[0])
300 | # print("min_idx==i",min_idx==i)
301 | sidx = sidx[min_idx==i]
302 | # print("sidx",sidx)
303 | next_p = p[sidx]
304 | next_h = h[sidx]
305 | next_input_pts = input_pts[sidx]
306 | next_input_views = input_views[sidx]
307 | next_remain_mask = remain_mask[sidx].copy()
308 | next_conf = confidence[sidx]
309 | next_remain_mask[threshold>next_conf] = 0
310 | sidx_fuse = sidx[next_remain_mask==1]
311 |
312 | # print("start t",t,"i",i,"next_t",next_t)
313 | next_outputs, next_outputs_fuse, next_confs, next_confs_fuse, next_out_num = self.search(next_t, next_p, next_h, next_input_pts, next_input_views, next_remain_mask)
314 | out_num = out_num+next_out_num
315 | # print("search", t, next_t)
316 | # print("next_outputs",next_outputs.shape)
317 | # print("next_outputs_fuse",next_outputs_fuse.shape)
318 | # print("next_confs",next_confs.shape)
319 | # print("next_confs_fuse",next_confs_fuse.shape)
320 |
321 | # print("end t",t,"i",i,"next_t",next_t)
322 | son_outputs = self.my_concat(son_outputs, next_outputs, 1)
323 | son_outputs_fuse = self.my_concat(son_outputs_fuse, next_outputs_fuse, 1)
324 | son_confs = self.my_concat(son_confs, next_confs, 1)
325 | son_confs_fuse = self.my_concat(son_confs_fuse, next_confs_fuse, 1)
326 | idxs = self.my_concat(idxs, sidx, 0)
327 | idxs_fuse = self.my_concat(idxs_fuse, sidx_fuse, 0)
328 |
329 | # print("t",t)
330 | son_outputs_save = jt.zeros(son_outputs.shape)
331 | son_outputs_save[:,idxs] = son_outputs
332 | son_outputs_save = jt.concat([output.unsqueeze(0), son_outputs_save], 0)
333 | son_confs_save = jt.zeros(son_confs.shape)
334 | son_confs_save[:,idxs] = son_confs
335 | son_confs_save = jt.concat([confidence.unsqueeze(1).unsqueeze(0), son_confs_save], 0)
336 |
337 | out_remain_mask = remain_mask.copy()
338 | out_remain_mask[threshold<=confidence] = 0
339 | idx_out = jt.arange(0,out_remain_mask.shape[0])[out_remain_mask==1]
340 | outputs_out = output[idx_out].unsqueeze(0)
341 | out_num[t] = outputs_out.shape[1]
342 | confs_out = confidence[idx_out].unsqueeze(1).unsqueeze(0)
343 | outputs_out = jt.concat([outputs_out, son_outputs_fuse], 1)
344 | confs_out = jt.concat([confs_out, son_confs_fuse], 1)
345 | idx_out = jt.concat([idx_out, idxs_fuse], 0)
346 |
347 | outputs_out_save = jt.zeros(output.unsqueeze(0).shape)
348 | outputs_out_save[:, idx_out] = outputs_out
349 | outputs_out_save = outputs_out_save[:, remain_mask==1]
350 | confs_out_save = jt.zeros(confidence.unsqueeze(1).unsqueeze(0).shape)
351 | confs_out_save[:, idx_out] = confs_out
352 | confs_out_save = confs_out_save[:, remain_mask==1]
353 |
354 | return son_outputs_save, outputs_out_save, son_confs_save, confs_out_save, out_num
355 | else:
356 | outputs_save = output.unsqueeze(0)
357 | outputs_save_log = outputs_save.copy()
358 | confs_save = confidence.unsqueeze(1).unsqueeze(0)
359 | # print("outputs_save",outputs_save.shape)
360 | # print("remain_mask",remain_mask.shape)
361 | # print("remain_mask==1",remain_mask==1)
362 | outputs_out_save = outputs_save[:, remain_mask==1]
363 | confs_out_save = confs_save[:, remain_mask==1]
364 | out_num[t] = outputs_out_save.shape[1]
365 |
366 | remain_mask[threshold<=confidence] = 0
367 | # print("out:", remain_mask.sum().numpy(), "remain:", remain_mask.shape[0]-remain_mask.sum().numpy())
368 |
369 | # print("outputs_out_save",outputs_out_save.shape)
370 | if not self.training:
371 | if t%4==0:
372 | outputs_save_log[..., 0] *= 0.
373 | outputs_save_log[..., 1] *= 0.
374 | elif t%4==1:
375 | outputs_save_log[..., 0] *= 0.
376 | outputs_save_log[..., 2] *= 0.
377 | elif t%4==2:
378 | outputs_save_log[..., 1] *= 0.
379 | outputs_save_log[..., 2] *= 0.
380 | elif t%4==3:
381 | outputs_save_log[..., 0] *= 0.
382 | # elif t==1:
383 | # outputs_out_save[..., 1] *= 0.
384 | # elif t==2:
385 | # outputs_out_save[..., 2] *= 0.
386 | outputs_save = jt.concat([outputs_save, outputs_save_log], 0)
387 | confs_save = jt.concat([confs_save, confs_save], 0)
388 | # print("outputs_out_save out",outputs_out_save.shape)
389 |
390 | # print("outputs_out_save",outputs_out_save.shape)
391 |
392 | return outputs_save, outputs_out_save, confs_save, confs_out_save, out_num
393 |
394 | def do_train(self, x, p):
395 | input_pts, input_views = jt.split(x, [self.input_ch, self.input_ch_views], dim=-1)
396 | remain_mask = jt.ones(input_pts.shape[0])
397 | outputs, outputs_fuse, confs, confs_fuse, out_num = self.search(0, p, input_pts, input_pts, input_views, remain_mask)
398 |
399 | outputs = jt.concat([outputs, outputs_fuse], 0)
400 | confs = jt.concat([confs, confs_fuse], 0)
401 |
402 | return outputs, confs, np.zeros([1])
403 |
404 | def do_eval(self, x, p):
405 | input_pts, input_views = jt.split(x, [self.input_ch, self.input_ch_views], dim=-1)
406 | remain_mask = jt.ones(input_pts.shape[0])
407 | outputs, outputs_fuse, confs, confs_fuse, out_num = self.search(0, p, input_pts, input_pts, input_views, remain_mask)
408 |
409 | log = "out: "
410 | sout_num = list(out_num)
411 | for i in range(len(sout_num)):
412 | log += str(i)+": %d; " % sout_num[i]
413 | # print(log)
414 | sout_num = np.array(sout_num)
415 | outputs = outputs[-1:]
416 |
417 | outputs = jt.concat([outputs, outputs_fuse], 0)
418 | confs = jt.concat([confs, confs_fuse], 0)
419 |
420 | return outputs, confs, sout_num
421 |
422 | def execute(self, x, p, training):
423 | self.training = training
424 | if training:
425 | return self.do_train(x, p)
426 | else:
427 | # return self.do_train(x, p)
428 | return self.do_eval(x, p)
429 |
430 | def load_weights_from_keras(self, weights):
431 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
432 |
433 | # Load pts_linears
434 | for i in range(self.D):
435 | idx_pts_linears = 2 * i
436 | self.pts_linears[i].weight.data = jt.array(np.transpose(weights[idx_pts_linears]))
437 | self.pts_linears[i].bias.data = jt.array(np.transpose(weights[idx_pts_linears+1]))
438 |
439 | # Load feature_linear
440 | idx_feature_linear = 2 * self.D
441 | self.feature_linear.weight.data = jt.array(np.transpose(weights[idx_feature_linear]))
442 | self.feature_linear.bias.data = jt.array(np.transpose(weights[idx_feature_linear+1]))
443 |
444 | # Load views_linears
445 | idx_views_linears = 2 * self.D + 2
446 | self.views_linears[0].weight.data = jt.array(np.transpose(weights[idx_views_linears]))
447 | self.views_linears[0].bias.data = jt.array(np.transpose(weights[idx_views_linears+1]))
448 |
449 | # Load rgb_linear
450 | idx_rbg_linear = 2 * self.D + 4
451 | self.rgb_linear.weight.data = jt.array(np.transpose(weights[idx_rbg_linear]))
452 | self.rgb_linear.bias.data = jt.array(np.transpose(weights[idx_rbg_linear+1]))
453 |
454 | # Load alpha_linear
455 | idx_alpha_linear = 2 * self.D + 6
456 | self.alpha_linear.weight.data = jt.array(np.transpose(weights[idx_alpha_linear]))
457 | self.alpha_linear.bias.data = jt.array(np.transpose(weights[idx_alpha_linear+1]))
458 |
459 |
460 |
461 | # Ray helpers
462 | def get_rays(H, W, focal, c2w, intrinsic = None):
463 | i, j = jt.meshgrid(jt.linspace(0, W-1, W), jt.linspace(0, H-1, H))
464 | i = i.t()
465 | j = j.t()
466 | if intrinsic is None:
467 | dirs = jt.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -jt.ones_like(i)], -1).unsqueeze(-2)
468 | else:
469 | i+=0.5
470 | j+=0.5
471 | dirs = jt.stack([i, j, jt.ones_like(i)], -1).unsqueeze(-2)
472 | dirs = jt.sum(dirs * intrinsic[:3,:3], -1).unsqueeze(-2)
473 | # Rotate ray directions from camera frame to the world frame
474 | rays_d = jt.sum(dirs * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
475 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
476 | rays_o = c2w[:3,-1].expand(rays_d.shape)
477 | return rays_o, rays_d
478 |
479 |
480 | def get_rays_np(H, W, focal, c2w):
481 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
482 | dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
483 | # Rotate ray directions from camera frame to the world frame
484 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
485 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
486 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
487 | return rays_o, rays_d
488 |
489 |
490 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
491 | # Shift ray origins to near plane
492 | t = -(near + rays_o[...,2]) / rays_d[...,2]
493 | rays_o = rays_o + t.unsqueeze(-1) * rays_d
494 |
495 | # Projection
496 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
497 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
498 | o2 = 1. + 2. * near / rays_o[...,2]
499 |
500 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
501 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
502 | d2 = -2. * near / rays_o[...,2]
503 |
504 | rays_o = jt.stack([o0,o1,o2], -1)
505 | rays_d = jt.stack([d0,d1,d2], -1)
506 |
507 | return rays_o, rays_d
508 |
509 |
510 | # Hierarchical sampling (section 5.2)
511 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
512 | # Get pdf
513 | weights = weights + 1e-5 # prevent nans
514 | pdf = weights / jt.sum(weights, -1, keepdims=True)
515 | cdf = jt.cumsum(pdf, -1)
516 | cdf = jt.concat([jt.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
517 |
518 | # Take uniform samples
519 | if det:
520 | u = jt.linspace(0., 1., steps=N_samples)
521 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
522 | else:
523 | u = jt.random(list(cdf.shape[:-1]) + [N_samples])
524 |
525 | # Pytest, overwrite u with numpy's fixed random numbers
526 | if pytest:
527 | # np.random.seed(0)
528 | new_shape = list(cdf.shape[:-1]) + [N_samples]
529 | if det:
530 | u = np.linspace(0., 1., N_samples)
531 | u = np.broadcast_to(u, new_shape)
532 | else:
533 | u = np.random.rand(*new_shape)
534 | u = jt.array(u)
535 |
536 | # Invert CDF
537 | # u = u.contiguous()
538 | # inds = searchsorted(cdf, u, side='right')
539 | inds = jt.searchsorted(cdf, u, right=True)
540 | below = jt.maximum(jt.zeros_like(inds-1), inds-1)
541 | above = jt.minimum((cdf.shape[-1]-1) * jt.ones_like(inds), inds)
542 | inds_g = jt.stack([below, above], -1) # (batch, N_samples, 2)
543 |
544 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
545 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
546 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
547 | cdf_g = jt.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
548 | bins_g = jt.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
549 |
550 | denom = (cdf_g[...,1]-cdf_g[...,0])
551 | cond = jt.where(denom<1e-5)
552 | denom[cond] = 1.
553 | t = (u-cdf_g[...,0])/denom
554 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
555 |
556 | return samples
557 |
--------------------------------------------------------------------------------
/imgs/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gword/Recursive-NeRF/e61ceb302d0bc028fe4524771b363bf514342b5f/imgs/pipeline.jpg
--------------------------------------------------------------------------------
/load_blender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 | import json
5 | import jittor as jt
6 | import cv2
7 |
8 |
9 | trans_t = lambda t : jt.array(np.array([
10 | [1,0,0,0],
11 | [0,1,0,0],
12 | [0,0,1,t],
13 | [0,0,0,1]]).astype(np.float32))
14 |
15 | rot_phi = lambda phi : jt.array(np.array([
16 | [1,0,0,0],
17 | [0,np.cos(phi),-np.sin(phi),0],
18 | [0,np.sin(phi), np.cos(phi),0],
19 | [0,0,0,1]]).astype(np.float32))
20 |
21 | rot_theta = lambda th : jt.array(np.array([
22 | [np.cos(th),0,-np.sin(th),0],
23 | [0,1,0,0],
24 | [np.sin(th),0, np.cos(th),0],
25 | [0,0,0,1]]).astype(np.float32))
26 |
27 |
28 | def pose_spherical(theta, phi, radius):
29 | c2w = trans_t(radius)
30 | c2w = rot_phi(phi/180.*np.pi) @ c2w
31 | c2w = rot_theta(theta/180.*np.pi) @ c2w
32 | c2w = jt.array(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
33 | return c2w
34 |
35 |
36 | def load_blender_data(basedir, half_res=False, testskip=1, factor=1, do_intrinsic = False):
37 | if half_res and factor==1:
38 | factor=2
39 | splits = ['train', 'val', 'test']
40 | metas = {}
41 | for s in splits:
42 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
43 | metas[s] = json.load(fp)
44 |
45 | all_imgs = []
46 | all_poses = []
47 | counts = [0]
48 | for s in splits:
49 | meta = metas[s]
50 | imgs = []
51 | poses = []
52 | if s=='train' or testskip==0:
53 | skip = 1
54 | else:
55 | skip = testskip
56 |
57 | for frame in meta['frames'][::skip]:
58 | fname = os.path.join(basedir, frame['file_path'] + '.png')
59 | imgs.append(imageio.imread(fname))
60 | poses.append(np.array(frame['transform_matrix']))
61 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
62 | poses = np.array(poses).astype(np.float32)
63 | counts.append(counts[-1] + imgs.shape[0])
64 | all_imgs.append(imgs)
65 | all_poses.append(poses)
66 |
67 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
68 |
69 | imgs = np.concatenate(all_imgs, 0)
70 | poses = np.concatenate(all_poses, 0)
71 |
72 | H, W = imgs[0].shape[:2]
73 | camera_angle_x = float(meta['camera_angle_x'])
74 | focal = .5 * W / np.tan(.5 * camera_angle_x)
75 |
76 | if do_intrinsic:
77 | a=np.array(meta['intrinsic_matrix'])
78 | # H, W, focal = 480,640,585
79 | # a = np.eye(4).astype(np.float32)
80 | # a[0,0]=focal
81 | # a[1,1]=focal
82 | # a[0,2]=W/2.
83 | # a[1,2]=H/2.
84 | if factor>1:
85 | a[:2]/=float(factor)
86 | a=np.linalg.inv(a)
87 | intrinsic=a
88 | print("intrinsic",intrinsic)
89 |
90 | render_poses = jt.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
91 | # render_poses = []
92 | # meta = metas['test']
93 | # for frame in meta['frames'][:40]:
94 | # render_poses.append(np.array(frame['transform_matrix']))
95 | # render_poses = jt.array(np.array(render_poses).astype(np.float32))
96 | if do_intrinsic:
97 | render_poses = []
98 | meta = metas['test']
99 | start = np.array(meta['frames'][0]['transform_matrix'])
100 | render_poses.append(start)
101 | for f in range(50,len(meta['frames']),50):
102 | end = np.array(meta['frames'][f]['transform_matrix'])
103 | for i in range(10):
104 | p=i/9.
105 | render_poses.append(start*(1.0-p)+end*p)
106 | start = end
107 | render_poses = jt.array(np.array(render_poses).astype(np.float32))
108 |
109 | if factor>1:
110 | H = H//factor
111 | W = W//factor
112 | focal = focal/float(factor)
113 |
114 | imgs_half_res = np.zeros((imgs.shape[0], H, W, imgs.shape[-1]))
115 | for i, img in enumerate(imgs):
116 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
117 | imgs = imgs_half_res
118 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
119 |
120 |
121 | if do_intrinsic:
122 | return imgs, poses, intrinsic, render_poses, [H, W, focal], i_split
123 | else:
124 | return imgs, poses, render_poses, [H, W, focal], i_split
125 |
126 |
127 |
--------------------------------------------------------------------------------
/load_deepvoxels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 |
5 |
6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8):
7 |
8 |
9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
10 | # Get camera intrinsics
11 | with open(filepath, 'r') as file:
12 | f, cx, cy = list(map(float, file.readline().split()))[:3]
13 | grid_barycenter = np.array(list(map(float, file.readline().split())))
14 | near_plane = float(file.readline())
15 | scale = float(file.readline())
16 | height, width = map(float, file.readline().split())
17 |
18 | try:
19 | world2cam_poses = int(file.readline())
20 | except ValueError:
21 | world2cam_poses = None
22 |
23 | if world2cam_poses is None:
24 | world2cam_poses = False
25 |
26 | world2cam_poses = bool(world2cam_poses)
27 |
28 | print(cx,cy,f,height,width)
29 |
30 | cx = cx / width * trgt_sidelength
31 | cy = cy / height * trgt_sidelength
32 | f = trgt_sidelength / height * f
33 |
34 | fx = f
35 | if invert_y:
36 | fy = -f
37 | else:
38 | fy = f
39 |
40 | # Build the intrinsic matrices
41 | full_intrinsic = np.array([[fx, 0., cx, 0.],
42 | [0., fy, cy, 0],
43 | [0., 0, 1, 0],
44 | [0, 0, 0, 1]])
45 |
46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
47 |
48 |
49 | def load_pose(filename):
50 | assert os.path.isfile(filename)
51 | nums = open(filename).read().split()
52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32)
53 |
54 |
55 | H = 512
56 | W = 512
57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene)
58 |
59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H)
60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses)
61 | focal = full_intrinsic[0,0]
62 | print(H, W, focal)
63 |
64 |
65 | def dir2poses(posedir):
66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
67 | transf = np.array([
68 | [1,0,0,0],
69 | [0,-1,0,0],
70 | [0,0,-1,0],
71 | [0,0,0,1.],
72 | ])
73 | poses = poses @ transf
74 | poses = poses[:,:3,:4].astype(np.float32)
75 | return poses
76 |
77 | posedir = os.path.join(deepvoxels_base, 'pose')
78 | poses = dir2poses(posedir)
79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene))
80 | testposes = testposes[::testskip]
81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene))
82 | valposes = valposes[::testskip]
83 |
84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')]
85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32)
86 |
87 |
88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene)
89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')]
90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
91 |
92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene)
93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')]
94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
95 |
96 | all_imgs = [imgs, valimgs, testimgs]
97 | counts = [0] + [x.shape[0] for x in all_imgs]
98 | counts = np.cumsum(counts)
99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
100 |
101 | imgs = np.concatenate(all_imgs, 0)
102 | poses = np.concatenate([poses, valposes, testposes], 0)
103 |
104 | render_poses = testposes
105 |
106 | print(poses.shape, imgs.shape)
107 |
108 | return imgs, poses, render_poses, [H,W,focal], i_split
109 |
110 |
111 |
--------------------------------------------------------------------------------
/load_llff.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os, imageio
3 |
4 |
5 | ########## Slightly modified version of LLFF data loading code
6 | ########## see https://github.com/Fyusion/LLFF for original
7 |
8 | def _minify(basedir, factors=[], resolutions=[]):
9 | needtoload = False
10 | for r in factors:
11 | imgdir = os.path.join(basedir, 'images_{}'.format(r))
12 | if not os.path.exists(imgdir):
13 | needtoload = True
14 | for r in resolutions:
15 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
16 | if not os.path.exists(imgdir):
17 | needtoload = True
18 | if not needtoload:
19 | return
20 |
21 | from shutil import copy
22 | from subprocess import check_output
23 |
24 | imgdir = os.path.join(basedir, 'images')
25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
27 | imgdir_orig = imgdir
28 |
29 | wd = os.getcwd()
30 |
31 | for r in factors + resolutions:
32 | if isinstance(r, int):
33 | name = 'images_{}'.format(r)
34 | resizearg = '{}%'.format(100./r)
35 | else:
36 | name = 'images_{}x{}'.format(r[1], r[0])
37 | resizearg = '{}x{}'.format(r[1], r[0])
38 | imgdir = os.path.join(basedir, name)
39 | if os.path.exists(imgdir):
40 | continue
41 |
42 | print('Minifying', r, basedir)
43 |
44 | os.makedirs(imgdir)
45 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)
46 |
47 | ext = imgs[0].split('.')[-1]
48 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
49 | print(args)
50 | os.chdir(imgdir)
51 | check_output(args, shell=True)
52 | os.chdir(wd)
53 |
54 | if ext != 'png':
55 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
56 | print('Removed duplicates')
57 | print('Done')
58 |
59 |
60 |
61 |
62 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
63 |
64 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
65 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
66 | bds = poses_arr[:, -2:].transpose([1,0])
67 |
68 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
69 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
70 | sh = imageio.imread(img0).shape
71 |
72 | sfx = ''
73 |
74 | if factor is not None:
75 | sfx = '_{}'.format(factor)
76 | _minify(basedir, factors=[factor])
77 | factor = factor
78 | elif height is not None:
79 | factor = sh[0] / float(height)
80 | width = int(sh[1] / factor)
81 | _minify(basedir, resolutions=[[height, width]])
82 | sfx = '_{}x{}'.format(width, height)
83 | elif width is not None:
84 | factor = sh[1] / float(width)
85 | height = int(sh[0] / factor)
86 | _minify(basedir, resolutions=[[height, width]])
87 | sfx = '_{}x{}'.format(width, height)
88 | else:
89 | factor = 1
90 |
91 | imgdir = os.path.join(basedir, 'images' + sfx)
92 | if not os.path.exists(imgdir):
93 | print( imgdir, 'does not exist, returning' )
94 | return
95 |
96 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
97 | if poses.shape[-1] != len(imgfiles):
98 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
99 | return
100 |
101 | sh = imageio.imread(imgfiles[0]).shape
102 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
103 | poses[2, 4, :] = poses[2, 4, :] * 1./factor
104 |
105 | if not load_imgs:
106 | return poses, bds
107 |
108 | def imread(f):
109 | if f.endswith('png'):
110 | return imageio.imread(f, ignoregamma=True)
111 | else:
112 | return imageio.imread(f)
113 |
114 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles]
115 | imgs = np.stack(imgs, -1)
116 |
117 | print('Loaded image data', imgs.shape, poses[:,-1,0])
118 | return poses, bds, imgs
119 |
120 |
121 |
122 |
123 |
124 |
125 | def normalize(x):
126 | return x / np.linalg.norm(x)
127 |
128 | def viewmatrix(z, up, pos):
129 | vec2 = normalize(z)
130 | vec1_avg = up
131 | vec0 = normalize(np.cross(vec1_avg, vec2))
132 | vec1 = normalize(np.cross(vec2, vec0))
133 | m = np.stack([vec0, vec1, vec2, pos], 1)
134 | return m
135 |
136 | def ptstocam(pts, c2w):
137 | tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0]
138 | return tt
139 |
140 | def poses_avg(poses):
141 |
142 | hwf = poses[0, :3, -1:]
143 |
144 | center = poses[:, :3, 3].mean(0)
145 | vec2 = normalize(poses[:, :3, 2].sum(0))
146 | up = poses[:, :3, 1].sum(0)
147 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
148 |
149 | return c2w
150 |
151 |
152 |
153 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
154 | render_poses = []
155 | rads = np.array(list(rads) + [1.])
156 | hwf = c2w[:,4:5]
157 |
158 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
159 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
160 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
161 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
162 | return render_poses
163 |
164 |
165 |
166 | def recenter_poses(poses):
167 |
168 | poses_ = poses+0
169 | bottom = np.reshape([0,0,0,1.], [1,4])
170 | c2w = poses_avg(poses)
171 | c2w = np.concatenate([c2w[:3,:4], bottom], -2)
172 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1])
173 | poses = np.concatenate([poses[:,:3,:4], bottom], -2)
174 |
175 | poses = np.linalg.inv(c2w) @ poses
176 | poses_[:,:3,:4] = poses[:,:3,:4]
177 | poses = poses_
178 | return poses
179 |
180 |
181 | #####################
182 |
183 |
184 | def spherify_poses(poses, bds):
185 |
186 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)
187 |
188 | rays_d = poses[:,:3,2:3]
189 | rays_o = poses[:,:3,3:4]
190 |
191 | def min_line_dist(rays_o, rays_d):
192 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
193 | b_i = -A_i @ rays_o
194 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
195 | return pt_mindist
196 |
197 | pt_mindist = min_line_dist(rays_o, rays_d)
198 |
199 | center = pt_mindist
200 | up = (poses[:,:3,3] - center).mean(0)
201 |
202 | vec0 = normalize(up)
203 | vec1 = normalize(np.cross([.1,.2,.3], vec0))
204 | vec2 = normalize(np.cross(vec0, vec1))
205 | pos = center
206 | c2w = np.stack([vec1, vec2, vec0, pos], 1)
207 |
208 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])
209 |
210 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))
211 |
212 | sc = 1./rad
213 | poses_reset[:,:3,3] *= sc
214 | bds *= sc
215 | rad *= sc
216 |
217 | centroid = np.mean(poses_reset[:,:3,3], 0)
218 | zh = centroid[2]
219 | radcircle = np.sqrt(rad**2-zh**2)
220 | new_poses = []
221 |
222 | for th in np.linspace(0.,2.*np.pi, 120):
223 |
224 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
225 | up = np.array([0,0,-1.])
226 |
227 | vec2 = normalize(camorigin)
228 | vec0 = normalize(np.cross(vec2, up))
229 | vec1 = normalize(np.cross(vec2, vec0))
230 | pos = camorigin
231 | p = np.stack([vec0, vec1, vec2, pos], 1)
232 |
233 | new_poses.append(p)
234 |
235 | new_poses = np.stack(new_poses, 0)
236 |
237 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)
238 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)
239 |
240 | return poses_reset, new_poses, bds
241 |
242 |
243 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False):
244 |
245 |
246 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x
247 | print('Loaded', basedir, bds.min(), bds.max())
248 |
249 | # Correct rotation matrix ordering and move variable dim to axis 0
250 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
251 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
252 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
253 | images = imgs
254 | bds = np.moveaxis(bds, -1, 0).astype(np.float32)
255 |
256 | # Rescale if bd_factor is provided
257 | sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor)
258 | poses[:,:3,3] *= sc
259 | bds *= sc
260 |
261 | if recenter:
262 | poses = recenter_poses(poses)
263 |
264 | if spherify:
265 | poses, render_poses, bds = spherify_poses(poses, bds)
266 |
267 | else:
268 |
269 | c2w = poses_avg(poses)
270 | print('recentered', c2w.shape)
271 | print(c2w[:3,:4])
272 |
273 | ## Get spiral
274 | # Get average pose
275 | up = normalize(poses[:, :3, 1].sum(0))
276 |
277 | # Find a reasonable "focus depth" for this dataset
278 | close_depth, inf_depth = bds.min()*.9, bds.max()*5.
279 | dt = .75
280 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
281 | focal = mean_dz
282 |
283 | # Get radii for spiral path
284 | shrink_factor = .8
285 | zdelta = close_depth * .2
286 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T
287 | rads = np.percentile(np.abs(tt), 90, 0)
288 | c2w_path = c2w
289 | N_views = 120
290 | N_rots = 2
291 | if path_zflat:
292 | # zloc = np.percentile(tt, 10, 0)[2]
293 | zloc = -close_depth * .1
294 | c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2]
295 | rads[2] = 0.
296 | N_rots = 1
297 | N_views/=2
298 |
299 | # Generate poses for spiral path
300 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)
301 |
302 |
303 | render_poses = np.array(render_poses).astype(np.float32)
304 |
305 | c2w = poses_avg(poses)
306 | print('Data:')
307 | print(poses.shape, images.shape, bds.shape)
308 |
309 | dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1)
310 | i_test = np.argmin(dists)
311 | print('HOLDOUT view is', i_test)
312 |
313 | images = images.astype(np.float32)
314 | poses = poses.astype(np.float32)
315 |
316 | return images, poses, bds, render_poses, i_test
317 |
318 |
319 |
320 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jittor
2 | imageio
3 | imageio-ffmpeg
4 | matplotlib
5 | configargparse
6 | tensorboard==1.14.0
7 | tqdm
8 | opencv-python
9 |
--------------------------------------------------------------------------------
/run_nerf.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import imageio
4 | import json
5 | import random
6 | import time
7 | import jittor as jt
8 | from jittor import nn
9 | from tqdm import tqdm, trange
10 | import datetime
11 |
12 | import matplotlib.pyplot as plt
13 |
14 | from run_nerf_helpers import *
15 |
16 | from load_llff import load_llff_data
17 | from load_deepvoxels import load_dv_data
18 | from load_blender import load_blender_data
19 | from tensorboardX import SummaryWriter
20 |
21 |
22 | jt.flags.use_cuda = 1
23 | # np.random.seed(0)
24 | DEBUG = False
25 |
26 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
27 | """Prepares inputs and applies network 'fn'.
28 | """
29 | inputs_flat = jt.reshape(inputs, [-1, inputs.shape[-1]])
30 | # print("x min", inputs_flat[:,0].min(), inputs_flat[:,0].max())
31 | # print("y min", inputs_flat[:,1].min(), inputs_flat[:,1].max())
32 | # print("z min", inputs_flat[:,2].min(), inputs_flat[:,2].max())
33 | embedded = embed_fn(inputs_flat)
34 | training = not jt.flags.no_grad
35 |
36 | if viewdirs is not None:
37 | input_dirs = viewdirs[:,None].expand(inputs.shape)
38 | input_dirs_flat = jt.reshape(input_dirs, [-1, input_dirs.shape[-1]])
39 | embedded_dirs = embeddirs_fn(input_dirs_flat)
40 | embedded = jt.concat([embedded, embedded_dirs], -1)
41 |
42 | if netchunk is None:
43 | outputs_flat, loss_conf, loss_num = fn(embedded)
44 | else:
45 | flat_list = []
46 | conf_list = []
47 | sout_num_list = []
48 | loss_conf = jt.zeros([1])
49 | # loss_num = 0
50 | # pre_value = 0
51 | for i in range(0, embedded.shape[0], netchunk):
52 | flat, conf, sout_num = fn(embedded[i:i+netchunk], inputs_flat[i:i+netchunk], training)
53 | flat_list.append(flat)
54 | conf_list.append(conf)
55 | sout_num_list.append(np.expand_dims(sout_num, 0))
56 | # loss_conf += lconf
57 | # loss_num += num
58 | # pre_value = (flat.reshape((-1))[0] + conf.reshape((-1))[0] + lconf.reshape((-1))[0]) * 0
59 | if not training:
60 | flat_list[-1].sync()
61 | conf_list[-1].sync()
62 | loss_conf.sync()
63 | outputs_flat = jt.concat(flat_list, 1)
64 | conf_flat = jt.concat(conf_list, 1)
65 | sout_flat = np.concatenate(sout_num_list, 0)
66 | sout_flat = sout_flat.sum(0)
67 | # loss_conf /= float(loss_num)
68 | outputs = jt.reshape(outputs_flat, [outputs_flat.shape[0]] + list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
69 | conf = jt.reshape(conf_flat, [conf_flat.shape[0]] + list(inputs.shape[:-1]) + [conf_flat.shape[-1]])
70 | return outputs, conf, loss_conf, sout_flat
71 |
72 |
73 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
74 | """Render rays in smaller minibatches to avoid OOM.
75 | """
76 | all_ret = {}
77 | for i in range(0, rays_flat.shape[0], chunk):
78 | # jt.display_memory_info()
79 | ret = render_rays(rays_flat[i:i+chunk], **kwargs)
80 | for k in ret:
81 | if k not in all_ret:
82 | all_ret[k] = []
83 | all_ret[k].append(ret[k])
84 | # print("all_ret[k]",k, len(all_ret[k]), all_ret[k][0].shape)
85 | # ret[k].sync()
86 | # jt.display_memory_info()
87 | for k in all_ret:
88 | # print("k",k)
89 | if k=="loss_conf" or k=="loss_conf0":
90 | all_ret[k] = jt.concat(all_ret[k], 0)
91 | elif k=="pts":
92 | all_ret[k] = np.concatenate(all_ret[k], axis=0)
93 | elif k=="outnum" or k=="outnum0":
94 | all_ret[k] = np.concatenate(all_ret[k], axis=0).sum(0)
95 | else:
96 | all_ret[k] = jt.concat(all_ret[k], 1)
97 | # all_ret = {k : jt.concat(all_ret[k], 1) for k in all_ret}
98 | return all_ret
99 |
100 |
101 | def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, intrinsic=None, ndc=True,
102 | near=0., far=1.,
103 | use_viewdirs=False, c2w_staticcam=None,
104 | **kwargs):
105 | """Render rays
106 | Args:
107 | H: int. Height of image in pixels.
108 | W: int. Width of image in pixels.
109 | focal: float. Focal length of pinhole camera.
110 | chunk: int. Maximum number of rays to process simultaneously. Used to
111 | control maximum memory usage. Does not affect final results.
112 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for
113 | each example in batch.
114 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
115 | ndc: bool. If True, represent ray origin, direction in NDC coordinates.
116 | near: float or array of shape [batch_size]. Nearest distance for a ray.
117 | far: float or array of shape [batch_size]. Farthest distance for a ray.
118 | use_viewdirs: bool. If True, use viewing direction of a point in space in model.
119 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
120 | camera while using other c2w argument for viewing directions.
121 | Returns:
122 | rgb_map: [batch_size, 3]. Predicted RGB values for rays.
123 | disp_map: [batch_size]. Disparity map. Inverse of depth.
124 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
125 | extras: dict with everything returned by render_rays().
126 | """
127 | if c2w is not None:
128 | # special case to render full image
129 | print("render c2w",c2w.shape)
130 | rays_o, rays_d = get_rays(H, W, focal, c2w, intrinsic)
131 | else:
132 | # use provided ray batch
133 | rays_o, rays_d = rays
134 |
135 | if use_viewdirs:
136 | # provide ray directions as input
137 | viewdirs = rays_d
138 | if c2w_staticcam is not None:
139 | assert intrinsic is None
140 | # special case to visualize effect of viewdirs
141 | print("render c2w_staticcam",c2w_staticcam.shape)
142 | rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam)
143 | viewdirs = viewdirs / jt.norm(viewdirs, k=2, dim=-1, keepdim=True)
144 | viewdirs = jt.reshape(viewdirs, [-1,3]).float()
145 |
146 | sh = rays_d.shape # [..., 3]
147 | if ndc:
148 | # for forward facing scenes
149 | rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)
150 |
151 | # Create ray batch
152 | rays_o = jt.reshape(rays_o, [-1,3]).float()
153 | rays_d = jt.reshape(rays_d, [-1,3]).float()
154 |
155 | near, far = near * jt.ones_like(rays_d[...,:1]), far * jt.ones_like(rays_d[...,:1])
156 | rays = jt.concat([rays_o, rays_d, near, far], -1)
157 | if use_viewdirs:
158 | rays = jt.concat([rays, viewdirs], -1)
159 |
160 | # Render and reshape
161 | all_ret = batchify_rays(rays, chunk, **kwargs)
162 | for k in all_ret:
163 | if k=="loss_conf" or k=="loss_conf0" or k=="pts" or k=="outnum" or k=="outnum0":
164 | continue
165 | k_sh = [all_ret[k].shape[0]] + list(sh[:-1]) + list(all_ret[k].shape[2:])
166 | all_ret[k] = jt.reshape(all_ret[k], k_sh)
167 |
168 | k_extract = ['rgb_map', 'disp_map', 'acc_map']
169 | ret_list = [all_ret[k] for k in k_extract]
170 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
171 | return ret_list + [ret_dict]
172 |
173 |
174 | def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0, intrinsic = None, get_points = False, log_path = None, large_scene = False):
175 |
176 | H, W, focal = hwf
177 |
178 | if render_factor!=0:
179 | # Render downsampled for speed
180 | H = H//render_factor
181 | W = W//render_factor
182 | focal = focal/render_factor
183 |
184 | rgbs = []
185 | rgbs_log = []
186 | disps = []
187 | outnum0 = []
188 | outnum = []
189 | points = None
190 |
191 | t = time.time()
192 | for i, c2w in enumerate(tqdm(render_poses)):
193 | # import ipdb
194 | # ipdb.set_trace()
195 | print(i, time.time() - t)
196 | t = time.time()
197 | rgb, disp, acc, extras = render(H, W, focal, chunk=chunk, c2w=c2w[:3,:4], intrinsic=intrinsic, **render_kwargs)
198 | rgbs.append(rgb[-1].numpy())
199 | rgbs_log.append(rgb[-2].numpy())
200 | disps.append(disp[-1].numpy())
201 | outnum0.append(np.expand_dims(extras['outnum0'], 0))
202 | outnum.append(np.expand_dims(extras['outnum'], 0))
203 | if get_points:
204 | point = extras['pts']
205 | # conf = extras['conf_map'][-1].numpy()
206 | # point = point.reshape((-1, point.shape[-1]))
207 | # conf = conf.reshape((-1))
208 | # idx=np.random.choice(point.shape[0], point.shape[0]//render_poses.shape[0])
209 | # point = point[idx]
210 | if points is None:
211 | points = point
212 | else:
213 | points = np.concatenate((points, point), axis=0)
214 |
215 | if i==0:
216 | print(rgb.shape, disp.shape)
217 |
218 | """
219 | if gt_imgs is not None and render_factor==0:
220 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
221 | print(p)
222 | """
223 |
224 | if savedir is not None:
225 | rgb8 = to8b(rgbs[-1])
226 | filename = os.path.join(savedir, '{:03d}.png'.format(i))
227 | imageio.imwrite(filename, rgb8)
228 | del rgb
229 | del disp
230 | del acc
231 | del extras
232 |
233 | if large_scene:
234 | if get_points and intrinsic is None:
235 | break
236 | else:
237 | if get_points:
238 | break
239 | if jt.mpi and jt.mpi.local_rank()!=0:
240 | break
241 |
242 | outnum0 = np.concatenate(outnum0, axis=0).sum(0)
243 | outnum = np.concatenate(outnum, axis=0).sum(0)
244 | sout_num = list(outnum0)
245 | log = ""
246 | for i in range(len(sout_num)):
247 | log += str(i)+": %d, " % int(sout_num[i])
248 | sout_num = list(outnum)
249 | log += "\n"
250 | for i in range(len(sout_num)):
251 | log += str(i)+": %d, " % int(sout_num[i])
252 | print(log)
253 | if log_path is not None:
254 | with open(log_path, 'w') as file_object:
255 | file_object.write(log)
256 |
257 | rgbs = np.stack(rgbs, 0)
258 | rgbs_log = np.stack(rgbs_log, 0)
259 | disps = np.stack(disps, 0)
260 |
261 | if get_points:
262 | return rgbs, disps, rgbs_log, points
263 | else:
264 | return rgbs, disps, rgbs_log
265 |
266 | def create_nerf(args):
267 | """Instantiate NeRF's MLP model.
268 | """
269 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
270 |
271 | input_ch_views = 0
272 | embeddirs_fn = None
273 | if args.use_viewdirs:
274 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
275 | output_ch = 5 if args.N_importance > 0 else 4
276 | skips = [4]
277 | model = NeRF(D=args.netdepth, W=args.netwidth,
278 | input_ch=input_ch, output_ch=output_ch, skips=skips,
279 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, head_num=args.head_num, threshold=args.threshold*2.0 if args.large_scene else args.threshold*10.0)
280 | grad_vars = list(model.parameters())
281 |
282 | model_fine = None
283 | if args.N_importance > 0:
284 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
285 | input_ch=input_ch, output_ch=output_ch, skips=skips,
286 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, head_num=args.head_num, threshold=args.threshold)
287 | grad_vars += list(model_fine.parameters())
288 |
289 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
290 | embed_fn=embed_fn,
291 | embeddirs_fn=embeddirs_fn,
292 | netchunk=args.netchunk)
293 |
294 | # Create optimizer
295 | optimizer = jt.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
296 |
297 | start = 0
298 | basedir = args.basedir
299 | expname = args.expname
300 |
301 | ##########################
302 |
303 | # Load checkpoints
304 | if args.ft_path is not None and args.ft_path!='None':
305 | ckpts = [args.ft_path]
306 | else:
307 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
308 |
309 | print('Found ckpts', ckpts)
310 | if len(ckpts) > 0 and not args.no_reload:
311 | ckpt_path = ckpts[-1]
312 | print('Reloading from', ckpt_path)
313 | ckpt = jt.load(ckpt_path)
314 |
315 | start = ckpt['global_step']
316 | # optimizer.load_state_dict(ckpt['optimizer_state_dict'])
317 |
318 | # Load model
319 | model.load_state_dict(ckpt['network_fn_state_dict'])
320 | if model_fine is not None:
321 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
322 |
323 | ##########################
324 |
325 | render_kwargs_train = {
326 | 'network_query_fn' : network_query_fn,
327 | 'perturb' : args.perturb,
328 | 'N_importance' : args.N_importance,
329 | 'network_fine' : model_fine,
330 | 'N_samples' : args.N_samples,
331 | 'network_fn' : model,
332 | 'use_viewdirs' : args.use_viewdirs,
333 | 'white_bkgd' : args.white_bkgd,
334 | 'raw_noise_std' : args.raw_noise_std,
335 | 'threshold' : args.threshold,
336 | }
337 |
338 | # NDC only good for LLFF-style forward facing data
339 | if args.dataset_type != 'llff' or args.no_ndc:
340 | print('Not ndc!')
341 | render_kwargs_train['ndc'] = False
342 | render_kwargs_train['lindisp'] = args.lindisp
343 |
344 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
345 | render_kwargs_test['perturb'] = False
346 | render_kwargs_test['raw_noise_std'] = 0.
347 |
348 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
349 |
350 |
351 | def raw2outputs(raw, conf, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
352 | """Transforms model's predictions to semantically meaningful values.
353 | Args:
354 | raw: [num_rays, num_samples along ray, 4]. Prediction from model.
355 | z_vals: [num_rays, num_samples along ray]. Integration time.
356 | rays_d: [num_rays, 3]. Direction of each ray.
357 | Returns:
358 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
359 | disp_map: [num_rays]. Disparity map. Inverse of depth map.
360 | acc_map: [num_rays]. Sum of weights along each ray.
361 | weights: [num_rays, num_samples]. Weights assigned to each sampled color.
362 | depth_map: [num_rays]. Estimated distance to object.
363 | """
364 | raw2alpha = lambda raw, dists, act_fn=jt.nn.relu: 1.-jt.exp(-act_fn(raw)*dists)
365 |
366 | dists = z_vals[...,1:] - z_vals[...,:-1]
367 | dists = jt.concat([dists, jt.array(np.array([1e10]).astype(np.float32)).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
368 |
369 | dists = dists * jt.norm(rays_d.unsqueeze(-2), k=2, dim=-1)
370 |
371 | rgb = jt.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
372 | noise = 0.
373 | if raw_noise_std > 0.:
374 | noise = jt.init.gauss(raw[...,3].shape, raw.dtype) * raw_noise_std
375 |
376 | # Overwrite randomly sampled data if pytest
377 | if pytest:
378 | # np.random.seed(0)
379 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
380 | noise = jt.array(noise)
381 |
382 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples]
383 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
384 | weights = alpha * jt.cumprod(jt.concat([jt.ones((1,alpha.shape[1], 1)), 1.-alpha + 1e-10], -1), -1)[..., :-1]
385 | rgb_map = jt.sum(weights.unsqueeze(-1) * rgb, -2) # [N_rays, 3]
386 | # conf_map = jt.sum(weights.unsqueeze(-1) * conf, -2) # [N_rays, 1]
387 | # conf_map = jt.mean(conf, -2) # [N_rays, 1]
388 | conf_map = conf
389 |
390 | depth_map = jt.sum(weights * z_vals, -1)
391 | disp_map = 1./jt.maximum(1e-10 * jt.ones_like(depth_map), depth_map / jt.sum(weights, -1))
392 | acc_map = jt.sum(weights, -1)
393 |
394 | if white_bkgd:
395 | rgb_map = rgb_map + (1.-acc_map.unsqueeze(-1))
396 |
397 | return rgb_map, disp_map, acc_map, weights, depth_map, conf_map
398 |
399 |
400 | def render_rays(ray_batch,
401 | network_fn,
402 | network_query_fn,
403 | N_samples,
404 | retraw=False,
405 | lindisp=False,
406 | perturb=0.,
407 | N_importance=0,
408 | network_fine=None,
409 | white_bkgd=False,
410 | raw_noise_std=0.,
411 | threshold=3e-2,
412 | verbose=False,
413 | pytest=False):
414 | """Volumetric rendering.
415 | Args:
416 | ray_batch: array of shape [batch_size, ...]. All information necessary
417 | for sampling along a ray, including: ray origin, ray direction, min
418 | dist, max dist, and unit-magnitude viewing direction.
419 | network_fn: function. Model for predicting RGB and density at each point
420 | in space.
421 | network_query_fn: function used for passing queries to network_fn.
422 | N_samples: int. Number of different times to sample along each ray.
423 | retraw: bool. If True, include model's raw, unprocessed predictions.
424 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
425 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
426 | random points in time.
427 | N_importance: int. Number of additional times to sample along each ray.
428 | These samples are only passed to network_fine.
429 | network_fine: "fine" network with same spec as network_fn.
430 | white_bkgd: bool. If True, assume a white background.
431 | raw_noise_std: ...
432 | verbose: bool. If True, print more debugging info.
433 | Returns:
434 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
435 | disp_map: [num_rays]. Disparity map. 1 / depth.
436 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
437 | raw: [num_rays, num_samples, 4]. Raw predictions from model.
438 | rgb0: See rgb_map. Output for coarse model.
439 | disp0: See disp_map. Output for coarse model.
440 | acc0: See acc_map. Output for coarse model.
441 | z_std: [num_rays]. Standard deviation of distances along ray for each
442 | sample.
443 | """
444 | training = not jt.flags.no_grad
445 | N_rays = ray_batch.shape[0]
446 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
447 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
448 | bounds = jt.reshape(ray_batch[...,6:8], [-1,1,2])
449 | near, far = bounds[...,0], bounds[...,1] # [-1,1]
450 |
451 | t_vals = jt.linspace(0., 1., steps=N_samples)
452 | if not lindisp:
453 | z_vals = near * (1.-t_vals) + far * (t_vals)
454 | else:
455 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
456 |
457 | z_vals = z_vals.expand([N_rays, N_samples])
458 |
459 | if perturb > 0.:
460 | # get intervals between samples
461 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
462 | upper = jt.concat([mids, z_vals[...,-1:]], -1)
463 | lower = jt.concat([z_vals[...,:1], mids], -1)
464 | # stratified samples in those intervals
465 | t_rand = jt.random(z_vals.shape)
466 |
467 | # Pytest, overwrite u with numpy's fixed random numbers
468 | if pytest:
469 | # np.random.seed(0)
470 | t_rand = np.random.rand(*list(z_vals.shape))
471 | t_rand = jt.array(t_rand)
472 |
473 | z_vals = lower + (upper - lower) * t_rand
474 |
475 | pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N_rays, N_samples, 3]
476 |
477 |
478 | # raw = run_network(pts)
479 | raw, conf, loss_conf, outnum = network_query_fn(pts, viewdirs, network_fn)
480 | rgb_map, disp_map, acc_map, weights, depth_map, conf_map = raw2outputs(raw, conf, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
481 |
482 | if N_importance > 0:
483 |
484 | rgb_map_0, disp_map_0, acc_map_0, conf_map_0, loss_conf_0, weights_0, outnum_0 = rgb_map, disp_map, acc_map, conf_map, loss_conf, weights, outnum
485 |
486 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
487 | z_samples = sample_pdf(z_vals_mid, weights[-1,...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
488 | z_samples = z_samples.detach()
489 |
490 | _, z_vals = jt.argsort(jt.concat([z_vals, z_samples], -1), -1)
491 | pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N_rays, N_samples + N_importance, 3]
492 |
493 | run_fn = network_fn if network_fine is None else network_fine
494 | # raw = run_network(pts, fn=run_fn)
495 | raw, conf, loss_conf, outnum = network_query_fn(pts, viewdirs, run_fn)
496 | # print("raw",raw.shape)
497 | # print("pts",pts.shape)
498 |
499 | rgb_map, disp_map, acc_map, weights, depth_map, conf_map = raw2outputs(raw, conf, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
500 | # print("rgb_map",rgb_map.shape)
501 | # print("conf_map",conf_map.shape)
502 | if training:
503 | ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map, 'conf_map' : conf_map, 'loss_conf' : loss_conf, 'weights' : weights}
504 | if retraw:
505 | ret['raw'] = raw
506 | if N_importance > 0:
507 | ret['rgb0'] = rgb_map_0
508 | ret['disp0'] = disp_map_0
509 | ret['acc0'] = acc_map_0
510 | ret['conf_map0'] = conf_map_0
511 | ret['loss_conf0'] = loss_conf_0
512 | ret['weights0'] = weights_0
513 | # ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] TODO: support jt.std
514 | else:
515 | point = pts.numpy()
516 | conf = conf_map[-1].numpy()
517 | point = point.reshape((-1, point.shape[-1]))
518 | conf = conf.reshape((-1))
519 | idx=np.random.choice(point.shape[0], point.shape[0]//10, replace=False)
520 | point = point[idx]
521 | conf = conf[idx]
522 | # threshold = 0.0
523 | point = point[conf>=threshold]
524 | ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map, 'pts' : point, 'outnum' : np.expand_dims(outnum,0)}
525 | if retraw:
526 | ret['raw'] = raw
527 | if N_importance > 0:
528 | ret['rgb0'] = rgb_map_0
529 | ret['disp0'] = disp_map_0
530 | ret['acc0'] = acc_map_0
531 | ret['outnum0'] = np.expand_dims(outnum_0,0)
532 |
533 | # for k in ret:
534 | # if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
535 | # print(f"! [Numerical Error] {k} contains nan or inf.")
536 |
537 | return ret
538 |
539 | def dfs(t, points, model, model_fine):
540 | k = len(model.son_list[t])
541 | if t in model.force_out:
542 | if points.shape[0]>=k:
543 | centroid = points[jt.array(np.random.choice(points.shape[0], k, replace=False))]
544 | print("centroid",centroid.shape)
545 | # print("step",-1,centroid.numpy())
546 | for step in range(100):
547 | dis = (points.unsqueeze(1) - centroid.unsqueeze(0)).sqr().sum(-1).sqrt()
548 | min_idx, _ = jt.argmin(dis,-1)
549 | # print("min_idx",min_idx.shape)
550 | for i in range(k):
551 | # print("i",i,(min_idx==i).sum())
552 | centroid[i] = points[min_idx==i].mean(0)
553 | # print("step",step,centroid.numpy())
554 | else:
555 | centroid = jt.rand((k,3))
556 | print("centroid fail",centroid.shape)
557 | setattr(model, model.node_list[t].anchors, centroid.detach())
558 | setattr(model_fine, model_fine.node_list[t].anchors, centroid.detach())
559 | if jt.mpi:
560 | # v1 = getattr(model, model.node_list[t].anchors)
561 | # v1.assign(jt.mpi.broadcast(v2, root=0))
562 | # v2 = getattr(model_fine, model_fine.node_list[t].anchors)
563 | # v2.assign(jt.mpi.broadcast(v2, root=0))
564 | jt.mpi.broadcast(getattr(model, model.node_list[t].anchors), root=0)
565 | jt.mpi.broadcast(getattr(model_fine, model_fine.node_list[t].anchors), root=0)
566 | print("model", jt.mpi.local_rank(), t, getattr(model, model.node_list[t].anchors))
567 | print("modelfine", jt.mpi.local_rank(), t, getattr(model_fine, model_fine.node_list[t].anchors))
568 | for i in model.son_list[t]:
569 | model.outnet[i].alpha_linear.load_state_dict(model.outnet[t].alpha_linear.state_dict())
570 | model_fine.outnet[i].alpha_linear.load_state_dict(model_fine.outnet[t].alpha_linear.state_dict())
571 | # model.outnet[i].load_state_dict(model.outnet[t].state_dict())
572 | # model_fine.outnet[i].load_state_dict(model_fine.outnet[t].state_dict())
573 | return model.son_list[t]
574 | else:
575 | centroid = model.get_anchor(model.node_list[t].anchors)
576 | dis = (points.unsqueeze(1) - centroid.unsqueeze(0)).sqr().sum(-1).sqrt()
577 | min_idx, _ = jt.argmin(dis,-1)
578 | res = []
579 | for i in range(k):
580 | res += dfs(model.son_list[t][i], points[min_idx==i], model, model_fine)
581 | return res
582 |
583 | def do_kmeans(points, model, model_fine):
584 | # points = points.reshape(-1, points.shape[-1])
585 | # confs = confs.reshape(-1)
586 | print("do_kmeans",points.shape)
587 | points = jt.array(points)
588 | force_out = dfs(0, points, model, model_fine)
589 |
590 | model.force_out = force_out
591 | model_fine.force_out = force_out
592 |
593 | def config_parser():
594 | gpu = "gpu"+os.environ["CUDA_VISIBLE_DEVICES"]
595 | import configargparse
596 | parser = configargparse.ArgumentParser()
597 | parser.add_argument('--config', is_config_file=True,
598 | help='config file path')
599 | parser.add_argument("--expname", type=str,
600 | help='experiment name')
601 | parser.add_argument("--basedir", type=str, default='./logs/'+gpu+"/",
602 | help='where to store ckpts and logs')
603 | parser.add_argument("--datadir", type=str, default='./data/llff/fern',
604 | help='input data directory')
605 |
606 | # training options
607 | parser.add_argument("--netdepth", type=int, default=8,
608 | help='layers in network')
609 | parser.add_argument("--step1", type=int, default=5000,
610 | help='?')
611 | parser.add_argument("--step2", type=int, default=10000,
612 | help='?')
613 | parser.add_argument("--step3", type=int, default=15000,
614 | help='?')
615 | parser.add_argument("--netwidth", type=int, default=256,
616 | help='channels per layer')
617 | parser.add_argument("--netdepth_fine", type=int, default=8,
618 | help='layers in fine network')
619 | parser.add_argument("--netwidth_fine", type=int, default=256,
620 | help='channels per layer in fine network')
621 | parser.add_argument("--N_rand", type=int, default=32*32*4,
622 | help='batch size (number of random rays per gradient step)')
623 | parser.add_argument("--lrate", type=float, default=5e-4,
624 | help='learning rate')
625 | parser.add_argument("--lrate_decay", type=int, default=250,
626 | help='exponential learning rate decay (in 1000 steps)')
627 | parser.add_argument("--chunk", type=int, default=1024*8,
628 | help='number of rays processed in parallel, decrease if running out of memory')
629 | parser.add_argument("--netchunk", type=int, default=1024*64,
630 | help='number of pts sent through network in parallel, decrease if running out of memory')
631 | parser.add_argument("--no_batching", action='store_true',
632 | help='only take random rays from 1 image at a time')
633 | parser.add_argument("--no_reload", action='store_true',
634 | help='do not reload weights from saved ckpt')
635 | parser.add_argument("--ft_path", type=str, default=None,
636 | help='specific weights npy file to reload for coarse network')
637 | parser.add_argument("--threshold", type=float, default=1e-2,
638 | help='threshold')
639 |
640 | # rendering options
641 | parser.add_argument("--N_samples", type=int, default=64,
642 | help='number of coarse samples per ray')
643 | parser.add_argument("--N_importance", type=int, default=0,
644 | help='number of additional fine samples per ray')
645 | parser.add_argument("--perturb", type=float, default=1.,
646 | help='set to 0. for no jitter, 1. for jitter')
647 | parser.add_argument("--use_viewdirs", action='store_true',
648 | help='use full 5D input instead of 3D')
649 | parser.add_argument("--i_embed", type=int, default=0,
650 | help='set 0 for default positional encoding, -1 for none')
651 | parser.add_argument("--multires", type=int, default=10,
652 | help='log2 of max freq for positional encoding (3D location)')
653 | parser.add_argument("--multires_views", type=int, default=4,
654 | help='log2 of max freq for positional encoding (2D direction)')
655 | parser.add_argument("--raw_noise_std", type=float, default=0.,
656 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
657 |
658 | parser.add_argument("--render_only", action='store_true',
659 | help='do not optimize, reload weights and render out render_poses path')
660 | parser.add_argument("--render_test", action='store_true',
661 | help='render the test set instead of render_poses path')
662 | parser.add_argument("--render_factor", type=int, default=0,
663 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
664 |
665 | # training options
666 | parser.add_argument("--precrop_iters", type=int, default=0,
667 | help='number of steps to train on central crops')
668 | parser.add_argument("--head_num", type=int, default=8,
669 | help='number of heads')
670 | parser.add_argument("--precrop_frac", type=float,
671 | default=.5, help='fraction of img taken for central crops')
672 | parser.add_argument("--large_scene", action='store_true',
673 | help='use large scene')
674 |
675 | # dataset options
676 | parser.add_argument("--dataset_type", type=str, default='llff',
677 | help='options: llff / blender / deepvoxels')
678 | parser.add_argument("--testskip", type=int, default=8,
679 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
680 | parser.add_argument("--faketestskip", type=int, default=1,
681 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
682 |
683 | ## deepvoxels flags
684 | parser.add_argument("--shape", type=str, default='greek',
685 | help='options : armchair / cube / greek / vase')
686 |
687 | ## blender flags
688 | parser.add_argument("--white_bkgd", action='store_true',
689 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
690 | parser.add_argument("--half_res", action='store_true',
691 | help='load blender synthetic data at 400x400 instead of 800x800')
692 | parser.add_argument("--near", type=float, default=2.,
693 | help='downsample factor for LLFF images')
694 | parser.add_argument("--far", type=float, default=6.,
695 | help='downsample factor for LLFF images')
696 | parser.add_argument("--do_intrinsic", action='store_true',
697 | help='use intrinsic matrix')
698 | parser.add_argument("--blender_factor", type=int, default=1,
699 | help='downsample factor for blender images')
700 |
701 | ## llff flags
702 | parser.add_argument("--factor", type=int, default=8,
703 | help='downsample factor for LLFF images')
704 | parser.add_argument("--no_ndc", action='store_true',
705 | help='do not use normalized device coordinates (set for non-forward facing scenes)')
706 | parser.add_argument("--lindisp", action='store_true',
707 | help='sampling linearly in disparity rather than depth')
708 | parser.add_argument("--spherify", action='store_true',
709 | help='set for spherical 360 scenes')
710 | parser.add_argument("--llffhold", type=int, default=8,
711 | help='will take every 1/N images as LLFF test set, paper uses 8')
712 |
713 | # logging/saving options
714 | parser.add_argument("--i_print", type=int, default=100,
715 | help='frequency of console printout and metric loggin')
716 | parser.add_argument("--i_img", type=int, default=5000,
717 | help='frequency of tensorboard image logging')
718 | parser.add_argument("--i_weights", type=int, default=10000,
719 | help='frequency of weight ckpt saving')
720 | parser.add_argument("--i_testset", type=int, default=50000,
721 | help='frequency of testset saving')
722 | parser.add_argument("--i_tottest", type=int, default=400000,
723 | help='frequency of testset saving')
724 | parser.add_argument("--i_video", type=int, default=50000,
725 | help='frequency of render_poses video saving')
726 |
727 | return parser
728 |
729 |
730 | def train():
731 |
732 | parser = config_parser()
733 | args = parser.parse_args()
734 |
735 | # Load data
736 | intrinsic = None
737 | if args.dataset_type == 'llff':
738 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
739 | recenter=True, bd_factor=.75,
740 | spherify=args.spherify)
741 | hwf = poses[0,:3,-1]
742 | poses = poses[:,:3,:4]
743 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
744 | if not isinstance(i_test, list):
745 | i_test = [i_test]
746 |
747 | if args.llffhold > 0:
748 | print('Auto LLFF holdout,', args.llffhold)
749 | i_test = np.arange(images.shape[0])[::args.llffhold]
750 |
751 | i_val = i_test
752 | i_test_tot = i_test
753 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if
754 | (i not in i_test and i not in i_val)])
755 |
756 | print('DEFINING BOUNDS')
757 | if args.no_ndc:
758 | near = np.ndarray.min(bds) * .9
759 | far = np.ndarray.max(bds) * 1.
760 |
761 | else:
762 | near = 0.
763 | far = 1.
764 | print('NEAR FAR', near, far)
765 |
766 | elif args.dataset_type == 'blender':
767 | testskip = args.testskip
768 | faketestskip = args.faketestskip
769 | if jt.mpi and jt.mpi.local_rank()!=0:
770 | testskip = faketestskip
771 | faketestskip = 1
772 | if args.do_intrinsic:
773 | images, poses, intrinsic, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, testskip, args.blender_factor, True)
774 | else:
775 | images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, testskip, args.blender_factor)
776 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
777 | i_train, i_val, i_test = i_split
778 | i_test_tot = i_test
779 | i_test = i_test[::faketestskip]
780 |
781 | near = args.near
782 | far = args.far
783 | print("near", near)
784 | print("far", far)
785 |
786 | if args.white_bkgd:
787 | # accs = images[...,-1]
788 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
789 | else:
790 | images = images[...,:3]
791 |
792 | elif args.dataset_type == 'deepvoxels':
793 |
794 | images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
795 | basedir=args.datadir,
796 | testskip=args.testskip)
797 |
798 | print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
799 | i_train, i_val, i_test = i_split
800 |
801 | hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
802 | near = hemi_R-1.
803 | far = hemi_R+1.
804 |
805 | else:
806 | print('Unknown dataset type', args.dataset_type, 'exiting')
807 | return
808 |
809 | # Cast intrinsics to right types
810 | H, W, focal = hwf
811 | H, W = int(H), int(W)
812 | hwf = [H, W, focal]
813 |
814 | if args.render_test:
815 | render_poses = np.array(poses[i_test])
816 |
817 | # Create log dir and copy the config file
818 | basedir = args.basedir
819 | expname = args.expname
820 | os.makedirs(os.path.join(basedir, expname), exist_ok=True)
821 | f = os.path.join(basedir, expname, 'args.txt')
822 | with open(f, 'w') as file:
823 | for arg in sorted(vars(args)):
824 | attr = getattr(args, arg)
825 | file.write('{} = {}\n'.format(arg, attr))
826 | if args.config is not None:
827 | f = os.path.join(basedir, expname, 'config.txt')
828 | with open(f, 'w') as file:
829 | file.write(open(args.config, 'r').read())
830 |
831 | # Create nerf model
832 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
833 | global_step = start
834 |
835 | bds_dict = {
836 | 'near' : near,
837 | 'far' : far,
838 | }
839 | render_kwargs_train.update(bds_dict)
840 | render_kwargs_test.update(bds_dict)
841 |
842 | # Move testing data to GPU
843 | render_poses = jt.array(render_poses)
844 |
845 | # Short circuit if only rendering out from trained model
846 | if args.render_only:
847 | print('RENDER ONLY')
848 | with jt.no_grad():
849 | if args.render_test:
850 | # render_test switches to test poses
851 | images = images[i_test]
852 | else:
853 | # Default is smoother render_poses path
854 | images = None
855 |
856 | testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))
857 | os.makedirs(testsavedir, exist_ok=True)
858 | print('test poses shape', render_poses.shape)
859 |
860 | rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor, intrinsic = intrinsic)
861 | print('Done rendering', testsavedir)
862 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
863 |
864 | return
865 |
866 | # Prepare raybatch tensor if batching random rays
867 | accumulation_steps = 2
868 | N_rand = args.N_rand//accumulation_steps
869 | use_batching = not args.no_batching
870 | if use_batching:
871 | # For random ray batching
872 | print('get rays')
873 | rays = np.stack([get_rays_np(H, W, focal, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
874 | print('done, concats')
875 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
876 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
877 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
878 | rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
879 | rays_rgb = rays_rgb.astype(np.float32)
880 | print('shuffle rays')
881 | np.random.shuffle(rays_rgb)
882 |
883 | print('done')
884 | i_batch = 0
885 |
886 | # Move training data to GPU
887 | images = jt.array(images.astype(np.float32))
888 | # accs = jt.array(accs.astype(np.float32))
889 | # a=images[0].copy()
890 | # b=images[1].copy()
891 | # c=images[2].copy()
892 | # print("images0",a.sum())
893 | # print("images1",b.sum())
894 | # print("images2",c.sum())
895 | # print("images0",images[0].numpy().sum())
896 | # print("images1",images[1].numpy().sum())
897 | # print("images2",images[2].numpy().sum())
898 | # print("images0",images[0].sum().numpy())
899 | # print("images1",images[1].sum().numpy())
900 | # print("images2",images[2].sum().numpy())
901 | poses = jt.array(poses)
902 | if use_batching:
903 | rays_rgb = jt.array(rays_rgb)
904 |
905 |
906 | N_iters = 300000*accumulation_steps + 1
907 | split_tree1 = args.step1*accumulation_steps
908 | split_tree2 = args.step2*accumulation_steps
909 | split_tree3 = args.step3*accumulation_steps
910 | # split_tree1 = args.i_img
911 | # split_tree2 = args.i_img*2
912 | print('Begin')
913 | print('TRAIN views are', i_train)
914 | print('TEST views are', i_test)
915 | print('VAL views are', i_val)
916 |
917 | # Summary writers
918 | # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
919 | if not jt.mpi or jt.mpi.local_rank()==0:
920 | date = str(datetime.datetime.now())
921 | date = date[:date.rfind(":")].replace("-", "")\
922 | .replace(":", "")\
923 | .replace(" ", "_")
924 | gpu_idx = os.environ["CUDA_VISIBLE_DEVICES"]
925 | log_dir = os.path.join("./logs", "summaries", f"log_{date}_gpu{gpu_idx}_{args.expname}")
926 | if not os.path.exists(log_dir):
927 | os.makedirs(log_dir)
928 | writer = SummaryWriter(log_dir=log_dir)
929 |
930 | start = start + 1
931 | # if not use_batching and jt.mpi:
932 | # img_i_list = np.random.choice(i_train, N_iters)
933 | # print("before img_i_list",img_i_list.sum())
934 | # jt.mpi.broadcast(img_i_list, root=0)
935 | # print("after img_i_list",img_i_list.sum())
936 | for i in trange(start, N_iters):
937 | # print("i",i,"jt.mpi.local_rank()",jt.mpi.local_rank())
938 | # jt.display_memory_info()
939 | time0 = time.time()
940 |
941 | # Sample random ray batch
942 | if use_batching:
943 | # Random over all images
944 | batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
945 | batch = jt.transpose(batch, (1, 0, 2))
946 | batch_rays, target_s = batch[:2], batch[2]
947 |
948 | i_batch += N_rand
949 | if i_batch >= rays_rgb.shape[0]:
950 | print("Shuffle data after an epoch!")
951 | rand_idx = jt.randperm(rays_rgb.shape[0])
952 | rays_rgb = rays_rgb[rand_idx]
953 | i_batch = 0
954 |
955 | else:
956 | # Random from one image
957 | # if jt.mpi:
958 | # img_i = img_i_list[i]
959 | # else:
960 | # img_i = np.random.choice(i_train)
961 | img_i = np.random.choice(i_train)
962 | target = images[img_i]#.squeeze(0)
963 | # acc_target = accs[img_i]
964 | pose = poses[img_i, :3,:4]#.squeeze(0)
965 |
966 | if N_rand is not None:
967 | rays_o, rays_d = get_rays(H, W, focal, pose, intrinsic) # (H, W, 3), (H, W, 3)
968 |
969 | if i < args.precrop_iters:
970 | dH = int(H//2 * args.precrop_frac)
971 | dW = int(W//2 * args.precrop_frac)
972 | coords = jt.stack(
973 | jt.meshgrid(
974 | jt.linspace(H//2 - dH, H//2 + dH - 1, 2*dH),
975 | jt.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
976 | ), -1)
977 | if i == start:
978 | print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
979 | else:
980 | coords = jt.stack(jt.meshgrid(jt.linspace(0, H-1, H), jt.linspace(0, W-1, W)), -1) # (H, W, 2)
981 |
982 | coords = jt.reshape(coords, [-1,2]) # (H * W, 2)
983 | # if jt.mpi:
984 | # assert coords.shape[0]%jt.mpi.world_size()==0
985 | # select_inds = np.random.choice(coords.shape[0]//jt.mpi.world_size(), size=[N_rand], replace=False) # (N_rand,)
986 | # select_inds += (coords.shape[0]//jt.mpi.world_size())*jt.mpi.local_rank()
987 | # else:
988 | # select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
989 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
990 | select_coords = coords[select_inds].int() # (N_rand, 2)
991 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
992 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
993 | batch_rays = jt.stack([rays_o, rays_d], 0)
994 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
995 | # target_a = acc_target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 1)
996 |
997 | ##### Core optimization loop #####
998 | rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays,
999 | verbose=i < 10, retraw=True,
1000 | **render_kwargs_train)
1001 | # print("rgb",rgb.shape)
1002 | # print("target_s",target_s.shape)
1003 | img_loss = img2mse(rgb, target_s.unsqueeze(0))
1004 | # img_loss = (img2mse(rgb[-1], target_s.unsqueeze(0))*10.0+img2mse(rgb[:-1], target_s.unsqueeze(0)))/11.0
1005 | trans = extras['raw'][...,-1]
1006 | loss = img_loss
1007 | # loss_conf = img2mse(extras['conf_map'], (rgb-target_s)**2)
1008 | # loss_conf = jt.maximum((rgb-target_s)**2-extras['conf_map'],0.)+jt.maximum(extras['conf_map']+0.01,0.).mean()
1009 | sloss = jt.maximum(((rgb-target_s)**2).detach().unsqueeze(-2)-extras['conf_map'],0.)
1010 | # loss_conf_loss = jt.sum(extras['weights'].unsqueeze(-1) * sloss, -2).mean()
1011 | loss_conf_loss = sloss.mean()
1012 | # loss_conf_loss = jt.maximum((rgb-target_s)**2-extras['conf_map'],0.).mean()
1013 | loss_conf_mean = jt.maximum(extras['conf_map'], 0.).mean()
1014 | # print("(rgb-target_s)",(rgb-target_s).shape)
1015 | # print("loss_conf_loss",loss_conf_loss.shape)
1016 | # print("loss_conf_mean",loss_conf_mean.shape)
1017 | # loss_conf = jt.maximum((rgb-target_s)**2-extras['conf_map'],0.).sum()
1018 | # print("ssqr",ssqr.shape,"mean",ssqr.mean(),"max",ssqr.max())
1019 | # print("extras['conf_map']",extras['conf_map'].shape,"mean",extras['conf_map'].mean(),"max",extras['conf_map'].max())
1020 | # print("extras['loss_conf']",extras['loss_conf'].shape,"mean",extras['loss_conf'].mean(),"max",extras['loss_conf'].max())
1021 | # psnr = mse2psnr(img_loss)
1022 | # psnr = mse2psnr(img2mse(rgb[-2], target_s.unsqueeze(0)))
1023 | psnr = mse2psnr(img2mse(rgb[-1], target_s.unsqueeze(0)))
1024 | # loss_acc = jt.zeros([1])
1025 | # if args.white_bkgd:
1026 | # # print("acc",acc.shape)
1027 | # # print("target_a",target_a.shape)
1028 | # loss_acc = loss_acc+img2mse(acc, target_a.unsqueeze(0))
1029 |
1030 | if 'rgb0' in extras:
1031 | img_loss0 = img2mse(extras['rgb0'], target_s)
1032 | # img_loss0 = (img2mse(extras['rgb0'][-1], target_s.unsqueeze(0))*10.0+img2mse(extras['rgb0'][:-1], target_s.unsqueeze(0)))/11.0
1033 | loss = loss + img_loss0
1034 | # loss_conf = loss_conf + img2mse(extras['conf_map0'], (extras['rgb0']-target_s)**2)
1035 | # loss_conf = loss_conf + jt.maximum((extras['rgb0']-target_s)**2-extras['conf_map0'],0.)+jt.maximum(extras['conf_map0']+0.01,0.)
1036 | sloss = jt.maximum(((extras['rgb0']-target_s)**2).detach().unsqueeze(-2)-extras['conf_map0'],0.)
1037 | # loss_conf_loss = loss_conf_loss + jt.sum(extras['weights0'].unsqueeze(-1) * sloss, -2).mean()
1038 | loss_conf_loss = loss_conf_loss + sloss.mean()
1039 | # loss_conf_loss = loss_conf_loss + jt.maximum((extras['rgb0']-target_s)**2-extras['conf_map0'],0.).mean()
1040 | loss_conf_mean = loss_conf_mean + jt.maximum(extras['conf_map0'], 0.).mean()
1041 | # loss_conf = loss_conf + jt.maximum((extras['rgb0']-target_s)**2-extras['conf_map0'],0.).sum()
1042 | psnr0 = mse2psnr(img_loss0)
1043 | # if args.white_bkgd:
1044 | # loss_acc = loss_acc+img2mse(extras['acc0'], target_a.unsqueeze(0))
1045 | loss_conf_loss = loss_conf_loss
1046 | # print("loss_conf_loss",loss_conf_loss)
1047 | # print("loss_conf_mean",loss_conf_mean)
1048 | loss_conf = loss_conf_loss+loss_conf_mean*0.01
1049 | loss = loss + loss_conf*0.1
1050 | # loss = loss + loss_acc
1051 |
1052 | jt.sync_all()
1053 | optimizer.backward(loss / accumulation_steps)
1054 | if i % accumulation_steps == 0:
1055 | optimizer.step()
1056 | jt.sync_all()
1057 | # optimizer.step(loss)
1058 |
1059 | # if global_step==10000:
1060 | # if global_step==0:
1061 | # render_kwargs_train['network_fn'].force_out = 15
1062 | # render_kwargs_train['network_fine'].force_out = 15
1063 |
1064 | # NOTE: IMPORTANT!
1065 | ### update learning rate ###
1066 | decay_rate = 0.1
1067 | decay_steps = args.lrate_decay * accumulation_steps * 1000
1068 | sstep = global_step
1069 | if sstep>split_tree3:
1070 | sstep-=split_tree3
1071 | elif sstep>split_tree2:
1072 | sstep-=split_tree2
1073 | elif sstep>split_tree1:
1074 | sstep-=split_tree1
1075 | new_lrate = args.lrate * (decay_rate ** (sstep / decay_steps))
1076 | for param_group in optimizer.param_groups:
1077 | param_group['lr'] = new_lrate
1078 | ################################
1079 |
1080 | dt = time.time()-time0
1081 | # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
1082 | ##### end #####
1083 |
1084 | # Rest is logging
1085 | if (i+1)%args.i_weights==0:
1086 | if (not jt.mpi or jt.mpi.local_rank()==0):
1087 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
1088 | else:
1089 | path = os.path.join(basedir, expname, 'tmp.tar')
1090 | jt.save({
1091 | 'global_step': global_step,
1092 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
1093 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
1094 | # 'optimizer_state_dict': optimizer.state_dict(),
1095 | }, path)
1096 | print('Saved checkpoints at', path)
1097 |
1098 | # jt.display_memory_info()
1099 | if i%args.i_video==0 and i > 0 or i==split_tree1 or i==split_tree2 or i==split_tree3:
1100 | # import ipdb
1101 | # ipdb.set_trace()
1102 | # Turn on testing mode
1103 | if not jt.mpi or jt.mpi.local_rank()==0:
1104 | with jt.no_grad():
1105 | rgbs, disps, rgbs_log, points = render_path(render_poses, hwf, args.chunk, render_kwargs_test, intrinsic = intrinsic, get_points = True, large_scene = args.large_scene)
1106 | else:
1107 | points = jt.random([1000,3])
1108 | if i==split_tree1 or i==split_tree2 or i==split_tree3:
1109 | do_kmeans(points, render_kwargs_train['network_fn'], render_kwargs_train['network_fine'])
1110 | # jt.display_memory_info()
1111 | if not jt.mpi or jt.mpi.local_rank()==0:
1112 | print('Done, saving', rgbs.shape, disps.shape)
1113 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
1114 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
1115 | imageio.mimwrite(moviebase + 'rgb_log.mp4', to8b(rgbs_log), fps=30, quality=8)
1116 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
1117 | jt.gc()
1118 |
1119 | # if args.use_viewdirs:
1120 | # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
1121 | # with jt.no_grad():
1122 | # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
1123 | # render_kwargs_test['c2w_staticcam'] = None
1124 | # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)
1125 | if i%args.i_testset==0 and i > 0 and (not jt.mpi or jt.mpi.local_rank()==0):
1126 | si_test = i_test_tot if i%args.i_tottest==0 else i_test
1127 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
1128 | os.makedirs(testsavedir, exist_ok=True)
1129 | print('test poses shape', poses[si_test].shape)
1130 | with jt.no_grad():
1131 | rgbs, disps, rgbs_log = render_path(jt.array(poses[si_test]), hwf, args.chunk, render_kwargs_test, gt_imgs=images[si_test], savedir=testsavedir, intrinsic = intrinsic, log_path = os.path.join(basedir, expname, 'outnum_{:06d}.txt'.format(i)))
1132 | tars = images[si_test]
1133 | testpsnr = mse2psnr(img2mse(jt.array(rgbs), tars)).item()
1134 | if not jt.mpi or jt.mpi.local_rank()==0:
1135 | writer.add_scalar('test/psnr_tot', testpsnr, global_step)
1136 | print('Saved test set')
1137 |
1138 |
1139 |
1140 |
1141 | if i%args.i_print==0:
1142 | tqdm.write(f"[TRAIN] Iter: {i} expname: {args.expname} Loss: {loss.item()} LossConf: {loss_conf.item()} PSNR: {psnr.item()}")
1143 | a=psnr0.item()
1144 | # print("before:",jt.mpi.local_rank())
1145 | if not jt.mpi or jt.mpi.local_rank()==0:
1146 | writer.add_scalar("train/loss", loss.item(), global_step)
1147 | # writer.add_scalar("train/loss_acc", loss_acc.item(), global_step)
1148 | writer.add_scalar("train/loss_conf", loss_conf.item(), global_step)
1149 | writer.add_scalar("train/PSNR", psnr.item(), global_step)
1150 | writer.add_scalar("lr/lr", new_lrate, global_step)
1151 | # writer.add_histogram('tran', trans.numpy(), global_step)
1152 | if args.N_importance > 0:
1153 | writer.add_scalar("train/PSNR0", psnr0.item(), global_step)
1154 | # print("done:",jt.mpi.local_rank())
1155 | # print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
1156 | # print('iter time {:.05f}'.format(dt))
1157 |
1158 | # with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
1159 | # tf.contrib.summary.scalar('loss', loss)
1160 | # tf.contrib.summary.scalar('psnr', psnr)
1161 | # tf.contrib.summary.histogram('tran', trans)
1162 | # if args.N_importance > 0:
1163 | # tf.contrib.summary.scalar('psnr0', psnr0)
1164 |
1165 |
1166 | if i%args.i_img==0:
1167 |
1168 | # Log a rendered validation view to Tensorboard
1169 | img_i=np.random.choice(i_val)
1170 | target = images[img_i]
1171 | pose = poses[img_i, :3,:4]
1172 | with jt.no_grad():
1173 | rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, intrinsic=intrinsic,
1174 | **render_kwargs_test)
1175 | sout_num = list(extras['outnum0'])
1176 | log = "i_img out_num0:"
1177 | for k in range(len(sout_num)):
1178 | log += str(k)+": %d; " % int(sout_num[k])
1179 | print(log)
1180 | sout_num = list(extras['outnum'])
1181 | log = "i_img out_num:"
1182 | for k in range(len(sout_num)):
1183 | log += str(k)+": %d; " % int(sout_num[k])
1184 | print(log)
1185 | psnr = mse2psnr(img2mse(rgb[-1], target))
1186 | # psnr = mse2psnr(img2mse(rgb[-2], target))
1187 | rgb_log = rgb[-2].numpy()
1188 | rgb = rgb[-1].numpy()
1189 | rgb0 = extras['rgb0'][-1].numpy()
1190 | # rgb = extras['rgb0'][-1].numpy()
1191 | disp = disp[-1].numpy()
1192 | acc = acc[-1].numpy()
1193 | acc0 = extras['acc0'][-1].numpy()
1194 | a = target.numpy()
1195 | a = psnr.item()
1196 |
1197 | if not jt.mpi or jt.mpi.local_rank()==0:
1198 | writer.add_image('test/rgb', to8b(rgb), global_step, dataformats="HWC")
1199 | writer.add_image('test/rgb0', to8b(rgb0), global_step, dataformats="HWC")
1200 | writer.add_image('log/rgb', to8b(rgb_log), global_step, dataformats="HWC")
1201 | # writer.add_image('test/disp', disp[...,np.newaxis], global_step, dataformats="HWC")
1202 | # writer.add_image('test/acc', acc[...,np.newaxis], global_step, dataformats="HWC")
1203 | # writer.add_image('test/acc0', acc0[...,np.newaxis], global_step, dataformats="HWC")
1204 | writer.add_image('test/target', target.numpy(), global_step, dataformats="HWC")
1205 |
1206 | writer.add_scalar('test/psnr', psnr.item(), global_step)
1207 | jt.gc()
1208 | # jt.display_memory_info()
1209 | # jt.display_max_memory_info()
1210 | # a=images[0].numpy()
1211 | # b=images[1].numpy()
1212 | # c=images[2].numpy()
1213 | # print("images0",a.shape,a.sum())
1214 | # print("images1",b.shape,b.sum())
1215 | # print("images2",c.shape,c.sum())
1216 | # writer.add_image('test/rgb_target0', a, global_step, dataformats="HWC")
1217 | # writer.add_image('test/rgb_target1', b, global_step, dataformats="HWC")
1218 | # writer.add_image('test/rgb_target2', c, global_step, dataformats="HWC")
1219 |
1220 |
1221 | # if args.N_importance > 0:
1222 |
1223 | # with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
1224 | # tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
1225 | # tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
1226 | # tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
1227 |
1228 |
1229 | global_step += 1
1230 |
1231 |
1232 | if __name__=='__main__':
1233 | train()
1234 |
--------------------------------------------------------------------------------
/run_nerf_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jittor as jt
3 | from jittor import nn
4 | import numpy as np
5 |
6 |
7 | # Misc
8 | img2mse = lambda x, y : jt.mean((x - y) ** 2)
9 | mse2psnr = lambda x : -10. * jt.log(x) / jt.log(jt.array(np.array([10.])))
10 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
11 |
12 |
13 | # Positional encoding (section 5.1)
14 | class Embedder:
15 | def __init__(self, **kwargs):
16 | self.kwargs = kwargs
17 | self.create_embedding_fn()
18 |
19 | def create_embedding_fn(self):
20 | embed_fns = []
21 | d = self.kwargs['input_dims']
22 | out_dim = 0
23 | if self.kwargs['include_input']:
24 | embed_fns.append(lambda x : x)
25 | out_dim += d
26 |
27 | max_freq = self.kwargs['max_freq_log2']
28 | N_freqs = self.kwargs['num_freqs']
29 |
30 | if self.kwargs['log_sampling']:
31 | freq_bands = 2.**jt.linspace(0., max_freq, steps=N_freqs)
32 | else:
33 | freq_bands = jt.linspace(2.**0., 2.**max_freq, steps=N_freqs)
34 |
35 | for freq in freq_bands:
36 | for p_fn in self.kwargs['periodic_fns']:
37 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
38 | out_dim += d
39 |
40 | self.embed_fns = embed_fns
41 | self.out_dim = out_dim
42 |
43 | def embed(self, inputs):
44 | return jt.concat([fn(inputs) for fn in self.embed_fns], -1)
45 |
46 |
47 | def get_embedder(multires, i=0):
48 | if i == -1:
49 | return nn.Identity(), 3
50 |
51 | embed_kwargs = {
52 | 'include_input' : True,
53 | 'input_dims' : 3,
54 | 'max_freq_log2' : multires-1,
55 | 'num_freqs' : multires,
56 | 'log_sampling' : True,
57 | 'periodic_fns' : [jt.sin, jt.cos],
58 | }
59 |
60 | embedder_obj = Embedder(**embed_kwargs)
61 | embed = lambda x, eo=embedder_obj : eo.embed(x)
62 | return embed, embedder_obj.out_dim
63 |
64 | #tree class
65 | class Node():
66 | def __init__(self, anchors, sons, linears):
67 | self.anchors = anchors
68 | self.sons = sons
69 | self.linears = linears
70 |
71 | # Model
72 | class OutputNet(nn.Module):
73 | def __init__(self, W, input_ch_views):
74 | """
75 | """
76 | super(OutputNet, self).__init__()
77 |
78 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
79 | self.feature_linear = nn.Linear(W, W)
80 | self.alpha_linear = nn.Linear(W, 1)
81 | self.rgb_linear = nn.Linear(W//2, 3)
82 |
83 | def execute(self, h, input_views):
84 | alpha = self.alpha_linear(h)
85 | feature = self.feature_linear(h)
86 | h = jt.concat([feature, input_views], -1)
87 |
88 | for i, l in enumerate(self.views_linears):
89 | h = self.views_linears[i](h)
90 | h = jt.nn.relu(h)
91 |
92 | rgb = self.rgb_linear(h)
93 | outputs = jt.concat([rgb, alpha], -1)
94 | return outputs
95 |
96 | # Model
97 | class NeRF(nn.Module):
98 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False, head_num=8, threshold=3e-2):
99 | """
100 | """
101 | super(NeRF, self).__init__()
102 | D=12
103 | self.D = D
104 | self.W = W
105 | self.input_ch = input_ch
106 | self.input_ch_views = input_ch_views
107 | skips=[]
108 | # skips = [2,4,6]
109 | self.skips = skips
110 | # self.ress = [1,3,7,11]
111 | # self.ress = []
112 | # self.outs = [1,3,7,11]
113 | self.force_out = [0]
114 | # self.force_out = [7,8,9,10,11,12,13,14]
115 | self.use_viewdirs = use_viewdirs
116 | assert self.use_viewdirs==True
117 |
118 | self.threshold = threshold
119 |
120 | self.build_tree(head_num)
121 | self.pts_linears = nn.ModuleList(
122 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skip_linear else nn.Linear(W + input_ch, W) for i in range(self.linear_num-1)])
123 | # for i in range(self.nlinear_list[0]+1,len(self.pts_linears)):
124 | # jt.init.constant_(self.pts_linears[i].weight, 0.0)
125 | # jt.init.constant_(self.pts_linears[i].bias, 0.0)
126 | # self.confidence_linears = nn.ModuleList([nn.Linear(W+ input_ch, 1) for i in range(D)])
127 | self.confidence_linears = nn.ModuleList([nn.Linear(W, 1) for i in range(self.node_num)])
128 | # self.outnet = OutputNet(W, input_ch_views)
129 | self.outnet = nn.ModuleList([OutputNet(W, input_ch_views) for i in range(self.node_num)])
130 |
131 | def get_anchor(self, i):
132 | return getattr(self, i)
133 |
134 | def build_tree(self, head_num):
135 | # self.son_list = [[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16],[17,18,19,20],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]
136 | # self.nlinear_list = [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2]
137 |
138 | # self.son_list = [[1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14],[],[],[],[],[],[],[],[]]
139 | # self.nlinear_list = [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2]
140 |
141 | # self.son_list = [[1,2],[3,4],[5,6],[],[],[],[]]
142 | # self.nlinear_list = [2,2,2,4,4,4,4]
143 | # self.skip_linear = [6,10,14,18]
144 |
145 | if head_num == 1:
146 | # 1 head
147 | self.son_list = [[1],[2],[3],[]]
148 | self.nlinear_list = [2,2,4,4]
149 | self.skip_linear = [4]
150 | elif head_num == 4:
151 | # 4 head
152 | self.son_list = [[1,2],[3,4],[5,6],[7],[8],[9],[10],[],[],[],[]]
153 | self.nlinear_list = [2,2,2,4,4,4,4,4,4,4,4]
154 | self.skip_linear = [6,10,14,18]
155 | elif head_num == 8:
156 | # 8 head
157 | self.son_list = [[1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14],[],[],[],[],[],[],[],[]]
158 | self.nlinear_list = [2,2,2,4,4,4,4,4,4,4,4,4,4,4,4]
159 | self.skip_linear = [6,10,14,18]
160 | elif head_num == 16:
161 | # 16 head
162 | self.son_list = [[1,2,3,4],[5,6],[7,8],[9,10],[11,12],[13,14],[15,16],[17,18],[19,20],[21,22],[23,24],[25,26],[27,28],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]
163 | self.nlinear_list = [2,2,2,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4]
164 | self.skip_linear = [10,14,18,22,26,30,34,38]
165 |
166 | # self.anchor_list = [np.array([[-2,0,0],[2,0,0]]).astype(np.float32),
167 | # np.array([[0,-2,0],[0,2,0]]).astype(np.float32),
168 | # np.array([[0,-2,0],[0,2,0]]).astype(np.float32),
169 | # np.array([[-1,0,0],[0,0,0]]).astype(np.float32),
170 | # np.array([[-1,0,0],[0,0,0]]).astype(np.float32),
171 | # np.array([[0,0,0],[1,0,0]]).astype(np.float32),
172 | # np.array([[0,0,0],[1,0,0]]).astype(np.float32)]
173 | # self.anchor_list = [np.array([[-2,0,0],[2,0,0]]).astype(np.float32),
174 | # np.array([[0,-2,0],[0,2,0]]).astype(np.float32),
175 | # np.array([[0,-2,0],[0,2,0]]).astype(np.float32),
176 | # np.array([[0,0,0]]).astype(np.float32),
177 | # np.array([[0,0,0]]).astype(np.float32),
178 | # np.array([[0,0,0]]).astype(np.float32),
179 | # np.array([[0,0,0]]).astype(np.float32)]
180 | assert len(self.son_list) == len(self.nlinear_list)
181 | self.anchor_list = [np.array([[0,0,0]]).astype(np.float32)]*len(self.son_list)
182 | self.node_list = []
183 | self.node_num = len(self.son_list)
184 | self.anchor_num = 0
185 | self.linear_num = 0
186 | for i in range(len(self.son_list)):
187 | son = self.son_list[i]
188 | if len(son)>0:
189 | anchor = "anchor"+str(self.anchor_num)
190 | self.anchor_num += 1
191 | setattr(self, anchor, jt.array(self.anchor_list[i]))
192 | # setattr(self, anchor, jt.random([len(son), 3]))
193 | else:
194 | anchor = None
195 | linear = list(range(self.linear_num, self.linear_num+self.nlinear_list[i]))
196 | self.linear_num += self.nlinear_list[i]
197 | self.node_list.append(Node(anchor, son, linear))
198 |
199 | def my_concat(self, a, b, dim):
200 | if a is None:
201 | return b
202 | elif b is None:
203 | return a
204 | else:
205 | return jt.concat([a,b],dim)
206 |
207 | def search(self, t, p, h, input_pts, input_views, remain_mask):
208 | node = self.node_list[t]
209 | # print("search t",t,"remain_mask",remain_mask.sum())
210 | identity = h
211 | for i in range(len(node.linears)):
212 | # print("i",i)
213 | # print("h",h.shape)
214 | # print(self.pts_linears[node.linears[i]])
215 | # print("len",len(self.pts_linears),"node.linears[i]",node.linears[i],"node",node)
216 | # print("t",t,"i",i,"h",h.shape,"line",self.pts_linears[node.linears[i]].weight.shape)
217 | h = self.pts_linears[node.linears[i]](h)
218 | if t==0 and i==0:
219 | identity = h
220 | if i==len(node.linears)-1:
221 | h = h+identity
222 | h = jt.nn.relu(h)
223 | if node.linears[i] in self.skip_linear:
224 | h = jt.concat([input_pts, h], -1)
225 |
226 | confidence = self.confidence_linears[t](h).view(-1)
227 | # threshold = 0.0
228 | threshold = self.threshold
229 | # threshold = -1e10
230 | output = self.outnet[t](h, input_views)
231 | # output = self.outnet[0](h, input_views)
232 | out_num = np.zeros((self.node_num))
233 |
234 | if len(node.sons)>0 and (not t in self.force_out):
235 | son_outputs = None
236 | son_outputs_fuse = None
237 | son_confs = None
238 | son_confs_fuse = None
239 | idxs = None
240 | idxs_fuse = None
241 |
242 | anchor = self.get_anchor(node.anchors)
243 | dis = (anchor.unsqueeze(0)-p.unsqueeze(1)).sqr().sum(-1).sqrt()
244 | min_idx, _ = jt.argmin(dis,-1)
245 | for i in range(len(node.sons)):
246 | # print("t",t,"i",i)
247 | next_t = node.sons[i]
248 | sidx = jt.arange(0,p.shape[0])
249 | # print("min_idx==i",min_idx==i)
250 | sidx = sidx[min_idx==i]
251 | # print("sidx",sidx)
252 | next_p = p[sidx]
253 | next_h = h[sidx]
254 | next_input_pts = input_pts[sidx]
255 | next_input_views = input_views[sidx]
256 | next_remain_mask = remain_mask[sidx].copy()
257 | next_conf = confidence[sidx]
258 | next_remain_mask[threshold>next_conf] = 0
259 | sidx_fuse = sidx[next_remain_mask==1]
260 |
261 | # print("start t",t,"i",i,"next_t",next_t)
262 | next_outputs, next_outputs_fuse, next_confs, next_confs_fuse, next_out_num = self.search(next_t, next_p, next_h, next_input_pts, next_input_views, next_remain_mask)
263 | out_num = out_num+next_out_num
264 | # print("search", t, next_t)
265 | # print("next_outputs",next_outputs.shape)
266 | # print("next_outputs_fuse",next_outputs_fuse.shape)
267 | # print("next_confs",next_confs.shape)
268 | # print("next_confs_fuse",next_confs_fuse.shape)
269 |
270 | # print("end t",t,"i",i,"next_t",next_t)
271 | son_outputs = self.my_concat(son_outputs, next_outputs, 1)
272 | son_outputs_fuse = self.my_concat(son_outputs_fuse, next_outputs_fuse, 1)
273 | son_confs = self.my_concat(son_confs, next_confs, 1)
274 | son_confs_fuse = self.my_concat(son_confs_fuse, next_confs_fuse, 1)
275 | idxs = self.my_concat(idxs, sidx, 0)
276 | idxs_fuse = self.my_concat(idxs_fuse, sidx_fuse, 0)
277 |
278 | # print("t",t)
279 | son_outputs_save = jt.zeros(son_outputs.shape)
280 | son_outputs_save[:,idxs] = son_outputs
281 | son_outputs_save = jt.concat([output.unsqueeze(0), son_outputs_save], 0)
282 | son_confs_save = jt.zeros(son_confs.shape)
283 | son_confs_save[:,idxs] = son_confs
284 | son_confs_save = jt.concat([confidence.unsqueeze(1).unsqueeze(0), son_confs_save], 0)
285 |
286 | out_remain_mask = remain_mask.copy()
287 | out_remain_mask[threshold<=confidence] = 0
288 | idx_out = jt.arange(0,out_remain_mask.shape[0])[out_remain_mask==1]
289 | outputs_out = output[idx_out].unsqueeze(0)
290 | out_num[t] = outputs_out.shape[1]
291 | confs_out = confidence[idx_out].unsqueeze(1).unsqueeze(0)
292 | outputs_out = jt.concat([outputs_out, son_outputs_fuse], 1)
293 | confs_out = jt.concat([confs_out, son_confs_fuse], 1)
294 | idx_out = jt.concat([idx_out, idxs_fuse], 0)
295 |
296 | outputs_out_save = jt.zeros(output.unsqueeze(0).shape)
297 | outputs_out_save[:, idx_out] = outputs_out
298 | outputs_out_save = outputs_out_save[:, remain_mask==1]
299 | confs_out_save = jt.zeros(confidence.unsqueeze(1).unsqueeze(0).shape)
300 | confs_out_save[:, idx_out] = confs_out
301 | confs_out_save = confs_out_save[:, remain_mask==1]
302 |
303 | return son_outputs_save, outputs_out_save, son_confs_save, confs_out_save, out_num
304 | else:
305 | outputs_save = output.unsqueeze(0)
306 | outputs_save_log = outputs_save.copy()
307 | confs_save = confidence.unsqueeze(1).unsqueeze(0)
308 | # print("outputs_save",outputs_save.shape)
309 | # print("remain_mask",remain_mask.shape)
310 | # print("remain_mask==1",remain_mask==1)
311 | outputs_out_save = outputs_save[:, remain_mask==1]
312 | confs_out_save = confs_save[:, remain_mask==1]
313 | out_num[t] = outputs_out_save.shape[1]
314 |
315 | remain_mask[threshold<=confidence] = 0
316 | # print("out:", remain_mask.sum().numpy(), "remain:", remain_mask.shape[0]-remain_mask.sum().numpy())
317 |
318 | # print("outputs_out_save",outputs_out_save.shape)
319 | if not self.training:
320 | if t%4==0:
321 | outputs_save_log[..., 0] *= 0.
322 | outputs_save_log[..., 1] *= 0.
323 | elif t%4==1:
324 | outputs_save_log[..., 0] *= 0.
325 | outputs_save_log[..., 2] *= 0.
326 | elif t%4==2:
327 | outputs_save_log[..., 1] *= 0.
328 | outputs_save_log[..., 2] *= 0.
329 | elif t%4==3:
330 | outputs_save_log[..., 0] *= 0.
331 | # elif t==1:
332 | # outputs_out_save[..., 1] *= 0.
333 | # elif t==2:
334 | # outputs_out_save[..., 2] *= 0.
335 | outputs_save = jt.concat([outputs_save, outputs_save_log], 0)
336 | confs_save = jt.concat([confs_save, confs_save], 0)
337 | # print("outputs_out_save out",outputs_out_save.shape)
338 |
339 | # print("outputs_out_save",outputs_out_save.shape)
340 |
341 | return outputs_save, outputs_out_save, confs_save, confs_out_save, out_num
342 |
343 | def do_train(self, x, p):
344 | input_pts, input_views = jt.split(x, [self.input_ch, self.input_ch_views], dim=-1)
345 | remain_mask = jt.ones(input_pts.shape[0])
346 | outputs, outputs_fuse, confs, confs_fuse, out_num = self.search(0, p, input_pts, input_pts, input_views, remain_mask)
347 |
348 | outputs = jt.concat([outputs, outputs_fuse], 0)
349 | confs = jt.concat([confs, confs_fuse], 0)
350 |
351 | return outputs, confs, np.zeros([1])
352 |
353 | def do_eval(self, x, p):
354 | input_pts, input_views = jt.split(x, [self.input_ch, self.input_ch_views], dim=-1)
355 | remain_mask = jt.ones(input_pts.shape[0])
356 | outputs, outputs_fuse, confs, confs_fuse, out_num = self.search(0, p, input_pts, input_pts, input_views, remain_mask)
357 |
358 | log = "out: "
359 | sout_num = list(out_num)
360 | for i in range(len(sout_num)):
361 | log += str(i)+": %d; " % sout_num[i]
362 | # print(log)
363 | sout_num = np.array(sout_num)
364 | outputs = outputs[-1:]
365 |
366 | outputs = jt.concat([outputs, outputs_fuse], 0)
367 | confs = jt.concat([confs, confs_fuse], 0)
368 |
369 | return outputs, confs, sout_num
370 |
371 | def execute(self, x, p, training):
372 | self.training = training
373 | if training:
374 | return self.do_train(x, p)
375 | else:
376 | # return self.do_train(x, p)
377 | return self.do_eval(x, p)
378 |
379 | def load_weights_from_keras(self, weights):
380 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
381 |
382 | # Load pts_linears
383 | for i in range(self.D):
384 | idx_pts_linears = 2 * i
385 | self.pts_linears[i].weight.data = jt.array(np.transpose(weights[idx_pts_linears]))
386 | self.pts_linears[i].bias.data = jt.array(np.transpose(weights[idx_pts_linears+1]))
387 |
388 | # Load feature_linear
389 | idx_feature_linear = 2 * self.D
390 | self.feature_linear.weight.data = jt.array(np.transpose(weights[idx_feature_linear]))
391 | self.feature_linear.bias.data = jt.array(np.transpose(weights[idx_feature_linear+1]))
392 |
393 | # Load views_linears
394 | idx_views_linears = 2 * self.D + 2
395 | self.views_linears[0].weight.data = jt.array(np.transpose(weights[idx_views_linears]))
396 | self.views_linears[0].bias.data = jt.array(np.transpose(weights[idx_views_linears+1]))
397 |
398 | # Load rgb_linear
399 | idx_rbg_linear = 2 * self.D + 4
400 | self.rgb_linear.weight.data = jt.array(np.transpose(weights[idx_rbg_linear]))
401 | self.rgb_linear.bias.data = jt.array(np.transpose(weights[idx_rbg_linear+1]))
402 |
403 | # Load alpha_linear
404 | idx_alpha_linear = 2 * self.D + 6
405 | self.alpha_linear.weight.data = jt.array(np.transpose(weights[idx_alpha_linear]))
406 | self.alpha_linear.bias.data = jt.array(np.transpose(weights[idx_alpha_linear+1]))
407 |
408 |
409 |
410 | # Ray helpers
411 | def get_rays(H, W, focal, c2w, intrinsic = None):
412 | i, j = jt.meshgrid(jt.linspace(0, W-1, W), jt.linspace(0, H-1, H))
413 | i = i.t()
414 | j = j.t()
415 | if intrinsic is None:
416 | dirs = jt.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -jt.ones_like(i)], -1).unsqueeze(-2)
417 | else:
418 | i+=0.5
419 | j+=0.5
420 | dirs = jt.stack([i, j, jt.ones_like(i)], -1).unsqueeze(-2)
421 | dirs = jt.sum(dirs * intrinsic[:3,:3], -1).unsqueeze(-2)
422 | # Rotate ray directions from camera frame to the world frame
423 | rays_d = jt.sum(dirs * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
424 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
425 | rays_o = c2w[:3,-1].expand(rays_d.shape)
426 | return rays_o, rays_d
427 |
428 |
429 | def get_rays_np(H, W, focal, c2w):
430 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
431 | dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
432 | # Rotate ray directions from camera frame to the world frame
433 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
434 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
435 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
436 | return rays_o, rays_d
437 |
438 |
439 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
440 | # Shift ray origins to near plane
441 | t = -(near + rays_o[...,2]) / rays_d[...,2]
442 | rays_o = rays_o + t.unsqueeze(-1) * rays_d
443 |
444 | # Projection
445 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
446 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
447 | o2 = 1. + 2. * near / rays_o[...,2]
448 |
449 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
450 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
451 | d2 = -2. * near / rays_o[...,2]
452 |
453 | rays_o = jt.stack([o0,o1,o2], -1)
454 | rays_d = jt.stack([d0,d1,d2], -1)
455 |
456 | return rays_o, rays_d
457 |
458 |
459 | # Hierarchical sampling (section 5.2)
460 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
461 | # Get pdf
462 | weights = weights + 1e-5 # prevent nans
463 | pdf = weights / jt.sum(weights, -1, keepdims=True)
464 | cdf = jt.cumsum(pdf, -1)
465 | cdf = jt.concat([jt.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
466 |
467 | # Take uniform samples
468 | if det:
469 | u = jt.linspace(0., 1., steps=N_samples)
470 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
471 | else:
472 | u = jt.random(list(cdf.shape[:-1]) + [N_samples])
473 |
474 | # Pytest, overwrite u with numpy's fixed random numbers
475 | if pytest:
476 | # np.random.seed(0)
477 | new_shape = list(cdf.shape[:-1]) + [N_samples]
478 | if det:
479 | u = np.linspace(0., 1., N_samples)
480 | u = np.broadcast_to(u, new_shape)
481 | else:
482 | u = np.random.rand(*new_shape)
483 | u = jt.array(u)
484 |
485 | # Invert CDF
486 | # u = u.contiguous()
487 | # inds = searchsorted(cdf, u, side='right')
488 | inds = jt.searchsorted(cdf, u, right=True)
489 | below = jt.maximum(jt.zeros_like(inds-1), inds-1)
490 | above = jt.minimum((cdf.shape[-1]-1) * jt.ones_like(inds), inds)
491 | inds_g = jt.stack([below, above], -1) # (batch, N_samples, 2)
492 |
493 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
494 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
495 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
496 | cdf_g = jt.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
497 | bins_g = jt.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
498 |
499 | denom = (cdf_g[...,1]-cdf_g[...,0])
500 | cond = jt.where(denom<1e-5)
501 | denom[cond] = 1.
502 | t = (u-cdf_g[...,0])/denom
503 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
504 |
505 | return samples
506 |
--------------------------------------------------------------------------------
|