├── fig └── teaser.png ├── requirements.txt ├── ours.sh ├── Stanford2D3D.sh ├── Matterport3D.sh ├── LICENSE ├── adjust.py ├── confs ├── ours.conf ├── Matterport3D.conf └── Stanford2D3D.conf ├── README.md ├── models ├── embedder.py ├── fields.py ├── dataset_patch.py ├── dataset.py └── renderer.py └── exp_runner.py /fig/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WJ-Chang-42/IndoorPanoDepth/HEAD/fig/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trimesh==3.9.8 2 | numpy==1.19.2 3 | pyhocon==0.3.57 4 | icecream==2.1.0 5 | opencv_python==4.5.2.52 6 | tqdm==4.50.2 7 | scipy==1.7.0 8 | PyMCubes==0.1.2 9 | -------------------------------------------------------------------------------- /ours.sh: -------------------------------------------------------------------------------- 1 | numbers=0 2 | d=1.5 3 | name=ours 4 | SCENES="classroom bedroom kitchen livingroom loft" 5 | mkdir ./exp/"$d"_"$name"/ 6 | for scene in $SCENES; do 7 | /opt/conda/bin/python -u exp_runner.py --conf ./confs/"$name".conf --case "$scene" --d "$d" --n "$numbers" --dir ./exp/"$d"_"$name"/ --random 80 | tee ./exp/"$d"_"$name"/"$numbers"_images_"$scene".txt 8 | done 9 | 10 | -------------------------------------------------------------------------------- /Stanford2D3D.sh: -------------------------------------------------------------------------------- 1 | numbers=0 2 | d=1.5 3 | name=Stanford2D3D 4 | SCENES="1_area_5a1 1_area_5b1 5_area_5a1 10_area_61 207_area_41" 5 | mkdir ./exp/"$d"_"$name"/ 6 | for scene in $SCENES; do 7 | /opt/conda/bin/python -u exp_runner.py --conf ./confs/"$name".conf --case "$scene" --d "$d" --n "$numbers" --dir ./exp/"$d"_"$name"/ --random 80 | tee ./exp/"$d"_"$name"/"$numbers"_images_"$scene".txt 8 | done 9 | 10 | -------------------------------------------------------------------------------- /Matterport3D.sh: -------------------------------------------------------------------------------- 1 | numbers=0 2 | d=1.5 3 | name=Matterport3D 4 | SCENES="0_0b217f59904d4bdf85d35da2cab963471 1_0b724f78b3c04feeb3e744945517073d1 0_a2577698031844e7a5982c8ee0fecdeb1 0_9f2deaf4cf954d7aa43ce5dc70e7abbe1 0_7812e14df5e746388ff6cfe8b043950a1 4_0b724f78b3c04feeb3e744945517073d1 2_0b217f59904d4bdf85d35da2cab963471 1_7812e14df5e746388ff6cfe8b043950a1 47_a2577698031844e7a5982c8ee0fecdeb1 45_a2577698031844e7a5982c8ee0fecdeb1" 5 | mkdir ./exp/"$d"_"$name"/ 6 | for scene in $SCENES; do 7 | /opt/conda/bin/python -u exp_runner.py --conf ./confs/"$name".conf --case "$scene" --d "$d" --n "$numbers" --dir ./exp/"$d"_"$name"/ --random 80 | tee ./exp/"$d"_"$name"/"$numbers"_images_"$scene".txt 8 | done 9 | 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Peng Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /adjust.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import cv2 4 | import os 5 | def linear_adjust(depth,alpha): 6 | mapping = np.ones_like(depth) 7 | bar = depth - alpha 8 | mapping[bar>0] = 1 + 0.4*bar[bar>0] 9 | return mapping[...,None] 10 | 11 | 12 | 13 | ###Config### 14 | data_dir = '***/Matterport3D/' 15 | save_dir = '***/adjust/' 16 | distance=1.5 17 | ##################### 18 | scenes = ['0_0b217f59904d4bdf85d35da2cab963471', '1_0b724f78b3c04feeb3e744945517073d1', '0_a2577698031844e7a5982c8ee0fecdeb1', '0_9f2deaf4cf954d7aa43ce5dc70e7abbe1', '0_7812e14df5e746388ff6cfe8b043950a1', '4_0b724f78b3c04feeb3e744945517073d1',"2_0b217f59904d4bdf85d35da2cab963471" ,"1_7812e14df5e746388ff6cfe8b043950a1","47_a2577698031844e7a5982c8ee0fecdeb1","45_a2577698031844e7a5982c8ee0fecdeb1"] 19 | for scene in scenes: 20 | positions = ['Right','Up', 'Left_Down'] 21 | for position in positions: 22 | gt_depth = np.array(cv2.imread(data_dir+'%s_depth_0_%s_0.0.exr'%(scene,position),cv2.IMREAD_UNCHANGED)[:,:,0]) 23 | img = np.array(cv2.imread(data_dir+'%s_color_0_%s_0.0.png'%(scene,position), cv2.IMREAD_ANYCOLOR))/255 24 | mapping = linear_adjust(gt_depth,distance) 25 | cv2.imwrite(save_dir+'%s_color_0_%s_0.0.png'%(scene,position),(img*mapping).clip(0,1)*255) 26 | os.system('cp %s%s_depth* %s'%(data_dir,scene,save_dir)) -------------------------------------------------------------------------------- /confs/ours.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = ./data/ours/CASE_NAME/CASE_NAME 11 | } 12 | 13 | train { 14 | learning_rate = 5e-4 15 | learning_rate_alpha = 0.1 16 | end_iter = 100 17 | 18 | batch_size = 512 19 | validate_resolution_level = 1 20 | warm_up_end = 0 21 | anneal_end = 50 22 | use_white_bkgd = False 23 | 24 | save_freq = 10 25 | val_freq = 10 26 | val_mesh_freq = 100000 27 | report_freq = 1 28 | 29 | igr_weight = 0.0 30 | mask_weight = 0.0 31 | } 32 | 33 | model { 34 | color_network { 35 | D = 8, 36 | d_in = 3, 37 | d_in_view = 3, 38 | W = 256, 39 | multires = 10, 40 | multires_view = 4, 41 | output_ch = 4, 42 | skips=[4], 43 | use_viewdirs=True 44 | } 45 | 46 | 47 | sdf_network { 48 | d_out = 257 49 | d_in = 3 50 | d_hidden = 256 51 | n_layers = 8 52 | skip_in = [4] 53 | multires = 6 54 | bias = 2.5 55 | scale = 1.0 56 | inside_outside = True 57 | geometric_init = True 58 | weight_norm = False 59 | } 60 | 61 | variance_network { 62 | init_val = 0.3 63 | } 64 | 65 | renderer { 66 | n_samples = 64 67 | n_importance = 64 68 | perturb = 1.0 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /confs/Matterport3D.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = ./data/Matterport3D/CASE_NAME 11 | } 12 | 13 | train { 14 | learning_rate = 5e-4 15 | learning_rate_alpha = 0.1 16 | end_iter = 100 17 | 18 | batch_size = 512 19 | validate_resolution_level = 1 20 | warm_up_end = 0 21 | anneal_end = 50 22 | use_white_bkgd = False 23 | 24 | save_freq = 10 25 | val_freq = 10 26 | val_mesh_freq = 100000 27 | report_freq = 1 28 | 29 | igr_weight = 0.0 30 | mask_weight = 0.0 31 | } 32 | 33 | model { 34 | color_network { 35 | D = 8, 36 | d_in = 3, 37 | d_in_view = 3, 38 | W = 256, 39 | multires = 10, 40 | multires_view = 4, 41 | output_ch = 4, 42 | skips=[4], 43 | use_viewdirs=True 44 | } 45 | 46 | 47 | sdf_network { 48 | d_out = 257 49 | d_in = 3 50 | d_hidden = 256 51 | n_layers = 8 52 | skip_in = [4] 53 | multires = 6 54 | bias = 2.5 55 | scale = 1.0 56 | inside_outside = True 57 | geometric_init = True 58 | weight_norm = False 59 | } 60 | 61 | variance_network { 62 | init_val = 0.3 63 | } 64 | 65 | renderer { 66 | n_samples = 64 67 | n_importance = 64 68 | perturb = 1.0 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /confs/Stanford2D3D.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = ./data/Stanford2D3D/CASE_NAME 11 | 12 | } 13 | 14 | train { 15 | learning_rate = 5e-4 16 | learning_rate_alpha = 0.1 17 | end_iter = 100 18 | 19 | batch_size = 512 20 | validate_resolution_level = 1 21 | warm_up_end = 0 22 | anneal_end = 50 23 | use_white_bkgd = False 24 | 25 | save_freq = 10 26 | val_freq = 10 27 | val_mesh_freq = 100000 28 | report_freq = 1 29 | 30 | igr_weight = 0.0 31 | mask_weight = 0.0 32 | } 33 | 34 | model { 35 | color_network { 36 | D = 8, 37 | d_in = 3, 38 | d_in_view = 3, 39 | W = 256, 40 | multires = 10, 41 | multires_view = 4, 42 | output_ch = 4, 43 | skips=[4], 44 | use_viewdirs=True 45 | } 46 | 47 | 48 | sdf_network { 49 | d_out = 257 50 | d_in = 3 51 | d_hidden = 256 52 | n_layers = 8 53 | skip_in = [4] 54 | multires = 6 55 | bias = 2.5 56 | scale = 1.0 57 | inside_outside = True 58 | geometric_init = True 59 | weight_norm = False 60 | } 61 | 62 | variance_network { 63 | init_val = 0.3 64 | } 65 | 66 | renderer { 67 | n_samples = 64 68 | n_importance = 64 69 | perturb = 1.0 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IndoorPanoDepth 2 | We present a novel neural representation based method for depth estimation from a few panoramic images of different views. 3 | ![](./fig/teaser.png) 4 | This is the official repo for the implementation of [Depth Estimation from Indoor Panoramas with Neural Scene Representation (CVPR'2023)](https://openaccess.thecvf.com/content/CVPR2023/html/Chang_Depth_Estimation_From_Indoor_Panoramas_With_Neural_Scene_Representation_CVPR_2023_paper.html). 5 | 6 | ## Usage 7 | For the Matterport3D and Stanford2D3D datasets, we adopt the rerendered version from [3D60](https://vcl3d.github.io/3D60/). The brightness-adjusted dataset could be generated with 'adjust.py'. 8 | ### Matterport3D 9 | Copy the scenes to ./data/Matterport3D 10 | ``` 11 | SCENES="0_0b217f59904d4bdf85d35da2cab963471 1_0b724f78b3c04feeb3e744945517073d1 0_a2577698031844e7a5982c8ee0fecdeb1 0_9f2deaf4cf954d7aa43ce5dc70e7abbe1 0_7812e14df5e746388ff6cfe8b043950a1 4_0b724f78b3c04feeb3e744945517073d1 2_0b217f59904d4bdf85d35da2cab963471 1_7812e14df5e746388ff6cfe8b043950a1 47_a2577698031844e7a5982c8ee0fecdeb1 45_a2577698031844e7a5982c8ee0fecdeb1" 12 | for scene in $SCENES; do 13 | cp 3D60/Matterport3D/"$scene"_* ./data/Matterport3D 14 | done 15 | ``` 16 | 17 | Direct run the following command. 18 | ``` 19 | sh Matterport3D.sh 20 | ``` 21 | --- 22 | ### Stanford2D3D 23 | Copy the scenes to ./data/Stanford2D3D 24 | ``` 25 | SCENES="1_area_5a1 1_area_5b1 5_area_5a1 10_area_61 207_area_41" 26 | for scene in $SCENES; do 27 | cp 3D60/Stanford2D3D/"$scene"_* ./data/Stanford2D3D 28 | done 29 | ``` 30 | Direct run the following command. 31 | ``` 32 | sh Stanford2D3D.sh 33 | ``` 34 | --- 35 | ### Our dataset 36 | First download our dataset from https://1drv.ms/u/s!AmmYGRQ4ky-T1N0Bv8x7Oq_qQiKmNg?e=xUlxHR. Then, unzip it to the './data' fold as follows: 37 | ``` 38 | |-- code 39 | |-- data 40 | |-- Matterport3D 41 | |-- Stanford2D3D 42 | |-- ours 43 | |-- bedroom 44 | ... 45 | ``` 46 | Finally, run the command 47 | ``` 48 | sh ours.sh 49 | ``` 50 | 51 | 52 | ## Acknowledgement 53 | The main framework is borrowed from [NeuS](https://github.com/Totoro97/NeuS). The 3D models used for rendering dataset are from [Flavio, Della, Tommasa](https://download.blender.org/demo/cycles/flat-archiviz.blend), [Christophe, Seux](https://download.blender.org/demo/test/classroom.zip) and [Tadeusz](https://blenderartists.org/t/free-scene-loft-interior-design/1200857). 54 | Thanks for these great works. 55 | -------------------------------------------------------------------------------- /models/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import ipdb 5 | 6 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. 7 | class Embedder: 8 | def __init__(self, **kwargs): 9 | self.kwargs = kwargs 10 | self.create_embedding_fn() 11 | 12 | def create_embedding_fn(self): 13 | embed_fns = [] 14 | d = self.kwargs['input_dims'] 15 | #d = 4 16 | out_dim = 0 17 | if self.kwargs['include_input']: 18 | embed_fns.append(lambda x: x) 19 | out_dim += d 20 | 21 | max_freq = self.kwargs['max_freq_log2'] 22 | N_freqs = self.kwargs['num_freqs'] 23 | 24 | if self.kwargs['log_sampling']: 25 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 26 | else: 27 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 28 | 29 | for freq in freq_bands: 30 | for p_fn in self.kwargs['periodic_fns']: 31 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 32 | out_dim += d 33 | 34 | self.embed_fns = embed_fns 35 | self.out_dim = out_dim 36 | 37 | def embed(self, inputs): 38 | 39 | if False: 40 | dis_to_center = torch.linalg.norm(inputs, ord=2, dim=-1) 41 | #phi_ = torch.arcsin((inputs[...,1]+10e-8)/(dis_to_center+10e-8)) 42 | phi_ = torch.arctan(inputs[...,1]/(torch.sqrt(inputs[...,0]**2 + inputs[...,2]**2)+10e-10)) 43 | theta_ = torch.arctan(inputs[...,0]/((inputs[...,2]+10e-10))) 44 | theta_[inputs[...,2]<0] += np.pi 45 | theta_[theta_>np.pi] -= 2*np.pi 46 | sphere = torch.cat([theta_.unsqueeze(-1),phi_.unsqueeze(-1),1/(dis_to_center.unsqueeze(-1)+ 10e-10)],-1) 47 | #print(torch.abs(dis_to_center*torch.cos(phi_)*torch.cos(theta_) - inputs[...,2]).max(), theta_.max(),phi_.max()) 48 | results = torch.cat([inputs]+[fn(sphere) for fn in self.embed_fns], -1) 49 | # if torch.sum(torch.isnan(results)) > 0: 50 | # torch.sum(torch.isnan(phi_)) 51 | # ipdb.set_trace() 52 | if False: 53 | #ipdb.set_trace() 54 | dis_to_center = torch.linalg.norm(inputs, ord=2, dim=-1, keepdim=True) 55 | pts_color = torch.cat([inputs /dis_to_center, 1/dis_to_center], dim=-1) 56 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 57 | 58 | 59 | def get_embedder(multires, input_dims=3): 60 | embed_kwargs = { 61 | 'include_input': True, 62 | 'input_dims': input_dims, 63 | 'max_freq_log2': multires-1, 64 | 'num_freqs': multires, 65 | 'log_sampling': True, 66 | 'periodic_fns': [torch.sin, torch.cos], 67 | } 68 | 69 | embedder_obj = Embedder(**embed_kwargs) 70 | def embed(x, eo=embedder_obj): return eo.embed(x) 71 | return embed, embedder_obj.out_dim 72 | -------------------------------------------------------------------------------- /models/fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.embedder import get_embedder 6 | import ipdb 7 | 8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 9 | class SDFNetwork(nn.Module): 10 | def __init__(self, 11 | distance, 12 | d_in, 13 | d_out, 14 | d_hidden, 15 | n_layers, 16 | skip_in=(4,), 17 | multires=0, 18 | bias=0.5, 19 | scale=1, 20 | geometric_init=True, 21 | weight_norm=True, 22 | inside_outside=False): 23 | super(SDFNetwork, self).__init__() 24 | #ipdb.set_trace() 25 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 26 | 27 | self.embed_fn_fine = None 28 | 29 | if multires > 0: 30 | embed_fn, input_ch = get_embedder(multires, input_dims=3) 31 | self.embed_fn_fine = embed_fn 32 | dims[0] = input_ch 33 | 34 | self.num_layers = len(dims) 35 | self.skip_in = skip_in 36 | self.scale = scale 37 | dim_ = 3 38 | for l in range(0, self.num_layers - 1): 39 | if l + 1 in self.skip_in: 40 | out_dim = dims[l + 1] - dims[0] 41 | else: 42 | out_dim = dims[l + 1] 43 | 44 | lin = nn.Linear(dims[l], out_dim) 45 | if distance > 0: 46 | if l == self.num_layers - 2: 47 | if not inside_outside: 48 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 49 | torch.nn.init.constant_(lin.bias, -distance) 50 | else: 51 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 52 | torch.nn.init.constant_(lin.bias, distance) 53 | elif multires > 0 and l == 0: 54 | torch.nn.init.constant_(lin.bias, 0.0) 55 | torch.nn.init.constant_(lin.weight, 0.0) 56 | #torch.nn.init.normal_(lin.weight[:, 0], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 57 | torch.nn.init.normal_(lin.weight[:, 1], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 58 | #torch.nn.init.normal_(lin.weight[:, 2], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 59 | elif multires > 0 and l in self.skip_in: 60 | #ipdb.set_trace() 61 | torch.nn.init.constant_(lin.bias, 0.0) 62 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 63 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - dim_):], 0.0) 64 | #torch.nn.init.constant_(lin.weight[:, -(dims[0] - 1)], 0.0) 65 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 0)], 0.0) 66 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 2)], 0.0) 67 | else: 68 | torch.nn.init.constant_(lin.bias, 0.0) 69 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 70 | # else: 71 | # if l == self.num_layers - 2: 72 | # if not inside_outside: 73 | # torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 74 | # torch.nn.init.constant_(lin.bias, -distance) 75 | # else: 76 | # torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 77 | # torch.nn.init.constant_(lin.bias, distance) 78 | # elif multires > 0 and l == 0: 79 | # torch.nn.init.constant_(lin.bias, 0.0) 80 | # torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 81 | # torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 82 | # elif multires > 0 and l in self.skip_in: 83 | # torch.nn.init.constant_(lin.bias, 0.0) 84 | # torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 85 | # torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 86 | # else: 87 | # torch.nn.init.constant_(lin.bias, 0.0) 88 | # torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 89 | 90 | if weight_norm: 91 | lin = nn.utils.weight_norm(lin) 92 | 93 | setattr(self, "lin" + str(l), lin) 94 | 95 | self.activation = nn.Softplus(beta=100) 96 | 97 | def forward(self, inputs): 98 | #ipdb.set_trace() 99 | dis_to_center = torch.linalg.norm(inputs, ord=2, dim=-1) 100 | # phi_ = torch.arctan(inputs[...,1]/(torch.sqrt(inputs[...,0]**2 + inputs[...,2]**2)+10e-10)) 101 | # theta_ = torch.arctan(inputs[...,0]/((inputs[...,2]+10e-10))) 102 | # theta_[inputs[...,2]<0] += np.pi 103 | # theta_[theta_>np.pi] -= 2*np.pi 104 | # sphere = torch.cat([theta_.unsqueeze(-1),phi_.unsqueeze(-1),1/(dis_to_center.unsqueeze(-1)+ 10e-10)],-1) 105 | if self.embed_fn_fine is not None: 106 | #inputs = torch.cat([inputs,self.embed_fn_fine(torch.cat([inputs /(dis_to_center.unsqueeze(-1)+ 10e-10),1 /(dis_to_center.unsqueeze(-1)+ 10e-10)], dim=-1))],-1) 107 | inputs = self.embed_fn_fine(inputs) 108 | 109 | x = inputs 110 | for l in range(0, self.num_layers - 1): 111 | lin = getattr(self, "lin" + str(l)) 112 | 113 | if l in self.skip_in: 114 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 115 | 116 | x = lin(x) 117 | 118 | if l < self.num_layers - 2: 119 | x = self.activation(x) 120 | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) 121 | 122 | def sdf(self, x): 123 | return self.forward(x)[:, :1] 124 | 125 | def sdf_hidden_appearance(self, x): 126 | return self.forward(x) 127 | 128 | def gradient(self, x): 129 | x.requires_grad_(True) 130 | result = self.sdf_hidden_appearance(x) 131 | y = result[:, :1] 132 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 133 | gradients = torch.autograd.grad( 134 | outputs=y, 135 | inputs=x, 136 | grad_outputs=d_output, 137 | create_graph=True, 138 | retain_graph=True, 139 | only_inputs=True)[0] 140 | return gradients.unsqueeze(1), result 141 | 142 | def gradient_inference(self, x): 143 | x.requires_grad_(True) 144 | result = self.sdf_hidden_appearance(x) 145 | y = result[:, :1] 146 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 147 | gradients = torch.autograd.grad( 148 | outputs=y, 149 | inputs=x, 150 | grad_outputs=d_output, 151 | only_inputs=True)[0] 152 | return gradients.unsqueeze(1), result 153 | 154 | class COLORNetwork(nn.Module): 155 | def __init__(self, 156 | D=8, 157 | W=256, 158 | d_in=3, 159 | d_in_view=3, 160 | multires=0, 161 | multires_view=0, 162 | output_ch=4, 163 | skips=[4], 164 | use_viewdirs=False): 165 | super(COLORNetwork, self).__init__() 166 | self.D = D 167 | self.W = W 168 | self.d_in = d_in 169 | self.d_in_view = d_in_view 170 | self.input_ch = 3 171 | self.input_ch_view = 3 172 | self.embed_fn = None 173 | self.embed_fn_view = None 174 | #ipdb.set_trace() 175 | if multires > 0: 176 | embed_fn, input_ch = get_embedder(multires, input_dims=6) 177 | self.embed_fn = embed_fn 178 | self.input_ch = input_ch 179 | 180 | if multires_view > 0: 181 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=3) 182 | self.embed_fn_view = embed_fn_view 183 | self.input_ch_view = input_ch_view 184 | 185 | self.skips = skips 186 | self.use_viewdirs = use_viewdirs 187 | 188 | self.pts_linears = nn.ModuleList( 189 | [nn.Linear(self.input_ch+256, W)] + 190 | [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)]) 191 | 192 | ### Implementation according to the official code release 193 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 194 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) 195 | 196 | ### Implementation according to the paper 197 | # self.views_linears = nn.ModuleList( 198 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 199 | 200 | if use_viewdirs: 201 | self.feature_linear = nn.Linear(W, W) 202 | #self.alpha_linear = nn.Linear(W, 1) 203 | self.rgb_linear = nn.Linear(W // 2, 3) 204 | else: 205 | self.output_linear = nn.Linear(W, output_ch) 206 | 207 | def forward(self, input_pts, input_views, feature_vectors): 208 | #ipdb.set_trace() 209 | with torch.no_grad(): 210 | dis_to_center = torch.linalg.norm(input_pts, ord=2, dim=-1).clip(10e-4,100) 211 | sin_phi = input_pts[...,1]/dis_to_center 212 | cos_theta_ = input_pts[...,0]/dis_to_center#/torch.sqrt(1 - sin_phi**2 + 10e-6) 213 | cos_theta = cos_theta_ /torch.sqrt(1 - sin_phi**2 + 10e-6) 214 | sin_theta_ = input_pts[...,2]/dis_to_center#/torch.sqrt(1 - sin_phi**2 + 10e-6) 215 | sin_theta = sin_theta_ /torch.sqrt(1 - sin_phi**2 + 10e-6) 216 | 217 | 218 | if self.embed_fn is not None: 219 | #input_pts = self.embed_fn(input_pts/16) 220 | #input_pts = self.embed_fn(sphere) 221 | #input_pts = self.embed_fn(torch.cat([input_pts/(dis_to_center.unsqueeze(-1)+ 10e-10),1/(dis_to_center.unsqueeze(-1)+ 10e-10)],-1)) 222 | input_pts = self.embed_fn(torch.cat([sin_theta_.unsqueeze(-1),cos_theta_.unsqueeze(-1),sin_phi.unsqueeze(-1),1/(dis_to_center.unsqueeze(-1)),sin_theta.unsqueeze(-1),cos_theta.unsqueeze(-1),],-1)) 223 | if self.embed_fn_view is not None: 224 | input_views = self.embed_fn_view(input_views) 225 | 226 | h = torch.cat([input_pts, feature_vectors], dim=-1) 227 | for i, l in enumerate(self.pts_linears): 228 | h = self.pts_linears[i](h) 229 | h = F.relu(h) 230 | if i in self.skips: 231 | h = torch.cat([input_pts, h], -1) 232 | 233 | 234 | feature = self.feature_linear(h) 235 | h = torch.cat([feature, input_views], -1) 236 | 237 | for i, l in enumerate(self.views_linears): 238 | h = self.views_linears[i](h) 239 | h = F.relu(h) 240 | 241 | rgb = self.rgb_linear(h) 242 | return rgb 243 | 244 | class SingleVarianceNetwork(nn.Module): 245 | def __init__(self, init_val): 246 | super(SingleVarianceNetwork, self).__init__() 247 | self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) 248 | 249 | def forward(self, x): 250 | return torch.ones([len(x), 1]) * torch.exp(self.variance * 10.0) 251 | -------------------------------------------------------------------------------- /models/dataset_patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import cv2 as cv 4 | import numpy as np 5 | import os 6 | from glob import glob 7 | from icecream import ic 8 | from scipy.spatial.transform import Rotation as Rot 9 | from scipy.spatial.transform import Slerp 10 | import ipdb 11 | import cv2 12 | 13 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 14 | def load_K_Rt_from_P(filename, P=None): 15 | if P is None: 16 | lines = open(filename).read().splitlines() 17 | if len(lines) == 4: 18 | lines = lines[1:] 19 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 20 | P = np.asarray(lines).astype(np.float32).squeeze() 21 | 22 | out = cv.decomposeProjectionMatrix(P) 23 | K = out[0] 24 | R = out[1] 25 | t = out[2] 26 | 27 | K = K / K[2, 2] 28 | intrinsics = np.eye(4) 29 | intrinsics[:3, :3] = K 30 | 31 | pose = np.eye(4, dtype=np.float32) 32 | pose[:3, :3] = R.transpose() 33 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 34 | 35 | return intrinsics, pose 36 | 37 | 38 | class Dataset_patch: 39 | def __init__(self, conf,numbers,patch=4,degree=0): 40 | super(Dataset_patch, self).__init__() 41 | print('Load data: Begin') 42 | 43 | self.device = torch.device('cuda') 44 | self.conf = conf 45 | self.data_dir = conf.get_string('data_dir') 46 | if numbers>0: 47 | self.position = np.load(self.data_dir+'_position.npy') 48 | 49 | images = [] 50 | depths = [] 51 | #ipdb.set_trace() 52 | img = np.array(cv2.imread(self.data_dir+'_color_0_Left_Down_0.0.png', cv2.IMREAD_ANYCOLOR)) / 255. 53 | #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255. 54 | depth = np.array(cv2.imread(self.data_dir+'_depth_0_Left_Down_0.0.exr',cv2.IMREAD_UNCHANGED)[:,:,0]) 55 | images.append(img) 56 | depths.append(depth) 57 | img = np.array(cv2.imread(self.data_dir+'_color_0_Right_0.0.png', cv2.IMREAD_ANYCOLOR)) / 255. 58 | #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255. 59 | depth = np.array(cv2.imread(self.data_dir+'_depth_0_Right_0.0.exr',cv2.IMREAD_UNCHANGED)[:,:,0]) 60 | images.append(img) 61 | depths.append(depth) 62 | img = np.array(cv2.imread(self.data_dir+'_color_0_Up_0.0.png', cv2.IMREAD_ANYCOLOR)) / 255. 63 | #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255. 64 | depth = np.array(cv2.imread(self.data_dir+'_depth_0_Up_0.0.exr',cv2.IMREAD_UNCHANGED)[:,:,0]) 65 | images.append(img) 66 | depths.append(depth) 67 | for i in range(numbers): 68 | img = np.array(cv2.imread(self.data_dir+'_color_0_%02d_0.0.png'%(i+1), cv2.IMREAD_ANYCOLOR)) / 255. 69 | #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255. 70 | depth = np.array(cv2.imread(self.data_dir+'_depth_0_%02d_0.0.exr'%(i+1),cv2.IMREAD_UNCHANGED)[:,:,0]) 71 | images.append(img) 72 | depths.append(depth) 73 | 74 | self.n_images = len(images) 75 | 76 | self.images_np = np.stack(images, axis=0) 77 | self.depths_np = np.stack(depths, axis=0) 78 | 79 | self.images = torch.from_numpy(self.images_np.astype(np.float32)) # [n_images, H, W, 3] 80 | self.depths = torch.from_numpy(self.depths_np.astype(np.float32)) # [n_images, H, W] 81 | self.H, self.W = self.images.shape[1], self.images.shape[2] 82 | self.image_pixels = self.H * self.W 83 | self.patch_size = patch 84 | self.n_img_patches = self.image_pixels 85 | self.n_patches = self.n_img_patches * (self.n_images) 86 | self.patch_row = torch.arange(patch, device='cpu').repeat(patch,1).reshape(-1) 87 | self.patch_col = torch.arange(patch, device='cpu').repeat(patch,1).permute(1,0).reshape(-1) 88 | #ipdb.set_trace() 89 | cen_x = (self.W - 1) / 2.0 90 | cen_y = (self.H - 1) / 2.0 91 | theta = (2 * (np.arange(self.W) - cen_x) / self.W) * np.pi 92 | phi = (2 * (np.arange(self.H) - cen_y) / self.H) * (np.pi / 2) 93 | theta = np.tile(theta[None, :], [self.H, 1]) 94 | phi = np.tile(phi[None, :], [self.W, 1]).T 95 | 96 | x = (np.cos(phi) * np.sin(theta)).reshape([self.H, self.W, 1]) 97 | y = (np.sin(phi)).reshape([self.H, self.W, 1]) 98 | z = (np.cos(phi) * np.cos(theta)).reshape([self.H, self.W, 1]) 99 | directions = np.concatenate([x, y, z], axis=-1).reshape(-1,3) 100 | ellipsoid = np.concatenate([16*x, 16*y, 16*z], axis=-1) 101 | radius = np.linalg.norm(ellipsoid,ord=2,axis=-1) 102 | #phi_ = torch.arcsin(y) 103 | #theta_ = torch.arctan(x/z)[...,0] 104 | #theta_[z[...,0]<0] += torch.pi 105 | #theta_[theta_>np.pi] -= 2*np.pi 106 | ro = degree/180*np.pi 107 | rotation = np.array([[np.cos(ro),np.sin(ro),0],[-np.sin(ro),np.cos(ro),0],[0,0,1]]) 108 | directions = (rotation @ directions.T ).T 109 | directions = directions.reshape([self.H,self.W,-1])[None,:,:,:] 110 | 111 | directions = np.concatenate([directions, directions, directions,directions,directions,directions,directions,directions,directions,directions,directions,directions,directions], axis=0) 112 | #directions = np.concatenate([directions, directions], axis=0) 113 | origins = np.zeros(directions.shape, dtype = directions.dtype) 114 | origins[1] = origins[1] + (rotation @ np.array([0.26,0,0]).T)[None,None,:] 115 | origins[2] = origins[2] + (rotation @ np.array([0,-0.26,0]).T)[None,None,:] 116 | for i in range(numbers): 117 | origins[i+3] = origins[i+3] + np.array([self.position[i][0],-self.position[i][2],self.position[i][1]])[None,None,:] 118 | #ipdb.set_trace() 119 | self.rays_v = torch.from_numpy(directions.astype(np.float32)) 120 | self.rays_o = torch.from_numpy(origins.astype(np.float32)) 121 | self.radius = torch.from_numpy(radius.astype(np.float32)) 122 | self.all_radius = torch.from_numpy(radius[None,:].astype(np.float32).repeat(self.n_images,axis=0)) 123 | self.all_images = self.images.reshape(-1,3) 124 | self.all_depths = self.depths.reshape(-1) 125 | self.all_rays_v = self.rays_v.reshape(-1,3) 126 | self.all_rays_o = self.rays_o.reshape(-1,3) 127 | self.all_radius = self.all_radius.reshape(-1) 128 | 129 | 130 | print('Load data: End') 131 | 132 | def gen_rays_at(self, img_idx, resolution_level=1): 133 | """ 134 | Generate rays at world space from one camera. 135 | """ 136 | l = resolution_level 137 | 138 | 139 | return self.rays_o[img_idx].cuda(), self.rays_v[img_idx].cuda(), self.depths[img_idx].cuda(), self.radius.cuda() 140 | 141 | 142 | def __len__(self): 143 | #return len(self.all_images) // self.patch_size**2 144 | return 512*256*3 // self.patch_size**2 145 | 146 | 147 | def gen_random_rays_at(self, img_idx, batch_size): 148 | """ 149 | Generate random rays at world space from one camera. 150 | """ 151 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size],device='cpu') 152 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size],device='cpu') 153 | color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3 154 | rays_v = self.rays_v[img_idx][(pixels_y, pixels_x)] # batch_size, 3 155 | rays_o = self.rays_o[img_idx][(pixels_y, pixels_x)] # batch_size, 3 156 | depth = self.depths[img_idx][(pixels_y, pixels_x)] 157 | #ipdb.set_trace() 158 | far = self.radius[(pixels_y, pixels_x)] 159 | return torch.cat([rays_o, rays_v, color, depth.unsqueeze(-1),far.unsqueeze(-1)], dim=-1) # batch_size, 9 160 | 161 | def __getitem__(self, index): 162 | #ipdb.set_trace() 163 | x = torch.randint(low=0, high=self.W, size=[1],device='cpu')[0].item() 164 | y = torch.randint(low=0, high=self.H, size=[1],device='cpu')[0].item() 165 | pixels_x = x + self.patch_row 166 | pixels_x[pixels_x >= self.W] = pixels_x[pixels_x >= self.W] - self.W 167 | pixels_y = y + self.patch_col 168 | pixels_y[pixels_y >= self.H] = pixels_y[pixels_y >= self.H] - self.H 169 | img_idx = torch.randint(low=0, high=self.n_images, size=[1],device='cpu')[0].item() 170 | 171 | rays_o = self.rays_o[img_idx][(pixels_y, pixels_x)] 172 | rays_v = self.rays_v[img_idx][(pixels_y, pixels_x)] 173 | color = self.images[img_idx][(pixels_y, pixels_x)] 174 | depth = self.depths[img_idx][(pixels_y, pixels_x)] 175 | far = self.radius[(pixels_y, pixels_x)] 176 | return torch.cat([rays_o, rays_v, color, depth.unsqueeze(-1),far.unsqueeze(-1)], dim=-1) 177 | 178 | # def __getitem__(self, index): 179 | # i_patch = torch.randint(high=self.n_patches, size=(1,),device='cpu')[0].item() 180 | # i_img, i_patch = i_patch // self.n_img_patches, i_patch % self.n_img_patches 181 | # row, col = i_patch // (self.W - self.patch_size + 1), i_patch % (self.W - self.patch_size + 1) 182 | # start_idx = i_img * self.W * self.H + row * self.W + col 183 | # idxs = start_idx + torch.cat([torch.arange(self.patch_size,device='cpu') + i * self.W for i in range(self.patch_size)]) 184 | 185 | # rays_o = self.all_rays_o[idxs] 186 | # rays_v = self.all_rays_v[idxs] 187 | # color = self.all_images[idxs] 188 | # depth = self.all_depths[idxs] 189 | # far = self.all_radius[idxs] 190 | # return torch.cat([rays_o, rays_v, color, depth.unsqueeze(-1),far.unsqueeze(-1)], dim=-1) 191 | 192 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1): 193 | """ 194 | Interpolate pose between two cameras. 195 | """ 196 | l = resolution_level 197 | tx = torch.linspace(0, self.W - 1, self.W // l) 198 | ty = torch.linspace(0, self.H - 1, self.H // l) 199 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 200 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 201 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 202 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 203 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio 204 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy() 205 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy() 206 | pose_0 = np.linalg.inv(pose_0) 207 | pose_1 = np.linalg.inv(pose_1) 208 | rot_0 = pose_0[:3, :3] 209 | rot_1 = pose_1[:3, :3] 210 | rots = Rot.from_matrix(np.stack([rot_0, rot_1])) 211 | key_times = [0, 1] 212 | slerp = Slerp(key_times, rots) 213 | rot = slerp(ratio) 214 | pose = np.diag([1.0, 1.0, 1.0, 1.0]) 215 | pose = pose.astype(np.float32) 216 | pose[:3, :3] = rot.as_matrix() 217 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] 218 | pose = np.linalg.inv(pose) 219 | rot = torch.from_numpy(pose[:3, :3]).cuda() 220 | trans = torch.from_numpy(pose[:3, 3]).cuda() 221 | rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 222 | rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3 223 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1) 224 | 225 | def near_far_from_sphere(self, rays_o, rays_d): 226 | #a = torch.sum(rays_d**2, dim=-1, keepdim=True) 227 | #b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True) 228 | #mid = 0.5 * (-b) / a 229 | #near = mid - 1.0 230 | #far = mid + 1.0 231 | near = 0 232 | far = 16 233 | return near, far 234 | 235 | def image_at(self, idx, resolution_level): 236 | img = cv.imread(self.images_lis[idx]) 237 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255) 238 | 239 | -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import cv2 as cv 4 | import numpy as np 5 | import os 6 | from glob import glob 7 | from icecream import ic 8 | from scipy.spatial.transform import Rotation as Rot 9 | from scipy.spatial.transform import Slerp 10 | import ipdb 11 | import cv2 12 | 13 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 14 | def load_K_Rt_from_P(filename, P=None): 15 | if P is None: 16 | lines = open(filename).read().splitlines() 17 | if len(lines) == 4: 18 | lines = lines[1:] 19 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 20 | P = np.asarray(lines).astype(np.float32).squeeze() 21 | 22 | out = cv.decomposeProjectionMatrix(P) 23 | K = out[0] 24 | R = out[1] 25 | t = out[2] 26 | 27 | K = K / K[2, 2] 28 | intrinsics = np.eye(4) 29 | intrinsics[:3, :3] = K 30 | 31 | pose = np.eye(4, dtype=np.float32) 32 | pose[:3, :3] = R.transpose() 33 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 34 | 35 | return intrinsics, pose 36 | 37 | 38 | class Dataset: 39 | def __init__(self, conf,numbers,patch=4,degree=0): 40 | super(Dataset, self).__init__() 41 | print('Load data: Begin:Random_sampling') 42 | 43 | self.device = torch.device('cuda') 44 | self.conf = conf 45 | self.data_dir = conf.get_string('data_dir') 46 | if numbers>0: 47 | self.position = np.load(self.data_dir+'_position.npy') 48 | 49 | images = [] 50 | depths = [] 51 | #ipdb.set_trace() 52 | img = np.array(cv2.imread(self.data_dir+'_color_0_Left_Down_0.0.png', cv2.IMREAD_ANYCOLOR)) / 255. 53 | #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255. 54 | depth = np.array(cv2.imread(self.data_dir+'_depth_0_Left_Down_0.0.exr',cv2.IMREAD_UNCHANGED)[:,:,0]) 55 | images.append(img) 56 | depths.append(depth) 57 | img = np.array(cv2.imread(self.data_dir+'_color_0_Right_0.0.png', cv2.IMREAD_ANYCOLOR)) / 255. 58 | #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255. 59 | depth = np.array(cv2.imread(self.data_dir+'_depth_0_Right_0.0.exr',cv2.IMREAD_UNCHANGED)[:,:,0]) 60 | images.append(img) 61 | depths.append(depth) 62 | img = np.array(cv2.imread(self.data_dir+'_color_0_Up_0.0.png', cv2.IMREAD_ANYCOLOR)) / 255. 63 | #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255. 64 | depth = np.array(cv2.imread(self.data_dir+'_depth_0_Up_0.0.exr',cv2.IMREAD_UNCHANGED)[:,:,0]) 65 | images.append(img) 66 | depths.append(depth) 67 | for i in range(numbers): 68 | img = np.array(cv2.imread(self.data_dir+'_color_0_%02d_0.0.png'%(i+1), cv2.IMREAD_ANYCOLOR)) / 255. 69 | #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255. 70 | depth = np.array(cv2.imread(self.data_dir+'_depth_0_%02d_0.0.exr'%(i+1),cv2.IMREAD_UNCHANGED)[:,:,0]) 71 | images.append(img) 72 | depths.append(depth) 73 | 74 | self.n_images = len(images) 75 | 76 | self.images_np = np.stack(images, axis=0) 77 | self.depths_np = np.stack(depths, axis=0) 78 | 79 | self.images = torch.from_numpy(self.images_np.astype(np.float32)) # [n_images, H, W, 3] 80 | self.depths = torch.from_numpy(self.depths_np.astype(np.float32)) # [n_images, H, W] 81 | self.H, self.W = self.images.shape[1], self.images.shape[2] 82 | self.image_pixels = self.H * self.W 83 | self.patch_size = patch 84 | self.n_img_patches = self.image_pixels 85 | self.n_patches = self.n_img_patches * (self.n_images) 86 | self.patch_row = torch.arange(patch, device='cpu').repeat(patch,1).reshape(-1) 87 | self.patch_col = torch.arange(patch, device='cpu').repeat(patch,1).permute(1,0).reshape(-1) 88 | #ipdb.set_trace() 89 | cen_x = (self.W - 1) / 2.0 90 | cen_y = (self.H - 1) / 2.0 91 | theta = (2 * (np.arange(self.W) - cen_x) / self.W) * np.pi 92 | phi = (2 * (np.arange(self.H) - cen_y) / self.H) * (np.pi / 2) 93 | theta = np.tile(theta[None, :], [self.H, 1]) 94 | phi = np.tile(phi[None, :], [self.W, 1]).T 95 | 96 | x = (np.cos(phi) * np.sin(theta)).reshape([self.H, self.W, 1]) 97 | y = (np.sin(phi)).reshape([self.H, self.W, 1]) 98 | z = (np.cos(phi) * np.cos(theta)).reshape([self.H, self.W, 1]) 99 | directions = np.concatenate([x, y, z], axis=-1).reshape(-1,3) 100 | ellipsoid = np.concatenate([16*x, 16*y, 16*z], axis=-1) 101 | radius = np.linalg.norm(ellipsoid,ord=2,axis=-1) 102 | #phi_ = torch.arcsin(y) 103 | #theta_ = torch.arctan(x/z)[...,0] 104 | #theta_[z[...,0]<0] += torch.pi 105 | #theta_[theta_>np.pi] -= 2*np.pi 106 | ro = degree/180*np.pi 107 | rotation = np.array([[np.cos(ro),np.sin(ro),0],[-np.sin(ro),np.cos(ro),0],[0,0,1]]) 108 | directions = (rotation @ directions.T ).T 109 | directions = directions.reshape([self.H,self.W,-1])[None,:,:,:] 110 | 111 | directions = np.concatenate([directions, directions, directions,directions,directions,directions,directions,directions,directions,directions,directions,directions,directions], axis=0) 112 | #directions = np.concatenate([directions, directions], axis=0) 113 | origins = np.zeros(directions.shape, dtype = directions.dtype) 114 | origins[1] = origins[1] + (rotation @ np.array([0.26,0,0]).T)[None,None,:] 115 | origins[2] = origins[2] + (rotation @ np.array([0,-0.26,0]).T)[None,None,:] 116 | for i in range(numbers): 117 | origins[i+3] = origins[i+3] + np.array([self.position[i][0],-self.position[i][2],self.position[i][1]])[None,None,:] 118 | #ipdb.set_trace() 119 | self.rays_v = torch.from_numpy(directions.astype(np.float32)) 120 | self.rays_o = torch.from_numpy(origins.astype(np.float32)) 121 | self.radius = torch.from_numpy(radius.astype(np.float32)) 122 | self.all_radius = torch.from_numpy(radius[None,:].astype(np.float32).repeat(self.n_images,axis=0)) 123 | self.all_images = self.images.reshape(-1,3) 124 | self.all_depths = self.depths.reshape(-1) 125 | self.all_rays_v = self.rays_v.reshape(-1,3) 126 | self.all_rays_o = self.rays_o.reshape(-1,3) 127 | self.all_radius = self.all_radius.reshape(-1) 128 | 129 | 130 | print('Load data: End') 131 | 132 | def gen_rays_at(self, img_idx, resolution_level=1): 133 | """ 134 | Generate rays at world space from one camera. 135 | """ 136 | l = resolution_level 137 | 138 | 139 | return self.rays_o[img_idx].cuda(), self.rays_v[img_idx].cuda(), self.depths[img_idx].cuda(), self.radius.cuda() 140 | 141 | 142 | def __len__(self): 143 | return 512*256*self.n_images // self.patch_size**2 144 | 145 | 146 | def gen_random_rays_at(self, img_idx, batch_size): 147 | """ 148 | Generate random rays at world space from one camera. 149 | """ 150 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size],device='cpu') 151 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size],device='cpu') 152 | color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3 153 | rays_v = self.rays_v[img_idx][(pixels_y, pixels_x)] # batch_size, 3 154 | rays_o = self.rays_o[img_idx][(pixels_y, pixels_x)] # batch_size, 3 155 | depth = self.depths[img_idx][(pixels_y, pixels_x)] 156 | #ipdb.set_trace() 157 | far = self.radius[(pixels_y, pixels_x)] 158 | return torch.cat([rays_o, rays_v, color, depth.unsqueeze(-1),far.unsqueeze(-1)], dim=-1) # batch_size, 9 159 | 160 | def __getitem__(self, index): 161 | #ipdb.set_trace() 162 | pixels_x = torch.randint(low=0, high=self.W, size=[self.patch_size*self.patch_size],device='cpu')#[0].item() 163 | pixels_y = torch.randint(low=0, high=self.H, size=[self.patch_size*self.patch_size],device='cpu')#[0].item() 164 | #pixels_x = x + self.patch_row 165 | #pixels_x[pixels_x >= self.W] = pixels_x[pixels_x >= self.W] - self.W 166 | #pixels_y = y + self.patch_col 167 | #pixels_y[pixels_y >= self.H] = pixels_y[pixels_y >= self.H] - self.H 168 | img_idx = torch.randint(low=0, high=self.n_images, size=[1],device='cpu')[0].item() 169 | 170 | rays_o = self.rays_o[img_idx][(pixels_y, pixels_x)] 171 | rays_v = self.rays_v[img_idx][(pixels_y, pixels_x)] 172 | color = self.images[img_idx][(pixels_y, pixels_x)] 173 | depth = self.depths[img_idx][(pixels_y, pixels_x)] 174 | far = self.radius[(pixels_y, pixels_x)] 175 | return torch.cat([rays_o, rays_v, color, depth.unsqueeze(-1),far.unsqueeze(-1)], dim=-1) 176 | 177 | # def __getitem__(self, index): 178 | # i_patch = torch.randint(high=self.n_patches, size=(1,),device='cpu')[0].item() 179 | # i_img, i_patch = i_patch // self.n_img_patches, i_patch % self.n_img_patches 180 | # row, col = i_patch // (self.W - self.patch_size + 1), i_patch % (self.W - self.patch_size + 1) 181 | # start_idx = i_img * self.W * self.H + row * self.W + col 182 | # idxs = start_idx + torch.cat([torch.arange(self.patch_size,device='cpu') + i * self.W for i in range(self.patch_size)]) 183 | 184 | # rays_o = self.all_rays_o[idxs] 185 | # rays_v = self.all_rays_v[idxs] 186 | # color = self.all_images[idxs] 187 | # depth = self.all_depths[idxs] 188 | # far = self.all_radius[idxs] 189 | # return torch.cat([rays_o, rays_v, color, depth.unsqueeze(-1),far.unsqueeze(-1)], dim=-1) 190 | 191 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1): 192 | """ 193 | Interpolate pose between two cameras. 194 | """ 195 | l = resolution_level 196 | tx = torch.linspace(0, self.W - 1, self.W // l) 197 | ty = torch.linspace(0, self.H - 1, self.H // l) 198 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 199 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 200 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 201 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 202 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio 203 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy() 204 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy() 205 | pose_0 = np.linalg.inv(pose_0) 206 | pose_1 = np.linalg.inv(pose_1) 207 | rot_0 = pose_0[:3, :3] 208 | rot_1 = pose_1[:3, :3] 209 | rots = Rot.from_matrix(np.stack([rot_0, rot_1])) 210 | key_times = [0, 1] 211 | slerp = Slerp(key_times, rots) 212 | rot = slerp(ratio) 213 | pose = np.diag([1.0, 1.0, 1.0, 1.0]) 214 | pose = pose.astype(np.float32) 215 | pose[:3, :3] = rot.as_matrix() 216 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] 217 | pose = np.linalg.inv(pose) 218 | rot = torch.from_numpy(pose[:3, :3]).cuda() 219 | trans = torch.from_numpy(pose[:3, 3]).cuda() 220 | rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 221 | rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3 222 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1) 223 | 224 | def near_far_from_sphere(self, rays_o, rays_d): 225 | #a = torch.sum(rays_d**2, dim=-1, keepdim=True) 226 | #b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True) 227 | #mid = 0.5 * (-b) / a 228 | #near = mid - 1.0 229 | #far = mid + 1.0 230 | near = 0 231 | far = 16 232 | return near, far 233 | 234 | def image_at(self, idx, resolution_level): 235 | img = cv.imread(self.images_lis[idx]) 236 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255) 237 | 238 | -------------------------------------------------------------------------------- /models/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import logging 6 | import mcubes 7 | from icecream import ic 8 | import ipdb 9 | 10 | 11 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 12 | # Get pdf 13 | weights = weights + 1e-5 # prevent nans 14 | pdf = weights / torch.sum(weights, -1, keepdim=True) 15 | cdf = torch.cumsum(pdf, -1) 16 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) 17 | 18 | # Take uniform samples 19 | if det: 20 | u = torch.linspace(0., 1., steps=N_samples) 21 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 22 | else: 23 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 24 | 25 | # Pytest, overwrite u with numpy's fixed random numbers 26 | if pytest: 27 | np.random.seed(0) 28 | new_shape = list(cdf.shape[:-1]) + [N_samples] 29 | if det: 30 | u = np.linspace(0., 1., N_samples) 31 | u = np.broadcast_to(u, new_shape) 32 | else: 33 | u = np.random.rand(*new_shape) 34 | u = torch.Tensor(u) 35 | 36 | # Invert CDF 37 | u = u.contiguous() 38 | inds = torch.searchsorted(cdf, u, right=True) 39 | below = torch.max(torch.zeros_like(inds-1), inds-1) 40 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) 41 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 42 | 43 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 44 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 45 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 46 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 47 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 48 | 49 | denom = (cdf_g[...,1]-cdf_g[...,0]) 50 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) 51 | t = (u-cdf_g[...,0])/denom 52 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 53 | 54 | return samples 55 | 56 | 57 | def sample_pdf(bins, weights, n_samples, det=False): 58 | #ipdb.set_trace() 59 | # This implementation is from NeRF 60 | # Get pdf 61 | weights = weights + 1e-5 # prevent nans 62 | pdf = weights / torch.sum(weights, -1, keepdim=True) 63 | cdf = torch.cumsum(pdf, -1) 64 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 65 | # Take uniform samples 66 | if det: 67 | u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples) 68 | u = u.expand(list(cdf.shape[:-1]) + [n_samples]) 69 | else: 70 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples]) 71 | 72 | # Invert CDF 73 | u = u.contiguous() 74 | inds = torch.searchsorted(cdf, u, right=True) 75 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 76 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 77 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 78 | 79 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 80 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 81 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 82 | 83 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 84 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 85 | t = (u - cdf_g[..., 0]) / denom 86 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 87 | 88 | return samples 89 | 90 | 91 | class Renderer: 92 | def __init__(self, 93 | color_network, 94 | sdf_network, 95 | deviation_network, 96 | n_samples, 97 | n_importance, 98 | perturb): 99 | self.color_network = color_network 100 | self.sdf_network = sdf_network 101 | self.deviation_network = deviation_network 102 | self.n_samples = n_samples 103 | self.n_importance = n_importance 104 | self.perturb = perturb 105 | 106 | def render_inference(self, rays_o, rays_d, z_vals, sample_dist, color_network,sdf,deviation_network,cos_anneal_ratio, background_rgb=None): 107 | """ 108 | Render background 109 | """ 110 | batch_size, n_samples = z_vals.shape 111 | 112 | dists = z_vals[..., 1:] - z_vals[..., :-1] 113 | dists = torch.cat([dists, sample_dist], -1) 114 | mid_z_vals = z_vals + dists * 0.5 115 | 116 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 117 | 118 | dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3) 119 | 120 | #pts_color = pts_color.reshape(-1, 3 + int(self.n_outside > 0)) 121 | pts = pts.reshape(-1, 3) 122 | dirs = dirs.reshape(-1, 3) 123 | 124 | gradient, sdf_nn_output = sdf.gradient_inference(pts) 125 | gradients = gradient.squeeze() 126 | with torch.no_grad(): 127 | sdf = sdf_nn_output[:, :1] 128 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter 129 | inv_s = inv_s.expand(batch_size * n_samples, 1) 130 | 131 | true_cos = (dirs * gradients).sum(-1, keepdim=True) 132 | 133 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 134 | # the cos value "not dead" at the beginning training iterations, for better convergence. 135 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + 136 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive 137 | 138 | # Estimate signed distances at section points 139 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 140 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 141 | 142 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 143 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 144 | 145 | p = prev_cdf - next_cdf 146 | c = prev_cdf 147 | 148 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0) 149 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 150 | 151 | return { 152 | 'alpha': alpha, 153 | 'weights': weights, 154 | 'mid_z':mid_z_vals 155 | } 156 | 157 | def render_gradient(self, rays_o, rays_d, z_vals, sample_dist, color_network,sdf,deviation_network,cos_anneal_ratio, background_rgb=None): 158 | """ 159 | Render background 160 | """ 161 | batch_size, n_samples = z_vals.shape 162 | 163 | # Section length 164 | dists = z_vals[..., 1:] - z_vals[..., :-1] 165 | dists = torch.cat([dists, sample_dist], -1) 166 | mid_z_vals = z_vals + dists * 0.5 167 | #ipdb.set_trace() 168 | # Section midpoints 169 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 170 | 171 | dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(10e-4, 1e10) 172 | pts_color = torch.cat([pts , 1/dis_to_center], dim=-1) # batch_size, n_samples, 4 173 | 174 | dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3) 175 | 176 | pts_color = pts_color.reshape(-1, 4) 177 | pts = pts.reshape(-1, 3) 178 | dirs = dirs.reshape(-1, 3) 179 | 180 | gradient, sdf_nn_output = sdf.gradient(pts) 181 | gradients = gradient.squeeze() 182 | sdf = sdf_nn_output[:, :1] 183 | #sdf = torch.clamp(sdf, min=-16, max=16) 184 | feature_vector = sdf_nn_output[:, 1:] 185 | sampled_color = color_network(pts, dirs, feature_vector) 186 | #################SDF rendering equation#################### 187 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter 188 | inv_s = inv_s.expand(batch_size * n_samples, 1) 189 | 190 | true_cos = (dirs * gradients).sum(-1, keepdim=True) 191 | 192 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 193 | # the cos value "not dead" at the beginning training iterations, for better convergence. 194 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + 195 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive 196 | 197 | # Estimate signed distances at section points 198 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 199 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 200 | 201 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 202 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 203 | 204 | p = prev_cdf - next_cdf 205 | c = prev_cdf 206 | 207 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0) 208 | #############classic NeRF rendering equation############### 209 | # alpha = 1.0 - torch.exp(-F.softplus(sdf.reshape(batch_size, n_samples)) * dists) 210 | # alpha = alpha.reshape(batch_size, n_samples) 211 | ########################################################### 212 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 213 | sampled_color = sampled_color.reshape(batch_size, n_samples, 3) 214 | color = (weights[:, :, None] * sampled_color).sum(dim=1) 215 | if background_rgb is not None: 216 | color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True)) 217 | 218 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples) 219 | relax_inside_sphere = (pts_norm < 10).float().detach() 220 | 221 | gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2, 222 | dim=-1) - 1.0) ** 2 223 | gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5) 224 | 225 | return { 226 | 'color': color, 227 | 'sampled_color': sampled_color, 228 | 'gradients': gradients.reshape(batch_size, n_samples, 3), 229 | 'alpha': alpha, 230 | 'weights': weights, 231 | 'mid_z':mid_z_vals, 232 | 'sdf':sdf, 233 | 'gradient_error':gradient_error 234 | } 235 | 236 | 237 | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False): 238 | batch_size, n_samples = z_vals.shape 239 | _, n_importance = new_z_vals.shape 240 | pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] 241 | z_vals = torch.cat([z_vals, new_z_vals], dim=-1) 242 | z_vals, index = torch.sort(z_vals, dim=-1) 243 | 244 | if not last: 245 | new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance) 246 | sdf = torch.cat([sdf, new_sdf], dim=-1) 247 | xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) 248 | index = index.reshape(-1) 249 | sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) 250 | 251 | return z_vals, sdf 252 | 253 | def render(self, rays_o, rays_d, near, far, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0): 254 | #ipdb.set_trace() 255 | batch_size = len(rays_o) 256 | sample_dist = (far - near)[...,None] / self.n_samples # Assuming the region of interest is a unit sphere 257 | z_vals = torch.linspace(0.0, 1.0, self.n_samples) 258 | z_vals = near + (far - near)[...,None] * z_vals[None, :] 259 | 260 | perturb = self.perturb 261 | 262 | if perturb_overwrite >= 0: 263 | perturb = perturb_overwrite 264 | if perturb > 0: 265 | t_rand = (torch.rand([batch_size, 1]) - 0.5) 266 | z_vals = z_vals + t_rand * (far - near)[...,None] / self.n_samples 267 | else: 268 | t_rand = torch.zeros([batch_size, 1]) 269 | z_vals = z_vals + t_rand * (far - near)[...,None] / self.n_samples 270 | 271 | 272 | ret = self.render_inference(rays_o, rays_d, z_vals, sample_dist, self.color_network, self.sdf_network, self.deviation_network, cos_anneal_ratio) 273 | z_samples = sample_pdf(0.5*(ret['mid_z'][..., 1:] + ret['mid_z'][..., :-1]), ret['weights'][...,1:-1], self.n_importance, det=(perturb == 0.)) 274 | z_vals_feed = torch.cat([z_vals, z_samples], dim=-1) 275 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1) 276 | 277 | ret_outside = self.render_gradient(rays_o, rays_d, z_vals_feed, sample_dist, self.color_network, self.sdf_network, self.deviation_network,cos_anneal_ratio) 278 | 279 | depth_fine = torch.sum(ret_outside['mid_z']*ret_outside['weights'], -1) 280 | 281 | pts = rays_o + rays_d * depth_fine[...,None] 282 | return { 283 | #'color_coarse': ret['color'], 284 | 'color_fine': ret_outside['color'], 285 | 'depth_fine': depth_fine, 286 | 'val_z':ret_outside['mid_z'], 287 | 'weights':ret_outside['weights'], 288 | 'sdf':ret_outside['sdf'], 289 | 'pts':pts, 290 | 'gradients_out':ret_outside['gradients'], 291 | 'gradient_error':ret_outside['gradient_error'] 292 | } 293 | 294 | -------------------------------------------------------------------------------- /exp_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import argparse 5 | import ipdb 6 | import numpy as np 7 | import cv2 as cv 8 | import trimesh 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from models.dataset_patch import Dataset_patch 13 | from models.dataset import Dataset 14 | from shutil import copyfile 15 | from icecream import ic 16 | from tqdm import tqdm 17 | from pyhocon import ConfigFactory 18 | from models.fields import SDFNetwork, SingleVarianceNetwork, COLORNetwork 19 | from models.renderer import Renderer 20 | import matplotlib.pyplot as plt 21 | import random 22 | import time 23 | seed = 841232111 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | def standard_metrics(input_gt_depth_image,pred_depth_image, verbose=True): 30 | input_gt_depth = input_gt_depth_image.copy() 31 | pred_depth = pred_depth_image.copy() 32 | 33 | input_gt_depth[input_gt_depth>10] = 0 34 | 35 | n = np.sum((input_gt_depth > 1e-3)) ####valid samples 36 | 37 | ###invalid samples - no measures 38 | idxs = ( (input_gt_depth <= 1e-3) ) 39 | pred_depth[idxs] = 1 40 | input_gt_depth[idxs] = 1 41 | 42 | if(verbose): 43 | print('valid samples:',n,'masked samples:', np.sum(idxs)) 44 | 45 | ####STEP 1: compute delta################################################################ 46 | #######prepare mask 47 | pred_d_gt = pred_depth / input_gt_depth 48 | pred_d_gt[idxs] = 100 49 | gt_d_pred = input_gt_depth / pred_depth 50 | gt_d_pred[idxs] = 100 51 | 52 | Threshold_1_25 = np.sum(np.maximum(pred_d_gt, gt_d_pred) < 1.25) / n 53 | Threshold_1_25_2 = np.sum(np.maximum(pred_d_gt, gt_d_pred) < 1.25 * 1.25) / n 54 | Threshold_1_25_3 = np.sum(np.maximum(pred_d_gt, gt_d_pred) < 1.25 * 1.25 * 1.25) / n 55 | ######################################################################################## 56 | 57 | #####STEP 2: compute mean error########################################################## 58 | input_gt_depth_norm = input_gt_depth / np.max(input_gt_depth) 59 | pred_depth_norm = pred_depth / np.max(pred_depth) 60 | if(verbose): 61 | print(np.max(input_gt_depth),np.max(pred_depth)) 62 | log_pred = np.log(pred_depth_norm) 63 | log_gt = np.log(input_gt_depth_norm) 64 | 65 | ###OmniDepth: 66 | RMSE_linear = ((pred_depth - input_gt_depth) ** 2).mean() 67 | RMSE_log = np.sqrt(((log_pred - log_gt) ** 2).mean()) 68 | ARD = (np.abs((pred_depth_norm - input_gt_depth_norm)) / input_gt_depth_norm).mean() 69 | SRD = (((pred_depth_norm - input_gt_depth_norm)** 2) / input_gt_depth_norm).mean() 70 | MAE = np.abs((pred_depth - input_gt_depth)).mean() 71 | REL = (np.abs((pred_depth - input_gt_depth)) / input_gt_depth).mean() 72 | if(verbose): 73 | print('MAE\tREL\tLog\tSRD\tARD\tRMSE\tTh1\tTh2\tTh3') 74 | print('%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f'%(MAE,REL,RMSE_log,SRD,ARD,RMSE_linear,Threshold_1_25,Threshold_1_25_2,Threshold_1_25_3)) 75 | 76 | return [MAE,REL,RMSE_linear,Threshold_1_25,Threshold_1_25_2,Threshold_1_25_3] 77 | 78 | 79 | 80 | class Runner: 81 | def __init__(self, conf_path, base_exp_dir, case='CASE_NAME', is_continue=False,d = 0,numbers=0, random_epoch = None, ps = None): 82 | self.device = torch.device('cuda') 83 | 84 | # Configuration 85 | self.conf_path = conf_path 86 | f = open(self.conf_path) 87 | conf_text = f.read() 88 | conf_text = conf_text.replace('CASE_NAME', case) 89 | f.close() 90 | self.bs = 32 91 | self.random_epoch = random_epoch 92 | self.conf = ConfigFactory.parse_string(conf_text) 93 | #ipdb.set_trace() 94 | self.conf['dataset.data_dir'] = self.conf['dataset.data_dir'].replace('CASE_NAME', case) 95 | self.base_exp_dir = base_exp_dir+ '/%d_images/'%numbers+ case 96 | os.makedirs(self.base_exp_dir, exist_ok=True) 97 | self.dataset = Dataset(self.conf['dataset'],numbers,patch = 4,degree=args.degree) 98 | self.dataset_patch = Dataset_patch(self.conf['dataset'],numbers,patch = ps,degree=args.degree) 99 | print(self.dataset.n_images) 100 | print(self.dataset_patch.n_images) 101 | print(self.dataset_patch[0].shape) 102 | self.dataloader = torch.utils.data.DataLoader( 103 | self.dataset, 104 | batch_size=self.bs, 105 | shuffle=False, 106 | num_workers=8, 107 | drop_last=False, 108 | pin_memory=True 109 | ) 110 | self.dataloader_patch = torch.utils.data.DataLoader( 111 | self.dataset_patch, 112 | batch_size=self.bs, 113 | shuffle=False, 114 | num_workers=8, 115 | drop_last=False, 116 | pin_memory=True 117 | ) 118 | self.iter_step = 0 119 | 120 | # Training parameters 121 | self.end_iter = self.conf.get_int('train.end_iter') 122 | self.save_freq = self.conf.get_int('train.save_freq') 123 | self.report_freq = self.conf.get_int('train.report_freq') 124 | self.val_freq = self.conf.get_int('train.val_freq') 125 | self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') 126 | self.batch_size = self.conf.get_int('train.batch_size') 127 | self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') 128 | self.learning_rate = self.conf.get_float('train.learning_rate') 129 | self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha') 130 | self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') 131 | self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0) 132 | self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) 133 | 134 | # Weights 135 | self.igr_weight = self.conf.get_float('train.igr_weight') 136 | self.mask_weight = self.conf.get_float('train.mask_weight') 137 | self.is_continue = is_continue 138 | self.model_list = [] 139 | self.writer = None 140 | 141 | # Networks 142 | params_to_train = [] 143 | self.color_network = COLORNetwork(**self.conf['model.color_network']).to(self.device) 144 | self.sdf_network = SDFNetwork(distance = d,**self.conf['model.sdf_network']).to(self.device) 145 | self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) 146 | params_to_train += list(self.color_network.parameters()) 147 | params_to_train += list(self.sdf_network.parameters()) 148 | params_to_train += list(self.deviation_network.parameters()) 149 | 150 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate) 151 | 152 | self.renderer = Renderer(self.color_network, 153 | self.sdf_network, 154 | self.deviation_network, 155 | **self.conf['model.renderer']) 156 | #ipdb.set_trace() 157 | # Load checkpoint 158 | latest_model_name = None 159 | if is_continue is not None: 160 | latest_model_name = 'exp/%.1f_%s/%d_images/%s/checkpoints/ckpt_%06d.pth'%(d,is_continue,numbers,case,80) 161 | 162 | if latest_model_name is not None: 163 | print('Find checkpoint: {}'.format(latest_model_name)) 164 | self.load_checkpoint(latest_model_name) 165 | 166 | # Backup codes and configs for debug 167 | self.file_backup() 168 | 169 | def train(self): 170 | self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs')) 171 | self.update_learning_rate() 172 | res_step = self.end_iter - self.iter_step 173 | _,_ = self.validate_image(idx=2) 174 | near = 0 175 | out_average = [] 176 | #ipdb.set_trace() 177 | if self.iter_step < self.random_epoch: 178 | for epoch in tqdm(range(self.random_epoch - self.iter_step)): 179 | for i, data in enumerate(self.dataloader): 180 | rays_o, rays_d, true_rgb, depth_gt, far = data[..., :3], data[..., 3: 6], data[..., 6: 9], data[..., 9], data[..., 10] 181 | rays_o = rays_o.reshape(-1,3).cuda() 182 | rays_d = rays_d.reshape(-1,3).cuda() 183 | true_rgb = true_rgb.reshape(-1,3).cuda() 184 | depth_gt = depth_gt.reshape(-1).cuda() 185 | far = far.reshape(-1).cuda() 186 | background_rgb = None 187 | 188 | mask = torch.ones_like(rays_o[:,:1]) 189 | 190 | mask_sum = mask.sum() + 1e-5 191 | render_out = self.renderer.render(rays_o, rays_d, near, far, 192 | background_rgb=background_rgb, 193 | cos_anneal_ratio=self.get_cos_anneal_ratio()) 194 | 195 | #ipdb.set_trace() 196 | color_fine = render_out['color_fine'] 197 | depth_mask = depth_gt < 10.0 198 | depth_l1 = torch.abs((render_out['depth_fine'] - depth_gt)*depth_mask).mean() 199 | # Loss 200 | color_error_fine = (color_fine - true_rgb) * mask 201 | color_fine_loss = F.l1_loss(color_error_fine, torch.zeros_like(color_error_fine), reduction='sum') / mask_sum 202 | psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb)**2 * mask).sum() / (mask_sum * 3.0)).sqrt()) 203 | eikonal_loss = render_out['gradient_error'] 204 | 205 | loss = color_fine_loss + eikonal_loss * self.igr_weight #+ color_coarse_loss 206 | 207 | self.optimizer.zero_grad() 208 | loss.backward() 209 | self.optimizer.step() 210 | self.iter_step += 1 211 | #ipdb.set_trace() 212 | 213 | if self.iter_step % self.report_freq == 0: 214 | 215 | if self.iter_step > 0 : 216 | self.writer.add_scalar('Loss/loss', loss, self.iter_step) 217 | self.writer.add_scalar('Loss/color_fine_loss', color_fine_loss, self.iter_step) 218 | self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step) 219 | self.writer.add_scalar('Statistics/depth', depth_l1, self.iter_step) 220 | print('iter:{:8>d} loss = {} lr = {}'.format(self.iter_step, loss,self.optimizer.param_groups[0]['lr'])) 221 | if self.iter_step % self.val_freq == 0: 222 | errors_weight, weight_sdf = self.validate_image(idx=2) 223 | self.writer.add_scalar('Weight/MRE', errors_weight[1], self.iter_step) 224 | self.writer.add_scalar('Weight/RMSE', errors_weight[2], self.iter_step) 225 | self.writer.add_scalar('Weight/MAE', errors_weight[0], self.iter_step) 226 | if self.iter_step % self.save_freq == 0: 227 | self.save_checkpoint() 228 | self.update_learning_rate() 229 | res_epochs = self.end_iter - self.iter_step 230 | for epoch in tqdm(range(res_epochs)): 231 | for i, data in enumerate(self.dataloader_patch): 232 | rays_o, rays_d, true_rgb, depth_gt, far = data[..., :3], data[..., 3: 6], data[..., 6: 9], data[..., 9], data[..., 10] 233 | batch = rays_o.shape[0] 234 | rays_o = rays_o.reshape(-1,3).cuda() 235 | rays_d = rays_d.reshape(-1,3).cuda() 236 | true_rgb = true_rgb.reshape(-1,3).cuda() 237 | depth_gt = depth_gt.reshape(-1).cuda() 238 | far = far.reshape(-1).cuda() 239 | background_rgb = None 240 | if self.use_white_bkgd: 241 | background_rgb = torch.ones([1, 3]) 242 | 243 | mask = torch.ones_like(rays_o[:,:1]) 244 | 245 | mask_sum = mask.sum() + 1e-5 246 | render_out = self.renderer.render(rays_o, rays_d, near, far, 247 | background_rgb=background_rgb, 248 | cos_anneal_ratio=self.get_cos_anneal_ratio()) 249 | normals = render_out['gradients_out'] * render_out['weights'][..., None] 250 | #ipdb.set_trace() 251 | normals = -F.normalize(normals.sum(dim=1),p=2,dim=-1).reshape(batch,-1,3).mean(1) 252 | matrix_a = render_out['pts'].reshape(batch,-1,3) 253 | matrix_a_trans = matrix_a.permute(0,2,1) 254 | matrix_b = torch.ones(matrix_a.shape[:2]).unsqueeze(-1) 255 | point_multi = torch.matmul(matrix_a_trans,matrix_a) 256 | point_multi_inverse = torch.inverse(point_multi) 257 | normals_from_points = torch.matmul(torch.matmul(point_multi_inverse,matrix_a_trans),matrix_b).squeeze(-1) 258 | normals_from_points = F.normalize(normals_from_points,p=2,dim=1) 259 | normal_error_fine = (1 - torch.sum(normals_from_points * normals,-1)).mean() 260 | color_fine = render_out['color_fine'] 261 | depth_mask = depth_gt < 10.0 262 | depth_l1 = torch.abs((render_out['depth_fine'] - depth_gt)*depth_mask).mean() 263 | # Loss 264 | color_error_fine = (color_fine - true_rgb) * mask 265 | color_fine_loss = F.l1_loss(color_error_fine, torch.zeros_like(color_error_fine), reduction='sum') / mask_sum 266 | psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb)**2 * mask).sum() / (mask_sum * 3.0)).sqrt()) 267 | eikonal_loss = render_out['gradient_error'] 268 | 269 | loss = color_fine_loss + eikonal_loss * self.igr_weight + 0.01*normal_error_fine#+ color_coarse_loss 270 | 271 | self.optimizer.zero_grad() 272 | loss.backward() 273 | self.optimizer.step() 274 | self.iter_step += 1 275 | 276 | if self.iter_step % self.report_freq == 0: 277 | #print(self.base_exp_dir) 278 | if self.iter_step > 0 : 279 | self.writer.add_scalar('Loss/loss', loss, self.iter_step) 280 | self.writer.add_scalar('Loss/color_fine_loss', color_fine_loss, self.iter_step) 281 | self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step) 282 | self.writer.add_scalar('Statistics/depth', depth_l1, self.iter_step) 283 | print('iter:{:8>d} loss = {} normal_loss = {} lr = {}'.format(self.iter_step, loss, normal_error_fine,self.optimizer.param_groups[0]['lr'])) 284 | if self.iter_step % self.val_freq == 0: 285 | errors_weight, weight_sdf = self.validate_image(idx=2) 286 | self.writer.add_scalar('Weight/MRE', errors_weight[1], self.iter_step) 287 | self.writer.add_scalar('Weight/RMSE', errors_weight[2], self.iter_step) 288 | self.writer.add_scalar('Weight/MAE', errors_weight[0], self.iter_step) 289 | if self.iter_step % self.save_freq == 0: 290 | self.save_checkpoint() 291 | self.update_learning_rate() 292 | #np.save(os.path.join(self.base_exp_dir, 'average'),np.array(out_average)) 293 | 294 | def get_image_perm(self): 295 | return torch.randperm(self.dataset.n_images) 296 | 297 | def get_cos_anneal_ratio(self): 298 | if self.anneal_end == 0.0: 299 | return 1.0 300 | else: 301 | return np.min([1.0, self.iter_step / self.anneal_end]) 302 | 303 | def update_learning_rate(self): 304 | if self.iter_step < self.warm_up_end: 305 | learning_factor = self.iter_step / self.warm_up_end 306 | else: 307 | alpha = self.learning_rate_alpha 308 | progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end) 309 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha 310 | 311 | for g in self.optimizer.param_groups: 312 | g['lr'] = self.learning_rate * learning_factor 313 | 314 | def file_backup(self): 315 | dir_lis = self.conf['general.recording'] 316 | os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) 317 | for dir_name in dir_lis: 318 | cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) 319 | os.makedirs(cur_dir, exist_ok=True) 320 | files = os.listdir(dir_name) 321 | for f_name in files: 322 | if f_name[-3:] == '.py': 323 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) 324 | 325 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) 326 | 327 | def load_checkpoint(self, checkpoint_name): 328 | #ipdb.set_trace() 329 | checkpoint = torch.load(os.path.join(checkpoint_name), map_location=self.device) 330 | self.color_network.load_state_dict(checkpoint['color_network']) 331 | self.sdf_network.load_state_dict(checkpoint['sdf_network']) 332 | self.deviation_network.load_state_dict(checkpoint['variance_network_fine']) 333 | self.optimizer.load_state_dict(checkpoint['optimizer']) 334 | self.iter_step = checkpoint['iter_step'] 335 | #ipdb.set_trace() 336 | torch.cuda.set_rng_state(checkpoint['rng'].cpu()) 337 | np.random.set_state(checkpoint['rng_np']) 338 | 339 | logging.info('End') 340 | 341 | def save_checkpoint(self): 342 | checkpoint = { 343 | 'color_network': self.color_network.state_dict(), 344 | 'sdf_network': self.sdf_network.state_dict(), 345 | 'variance_network_fine': self.deviation_network.state_dict(), 346 | 'optimizer': self.optimizer.state_dict(), 347 | 'iter_step': self.iter_step, 348 | 'rng':torch.cuda.get_rng_state(), 349 | 'rng_np':np.random.get_state(), 350 | } 351 | 352 | os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) 353 | torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) 354 | 355 | def validate_image(self, idx=-1, resolution_level=-1, verbose = True): 356 | if idx < 0: 357 | idx = np.random.randint(self.dataset.n_images) 358 | if(verbose): 359 | print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx)) 360 | 361 | if resolution_level < 0: 362 | resolution_level = self.validate_resolution_level 363 | rays_o, rays_d, gt_depth, far = self.dataset.gen_rays_at(idx, resolution_level=resolution_level) 364 | H, W, _ = rays_o.shape 365 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 366 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 367 | far = far.reshape(-1).split(self.batch_size) 368 | out_rgb_fine = [] 369 | out_normal_fine = [] 370 | out_depth = [] 371 | #out_sdf = [] 372 | out_val_z = [] 373 | out_weights = [] 374 | out_normal = [] 375 | out_pts = [] 376 | near = 0 377 | for rays_o_batch, rays_d_batch, far_batch in zip(rays_o, rays_d,far): 378 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 379 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 380 | #ipdb.set_trace() 381 | render_out = self.renderer.render(rays_o_batch, 382 | rays_d_batch, 383 | near, 384 | far_batch, 385 | perturb_overwrite = 0, 386 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 387 | background_rgb=background_rgb) 388 | 389 | def feasible(key): return (key in render_out) and (render_out[key] is not None) 390 | 391 | if feasible('color_fine'): 392 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) 393 | out_depth.append(render_out['depth_fine'].detach().cpu().numpy()) 394 | #out_sdf.append(render_out['sdf'].detach().cpu().numpy()) 395 | out_val_z.append(render_out['val_z'].detach().cpu().numpy()) 396 | out_weights.append(render_out['weights'].detach().cpu().numpy()) 397 | #n_samples = self.renderer.n_samples + self.renderer.n_importance 398 | normals = render_out['gradients_out'] * render_out['weights'][..., None] 399 | normals = F.normalize(normals.sum(dim=1),p=2,dim=-1).detach().cpu().numpy() 400 | out_normal.append(normals) 401 | out_pts.append(render_out['pts'].detach().cpu().numpy()) 402 | 403 | if feasible('gradients') and feasible('weights'): 404 | n_samples = self.renderer.n_samples + self.renderer.n_importance 405 | normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None] 406 | if feasible('inside_sphere'): 407 | normals = normals * render_out['inside_sphere'][..., None] 408 | normals = normals.sum(dim=1).detach().cpu().numpy() 409 | out_normal_fine.append(normals) 410 | del render_out 411 | 412 | img_fine = None 413 | if len(out_rgb_fine) > 0: 414 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255) 415 | depth_fine = np.concatenate(out_depth, axis=0).reshape([H, W, 1, -1]) 416 | val_z = np.concatenate(out_val_z, axis=0).reshape([H, W, -1]) 417 | weights = np.concatenate(out_weights, axis=0).reshape([H, W, -1]) 418 | out_normal = np.concatenate(out_normal, axis=0).reshape([H, W, -1]) 419 | #sdf = np.concatenate(out_sdf, axis=0).reshape([H, W, -1]) 420 | #temp = sdf < 0 421 | #idx = np.expand_dims(np.argmax(temp, axis=2), axis=2) 422 | #depth_sdf = np.take_along_axis(val_z,idx,axis=-1)[:,:,0] 423 | 424 | normal_img = None 425 | if len(out_normal_fine) > 0: 426 | normal_img = np.concatenate(out_normal_fine, axis=0) 427 | rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy()) 428 | normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None]) 429 | .reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255) 430 | 431 | os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True) 432 | os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine','{:0>8d}'.format(self.iter_step)), exist_ok=True) 433 | os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True) 434 | 435 | for i in range(img_fine.shape[-1]): 436 | if len(out_rgb_fine) > 0: 437 | cv.imwrite(os.path.join(self.base_exp_dir, 438 | 'validations_fine', 439 | '{:0>8d}'.format(self.iter_step),'rgb.png'), 440 | img_fine[..., i]) 441 | if(verbose): 442 | print('==========depth from weight==========') 443 | errors_weight = standard_metrics(gt_depth.detach().cpu().numpy(),depth_fine[:,:,0,i],verbose=verbose) 444 | # print('==========depth from SDF==========') 445 | # errors_sdf = standard_metrics(gt_depth.detach().cpu().numpy(),depth_sdf) 446 | _ = plt.imshow(depth_fine[:,:,0,i]) 447 | plt.tight_layout() 448 | plt.savefig(os.path.join(self.base_exp_dir, 449 | 'validations_fine', 450 | '{:0>8d}'.format(self.iter_step),'depth.png')) 451 | plt.close() 452 | 453 | #ipdb.set_trace() 454 | _ = plt.imshow((out_normal + 1)/2) 455 | plt.tight_layout() 456 | plt.savefig(os.path.join(self.base_exp_dir, 457 | 'validations_fine', 458 | '{:0>8d}'.format(self.iter_step),'normal.png')) 459 | plt.close() 460 | np.save(os.path.join(self.base_exp_dir, 461 | 'validations_fine', 462 | '{:0>8d}'.format(self.iter_step),'depth'), depth_fine[:,:,0,i]) 463 | return errors_weight, None 464 | 465 | 466 | 467 | 468 | 469 | if __name__ == '__main__': 470 | 471 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 472 | 473 | parser = argparse.ArgumentParser() 474 | parser.add_argument('--conf', type=str, default='./confs/base.conf') 475 | parser.add_argument('--mcube_threshold', type=float, default=0.0) 476 | parser.add_argument('--is_continue', type=str, default=None) 477 | parser.add_argument('--gpu', type=int, default=0) 478 | parser.add_argument('--case', type=str, default='') 479 | parser.add_argument('--d', type=float, default=1.5) 480 | parser.add_argument('--n', type=int, default=0) 481 | parser.add_argument('--random', type=int, default=80) 482 | parser.add_argument('--degree', type=int, default=0) 483 | parser.add_argument('--ps', type=int, default=4) 484 | parser.add_argument('--dir', type=str, default='./exp') 485 | #ipdb.set_trace() 486 | args = parser.parse_args() 487 | 488 | torch.cuda.set_device(args.gpu) 489 | runner = Runner(args.conf, args.dir, args.case, args.is_continue,args.d,args.n,args.random,args.ps) 490 | runner.train() 491 | temp = [] 492 | for i in range(args.n+3): 493 | errors_weight,_ = runner.validate_image(idx=i) 494 | temp.append(errors_weight) 495 | temp = np.array(temp) 496 | average = np.mean(temp, axis=0) 497 | np.save(os.path.join(runner.base_exp_dir, 'average'),np.array(average)) 498 | --------------------------------------------------------------------------------