├── .gitignore
├── DATA_CONVENTION.md
├── LICENSE
├── README.md
├── code
├── confs
│ ├── bmvs.conf
│ └── dtu.conf
├── datasets
│ └── scene_dataset.py
├── evaluation
│ └── eval.py
├── model
│ ├── density.py
│ ├── embedder.py
│ ├── loss.py
│ ├── network.py
│ ├── network_bg.py
│ └── ray_sampler.py
├── training
│ ├── exp_runner.py
│ └── volsdf_train.py
└── utils
│ ├── general.py
│ ├── plots.py
│ └── rend_util.py
├── data
├── download_data.sh
└── preprocess
│ ├── normalize_cameras.py
│ └── parse_cameras_blendedmvs.py
├── environment.yml
└── media
└── teaser.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | exps*
132 | evals*
133 | data/DTU
134 | data/BlendedMVS
135 |
136 | code/.idea/
137 | .DS_Store
138 | ._.DS_Store
139 | .idea/
140 |
--------------------------------------------------------------------------------
/DATA_CONVENTION.md:
--------------------------------------------------------------------------------
1 | # Data Convention
2 |
3 | ### Camera information and normalization
4 | Besides multi-view RGB images, VolSDF needs cameras information in order to run. For each scan that we used, we supply a file named `cameras.npz`.
5 | The `cameras.npz` file contains for each image its assosiacted camera projection matrix (named "world_mat_{i}"), and a normalization matrix (named "scale_mat_{i}").
6 | #### Camera projection matrix
7 | A 3x4 camera projection matrix, P = K[R | t] projects points from 3D coordinates to image pixels by the formula: d[x; y; 1]=P[X;Y;Z;1] where K is a 3x3 calibration matrix, [R t] is 3x4 a world to camera Euclidean transformation, [X;Y;Z] is the 3D point, [x;y] is the 2D pixel coordinates of the projected point and d is the depth of the point.
8 | The input `cameras.npz` file contains the camera matrices, where P_i = cameras['world_mat_{i}'][:3, :] is a 3x4 matrix that projects points from the 3D world coordinates to the 2D coordinates of image i (intrinsics and extrinsics, i.e. P=K[R | t] ).
9 | Each "world_mat" matrix is a concatenation of the camera projection matrix with a row vector of [0,0,0,1] (which makes it a 4x4 matrix).
10 |
11 | #### Normalization matrix
12 | The `cameras.npz` contains also one normalization matrix named "scale_mat_{i}" (identical for all i) for changing the coordinates system such that the cameras and the region of interest are located inside a sphere with radius 3 located at the origin (more details are in the paper).
13 |
14 |
15 | ### Preprocess new data
16 | For converting BlendedMVS cameras format to ours (not required for the supplied scans), run :
17 | ```
18 | cd data/preprocess/
19 | python parse_cameras_blendedmvs.py --blendedMVS_path [BLENDED_MVS_PATH] --output_cameras_file [OUTPUT_CAMERAS_NPZ_FILE] --scan_ind [BLENDED_MVS_SCAN_ID]
20 | ```
21 |
22 | In order to generate a normalization matrix for each scan, we used the input camera projection matrices. A script that demonstrates this process is presented in: `data/preprocess/normalize_cameras.py`.
23 | Note: in order to run the supplied scans, it is not required to run this script.
24 | For normalizing a given `cameras.npz` file run:
25 | ```
26 | cd data/preprocess/
27 | python normalize_cameras.py --input_cameras_file [INPUT_CAMERAS_NPZ_FILE] --output_cameras_file [OUTPUT_NORMALIZED_CAMERAS_NPZ_FILE] [--number_of_cams [NUMBER_OF_CAMERAS_LIMIT]]
28 | ```
29 | where the last argument is optional and used for limiting the number of cameras such that only the first [NUMBER_OF_CAMERAS_LIMIT] cameras are considered, which is useful for the DTU dataset, where for scan_id<80 only the first 49 cameras out of 64 are used.
30 |
31 |
32 | #### Parsing COLMAP cameras
33 | It is possible to convert COLMAP cameras to our cameras format using Python. First the functions read_cameras_text,read_images_text, qvec2rotmat should be imported from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py. Then the following Python code can be used:
34 |
35 | ```
36 | cameras=read_cameras_text("output_sfm\\cameras.txt")
37 | images=read_images_text("output_sfm\\images.txt")
38 | K = np.eye(3)
39 | K[0, 0] = cameras[1].params[0]
40 | K[1, 1] = cameras[1].params[1]
41 | K[0, 2] = cameras[1].params[2]
42 | K[1, 2] = cameras[1].params[3]
43 |
44 | cameras_npz_format = {}
45 | for ii in range(len(images)):
46 | cur_image=images[ii]
47 |
48 | M=np.zeros((3,4))
49 | M[:,3]=cur_image.tvec
50 | M[:3,:3]=qvec2rotmat(cur_image.qvec)
51 |
52 | P=np.eye(4)
53 | P[:3,:] = K@M
54 | cameras_npz_format['world_mat_%d' % ii] = P
55 |
56 | np.savez(
57 | "cameras_before_normalization.npz",
58 | **cameras_npz_format)
59 |
60 | ```
61 | Note that you will have to normalize the cameras after running this code by running normalize_cameras.py as described above.
62 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Lior Yariv
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Volume Rendering of Neural Implicit Surfaces
2 |
3 | ### [Project Page](https://lioryariv.github.io/volsdf/) | [Paper](https://arxiv.org/abs/2106.12052) | [Data](https://www.dropbox.com/sh/oum8dyo19jqdkwu/AAAxpIifYjjotz_fIRBj1Fyla)
4 |
5 |
6 |
7 |
8 |
9 | This repository contains an implementation for the NeurIPS 2021 paper:
10 | Volume Rendering of Neural Implicit Surfaces
11 | Lior Yariv1, Jiatao Gu2, Yoni Kasten1, Yaron Lipman1,2
12 | 1Weizmann Institute of Science, 2Facebook AI Research
13 |
14 | The paper introduce VolSDF: a volume rendering framework for implicit neural surfaces, allowing to learn high fidelity geometry from a sparse set of input images.
15 |
16 | ## Setup
17 | #### Installation Requirmenets
18 | The code is compatible with python 3.8 and pytorch 1.7. In addition, the following packages are required:
19 | numpy, pyhocon, plotly, scikit-image, trimesh, imageio, opencv, torchvision.
20 |
21 | You can create an anaconda environment called `volsdf` with the required dependencies by running:
22 | ```
23 | conda env create -f environment.yml
24 | conda activate volsdf
25 | ```
26 |
27 | #### Data
28 |
29 | We apply our multiview surface reconstruction model to real 2D images from two datasets: DTU and BlendedMVS.
30 | The selected scans data evaluated in the paper can be downloaded using:
31 | ```
32 | bash data/download_data.sh
33 | ```
34 | For more information on the data convention and how to run VolSDF on a new data please have a look at data convention.
35 |
36 | ## Usage
37 | #### Multiview 3D reconstruction
38 |
39 | For training VolSDF run:
40 | ```
41 | cd ./code
42 | python training/exp_runner.py --conf ./confs/dtu.conf --scan_id SCAN_ID
43 | ```
44 | where SCAN_ID is the id of the scene to reconstruct.
45 |
46 | To run on the BlendedMVS dataset, which have more complex background, use `--conf ./confs/bmvs.conf`.
47 |
48 |
49 | #### Evaluation
50 |
51 | To produce the meshed surface and renderings, run:
52 | ```
53 | cd ./code
54 | python evaluation/eval.py --conf ./confs/dtu.conf --scan_id SCAN_ID --checkpoint CHECKPOINT [--eval_rendering]
55 | ```
56 | where CHECKPOINT is the epoch you wish to evaluate or 'latest' if you wish to take the most recent epoch.
57 | Turning on `--eval_rendering` will further produce and evaluate PSNR of train image reconstructions.
58 |
59 |
60 |
61 | ## Citation
62 | If you find our work useful in your research, please consider citing:
63 |
64 | @inproceedings{yariv2021volume,
65 | title={Volume rendering of neural implicit surfaces},
66 | author={Yariv, Lior and Gu, Jiatao and Kasten, Yoni and Lipman, Yaron},
67 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
68 | year={2021}
69 | }
70 |
71 |
--------------------------------------------------------------------------------
/code/confs/bmvs.conf:
--------------------------------------------------------------------------------
1 | train{
2 | expname = bmvs
3 | dataset_class = datasets.scene_dataset.SceneDataset
4 | model_class = model.network_bg.VolSDFNetworkBG
5 | loss_class = model.loss.VolSDFLoss
6 | learning_rate = 5.0e-4
7 | num_pixels = 1024
8 | checkpoint_freq = 100
9 | plot_freq = 500
10 | split_n_pixels = 1000
11 | }
12 | plot{
13 | plot_nimgs = 1
14 | resolution = 100
15 | grid_boundary = [-1.5, 1.5]
16 | }
17 | loss{
18 | eikonal_weight = 0.1
19 | rgb_loss = torch.nn.L1Loss
20 | }
21 | dataset{
22 | data_dir = BlendedMVS
23 | img_res = [576, 768]
24 | scan_id = 1
25 | }
26 | model{
27 | feature_vector_size = 256
28 | scene_bounding_sphere = 3.0
29 | implicit_network
30 | {
31 | d_in = 3
32 | d_out = 1
33 | dims = [ 256, 256, 256, 256, 256, 256, 256, 256 ]
34 | geometric_init = True
35 | bias = 0.6
36 | skip_in = [4]
37 | weight_norm = True
38 | multires = 6
39 | }
40 | rendering_network
41 | {
42 | mode = idr
43 | d_in = 9
44 | d_out = 3
45 | dims = [ 256, 256, 256, 256]
46 | weight_norm = True
47 | multires_view = 4
48 | }
49 | density
50 | {
51 | params_init{
52 | beta = 0.1
53 | }
54 | beta_min = 0.0001
55 | }
56 | ray_sampler
57 | {
58 | near = 0.0
59 | N_samples = 64
60 | N_samples_eval = 128
61 | N_samples_extra = 32
62 | eps = 0.1
63 | beta_iters = 10
64 | max_total_iters = 5
65 | N_samples_inverse_sphere = 32
66 | add_tiny = 1.0e-6
67 | }
68 | bg_network{
69 | feature_vector_size = 256
70 | implicit_network
71 | {
72 | d_in = 4
73 | d_out = 1
74 | dims = [ 256, 256, 256, 256, 256, 256, 256, 256 ]
75 | geometric_init = False
76 | bias = 0.0
77 | skip_in = [4]
78 | weight_norm = False
79 | multires = 10
80 | }
81 | rendering_network
82 | {
83 | mode = nerf
84 | d_in = 3
85 | d_out = 3
86 | dims = [128]
87 | weight_norm = False
88 | multires_view = 4
89 | }
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/code/confs/dtu.conf:
--------------------------------------------------------------------------------
1 | train{
2 | expname = dtu
3 | dataset_class = datasets.scene_dataset.SceneDataset
4 | model_class = model.network.VolSDFNetwork
5 | loss_class = model.loss.VolSDFLoss
6 | learning_rate = 5.0e-4
7 | num_pixels = 1024
8 | checkpoint_freq = 100
9 | plot_freq = 500
10 | split_n_pixels = 1000
11 | }
12 | plot{
13 | plot_nimgs = 1
14 | resolution = 100
15 | grid_boundary = [-1.5, 1.5]
16 | }
17 | loss{
18 | eikonal_weight = 0.1
19 | rgb_loss = torch.nn.L1Loss
20 | }
21 | dataset{
22 | data_dir = DTU
23 | img_res = [1200, 1600]
24 | scan_id = 65
25 | }
26 | model{
27 | feature_vector_size = 256
28 | scene_bounding_sphere = 3.0
29 | implicit_network
30 | {
31 | d_in = 3
32 | d_out = 1
33 | dims = [ 256, 256, 256, 256, 256, 256, 256, 256 ]
34 | geometric_init = True
35 | bias = 0.6
36 | skip_in = [4]
37 | weight_norm = True
38 | multires = 6
39 | sphere_scale = 20.0
40 | }
41 | rendering_network
42 | {
43 | mode = idr
44 | d_in = 9
45 | d_out = 3
46 | dims = [ 256, 256, 256, 256]
47 | weight_norm = True
48 | multires_view = 4
49 | }
50 | density
51 | {
52 | params_init{
53 | beta = 0.1
54 | }
55 | beta_min = 0.0001
56 | }
57 | ray_sampler
58 | {
59 | near = 0.0
60 | N_samples = 64
61 | N_samples_eval = 128
62 | N_samples_extra = 32
63 | eps = 0.1
64 | beta_iters = 10
65 | max_total_iters = 5
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/code/datasets/scene_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | import utils.general as utils
6 | from utils import rend_util
7 |
8 | class SceneDataset(torch.utils.data.Dataset):
9 |
10 | def __init__(self,
11 | data_dir,
12 | img_res,
13 | scan_id=0,
14 | ):
15 |
16 | self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id))
17 |
18 | self.total_pixels = img_res[0] * img_res[1]
19 | self.img_res = img_res
20 |
21 | assert os.path.exists(self.instance_dir), "Data directory is empty"
22 |
23 | self.sampling_idx = None
24 |
25 | image_dir = '{0}/image'.format(self.instance_dir)
26 | image_paths = sorted(utils.glob_imgs(image_dir))
27 | self.n_images = len(image_paths)
28 |
29 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir)
30 | camera_dict = np.load(self.cam_file)
31 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
32 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
33 |
34 | self.intrinsics_all = []
35 | self.pose_all = []
36 | for scale_mat, world_mat in zip(scale_mats, world_mats):
37 | P = world_mat @ scale_mat
38 | P = P[:3, :4]
39 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
40 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
41 | self.pose_all.append(torch.from_numpy(pose).float())
42 |
43 | self.rgb_images = []
44 | for path in image_paths:
45 | rgb = rend_util.load_rgb(path)
46 | rgb = rgb.reshape(3, -1).transpose(1, 0)
47 | self.rgb_images.append(torch.from_numpy(rgb).float())
48 |
49 | def __len__(self):
50 | return self.n_images
51 |
52 | def __getitem__(self, idx):
53 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
54 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
55 | uv = uv.reshape(2, -1).transpose(1, 0)
56 |
57 | sample = {
58 | "uv": uv,
59 | "intrinsics": self.intrinsics_all[idx],
60 | "pose": self.pose_all[idx]
61 | }
62 |
63 | ground_truth = {
64 | "rgb": self.rgb_images[idx]
65 | }
66 |
67 | if self.sampling_idx is not None:
68 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :]
69 | sample["uv"] = uv[self.sampling_idx, :]
70 |
71 | return idx, sample, ground_truth
72 |
73 | def collate_fn(self, batch_list):
74 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances
75 | batch_list = zip(*batch_list)
76 |
77 | all_parsed = []
78 | for entry in batch_list:
79 | if type(entry[0]) is dict:
80 | # make them all into a new dict
81 | ret = {}
82 | for k in entry[0].keys():
83 | ret[k] = torch.stack([obj[k] for obj in entry])
84 | all_parsed.append(ret)
85 | else:
86 | all_parsed.append(torch.LongTensor(entry))
87 |
88 | return tuple(all_parsed)
89 |
90 | def change_sampling_idx(self, sampling_size):
91 | if sampling_size == -1:
92 | self.sampling_idx = None
93 | else:
94 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size]
95 |
96 | def get_scale_mat(self):
97 | return np.load(self.cam_file)['scale_mat_0']
98 |
--------------------------------------------------------------------------------
/code/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../code')
3 | import argparse
4 | import GPUtil
5 | import os
6 | from pyhocon import ConfigFactory
7 | import torch
8 | import numpy as np
9 | from PIL import Image
10 | from tqdm import tqdm
11 | import pandas as pd
12 |
13 | import utils.general as utils
14 | import utils.plots as plt
15 | from utils import rend_util
16 |
17 | def evaluate(**kwargs):
18 | torch.set_default_dtype(torch.float32)
19 | torch.set_num_threads(1)
20 |
21 | conf = ConfigFactory.parse_file(kwargs['conf'])
22 | exps_folder_name = kwargs['exps_folder_name']
23 | evals_folder_name = kwargs['evals_folder_name']
24 | eval_rendering = kwargs['eval_rendering']
25 |
26 | expname = conf.get_string('train.expname') + kwargs['expname']
27 | scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int('dataset.scan_id', default=-1)
28 | if scan_id != -1:
29 | expname = expname + '_{0}'.format(scan_id)
30 | else:
31 | scan_id = conf.get_string('dataset.object', default='')
32 |
33 | if kwargs['timestamp'] == 'latest':
34 | if os.path.exists(os.path.join('../', kwargs['exps_folder_name'], expname)):
35 | timestamps = os.listdir(os.path.join('../', kwargs['exps_folder_name'], expname))
36 | if (len(timestamps)) == 0:
37 | print('WRONG EXP FOLDER')
38 | exit()
39 | # self.timestamp = sorted(timestamps)[-1]
40 | timestamp = None
41 | for t in sorted(timestamps):
42 | if os.path.exists(os.path.join('../', kwargs['exps_folder_name'], expname, t, 'checkpoints',
43 | 'ModelParameters', str(kwargs['checkpoint']) + ".pth")):
44 | timestamp = t
45 | if timestamp is None:
46 | print('NO GOOD TIMSTAMP')
47 | exit()
48 | else:
49 | print('WRONG EXP FOLDER')
50 | exit()
51 | else:
52 | timestamp = kwargs['timestamp']
53 |
54 | utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
55 | expdir = os.path.join('../', exps_folder_name, expname)
56 | evaldir = os.path.join('../', evals_folder_name, expname)
57 | utils.mkdir_ifnotexists(evaldir)
58 |
59 | dataset_conf = conf.get_config('dataset')
60 | if kwargs['scan_id'] != -1:
61 | dataset_conf['scan_id'] = kwargs['scan_id']
62 | eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(**dataset_conf)
63 |
64 | conf_model = conf.get_config('model')
65 | model = utils.get_class(conf.get_string('train.model_class'))(conf=conf_model)
66 | if torch.cuda.is_available():
67 | model.cuda()
68 |
69 | # settings for camera optimization
70 | scale_mat = eval_dataset.get_scale_mat()
71 |
72 | if eval_rendering:
73 | eval_dataloader = torch.utils.data.DataLoader(eval_dataset,
74 | batch_size=1,
75 | shuffle=False,
76 | collate_fn=eval_dataset.collate_fn
77 | )
78 | total_pixels = eval_dataset.total_pixels
79 | img_res = eval_dataset.img_res
80 | split_n_pixels = conf.get_int('train.split_n_pixels', 10000)
81 |
82 | old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints')
83 |
84 | saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
85 | model.load_state_dict(saved_model_state["model_state_dict"])
86 | epoch = saved_model_state['epoch']
87 |
88 | ####################################################################################################################
89 | print("evaluating...")
90 |
91 | model.eval()
92 |
93 | with torch.no_grad():
94 |
95 | if scan_id < 24: # Blended MVS
96 | mesh = plt.get_surface_high_res_mesh(
97 | sdf=lambda x: model.implicit_network(x)[:, 0],
98 | resolution=kwargs['resolution'],
99 | grid_boundary=conf.get_list('plot.grid_boundary'),
100 | level=conf.get_int('plot.level', default=0),
101 | take_components = type(scan_id) is not str
102 | )
103 | else: # DTU
104 | bb_dict = np.load('../data/DTU/bbs.npz')
105 | grid_params = bb_dict[str(scan_id)]
106 |
107 | mesh = plt.get_surface_by_grid(
108 | grid_params=grid_params,
109 | sdf=lambda x: model.implicit_network(x)[:, 0],
110 | resolution=kwargs['resolution'],
111 | level=conf.get_int('plot.level', default=0),
112 | higher_res=True
113 | )
114 |
115 | # Transform to world coordinates
116 | mesh.apply_transform(scale_mat)
117 |
118 | # Taking the biggest connected component
119 | components = mesh.split(only_watertight=False)
120 | areas = np.array([c.area for c in components], dtype=np.float32)
121 | mesh_clean = components[areas.argmax()]
122 |
123 | mesh_folder = '{0}/{1}'.format(evaldir, epoch)
124 | utils.mkdir_ifnotexists(mesh_folder)
125 | mesh_clean.export('{0}/scan{1}.ply'.format(mesh_folder, scan_id), 'ply')
126 |
127 | if eval_rendering:
128 | images_dir = '{0}/rendering_{1}'.format(evaldir, epoch)
129 | utils.mkdir_ifnotexists(images_dir)
130 |
131 | psnrs = []
132 | for data_index, (indices, model_input, ground_truth) in enumerate(eval_dataloader):
133 | model_input["intrinsics"] = model_input["intrinsics"].cuda()
134 | model_input["uv"] = model_input["uv"].cuda()
135 | model_input['pose'] = model_input['pose'].cuda()
136 |
137 | split = utils.split_input(model_input, total_pixels, n_pixels=split_n_pixels)
138 | res = []
139 | for s in tqdm(split):
140 | torch.cuda.empty_cache()
141 | out = model(s)
142 | res.append({
143 | 'rgb_values': out['rgb_values'].detach(),
144 | })
145 |
146 | batch_size = ground_truth['rgb'].shape[0]
147 | model_outputs = utils.merge_output(res, total_pixels, batch_size)
148 | rgb_eval = model_outputs['rgb_values']
149 | rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)
150 |
151 | rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
152 | rgb_eval = rgb_eval.transpose(1, 2, 0)
153 | img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
154 | img.save('{0}/eval_{1}.png'.format(images_dir,'%03d' % indices[0]))
155 |
156 | psnr = rend_util.get_psnr(model_outputs['rgb_values'],
157 | ground_truth['rgb'].cuda().reshape(-1, 3)).item()
158 | psnrs.append(psnr)
159 |
160 |
161 | psnrs = np.array(psnrs).astype(np.float64)
162 | print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id))
163 | psnrs = np.concatenate([psnrs, psnrs.mean()[None], psnrs.std()[None]])
164 | pd.DataFrame(psnrs).to_csv('{0}/psnr_{1}.csv'.format(evaldir, epoch))
165 |
166 |
167 |
168 | if __name__ == '__main__':
169 |
170 | parser = argparse.ArgumentParser()
171 |
172 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf')
173 | parser.add_argument('--expname', type=str, default='', help='The experiment name to be evaluated.')
174 | parser.add_argument('--exps_folder', type=str, default='exps', help='The experiments folder name.')
175 | parser.add_argument('--evals_folder', type=str, default='evals', help='The evaluation folder name.')
176 | parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]')
177 | parser.add_argument('--timestamp', default='latest', type=str, help='The experiemnt timestamp to test.')
178 | parser.add_argument('--checkpoint', default='latest',type=str,help='The trained model checkpoint to test')
179 | parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.')
180 | parser.add_argument('--resolution', default=512, type=int, help='Grid resolution for marching cube')
181 | parser.add_argument('--eval_rendering', default=False, action="store_true", help='If set, evaluate rendering quality.')
182 |
183 | opt = parser.parse_args()
184 |
185 | if opt.gpu == "auto":
186 | deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[], excludeUUID=[])
187 | gpu = deviceIDs[0]
188 | else:
189 | gpu = opt.gpu
190 |
191 | if (not gpu == 'ignore'):
192 | os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(gpu)
193 |
194 | evaluate(conf=opt.conf,
195 | expname=opt.expname,
196 | exps_folder_name=opt.exps_folder,
197 | evals_folder_name=opt.evals_folder,
198 | timestamp=opt.timestamp,
199 | checkpoint=opt.checkpoint,
200 | scan_id=opt.scan_id,
201 | resolution=opt.resolution,
202 | eval_rendering=opt.eval_rendering,
203 | )
204 |
--------------------------------------------------------------------------------
/code/model/density.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class Density(nn.Module):
6 | def __init__(self, params_init={}):
7 | super().__init__()
8 | for p in params_init:
9 | param = nn.Parameter(torch.tensor(params_init[p]))
10 | setattr(self, p, param)
11 |
12 | def forward(self, sdf, beta=None):
13 | return self.density_func(sdf, beta=beta)
14 |
15 |
16 | class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)
17 | def __init__(self, params_init={}, beta_min=0.0001):
18 | super().__init__(params_init=params_init)
19 | self.beta_min = torch.tensor(beta_min).cuda()
20 |
21 | def density_func(self, sdf, beta=None):
22 | if beta is None:
23 | beta = self.get_beta()
24 |
25 | alpha = 1 / beta
26 | return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta))
27 |
28 | def get_beta(self):
29 | beta = self.beta.abs() + self.beta_min
30 | return beta
31 |
32 |
33 | class AbsDensity(Density): # like NeRF++
34 | def density_func(self, sdf, beta=None):
35 | return torch.abs(sdf)
36 |
37 |
38 | class SimpleDensity(Density): # like NeRF
39 | def __init__(self, params_init={}, noise_std=1.0):
40 | super().__init__(params_init=params_init)
41 | self.noise_std = noise_std
42 |
43 | def density_func(self, sdf, beta=None):
44 | if self.training and self.noise_std > 0.0:
45 | noise = torch.randn(sdf.shape).cuda() * self.noise_std
46 | sdf = sdf + noise
47 | return torch.relu(sdf)
48 |
--------------------------------------------------------------------------------
/code/model/embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
4 |
5 | class Embedder:
6 | def __init__(self, **kwargs):
7 | self.kwargs = kwargs
8 | self.create_embedding_fn()
9 |
10 | def create_embedding_fn(self):
11 | embed_fns = []
12 | d = self.kwargs['input_dims']
13 | out_dim = 0
14 | if self.kwargs['include_input']:
15 | embed_fns.append(lambda x: x)
16 | out_dim += d
17 |
18 | max_freq = self.kwargs['max_freq_log2']
19 | N_freqs = self.kwargs['num_freqs']
20 |
21 | if self.kwargs['log_sampling']:
22 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
23 | else:
24 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
25 |
26 | for freq in freq_bands:
27 | for p_fn in self.kwargs['periodic_fns']:
28 | embed_fns.append(lambda x, p_fn=p_fn,
29 | freq=freq: p_fn(x * freq))
30 | out_dim += d
31 |
32 | self.embed_fns = embed_fns
33 | self.out_dim = out_dim
34 |
35 | def embed(self, inputs):
36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
37 |
38 | def get_embedder(multires, input_dims=3):
39 | embed_kwargs = {
40 | 'include_input': True,
41 | 'input_dims': input_dims,
42 | 'max_freq_log2': multires-1,
43 | 'num_freqs': multires,
44 | 'log_sampling': True,
45 | 'periodic_fns': [torch.sin, torch.cos],
46 | }
47 |
48 | embedder_obj = Embedder(**embed_kwargs)
49 | def embed(x, eo=embedder_obj): return eo.embed(x)
50 | return embed, embedder_obj.out_dim
51 |
--------------------------------------------------------------------------------
/code/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import utils.general as utils
4 |
5 |
6 | class VolSDFLoss(nn.Module):
7 | def __init__(self, rgb_loss, eikonal_weight):
8 | super().__init__()
9 | self.eikonal_weight = eikonal_weight
10 | self.rgb_loss = utils.get_class(rgb_loss)(reduction='mean')
11 |
12 | def get_rgb_loss(self,rgb_values, rgb_gt):
13 | rgb_gt = rgb_gt.reshape(-1, 3)
14 | rgb_loss = self.rgb_loss(rgb_values, rgb_gt)
15 | return rgb_loss
16 |
17 | def get_eikonal_loss(self, grad_theta):
18 | eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
19 | return eikonal_loss
20 |
21 | def forward(self, model_outputs, ground_truth):
22 | rgb_gt = ground_truth['rgb'].cuda()
23 |
24 | rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt)
25 | if 'grad_theta' in model_outputs:
26 | eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
27 | else:
28 | eikonal_loss = torch.tensor(0.0).cuda().float()
29 |
30 | loss = rgb_loss + \
31 | self.eikonal_weight * eikonal_loss
32 |
33 | output = {
34 | 'loss': loss,
35 | 'rgb_loss': rgb_loss,
36 | 'eikonal_loss': eikonal_loss,
37 | }
38 |
39 | return output
40 |
--------------------------------------------------------------------------------
/code/model/network.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 |
4 | from utils import rend_util
5 | from model.embedder import *
6 | from model.density import LaplaceDensity
7 | from model.ray_sampler import ErrorBoundSampler
8 |
9 | class ImplicitNetwork(nn.Module):
10 | def __init__(
11 | self,
12 | feature_vector_size,
13 | sdf_bounding_sphere,
14 | d_in,
15 | d_out,
16 | dims,
17 | geometric_init=True,
18 | bias=1.0,
19 | skip_in=(),
20 | weight_norm=True,
21 | multires=0,
22 | sphere_scale=1.0,
23 | ):
24 | super().__init__()
25 |
26 | self.sdf_bounding_sphere = sdf_bounding_sphere
27 | self.sphere_scale = sphere_scale
28 | dims = [d_in] + dims + [d_out + feature_vector_size]
29 |
30 | self.embed_fn = None
31 | if multires > 0:
32 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
33 | self.embed_fn = embed_fn
34 | dims[0] = input_ch
35 |
36 | self.num_layers = len(dims)
37 | self.skip_in = skip_in
38 |
39 | for l in range(0, self.num_layers - 1):
40 | if l + 1 in self.skip_in:
41 | out_dim = dims[l + 1] - dims[0]
42 | else:
43 | out_dim = dims[l + 1]
44 |
45 | lin = nn.Linear(dims[l], out_dim)
46 |
47 | if geometric_init:
48 | if l == self.num_layers - 2:
49 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
50 | torch.nn.init.constant_(lin.bias, -bias)
51 | elif multires > 0 and l == 0:
52 | torch.nn.init.constant_(lin.bias, 0.0)
53 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
54 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
55 | elif multires > 0 and l in self.skip_in:
56 | torch.nn.init.constant_(lin.bias, 0.0)
57 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
58 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
59 | else:
60 | torch.nn.init.constant_(lin.bias, 0.0)
61 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
62 |
63 | if weight_norm:
64 | lin = nn.utils.weight_norm(lin)
65 |
66 | setattr(self, "lin" + str(l), lin)
67 |
68 | self.softplus = nn.Softplus(beta=100)
69 |
70 | def forward(self, input):
71 | if self.embed_fn is not None:
72 | input = self.embed_fn(input)
73 |
74 | x = input
75 |
76 | for l in range(0, self.num_layers - 1):
77 | lin = getattr(self, "lin" + str(l))
78 |
79 | if l in self.skip_in:
80 | x = torch.cat([x, input], 1) / np.sqrt(2)
81 |
82 | x = lin(x)
83 |
84 | if l < self.num_layers - 2:
85 | x = self.softplus(x)
86 |
87 | return x
88 |
89 | def gradient(self, x):
90 | x.requires_grad_(True)
91 | y = self.forward(x)[:,:1]
92 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
93 | gradients = torch.autograd.grad(
94 | outputs=y,
95 | inputs=x,
96 | grad_outputs=d_output,
97 | create_graph=True,
98 | retain_graph=True,
99 | only_inputs=True)[0]
100 | return gradients
101 |
102 | def get_outputs(self, x):
103 | x.requires_grad_(True)
104 | output = self.forward(x)
105 | sdf = output[:,:1]
106 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
107 | if self.sdf_bounding_sphere > 0.0:
108 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
109 | sdf = torch.minimum(sdf, sphere_sdf)
110 | feature_vectors = output[:, 1:]
111 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
112 | gradients = torch.autograd.grad(
113 | outputs=sdf,
114 | inputs=x,
115 | grad_outputs=d_output,
116 | create_graph=True,
117 | retain_graph=True,
118 | only_inputs=True)[0]
119 |
120 | return sdf, feature_vectors, gradients
121 |
122 | def get_sdf_vals(self, x):
123 | sdf = self.forward(x)[:,:1]
124 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
125 | if self.sdf_bounding_sphere > 0.0:
126 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
127 | sdf = torch.minimum(sdf, sphere_sdf)
128 | return sdf
129 |
130 |
131 | class RenderingNetwork(nn.Module):
132 | def __init__(
133 | self,
134 | feature_vector_size,
135 | mode,
136 | d_in,
137 | d_out,
138 | dims,
139 | weight_norm=True,
140 | multires_view=0,
141 | ):
142 | super().__init__()
143 |
144 | self.mode = mode
145 | dims = [d_in + feature_vector_size] + dims + [d_out]
146 |
147 | self.embedview_fn = None
148 | if multires_view > 0:
149 | embedview_fn, input_ch = get_embedder(multires_view)
150 | self.embedview_fn = embedview_fn
151 | dims[0] += (input_ch - 3)
152 |
153 | self.num_layers = len(dims)
154 |
155 | for l in range(0, self.num_layers - 1):
156 | out_dim = dims[l + 1]
157 | lin = nn.Linear(dims[l], out_dim)
158 |
159 | if weight_norm:
160 | lin = nn.utils.weight_norm(lin)
161 |
162 | setattr(self, "lin" + str(l), lin)
163 |
164 | self.relu = nn.ReLU()
165 | self.sigmoid = torch.nn.Sigmoid()
166 |
167 | def forward(self, points, normals, view_dirs, feature_vectors):
168 | if self.embedview_fn is not None:
169 | view_dirs = self.embedview_fn(view_dirs)
170 |
171 | if self.mode == 'idr':
172 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
173 | elif self.mode == 'nerf':
174 | rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1)
175 |
176 | x = rendering_input
177 |
178 | for l in range(0, self.num_layers - 1):
179 | lin = getattr(self, "lin" + str(l))
180 |
181 | x = lin(x)
182 |
183 | if l < self.num_layers - 2:
184 | x = self.relu(x)
185 |
186 | x = self.sigmoid(x)
187 | return x
188 |
189 | class VolSDFNetwork(nn.Module):
190 | def __init__(self, conf):
191 | super().__init__()
192 | self.feature_vector_size = conf.get_int('feature_vector_size')
193 | self.scene_bounding_sphere = conf.get_float('scene_bounding_sphere', default=1.0)
194 | self.white_bkgd = conf.get_bool('white_bkgd', default=False)
195 | self.bg_color = torch.tensor(conf.get_list("bg_color", default=[1.0, 1.0, 1.0])).float().cuda()
196 |
197 | self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network'))
198 | self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.get_config('rendering_network'))
199 |
200 | self.density = LaplaceDensity(**conf.get_config('density'))
201 | self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, **conf.get_config('ray_sampler'))
202 |
203 | def forward(self, input):
204 | # Parse model input
205 | intrinsics = input["intrinsics"]
206 | uv = input["uv"]
207 | pose = input["pose"]
208 |
209 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics)
210 |
211 | batch_size, num_pixels, _ = ray_dirs.shape
212 |
213 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)
214 | ray_dirs = ray_dirs.reshape(-1, 3)
215 |
216 | z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self)
217 | N_samples = z_vals.shape[1]
218 |
219 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1)
220 | points_flat = points.reshape(-1, 3)
221 |
222 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1)
223 | dirs_flat = dirs.reshape(-1, 3)
224 |
225 | sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat)
226 |
227 | rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors)
228 | rgb = rgb_flat.reshape(-1, N_samples, 3)
229 |
230 | weights = self.volume_rendering(z_vals, sdf)
231 |
232 | rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1)
233 |
234 | # white background assumption
235 | if self.white_bkgd:
236 | acc_map = torch.sum(weights, -1)
237 | rgb_values = rgb_values + (1. - acc_map[..., None]) * self.bg_color.unsqueeze(0)
238 |
239 | output = {
240 | 'rgb_values': rgb_values,
241 | }
242 |
243 | if self.training:
244 | # Sample points for the eikonal loss
245 | n_eik_points = batch_size * num_pixels
246 | eikonal_points = torch.empty(n_eik_points, 3).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere).cuda()
247 |
248 | # add some of the near surface points
249 | eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3)
250 | eikonal_points = torch.cat([eikonal_points, eik_near_points], 0)
251 |
252 | grad_theta = self.implicit_network.gradient(eikonal_points)
253 | output['grad_theta'] = grad_theta
254 |
255 | if not self.training:
256 | gradients = gradients.detach()
257 | normals = gradients / gradients.norm(2, -1, keepdim=True)
258 | normals = normals.reshape(-1, N_samples, 3)
259 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1)
260 |
261 | output['normal_map'] = normal_map
262 |
263 | return output
264 |
265 | def volume_rendering(self, z_vals, sdf):
266 | density_flat = self.density(sdf)
267 | density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples
268 |
269 | dists = z_vals[:, 1:] - z_vals[:, :-1]
270 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
271 |
272 | # LOG SPACE
273 | free_energy = dists * density
274 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) # shift one step
275 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here
276 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now
277 | weights = alpha * transmittance # probability of the ray hits something here
278 |
279 | return weights
280 |
--------------------------------------------------------------------------------
/code/model/network_bg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import time
4 |
5 | import utils.general as utils
6 | from utils import rend_util
7 | from model.network import ImplicitNetwork, RenderingNetwork
8 | from model.density import LaplaceDensity, AbsDensity
9 | from model.ray_sampler import ErrorBoundSampler
10 |
11 |
12 | """
13 | For modeling more complex backgrounds, we follow the inverted sphere parametrization from NeRF++
14 | https://github.com/Kai-46/nerfplusplus
15 | """
16 |
17 |
18 | class VolSDFNetworkBG(nn.Module):
19 | def __init__(self, conf):
20 | super().__init__()
21 | self.feature_vector_size = conf.get_int('feature_vector_size')
22 | self.scene_bounding_sphere = conf.get_float('scene_bounding_sphere', default=1.0)
23 |
24 | # Foreground object's networks
25 | self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0, **conf.get_config('implicit_network'))
26 | self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.get_config('rendering_network'))
27 |
28 | self.density = LaplaceDensity(**conf.get_config('density'))
29 | self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, inverse_sphere_bg=True, **conf.get_config('ray_sampler'))
30 |
31 | # Background's networks
32 | bg_feature_vector_size = conf.get_int('bg_network.feature_vector_size')
33 | self.bg_implicit_network = ImplicitNetwork(bg_feature_vector_size, 0.0, **conf.get_config('bg_network.implicit_network'))
34 | self.bg_rendering_network = RenderingNetwork(bg_feature_vector_size, **conf.get_config('bg_network.rendering_network'))
35 | self.bg_density = AbsDensity(**conf.get_config('bg_network.density', default={}))
36 |
37 | def forward(self, input):
38 | # Parse model input
39 | intrinsics = input["intrinsics"]
40 | uv = input["uv"]
41 | pose = input["pose"]
42 |
43 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics)
44 |
45 | batch_size, num_pixels, _ = ray_dirs.shape
46 |
47 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)
48 | ray_dirs = ray_dirs.reshape(-1, 3)
49 |
50 | z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self)
51 |
52 | z_vals, z_vals_bg = z_vals
53 | z_max = z_vals[:,-1]
54 | z_vals = z_vals[:,:-1]
55 | N_samples = z_vals.shape[1]
56 |
57 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1)
58 | points_flat = points.reshape(-1, 3)
59 |
60 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1)
61 | dirs_flat = dirs.reshape(-1, 3)
62 |
63 | sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat)
64 |
65 | rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors)
66 | rgb = rgb_flat.reshape(-1, N_samples, 3)
67 |
68 | weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf)
69 |
70 | fg_rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1)
71 |
72 |
73 | # Background rendering
74 | N_bg_samples = z_vals_bg.shape[1]
75 | z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ]) # 1--->0
76 |
77 | bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1)
78 | bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1)
79 |
80 | bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg) # [..., N_samples, 4]
81 | bg_points_flat = bg_points.reshape(-1, 4)
82 | bg_dirs_flat = bg_dirs.reshape(-1, 3)
83 |
84 | output = self.bg_implicit_network(bg_points_flat)
85 | bg_sdf = output[:,:1]
86 | bg_feature_vectors = output[:, 1:]
87 | bg_rgb_flat = self.bg_rendering_network(None, None, bg_dirs_flat, bg_feature_vectors)
88 | bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)
89 |
90 | bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf)
91 |
92 | bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1)
93 |
94 |
95 | # Composite foreground and background
96 | bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values
97 | rgb_values = fg_rgb_values + bg_rgb_values
98 |
99 | output = {
100 | 'rgb_values': rgb_values,
101 | }
102 |
103 | if self.training:
104 | # Sample points for the eikonal loss
105 | n_eik_points = batch_size * num_pixels
106 | eikonal_points = torch.empty(n_eik_points, 3).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere).cuda()
107 |
108 | # add some of the near surface points
109 | eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3)
110 | eikonal_points = torch.cat([eikonal_points, eik_near_points], 0)
111 |
112 | grad_theta = self.implicit_network.gradient(eikonal_points)
113 | output['grad_theta'] = grad_theta
114 |
115 | if not self.training:
116 | gradients = gradients.detach()
117 | normals = gradients / gradients.norm(2, -1, keepdim=True)
118 | normals = normals.reshape(-1, N_samples, 3)
119 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1)
120 |
121 | output['normal_map'] = normal_map
122 |
123 | return output
124 |
125 | def volume_rendering(self, z_vals, z_max, sdf):
126 | density_flat = self.density(sdf)
127 | density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples
128 |
129 | # included also the dist from the sphere intersection
130 | dists = z_vals[:, 1:] - z_vals[:, :-1]
131 | dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1)
132 |
133 | # LOG SPACE
134 | free_energy = dists * density
135 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy], dim=-1) # add 0 for transperancy 1 at t_0
136 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here
137 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now
138 | fg_transmittance = transmittance[:, :-1]
139 | weights = alpha * fg_transmittance # probability of the ray hits something here
140 | bg_transmittance = transmittance[:, -1] # factor to be multiplied with the bg volume rendering
141 |
142 | return weights, bg_transmittance
143 |
144 | def bg_volume_rendering(self, z_vals_bg, bg_sdf):
145 | bg_density_flat = self.bg_density(bg_sdf)
146 | bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples
147 |
148 | bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:]
149 | bg_dists = torch.cat([bg_dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1)
150 |
151 | # LOG SPACE
152 | bg_free_energy = bg_dists * bg_density
153 | bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1).cuda(), bg_free_energy[:, :-1]], dim=-1) # shift one step
154 | bg_alpha = 1 - torch.exp(-bg_free_energy) # probability of it is not empty here
155 | bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1)) # probability of everything is empty up to now
156 | bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here
157 |
158 | return bg_weights
159 |
160 | def depth2pts_outside(self, ray_o, ray_d, depth):
161 |
162 | '''
163 | ray_o, ray_d: [..., 3]
164 | depth: [...]; inverse of distance to sphere origin
165 | '''
166 |
167 | o_dot_d = torch.sum(ray_d * ray_o, dim=-1)
168 | under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.scene_bounding_sphere ** 2)
169 | d_sphere = torch.sqrt(under_sqrt) - o_dot_d
170 | p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d
171 | p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d
172 | p_mid_norm = torch.norm(p_mid, dim=-1)
173 |
174 | rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
175 | rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
176 | phi = torch.asin(p_mid_norm / self.scene_bounding_sphere)
177 | theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
178 | rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
179 |
180 | # now rotate p_sphere
181 | # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
182 | p_sphere_new = p_sphere * torch.cos(rot_angle) + \
183 | torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
184 | rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle))
185 | p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
186 | pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)
187 |
188 | return pts
189 |
--------------------------------------------------------------------------------
/code/model/ray_sampler.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 |
4 | from utils import rend_util
5 |
6 | class RaySampler(metaclass=abc.ABCMeta):
7 | def __init__(self,near, far):
8 | self.near = near
9 | self.far = far
10 |
11 | @abc.abstractmethod
12 | def get_z_vals(self, ray_dirs, cam_loc, model):
13 | pass
14 |
15 | class UniformSampler(RaySampler):
16 | def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):
17 | super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R
18 | self.N_samples = N_samples
19 | self.scene_bounding_sphere = scene_bounding_sphere
20 | self.take_sphere_intersection = take_sphere_intersection
21 |
22 | def get_z_vals(self, ray_dirs, cam_loc, model):
23 | if not self.take_sphere_intersection:
24 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda()
25 | else:
26 | sphere_intersections = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)
27 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda()
28 | far = sphere_intersections[:,1:]
29 |
30 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda()
31 | z_vals = near * (1. - t_vals) + far * (t_vals)
32 |
33 | if model.training:
34 | # get intervals between samples
35 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
36 | upper = torch.cat([mids, z_vals[..., -1:]], -1)
37 | lower = torch.cat([z_vals[..., :1], mids], -1)
38 | # stratified samples in those intervals
39 | t_rand = torch.rand(z_vals.shape).cuda()
40 |
41 | z_vals = lower + (upper - lower) * t_rand
42 |
43 | return z_vals
44 |
45 |
46 | class ErrorBoundSampler(RaySampler):
47 | def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,
48 | eps, beta_iters, max_total_iters,
49 | inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0):
50 | super().__init__(near, 2.0 * scene_bounding_sphere)
51 | self.N_samples = N_samples
52 | self.N_samples_eval = N_samples_eval
53 | self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg)
54 |
55 | self.N_samples_extra = N_samples_extra
56 |
57 | self.eps = eps
58 | self.beta_iters = beta_iters
59 | self.max_total_iters = max_total_iters
60 | self.scene_bounding_sphere = scene_bounding_sphere
61 | self.add_tiny = add_tiny
62 |
63 | self.inverse_sphere_bg = inverse_sphere_bg
64 | if inverse_sphere_bg:
65 | self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)
66 |
67 | def get_z_vals(self, ray_dirs, cam_loc, model):
68 | beta0 = model.density.get_beta().detach()
69 |
70 | # Start with uniform sampling
71 | z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)
72 | samples, samples_idx = z_vals, None
73 |
74 | # Get maximum beta from the upper bound (Lemma 2)
75 | dists = z_vals[:, 1:] - z_vals[:, :-1]
76 | bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
77 | beta = torch.sqrt(bound)
78 |
79 | total_iters, not_converge = 0, True
80 |
81 | # Algorithm 1
82 | while not_converge and total_iters < self.max_total_iters:
83 | points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)
84 | points_flat = points.reshape(-1, 3)
85 |
86 | # Calculating the SDF only for the new sampled points
87 | with torch.no_grad():
88 | samples_sdf = model.implicit_network.get_sdf_vals(points_flat)
89 | if samples_idx is not None:
90 | sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
91 | samples_sdf.reshape(-1, samples.shape[1])], -1)
92 | sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
93 | else:
94 | sdf = samples_sdf
95 |
96 |
97 | # Calculating the bound d* (Theorem 1)
98 | d = sdf.reshape(z_vals.shape)
99 | dists = z_vals[:, 1:] - z_vals[:, :-1]
100 | a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
101 | first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
102 | second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
103 | d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda()
104 | d_star[first_cond] = b[first_cond]
105 | d_star[second_cond] = c[second_cond]
106 | s = (a + b + c) / 2.0
107 | area_before_sqrt = s * (s - a) * (s - b) * (s - c)
108 | mask = ~first_cond & ~second_cond & (b + c - a > 0)
109 | d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
110 | d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign
111 |
112 |
113 | # Updating beta using line search
114 | curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
115 | beta[curr_error <= self.eps] = beta0
116 | beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
117 | for j in range(self.beta_iters):
118 | beta_mid = (beta_min + beta_max) / 2.
119 | curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
120 | beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
121 | beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
122 | beta = beta_max
123 |
124 |
125 | # Upsample more points
126 | density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))
127 |
128 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
129 | free_energy = dists * density
130 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1)
131 | alpha = 1 - torch.exp(-free_energy)
132 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
133 | weights = alpha * transmittance # probability of the ray hits something here
134 |
135 | # Check if we are done and this is the last sampling
136 | total_iters += 1
137 | not_converge = beta.max() > beta0
138 |
139 | if not_converge and total_iters < self.max_total_iters:
140 | ''' Sample more points proportional to the current error bound'''
141 |
142 | N = self.N_samples_eval
143 |
144 | bins = z_vals
145 | error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
146 | error_integral = torch.cumsum(error_per_section, dim=-1)
147 | bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]
148 |
149 | pdf = bound_opacity + self.add_tiny
150 | pdf = pdf / torch.sum(pdf, -1, keepdim=True)
151 | cdf = torch.cumsum(pdf, -1)
152 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
153 |
154 | else:
155 | ''' Sample the final sample set to be used in the volume rendering integral '''
156 |
157 | N = self.N_samples
158 |
159 | bins = z_vals
160 | pdf = weights[..., :-1]
161 | pdf = pdf + 1e-5 # prevent nans
162 | pdf = pdf / torch.sum(pdf, -1, keepdim=True)
163 | cdf = torch.cumsum(pdf, -1)
164 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
165 |
166 |
167 | # Invert CDF
168 | if (not_converge and total_iters < self.max_total_iters) or (not model.training):
169 | u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1)
170 | else:
171 | u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda()
172 | u = u.contiguous()
173 |
174 | inds = torch.searchsorted(cdf, u, right=True)
175 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
176 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
177 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
178 |
179 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
180 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
181 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
182 |
183 | denom = (cdf_g[..., 1] - cdf_g[..., 0])
184 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
185 | t = (u - cdf_g[..., 0]) / denom
186 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
187 |
188 |
189 | # Adding samples if we not converged
190 | if not_converge and total_iters < self.max_total_iters:
191 | z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)
192 |
193 |
194 | z_samples = samples
195 |
196 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda()
197 | if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
198 | far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]
199 |
200 | if self.N_samples_extra > 0:
201 | if model.training:
202 | sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]
203 | else:
204 | sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
205 | z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
206 | else:
207 | z_vals_extra = torch.cat([near, far], -1)
208 |
209 | z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)
210 |
211 | # add some of the near surface points
212 | idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda()
213 | z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))
214 |
215 | if self.inverse_sphere_bg:
216 | z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)
217 | z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
218 | z_vals = (z_vals, z_vals_inverse_sphere)
219 |
220 | return z_vals, z_samples_eik
221 |
222 | def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
223 | density = model.density(sdf.reshape(z_vals.shape), beta=beta)
224 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1)
225 | integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
226 | error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
227 | error_integral = torch.cumsum(error_per_section, dim=-1)
228 | bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])
229 |
230 | return bound_opacity.max(-1)[0]
231 |
232 |
233 |
--------------------------------------------------------------------------------
/code/training/exp_runner.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('../code')
4 | import argparse
5 | import GPUtil
6 |
7 | from training.volsdf_train import VolSDFTrainRunner
8 |
9 | if __name__ == '__main__':
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
13 | parser.add_argument('--nepoch', type=int, default=2000, help='number of epochs to train for')
14 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf')
15 | parser.add_argument('--expname', type=str, default='')
16 | parser.add_argument("--exps_folder", type=str, default="exps")
17 | parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]')
18 | parser.add_argument('--is_continue', default=False, action="store_true",
19 | help='If set, indicates continuing from a previous run.')
20 | parser.add_argument('--timestamp', default='latest', type=str,
21 | help='The timestamp of the run to be used in case of continuing from a previous run.')
22 | parser.add_argument('--checkpoint', default='latest', type=str,
23 | help='The checkpoint epoch of the run to be used in case of continuing from a previous run.')
24 | parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.')
25 | parser.add_argument('--cancel_vis', default=False, action="store_true",
26 | help='If set, cancel visualization in intermediate epochs.')
27 |
28 | opt = parser.parse_args()
29 |
30 | if opt.gpu == "auto":
31 | deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False,
32 | excludeID=[], excludeUUID=[])
33 | gpu = deviceIDs[0]
34 | else:
35 | gpu = opt.gpu
36 |
37 | trainrunner = VolSDFTrainRunner(conf=opt.conf,
38 | batch_size=opt.batch_size,
39 | nepochs=opt.nepoch,
40 | expname=opt.expname,
41 | gpu_index=gpu,
42 | exps_folder_name=opt.exps_folder,
43 | is_continue=opt.is_continue,
44 | timestamp=opt.timestamp,
45 | checkpoint=opt.checkpoint,
46 | scan_id=opt.scan_id,
47 | do_vis=not opt.cancel_vis
48 | )
49 |
50 | trainrunner.run()
51 |
--------------------------------------------------------------------------------
/code/training/volsdf_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 | from pyhocon import ConfigFactory
4 | import sys
5 | import torch
6 | from tqdm import tqdm
7 |
8 | import utils.general as utils
9 | import utils.plots as plt
10 | from utils import rend_util
11 |
12 | class VolSDFTrainRunner():
13 | def __init__(self,**kwargs):
14 | torch.set_default_dtype(torch.float32)
15 | torch.set_num_threads(1)
16 |
17 | self.conf = ConfigFactory.parse_file(kwargs['conf'])
18 | self.batch_size = kwargs['batch_size']
19 | self.nepochs = kwargs['nepochs']
20 | self.exps_folder_name = kwargs['exps_folder_name']
21 | self.GPU_INDEX = kwargs['gpu_index']
22 |
23 | self.expname = self.conf.get_string('train.expname') + kwargs['expname']
24 | scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int('dataset.scan_id', default=-1)
25 | if scan_id != -1:
26 | self.expname = self.expname + '_{0}'.format(scan_id)
27 |
28 | if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
29 | if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)):
30 | timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname))
31 | if (len(timestamps)) == 0:
32 | is_continue = False
33 | timestamp = None
34 | else:
35 | timestamp = sorted(timestamps)[-1]
36 | is_continue = True
37 | else:
38 | is_continue = False
39 | timestamp = None
40 | else:
41 | timestamp = kwargs['timestamp']
42 | is_continue = kwargs['is_continue']
43 |
44 | utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name))
45 | self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
46 | utils.mkdir_ifnotexists(self.expdir)
47 | self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
48 | utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))
49 |
50 | self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
51 | utils.mkdir_ifnotexists(self.plots_dir)
52 |
53 | # create checkpoints dirs
54 | self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
55 | utils.mkdir_ifnotexists(self.checkpoints_path)
56 | self.model_params_subdir = "ModelParameters"
57 | self.optimizer_params_subdir = "OptimizerParameters"
58 | self.scheduler_params_subdir = "SchedulerParameters"
59 |
60 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
61 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
62 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir))
63 |
64 | os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf')))
65 |
66 | if (not self.GPU_INDEX == 'ignore'):
67 | os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)
68 |
69 | print('shell command : {0}'.format(' '.join(sys.argv)))
70 |
71 | print('Loading data ...')
72 |
73 | dataset_conf = self.conf.get_config('dataset')
74 | if kwargs['scan_id'] != -1:
75 | dataset_conf['scan_id'] = kwargs['scan_id']
76 |
77 | self.train_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(**dataset_conf)
78 |
79 | self.ds_len = len(self.train_dataset)
80 | print('Finish loading data. Data-set size: {0}'.format(self.ds_len))
81 | if scan_id < 24 and scan_id > 0: # BlendedMVS, running for 200k iterations
82 | self.nepochs = int(200000 / self.ds_len)
83 | print('RUNNING FOR {0}'.format(self.nepochs))
84 |
85 | self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset,
86 | batch_size=self.batch_size,
87 | shuffle=True,
88 | collate_fn=self.train_dataset.collate_fn
89 | )
90 | self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset,
91 | batch_size=self.conf.get_int('plot.plot_nimgs'),
92 | shuffle=True,
93 | collate_fn=self.train_dataset.collate_fn
94 | )
95 |
96 | conf_model = self.conf.get_config('model')
97 | self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=conf_model)
98 | if torch.cuda.is_available():
99 | self.model.cuda()
100 |
101 | self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss'))
102 |
103 | self.lr = self.conf.get_float('train.learning_rate')
104 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
105 | # Exponential learning rate scheduler
106 | decay_rate = self.conf.get_float('train.sched_decay_rate', default=0.1)
107 | decay_steps = self.nepochs * len(self.train_dataset)
108 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, decay_rate ** (1./decay_steps))
109 |
110 | self.do_vis = kwargs['do_vis']
111 |
112 | self.start_epoch = 0
113 | if is_continue:
114 | old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')
115 |
116 | saved_model_state = torch.load(
117 | os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
118 | self.model.load_state_dict(saved_model_state["model_state_dict"])
119 | self.start_epoch = saved_model_state['epoch']
120 |
121 | data = torch.load(
122 | os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
123 | self.optimizer.load_state_dict(data["optimizer_state_dict"])
124 |
125 | data = torch.load(
126 | os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth"))
127 | self.scheduler.load_state_dict(data["scheduler_state_dict"])
128 |
129 | self.num_pixels = self.conf.get_int('train.num_pixels')
130 | self.total_pixels = self.train_dataset.total_pixels
131 | self.img_res = self.train_dataset.img_res
132 | self.n_batches = len(self.train_dataloader)
133 | self.plot_freq = self.conf.get_int('train.plot_freq')
134 | self.checkpoint_freq = self.conf.get_int('train.checkpoint_freq', default=100)
135 | self.split_n_pixels = self.conf.get_int('train.split_n_pixels', default=10000)
136 | self.plot_conf = self.conf.get_config('plot')
137 |
138 | def save_checkpoints(self, epoch):
139 | torch.save(
140 | {"epoch": epoch, "model_state_dict": self.model.state_dict()},
141 | os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth"))
142 | torch.save(
143 | {"epoch": epoch, "model_state_dict": self.model.state_dict()},
144 | os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth"))
145 |
146 | torch.save(
147 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
148 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth"))
149 | torch.save(
150 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
151 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
152 |
153 | torch.save(
154 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
155 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, str(epoch) + ".pth"))
156 | torch.save(
157 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
158 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, "latest.pth"))
159 |
160 | def run(self):
161 | print("training...")
162 |
163 | for epoch in range(self.start_epoch, self.nepochs + 1):
164 |
165 | if epoch % self.checkpoint_freq == 0:
166 | self.save_checkpoints(epoch)
167 |
168 | if self.do_vis and epoch % self.plot_freq == 0:
169 | self.model.eval()
170 |
171 | self.train_dataset.change_sampling_idx(-1)
172 | indices, model_input, ground_truth = next(iter(self.plot_dataloader))
173 |
174 | model_input["intrinsics"] = model_input["intrinsics"].cuda()
175 | model_input["uv"] = model_input["uv"].cuda()
176 | model_input['pose'] = model_input['pose'].cuda()
177 |
178 | split = utils.split_input(model_input, self.total_pixels, n_pixels=self.split_n_pixels)
179 | res = []
180 | for s in tqdm(split):
181 | out = self.model(s)
182 | d = {'rgb_values': out['rgb_values'].detach(),
183 | 'normal_map': out['normal_map'].detach()}
184 | res.append(d)
185 |
186 | batch_size = ground_truth['rgb'].shape[0]
187 | model_outputs = utils.merge_output(res, self.total_pixels, batch_size)
188 | plot_data = self.get_plot_data(model_outputs, model_input['pose'], ground_truth['rgb'])
189 |
190 | plt.plot(self.model.implicit_network,
191 | indices,
192 | plot_data,
193 | self.plots_dir,
194 | epoch,
195 | self.img_res,
196 | **self.plot_conf
197 | )
198 |
199 | self.model.train()
200 |
201 | self.train_dataset.change_sampling_idx(self.num_pixels)
202 |
203 | for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader):
204 | model_input["intrinsics"] = model_input["intrinsics"].cuda()
205 | model_input["uv"] = model_input["uv"].cuda()
206 | model_input['pose'] = model_input['pose'].cuda()
207 |
208 | model_outputs = self.model(model_input)
209 | loss_output = self.loss(model_outputs, ground_truth)
210 |
211 | loss = loss_output['loss']
212 |
213 | self.optimizer.zero_grad()
214 | loss.backward()
215 | self.optimizer.step()
216 |
217 | psnr = rend_util.get_psnr(model_outputs['rgb_values'],
218 | ground_truth['rgb'].cuda().reshape(-1,3))
219 | print(
220 | '{0}_{1} [{2}] ({3}/{4}): loss = {5}, rgb_loss = {6}, eikonal_loss = {7}, psnr = {8}'
221 | .format(self.expname, self.timestamp, epoch, data_index, self.n_batches, loss.item(),
222 | loss_output['rgb_loss'].item(),
223 | loss_output['eikonal_loss'].item(),
224 | psnr.item()))
225 |
226 | self.train_dataset.change_sampling_idx(self.num_pixels)
227 | self.scheduler.step()
228 |
229 | self.save_checkpoints(epoch)
230 |
231 | def get_plot_data(self, model_outputs, pose, rgb_gt):
232 | batch_size, num_samples, _ = rgb_gt.shape
233 |
234 | rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3)
235 | normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3)
236 | normal_map = (normal_map + 1.) / 2.
237 |
238 | plot_data = {
239 | 'rgb_gt': rgb_gt,
240 | 'pose': pose,
241 | 'rgb_eval': rgb_eval,
242 | 'normal_map': normal_map,
243 | }
244 |
245 | return plot_data
246 |
--------------------------------------------------------------------------------
/code/utils/general.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import torch
4 |
5 | def mkdir_ifnotexists(directory):
6 | if not os.path.exists(directory):
7 | os.mkdir(directory)
8 |
9 | def get_class(kls):
10 | parts = kls.split('.')
11 | module = ".".join(parts[:-1])
12 | m = __import__(module)
13 | for comp in parts[1:]:
14 | m = getattr(m, comp)
15 | return m
16 |
17 | def glob_imgs(path):
18 | imgs = []
19 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']:
20 | imgs.extend(glob(os.path.join(path, ext)))
21 | return imgs
22 |
23 | def split_input(model_input, total_pixels, n_pixels=10000):
24 | '''
25 | Split the input to fit Cuda memory for large resolution.
26 | Can decrease the value of n_pixels in case of cuda out of memory error.
27 | '''
28 | split = []
29 | for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)):
30 | data = model_input.copy()
31 | data['uv'] = torch.index_select(model_input['uv'], 1, indx)
32 | if 'object_mask' in data:
33 | data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx)
34 | split.append(data)
35 | return split
36 |
37 | def merge_output(res, total_pixels, batch_size):
38 | ''' Merge the split output. '''
39 |
40 | model_outputs = {}
41 | for entry in res[0]:
42 | if res[0][entry] is None:
43 | continue
44 | if len(res[0][entry].shape) == 1:
45 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res],
46 | 1).reshape(batch_size * total_pixels)
47 | else:
48 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res],
49 | 1).reshape(batch_size * total_pixels, -1)
50 |
51 | return model_outputs
52 |
53 | def concat_home_dir(path):
54 | return os.path.join(os.environ['HOME'],'data',path)
--------------------------------------------------------------------------------
/code/utils/plots.py:
--------------------------------------------------------------------------------
1 | import plotly.graph_objs as go
2 | import plotly.offline as offline
3 | from plotly.subplots import make_subplots
4 | import numpy as np
5 | import torch
6 | from skimage import measure
7 | import torchvision
8 | import trimesh
9 | from PIL import Image
10 |
11 | from utils import rend_util
12 |
13 |
14 | def plot(implicit_network, indices, plot_data, path, epoch, img_res, plot_nimgs, resolution, grid_boundary, level=0):
15 |
16 | if plot_data is not None:
17 | cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose'])
18 |
19 | plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res)
20 |
21 | # plot normal maps
22 | plot_normal_maps(plot_data['normal_map'], path, epoch, plot_nimgs, img_res)
23 |
24 |
25 | data = []
26 |
27 | # plot surface
28 | surface_traces = get_surface_trace(path=path,
29 | epoch=epoch,
30 | sdf=lambda x: implicit_network(x)[:, 0],
31 | resolution=resolution,
32 | grid_boundary=grid_boundary,
33 | level=level
34 | )
35 |
36 | if surface_traces is not None:
37 | data.append(surface_traces[0])
38 |
39 | # plot cameras locations
40 | if plot_data is not None:
41 | for i, loc, dir in zip(indices, cam_loc, cam_dir):
42 | data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i)))
43 |
44 | fig = go.Figure(data=data)
45 | scene_dict = dict(xaxis=dict(range=[-6, 6], autorange=False),
46 | yaxis=dict(range=[-6, 6], autorange=False),
47 | zaxis=dict(range=[-6, 6], autorange=False),
48 | aspectratio=dict(x=1, y=1, z=1))
49 | fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)
50 | filename = '{0}/surface_{1}.html'.format(path, epoch)
51 | offline.plot(fig, filename=filename, auto_open=False)
52 |
53 |
54 | def get_3D_scatter_trace(points, name='', size=3, caption=None):
55 | assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped "
56 | assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped "
57 |
58 | trace = go.Scatter3d(
59 | x=points[:, 0].cpu(),
60 | y=points[:, 1].cpu(),
61 | z=points[:, 2].cpu(),
62 | mode='markers',
63 | name=name,
64 | marker=dict(
65 | size=size,
66 | line=dict(
67 | width=2,
68 | ),
69 | opacity=1.0,
70 | ), text=caption)
71 |
72 | return trace
73 |
74 |
75 | def get_3D_quiver_trace(points, directions, color='#bd1540', name=''):
76 | assert points.shape[1] == 3, "3d cone plot input points are not correctely shaped "
77 | assert len(points.shape) == 2, "3d cone plot input points are not correctely shaped "
78 | assert directions.shape[1] == 3, "3d cone plot input directions are not correctely shaped "
79 | assert len(directions.shape) == 2, "3d cone plot input directions are not correctely shaped "
80 |
81 | trace = go.Cone(
82 | name=name,
83 | x=points[:, 0].cpu(),
84 | y=points[:, 1].cpu(),
85 | z=points[:, 2].cpu(),
86 | u=directions[:, 0].cpu(),
87 | v=directions[:, 1].cpu(),
88 | w=directions[:, 2].cpu(),
89 | sizemode='absolute',
90 | sizeref=0.125,
91 | showscale=False,
92 | colorscale=[[0, color], [1, color]],
93 | anchor="tail"
94 | )
95 |
96 | return trace
97 |
98 |
99 | def get_surface_trace(path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0):
100 | grid = get_grid_uniform(resolution, grid_boundary)
101 | points = grid['grid_points']
102 |
103 | z = []
104 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
105 | z.append(sdf(pnts).detach().cpu().numpy())
106 | z = np.concatenate(z, axis=0)
107 |
108 | if (not (np.min(z) > level or np.max(z) < level)):
109 |
110 | z = z.astype(np.float32)
111 |
112 | verts, faces, normals, values = measure.marching_cubes(
113 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
114 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
115 | level=level,
116 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
117 | grid['xyz'][0][2] - grid['xyz'][0][1],
118 | grid['xyz'][0][2] - grid['xyz'][0][1]))
119 |
120 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
121 |
122 | I, J, K = faces.transpose()
123 |
124 | traces = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
125 | i=I, j=J, k=K, name='implicit_surface',
126 | color='#ffffff', opacity=1.0, flatshading=False,
127 | lighting=dict(diffuse=1, ambient=0, specular=0),
128 | lightposition=dict(x=0, y=0, z=-1), showlegend=True)]
129 |
130 | meshexport = trimesh.Trimesh(verts, faces, normals)
131 | meshexport.export('{0}/surface_{1}.ply'.format(path, epoch), 'ply')
132 |
133 | if return_mesh:
134 | return meshexport
135 | return traces
136 | return None
137 |
138 | def get_surface_high_res_mesh(sdf, resolution=100, grid_boundary=[-2.0, 2.0], level=0, take_components=True):
139 | # get low res mesh to sample point cloud
140 | grid = get_grid_uniform(100, grid_boundary)
141 | z = []
142 | points = grid['grid_points']
143 |
144 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
145 | z.append(sdf(pnts).detach().cpu().numpy())
146 | z = np.concatenate(z, axis=0)
147 |
148 | z = z.astype(np.float32)
149 |
150 | verts, faces, normals, values = measure.marching_cubes(
151 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
152 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
153 | level=level,
154 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
155 | grid['xyz'][0][2] - grid['xyz'][0][1],
156 | grid['xyz'][0][2] - grid['xyz'][0][1]))
157 |
158 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
159 |
160 | mesh_low_res = trimesh.Trimesh(verts, faces, normals)
161 | if take_components:
162 | components = mesh_low_res.split(only_watertight=False)
163 | areas = np.array([c.area for c in components], dtype=np.float)
164 | mesh_low_res = components[areas.argmax()]
165 |
166 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
167 | recon_pc = torch.from_numpy(recon_pc).float().cuda()
168 |
169 | # Center and align the recon pc
170 | s_mean = recon_pc.mean(dim=0)
171 | s_cov = recon_pc - s_mean
172 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
173 | vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0]
174 | if torch.det(vecs) < 0:
175 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs)
176 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),
177 | (recon_pc - s_mean).unsqueeze(-1)).squeeze()
178 |
179 | grid_aligned = get_grid(helper.cpu(), resolution)
180 |
181 | grid_points = grid_aligned['grid_points']
182 |
183 | g = []
184 | for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)):
185 | g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),
186 | pnts.unsqueeze(-1)).squeeze() + s_mean)
187 | grid_points = torch.cat(g, dim=0)
188 |
189 | # MC to new grid
190 | points = grid_points
191 | z = []
192 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
193 | z.append(sdf(pnts).detach().cpu().numpy())
194 | z = np.concatenate(z, axis=0)
195 |
196 | meshexport = None
197 | if (not (np.min(z) > level or np.max(z) < level)):
198 |
199 | z = z.astype(np.float32)
200 |
201 | verts, faces, normals, values = measure.marching_cubes(
202 | volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0],
203 | grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]),
204 | level=level,
205 | spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
206 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
207 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1]))
208 |
209 | verts = torch.from_numpy(verts).cuda().float()
210 | verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),
211 | verts.unsqueeze(-1)).squeeze()
212 | verts = (verts + grid_points[0]).cpu().numpy()
213 |
214 | meshexport = trimesh.Trimesh(verts, faces, normals)
215 |
216 | return meshexport
217 |
218 |
219 | def get_surface_by_grid(grid_params, sdf, resolution=100, level=0, higher_res=False):
220 | grid_params = grid_params * [[1.5], [1.0]]
221 |
222 | # params = PLOT_DICT[scan_id]
223 | input_min = torch.tensor(grid_params[0]).float()
224 | input_max = torch.tensor(grid_params[1]).float()
225 |
226 | if higher_res:
227 | # get low res mesh to sample point cloud
228 | grid = get_grid(None, 100, input_min=input_min, input_max=input_max, eps=0.0)
229 | z = []
230 | points = grid['grid_points']
231 |
232 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
233 | z.append(sdf(pnts).detach().cpu().numpy())
234 | z = np.concatenate(z, axis=0)
235 |
236 | z = z.astype(np.float32)
237 |
238 | verts, faces, normals, values = measure.marching_cubes(
239 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
240 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
241 | level=level,
242 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
243 | grid['xyz'][0][2] - grid['xyz'][0][1],
244 | grid['xyz'][0][2] - grid['xyz'][0][1]))
245 |
246 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
247 |
248 | mesh_low_res = trimesh.Trimesh(verts, faces, normals)
249 | components = mesh_low_res.split(only_watertight=False)
250 | areas = np.array([c.area for c in components], dtype=np.float)
251 | mesh_low_res = components[areas.argmax()]
252 |
253 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
254 | recon_pc = torch.from_numpy(recon_pc).float().cuda()
255 |
256 | # Center and align the recon pc
257 | s_mean = recon_pc.mean(dim=0)
258 | s_cov = recon_pc - s_mean
259 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
260 | vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0]
261 | if torch.det(vecs) < 0:
262 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs)
263 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),
264 | (recon_pc - s_mean).unsqueeze(-1)).squeeze()
265 |
266 | grid_aligned = get_grid(helper.cpu(), resolution, eps=0.01)
267 | else:
268 | grid_aligned = get_grid(None, resolution, input_min=input_min, input_max=input_max, eps=0.0)
269 |
270 | grid_points = grid_aligned['grid_points']
271 |
272 | if higher_res:
273 | g = []
274 | for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)):
275 | g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),
276 | pnts.unsqueeze(-1)).squeeze() + s_mean)
277 | grid_points = torch.cat(g, dim=0)
278 |
279 | # MC to new grid
280 | points = grid_points
281 | z = []
282 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
283 | z.append(sdf(pnts).detach().cpu().numpy())
284 | z = np.concatenate(z, axis=0)
285 |
286 | meshexport = None
287 | if (not (np.min(z) > level or np.max(z) < level)):
288 |
289 | z = z.astype(np.float32)
290 |
291 | verts, faces, normals, values = measure.marching_cubes(
292 | volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0],
293 | grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]),
294 | level=level,
295 | spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
296 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
297 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1]))
298 |
299 | if higher_res:
300 | verts = torch.from_numpy(verts).cuda().float()
301 | verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),
302 | verts.unsqueeze(-1)).squeeze()
303 | verts = (verts + grid_points[0]).cpu().numpy()
304 | else:
305 | verts = verts + np.array([grid_aligned['xyz'][0][0], grid_aligned['xyz'][1][0], grid_aligned['xyz'][2][0]])
306 |
307 | meshexport = trimesh.Trimesh(verts, faces, normals)
308 |
309 | # CUTTING MESH ACCORDING TO THE BOUNDING BOX
310 | if higher_res:
311 | bb = grid_params
312 | transformation = np.eye(4)
313 | transformation[:3, 3] = (bb[1,:] + bb[0,:])/2.
314 | bounding_box = trimesh.creation.box(extents=bb[1,:] - bb[0,:], transform=transformation)
315 |
316 | meshexport = meshexport.slice_plane(bounding_box.facets_origin, -bounding_box.facets_normal)
317 |
318 | return meshexport
319 |
320 | def get_grid_uniform(resolution, grid_boundary=[-2.0, 2.0]):
321 | x = np.linspace(grid_boundary[0], grid_boundary[1], resolution)
322 | y = x
323 | z = x
324 |
325 | xx, yy, zz = np.meshgrid(x, y, z)
326 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)
327 |
328 | return {"grid_points": grid_points.cuda(),
329 | "shortest_axis_length": 2.0,
330 | "xyz": [x, y, z],
331 | "shortest_axis_index": 0}
332 |
333 | def get_grid(points, resolution, input_min=None, input_max=None, eps=0.1):
334 | if input_min is None or input_max is None:
335 | input_min = torch.min(points, dim=0)[0].squeeze().numpy()
336 | input_max = torch.max(points, dim=0)[0].squeeze().numpy()
337 |
338 | bounding_box = input_max - input_min
339 | shortest_axis = np.argmin(bounding_box)
340 | if (shortest_axis == 0):
341 | x = np.linspace(input_min[shortest_axis] - eps,
342 | input_max[shortest_axis] + eps, resolution)
343 | length = np.max(x) - np.min(x)
344 | y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
345 | z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
346 | elif (shortest_axis == 1):
347 | y = np.linspace(input_min[shortest_axis] - eps,
348 | input_max[shortest_axis] + eps, resolution)
349 | length = np.max(y) - np.min(y)
350 | x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
351 | z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
352 | elif (shortest_axis == 2):
353 | z = np.linspace(input_min[shortest_axis] - eps,
354 | input_max[shortest_axis] + eps, resolution)
355 | length = np.max(z) - np.min(z)
356 | x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
357 | y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
358 |
359 | xx, yy, zz = np.meshgrid(x, y, z)
360 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()
361 | return {"grid_points": grid_points,
362 | "shortest_axis_length": length,
363 | "xyz": [x, y, z],
364 | "shortest_axis_index": shortest_axis}
365 |
366 |
367 | def plot_normal_maps(normal_maps, path, epoch, plot_nrow, img_res):
368 | normal_maps_plot = lin2img(normal_maps, img_res)
369 |
370 | tensor = torchvision.utils.make_grid(normal_maps_plot,
371 | scale_each=False,
372 | normalize=False,
373 | nrow=plot_nrow).cpu().detach().numpy()
374 | tensor = tensor.transpose(1, 2, 0)
375 | scale_factor = 255
376 | tensor = (tensor * scale_factor).astype(np.uint8)
377 |
378 | img = Image.fromarray(tensor)
379 | img.save('{0}/normal_{1}.png'.format(path, epoch))
380 |
381 |
382 | def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res):
383 | ground_true = ground_true.cuda()
384 |
385 | output_vs_gt = torch.cat((rgb_points, ground_true), dim=0)
386 | output_vs_gt_plot = lin2img(output_vs_gt, img_res)
387 |
388 | tensor = torchvision.utils.make_grid(output_vs_gt_plot,
389 | scale_each=False,
390 | normalize=False,
391 | nrow=plot_nrow).cpu().detach().numpy()
392 |
393 | tensor = tensor.transpose(1, 2, 0)
394 | scale_factor = 255
395 | tensor = (tensor * scale_factor).astype(np.uint8)
396 |
397 | img = Image.fromarray(tensor)
398 | img.save('{0}/rendering_{1}.png'.format(path, epoch))
399 |
400 |
401 | def lin2img(tensor, img_res):
402 | batch_size, num_samples, channels = tensor.shape
403 | return tensor.permute(0, 2, 1).view(batch_size, channels, img_res[0], img_res[1])
404 |
--------------------------------------------------------------------------------
/code/utils/rend_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | import skimage
4 | import cv2
5 | import torch
6 | from torch.nn import functional as F
7 |
8 |
9 | def get_psnr(img1, img2, normalize_rgb=False):
10 | if normalize_rgb: # [-1,1] --> [0,1]
11 | img1 = (img1 + 1.) / 2.
12 | img2 = (img2 + 1. ) / 2.
13 |
14 | mse = torch.mean((img1 - img2) ** 2)
15 | psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda())
16 |
17 | return psnr
18 |
19 |
20 | def load_rgb(path, normalize_rgb = False):
21 | img = imageio.imread(path)
22 | img = skimage.img_as_float32(img)
23 |
24 | if normalize_rgb: # [-1,1] --> [0,1]
25 | img -= 0.5
26 | img *= 2.
27 | img = img.transpose(2, 0, 1)
28 | return img
29 |
30 |
31 | def load_K_Rt_from_P(filename, P=None):
32 | if P is None:
33 | lines = open(filename).read().splitlines()
34 | if len(lines) == 4:
35 | lines = lines[1:]
36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
37 | P = np.asarray(lines).astype(np.float32).squeeze()
38 |
39 | out = cv2.decomposeProjectionMatrix(P)
40 | K = out[0]
41 | R = out[1]
42 | t = out[2]
43 |
44 | K = K/K[2,2]
45 | intrinsics = np.eye(4)
46 | intrinsics[:3, :3] = K
47 |
48 | pose = np.eye(4, dtype=np.float32)
49 | pose[:3, :3] = R.transpose()
50 | pose[:3,3] = (t[:3] / t[3])[:,0]
51 |
52 | return intrinsics, pose
53 |
54 |
55 | def get_camera_params(uv, pose, intrinsics):
56 | if pose.shape[1] == 7: #In case of quaternion vector representation
57 | cam_loc = pose[:, 4:]
58 | R = quat_to_rot(pose[:,:4])
59 | p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
60 | p[:, :3, :3] = R
61 | p[:, :3, 3] = cam_loc
62 | else: # In case of pose matrix representation
63 | cam_loc = pose[:, :3, 3]
64 | p = pose
65 |
66 | batch_size, num_samples, _ = uv.shape
67 |
68 | depth = torch.ones((batch_size, num_samples)).cuda()
69 | x_cam = uv[:, :, 0].view(batch_size, -1)
70 | y_cam = uv[:, :, 1].view(batch_size, -1)
71 | z_cam = depth.view(batch_size, -1)
72 |
73 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
74 |
75 | # permute for batch matrix product
76 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
77 |
78 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
79 | ray_dirs = world_coords - cam_loc[:, None, :]
80 | ray_dirs = F.normalize(ray_dirs, dim=2)
81 |
82 | return ray_dirs, cam_loc
83 |
84 |
85 | def get_camera_for_plot(pose):
86 | if pose.shape[1] == 7: #In case of quaternion vector representation
87 | cam_loc = pose[:, 4:].detach()
88 | R = quat_to_rot(pose[:,:4].detach())
89 | else: # In case of pose matrix representation
90 | cam_loc = pose[:, :3, 3]
91 | R = pose[:, :3, :3]
92 | cam_dir = R[:, :3, 2]
93 | return cam_loc, cam_dir
94 |
95 |
96 | def lift(x, y, z, intrinsics):
97 | # parse intrinsics
98 | intrinsics = intrinsics.cuda()
99 | fx = intrinsics[:, 0, 0]
100 | fy = intrinsics[:, 1, 1]
101 | cx = intrinsics[:, 0, 2]
102 | cy = intrinsics[:, 1, 2]
103 | sk = intrinsics[:, 0, 1]
104 |
105 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
106 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
107 |
108 | # homogeneous
109 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1)
110 |
111 |
112 | def quat_to_rot(q):
113 | batch_size, _ = q.shape
114 | q = F.normalize(q, dim=1)
115 | R = torch.ones((batch_size, 3,3)).cuda()
116 | qr=q[:,0]
117 | qi = q[:, 1]
118 | qj = q[:, 2]
119 | qk = q[:, 3]
120 | R[:, 0, 0]=1-2 * (qj**2 + qk**2)
121 | R[:, 0, 1] = 2 * (qj *qi -qk*qr)
122 | R[:, 0, 2] = 2 * (qi * qk + qr * qj)
123 | R[:, 1, 0] = 2 * (qj * qi + qk * qr)
124 | R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
125 | R[:, 1, 2] = 2*(qj*qk - qi*qr)
126 | R[:, 2, 0] = 2 * (qk * qi-qj * qr)
127 | R[:, 2, 1] = 2 * (qj*qk + qi*qr)
128 | R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
129 | return R
130 |
131 |
132 | def rot_to_quat(R):
133 | batch_size, _,_ = R.shape
134 | q = torch.ones((batch_size, 4)).cuda()
135 |
136 | R00 = R[:, 0,0]
137 | R01 = R[:, 0, 1]
138 | R02 = R[:, 0, 2]
139 | R10 = R[:, 1, 0]
140 | R11 = R[:, 1, 1]
141 | R12 = R[:, 1, 2]
142 | R20 = R[:, 2, 0]
143 | R21 = R[:, 2, 1]
144 | R22 = R[:, 2, 2]
145 |
146 | q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2
147 | q[:, 1]=(R21-R12)/(4*q[:,0])
148 | q[:, 2] = (R02 - R20) / (4 * q[:, 0])
149 | q[:, 3] = (R10 - R01) / (4 * q[:, 0])
150 | return q
151 |
152 |
153 | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0):
154 | # Input: n_rays x 3 ; n_rays x 3
155 | # Output: n_rays x 1, n_rays x 1 (close and far)
156 |
157 | ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3),
158 | cam_loc.view(-1, 3, 1)).squeeze(-1)
159 | under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2)
160 |
161 | # sanity check
162 | if (under_sqrt <= 0).sum() > 0:
163 | print('BOUNDING SPHERE PROBLEM!')
164 | exit()
165 |
166 | sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot
167 | sphere_intersections = sphere_intersections.clamp_min(0.0)
168 |
169 | return sphere_intersections
170 |
--------------------------------------------------------------------------------
/data/download_data.sh:
--------------------------------------------------------------------------------
1 | confsmkdir -p data
2 | cd data
3 | echo "Downloading the DTU dataset ..."
4 | wget https://www.dropbox.com/s/s6psnh1q91m4kgo/DTU.zip
5 | echo "Start unzipping ..."
6 | unzip DTU.zip
7 | echo "DTU dataset is ready!"
8 | rm -f DTU.zip
9 | echo "Downloading the BlendedMVS dataset ..."
10 | wget https://www.dropbox.com/s/c88216wzn9t6pj8/BlendedMVS.zip
11 | echo "Start unzipping ..."
12 | unzip BlendedMVS.zip
13 | echo "BlendedMVS dataset is ready!"
14 | rm -f BlendedMVS.zip
--------------------------------------------------------------------------------
/data/preprocess/normalize_cameras.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import argparse
4 |
5 |
6 | def get_center_point(num_cams,cameras):
7 | A = np.zeros((3 * num_cams, 3 + num_cams))
8 | b = np.zeros((3 * num_cams, 1))
9 | camera_centers=np.zeros((3,num_cams))
10 | for i in range(num_cams):
11 | P0 = cameras['world_mat_%d' % i][:3, :]
12 |
13 | K = cv2.decomposeProjectionMatrix(P0)[0]
14 | R = cv2.decomposeProjectionMatrix(P0)[1]
15 | c = cv2.decomposeProjectionMatrix(P0)[2]
16 | c = c / c[3]
17 | camera_centers[:,i]=c[:3].flatten()
18 |
19 | v = np.linalg.inv(K) @ np.array([800, 600, 1])
20 | v = v / np.linalg.norm(v)
21 |
22 | v=R[2,:]
23 | A[3 * i:(3 * i + 3), :3] = np.eye(3)
24 | A[3 * i:(3 * i + 3), 3 + i] = -v
25 | b[3 * i:(3 * i + 3)] = c[:3]
26 |
27 | soll= np.linalg.pinv(A) @ b
28 |
29 | return soll,camera_centers
30 |
31 | def normalize_cameras(original_cameras_filename,output_cameras_filename,num_of_cameras):
32 | cameras = np.load(original_cameras_filename)
33 | if num_of_cameras==-1:
34 | all_files=cameras.files
35 | maximal_ind=0
36 | for field in all_files:
37 | maximal_ind=np.maximum(maximal_ind,int(field.split('_')[-1]))
38 | num_of_cameras=maximal_ind+1
39 | soll, camera_centers = get_center_point(num_of_cameras, cameras)
40 |
41 | center = soll[:3].flatten()
42 |
43 | max_radius = np.linalg.norm((center[:, np.newaxis] - camera_centers), axis=0).max() * 1.1
44 |
45 | normalization = np.eye(4).astype(np.float32)
46 |
47 | normalization[0, 3] = center[0]
48 | normalization[1, 3] = center[1]
49 | normalization[2, 3] = center[2]
50 |
51 | normalization[0, 0] = max_radius / 3.0
52 | normalization[1, 1] = max_radius / 3.0
53 | normalization[2, 2] = max_radius / 3.0
54 |
55 | cameras_new = {}
56 | for i in range(num_of_cameras):
57 | cameras_new['scale_mat_%d' % i] = normalization
58 | cameras_new['world_mat_%d' % i] = cameras['world_mat_%d' % i].copy()
59 | np.savez(output_cameras_filename, **cameras_new)
60 |
61 |
62 | if __name__ == "__main__":
63 | parser = argparse.ArgumentParser(description='Normalizing cameras')
64 | parser.add_argument('--input_cameras_file', type=str, default="cameras.npz",
65 | help='the input cameras file')
66 | parser.add_argument('--output_cameras_file', type=str, default="cameras_normalize.npz",
67 | help='the output cameras file')
68 | parser.add_argument('--number_of_cams',type=int, default=-1,
69 | help='Number of cameras, if -1 use all')
70 |
71 | args = parser.parse_args()
72 | normalize_cameras(args.input_cameras_file, args.output_cameras_file, args.number_of_cams)
73 |
--------------------------------------------------------------------------------
/data/preprocess/parse_cameras_blendedmvs.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import argparse
4 | import os
5 |
6 | def read_camera(sequence,ind):
7 | file = "%s/cams/%08d_cam.txt"%(sequence,ind)
8 | f = open(file)
9 |
10 | f.readline().strip()
11 |
12 | row1 = f.readline().strip().split()
13 | row2 = f.readline().strip().split()
14 | row3 = f.readline().strip().split()
15 |
16 | M = np.stack(
17 | (np.array(row1).astype(np.float32), np.array(row2).astype(np.float32), np.array(row3).astype(np.float32)))
18 | f.readline()
19 | f.readline()
20 | f.readline()
21 | row1 = f.readline().strip().split()
22 | row2 = f.readline().strip().split()
23 | row3 = f.readline().strip().split()
24 | K = np.stack(
25 | (np.array(row1).astype(np.float32), np.array(row2).astype(np.float32), np.array(row3).astype(np.float32)))
26 |
27 | return (K,M)
28 |
29 | def parse_scan(scan_ind,output_cameras_file,blendedMVS_path):
30 | files = os.listdir('%s/scan%d/cams' % (blendedMVS_path,scan_ind))
31 | num_cams = len(files) - 1
32 |
33 | cameras_new = {}
34 | for i in range(num_cams):
35 | Ki, Mi = read_camera("%s/scan%d" % (blendedMVS_path,scan_ind), int(files[i][:8]))
36 | curp = np.eye(4).astype(np.float32)
37 | curp[:3, :] = Ki @ Mi
38 | cameras_new['world_mat_%d' % i] = curp.copy()
39 |
40 | np.savez(
41 | output_cameras_file,
42 | **cameras_new)
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser(description='Parsing blendedMVS')
46 | parser.add_argument('--blendedMVS_path', type=str, default="BlendedMVS",
47 | help='the blendedMVS path')
48 | parser.add_argument('--output_cameras_file', type=str, default="cameras.npz",
49 | help='the output cameras file')
50 | parser.add_argument('--scan_ind',type=int,
51 | help='Scan id')
52 |
53 | args = parser.parse_args()
54 | parse_scan(args.scan_ind,args.output_cameras_file,args.blendedMVS_path)
55 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: volsdf
2 | channels:
3 | - pytorch
4 | - plotly
5 | - anaconda
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=main
10 | - _openmp_mutex=4.5=1_gnu
11 | - aadict=0.2.3=pyh9f0ad1d_0
12 | - asset=0.6.13=pyh9f0ad1d_0
13 | - blas=1.0=mkl
14 | - blosc=1.21.0=h9c3ff4c_0
15 | - bottleneck=1.3.2=py39hdd57654_1
16 | - brotli=1.0.9=he6710b0_2
17 | - brunsli=0.1=h2531618_0
18 | - bzip2=1.0.8=h7b6447c_0
19 | - ca-certificates=2021.10.8=ha878542_0
20 | - cairo=1.16.0=hcf35c78_1003
21 | - certifi=2021.10.8=py39h06a4308_0
22 | - cfitsio=3.470=hf0d0db6_6
23 | - charls=2.2.0=h2531618_0
24 | - cloudpickle=1.6.0=py_0
25 | - colorama=0.4.4=pyh9f0ad1d_0
26 | - cudatoolkit=10.2.89=hfd86e86_1
27 | - cycler=0.11.0=pyhd3eb1b0_0
28 | - cytoolz=0.11.0=py39h27cfd23_0
29 | - dask-core=2.30.0=py_0
30 | - dbus=1.13.6=hfdff14a_1
31 | - decorator=4.4.2=py_0
32 | - expat=2.4.1=h9c3ff4c_0
33 | - ffmpeg=4.3.2=hca11adc_0
34 | - fontconfig=2.13.1=hba837de_1005
35 | - fonttools=4.25.0=pyhd3eb1b0_0
36 | - freetype=2.11.0=h70c0345_0
37 | - gettext=0.19.8.1=hf34092f_1004
38 | - giflib=5.2.1=h7b6447c_0
39 | - glib=2.66.3=h58526e2_0
40 | - globre=0.1.5=pyh9f0ad1d_0
41 | - gmp=6.2.1=h2531618_2
42 | - gnutls=3.6.15=he1e5248_0
43 | - gputil=1.4.0=pyh9f0ad1d_0
44 | - graphite2=1.3.13=h58526e2_1001
45 | - gst-plugins-base=1.14.5=h0935bb2_2
46 | - gstreamer=1.14.5=h36ae1b5_2
47 | - harfbuzz=2.4.0=h9f30f68_3
48 | - hdf5=1.10.6=nompi_h7c3c948_1111
49 | - icu=64.2=he1b5a44_1
50 | - imagecodecs=2021.3.31=py39h7572904_1
51 | - imageio=2.9.0=py_0
52 | - intel-openmp=2021.4.0=h06a4308_3561
53 | - jasper=1.900.1=h07fcdf6_1006
54 | - jbig=2.1=h7f98852_2003
55 | - jpeg=9d=h7f8727e_0
56 | - jxrlib=1.1=h7b6447c_2
57 | - kiwisolver=1.3.1=py39h2531618_0
58 | - krb5=1.18.2=h173b8e3_0
59 | - lame=3.100=h7b6447c_0
60 | - lcms2=2.12=h3be6417_0
61 | - ld_impl_linux-64=2.35.1=h7274673_9
62 | - lerc=2.2.1=h9c3ff4c_0
63 | - libaec=1.0.4=he6710b0_1
64 | - libblas=3.9.0=12_linux64_mkl
65 | - libcblas=3.9.0=12_linux64_mkl
66 | - libclang=9.0.1=default_ha53f305_1
67 | - libcurl=7.71.1=h20c2e04_1
68 | - libdeflate=1.7=h7f98852_5
69 | - libedit=3.1.20191231=h14c3975_1
70 | - libffi=3.2.1=he1b5a44_1007
71 | - libgcc-ng=9.3.0=h5101ec6_17
72 | - libgfortran-ng=7.5.0=ha8ba4b0_17
73 | - libgfortran4=7.5.0=ha8ba4b0_17
74 | - libglib=2.66.3=hbe7bbb4_0
75 | - libgomp=9.3.0=h5101ec6_17
76 | - libiconv=1.16=h516909a_0
77 | - libidn2=2.3.2=h7f8727e_0
78 | - liblapack=3.9.0=12_linux64_mkl
79 | - liblapacke=3.9.0=12_linux64_mkl
80 | - libllvm9=9.0.1=hf817b99_2
81 | - libopencv=4.5.2=py39h70bf20d_1
82 | - libpng=1.6.37=hbc83047_0
83 | - libprotobuf=3.16.0=h780b84a_0
84 | - libssh2=1.9.0=h1ba5d50_1
85 | - libstdcxx-ng=9.3.0=hd4cf53a_17
86 | - libtasn1=4.16.0=h27cfd23_0
87 | - libtiff=4.3.0=hf544144_1
88 | - libunistring=0.9.10=h27cfd23_0
89 | - libuuid=2.32.1=h7f98852_1000
90 | - libuv=1.40.0=h7b6447c_0
91 | - libwebp=1.2.0=h89dd481_0
92 | - libwebp-base=1.2.0=h27cfd23_0
93 | - libxcb=1.13=h7f98852_1003
94 | - libxkbcommon=0.10.0=he1b5a44_0
95 | - libxml2=2.9.10=hee79883_0
96 | - libzopfli=1.0.3=he6710b0_0
97 | - lz4-c=1.9.3=h295c915_1
98 | - matplotlib-base=3.5.0=py39h3ed280b_0
99 | - mkl=2021.4.0=h06a4308_640
100 | - mkl-service=2.4.0=py39h7f8727e_0
101 | - mkl_fft=1.3.1=py39hd3c417c_0
102 | - mkl_random=1.2.2=py39h51133e4_0
103 | - munkres=1.1.4=py_0
104 | - ncurses=6.2=h58526e2_4
105 | - nettle=3.7.3=hbbd107a_1
106 | - networkx=2.5=py_0
107 | - nspr=4.30=h9c3ff4c_0
108 | - nss=3.67=hb5efdd6_0
109 | - numexpr=2.7.3=py39h22e1b3c_1
110 | - numpy=1.21.2=py39h20f2e39_0
111 | - numpy-base=1.21.2=py39h79a1101_0
112 | - olefile=0.46=pyhd3eb1b0_0
113 | - opencv=4.5.2=py39hf3d152e_1
114 | - openh264=2.1.1=h780b84a_0
115 | - openjpeg=2.4.0=hb52868f_1
116 | - openssl=1.1.1l=h7f8727e_0
117 | - packaging=20.4=py_0
118 | - pandas=1.3.4=py39h8c16a72_0
119 | - pcre=8.45=h9c3ff4c_0
120 | - pillow=8.4.0=py39h5aabda8_0
121 | - pip=21.2.4=py39h06a4308_0
122 | - pixman=0.38.0=h516909a_1003
123 | - plotly=5.4.0=py_0
124 | - pthread-stubs=0.4=h36c2ea0_1001
125 | - py-opencv=4.5.2=py39hef51801_1
126 | - pyhocon=0.3.59=pyhd8ed1ab_0
127 | - pyparsing=3.0.6=pyhd8ed1ab_0
128 | - python=3.9.0=h2a148a8_4_cpython
129 | - python-dateutil=2.8.1=py_0
130 | - python_abi=3.9=2_cp39
131 | - pytorch=1.10.0=py3.9_cuda10.2_cudnn7.6.5_0
132 | - pytorch-mutex=1.0=cuda
133 | - pytz=2020.1=py_0
134 | - pywavelets=1.1.1=py39h6323ea4_4
135 | - pyyaml=6.0=py39h7f8727e_1
136 | - qt=5.12.5=hd8c4c69_1
137 | - readline=8.1=h27cfd23_0
138 | - scikit-image=0.18.3=py39h51133e4_0
139 | - scipy=1.7.1=py39h292c36d_2
140 | - setuptools=58.0.4=py39h06a4308_0
141 | - six=1.16.0=pyhd3eb1b0_0
142 | - snappy=1.1.8=he6710b0_0
143 | - sqlite=3.36.0=hc218d9a_0
144 | - tenacity=8.0.1=py39h06a4308_0
145 | - tifffile=2020.10.1=py_0
146 | - tk=8.6.11=h1ccaba5_0
147 | - toolz=0.11.1=py_0
148 | - torchaudio=0.10.0=py39_cu102
149 | - torchvision=0.11.1=py39_cu102
150 | - tqdm=4.62.3=pyhd8ed1ab_0
151 | - trimesh=3.9.36=pyh6c4a22f_0
152 | - typing_extensions=3.10.0.2=pyh06a4308_0
153 | - tzdata=2021e=hda174b7_0
154 | - wheel=0.37.0=pyhd3eb1b0_1
155 | - x264=1!161.3030=h7f98852_1
156 | - xorg-kbproto=1.0.7=h7f98852_1002
157 | - xorg-libice=1.0.10=h7f98852_0
158 | - xorg-libsm=1.2.3=hd9c2040_1000
159 | - xorg-libx11=1.7.2=h7f98852_0
160 | - xorg-libxau=1.0.9=h7f98852_0
161 | - xorg-libxdmcp=1.1.3=h7f98852_0
162 | - xorg-libxext=1.3.4=h7f98852_1
163 | - xorg-libxrender=0.9.10=h7f98852_1003
164 | - xorg-renderproto=0.11.1=h7f98852_1002
165 | - xorg-xextproto=7.3.0=h7f98852_1002
166 | - xorg-xproto=7.0.31=h7f98852_1007
167 | - xz=5.2.5=h7b6447c_0
168 | - yaml=0.2.5=h7b6447c_0
169 | - zfp=0.5.5=h2531618_6
170 | - zlib=1.2.11=h7b6447c_3
171 | - zstd=1.5.0=ha95c52a_0
172 |
--------------------------------------------------------------------------------
/media/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lioryariv/volsdf/a974c883eb70af666d8b4374e771d76930c806f3/media/teaser.png
--------------------------------------------------------------------------------