├── .gitignore ├── DATA_CONVENTION.md ├── LICENSE ├── README.md ├── code ├── confs │ ├── bmvs.conf │ └── dtu.conf ├── datasets │ └── scene_dataset.py ├── evaluation │ └── eval.py ├── model │ ├── density.py │ ├── embedder.py │ ├── loss.py │ ├── network.py │ ├── network_bg.py │ └── ray_sampler.py ├── training │ ├── exp_runner.py │ └── volsdf_train.py └── utils │ ├── general.py │ ├── plots.py │ └── rend_util.py ├── data ├── download_data.sh └── preprocess │ ├── normalize_cameras.py │ └── parse_cameras_blendedmvs.py ├── environment.yml └── media └── teaser.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | exps* 132 | evals* 133 | data/DTU 134 | data/BlendedMVS 135 | 136 | code/.idea/ 137 | .DS_Store 138 | ._.DS_Store 139 | .idea/ 140 | -------------------------------------------------------------------------------- /DATA_CONVENTION.md: -------------------------------------------------------------------------------- 1 | # Data Convention 2 | 3 | ### Camera information and normalization 4 | Besides multi-view RGB images, VolSDF needs cameras information in order to run. For each scan that we used, we supply a file named `cameras.npz`. 5 | The `cameras.npz` file contains for each image its assosiacted camera projection matrix (named "world_mat_{i}"), and a normalization matrix (named "scale_mat_{i}"). 6 | #### Camera projection matrix 7 | A 3x4 camera projection matrix, P = K[R | t] projects points from 3D coordinates to image pixels by the formula: d[x; y; 1]=P[X;Y;Z;1] where K is a 3x3 calibration matrix, [R t] is 3x4 a world to camera Euclidean transformation, [X;Y;Z] is the 3D point, [x;y] is the 2D pixel coordinates of the projected point and d is the depth of the point. 8 | The input `cameras.npz` file contains the camera matrices, where P_i = cameras['world_mat_{i}'][:3, :] is a 3x4 matrix that projects points from the 3D world coordinates to the 2D coordinates of image i (intrinsics and extrinsics, i.e. P=K[R | t] ). 9 | Each "world_mat" matrix is a concatenation of the camera projection matrix with a row vector of [0,0,0,1] (which makes it a 4x4 matrix). 10 | 11 | #### Normalization matrix 12 | The `cameras.npz` contains also one normalization matrix named "scale_mat_{i}" (identical for all i) for changing the coordinates system such that the cameras and the region of interest are located inside a sphere with radius 3 located at the origin (more details are in the paper). 13 | 14 | 15 | ### Preprocess new data 16 | For converting BlendedMVS cameras format to ours (not required for the supplied scans), run : 17 | ``` 18 | cd data/preprocess/ 19 | python parse_cameras_blendedmvs.py --blendedMVS_path [BLENDED_MVS_PATH] --output_cameras_file [OUTPUT_CAMERAS_NPZ_FILE] --scan_ind [BLENDED_MVS_SCAN_ID] 20 | ``` 21 | 22 | In order to generate a normalization matrix for each scan, we used the input camera projection matrices. A script that demonstrates this process is presented in: `data/preprocess/normalize_cameras.py`. 23 | Note: in order to run the supplied scans, it is not required to run this script. 24 | For normalizing a given `cameras.npz` file run: 25 | ``` 26 | cd data/preprocess/ 27 | python normalize_cameras.py --input_cameras_file [INPUT_CAMERAS_NPZ_FILE] --output_cameras_file [OUTPUT_NORMALIZED_CAMERAS_NPZ_FILE] [--number_of_cams [NUMBER_OF_CAMERAS_LIMIT]] 28 | ``` 29 | where the last argument is optional and used for limiting the number of cameras such that only the first [NUMBER_OF_CAMERAS_LIMIT] cameras are considered, which is useful for the DTU dataset, where for scan_id<80 only the first 49 cameras out of 64 are used. 30 | 31 | 32 | #### Parsing COLMAP cameras 33 | It is possible to convert COLMAP cameras to our cameras format using Python. First the functions read_cameras_text,read_images_text, qvec2rotmat should be imported from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py. Then the following Python code can be used: 34 | 35 | ``` 36 | cameras=read_cameras_text("output_sfm\\cameras.txt") 37 | images=read_images_text("output_sfm\\images.txt") 38 | K = np.eye(3) 39 | K[0, 0] = cameras[1].params[0] 40 | K[1, 1] = cameras[1].params[1] 41 | K[0, 2] = cameras[1].params[2] 42 | K[1, 2] = cameras[1].params[3] 43 | 44 | cameras_npz_format = {} 45 | for ii in range(len(images)): 46 | cur_image=images[ii] 47 | 48 | M=np.zeros((3,4)) 49 | M[:,3]=cur_image.tvec 50 | M[:3,:3]=qvec2rotmat(cur_image.qvec) 51 | 52 | P=np.eye(4) 53 | P[:3,:] = K@M 54 | cameras_npz_format['world_mat_%d' % ii] = P 55 | 56 | np.savez( 57 | "cameras_before_normalization.npz", 58 | **cameras_npz_format) 59 | 60 | ``` 61 | Note that you will have to normalize the cameras after running this code by running normalize_cameras.py as described above. 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Lior Yariv 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Volume Rendering of Neural Implicit Surfaces 2 | 3 | ### [Project Page](https://lioryariv.github.io/volsdf/) | [Paper](https://arxiv.org/abs/2106.12052) | [Data](https://www.dropbox.com/sh/oum8dyo19jqdkwu/AAAxpIifYjjotz_fIRBj1Fyla) 4 | 5 |

6 | 7 |

8 | 9 | This repository contains an implementation for the NeurIPS 2021 paper:
10 | Volume Rendering of Neural Implicit Surfaces
11 | Lior Yariv1, Jiatao Gu2, Yoni Kasten1, Yaron Lipman1,2
12 | 1Weizmann Institute of Science, 2Facebook AI Research 13 | 14 | The paper introduce VolSDF: a volume rendering framework for implicit neural surfaces, allowing to learn high fidelity geometry from a sparse set of input images. 15 | 16 | ## Setup 17 | #### Installation Requirmenets 18 | The code is compatible with python 3.8 and pytorch 1.7. In addition, the following packages are required: 19 | numpy, pyhocon, plotly, scikit-image, trimesh, imageio, opencv, torchvision. 20 | 21 | You can create an anaconda environment called `volsdf` with the required dependencies by running: 22 | ``` 23 | conda env create -f environment.yml 24 | conda activate volsdf 25 | ``` 26 | 27 | #### Data 28 | 29 | We apply our multiview surface reconstruction model to real 2D images from two datasets: DTU and BlendedMVS. 30 | The selected scans data evaluated in the paper can be downloaded using: 31 | ``` 32 | bash data/download_data.sh 33 | ``` 34 | For more information on the data convention and how to run VolSDF on a new data please have a look at data convention.

35 | 36 | ## Usage 37 | #### Multiview 3D reconstruction 38 | 39 | For training VolSDF run: 40 | ``` 41 | cd ./code 42 | python training/exp_runner.py --conf ./confs/dtu.conf --scan_id SCAN_ID 43 | ``` 44 | where SCAN_ID is the id of the scene to reconstruct. 45 | 46 | To run on the BlendedMVS dataset, which have more complex background, use `--conf ./confs/bmvs.conf`. 47 | 48 | 49 | #### Evaluation 50 | 51 | To produce the meshed surface and renderings, run: 52 | ``` 53 | cd ./code 54 | python evaluation/eval.py --conf ./confs/dtu.conf --scan_id SCAN_ID --checkpoint CHECKPOINT [--eval_rendering] 55 | ``` 56 | where CHECKPOINT is the epoch you wish to evaluate or 'latest' if you wish to take the most recent epoch. 57 | Turning on `--eval_rendering` will further produce and evaluate PSNR of train image reconstructions. 58 | 59 | 60 | 61 | ## Citation 62 | If you find our work useful in your research, please consider citing: 63 | 64 | @inproceedings{yariv2021volume, 65 | title={Volume rendering of neural implicit surfaces}, 66 | author={Yariv, Lior and Gu, Jiatao and Kasten, Yoni and Lipman, Yaron}, 67 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, 68 | year={2021} 69 | } 70 | 71 | -------------------------------------------------------------------------------- /code/confs/bmvs.conf: -------------------------------------------------------------------------------- 1 | train{ 2 | expname = bmvs 3 | dataset_class = datasets.scene_dataset.SceneDataset 4 | model_class = model.network_bg.VolSDFNetworkBG 5 | loss_class = model.loss.VolSDFLoss 6 | learning_rate = 5.0e-4 7 | num_pixels = 1024 8 | checkpoint_freq = 100 9 | plot_freq = 500 10 | split_n_pixels = 1000 11 | } 12 | plot{ 13 | plot_nimgs = 1 14 | resolution = 100 15 | grid_boundary = [-1.5, 1.5] 16 | } 17 | loss{ 18 | eikonal_weight = 0.1 19 | rgb_loss = torch.nn.L1Loss 20 | } 21 | dataset{ 22 | data_dir = BlendedMVS 23 | img_res = [576, 768] 24 | scan_id = 1 25 | } 26 | model{ 27 | feature_vector_size = 256 28 | scene_bounding_sphere = 3.0 29 | implicit_network 30 | { 31 | d_in = 3 32 | d_out = 1 33 | dims = [ 256, 256, 256, 256, 256, 256, 256, 256 ] 34 | geometric_init = True 35 | bias = 0.6 36 | skip_in = [4] 37 | weight_norm = True 38 | multires = 6 39 | } 40 | rendering_network 41 | { 42 | mode = idr 43 | d_in = 9 44 | d_out = 3 45 | dims = [ 256, 256, 256, 256] 46 | weight_norm = True 47 | multires_view = 4 48 | } 49 | density 50 | { 51 | params_init{ 52 | beta = 0.1 53 | } 54 | beta_min = 0.0001 55 | } 56 | ray_sampler 57 | { 58 | near = 0.0 59 | N_samples = 64 60 | N_samples_eval = 128 61 | N_samples_extra = 32 62 | eps = 0.1 63 | beta_iters = 10 64 | max_total_iters = 5 65 | N_samples_inverse_sphere = 32 66 | add_tiny = 1.0e-6 67 | } 68 | bg_network{ 69 | feature_vector_size = 256 70 | implicit_network 71 | { 72 | d_in = 4 73 | d_out = 1 74 | dims = [ 256, 256, 256, 256, 256, 256, 256, 256 ] 75 | geometric_init = False 76 | bias = 0.0 77 | skip_in = [4] 78 | weight_norm = False 79 | multires = 10 80 | } 81 | rendering_network 82 | { 83 | mode = nerf 84 | d_in = 3 85 | d_out = 3 86 | dims = [128] 87 | weight_norm = False 88 | multires_view = 4 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /code/confs/dtu.conf: -------------------------------------------------------------------------------- 1 | train{ 2 | expname = dtu 3 | dataset_class = datasets.scene_dataset.SceneDataset 4 | model_class = model.network.VolSDFNetwork 5 | loss_class = model.loss.VolSDFLoss 6 | learning_rate = 5.0e-4 7 | num_pixels = 1024 8 | checkpoint_freq = 100 9 | plot_freq = 500 10 | split_n_pixels = 1000 11 | } 12 | plot{ 13 | plot_nimgs = 1 14 | resolution = 100 15 | grid_boundary = [-1.5, 1.5] 16 | } 17 | loss{ 18 | eikonal_weight = 0.1 19 | rgb_loss = torch.nn.L1Loss 20 | } 21 | dataset{ 22 | data_dir = DTU 23 | img_res = [1200, 1600] 24 | scan_id = 65 25 | } 26 | model{ 27 | feature_vector_size = 256 28 | scene_bounding_sphere = 3.0 29 | implicit_network 30 | { 31 | d_in = 3 32 | d_out = 1 33 | dims = [ 256, 256, 256, 256, 256, 256, 256, 256 ] 34 | geometric_init = True 35 | bias = 0.6 36 | skip_in = [4] 37 | weight_norm = True 38 | multires = 6 39 | sphere_scale = 20.0 40 | } 41 | rendering_network 42 | { 43 | mode = idr 44 | d_in = 9 45 | d_out = 3 46 | dims = [ 256, 256, 256, 256] 47 | weight_norm = True 48 | multires_view = 4 49 | } 50 | density 51 | { 52 | params_init{ 53 | beta = 0.1 54 | } 55 | beta_min = 0.0001 56 | } 57 | ray_sampler 58 | { 59 | near = 0.0 60 | N_samples = 64 61 | N_samples_eval = 128 62 | N_samples_extra = 32 63 | eps = 0.1 64 | beta_iters = 10 65 | max_total_iters = 5 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /code/datasets/scene_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | import utils.general as utils 6 | from utils import rend_util 7 | 8 | class SceneDataset(torch.utils.data.Dataset): 9 | 10 | def __init__(self, 11 | data_dir, 12 | img_res, 13 | scan_id=0, 14 | ): 15 | 16 | self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id)) 17 | 18 | self.total_pixels = img_res[0] * img_res[1] 19 | self.img_res = img_res 20 | 21 | assert os.path.exists(self.instance_dir), "Data directory is empty" 22 | 23 | self.sampling_idx = None 24 | 25 | image_dir = '{0}/image'.format(self.instance_dir) 26 | image_paths = sorted(utils.glob_imgs(image_dir)) 27 | self.n_images = len(image_paths) 28 | 29 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir) 30 | camera_dict = np.load(self.cam_file) 31 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 32 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 33 | 34 | self.intrinsics_all = [] 35 | self.pose_all = [] 36 | for scale_mat, world_mat in zip(scale_mats, world_mats): 37 | P = world_mat @ scale_mat 38 | P = P[:3, :4] 39 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) 40 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 41 | self.pose_all.append(torch.from_numpy(pose).float()) 42 | 43 | self.rgb_images = [] 44 | for path in image_paths: 45 | rgb = rend_util.load_rgb(path) 46 | rgb = rgb.reshape(3, -1).transpose(1, 0) 47 | self.rgb_images.append(torch.from_numpy(rgb).float()) 48 | 49 | def __len__(self): 50 | return self.n_images 51 | 52 | def __getitem__(self, idx): 53 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) 54 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() 55 | uv = uv.reshape(2, -1).transpose(1, 0) 56 | 57 | sample = { 58 | "uv": uv, 59 | "intrinsics": self.intrinsics_all[idx], 60 | "pose": self.pose_all[idx] 61 | } 62 | 63 | ground_truth = { 64 | "rgb": self.rgb_images[idx] 65 | } 66 | 67 | if self.sampling_idx is not None: 68 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :] 69 | sample["uv"] = uv[self.sampling_idx, :] 70 | 71 | return idx, sample, ground_truth 72 | 73 | def collate_fn(self, batch_list): 74 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances 75 | batch_list = zip(*batch_list) 76 | 77 | all_parsed = [] 78 | for entry in batch_list: 79 | if type(entry[0]) is dict: 80 | # make them all into a new dict 81 | ret = {} 82 | for k in entry[0].keys(): 83 | ret[k] = torch.stack([obj[k] for obj in entry]) 84 | all_parsed.append(ret) 85 | else: 86 | all_parsed.append(torch.LongTensor(entry)) 87 | 88 | return tuple(all_parsed) 89 | 90 | def change_sampling_idx(self, sampling_size): 91 | if sampling_size == -1: 92 | self.sampling_idx = None 93 | else: 94 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size] 95 | 96 | def get_scale_mat(self): 97 | return np.load(self.cam_file)['scale_mat_0'] 98 | -------------------------------------------------------------------------------- /code/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../code') 3 | import argparse 4 | import GPUtil 5 | import os 6 | from pyhocon import ConfigFactory 7 | import torch 8 | import numpy as np 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import pandas as pd 12 | 13 | import utils.general as utils 14 | import utils.plots as plt 15 | from utils import rend_util 16 | 17 | def evaluate(**kwargs): 18 | torch.set_default_dtype(torch.float32) 19 | torch.set_num_threads(1) 20 | 21 | conf = ConfigFactory.parse_file(kwargs['conf']) 22 | exps_folder_name = kwargs['exps_folder_name'] 23 | evals_folder_name = kwargs['evals_folder_name'] 24 | eval_rendering = kwargs['eval_rendering'] 25 | 26 | expname = conf.get_string('train.expname') + kwargs['expname'] 27 | scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int('dataset.scan_id', default=-1) 28 | if scan_id != -1: 29 | expname = expname + '_{0}'.format(scan_id) 30 | else: 31 | scan_id = conf.get_string('dataset.object', default='') 32 | 33 | if kwargs['timestamp'] == 'latest': 34 | if os.path.exists(os.path.join('../', kwargs['exps_folder_name'], expname)): 35 | timestamps = os.listdir(os.path.join('../', kwargs['exps_folder_name'], expname)) 36 | if (len(timestamps)) == 0: 37 | print('WRONG EXP FOLDER') 38 | exit() 39 | # self.timestamp = sorted(timestamps)[-1] 40 | timestamp = None 41 | for t in sorted(timestamps): 42 | if os.path.exists(os.path.join('../', kwargs['exps_folder_name'], expname, t, 'checkpoints', 43 | 'ModelParameters', str(kwargs['checkpoint']) + ".pth")): 44 | timestamp = t 45 | if timestamp is None: 46 | print('NO GOOD TIMSTAMP') 47 | exit() 48 | else: 49 | print('WRONG EXP FOLDER') 50 | exit() 51 | else: 52 | timestamp = kwargs['timestamp'] 53 | 54 | utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name)) 55 | expdir = os.path.join('../', exps_folder_name, expname) 56 | evaldir = os.path.join('../', evals_folder_name, expname) 57 | utils.mkdir_ifnotexists(evaldir) 58 | 59 | dataset_conf = conf.get_config('dataset') 60 | if kwargs['scan_id'] != -1: 61 | dataset_conf['scan_id'] = kwargs['scan_id'] 62 | eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(**dataset_conf) 63 | 64 | conf_model = conf.get_config('model') 65 | model = utils.get_class(conf.get_string('train.model_class'))(conf=conf_model) 66 | if torch.cuda.is_available(): 67 | model.cuda() 68 | 69 | # settings for camera optimization 70 | scale_mat = eval_dataset.get_scale_mat() 71 | 72 | if eval_rendering: 73 | eval_dataloader = torch.utils.data.DataLoader(eval_dataset, 74 | batch_size=1, 75 | shuffle=False, 76 | collate_fn=eval_dataset.collate_fn 77 | ) 78 | total_pixels = eval_dataset.total_pixels 79 | img_res = eval_dataset.img_res 80 | split_n_pixels = conf.get_int('train.split_n_pixels', 10000) 81 | 82 | old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints') 83 | 84 | saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) 85 | model.load_state_dict(saved_model_state["model_state_dict"]) 86 | epoch = saved_model_state['epoch'] 87 | 88 | #################################################################################################################### 89 | print("evaluating...") 90 | 91 | model.eval() 92 | 93 | with torch.no_grad(): 94 | 95 | if scan_id < 24: # Blended MVS 96 | mesh = plt.get_surface_high_res_mesh( 97 | sdf=lambda x: model.implicit_network(x)[:, 0], 98 | resolution=kwargs['resolution'], 99 | grid_boundary=conf.get_list('plot.grid_boundary'), 100 | level=conf.get_int('plot.level', default=0), 101 | take_components = type(scan_id) is not str 102 | ) 103 | else: # DTU 104 | bb_dict = np.load('../data/DTU/bbs.npz') 105 | grid_params = bb_dict[str(scan_id)] 106 | 107 | mesh = plt.get_surface_by_grid( 108 | grid_params=grid_params, 109 | sdf=lambda x: model.implicit_network(x)[:, 0], 110 | resolution=kwargs['resolution'], 111 | level=conf.get_int('plot.level', default=0), 112 | higher_res=True 113 | ) 114 | 115 | # Transform to world coordinates 116 | mesh.apply_transform(scale_mat) 117 | 118 | # Taking the biggest connected component 119 | components = mesh.split(only_watertight=False) 120 | areas = np.array([c.area for c in components], dtype=np.float32) 121 | mesh_clean = components[areas.argmax()] 122 | 123 | mesh_folder = '{0}/{1}'.format(evaldir, epoch) 124 | utils.mkdir_ifnotexists(mesh_folder) 125 | mesh_clean.export('{0}/scan{1}.ply'.format(mesh_folder, scan_id), 'ply') 126 | 127 | if eval_rendering: 128 | images_dir = '{0}/rendering_{1}'.format(evaldir, epoch) 129 | utils.mkdir_ifnotexists(images_dir) 130 | 131 | psnrs = [] 132 | for data_index, (indices, model_input, ground_truth) in enumerate(eval_dataloader): 133 | model_input["intrinsics"] = model_input["intrinsics"].cuda() 134 | model_input["uv"] = model_input["uv"].cuda() 135 | model_input['pose'] = model_input['pose'].cuda() 136 | 137 | split = utils.split_input(model_input, total_pixels, n_pixels=split_n_pixels) 138 | res = [] 139 | for s in tqdm(split): 140 | torch.cuda.empty_cache() 141 | out = model(s) 142 | res.append({ 143 | 'rgb_values': out['rgb_values'].detach(), 144 | }) 145 | 146 | batch_size = ground_truth['rgb'].shape[0] 147 | model_outputs = utils.merge_output(res, total_pixels, batch_size) 148 | rgb_eval = model_outputs['rgb_values'] 149 | rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3) 150 | 151 | rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0] 152 | rgb_eval = rgb_eval.transpose(1, 2, 0) 153 | img = Image.fromarray((rgb_eval * 255).astype(np.uint8)) 154 | img.save('{0}/eval_{1}.png'.format(images_dir,'%03d' % indices[0])) 155 | 156 | psnr = rend_util.get_psnr(model_outputs['rgb_values'], 157 | ground_truth['rgb'].cuda().reshape(-1, 3)).item() 158 | psnrs.append(psnr) 159 | 160 | 161 | psnrs = np.array(psnrs).astype(np.float64) 162 | print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id)) 163 | psnrs = np.concatenate([psnrs, psnrs.mean()[None], psnrs.std()[None]]) 164 | pd.DataFrame(psnrs).to_csv('{0}/psnr_{1}.csv'.format(evaldir, epoch)) 165 | 166 | 167 | 168 | if __name__ == '__main__': 169 | 170 | parser = argparse.ArgumentParser() 171 | 172 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf') 173 | parser.add_argument('--expname', type=str, default='', help='The experiment name to be evaluated.') 174 | parser.add_argument('--exps_folder', type=str, default='exps', help='The experiments folder name.') 175 | parser.add_argument('--evals_folder', type=str, default='evals', help='The evaluation folder name.') 176 | parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]') 177 | parser.add_argument('--timestamp', default='latest', type=str, help='The experiemnt timestamp to test.') 178 | parser.add_argument('--checkpoint', default='latest',type=str,help='The trained model checkpoint to test') 179 | parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.') 180 | parser.add_argument('--resolution', default=512, type=int, help='Grid resolution for marching cube') 181 | parser.add_argument('--eval_rendering', default=False, action="store_true", help='If set, evaluate rendering quality.') 182 | 183 | opt = parser.parse_args() 184 | 185 | if opt.gpu == "auto": 186 | deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[], excludeUUID=[]) 187 | gpu = deviceIDs[0] 188 | else: 189 | gpu = opt.gpu 190 | 191 | if (not gpu == 'ignore'): 192 | os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(gpu) 193 | 194 | evaluate(conf=opt.conf, 195 | expname=opt.expname, 196 | exps_folder_name=opt.exps_folder, 197 | evals_folder_name=opt.evals_folder, 198 | timestamp=opt.timestamp, 199 | checkpoint=opt.checkpoint, 200 | scan_id=opt.scan_id, 201 | resolution=opt.resolution, 202 | eval_rendering=opt.eval_rendering, 203 | ) 204 | -------------------------------------------------------------------------------- /code/model/density.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class Density(nn.Module): 6 | def __init__(self, params_init={}): 7 | super().__init__() 8 | for p in params_init: 9 | param = nn.Parameter(torch.tensor(params_init[p])) 10 | setattr(self, p, param) 11 | 12 | def forward(self, sdf, beta=None): 13 | return self.density_func(sdf, beta=beta) 14 | 15 | 16 | class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf) 17 | def __init__(self, params_init={}, beta_min=0.0001): 18 | super().__init__(params_init=params_init) 19 | self.beta_min = torch.tensor(beta_min).cuda() 20 | 21 | def density_func(self, sdf, beta=None): 22 | if beta is None: 23 | beta = self.get_beta() 24 | 25 | alpha = 1 / beta 26 | return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta)) 27 | 28 | def get_beta(self): 29 | beta = self.beta.abs() + self.beta_min 30 | return beta 31 | 32 | 33 | class AbsDensity(Density): # like NeRF++ 34 | def density_func(self, sdf, beta=None): 35 | return torch.abs(sdf) 36 | 37 | 38 | class SimpleDensity(Density): # like NeRF 39 | def __init__(self, params_init={}, noise_std=1.0): 40 | super().__init__(params_init=params_init) 41 | self.noise_std = noise_std 42 | 43 | def density_func(self, sdf, beta=None): 44 | if self.training and self.noise_std > 0.0: 45 | noise = torch.randn(sdf.shape).cuda() * self.noise_std 46 | sdf = sdf + noise 47 | return torch.relu(sdf) 48 | -------------------------------------------------------------------------------- /code/model/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ 4 | 5 | class Embedder: 6 | def __init__(self, **kwargs): 7 | self.kwargs = kwargs 8 | self.create_embedding_fn() 9 | 10 | def create_embedding_fn(self): 11 | embed_fns = [] 12 | d = self.kwargs['input_dims'] 13 | out_dim = 0 14 | if self.kwargs['include_input']: 15 | embed_fns.append(lambda x: x) 16 | out_dim += d 17 | 18 | max_freq = self.kwargs['max_freq_log2'] 19 | N_freqs = self.kwargs['num_freqs'] 20 | 21 | if self.kwargs['log_sampling']: 22 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 23 | else: 24 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 25 | 26 | for freq in freq_bands: 27 | for p_fn in self.kwargs['periodic_fns']: 28 | embed_fns.append(lambda x, p_fn=p_fn, 29 | freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | def get_embedder(multires, input_dims=3): 39 | embed_kwargs = { 40 | 'include_input': True, 41 | 'input_dims': input_dims, 42 | 'max_freq_log2': multires-1, 43 | 'num_freqs': multires, 44 | 'log_sampling': True, 45 | 'periodic_fns': [torch.sin, torch.cos], 46 | } 47 | 48 | embedder_obj = Embedder(**embed_kwargs) 49 | def embed(x, eo=embedder_obj): return eo.embed(x) 50 | return embed, embedder_obj.out_dim 51 | -------------------------------------------------------------------------------- /code/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import utils.general as utils 4 | 5 | 6 | class VolSDFLoss(nn.Module): 7 | def __init__(self, rgb_loss, eikonal_weight): 8 | super().__init__() 9 | self.eikonal_weight = eikonal_weight 10 | self.rgb_loss = utils.get_class(rgb_loss)(reduction='mean') 11 | 12 | def get_rgb_loss(self,rgb_values, rgb_gt): 13 | rgb_gt = rgb_gt.reshape(-1, 3) 14 | rgb_loss = self.rgb_loss(rgb_values, rgb_gt) 15 | return rgb_loss 16 | 17 | def get_eikonal_loss(self, grad_theta): 18 | eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean() 19 | return eikonal_loss 20 | 21 | def forward(self, model_outputs, ground_truth): 22 | rgb_gt = ground_truth['rgb'].cuda() 23 | 24 | rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt) 25 | if 'grad_theta' in model_outputs: 26 | eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta']) 27 | else: 28 | eikonal_loss = torch.tensor(0.0).cuda().float() 29 | 30 | loss = rgb_loss + \ 31 | self.eikonal_weight * eikonal_loss 32 | 33 | output = { 34 | 'loss': loss, 35 | 'rgb_loss': rgb_loss, 36 | 'eikonal_loss': eikonal_loss, 37 | } 38 | 39 | return output 40 | -------------------------------------------------------------------------------- /code/model/network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | 4 | from utils import rend_util 5 | from model.embedder import * 6 | from model.density import LaplaceDensity 7 | from model.ray_sampler import ErrorBoundSampler 8 | 9 | class ImplicitNetwork(nn.Module): 10 | def __init__( 11 | self, 12 | feature_vector_size, 13 | sdf_bounding_sphere, 14 | d_in, 15 | d_out, 16 | dims, 17 | geometric_init=True, 18 | bias=1.0, 19 | skip_in=(), 20 | weight_norm=True, 21 | multires=0, 22 | sphere_scale=1.0, 23 | ): 24 | super().__init__() 25 | 26 | self.sdf_bounding_sphere = sdf_bounding_sphere 27 | self.sphere_scale = sphere_scale 28 | dims = [d_in] + dims + [d_out + feature_vector_size] 29 | 30 | self.embed_fn = None 31 | if multires > 0: 32 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 33 | self.embed_fn = embed_fn 34 | dims[0] = input_ch 35 | 36 | self.num_layers = len(dims) 37 | self.skip_in = skip_in 38 | 39 | for l in range(0, self.num_layers - 1): 40 | if l + 1 in self.skip_in: 41 | out_dim = dims[l + 1] - dims[0] 42 | else: 43 | out_dim = dims[l + 1] 44 | 45 | lin = nn.Linear(dims[l], out_dim) 46 | 47 | if geometric_init: 48 | if l == self.num_layers - 2: 49 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 50 | torch.nn.init.constant_(lin.bias, -bias) 51 | elif multires > 0 and l == 0: 52 | torch.nn.init.constant_(lin.bias, 0.0) 53 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 54 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 55 | elif multires > 0 and l in self.skip_in: 56 | torch.nn.init.constant_(lin.bias, 0.0) 57 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 58 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 59 | else: 60 | torch.nn.init.constant_(lin.bias, 0.0) 61 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 62 | 63 | if weight_norm: 64 | lin = nn.utils.weight_norm(lin) 65 | 66 | setattr(self, "lin" + str(l), lin) 67 | 68 | self.softplus = nn.Softplus(beta=100) 69 | 70 | def forward(self, input): 71 | if self.embed_fn is not None: 72 | input = self.embed_fn(input) 73 | 74 | x = input 75 | 76 | for l in range(0, self.num_layers - 1): 77 | lin = getattr(self, "lin" + str(l)) 78 | 79 | if l in self.skip_in: 80 | x = torch.cat([x, input], 1) / np.sqrt(2) 81 | 82 | x = lin(x) 83 | 84 | if l < self.num_layers - 2: 85 | x = self.softplus(x) 86 | 87 | return x 88 | 89 | def gradient(self, x): 90 | x.requires_grad_(True) 91 | y = self.forward(x)[:,:1] 92 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 93 | gradients = torch.autograd.grad( 94 | outputs=y, 95 | inputs=x, 96 | grad_outputs=d_output, 97 | create_graph=True, 98 | retain_graph=True, 99 | only_inputs=True)[0] 100 | return gradients 101 | 102 | def get_outputs(self, x): 103 | x.requires_grad_(True) 104 | output = self.forward(x) 105 | sdf = output[:,:1] 106 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded ''' 107 | if self.sdf_bounding_sphere > 0.0: 108 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True)) 109 | sdf = torch.minimum(sdf, sphere_sdf) 110 | feature_vectors = output[:, 1:] 111 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) 112 | gradients = torch.autograd.grad( 113 | outputs=sdf, 114 | inputs=x, 115 | grad_outputs=d_output, 116 | create_graph=True, 117 | retain_graph=True, 118 | only_inputs=True)[0] 119 | 120 | return sdf, feature_vectors, gradients 121 | 122 | def get_sdf_vals(self, x): 123 | sdf = self.forward(x)[:,:1] 124 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded ''' 125 | if self.sdf_bounding_sphere > 0.0: 126 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True)) 127 | sdf = torch.minimum(sdf, sphere_sdf) 128 | return sdf 129 | 130 | 131 | class RenderingNetwork(nn.Module): 132 | def __init__( 133 | self, 134 | feature_vector_size, 135 | mode, 136 | d_in, 137 | d_out, 138 | dims, 139 | weight_norm=True, 140 | multires_view=0, 141 | ): 142 | super().__init__() 143 | 144 | self.mode = mode 145 | dims = [d_in + feature_vector_size] + dims + [d_out] 146 | 147 | self.embedview_fn = None 148 | if multires_view > 0: 149 | embedview_fn, input_ch = get_embedder(multires_view) 150 | self.embedview_fn = embedview_fn 151 | dims[0] += (input_ch - 3) 152 | 153 | self.num_layers = len(dims) 154 | 155 | for l in range(0, self.num_layers - 1): 156 | out_dim = dims[l + 1] 157 | lin = nn.Linear(dims[l], out_dim) 158 | 159 | if weight_norm: 160 | lin = nn.utils.weight_norm(lin) 161 | 162 | setattr(self, "lin" + str(l), lin) 163 | 164 | self.relu = nn.ReLU() 165 | self.sigmoid = torch.nn.Sigmoid() 166 | 167 | def forward(self, points, normals, view_dirs, feature_vectors): 168 | if self.embedview_fn is not None: 169 | view_dirs = self.embedview_fn(view_dirs) 170 | 171 | if self.mode == 'idr': 172 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) 173 | elif self.mode == 'nerf': 174 | rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1) 175 | 176 | x = rendering_input 177 | 178 | for l in range(0, self.num_layers - 1): 179 | lin = getattr(self, "lin" + str(l)) 180 | 181 | x = lin(x) 182 | 183 | if l < self.num_layers - 2: 184 | x = self.relu(x) 185 | 186 | x = self.sigmoid(x) 187 | return x 188 | 189 | class VolSDFNetwork(nn.Module): 190 | def __init__(self, conf): 191 | super().__init__() 192 | self.feature_vector_size = conf.get_int('feature_vector_size') 193 | self.scene_bounding_sphere = conf.get_float('scene_bounding_sphere', default=1.0) 194 | self.white_bkgd = conf.get_bool('white_bkgd', default=False) 195 | self.bg_color = torch.tensor(conf.get_list("bg_color", default=[1.0, 1.0, 1.0])).float().cuda() 196 | 197 | self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network')) 198 | self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.get_config('rendering_network')) 199 | 200 | self.density = LaplaceDensity(**conf.get_config('density')) 201 | self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, **conf.get_config('ray_sampler')) 202 | 203 | def forward(self, input): 204 | # Parse model input 205 | intrinsics = input["intrinsics"] 206 | uv = input["uv"] 207 | pose = input["pose"] 208 | 209 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics) 210 | 211 | batch_size, num_pixels, _ = ray_dirs.shape 212 | 213 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3) 214 | ray_dirs = ray_dirs.reshape(-1, 3) 215 | 216 | z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self) 217 | N_samples = z_vals.shape[1] 218 | 219 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) 220 | points_flat = points.reshape(-1, 3) 221 | 222 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1) 223 | dirs_flat = dirs.reshape(-1, 3) 224 | 225 | sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat) 226 | 227 | rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors) 228 | rgb = rgb_flat.reshape(-1, N_samples, 3) 229 | 230 | weights = self.volume_rendering(z_vals, sdf) 231 | 232 | rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1) 233 | 234 | # white background assumption 235 | if self.white_bkgd: 236 | acc_map = torch.sum(weights, -1) 237 | rgb_values = rgb_values + (1. - acc_map[..., None]) * self.bg_color.unsqueeze(0) 238 | 239 | output = { 240 | 'rgb_values': rgb_values, 241 | } 242 | 243 | if self.training: 244 | # Sample points for the eikonal loss 245 | n_eik_points = batch_size * num_pixels 246 | eikonal_points = torch.empty(n_eik_points, 3).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere).cuda() 247 | 248 | # add some of the near surface points 249 | eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3) 250 | eikonal_points = torch.cat([eikonal_points, eik_near_points], 0) 251 | 252 | grad_theta = self.implicit_network.gradient(eikonal_points) 253 | output['grad_theta'] = grad_theta 254 | 255 | if not self.training: 256 | gradients = gradients.detach() 257 | normals = gradients / gradients.norm(2, -1, keepdim=True) 258 | normals = normals.reshape(-1, N_samples, 3) 259 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1) 260 | 261 | output['normal_map'] = normal_map 262 | 263 | return output 264 | 265 | def volume_rendering(self, z_vals, sdf): 266 | density_flat = self.density(sdf) 267 | density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples 268 | 269 | dists = z_vals[:, 1:] - z_vals[:, :-1] 270 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1) 271 | 272 | # LOG SPACE 273 | free_energy = dists * density 274 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) # shift one step 275 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here 276 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now 277 | weights = alpha * transmittance # probability of the ray hits something here 278 | 279 | return weights 280 | -------------------------------------------------------------------------------- /code/model/network_bg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | 5 | import utils.general as utils 6 | from utils import rend_util 7 | from model.network import ImplicitNetwork, RenderingNetwork 8 | from model.density import LaplaceDensity, AbsDensity 9 | from model.ray_sampler import ErrorBoundSampler 10 | 11 | 12 | """ 13 | For modeling more complex backgrounds, we follow the inverted sphere parametrization from NeRF++ 14 | https://github.com/Kai-46/nerfplusplus 15 | """ 16 | 17 | 18 | class VolSDFNetworkBG(nn.Module): 19 | def __init__(self, conf): 20 | super().__init__() 21 | self.feature_vector_size = conf.get_int('feature_vector_size') 22 | self.scene_bounding_sphere = conf.get_float('scene_bounding_sphere', default=1.0) 23 | 24 | # Foreground object's networks 25 | self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0, **conf.get_config('implicit_network')) 26 | self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.get_config('rendering_network')) 27 | 28 | self.density = LaplaceDensity(**conf.get_config('density')) 29 | self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, inverse_sphere_bg=True, **conf.get_config('ray_sampler')) 30 | 31 | # Background's networks 32 | bg_feature_vector_size = conf.get_int('bg_network.feature_vector_size') 33 | self.bg_implicit_network = ImplicitNetwork(bg_feature_vector_size, 0.0, **conf.get_config('bg_network.implicit_network')) 34 | self.bg_rendering_network = RenderingNetwork(bg_feature_vector_size, **conf.get_config('bg_network.rendering_network')) 35 | self.bg_density = AbsDensity(**conf.get_config('bg_network.density', default={})) 36 | 37 | def forward(self, input): 38 | # Parse model input 39 | intrinsics = input["intrinsics"] 40 | uv = input["uv"] 41 | pose = input["pose"] 42 | 43 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics) 44 | 45 | batch_size, num_pixels, _ = ray_dirs.shape 46 | 47 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3) 48 | ray_dirs = ray_dirs.reshape(-1, 3) 49 | 50 | z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self) 51 | 52 | z_vals, z_vals_bg = z_vals 53 | z_max = z_vals[:,-1] 54 | z_vals = z_vals[:,:-1] 55 | N_samples = z_vals.shape[1] 56 | 57 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) 58 | points_flat = points.reshape(-1, 3) 59 | 60 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1) 61 | dirs_flat = dirs.reshape(-1, 3) 62 | 63 | sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat) 64 | 65 | rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors) 66 | rgb = rgb_flat.reshape(-1, N_samples, 3) 67 | 68 | weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf) 69 | 70 | fg_rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1) 71 | 72 | 73 | # Background rendering 74 | N_bg_samples = z_vals_bg.shape[1] 75 | z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ]) # 1--->0 76 | 77 | bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1) 78 | bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1) 79 | 80 | bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg) # [..., N_samples, 4] 81 | bg_points_flat = bg_points.reshape(-1, 4) 82 | bg_dirs_flat = bg_dirs.reshape(-1, 3) 83 | 84 | output = self.bg_implicit_network(bg_points_flat) 85 | bg_sdf = output[:,:1] 86 | bg_feature_vectors = output[:, 1:] 87 | bg_rgb_flat = self.bg_rendering_network(None, None, bg_dirs_flat, bg_feature_vectors) 88 | bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3) 89 | 90 | bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf) 91 | 92 | bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1) 93 | 94 | 95 | # Composite foreground and background 96 | bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values 97 | rgb_values = fg_rgb_values + bg_rgb_values 98 | 99 | output = { 100 | 'rgb_values': rgb_values, 101 | } 102 | 103 | if self.training: 104 | # Sample points for the eikonal loss 105 | n_eik_points = batch_size * num_pixels 106 | eikonal_points = torch.empty(n_eik_points, 3).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere).cuda() 107 | 108 | # add some of the near surface points 109 | eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3) 110 | eikonal_points = torch.cat([eikonal_points, eik_near_points], 0) 111 | 112 | grad_theta = self.implicit_network.gradient(eikonal_points) 113 | output['grad_theta'] = grad_theta 114 | 115 | if not self.training: 116 | gradients = gradients.detach() 117 | normals = gradients / gradients.norm(2, -1, keepdim=True) 118 | normals = normals.reshape(-1, N_samples, 3) 119 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1) 120 | 121 | output['normal_map'] = normal_map 122 | 123 | return output 124 | 125 | def volume_rendering(self, z_vals, z_max, sdf): 126 | density_flat = self.density(sdf) 127 | density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples 128 | 129 | # included also the dist from the sphere intersection 130 | dists = z_vals[:, 1:] - z_vals[:, :-1] 131 | dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1) 132 | 133 | # LOG SPACE 134 | free_energy = dists * density 135 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy], dim=-1) # add 0 for transperancy 1 at t_0 136 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here 137 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now 138 | fg_transmittance = transmittance[:, :-1] 139 | weights = alpha * fg_transmittance # probability of the ray hits something here 140 | bg_transmittance = transmittance[:, -1] # factor to be multiplied with the bg volume rendering 141 | 142 | return weights, bg_transmittance 143 | 144 | def bg_volume_rendering(self, z_vals_bg, bg_sdf): 145 | bg_density_flat = self.bg_density(bg_sdf) 146 | bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples 147 | 148 | bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:] 149 | bg_dists = torch.cat([bg_dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1) 150 | 151 | # LOG SPACE 152 | bg_free_energy = bg_dists * bg_density 153 | bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1).cuda(), bg_free_energy[:, :-1]], dim=-1) # shift one step 154 | bg_alpha = 1 - torch.exp(-bg_free_energy) # probability of it is not empty here 155 | bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1)) # probability of everything is empty up to now 156 | bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here 157 | 158 | return bg_weights 159 | 160 | def depth2pts_outside(self, ray_o, ray_d, depth): 161 | 162 | ''' 163 | ray_o, ray_d: [..., 3] 164 | depth: [...]; inverse of distance to sphere origin 165 | ''' 166 | 167 | o_dot_d = torch.sum(ray_d * ray_o, dim=-1) 168 | under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.scene_bounding_sphere ** 2) 169 | d_sphere = torch.sqrt(under_sqrt) - o_dot_d 170 | p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d 171 | p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d 172 | p_mid_norm = torch.norm(p_mid, dim=-1) 173 | 174 | rot_axis = torch.cross(ray_o, p_sphere, dim=-1) 175 | rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True) 176 | phi = torch.asin(p_mid_norm / self.scene_bounding_sphere) 177 | theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1] 178 | rot_angle = (phi - theta).unsqueeze(-1) # [..., 1] 179 | 180 | # now rotate p_sphere 181 | # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula 182 | p_sphere_new = p_sphere * torch.cos(rot_angle) + \ 183 | torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \ 184 | rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle)) 185 | p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True) 186 | pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1) 187 | 188 | return pts 189 | -------------------------------------------------------------------------------- /code/model/ray_sampler.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | 4 | from utils import rend_util 5 | 6 | class RaySampler(metaclass=abc.ABCMeta): 7 | def __init__(self,near, far): 8 | self.near = near 9 | self.far = far 10 | 11 | @abc.abstractmethod 12 | def get_z_vals(self, ray_dirs, cam_loc, model): 13 | pass 14 | 15 | class UniformSampler(RaySampler): 16 | def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1): 17 | super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R 18 | self.N_samples = N_samples 19 | self.scene_bounding_sphere = scene_bounding_sphere 20 | self.take_sphere_intersection = take_sphere_intersection 21 | 22 | def get_z_vals(self, ray_dirs, cam_loc, model): 23 | if not self.take_sphere_intersection: 24 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda() 25 | else: 26 | sphere_intersections = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere) 27 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda() 28 | far = sphere_intersections[:,1:] 29 | 30 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda() 31 | z_vals = near * (1. - t_vals) + far * (t_vals) 32 | 33 | if model.training: 34 | # get intervals between samples 35 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 36 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 37 | lower = torch.cat([z_vals[..., :1], mids], -1) 38 | # stratified samples in those intervals 39 | t_rand = torch.rand(z_vals.shape).cuda() 40 | 41 | z_vals = lower + (upper - lower) * t_rand 42 | 43 | return z_vals 44 | 45 | 46 | class ErrorBoundSampler(RaySampler): 47 | def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra, 48 | eps, beta_iters, max_total_iters, 49 | inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0): 50 | super().__init__(near, 2.0 * scene_bounding_sphere) 51 | self.N_samples = N_samples 52 | self.N_samples_eval = N_samples_eval 53 | self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg) 54 | 55 | self.N_samples_extra = N_samples_extra 56 | 57 | self.eps = eps 58 | self.beta_iters = beta_iters 59 | self.max_total_iters = max_total_iters 60 | self.scene_bounding_sphere = scene_bounding_sphere 61 | self.add_tiny = add_tiny 62 | 63 | self.inverse_sphere_bg = inverse_sphere_bg 64 | if inverse_sphere_bg: 65 | self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0) 66 | 67 | def get_z_vals(self, ray_dirs, cam_loc, model): 68 | beta0 = model.density.get_beta().detach() 69 | 70 | # Start with uniform sampling 71 | z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model) 72 | samples, samples_idx = z_vals, None 73 | 74 | # Get maximum beta from the upper bound (Lemma 2) 75 | dists = z_vals[:, 1:] - z_vals[:, :-1] 76 | bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1) 77 | beta = torch.sqrt(bound) 78 | 79 | total_iters, not_converge = 0, True 80 | 81 | # Algorithm 1 82 | while not_converge and total_iters < self.max_total_iters: 83 | points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1) 84 | points_flat = points.reshape(-1, 3) 85 | 86 | # Calculating the SDF only for the new sampled points 87 | with torch.no_grad(): 88 | samples_sdf = model.implicit_network.get_sdf_vals(points_flat) 89 | if samples_idx is not None: 90 | sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]), 91 | samples_sdf.reshape(-1, samples.shape[1])], -1) 92 | sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1) 93 | else: 94 | sdf = samples_sdf 95 | 96 | 97 | # Calculating the bound d* (Theorem 1) 98 | d = sdf.reshape(z_vals.shape) 99 | dists = z_vals[:, 1:] - z_vals[:, :-1] 100 | a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs() 101 | first_cond = a.pow(2) + b.pow(2) <= c.pow(2) 102 | second_cond = a.pow(2) + c.pow(2) <= b.pow(2) 103 | d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda() 104 | d_star[first_cond] = b[first_cond] 105 | d_star[second_cond] = c[second_cond] 106 | s = (a + b + c) / 2.0 107 | area_before_sqrt = s * (s - a) * (s - b) * (s - c) 108 | mask = ~first_cond & ~second_cond & (b + c - a > 0) 109 | d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask]) 110 | d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign 111 | 112 | 113 | # Updating beta using line search 114 | curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star) 115 | beta[curr_error <= self.eps] = beta0 116 | beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta 117 | for j in range(self.beta_iters): 118 | beta_mid = (beta_min + beta_max) / 2. 119 | curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star) 120 | beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps] 121 | beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps] 122 | beta = beta_max 123 | 124 | 125 | # Upsample more points 126 | density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1)) 127 | 128 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1) 129 | free_energy = dists * density 130 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) 131 | alpha = 1 - torch.exp(-free_energy) 132 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) 133 | weights = alpha * transmittance # probability of the ray hits something here 134 | 135 | # Check if we are done and this is the last sampling 136 | total_iters += 1 137 | not_converge = beta.max() > beta0 138 | 139 | if not_converge and total_iters < self.max_total_iters: 140 | ''' Sample more points proportional to the current error bound''' 141 | 142 | N = self.N_samples_eval 143 | 144 | bins = z_vals 145 | error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2) 146 | error_integral = torch.cumsum(error_per_section, dim=-1) 147 | bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1] 148 | 149 | pdf = bound_opacity + self.add_tiny 150 | pdf = pdf / torch.sum(pdf, -1, keepdim=True) 151 | cdf = torch.cumsum(pdf, -1) 152 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 153 | 154 | else: 155 | ''' Sample the final sample set to be used in the volume rendering integral ''' 156 | 157 | N = self.N_samples 158 | 159 | bins = z_vals 160 | pdf = weights[..., :-1] 161 | pdf = pdf + 1e-5 # prevent nans 162 | pdf = pdf / torch.sum(pdf, -1, keepdim=True) 163 | cdf = torch.cumsum(pdf, -1) 164 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 165 | 166 | 167 | # Invert CDF 168 | if (not_converge and total_iters < self.max_total_iters) or (not model.training): 169 | u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1) 170 | else: 171 | u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda() 172 | u = u.contiguous() 173 | 174 | inds = torch.searchsorted(cdf, u, right=True) 175 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 176 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 177 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 178 | 179 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 180 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 181 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 182 | 183 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 184 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 185 | t = (u - cdf_g[..., 0]) / denom 186 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 187 | 188 | 189 | # Adding samples if we not converged 190 | if not_converge and total_iters < self.max_total_iters: 191 | z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1) 192 | 193 | 194 | z_samples = samples 195 | 196 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda() 197 | if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection 198 | far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:] 199 | 200 | if self.N_samples_extra > 0: 201 | if model.training: 202 | sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra] 203 | else: 204 | sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long() 205 | z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1) 206 | else: 207 | z_vals_extra = torch.cat([near, far], -1) 208 | 209 | z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1) 210 | 211 | # add some of the near surface points 212 | idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda() 213 | z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1)) 214 | 215 | if self.inverse_sphere_bg: 216 | z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model) 217 | z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere) 218 | z_vals = (z_vals, z_vals_inverse_sphere) 219 | 220 | return z_vals, z_samples_eik 221 | 222 | def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star): 223 | density = model.density(sdf.reshape(z_vals.shape), beta=beta) 224 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1) 225 | integral_estimation = torch.cumsum(shifted_free_energy, dim=-1) 226 | error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2) 227 | error_integral = torch.cumsum(error_per_section, dim=-1) 228 | bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1]) 229 | 230 | return bound_opacity.max(-1)[0] 231 | 232 | 233 | -------------------------------------------------------------------------------- /code/training/exp_runner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../code') 4 | import argparse 5 | import GPUtil 6 | 7 | from training.volsdf_train import VolSDFTrainRunner 8 | 9 | if __name__ == '__main__': 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 13 | parser.add_argument('--nepoch', type=int, default=2000, help='number of epochs to train for') 14 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf') 15 | parser.add_argument('--expname', type=str, default='') 16 | parser.add_argument("--exps_folder", type=str, default="exps") 17 | parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]') 18 | parser.add_argument('--is_continue', default=False, action="store_true", 19 | help='If set, indicates continuing from a previous run.') 20 | parser.add_argument('--timestamp', default='latest', type=str, 21 | help='The timestamp of the run to be used in case of continuing from a previous run.') 22 | parser.add_argument('--checkpoint', default='latest', type=str, 23 | help='The checkpoint epoch of the run to be used in case of continuing from a previous run.') 24 | parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.') 25 | parser.add_argument('--cancel_vis', default=False, action="store_true", 26 | help='If set, cancel visualization in intermediate epochs.') 27 | 28 | opt = parser.parse_args() 29 | 30 | if opt.gpu == "auto": 31 | deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, 32 | excludeID=[], excludeUUID=[]) 33 | gpu = deviceIDs[0] 34 | else: 35 | gpu = opt.gpu 36 | 37 | trainrunner = VolSDFTrainRunner(conf=opt.conf, 38 | batch_size=opt.batch_size, 39 | nepochs=opt.nepoch, 40 | expname=opt.expname, 41 | gpu_index=gpu, 42 | exps_folder_name=opt.exps_folder, 43 | is_continue=opt.is_continue, 44 | timestamp=opt.timestamp, 45 | checkpoint=opt.checkpoint, 46 | scan_id=opt.scan_id, 47 | do_vis=not opt.cancel_vis 48 | ) 49 | 50 | trainrunner.run() 51 | -------------------------------------------------------------------------------- /code/training/volsdf_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from pyhocon import ConfigFactory 4 | import sys 5 | import torch 6 | from tqdm import tqdm 7 | 8 | import utils.general as utils 9 | import utils.plots as plt 10 | from utils import rend_util 11 | 12 | class VolSDFTrainRunner(): 13 | def __init__(self,**kwargs): 14 | torch.set_default_dtype(torch.float32) 15 | torch.set_num_threads(1) 16 | 17 | self.conf = ConfigFactory.parse_file(kwargs['conf']) 18 | self.batch_size = kwargs['batch_size'] 19 | self.nepochs = kwargs['nepochs'] 20 | self.exps_folder_name = kwargs['exps_folder_name'] 21 | self.GPU_INDEX = kwargs['gpu_index'] 22 | 23 | self.expname = self.conf.get_string('train.expname') + kwargs['expname'] 24 | scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int('dataset.scan_id', default=-1) 25 | if scan_id != -1: 26 | self.expname = self.expname + '_{0}'.format(scan_id) 27 | 28 | if kwargs['is_continue'] and kwargs['timestamp'] == 'latest': 29 | if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)): 30 | timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname)) 31 | if (len(timestamps)) == 0: 32 | is_continue = False 33 | timestamp = None 34 | else: 35 | timestamp = sorted(timestamps)[-1] 36 | is_continue = True 37 | else: 38 | is_continue = False 39 | timestamp = None 40 | else: 41 | timestamp = kwargs['timestamp'] 42 | is_continue = kwargs['is_continue'] 43 | 44 | utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name)) 45 | self.expdir = os.path.join('../', self.exps_folder_name, self.expname) 46 | utils.mkdir_ifnotexists(self.expdir) 47 | self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) 48 | utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp)) 49 | 50 | self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots') 51 | utils.mkdir_ifnotexists(self.plots_dir) 52 | 53 | # create checkpoints dirs 54 | self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints') 55 | utils.mkdir_ifnotexists(self.checkpoints_path) 56 | self.model_params_subdir = "ModelParameters" 57 | self.optimizer_params_subdir = "OptimizerParameters" 58 | self.scheduler_params_subdir = "SchedulerParameters" 59 | 60 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir)) 61 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) 62 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir)) 63 | 64 | os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf'))) 65 | 66 | if (not self.GPU_INDEX == 'ignore'): 67 | os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX) 68 | 69 | print('shell command : {0}'.format(' '.join(sys.argv))) 70 | 71 | print('Loading data ...') 72 | 73 | dataset_conf = self.conf.get_config('dataset') 74 | if kwargs['scan_id'] != -1: 75 | dataset_conf['scan_id'] = kwargs['scan_id'] 76 | 77 | self.train_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(**dataset_conf) 78 | 79 | self.ds_len = len(self.train_dataset) 80 | print('Finish loading data. Data-set size: {0}'.format(self.ds_len)) 81 | if scan_id < 24 and scan_id > 0: # BlendedMVS, running for 200k iterations 82 | self.nepochs = int(200000 / self.ds_len) 83 | print('RUNNING FOR {0}'.format(self.nepochs)) 84 | 85 | self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, 86 | batch_size=self.batch_size, 87 | shuffle=True, 88 | collate_fn=self.train_dataset.collate_fn 89 | ) 90 | self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset, 91 | batch_size=self.conf.get_int('plot.plot_nimgs'), 92 | shuffle=True, 93 | collate_fn=self.train_dataset.collate_fn 94 | ) 95 | 96 | conf_model = self.conf.get_config('model') 97 | self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=conf_model) 98 | if torch.cuda.is_available(): 99 | self.model.cuda() 100 | 101 | self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss')) 102 | 103 | self.lr = self.conf.get_float('train.learning_rate') 104 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) 105 | # Exponential learning rate scheduler 106 | decay_rate = self.conf.get_float('train.sched_decay_rate', default=0.1) 107 | decay_steps = self.nepochs * len(self.train_dataset) 108 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, decay_rate ** (1./decay_steps)) 109 | 110 | self.do_vis = kwargs['do_vis'] 111 | 112 | self.start_epoch = 0 113 | if is_continue: 114 | old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') 115 | 116 | saved_model_state = torch.load( 117 | os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) 118 | self.model.load_state_dict(saved_model_state["model_state_dict"]) 119 | self.start_epoch = saved_model_state['epoch'] 120 | 121 | data = torch.load( 122 | os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) 123 | self.optimizer.load_state_dict(data["optimizer_state_dict"]) 124 | 125 | data = torch.load( 126 | os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth")) 127 | self.scheduler.load_state_dict(data["scheduler_state_dict"]) 128 | 129 | self.num_pixels = self.conf.get_int('train.num_pixels') 130 | self.total_pixels = self.train_dataset.total_pixels 131 | self.img_res = self.train_dataset.img_res 132 | self.n_batches = len(self.train_dataloader) 133 | self.plot_freq = self.conf.get_int('train.plot_freq') 134 | self.checkpoint_freq = self.conf.get_int('train.checkpoint_freq', default=100) 135 | self.split_n_pixels = self.conf.get_int('train.split_n_pixels', default=10000) 136 | self.plot_conf = self.conf.get_config('plot') 137 | 138 | def save_checkpoints(self, epoch): 139 | torch.save( 140 | {"epoch": epoch, "model_state_dict": self.model.state_dict()}, 141 | os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth")) 142 | torch.save( 143 | {"epoch": epoch, "model_state_dict": self.model.state_dict()}, 144 | os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth")) 145 | 146 | torch.save( 147 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, 148 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth")) 149 | torch.save( 150 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, 151 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth")) 152 | 153 | torch.save( 154 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()}, 155 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, str(epoch) + ".pth")) 156 | torch.save( 157 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()}, 158 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, "latest.pth")) 159 | 160 | def run(self): 161 | print("training...") 162 | 163 | for epoch in range(self.start_epoch, self.nepochs + 1): 164 | 165 | if epoch % self.checkpoint_freq == 0: 166 | self.save_checkpoints(epoch) 167 | 168 | if self.do_vis and epoch % self.plot_freq == 0: 169 | self.model.eval() 170 | 171 | self.train_dataset.change_sampling_idx(-1) 172 | indices, model_input, ground_truth = next(iter(self.plot_dataloader)) 173 | 174 | model_input["intrinsics"] = model_input["intrinsics"].cuda() 175 | model_input["uv"] = model_input["uv"].cuda() 176 | model_input['pose'] = model_input['pose'].cuda() 177 | 178 | split = utils.split_input(model_input, self.total_pixels, n_pixels=self.split_n_pixels) 179 | res = [] 180 | for s in tqdm(split): 181 | out = self.model(s) 182 | d = {'rgb_values': out['rgb_values'].detach(), 183 | 'normal_map': out['normal_map'].detach()} 184 | res.append(d) 185 | 186 | batch_size = ground_truth['rgb'].shape[0] 187 | model_outputs = utils.merge_output(res, self.total_pixels, batch_size) 188 | plot_data = self.get_plot_data(model_outputs, model_input['pose'], ground_truth['rgb']) 189 | 190 | plt.plot(self.model.implicit_network, 191 | indices, 192 | plot_data, 193 | self.plots_dir, 194 | epoch, 195 | self.img_res, 196 | **self.plot_conf 197 | ) 198 | 199 | self.model.train() 200 | 201 | self.train_dataset.change_sampling_idx(self.num_pixels) 202 | 203 | for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader): 204 | model_input["intrinsics"] = model_input["intrinsics"].cuda() 205 | model_input["uv"] = model_input["uv"].cuda() 206 | model_input['pose'] = model_input['pose'].cuda() 207 | 208 | model_outputs = self.model(model_input) 209 | loss_output = self.loss(model_outputs, ground_truth) 210 | 211 | loss = loss_output['loss'] 212 | 213 | self.optimizer.zero_grad() 214 | loss.backward() 215 | self.optimizer.step() 216 | 217 | psnr = rend_util.get_psnr(model_outputs['rgb_values'], 218 | ground_truth['rgb'].cuda().reshape(-1,3)) 219 | print( 220 | '{0}_{1} [{2}] ({3}/{4}): loss = {5}, rgb_loss = {6}, eikonal_loss = {7}, psnr = {8}' 221 | .format(self.expname, self.timestamp, epoch, data_index, self.n_batches, loss.item(), 222 | loss_output['rgb_loss'].item(), 223 | loss_output['eikonal_loss'].item(), 224 | psnr.item())) 225 | 226 | self.train_dataset.change_sampling_idx(self.num_pixels) 227 | self.scheduler.step() 228 | 229 | self.save_checkpoints(epoch) 230 | 231 | def get_plot_data(self, model_outputs, pose, rgb_gt): 232 | batch_size, num_samples, _ = rgb_gt.shape 233 | 234 | rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3) 235 | normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3) 236 | normal_map = (normal_map + 1.) / 2. 237 | 238 | plot_data = { 239 | 'rgb_gt': rgb_gt, 240 | 'pose': pose, 241 | 'rgb_eval': rgb_eval, 242 | 'normal_map': normal_map, 243 | } 244 | 245 | return plot_data 246 | -------------------------------------------------------------------------------- /code/utils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import torch 4 | 5 | def mkdir_ifnotexists(directory): 6 | if not os.path.exists(directory): 7 | os.mkdir(directory) 8 | 9 | def get_class(kls): 10 | parts = kls.split('.') 11 | module = ".".join(parts[:-1]) 12 | m = __import__(module) 13 | for comp in parts[1:]: 14 | m = getattr(m, comp) 15 | return m 16 | 17 | def glob_imgs(path): 18 | imgs = [] 19 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']: 20 | imgs.extend(glob(os.path.join(path, ext))) 21 | return imgs 22 | 23 | def split_input(model_input, total_pixels, n_pixels=10000): 24 | ''' 25 | Split the input to fit Cuda memory for large resolution. 26 | Can decrease the value of n_pixels in case of cuda out of memory error. 27 | ''' 28 | split = [] 29 | for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)): 30 | data = model_input.copy() 31 | data['uv'] = torch.index_select(model_input['uv'], 1, indx) 32 | if 'object_mask' in data: 33 | data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx) 34 | split.append(data) 35 | return split 36 | 37 | def merge_output(res, total_pixels, batch_size): 38 | ''' Merge the split output. ''' 39 | 40 | model_outputs = {} 41 | for entry in res[0]: 42 | if res[0][entry] is None: 43 | continue 44 | if len(res[0][entry].shape) == 1: 45 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res], 46 | 1).reshape(batch_size * total_pixels) 47 | else: 48 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res], 49 | 1).reshape(batch_size * total_pixels, -1) 50 | 51 | return model_outputs 52 | 53 | def concat_home_dir(path): 54 | return os.path.join(os.environ['HOME'],'data',path) -------------------------------------------------------------------------------- /code/utils/plots.py: -------------------------------------------------------------------------------- 1 | import plotly.graph_objs as go 2 | import plotly.offline as offline 3 | from plotly.subplots import make_subplots 4 | import numpy as np 5 | import torch 6 | from skimage import measure 7 | import torchvision 8 | import trimesh 9 | from PIL import Image 10 | 11 | from utils import rend_util 12 | 13 | 14 | def plot(implicit_network, indices, plot_data, path, epoch, img_res, plot_nimgs, resolution, grid_boundary, level=0): 15 | 16 | if plot_data is not None: 17 | cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose']) 18 | 19 | plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res) 20 | 21 | # plot normal maps 22 | plot_normal_maps(plot_data['normal_map'], path, epoch, plot_nimgs, img_res) 23 | 24 | 25 | data = [] 26 | 27 | # plot surface 28 | surface_traces = get_surface_trace(path=path, 29 | epoch=epoch, 30 | sdf=lambda x: implicit_network(x)[:, 0], 31 | resolution=resolution, 32 | grid_boundary=grid_boundary, 33 | level=level 34 | ) 35 | 36 | if surface_traces is not None: 37 | data.append(surface_traces[0]) 38 | 39 | # plot cameras locations 40 | if plot_data is not None: 41 | for i, loc, dir in zip(indices, cam_loc, cam_dir): 42 | data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i))) 43 | 44 | fig = go.Figure(data=data) 45 | scene_dict = dict(xaxis=dict(range=[-6, 6], autorange=False), 46 | yaxis=dict(range=[-6, 6], autorange=False), 47 | zaxis=dict(range=[-6, 6], autorange=False), 48 | aspectratio=dict(x=1, y=1, z=1)) 49 | fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True) 50 | filename = '{0}/surface_{1}.html'.format(path, epoch) 51 | offline.plot(fig, filename=filename, auto_open=False) 52 | 53 | 54 | def get_3D_scatter_trace(points, name='', size=3, caption=None): 55 | assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped " 56 | assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped " 57 | 58 | trace = go.Scatter3d( 59 | x=points[:, 0].cpu(), 60 | y=points[:, 1].cpu(), 61 | z=points[:, 2].cpu(), 62 | mode='markers', 63 | name=name, 64 | marker=dict( 65 | size=size, 66 | line=dict( 67 | width=2, 68 | ), 69 | opacity=1.0, 70 | ), text=caption) 71 | 72 | return trace 73 | 74 | 75 | def get_3D_quiver_trace(points, directions, color='#bd1540', name=''): 76 | assert points.shape[1] == 3, "3d cone plot input points are not correctely shaped " 77 | assert len(points.shape) == 2, "3d cone plot input points are not correctely shaped " 78 | assert directions.shape[1] == 3, "3d cone plot input directions are not correctely shaped " 79 | assert len(directions.shape) == 2, "3d cone plot input directions are not correctely shaped " 80 | 81 | trace = go.Cone( 82 | name=name, 83 | x=points[:, 0].cpu(), 84 | y=points[:, 1].cpu(), 85 | z=points[:, 2].cpu(), 86 | u=directions[:, 0].cpu(), 87 | v=directions[:, 1].cpu(), 88 | w=directions[:, 2].cpu(), 89 | sizemode='absolute', 90 | sizeref=0.125, 91 | showscale=False, 92 | colorscale=[[0, color], [1, color]], 93 | anchor="tail" 94 | ) 95 | 96 | return trace 97 | 98 | 99 | def get_surface_trace(path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0): 100 | grid = get_grid_uniform(resolution, grid_boundary) 101 | points = grid['grid_points'] 102 | 103 | z = [] 104 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 105 | z.append(sdf(pnts).detach().cpu().numpy()) 106 | z = np.concatenate(z, axis=0) 107 | 108 | if (not (np.min(z) > level or np.max(z) < level)): 109 | 110 | z = z.astype(np.float32) 111 | 112 | verts, faces, normals, values = measure.marching_cubes( 113 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], 114 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]), 115 | level=level, 116 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 117 | grid['xyz'][0][2] - grid['xyz'][0][1], 118 | grid['xyz'][0][2] - grid['xyz'][0][1])) 119 | 120 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) 121 | 122 | I, J, K = faces.transpose() 123 | 124 | traces = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], 125 | i=I, j=J, k=K, name='implicit_surface', 126 | color='#ffffff', opacity=1.0, flatshading=False, 127 | lighting=dict(diffuse=1, ambient=0, specular=0), 128 | lightposition=dict(x=0, y=0, z=-1), showlegend=True)] 129 | 130 | meshexport = trimesh.Trimesh(verts, faces, normals) 131 | meshexport.export('{0}/surface_{1}.ply'.format(path, epoch), 'ply') 132 | 133 | if return_mesh: 134 | return meshexport 135 | return traces 136 | return None 137 | 138 | def get_surface_high_res_mesh(sdf, resolution=100, grid_boundary=[-2.0, 2.0], level=0, take_components=True): 139 | # get low res mesh to sample point cloud 140 | grid = get_grid_uniform(100, grid_boundary) 141 | z = [] 142 | points = grid['grid_points'] 143 | 144 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 145 | z.append(sdf(pnts).detach().cpu().numpy()) 146 | z = np.concatenate(z, axis=0) 147 | 148 | z = z.astype(np.float32) 149 | 150 | verts, faces, normals, values = measure.marching_cubes( 151 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], 152 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]), 153 | level=level, 154 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 155 | grid['xyz'][0][2] - grid['xyz'][0][1], 156 | grid['xyz'][0][2] - grid['xyz'][0][1])) 157 | 158 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) 159 | 160 | mesh_low_res = trimesh.Trimesh(verts, faces, normals) 161 | if take_components: 162 | components = mesh_low_res.split(only_watertight=False) 163 | areas = np.array([c.area for c in components], dtype=np.float) 164 | mesh_low_res = components[areas.argmax()] 165 | 166 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] 167 | recon_pc = torch.from_numpy(recon_pc).float().cuda() 168 | 169 | # Center and align the recon pc 170 | s_mean = recon_pc.mean(dim=0) 171 | s_cov = recon_pc - s_mean 172 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) 173 | vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0] 174 | if torch.det(vecs) < 0: 175 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs) 176 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), 177 | (recon_pc - s_mean).unsqueeze(-1)).squeeze() 178 | 179 | grid_aligned = get_grid(helper.cpu(), resolution) 180 | 181 | grid_points = grid_aligned['grid_points'] 182 | 183 | g = [] 184 | for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)): 185 | g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), 186 | pnts.unsqueeze(-1)).squeeze() + s_mean) 187 | grid_points = torch.cat(g, dim=0) 188 | 189 | # MC to new grid 190 | points = grid_points 191 | z = [] 192 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 193 | z.append(sdf(pnts).detach().cpu().numpy()) 194 | z = np.concatenate(z, axis=0) 195 | 196 | meshexport = None 197 | if (not (np.min(z) > level or np.max(z) < level)): 198 | 199 | z = z.astype(np.float32) 200 | 201 | verts, faces, normals, values = measure.marching_cubes( 202 | volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0], 203 | grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]), 204 | level=level, 205 | spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], 206 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], 207 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1])) 208 | 209 | verts = torch.from_numpy(verts).cuda().float() 210 | verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), 211 | verts.unsqueeze(-1)).squeeze() 212 | verts = (verts + grid_points[0]).cpu().numpy() 213 | 214 | meshexport = trimesh.Trimesh(verts, faces, normals) 215 | 216 | return meshexport 217 | 218 | 219 | def get_surface_by_grid(grid_params, sdf, resolution=100, level=0, higher_res=False): 220 | grid_params = grid_params * [[1.5], [1.0]] 221 | 222 | # params = PLOT_DICT[scan_id] 223 | input_min = torch.tensor(grid_params[0]).float() 224 | input_max = torch.tensor(grid_params[1]).float() 225 | 226 | if higher_res: 227 | # get low res mesh to sample point cloud 228 | grid = get_grid(None, 100, input_min=input_min, input_max=input_max, eps=0.0) 229 | z = [] 230 | points = grid['grid_points'] 231 | 232 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 233 | z.append(sdf(pnts).detach().cpu().numpy()) 234 | z = np.concatenate(z, axis=0) 235 | 236 | z = z.astype(np.float32) 237 | 238 | verts, faces, normals, values = measure.marching_cubes( 239 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], 240 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]), 241 | level=level, 242 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 243 | grid['xyz'][0][2] - grid['xyz'][0][1], 244 | grid['xyz'][0][2] - grid['xyz'][0][1])) 245 | 246 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) 247 | 248 | mesh_low_res = trimesh.Trimesh(verts, faces, normals) 249 | components = mesh_low_res.split(only_watertight=False) 250 | areas = np.array([c.area for c in components], dtype=np.float) 251 | mesh_low_res = components[areas.argmax()] 252 | 253 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] 254 | recon_pc = torch.from_numpy(recon_pc).float().cuda() 255 | 256 | # Center and align the recon pc 257 | s_mean = recon_pc.mean(dim=0) 258 | s_cov = recon_pc - s_mean 259 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) 260 | vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0] 261 | if torch.det(vecs) < 0: 262 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs) 263 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), 264 | (recon_pc - s_mean).unsqueeze(-1)).squeeze() 265 | 266 | grid_aligned = get_grid(helper.cpu(), resolution, eps=0.01) 267 | else: 268 | grid_aligned = get_grid(None, resolution, input_min=input_min, input_max=input_max, eps=0.0) 269 | 270 | grid_points = grid_aligned['grid_points'] 271 | 272 | if higher_res: 273 | g = [] 274 | for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)): 275 | g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), 276 | pnts.unsqueeze(-1)).squeeze() + s_mean) 277 | grid_points = torch.cat(g, dim=0) 278 | 279 | # MC to new grid 280 | points = grid_points 281 | z = [] 282 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 283 | z.append(sdf(pnts).detach().cpu().numpy()) 284 | z = np.concatenate(z, axis=0) 285 | 286 | meshexport = None 287 | if (not (np.min(z) > level or np.max(z) < level)): 288 | 289 | z = z.astype(np.float32) 290 | 291 | verts, faces, normals, values = measure.marching_cubes( 292 | volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0], 293 | grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]), 294 | level=level, 295 | spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], 296 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], 297 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1])) 298 | 299 | if higher_res: 300 | verts = torch.from_numpy(verts).cuda().float() 301 | verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), 302 | verts.unsqueeze(-1)).squeeze() 303 | verts = (verts + grid_points[0]).cpu().numpy() 304 | else: 305 | verts = verts + np.array([grid_aligned['xyz'][0][0], grid_aligned['xyz'][1][0], grid_aligned['xyz'][2][0]]) 306 | 307 | meshexport = trimesh.Trimesh(verts, faces, normals) 308 | 309 | # CUTTING MESH ACCORDING TO THE BOUNDING BOX 310 | if higher_res: 311 | bb = grid_params 312 | transformation = np.eye(4) 313 | transformation[:3, 3] = (bb[1,:] + bb[0,:])/2. 314 | bounding_box = trimesh.creation.box(extents=bb[1,:] - bb[0,:], transform=transformation) 315 | 316 | meshexport = meshexport.slice_plane(bounding_box.facets_origin, -bounding_box.facets_normal) 317 | 318 | return meshexport 319 | 320 | def get_grid_uniform(resolution, grid_boundary=[-2.0, 2.0]): 321 | x = np.linspace(grid_boundary[0], grid_boundary[1], resolution) 322 | y = x 323 | z = x 324 | 325 | xx, yy, zz = np.meshgrid(x, y, z) 326 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float) 327 | 328 | return {"grid_points": grid_points.cuda(), 329 | "shortest_axis_length": 2.0, 330 | "xyz": [x, y, z], 331 | "shortest_axis_index": 0} 332 | 333 | def get_grid(points, resolution, input_min=None, input_max=None, eps=0.1): 334 | if input_min is None or input_max is None: 335 | input_min = torch.min(points, dim=0)[0].squeeze().numpy() 336 | input_max = torch.max(points, dim=0)[0].squeeze().numpy() 337 | 338 | bounding_box = input_max - input_min 339 | shortest_axis = np.argmin(bounding_box) 340 | if (shortest_axis == 0): 341 | x = np.linspace(input_min[shortest_axis] - eps, 342 | input_max[shortest_axis] + eps, resolution) 343 | length = np.max(x) - np.min(x) 344 | y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 345 | z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 346 | elif (shortest_axis == 1): 347 | y = np.linspace(input_min[shortest_axis] - eps, 348 | input_max[shortest_axis] + eps, resolution) 349 | length = np.max(y) - np.min(y) 350 | x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 351 | z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 352 | elif (shortest_axis == 2): 353 | z = np.linspace(input_min[shortest_axis] - eps, 354 | input_max[shortest_axis] + eps, resolution) 355 | length = np.max(z) - np.min(z) 356 | x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 357 | y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 358 | 359 | xx, yy, zz = np.meshgrid(x, y, z) 360 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() 361 | return {"grid_points": grid_points, 362 | "shortest_axis_length": length, 363 | "xyz": [x, y, z], 364 | "shortest_axis_index": shortest_axis} 365 | 366 | 367 | def plot_normal_maps(normal_maps, path, epoch, plot_nrow, img_res): 368 | normal_maps_plot = lin2img(normal_maps, img_res) 369 | 370 | tensor = torchvision.utils.make_grid(normal_maps_plot, 371 | scale_each=False, 372 | normalize=False, 373 | nrow=plot_nrow).cpu().detach().numpy() 374 | tensor = tensor.transpose(1, 2, 0) 375 | scale_factor = 255 376 | tensor = (tensor * scale_factor).astype(np.uint8) 377 | 378 | img = Image.fromarray(tensor) 379 | img.save('{0}/normal_{1}.png'.format(path, epoch)) 380 | 381 | 382 | def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res): 383 | ground_true = ground_true.cuda() 384 | 385 | output_vs_gt = torch.cat((rgb_points, ground_true), dim=0) 386 | output_vs_gt_plot = lin2img(output_vs_gt, img_res) 387 | 388 | tensor = torchvision.utils.make_grid(output_vs_gt_plot, 389 | scale_each=False, 390 | normalize=False, 391 | nrow=plot_nrow).cpu().detach().numpy() 392 | 393 | tensor = tensor.transpose(1, 2, 0) 394 | scale_factor = 255 395 | tensor = (tensor * scale_factor).astype(np.uint8) 396 | 397 | img = Image.fromarray(tensor) 398 | img.save('{0}/rendering_{1}.png'.format(path, epoch)) 399 | 400 | 401 | def lin2img(tensor, img_res): 402 | batch_size, num_samples, channels = tensor.shape 403 | return tensor.permute(0, 2, 1).view(batch_size, channels, img_res[0], img_res[1]) 404 | -------------------------------------------------------------------------------- /code/utils/rend_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import skimage 4 | import cv2 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | 9 | def get_psnr(img1, img2, normalize_rgb=False): 10 | if normalize_rgb: # [-1,1] --> [0,1] 11 | img1 = (img1 + 1.) / 2. 12 | img2 = (img2 + 1. ) / 2. 13 | 14 | mse = torch.mean((img1 - img2) ** 2) 15 | psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda()) 16 | 17 | return psnr 18 | 19 | 20 | def load_rgb(path, normalize_rgb = False): 21 | img = imageio.imread(path) 22 | img = skimage.img_as_float32(img) 23 | 24 | if normalize_rgb: # [-1,1] --> [0,1] 25 | img -= 0.5 26 | img *= 2. 27 | img = img.transpose(2, 0, 1) 28 | return img 29 | 30 | 31 | def load_K_Rt_from_P(filename, P=None): 32 | if P is None: 33 | lines = open(filename).read().splitlines() 34 | if len(lines) == 4: 35 | lines = lines[1:] 36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 37 | P = np.asarray(lines).astype(np.float32).squeeze() 38 | 39 | out = cv2.decomposeProjectionMatrix(P) 40 | K = out[0] 41 | R = out[1] 42 | t = out[2] 43 | 44 | K = K/K[2,2] 45 | intrinsics = np.eye(4) 46 | intrinsics[:3, :3] = K 47 | 48 | pose = np.eye(4, dtype=np.float32) 49 | pose[:3, :3] = R.transpose() 50 | pose[:3,3] = (t[:3] / t[3])[:,0] 51 | 52 | return intrinsics, pose 53 | 54 | 55 | def get_camera_params(uv, pose, intrinsics): 56 | if pose.shape[1] == 7: #In case of quaternion vector representation 57 | cam_loc = pose[:, 4:] 58 | R = quat_to_rot(pose[:,:4]) 59 | p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float() 60 | p[:, :3, :3] = R 61 | p[:, :3, 3] = cam_loc 62 | else: # In case of pose matrix representation 63 | cam_loc = pose[:, :3, 3] 64 | p = pose 65 | 66 | batch_size, num_samples, _ = uv.shape 67 | 68 | depth = torch.ones((batch_size, num_samples)).cuda() 69 | x_cam = uv[:, :, 0].view(batch_size, -1) 70 | y_cam = uv[:, :, 1].view(batch_size, -1) 71 | z_cam = depth.view(batch_size, -1) 72 | 73 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) 74 | 75 | # permute for batch matrix product 76 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1) 77 | 78 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] 79 | ray_dirs = world_coords - cam_loc[:, None, :] 80 | ray_dirs = F.normalize(ray_dirs, dim=2) 81 | 82 | return ray_dirs, cam_loc 83 | 84 | 85 | def get_camera_for_plot(pose): 86 | if pose.shape[1] == 7: #In case of quaternion vector representation 87 | cam_loc = pose[:, 4:].detach() 88 | R = quat_to_rot(pose[:,:4].detach()) 89 | else: # In case of pose matrix representation 90 | cam_loc = pose[:, :3, 3] 91 | R = pose[:, :3, :3] 92 | cam_dir = R[:, :3, 2] 93 | return cam_loc, cam_dir 94 | 95 | 96 | def lift(x, y, z, intrinsics): 97 | # parse intrinsics 98 | intrinsics = intrinsics.cuda() 99 | fx = intrinsics[:, 0, 0] 100 | fy = intrinsics[:, 1, 1] 101 | cx = intrinsics[:, 0, 2] 102 | cy = intrinsics[:, 1, 2] 103 | sk = intrinsics[:, 0, 1] 104 | 105 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z 106 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z 107 | 108 | # homogeneous 109 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1) 110 | 111 | 112 | def quat_to_rot(q): 113 | batch_size, _ = q.shape 114 | q = F.normalize(q, dim=1) 115 | R = torch.ones((batch_size, 3,3)).cuda() 116 | qr=q[:,0] 117 | qi = q[:, 1] 118 | qj = q[:, 2] 119 | qk = q[:, 3] 120 | R[:, 0, 0]=1-2 * (qj**2 + qk**2) 121 | R[:, 0, 1] = 2 * (qj *qi -qk*qr) 122 | R[:, 0, 2] = 2 * (qi * qk + qr * qj) 123 | R[:, 1, 0] = 2 * (qj * qi + qk * qr) 124 | R[:, 1, 1] = 1-2 * (qi**2 + qk**2) 125 | R[:, 1, 2] = 2*(qj*qk - qi*qr) 126 | R[:, 2, 0] = 2 * (qk * qi-qj * qr) 127 | R[:, 2, 1] = 2 * (qj*qk + qi*qr) 128 | R[:, 2, 2] = 1-2 * (qi**2 + qj**2) 129 | return R 130 | 131 | 132 | def rot_to_quat(R): 133 | batch_size, _,_ = R.shape 134 | q = torch.ones((batch_size, 4)).cuda() 135 | 136 | R00 = R[:, 0,0] 137 | R01 = R[:, 0, 1] 138 | R02 = R[:, 0, 2] 139 | R10 = R[:, 1, 0] 140 | R11 = R[:, 1, 1] 141 | R12 = R[:, 1, 2] 142 | R20 = R[:, 2, 0] 143 | R21 = R[:, 2, 1] 144 | R22 = R[:, 2, 2] 145 | 146 | q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2 147 | q[:, 1]=(R21-R12)/(4*q[:,0]) 148 | q[:, 2] = (R02 - R20) / (4 * q[:, 0]) 149 | q[:, 3] = (R10 - R01) / (4 * q[:, 0]) 150 | return q 151 | 152 | 153 | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0): 154 | # Input: n_rays x 3 ; n_rays x 3 155 | # Output: n_rays x 1, n_rays x 1 (close and far) 156 | 157 | ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3), 158 | cam_loc.view(-1, 3, 1)).squeeze(-1) 159 | under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2) 160 | 161 | # sanity check 162 | if (under_sqrt <= 0).sum() > 0: 163 | print('BOUNDING SPHERE PROBLEM!') 164 | exit() 165 | 166 | sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot 167 | sphere_intersections = sphere_intersections.clamp_min(0.0) 168 | 169 | return sphere_intersections 170 | -------------------------------------------------------------------------------- /data/download_data.sh: -------------------------------------------------------------------------------- 1 | confsmkdir -p data 2 | cd data 3 | echo "Downloading the DTU dataset ..." 4 | wget https://www.dropbox.com/s/s6psnh1q91m4kgo/DTU.zip 5 | echo "Start unzipping ..." 6 | unzip DTU.zip 7 | echo "DTU dataset is ready!" 8 | rm -f DTU.zip 9 | echo "Downloading the BlendedMVS dataset ..." 10 | wget https://www.dropbox.com/s/c88216wzn9t6pj8/BlendedMVS.zip 11 | echo "Start unzipping ..." 12 | unzip BlendedMVS.zip 13 | echo "BlendedMVS dataset is ready!" 14 | rm -f BlendedMVS.zip -------------------------------------------------------------------------------- /data/preprocess/normalize_cameras.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import argparse 4 | 5 | 6 | def get_center_point(num_cams,cameras): 7 | A = np.zeros((3 * num_cams, 3 + num_cams)) 8 | b = np.zeros((3 * num_cams, 1)) 9 | camera_centers=np.zeros((3,num_cams)) 10 | for i in range(num_cams): 11 | P0 = cameras['world_mat_%d' % i][:3, :] 12 | 13 | K = cv2.decomposeProjectionMatrix(P0)[0] 14 | R = cv2.decomposeProjectionMatrix(P0)[1] 15 | c = cv2.decomposeProjectionMatrix(P0)[2] 16 | c = c / c[3] 17 | camera_centers[:,i]=c[:3].flatten() 18 | 19 | v = np.linalg.inv(K) @ np.array([800, 600, 1]) 20 | v = v / np.linalg.norm(v) 21 | 22 | v=R[2,:] 23 | A[3 * i:(3 * i + 3), :3] = np.eye(3) 24 | A[3 * i:(3 * i + 3), 3 + i] = -v 25 | b[3 * i:(3 * i + 3)] = c[:3] 26 | 27 | soll= np.linalg.pinv(A) @ b 28 | 29 | return soll,camera_centers 30 | 31 | def normalize_cameras(original_cameras_filename,output_cameras_filename,num_of_cameras): 32 | cameras = np.load(original_cameras_filename) 33 | if num_of_cameras==-1: 34 | all_files=cameras.files 35 | maximal_ind=0 36 | for field in all_files: 37 | maximal_ind=np.maximum(maximal_ind,int(field.split('_')[-1])) 38 | num_of_cameras=maximal_ind+1 39 | soll, camera_centers = get_center_point(num_of_cameras, cameras) 40 | 41 | center = soll[:3].flatten() 42 | 43 | max_radius = np.linalg.norm((center[:, np.newaxis] - camera_centers), axis=0).max() * 1.1 44 | 45 | normalization = np.eye(4).astype(np.float32) 46 | 47 | normalization[0, 3] = center[0] 48 | normalization[1, 3] = center[1] 49 | normalization[2, 3] = center[2] 50 | 51 | normalization[0, 0] = max_radius / 3.0 52 | normalization[1, 1] = max_radius / 3.0 53 | normalization[2, 2] = max_radius / 3.0 54 | 55 | cameras_new = {} 56 | for i in range(num_of_cameras): 57 | cameras_new['scale_mat_%d' % i] = normalization 58 | cameras_new['world_mat_%d' % i] = cameras['world_mat_%d' % i].copy() 59 | np.savez(output_cameras_filename, **cameras_new) 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser(description='Normalizing cameras') 64 | parser.add_argument('--input_cameras_file', type=str, default="cameras.npz", 65 | help='the input cameras file') 66 | parser.add_argument('--output_cameras_file', type=str, default="cameras_normalize.npz", 67 | help='the output cameras file') 68 | parser.add_argument('--number_of_cams',type=int, default=-1, 69 | help='Number of cameras, if -1 use all') 70 | 71 | args = parser.parse_args() 72 | normalize_cameras(args.input_cameras_file, args.output_cameras_file, args.number_of_cams) 73 | -------------------------------------------------------------------------------- /data/preprocess/parse_cameras_blendedmvs.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import argparse 4 | import os 5 | 6 | def read_camera(sequence,ind): 7 | file = "%s/cams/%08d_cam.txt"%(sequence,ind) 8 | f = open(file) 9 | 10 | f.readline().strip() 11 | 12 | row1 = f.readline().strip().split() 13 | row2 = f.readline().strip().split() 14 | row3 = f.readline().strip().split() 15 | 16 | M = np.stack( 17 | (np.array(row1).astype(np.float32), np.array(row2).astype(np.float32), np.array(row3).astype(np.float32))) 18 | f.readline() 19 | f.readline() 20 | f.readline() 21 | row1 = f.readline().strip().split() 22 | row2 = f.readline().strip().split() 23 | row3 = f.readline().strip().split() 24 | K = np.stack( 25 | (np.array(row1).astype(np.float32), np.array(row2).astype(np.float32), np.array(row3).astype(np.float32))) 26 | 27 | return (K,M) 28 | 29 | def parse_scan(scan_ind,output_cameras_file,blendedMVS_path): 30 | files = os.listdir('%s/scan%d/cams' % (blendedMVS_path,scan_ind)) 31 | num_cams = len(files) - 1 32 | 33 | cameras_new = {} 34 | for i in range(num_cams): 35 | Ki, Mi = read_camera("%s/scan%d" % (blendedMVS_path,scan_ind), int(files[i][:8])) 36 | curp = np.eye(4).astype(np.float32) 37 | curp[:3, :] = Ki @ Mi 38 | cameras_new['world_mat_%d' % i] = curp.copy() 39 | 40 | np.savez( 41 | output_cameras_file, 42 | **cameras_new) 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser(description='Parsing blendedMVS') 46 | parser.add_argument('--blendedMVS_path', type=str, default="BlendedMVS", 47 | help='the blendedMVS path') 48 | parser.add_argument('--output_cameras_file', type=str, default="cameras.npz", 49 | help='the output cameras file') 50 | parser.add_argument('--scan_ind',type=int, 51 | help='Scan id') 52 | 53 | args = parser.parse_args() 54 | parse_scan(args.scan_ind,args.output_cameras_file,args.blendedMVS_path) 55 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: volsdf 2 | channels: 3 | - pytorch 4 | - plotly 5 | - anaconda 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=4.5=1_gnu 11 | - aadict=0.2.3=pyh9f0ad1d_0 12 | - asset=0.6.13=pyh9f0ad1d_0 13 | - blas=1.0=mkl 14 | - blosc=1.21.0=h9c3ff4c_0 15 | - bottleneck=1.3.2=py39hdd57654_1 16 | - brotli=1.0.9=he6710b0_2 17 | - brunsli=0.1=h2531618_0 18 | - bzip2=1.0.8=h7b6447c_0 19 | - ca-certificates=2021.10.8=ha878542_0 20 | - cairo=1.16.0=hcf35c78_1003 21 | - certifi=2021.10.8=py39h06a4308_0 22 | - cfitsio=3.470=hf0d0db6_6 23 | - charls=2.2.0=h2531618_0 24 | - cloudpickle=1.6.0=py_0 25 | - colorama=0.4.4=pyh9f0ad1d_0 26 | - cudatoolkit=10.2.89=hfd86e86_1 27 | - cycler=0.11.0=pyhd3eb1b0_0 28 | - cytoolz=0.11.0=py39h27cfd23_0 29 | - dask-core=2.30.0=py_0 30 | - dbus=1.13.6=hfdff14a_1 31 | - decorator=4.4.2=py_0 32 | - expat=2.4.1=h9c3ff4c_0 33 | - ffmpeg=4.3.2=hca11adc_0 34 | - fontconfig=2.13.1=hba837de_1005 35 | - fonttools=4.25.0=pyhd3eb1b0_0 36 | - freetype=2.11.0=h70c0345_0 37 | - gettext=0.19.8.1=hf34092f_1004 38 | - giflib=5.2.1=h7b6447c_0 39 | - glib=2.66.3=h58526e2_0 40 | - globre=0.1.5=pyh9f0ad1d_0 41 | - gmp=6.2.1=h2531618_2 42 | - gnutls=3.6.15=he1e5248_0 43 | - gputil=1.4.0=pyh9f0ad1d_0 44 | - graphite2=1.3.13=h58526e2_1001 45 | - gst-plugins-base=1.14.5=h0935bb2_2 46 | - gstreamer=1.14.5=h36ae1b5_2 47 | - harfbuzz=2.4.0=h9f30f68_3 48 | - hdf5=1.10.6=nompi_h7c3c948_1111 49 | - icu=64.2=he1b5a44_1 50 | - imagecodecs=2021.3.31=py39h7572904_1 51 | - imageio=2.9.0=py_0 52 | - intel-openmp=2021.4.0=h06a4308_3561 53 | - jasper=1.900.1=h07fcdf6_1006 54 | - jbig=2.1=h7f98852_2003 55 | - jpeg=9d=h7f8727e_0 56 | - jxrlib=1.1=h7b6447c_2 57 | - kiwisolver=1.3.1=py39h2531618_0 58 | - krb5=1.18.2=h173b8e3_0 59 | - lame=3.100=h7b6447c_0 60 | - lcms2=2.12=h3be6417_0 61 | - ld_impl_linux-64=2.35.1=h7274673_9 62 | - lerc=2.2.1=h9c3ff4c_0 63 | - libaec=1.0.4=he6710b0_1 64 | - libblas=3.9.0=12_linux64_mkl 65 | - libcblas=3.9.0=12_linux64_mkl 66 | - libclang=9.0.1=default_ha53f305_1 67 | - libcurl=7.71.1=h20c2e04_1 68 | - libdeflate=1.7=h7f98852_5 69 | - libedit=3.1.20191231=h14c3975_1 70 | - libffi=3.2.1=he1b5a44_1007 71 | - libgcc-ng=9.3.0=h5101ec6_17 72 | - libgfortran-ng=7.5.0=ha8ba4b0_17 73 | - libgfortran4=7.5.0=ha8ba4b0_17 74 | - libglib=2.66.3=hbe7bbb4_0 75 | - libgomp=9.3.0=h5101ec6_17 76 | - libiconv=1.16=h516909a_0 77 | - libidn2=2.3.2=h7f8727e_0 78 | - liblapack=3.9.0=12_linux64_mkl 79 | - liblapacke=3.9.0=12_linux64_mkl 80 | - libllvm9=9.0.1=hf817b99_2 81 | - libopencv=4.5.2=py39h70bf20d_1 82 | - libpng=1.6.37=hbc83047_0 83 | - libprotobuf=3.16.0=h780b84a_0 84 | - libssh2=1.9.0=h1ba5d50_1 85 | - libstdcxx-ng=9.3.0=hd4cf53a_17 86 | - libtasn1=4.16.0=h27cfd23_0 87 | - libtiff=4.3.0=hf544144_1 88 | - libunistring=0.9.10=h27cfd23_0 89 | - libuuid=2.32.1=h7f98852_1000 90 | - libuv=1.40.0=h7b6447c_0 91 | - libwebp=1.2.0=h89dd481_0 92 | - libwebp-base=1.2.0=h27cfd23_0 93 | - libxcb=1.13=h7f98852_1003 94 | - libxkbcommon=0.10.0=he1b5a44_0 95 | - libxml2=2.9.10=hee79883_0 96 | - libzopfli=1.0.3=he6710b0_0 97 | - lz4-c=1.9.3=h295c915_1 98 | - matplotlib-base=3.5.0=py39h3ed280b_0 99 | - mkl=2021.4.0=h06a4308_640 100 | - mkl-service=2.4.0=py39h7f8727e_0 101 | - mkl_fft=1.3.1=py39hd3c417c_0 102 | - mkl_random=1.2.2=py39h51133e4_0 103 | - munkres=1.1.4=py_0 104 | - ncurses=6.2=h58526e2_4 105 | - nettle=3.7.3=hbbd107a_1 106 | - networkx=2.5=py_0 107 | - nspr=4.30=h9c3ff4c_0 108 | - nss=3.67=hb5efdd6_0 109 | - numexpr=2.7.3=py39h22e1b3c_1 110 | - numpy=1.21.2=py39h20f2e39_0 111 | - numpy-base=1.21.2=py39h79a1101_0 112 | - olefile=0.46=pyhd3eb1b0_0 113 | - opencv=4.5.2=py39hf3d152e_1 114 | - openh264=2.1.1=h780b84a_0 115 | - openjpeg=2.4.0=hb52868f_1 116 | - openssl=1.1.1l=h7f8727e_0 117 | - packaging=20.4=py_0 118 | - pandas=1.3.4=py39h8c16a72_0 119 | - pcre=8.45=h9c3ff4c_0 120 | - pillow=8.4.0=py39h5aabda8_0 121 | - pip=21.2.4=py39h06a4308_0 122 | - pixman=0.38.0=h516909a_1003 123 | - plotly=5.4.0=py_0 124 | - pthread-stubs=0.4=h36c2ea0_1001 125 | - py-opencv=4.5.2=py39hef51801_1 126 | - pyhocon=0.3.59=pyhd8ed1ab_0 127 | - pyparsing=3.0.6=pyhd8ed1ab_0 128 | - python=3.9.0=h2a148a8_4_cpython 129 | - python-dateutil=2.8.1=py_0 130 | - python_abi=3.9=2_cp39 131 | - pytorch=1.10.0=py3.9_cuda10.2_cudnn7.6.5_0 132 | - pytorch-mutex=1.0=cuda 133 | - pytz=2020.1=py_0 134 | - pywavelets=1.1.1=py39h6323ea4_4 135 | - pyyaml=6.0=py39h7f8727e_1 136 | - qt=5.12.5=hd8c4c69_1 137 | - readline=8.1=h27cfd23_0 138 | - scikit-image=0.18.3=py39h51133e4_0 139 | - scipy=1.7.1=py39h292c36d_2 140 | - setuptools=58.0.4=py39h06a4308_0 141 | - six=1.16.0=pyhd3eb1b0_0 142 | - snappy=1.1.8=he6710b0_0 143 | - sqlite=3.36.0=hc218d9a_0 144 | - tenacity=8.0.1=py39h06a4308_0 145 | - tifffile=2020.10.1=py_0 146 | - tk=8.6.11=h1ccaba5_0 147 | - toolz=0.11.1=py_0 148 | - torchaudio=0.10.0=py39_cu102 149 | - torchvision=0.11.1=py39_cu102 150 | - tqdm=4.62.3=pyhd8ed1ab_0 151 | - trimesh=3.9.36=pyh6c4a22f_0 152 | - typing_extensions=3.10.0.2=pyh06a4308_0 153 | - tzdata=2021e=hda174b7_0 154 | - wheel=0.37.0=pyhd3eb1b0_1 155 | - x264=1!161.3030=h7f98852_1 156 | - xorg-kbproto=1.0.7=h7f98852_1002 157 | - xorg-libice=1.0.10=h7f98852_0 158 | - xorg-libsm=1.2.3=hd9c2040_1000 159 | - xorg-libx11=1.7.2=h7f98852_0 160 | - xorg-libxau=1.0.9=h7f98852_0 161 | - xorg-libxdmcp=1.1.3=h7f98852_0 162 | - xorg-libxext=1.3.4=h7f98852_1 163 | - xorg-libxrender=0.9.10=h7f98852_1003 164 | - xorg-renderproto=0.11.1=h7f98852_1002 165 | - xorg-xextproto=7.3.0=h7f98852_1002 166 | - xorg-xproto=7.0.31=h7f98852_1007 167 | - xz=5.2.5=h7b6447c_0 168 | - yaml=0.2.5=h7b6447c_0 169 | - zfp=0.5.5=h2531618_6 170 | - zlib=1.2.11=h7b6447c_3 171 | - zstd=1.5.0=ha95c52a_0 172 | -------------------------------------------------------------------------------- /media/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lioryariv/volsdf/a974c883eb70af666d8b4374e771d76930c806f3/media/teaser.png --------------------------------------------------------------------------------