├── .gitignore
├── LICENSE
├── README.md
├── camera_inspector
├── inspect_epipolar_geometry.py
├── screenshot_lowres.png
└── train
│ ├── cam_dict_norm.json
│ └── rgb
│ ├── 000001.png
│ └── 000005.png
├── camera_visualizer
├── camera_path
│ └── cam_dict_norm.json
├── mesh_norm.ply
├── screenshot_lowres.png
├── test
│ └── cam_dict_norm.json
├── train
│ └── cam_dict_norm.json
└── visualize_cameras.py
├── colmap_runner
├── database.py
├── extract_sfm.py
├── normalize_cam_dict.py
├── read_write_model.py
├── run_colmap.py
└── run_colmap_posed.py
├── configs
├── lf_data
│ └── lf_africa.txt
└── tanks_and_temples
│ └── tat_training_truck.txt
├── data_loader_split.py
├── ddp_model.py
├── ddp_test_nerf.py
├── ddp_train_nerf.py
├── demo
├── tat_Playground.gif
└── tat_Truck.gif
├── environment.yml
├── nerf_network.py
├── nerf_sample_ray_split.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # scripts
2 | *.sh
3 |
4 | # mac
5 | .DS_Store
6 |
7 | # pycharm
8 | .idea/
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 | cover/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | .pybuilder/
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | # For a library or package, you might want to ignore these files since the code is
96 | # intended to run in multiple environments; otherwise, check them in:
97 | # .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
143 | # pytype static type analyzer
144 | .pytype/
145 |
146 | # Cython debug symbols
147 | cython_debug/
148 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2020, the NeRF++ authors
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NeRF++
2 | Codebase for arXiv preprint ["NeRF++: Analyzing and Improving Neural Radiance Fields"](http://arxiv.org/abs/2010.07492)
3 | * Work with 360 capture of large-scale unbounded scenes.
4 | * Support multi-gpu training and inference with PyTorch DistributedDataParallel (DDP).
5 | * Optimize per-image autoexposure (**experimental feature**).
6 |
7 | ## Demo
8 |  
9 |
10 | ## Data
11 | * Download our preprocessed data from [tanks_and_temples](https://drive.google.com/file/d/11KRfN91W1AxAW6lOFs4EeYDbeoQZCi87/view?usp=sharing), [lf_data](https://drive.google.com/file/d/1gsjDjkbTh4GAR9fFqlIDZ__qR9NYTURQ/view?usp=sharing).
12 | * Put the data in the sub-folder data/ of this code directory.
13 | * Data format.
14 | * Each scene consists of 3 splits: train/test/validation.
15 | * Intrinsics and poses are stored as flattened 4x4 matrices (row-major).
16 | * Pixel coordinate of an image's upper-left corner is (column, row)=(0, 0), lower-right corner is (width-1, height-1).
17 | * Poses are camera-to-world, not world-to-camera transformations.
18 | * Opencv camera coordinate system is adopted, i.e., x--->right, y--->down, z--->scene. Similarly, intrinsic matrix also follows Opencv convention.
19 | * To convert camera poses between Opencv and Opengl conventions, the following code snippet can be used for both Opengl2Opencv and Opencv2Opengl.
20 | ```python
21 | import numpy as np
22 | def convert_pose(C2W):
23 | flip_yz = np.eye(4)
24 | flip_yz[1, 1] = -1
25 | flip_yz[2, 2] = -1
26 | C2W = np.matmul(C2W, flip_yz)
27 | return C2W
28 | ```
29 | * Scene normalization: move the average camera center to origin, and put all the camera centers inside the unit sphere. Check [normalize_cam_dict.py](https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/normalize_cam_dict.py) for details.
30 |
31 | ## Create environment
32 | ```bash
33 | conda env create --file environment.yml
34 | conda activate nerfplusplus
35 | ```
36 |
37 | ## Training (Use all available GPUs by default)
38 | ```python
39 | python ddp_train_nerf.py --config configs/tanks_and_temples/tat_training_truck.txt
40 | ```
41 |
42 | **Note: In the paper, we train NeRF++ on a node with 4 RTX 2080 Ti GPUs, which took ∼24 hours.**
43 |
44 | ## Testing (Use all available GPUs by default)
45 | ```python
46 | python ddp_test_nerf.py --config configs/tanks_and_temples/tat_training_truck.txt \
47 | --render_splits test,camera_path
48 | ```
49 |
50 | **Note**: due to restriction imposed by torch.distributed.gather function, please make sure the number of pixels in each image is divisible by the number of GPUs if you render images parallelly.
51 |
52 | ## Pretrained weights
53 | I recently re-trained NeRF++ on the tanks and temples data for another project. Here are the checkpoints ([google drive](https://drive.google.com/drive/folders/15P0vCeDiLULzktvFByE5mb6fucJomvHr?usp=sharing)) just in case you might find them useful.
54 |
55 | ## Citation
56 | Plese cite our work if you use the code.
57 | ```python
58 | @article{kaizhang2020,
59 | author = {Kai Zhang and Gernot Riegler and Noah Snavely and Vladlen Koltun},
60 | title = {NeRF++: Analyzing and Improving Neural Radiance Fields},
61 | journal = {arXiv:2010.07492},
62 | year = {2020},
63 | }
64 | ```
65 |
66 | ## Generate camera parameters (intrinsics and poses) with [COLMAP SfM](https://colmap.github.io/)
67 | You can use the scripts inside `colmap_runner` to generate camera parameters from images with COLMAP SfM.
68 | * Specify `img_dir` and `out_dir` in `colmap_runner/run_colmap.py`.
69 | * Inside `colmap_runner/`, execute command `python run_colmap.py`.
70 | * After program finishes, you would see the posed images in the folder `out_dir/posed_images`.
71 | * Distortion-free images are inside `out_dir/posed_images/images`.
72 | * Raw COLMAP intrinsics and poses are stored as a json file `out_dir/posed_images/kai_cameras.json`.
73 | * Normalized cameras are stored in `out_dir/posed_images/kai_cameras_normalized.json`. See the **Scene normalization method** in the **Data** section.
74 | * Split distortion-free images and `kai_cameras_normalized.json` according to your need. You might find the self-explanatory script `data_loader_split.py` helpful when you try converting the json file to data format compatible with NeRF++.
75 |
76 | ## Visualize cameras in 3D
77 | Check `camera_visualizer/visualize_cameras.py` for visualizing cameras in 3D. It creates an interactive viewer for you to inspect whether your cameras have been normalized to be compatible with this codebase. Below is a screenshot of the viewer: green cameras are used for training, blue ones are for testing, while yellow ones denote a novel camera path to be synthesized; red sphere is the unit sphere.
78 |
79 |
80 |
81 |
82 |
83 | ## Inspect camera parameters
84 | You can use `camera_inspector/inspect_epipolar_geometry.py` to inspect if the camera paramters are correct and follow the Opencv convention assumed by this codebase. The script creates a viewer for visually inspecting two-view epipolar geometry like below: for key points in the left image, it plots their correspoinding epipolar lines in the right image. If the epipolar geometry does not look correct in this visualization, it's likely that there are some issues with the camera parameters.
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/camera_inspector/inspect_epipolar_geometry.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | import json
5 |
6 |
7 | def skew(x):
8 | return np.array([[0, -x[2], x[1]],
9 | [x[2], 0, -x[0]],
10 | [-x[1], x[0], 0]])
11 |
12 |
13 | def two_view_geometry(intrinsics1, extrinsics1, intrinsics2, extrinsics2):
14 | relative_pose = extrinsics2.dot(np.linalg.inv(extrinsics1))
15 | R = relative_pose[:3, :3]
16 | T = relative_pose[:3, 3]
17 | tx = skew(T)
18 | E = np.dot(tx, R)
19 | F = np.linalg.inv(intrinsics2[:3, :3]).T.dot(E).dot(np.linalg.inv(intrinsics1[:3, :3]))
20 |
21 | return E, F, relative_pose
22 |
23 |
24 | def drawpointslines(img1, pts1, img2, lines2, colors):
25 | h, w = img2.shape[:2]
26 | # img1 = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR)
27 | # img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR)
28 | print(pts1.shape, lines2.shape, colors.shape)
29 | for p, l, c in zip(pts1, lines2, colors):
30 | c = tuple(c.tolist())
31 | img1 = cv2.circle(img1, tuple(p), 5, c, -1)
32 |
33 | x0, y0 = map(int, [0, -l[2]/l[1]])
34 | x1, y1 = map(int, [w, -(l[2]+l[0]*w)/l[1]])
35 | img2 = cv2.line(img2, (x0, y0), (x1, y1), c, 1, lineType=cv2.LINE_AA)
36 | return img1, img2
37 |
38 |
39 | def inspect(img1, K1, W2C1, img2, K2, W2C2):
40 | E, F, relative_pose = two_view_geometry(K1, W2C1, K2, W2C2)
41 |
42 | orb = cv2.ORB_create()
43 | kp1 = orb.detect(cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY), None)[:20]
44 | pts1 = np.array([[int(kp.pt[0]), int(kp.pt[1])] for kp in kp1])
45 |
46 | lines2 = cv2.computeCorrespondEpilines(pts1.reshape(-1,1,2), 1, F)
47 | lines2 = lines2.reshape(-1, 3)
48 |
49 | colors = np.random.randint(0, high=255, size=(len(pts1), 3))
50 |
51 | img1, img2 = drawpointslines(img1, pts1, img2, lines2, colors)
52 |
53 | im_to_show = np.concatenate((img1, img2), axis=1)
54 | # down sample to fit screen
55 | h, w = im_to_show.shape[:2]
56 | im_to_show = cv2.resize(im_to_show, (int(0.5*w), int(0.5*h)), interpolation=cv2.INTER_AREA)
57 | cv2.imshow('epipolar geometry', im_to_show)
58 | cv2.waitKey(0)
59 | cv2.destroyAllWindows()
60 |
61 |
62 | if __name__ == '__main__':
63 | base_dir = './'
64 |
65 | img_dir = os.path.join(base_dir, 'train/rgb')
66 | cam_dict_file = os.path.join(base_dir, 'train/cam_dict_norm.json')
67 | img_name1 = '000001.png'
68 | img_name2 = '000005.png'
69 |
70 | cam_dict = json.load(open(cam_dict_file))
71 | img1 = cv2.imread(os.path.join(img_dir, img_name1))
72 | K1 = np.array(cam_dict[img_name1]['K']).reshape((4, 4))
73 | W2C1 = np.array(cam_dict[img_name1]['W2C']).reshape((4, 4))
74 |
75 | img2 = cv2.imread(os.path.join(img_dir, img_name2))
76 | K2 = np.array(cam_dict[img_name2]['K']).reshape((4, 4))
77 | W2C2 = np.array(cam_dict[img_name2]['W2C']).reshape((4, 4))
78 |
79 | inspect(img1, K1, W2C1, img2, K2, W2C2)
80 |
81 |
--------------------------------------------------------------------------------
/camera_inspector/screenshot_lowres.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_inspector/screenshot_lowres.png
--------------------------------------------------------------------------------
/camera_inspector/train/rgb/000001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_inspector/train/rgb/000001.png
--------------------------------------------------------------------------------
/camera_inspector/train/rgb/000005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_inspector/train/rgb/000005.png
--------------------------------------------------------------------------------
/camera_visualizer/mesh_norm.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_visualizer/mesh_norm.ply
--------------------------------------------------------------------------------
/camera_visualizer/screenshot_lowres.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_visualizer/screenshot_lowres.png
--------------------------------------------------------------------------------
/camera_visualizer/test/cam_dict_norm.json:
--------------------------------------------------------------------------------
1 | {
2 | "000173.png": {
3 | "K": [
4 | 581.7877197265625,
5 | 0.0,
6 | 490.25,
7 | 0.0,
8 | 0.0,
9 | 581.7877197265625,
10 | 272.75,
11 | 0.0,
12 | 0.0,
13 | 0.0,
14 | 1.0,
15 | 0.0,
16 | 0.0,
17 | 0.0,
18 | 0.0,
19 | 1.0
20 | ],
21 | "W2C": [
22 | -0.6749785542488097,
23 | -0.10284826159477237,
24 | 0.730634093284607,
25 | -0.1415396263044849,
26 | 0.1326861083507538,
27 | 0.9571743607521058,
28 | 0.2573162019252777,
29 | -0.11108331785106688,
30 | -0.7258087396621704,
31 | 0.2706279158592224,
32 | -0.6324256062507629,
33 | 0.5185927744415963,
34 | 0.0,
35 | 0.0,
36 | 0.0,
37 | 1.0
38 | ],
39 | "img_size": [
40 | 980,
41 | 546
42 | ]
43 | },
44 | "000174.png": {
45 | "K": [
46 | 581.7877197265625,
47 | 0.0,
48 | 490.25,
49 | 0.0,
50 | 0.0,
51 | 581.7877197265625,
52 | 272.75,
53 | 0.0,
54 | 0.0,
55 | 0.0,
56 | 1.0,
57 | 0.0,
58 | 0.0,
59 | 0.0,
60 | 0.0,
61 | 1.0
62 | ],
63 | "W2C": [
64 | -0.6433882117271422,
65 | -0.10309550166130069,
66 | 0.758566379547119,
67 | -0.10336116912543912,
68 | 0.142945721745491,
69 | 0.9572839736938477,
70 | 0.2513442039489746,
71 | -0.10921025065909233,
72 | -0.7520759105682373,
73 | 0.27014571428298956,
74 | -0.601168155670166,
75 | 0.5231507731285338,
76 | 0.0,
77 | 0.0,
78 | 0.0,
79 | 1.0
80 | ],
81 | "img_size": [
82 | 980,
83 | 546
84 | ]
85 | },
86 | "000175.png": {
87 | "K": [
88 | 581.7877197265625,
89 | 0.0,
90 | 490.25,
91 | 0.0,
92 | 0.0,
93 | 581.7877197265625,
94 | 272.75,
95 | 0.0,
96 | 0.0,
97 | 0.0,
98 | 1.0,
99 | 0.0,
100 | 0.0,
101 | 0.0,
102 | 0.0,
103 | 1.0
104 | ],
105 | "W2C": [
106 | -0.6209585666656491,
107 | -0.10511306673288354,
108 | 0.776763617992401,
109 | -0.06012172693562135,
110 | 0.1500456780195236,
111 | 0.9567026495933535,
112 | 0.24941192567348483,
113 | -0.11759666418072104,
114 | -0.7693482637405393,
115 | 0.27142450213432306,
116 | -0.578300952911377,
117 | 0.558755932058601,
118 | 0.0,
119 | 0.0,
120 | 0.0,
121 | 1.0
122 | ],
123 | "img_size": [
124 | 980,
125 | 546
126 | ]
127 | },
128 | "000176.png": {
129 | "K": [
130 | 581.7877197265625,
131 | 0.0,
132 | 490.25,
133 | 0.0,
134 | 0.0,
135 | 581.7877197265625,
136 | 272.75,
137 | 0.0,
138 | 0.0,
139 | 0.0,
140 | 1.0,
141 | 0.0,
142 | 0.0,
143 | 0.0,
144 | 0.0,
145 | 1.0
146 | ],
147 | "W2C": [
148 | -0.6149386763572693,
149 | -0.10690607875585548,
150 | 0.7812947630882263,
151 | -0.0004664353357674536,
152 | 0.15467639267444613,
153 | 0.9551697969436644,
154 | 0.2524398863315582,
155 | -0.12307422643556815,
156 | -0.7732564806938171,
157 | 0.27608290314674383,
158 | -0.5708349943161011,
159 | 0.5910430066979042,
160 | 0.0,
161 | 0.0,
162 | 0.0,
163 | 1.0
164 | ],
165 | "img_size": [
166 | 980,
167 | 546
168 | ]
169 | },
170 | "000177.png": {
171 | "K": [
172 | 581.7877197265625,
173 | 0.0,
174 | 490.25,
175 | 0.0,
176 | 0.0,
177 | 581.7877197265625,
178 | 272.75,
179 | 0.0,
180 | 0.0,
181 | 0.0,
182 | 1.0,
183 | 0.0,
184 | 0.0,
185 | 0.0,
186 | 0.0,
187 | 1.0
188 | ],
189 | "W2C": [
190 | -0.5901945233345031,
191 | -0.11034809798002247,
192 | 0.7996835112571716,
193 | 0.04891851798678256,
194 | 0.16408699750900269,
195 | 0.9535346627235411,
196 | 0.2526799142360687,
197 | -0.13504251807839815,
198 | -0.790408730506897,
199 | 0.28034797310829157,
200 | -0.544664204120636,
201 | 0.6147589251643815,
202 | 0.0,
203 | 0.0,
204 | 0.0,
205 | 1.0
206 | ],
207 | "img_size": [
208 | 980,
209 | 546
210 | ]
211 | },
212 | "000178.png": {
213 | "K": [
214 | 581.7877197265625,
215 | 0.0,
216 | 490.25,
217 | 0.0,
218 | 0.0,
219 | 581.7877197265625,
220 | 272.75,
221 | 0.0,
222 | 0.0,
223 | 0.0,
224 | 1.0,
225 | 0.0,
226 | 0.0,
227 | 0.0,
228 | 0.0,
229 | 1.0
230 | ],
231 | "W2C": [
232 | -0.5680983662605286,
233 | -0.11199841648340227,
234 | 0.8153039813041688,
235 | 0.11406040174495254,
236 | 0.17730242013931277,
237 | 0.9507739543914796,
238 | 0.2541510760784149,
239 | -0.13489467795265128,
240 | -0.8036343455314636,
241 | 0.28893816471099854,
242 | -0.5202755331993103,
243 | 0.6091760132826988,
244 | 0.0,
245 | 0.0,
246 | 0.0,
247 | 1.0
248 | ],
249 | "img_size": [
250 | 980,
251 | 546
252 | ]
253 | },
254 | "000179.png": {
255 | "K": [
256 | 581.7877197265625,
257 | 0.0,
258 | 490.25,
259 | 0.0,
260 | 0.0,
261 | 581.7877197265625,
262 | 272.75,
263 | 0.0,
264 | 0.0,
265 | 0.0,
266 | 1.0,
267 | 0.0,
268 | 0.0,
269 | 0.0,
270 | 0.0,
271 | 1.0
272 | ],
273 | "W2C": [
274 | -0.5251939892768861,
275 | -0.11293330788612364,
276 | 0.8434556126594543,
277 | 0.15233507599020302,
278 | 0.19423152506351476,
279 | 0.9490842819213869,
280 | 0.24801832437515262,
281 | -0.13645353989132347,
282 | -0.8285199999809266,
283 | 0.2940833866596223,
284 | -0.47651815414428705,
285 | 0.6077169856718815,
286 | 0.0,
287 | 0.0,
288 | 0.0,
289 | 1.0
290 | ],
291 | "img_size": [
292 | 980,
293 | 546
294 | ]
295 | },
296 | "000180.png": {
297 | "K": [
298 | 581.7877197265625,
299 | 0.0,
300 | 490.25,
301 | 0.0,
302 | 0.0,
303 | 581.7877197265625,
304 | 272.75,
305 | 0.0,
306 | 0.0,
307 | 0.0,
308 | 1.0,
309 | 0.0,
310 | 0.0,
311 | 0.0,
312 | 0.0,
313 | 1.0
314 | ],
315 | "W2C": [
316 | -0.4368840157985687,
317 | -0.12017646431922917,
318 | 0.8914538621902465,
319 | 0.1584227767276896,
320 | 0.21023185551166537,
321 | 0.9499467611312866,
322 | 0.23109236359596255,
323 | -0.14294346814956563,
324 | -0.8746055960655214,
325 | 0.28837254643440247,
326 | -0.3897516429424286,
327 | 0.6279825376657556,
328 | 0.0,
329 | 0.0,
330 | 0.0,
331 | 1.0
332 | ],
333 | "img_size": [
334 | 980,
335 | 546
336 | ]
337 | },
338 | "000181.png": {
339 | "K": [
340 | 581.7877197265625,
341 | 0.0,
342 | 490.25,
343 | 0.0,
344 | 0.0,
345 | 581.7877197265625,
346 | 272.75,
347 | 0.0,
348 | 0.0,
349 | 0.0,
350 | 1.0,
351 | 0.0,
352 | 0.0,
353 | 0.0,
354 | 0.0,
355 | 1.0
356 | ],
357 | "W2C": [
358 | -0.3019417822360992,
359 | -0.12455697357654576,
360 | 0.9451543092727661,
361 | 0.12391406697080755,
362 | 0.23592309653758997,
363 | 0.9508262872695923,
364 | 0.2006731331348419,
365 | -0.15344331535910236,
366 | -0.9236727952957151,
367 | 0.2835753560066222,
368 | -0.2577083110809326,
369 | 0.649313956576192,
370 | 0.0,
371 | 0.0,
372 | 0.0,
373 | 1.0
374 | ],
375 | "img_size": [
376 | 980,
377 | 546
378 | ]
379 | },
380 | "000182.png": {
381 | "K": [
382 | 581.7877197265625,
383 | 0.0,
384 | 490.25,
385 | 0.0,
386 | 0.0,
387 | 581.7877197265625,
388 | 272.75,
389 | 0.0,
390 | 0.0,
391 | 0.0,
392 | 1.0,
393 | 0.0,
394 | 0.0,
395 | 0.0,
396 | 0.0,
397 | 1.0
398 | ],
399 | "W2C": [
400 | -0.05365315079689027,
401 | -0.13223105669021606,
402 | 0.9897657632827759,
403 | 0.008633672276731267,
404 | 0.256687343120575,
405 | 0.9560590982437135,
406 | 0.14164239168167117,
407 | -0.1679039701011501,
408 | -0.9650041460990905,
409 | 0.26165989041328425,
410 | -0.01735354587435723,
411 | 0.6778822326079519,
412 | 0.0,
413 | 0.0,
414 | 0.0,
415 | 1.0
416 | ],
417 | "img_size": [
418 | 980,
419 | 546
420 | ]
421 | },
422 | "000183.png": {
423 | "K": [
424 | 581.7877197265625,
425 | 0.0,
426 | 490.25,
427 | 0.0,
428 | 0.0,
429 | 581.7877197265625,
430 | 272.75,
431 | 0.0,
432 | 0.0,
433 | 0.0,
434 | 1.0,
435 | 0.0,
436 | 0.0,
437 | 0.0,
438 | 0.0,
439 | 1.0
440 | ],
441 | "W2C": [
442 | 0.08435853570699692,
443 | -0.11229652166366576,
444 | 0.9900874495506288,
445 | -0.02667809703294818,
446 | 0.2744568288326264,
447 | 0.9578129649162292,
448 | 0.08525134623050688,
449 | -0.18123123987972412,
450 | -0.9578920006752015,
451 | 0.2645445764064789,
452 | 0.1116202473640442,
453 | 0.68498322092808,
454 | 0.0,
455 | 0.0,
456 | 0.0,
457 | 1.0
458 | ],
459 | "img_size": [
460 | 980,
461 | 546
462 | ]
463 | },
464 | "000184.png": {
465 | "K": [
466 | 581.7877197265625,
467 | 0.0,
468 | 490.25,
469 | 0.0,
470 | 0.0,
471 | 581.7877197265625,
472 | 272.75,
473 | 0.0,
474 | 0.0,
475 | 0.0,
476 | 1.0,
477 | 0.0,
478 | 0.0,
479 | 0.0,
480 | 0.0,
481 | 1.0
482 | ],
483 | "W2C": [
484 | 0.173485592007637,
485 | -0.09174104034900667,
486 | 0.9805541038513182,
487 | -0.029180471150004907,
488 | 0.29095843434333796,
489 | 0.955982208251953,
490 | 0.03796394541859627,
491 | -0.2051112120771908,
492 | -0.9408751130104065,
493 | 0.278714269399643,
494 | 0.1925419718027115,
495 | 0.7113401071184143,
496 | 0.0,
497 | 0.0,
498 | 0.0,
499 | 1.0
500 | ],
501 | "img_size": [
502 | 980,
503 | 546
504 | ]
505 | },
506 | "000185.png": {
507 | "K": [
508 | 581.7877197265625,
509 | 0.0,
510 | 490.25,
511 | 0.0,
512 | 0.0,
513 | 581.7877197265625,
514 | 272.75,
515 | 0.0,
516 | 0.0,
517 | 0.0,
518 | 1.0,
519 | 0.0,
520 | 0.0,
521 | 0.0,
522 | 0.0,
523 | 1.0
524 | ],
525 | "W2C": [
526 | 0.23342822492122645,
527 | -0.07284945249557495,
528 | 0.9696412682533264,
529 | -0.005887109452774881,
530 | 0.29702088236808777,
531 | 0.954870939254761,
532 | 0.0002359295467613364,
533 | -0.21843709467707623,
534 | -0.9258995056152343,
535 | 0.28794863820075983,
536 | 0.24453164637088773,
537 | 0.7118756075295797,
538 | 0.0,
539 | 0.0,
540 | 0.0,
541 | 1.0
542 | ],
543 | "img_size": [
544 | 980,
545 | 546
546 | ]
547 | },
548 | "000186.png": {
549 | "K": [
550 | 581.7877197265625,
551 | 0.0,
552 | 490.25,
553 | 0.0,
554 | 0.0,
555 | 581.7877197265625,
556 | 272.75,
557 | 0.0,
558 | 0.0,
559 | 0.0,
560 | 1.0,
561 | 0.0,
562 | 0.0,
563 | 0.0,
564 | 0.0,
565 | 1.0
566 | ],
567 | "W2C": [
568 | 0.3044866621494293,
569 | -0.05772994086146357,
570 | 0.9507655501365662,
571 | 0.005409785790216546,
572 | 0.29629379510879517,
573 | 0.9543822407722473,
574 | -0.0369397886097431,
575 | -0.2284336673437035,
576 | -0.905261218547821,
577 | 0.2929536104202271,
578 | 0.30770167708396906,
579 | 0.7186774152057739,
580 | 0.0,
581 | 0.0,
582 | 0.0,
583 | 1.0
584 | ],
585 | "img_size": [
586 | 980,
587 | 546
588 | ]
589 | },
590 | "000187.png": {
591 | "K": [
592 | 581.7877197265625,
593 | 0.0,
594 | 490.25,
595 | 0.0,
596 | 0.0,
597 | 581.7877197265625,
598 | 272.75,
599 | 0.0,
600 | 0.0,
601 | 0.0,
602 | 1.0,
603 | 0.0,
604 | 0.0,
605 | 0.0,
606 | 0.0,
607 | 1.0
608 | ],
609 | "W2C": [
610 | 0.37352681159973145,
611 | -0.042923863977193846,
612 | 0.9266257286071778,
613 | 0.013849648484600553,
614 | 0.2894437611103058,
615 | 0.9544482231140137,
616 | -0.07246334105730054,
617 | -0.2333756901439718,
618 | -0.8813058733940125,
619 | 0.29527303576469427,
620 | 0.3689360618591308,
621 | 0.7126951784083984,
622 | 0.0,
623 | 0.0,
624 | 0.0,
625 | 1.0
626 | ],
627 | "img_size": [
628 | 980,
629 | 546
630 | ]
631 | },
632 | "000188.png": {
633 | "K": [
634 | 581.7877197265625,
635 | 0.0,
636 | 490.25,
637 | 0.0,
638 | 0.0,
639 | 581.7877197265625,
640 | 272.75,
641 | 0.0,
642 | 0.0,
643 | 0.0,
644 | 1.0,
645 | 0.0,
646 | 0.0,
647 | 0.0,
648 | 0.0,
649 | 1.0
650 | ],
651 | "W2C": [
652 | 0.49691322445869446,
653 | -0.03459263965487483,
654 | 0.8671104907989503,
655 | -0.02804523469799382,
656 | 0.27165842056274414,
657 | 0.9551849961280822,
658 | -0.11757244169712067,
659 | -0.23797951926387803,
660 | -0.8241838216781616,
661 | 0.2939811646938324,
662 | 0.48404145240783697,
663 | 0.7082496879378362,
664 | 0.0,
665 | 0.0,
666 | 0.0,
667 | 1.0
668 | ],
669 | "img_size": [
670 | 980,
671 | 546
672 | ]
673 | },
674 | "000189.png": {
675 | "K": [
676 | 581.7877197265625,
677 | 0.0,
678 | 490.25,
679 | 0.0,
680 | 0.0,
681 | 581.7877197265625,
682 | 272.75,
683 | 0.0,
684 | 0.0,
685 | 0.0,
686 | 1.0,
687 | 0.0,
688 | 0.0,
689 | 0.0,
690 | 0.0,
691 | 1.0
692 | ],
693 | "W2C": [
694 | 0.656022787094116,
695 | -0.028436420485377232,
696 | 0.7542051672935486,
697 | -0.11063643763864509,
698 | 0.2370839267969131,
699 | 0.9564713835716248,
700 | -0.17015771567821505,
701 | -0.23631046162031968,
702 | -0.7165369987487792,
703 | 0.29043728113174433,
704 | 0.6342088580131531,
705 | 0.6869845554740425,
706 | 0.0,
707 | 0.0,
708 | 0.0,
709 | 1.0
710 | ],
711 | "img_size": [
712 | 980,
713 | 546
714 | ]
715 | },
716 | "000190.png": {
717 | "K": [
718 | 581.7877197265625,
719 | 0.0,
720 | 490.25,
721 | 0.0,
722 | 0.0,
723 | 581.7877197265625,
724 | 272.75,
725 | 0.0,
726 | 0.0,
727 | 0.0,
728 | 1.0,
729 | 0.0,
730 | 0.0,
731 | 0.0,
732 | 0.0,
733 | 1.0
734 | ],
735 | "W2C": [
736 | 0.7924461960792542,
737 | -0.01046304591000084,
738 | 0.6098520755767822,
739 | -0.19199604683373434,
740 | 0.18565021455287933,
741 | 0.956550121307373,
742 | -0.22482399642467502,
743 | -0.23064418498320746,
744 | -0.5810017585754395,
745 | 0.2913800776004791,
746 | 0.759956955909729,
747 | 0.6486130335907647,
748 | 0.0,
749 | 0.0,
750 | 0.0,
751 | 1.0
752 | ],
753 | "img_size": [
754 | 980,
755 | 546
756 | ]
757 | },
758 | "000191.png": {
759 | "K": [
760 | 581.7877197265625,
761 | 0.0,
762 | 490.25,
763 | 0.0,
764 | 0.0,
765 | 581.7877197265625,
766 | 272.75,
767 | 0.0,
768 | 0.0,
769 | 0.0,
770 | 1.0,
771 | 0.0,
772 | 0.0,
773 | 0.0,
774 | 0.0,
775 | 1.0
776 | ],
777 | "W2C": [
778 | 0.8854711055755615,
779 | 0.010563611052930369,
780 | 0.46457436680793757,
781 | -0.24282484537054266,
782 | 0.1252961456775665,
783 | 0.9572874307632447,
784 | -0.26057946681976313,
785 | -0.21866898323124923,
786 | -0.44748386740684487,
787 | 0.288944959640503,
788 | 0.8463267683982846,
789 | 0.6151146596650793,
790 | 0.0,
791 | 0.0,
792 | 0.0,
793 | 1.0
794 | ],
795 | "img_size": [
796 | 980,
797 | 546
798 | ]
799 | },
800 | "000192.png": {
801 | "K": [
802 | 581.7877197265625,
803 | 0.0,
804 | 490.25,
805 | 0.0,
806 | 0.0,
807 | 581.7877197265625,
808 | 272.75,
809 | 0.0,
810 | 0.0,
811 | 0.0,
812 | 1.0,
813 | 0.0,
814 | 0.0,
815 | 0.0,
816 | 0.0,
817 | 1.0
818 | ],
819 | "W2C": [
820 | 0.9388862252235413,
821 | 0.03694019094109534,
822 | 0.3422397673130036,
823 | -0.27164514724995453,
824 | 0.06416944414377213,
825 | 0.9580151438713074,
826 | -0.27944466471672064,
827 | -0.20499586789544427,
828 | -0.33819359540939326,
829 | 0.28432807326316833,
830 | 0.8970967531204224,
831 | 0.5788967097431258,
832 | 0.0,
833 | 0.0,
834 | 0.0,
835 | 1.0
836 | ],
837 | "img_size": [
838 | 980,
839 | 546
840 | ]
841 | },
842 | "000193.png": {
843 | "K": [
844 | 581.7877197265625,
845 | 0.0,
846 | 490.25,
847 | 0.0,
848 | 0.0,
849 | 581.7877197265625,
850 | 272.75,
851 | 0.0,
852 | 0.0,
853 | 0.0,
854 | 1.0,
855 | 0.0,
856 | 0.0,
857 | 0.0,
858 | 0.0,
859 | 1.0
860 | ],
861 | "W2C": [
862 | 0.9651142954826355,
863 | 0.05770149827003481,
864 | 0.2553918063640595,
865 | -0.272934838794894,
866 | 0.015767628327012072,
867 | 0.9608356356620789,
868 | -0.27666985988616943,
869 | -0.19339081801285132,
870 | -0.2613538205623627,
871 | 0.2710449695587158,
872 | 0.9264063835144044,
873 | 0.5628077025323545,
874 | 0.0,
875 | 0.0,
876 | 0.0,
877 | 1.0
878 | ],
879 | "img_size": [
880 | 980,
881 | 546
882 | ]
883 | },
884 | "000194.png": {
885 | "K": [
886 | 581.7877197265625,
887 | 0.0,
888 | 490.25,
889 | 0.0,
890 | 0.0,
891 | 581.7877197265625,
892 | 272.75,
893 | 0.0,
894 | 0.0,
895 | 0.0,
896 | 1.0,
897 | 0.0,
898 | 0.0,
899 | 0.0,
900 | 0.0,
901 | 1.0
902 | ],
903 | "W2C": [
904 | 0.9572978019714354,
905 | 0.06780177354812619,
906 | 0.28104069828987116,
907 | -0.20122927411147834,
908 | 0.007655009161680944,
909 | 0.9658248424530029,
910 | -0.25908261537551885,
911 | -0.19321847581034227,
912 | -0.2890023589134216,
913 | 0.2501705884933471,
914 | 0.924062967300415,
915 | 0.563575894892584,
916 | 0.0,
917 | 0.0,
918 | 0.0,
919 | 1.0
920 | ],
921 | "img_size": [
922 | 980,
923 | 546
924 | ]
925 | },
926 | "000195.png": {
927 | "K": [
928 | 581.7877197265625,
929 | 0.0,
930 | 490.25,
931 | 0.0,
932 | 0.0,
933 | 581.7877197265625,
934 | 272.75,
935 | 0.0,
936 | 0.0,
937 | 0.0,
938 | 1.0,
939 | 0.0,
940 | 0.0,
941 | 0.0,
942 | 0.0,
943 | 1.0
944 | ],
945 | "W2C": [
946 | 0.9394192695617677,
947 | 0.040031708776950836,
948 | 0.3404245376586914,
949 | -0.10805127469870195,
950 | 0.03222170472145081,
951 | 0.9784454107284546,
952 | -0.20397628843784332,
953 | -0.17995865052818397,
954 | -0.34125232696533203,
955 | 0.20258831977844238,
956 | 0.917880594730377,
957 | 0.5956252623341279,
958 | 0.0,
959 | 0.0,
960 | 0.0,
961 | 1.0
962 | ],
963 | "img_size": [
964 | 980,
965 | 546
966 | ]
967 | },
968 | "000196.png": {
969 | "K": [
970 | 581.7877197265625,
971 | 0.0,
972 | 490.25,
973 | 0.0,
974 | 0.0,
975 | 581.7877197265625,
976 | 272.75,
977 | 0.0,
978 | 0.0,
979 | 0.0,
980 | 1.0,
981 | 0.0,
982 | 0.0,
983 | 0.0,
984 | 0.0,
985 | 1.0
986 | ],
987 | "W2C": [
988 | 0.9188615679740906,
989 | 0.01963340304791927,
990 | 0.3940912783145905,
991 | -0.02124089688202119,
992 | 0.0335158184170723,
993 | 0.9912682771682739,
994 | -0.1275297701358795,
995 | -0.13773930218650604,
996 | -0.39315405488014227,
997 | 0.1303904950618744,
998 | 0.9101803302764893,
999 | 0.590018327758047,
1000 | 0.0,
1001 | 0.0,
1002 | 0.0,
1003 | 1.0
1004 | ],
1005 | "img_size": [
1006 | 980,
1007 | 546
1008 | ]
1009 | },
1010 | "000197.png": {
1011 | "K": [
1012 | 581.7877197265625,
1013 | 0.0,
1014 | 490.25,
1015 | 0.0,
1016 | 0.0,
1017 | 581.7877197265625,
1018 | 272.75,
1019 | 0.0,
1020 | 0.0,
1021 | 0.0,
1022 | 1.0,
1023 | 0.0,
1024 | 0.0,
1025 | 0.0,
1026 | 0.0,
1027 | 1.0
1028 | ],
1029 | "W2C": [
1030 | 0.9298350214958191,
1031 | -0.002477371366694553,
1032 | 0.3679683208465576,
1033 | 0.023792225073484205,
1034 | 0.04552911967039108,
1035 | 0.9930682182312012,
1036 | -0.10836359858512877,
1037 | -0.13308144421741533,
1038 | -0.3651491701602936,
1039 | 0.11751354485750198,
1040 | 0.9235023856163025,
1041 | 0.5847770507172574,
1042 | 0.0,
1043 | 0.0,
1044 | 0.0,
1045 | 1.0
1046 | ],
1047 | "img_size": [
1048 | 980,
1049 | 546
1050 | ]
1051 | }
1052 | }
--------------------------------------------------------------------------------
/camera_visualizer/visualize_cameras.py:
--------------------------------------------------------------------------------
1 | import open3d as o3d
2 | import json
3 | import numpy as np
4 |
5 |
6 | def get_camera_frustum(img_size, K, W2C, frustum_length=0.5, color=[0., 1., 0.]):
7 | W, H = img_size
8 | hfov = np.rad2deg(np.arctan(W / 2. / K[0, 0]) * 2.)
9 | vfov = np.rad2deg(np.arctan(H / 2. / K[1, 1]) * 2.)
10 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.))
11 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.))
12 |
13 | # build view frustum for camera (I, 0)
14 | frustum_points = np.array([[0., 0., 0.], # frustum origin
15 | [-half_w, -half_h, frustum_length], # top-left image corner
16 | [half_w, -half_h, frustum_length], # top-right image corner
17 | [half_w, half_h, frustum_length], # bottom-right image corner
18 | [-half_w, half_h, frustum_length]]) # bottom-left image corner
19 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]])
20 | frustum_colors = np.tile(np.array(color).reshape((1, 3)), (frustum_lines.shape[0], 1))
21 |
22 | # frustum_colors = np.vstack((np.tile(np.array([[1., 0., 0.]]), (4, 1)),
23 | # np.tile(np.array([[0., 1., 0.]]), (4, 1))))
24 |
25 | # transform view frustum from (I, 0) to (R, t)
26 | C2W = np.linalg.inv(W2C)
27 | frustum_points = np.dot(np.hstack((frustum_points, np.ones_like(frustum_points[:, 0:1]))), C2W.T)
28 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4]
29 |
30 | return frustum_points, frustum_lines, frustum_colors
31 |
32 |
33 | def frustums2lineset(frustums):
34 | N = len(frustums)
35 | merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum
36 | merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum
37 | merged_colors = np.zeros((N*8, 3)) # each line gets a color
38 |
39 | for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums):
40 | merged_points[i*5:(i+1)*5, :] = frustum_points
41 | merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5
42 | merged_colors[i*8:(i+1)*8, :] = frustum_colors
43 |
44 | lineset = o3d.geometry.LineSet()
45 | lineset.points = o3d.utility.Vector3dVector(merged_points)
46 | lineset.lines = o3d.utility.Vector2iVector(merged_lines)
47 | lineset.colors = o3d.utility.Vector3dVector(merged_colors)
48 |
49 | return lineset
50 |
51 | def visualize_cameras(colored_camera_dicts, sphere_radius, camera_size=0.1, geometry_file=None, geometry_type='mesh'):
52 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=sphere_radius, resolution=10)
53 | sphere = o3d.geometry.LineSet.create_from_triangle_mesh(sphere)
54 | sphere.paint_uniform_color((1, 0, 0))
55 |
56 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0., 0., 0.])
57 | things_to_draw = [sphere, coord_frame]
58 |
59 | idx = 0
60 | for color, camera_dict in colored_camera_dicts:
61 | idx += 1
62 |
63 | cnt = 0
64 | frustums = []
65 | for img_name in sorted(camera_dict.keys()):
66 | K = np.array(camera_dict[img_name]['K']).reshape((4, 4))
67 | W2C = np.array(camera_dict[img_name]['W2C']).reshape((4, 4))
68 | C2W = np.linalg.inv(W2C)
69 | img_size = camera_dict[img_name]['img_size']
70 | frustums.append(get_camera_frustum(img_size, K, W2C, frustum_length=camera_size, color=color))
71 | cnt += 1
72 | cameras = frustums2lineset(frustums)
73 | things_to_draw.append(cameras)
74 |
75 | if geometry_file is not None:
76 | if geometry_type == 'mesh':
77 | geometry = o3d.io.read_triangle_mesh(geometry_file)
78 | geometry.compute_vertex_normals()
79 | elif geometry_type == 'pointcloud':
80 | geometry = o3d.io.read_point_cloud(geometry_file)
81 | else:
82 | raise Exception('Unknown geometry_type: ', geometry_type)
83 |
84 | things_to_draw.append(geometry)
85 |
86 | o3d.visualization.draw_geometries(things_to_draw)
87 |
88 |
89 | if __name__ == '__main__':
90 | import os
91 |
92 | base_dir = './'
93 |
94 | sphere_radius = 1.
95 | train_cam_dict = json.load(open(os.path.join(base_dir, 'train/cam_dict_norm.json')))
96 | test_cam_dict = json.load(open(os.path.join(base_dir, 'test/cam_dict_norm.json')))
97 | path_cam_dict = json.load(open(os.path.join(base_dir, 'camera_path/cam_dict_norm.json')))
98 | camera_size = 0.1
99 | colored_camera_dicts = [([0, 1, 0], train_cam_dict),
100 | ([0, 0, 1], test_cam_dict),
101 | ([1, 1, 0], path_cam_dict)
102 | ]
103 |
104 | geometry_file = os.path.join(base_dir, 'mesh_norm.ply')
105 | geometry_type = 'mesh'
106 |
107 | visualize_cameras(colored_camera_dicts, sphere_radius,
108 | camera_size=camera_size, geometry_file=geometry_file, geometry_type=geometry_type)
109 |
--------------------------------------------------------------------------------
/colmap_runner/database.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright
8 | # notice, this list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 | # its contributors may be used to endorse or promote products derived
16 | # from this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 | # POSSIBILITY OF SUCH DAMAGE.
29 | #
30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31 |
32 | # This script is based on an original implementation by True Price.
33 |
34 | import sys
35 | import sqlite3
36 | import numpy as np
37 |
38 |
39 | IS_PYTHON3 = sys.version_info[0] >= 3
40 |
41 | MAX_IMAGE_ID = 2**31 - 1
42 |
43 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
44 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
45 | model INTEGER NOT NULL,
46 | width INTEGER NOT NULL,
47 | height INTEGER NOT NULL,
48 | params BLOB,
49 | prior_focal_length INTEGER NOT NULL)"""
50 |
51 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
52 | image_id INTEGER PRIMARY KEY NOT NULL,
53 | rows INTEGER NOT NULL,
54 | cols INTEGER NOT NULL,
55 | data_500 BLOB,
56 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
57 |
58 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
59 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
60 | name TEXT NOT NULL UNIQUE,
61 | camera_id INTEGER NOT NULL,
62 | prior_qw REAL,
63 | prior_qx REAL,
64 | prior_qy REAL,
65 | prior_qz REAL,
66 | prior_tx REAL,
67 | prior_ty REAL,
68 | prior_tz REAL,
69 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
70 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
71 | """.format(MAX_IMAGE_ID)
72 |
73 | CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
74 | CREATE TABLE IF NOT EXISTS two_view_geometries (
75 | pair_id INTEGER PRIMARY KEY NOT NULL,
76 | rows INTEGER NOT NULL,
77 | cols INTEGER NOT NULL,
78 | data_500 BLOB,
79 | config INTEGER NOT NULL,
80 | F BLOB,
81 | E BLOB,
82 | H BLOB)
83 | """
84 |
85 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
86 | image_id INTEGER PRIMARY KEY NOT NULL,
87 | rows INTEGER NOT NULL,
88 | cols INTEGER NOT NULL,
89 | data_500 BLOB,
90 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
91 | """
92 |
93 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
94 | pair_id INTEGER PRIMARY KEY NOT NULL,
95 | rows INTEGER NOT NULL,
96 | cols INTEGER NOT NULL,
97 | data_500 BLOB)"""
98 |
99 | CREATE_NAME_INDEX = \
100 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
101 |
102 | CREATE_ALL = "; ".join([
103 | CREATE_CAMERAS_TABLE,
104 | CREATE_IMAGES_TABLE,
105 | CREATE_KEYPOINTS_TABLE,
106 | CREATE_DESCRIPTORS_TABLE,
107 | CREATE_MATCHES_TABLE,
108 | CREATE_TWO_VIEW_GEOMETRIES_TABLE,
109 | CREATE_NAME_INDEX
110 | ])
111 |
112 |
113 | def image_ids_to_pair_id(image_id1, image_id2):
114 | if image_id1 > image_id2:
115 | image_id1, image_id2 = image_id2, image_id1
116 | return image_id1 * MAX_IMAGE_ID + image_id2
117 |
118 |
119 | def pair_id_to_image_ids(pair_id):
120 | image_id2 = pair_id % MAX_IMAGE_ID
121 | image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
122 | return image_id1, image_id2
123 |
124 |
125 | def array_to_blob(array):
126 | if IS_PYTHON3:
127 | return array.tostring()
128 | else:
129 | return np.getbuffer(array)
130 |
131 |
132 | def blob_to_array(blob, dtype, shape=(-1,)):
133 | if IS_PYTHON3:
134 | return np.fromstring(blob, dtype=dtype).reshape(*shape)
135 | else:
136 | return np.frombuffer(blob, dtype=dtype).reshape(*shape)
137 |
138 |
139 | class COLMAPDatabase(sqlite3.Connection):
140 |
141 | @staticmethod
142 | def connect(database_path):
143 | return sqlite3.connect(database_path, factory=COLMAPDatabase)
144 |
145 |
146 | def __init__(self, *args, **kwargs):
147 | super(COLMAPDatabase, self).__init__(*args, **kwargs)
148 |
149 | self.create_tables = lambda: self.executescript(CREATE_ALL)
150 | self.create_cameras_table = \
151 | lambda: self.executescript(CREATE_CAMERAS_TABLE)
152 | self.create_descriptors_table = \
153 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
154 | self.create_images_table = \
155 | lambda: self.executescript(CREATE_IMAGES_TABLE)
156 | self.create_two_view_geometries_table = \
157 | lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
158 | self.create_keypoints_table = \
159 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
160 | self.create_matches_table = \
161 | lambda: self.executescript(CREATE_MATCHES_TABLE)
162 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
163 |
164 | def add_camera(self, model, width, height, params,
165 | prior_focal_length=False, camera_id=None):
166 | params = np.asarray(params, np.float64)
167 | cursor = self.execute(
168 | "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
169 | (camera_id, model, width, height, array_to_blob(params),
170 | prior_focal_length))
171 | return cursor.lastrowid
172 |
173 | def add_image(self, name, camera_id,
174 | prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None):
175 | cursor = self.execute(
176 | "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
177 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
178 | prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
179 | return cursor.lastrowid
180 |
181 | def add_keypoints(self, image_id, keypoints):
182 | assert(len(keypoints.shape) == 2)
183 | assert(keypoints.shape[1] in [2, 4, 6])
184 |
185 | keypoints = np.asarray(keypoints, np.float32)
186 | self.execute(
187 | "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
188 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
189 |
190 | def add_descriptors(self, image_id, descriptors):
191 | descriptors = np.ascontiguousarray(descriptors, np.uint8)
192 | self.execute(
193 | "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
194 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),))
195 |
196 | def add_matches(self, image_id1, image_id2, matches):
197 | assert(len(matches.shape) == 2)
198 | assert(matches.shape[1] == 2)
199 |
200 | if image_id1 > image_id2:
201 | matches = matches[:,::-1]
202 |
203 | pair_id = image_ids_to_pair_id(image_id1, image_id2)
204 | matches = np.asarray(matches, np.uint32)
205 | self.execute(
206 | "INSERT INTO matches VALUES (?, ?, ?, ?)",
207 | (pair_id,) + matches.shape + (array_to_blob(matches),))
208 |
209 | def add_two_view_geometry(self, image_id1, image_id2, matches,
210 | F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
211 | assert(len(matches.shape) == 2)
212 | assert(matches.shape[1] == 2)
213 |
214 | if image_id1 > image_id2:
215 | matches = matches[:,::-1]
216 |
217 | pair_id = image_ids_to_pair_id(image_id1, image_id2)
218 | matches = np.asarray(matches, np.uint32)
219 | F = np.asarray(F, dtype=np.float64)
220 | E = np.asarray(E, dtype=np.float64)
221 | H = np.asarray(H, dtype=np.float64)
222 | self.execute(
223 | "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
224 | (pair_id,) + matches.shape + (array_to_blob(matches), config,
225 | array_to_blob(F), array_to_blob(E), array_to_blob(H)))
226 |
227 |
228 | def example_usage():
229 | import os
230 | import argparse
231 |
232 | parser = argparse.ArgumentParser()
233 | parser.add_argument("--database_path", default="database.db")
234 | args = parser.parse_args()
235 |
236 | if os.path.exists(args.database_path):
237 | logging.info("ERROR: database path already exists -- will not modify it.")
238 | return
239 |
240 | # Open the database.
241 |
242 | db = COLMAPDatabase.connect(args.database_path)
243 |
244 | # For convenience, try creating all the tables upfront.
245 |
246 | db.create_tables()
247 |
248 | # Create dummy cameras.
249 |
250 | model1, width1, height1, params1 = \
251 | 0, 1024, 768, np.array((1024., 512., 384.))
252 | model2, width2, height2, params2 = \
253 | 2, 1024, 768, np.array((1024., 512., 384., 0.1))
254 |
255 | camera_id1 = db.add_camera(model1, width1, height1, params1)
256 | camera_id2 = db.add_camera(model2, width2, height2, params2)
257 |
258 | # Create dummy images.
259 |
260 | image_id1 = db.add_image("image1.png", camera_id1)
261 | image_id2 = db.add_image("image2.png", camera_id1)
262 | image_id3 = db.add_image("image3.png", camera_id2)
263 | image_id4 = db.add_image("image4.png", camera_id2)
264 |
265 | # Create dummy keypoints.
266 | #
267 | # Note that COLMAP supports:
268 | # - 2D keypoints: (x, y)
269 | # - 4D keypoints: (x, y, theta, scale)
270 | # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
271 |
272 | num_keypoints = 1000
273 | keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
274 | keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
275 | keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
276 | keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
277 |
278 | db.add_keypoints(image_id1, keypoints1)
279 | db.add_keypoints(image_id2, keypoints2)
280 | db.add_keypoints(image_id3, keypoints3)
281 | db.add_keypoints(image_id4, keypoints4)
282 |
283 | # Create dummy matches.
284 |
285 | M = 50
286 | matches12 = np.random.randint(num_keypoints, size=(M, 2))
287 | matches23 = np.random.randint(num_keypoints, size=(M, 2))
288 | matches34 = np.random.randint(num_keypoints, size=(M, 2))
289 |
290 | db.add_matches(image_id1, image_id2, matches12)
291 | db.add_matches(image_id2, image_id3, matches23)
292 | db.add_matches(image_id3, image_id4, matches34)
293 |
294 | # Commit the data_500 to the file.
295 |
296 | db.commit()
297 |
298 | # Read and check cameras.
299 |
300 | rows = db.execute("SELECT * FROM cameras")
301 |
302 | camera_id, model, width, height, params, prior = next(rows)
303 | params = blob_to_array(params, np.float64)
304 | assert camera_id == camera_id1
305 | assert model == model1 and width == width1 and height == height1
306 | assert np.allclose(params, params1)
307 |
308 | camera_id, model, width, height, params, prior = next(rows)
309 | params = blob_to_array(params, np.float64)
310 | assert camera_id == camera_id2
311 | assert model == model2 and width == width2 and height == height2
312 | assert np.allclose(params, params2)
313 |
314 | # Read and check keypoints.
315 |
316 | keypoints = dict(
317 | (image_id, blob_to_array(data, np.float32, (-1, 2)))
318 | for image_id, data in db.execute(
319 | "SELECT image_id, data_500 FROM keypoints"))
320 |
321 | assert np.allclose(keypoints[image_id1], keypoints1)
322 | assert np.allclose(keypoints[image_id2], keypoints2)
323 | assert np.allclose(keypoints[image_id3], keypoints3)
324 | assert np.allclose(keypoints[image_id4], keypoints4)
325 |
326 | # Read and check matches.
327 |
328 | pair_ids = [image_ids_to_pair_id(*pair) for pair in
329 | ((image_id1, image_id2),
330 | (image_id2, image_id3),
331 | (image_id3, image_id4))]
332 |
333 | matches = dict(
334 | (pair_id_to_image_ids(pair_id),
335 | blob_to_array(data, np.uint32, (-1, 2)))
336 | for pair_id, data in db.execute("SELECT pair_id, data_500 FROM matches")
337 | )
338 |
339 | assert np.all(matches[(image_id1, image_id2)] == matches12)
340 | assert np.all(matches[(image_id2, image_id3)] == matches23)
341 | assert np.all(matches[(image_id3, image_id4)] == matches34)
342 |
343 | # Clean up.
344 |
345 | db.close()
346 |
347 | if os.path.exists(args.database_path):
348 | os.remove(args.database_path)
349 |
350 |
351 | if __name__ == "__main__":
352 | example_usage()
353 |
--------------------------------------------------------------------------------
/colmap_runner/extract_sfm.py:
--------------------------------------------------------------------------------
1 | from read_write_model import read_model
2 | import numpy as np
3 | import json
4 | import os
5 | from pyquaternion import Quaternion
6 | import trimesh
7 |
8 |
9 | def parse_tracks(colmap_images, colmap_points3D):
10 | all_tracks = [] # list of dicts; each dict represents a track
11 | all_points = [] # list of all 3D points
12 | view_keypoints = {} # dict of lists; each list represents the triangulated key points of a view
13 |
14 |
15 | for point3D_id in colmap_points3D:
16 | point3D = colmap_points3D[point3D_id]
17 | image_ids = point3D.image_ids
18 | point2D_idxs = point3D.point2D_idxs
19 |
20 | cur_track = {}
21 | cur_track['xyz'] = (point3D.xyz[0], point3D.xyz[1], point3D.xyz[2])
22 | cur_track['err'] = point3D.error.item()
23 |
24 | cur_track_len = len(image_ids)
25 | assert (cur_track_len == len(point2D_idxs))
26 | all_points.append(list(cur_track['xyz'] + (cur_track['err'], cur_track_len) + tuple(point3D.rgb)))
27 |
28 | pixels = []
29 | for i in range(cur_track_len):
30 | image = colmap_images[image_ids[i]]
31 | img_name = image.name
32 | point2D_idx = point2D_idxs[i]
33 | point2D = image.xys[point2D_idx]
34 | assert (image.point3D_ids[point2D_idx] == point3D_id)
35 | pixels.append((img_name, point2D[0], point2D[1]))
36 |
37 | if img_name not in view_keypoints:
38 | view_keypoints[img_name] = [(point2D[0], point2D[1]) + cur_track['xyz'] + (cur_track_len, ), ]
39 | else:
40 | view_keypoints[img_name].append((point2D[0], point2D[1]) + cur_track['xyz'] + (cur_track_len, ))
41 |
42 | cur_track['pixels'] = sorted(pixels, key=lambda x: x[0]) # sort pixels by the img_name
43 | all_tracks.append(cur_track)
44 |
45 | return all_tracks, all_points, view_keypoints
46 |
47 |
48 | def parse_camera_dict(colmap_cameras, colmap_images):
49 | camera_dict = {}
50 | for image_id in colmap_images:
51 | image = colmap_images[image_id]
52 |
53 | img_name = image.name
54 | cam = colmap_cameras[image.camera_id]
55 |
56 | # print(cam)
57 | assert(cam.model == 'PINHOLE')
58 |
59 | img_size = [cam.width, cam.height]
60 | params = list(cam.params)
61 | qvec = list(image.qvec)
62 | tvec = list(image.tvec)
63 |
64 | # w, h, fx, fy, cx, cy, qvec, tvec
65 | # camera_dict[img_name] = img_size + params + qvec + tvec
66 | camera_dict[img_name] = {}
67 | camera_dict[img_name]['img_size'] = img_size
68 |
69 | fx, fy, cx, cy = params
70 | K = np.eye(4)
71 | K[0, 0] = fx
72 | K[1, 1] = fy
73 | K[0, 2] = cx
74 | K[1, 2] = cy
75 | camera_dict[img_name]['K'] = list(K.flatten())
76 |
77 | rot = Quaternion(qvec[0], qvec[1], qvec[2], qvec[3]).rotation_matrix
78 | W2C = np.eye(4)
79 | W2C[:3, :3] = rot
80 | W2C[:3, 3] = np.array(tvec)
81 | camera_dict[img_name]['W2C'] = list(W2C.flatten())
82 |
83 | return camera_dict
84 |
85 |
86 | def extract_all_to_dir(sparse_dir, out_dir, ext='.bin'):
87 | if not os.path.exists(out_dir):
88 | os.mkdir(out_dir)
89 |
90 | camera_dict_file = os.path.join(out_dir, 'kai_cameras.json')
91 | xyz_file = os.path.join(out_dir, 'kai_points.txt')
92 | track_file = os.path.join(out_dir, 'kai_tracks.json')
93 | keypoints_file = os.path.join(out_dir, 'kai_keypoints.json')
94 |
95 | colmap_cameras, colmap_images, colmap_points3D = read_model(sparse_dir, ext)
96 |
97 | camera_dict = parse_camera_dict(colmap_cameras, colmap_images)
98 | with open(camera_dict_file, 'w') as fp:
99 | json.dump(camera_dict, fp, indent=2, sort_keys=True)
100 |
101 | all_tracks, all_points, view_keypoints = parse_tracks(colmap_images, colmap_points3D)
102 | all_points = np.array(all_points)
103 | np.savetxt(xyz_file, all_points, header='# format: x, y, z, reproj_err, track_len, color(RGB)', fmt='%.6f')
104 |
105 | mesh = trimesh.Trimesh(vertices=all_points[:, :3].astype(np.float32),
106 | vertex_colors=all_points[:, -3:].astype(np.uint8))
107 | mesh.export(os.path.join(out_dir, 'kai_points.ply'))
108 |
109 | with open(track_file, 'w') as fp:
110 | json.dump(all_tracks, fp)
111 |
112 | with open(keypoints_file, 'w') as fp:
113 | json.dump(view_keypoints, fp)
114 |
115 |
116 | if __name__ == '__main__':
117 | mvs_dir = '/home/zhangka2/sg_render/run_mvs/scan114_train_5/colmap_mvs/mvs'
118 | sparse_dir = os.path.join(mvs_dir, 'sparse')
119 | out_dir = os.path.join(mvs_dir, 'sparse_inspect')
120 | extract_all_to_dir(sparse_dir, out_dir)
121 |
122 | xyz_file = os.path.join(out_dir, 'kai_points.txt')
123 | reproj_errs = np.loadtxt(xyz_file)[:, 3]
124 | with open(os.path.join(out_dir, 'stats.txt'), 'w') as fp:
125 | fp.write('reprojection errors (px) in SfM:\n')
126 | fp.write(' percentile value\n')
127 | for a in [50, 70, 90, 99]:
128 | fp.write(' {} {:.3f}\n'.format(a, np.percentile(reproj_errs, a)))
129 |
130 | print('reprojection errors (px) in SfM:')
131 | print(' percentile value')
132 | for a in [50, 70, 90, 99]:
133 | print(' {} {:.3f}'.format(a, np.percentile(reproj_errs, a)))
--------------------------------------------------------------------------------
/colmap_runner/normalize_cam_dict.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import copy
4 | import open3d as o3d
5 |
6 |
7 | def get_tf_cams(cam_dict, target_radius=1.):
8 | cam_centers = []
9 | for im_name in cam_dict:
10 | W2C = np.array(cam_dict[im_name]['W2C']).reshape((4, 4))
11 | C2W = np.linalg.inv(W2C)
12 | cam_centers.append(C2W[:3, 3:4])
13 |
14 | def get_center_and_diag(cam_centers):
15 | cam_centers = np.hstack(cam_centers)
16 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
17 | center = avg_cam_center
18 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
19 | diagonal = np.max(dist)
20 | return center.flatten(), diagonal
21 |
22 | center, diagonal = get_center_and_diag(cam_centers)
23 | radius = diagonal * 1.1
24 |
25 | translate = -center
26 | scale = target_radius / radius
27 |
28 | return translate, scale
29 |
30 |
31 | def normalize_cam_dict(in_cam_dict_file, out_cam_dict_file, target_radius=1., in_geometry_file=None, out_geometry_file=None):
32 | with open(in_cam_dict_file) as fp:
33 | in_cam_dict = json.load(fp)
34 |
35 | translate, scale = get_tf_cams(in_cam_dict, target_radius=target_radius)
36 |
37 | if in_geometry_file is not None and out_geometry_file is not None:
38 | # check this page if you encounter issue in file io: http://www.open3d.org/docs/0.9.0/tutorial/Basic/file_io.html
39 | geometry = o3d.io.read_triangle_mesh(in_geometry_file)
40 |
41 | tf_translate = np.eye(4)
42 | tf_translate[:3, 3:4] = translate
43 | tf_scale = np.eye(4)
44 | tf_scale[:3, :3] *= scale
45 | tf = np.matmul(tf_scale, tf_translate)
46 |
47 | geometry_norm = geometry.transform(tf)
48 | o3d.io.write_triangle_mesh(out_geometry_file, geometry_norm)
49 |
50 | def transform_pose(W2C, translate, scale):
51 | C2W = np.linalg.inv(W2C)
52 | cam_center = C2W[:3, 3]
53 | cam_center = (cam_center + translate) * scale
54 | C2W[:3, 3] = cam_center
55 | return np.linalg.inv(C2W)
56 |
57 | out_cam_dict = copy.deepcopy(in_cam_dict)
58 | for img_name in out_cam_dict:
59 | W2C = np.array(out_cam_dict[img_name]['W2C']).reshape((4, 4))
60 | W2C = transform_pose(W2C, translate, scale)
61 | assert(np.isclose(np.linalg.det(W2C[:3, :3]), 1.))
62 | out_cam_dict[img_name]['W2C'] = list(W2C.flatten())
63 |
64 | with open(out_cam_dict_file, 'w') as fp:
65 | json.dump(out_cam_dict, fp, indent=2, sort_keys=True)
66 |
67 |
68 | if __name__ == '__main__':
69 | in_cam_dict_file = ''
70 | out_cam_dict_file = ''
71 | normalize_cam_dict(in_cam_dict_file, out_cam_dict_file, target_radius=1.)
72 |
--------------------------------------------------------------------------------
/colmap_runner/read_write_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright
8 | # notice, this list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 | # its contributors may be used to endorse or promote products derived
16 | # from this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 | # POSSIBILITY OF SUCH DAMAGE.
29 | #
30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31 |
32 | import os
33 | import sys
34 | import collections
35 | import numpy as np
36 | import struct
37 | import argparse
38 |
39 |
40 | CameraModel = collections.namedtuple(
41 | "CameraModel", ["model_id", "model_name", "num_params"])
42 | Camera = collections.namedtuple(
43 | "Camera", ["id", "model", "width", "height", "params"])
44 | BaseImage = collections.namedtuple(
45 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
46 | Point3D = collections.namedtuple(
47 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
48 |
49 |
50 | class Image(BaseImage):
51 | def qvec2rotmat(self):
52 | return qvec2rotmat(self.qvec)
53 |
54 |
55 | CAMERA_MODELS = {
56 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
57 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
58 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
59 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
60 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
61 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
62 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
63 | CameraModel(model_id=7, model_name="FOV", num_params=5),
64 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
65 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
66 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
67 | }
68 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
69 | for camera_model in CAMERA_MODELS])
70 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
71 | for camera_model in CAMERA_MODELS])
72 |
73 |
74 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
75 | """Read and unpack the next bytes from a binary file.
76 | :param fid:
77 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
78 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
79 | :param endian_character: Any of {@, =, <, >, !}
80 | :return: Tuple of read and unpacked values.
81 | """
82 | data = fid.read(num_bytes)
83 | return struct.unpack(endian_character + format_char_sequence, data)
84 |
85 |
86 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
87 | """pack and write to a binary file.
88 | :param fid:
89 | :param data: data to send, if multiple elements are sent at the same time,
90 | they should be encapsuled either in a list or a tuple
91 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
92 | should be the same length as the data list or tuple
93 | :param endian_character: Any of {@, =, <, >, !}
94 | """
95 | if isinstance(data, (list, tuple)):
96 | bytes = struct.pack(endian_character + format_char_sequence, *data)
97 | else:
98 | bytes = struct.pack(endian_character + format_char_sequence, data)
99 | fid.write(bytes)
100 |
101 |
102 | def read_cameras_text(path):
103 | """
104 | see: src/base/reconstruction.cc
105 | void Reconstruction::WriteCamerasText(const std::string& path)
106 | void Reconstruction::ReadCamerasText(const std::string& path)
107 | """
108 | cameras = {}
109 | with open(path, "r") as fid:
110 | while True:
111 | line = fid.readline()
112 | if not line:
113 | break
114 | line = line.strip()
115 | if len(line) > 0 and line[0] != "#":
116 | elems = line.split()
117 | camera_id = int(elems[0])
118 | model = elems[1]
119 | width = int(elems[2])
120 | height = int(elems[3])
121 | params = np.array(tuple(map(float, elems[4:])))
122 | cameras[camera_id] = Camera(id=camera_id, model=model,
123 | width=width, height=height,
124 | params=params)
125 | return cameras
126 |
127 |
128 | def read_cameras_binary(path_to_model_file):
129 | """
130 | see: src/base/reconstruction.cc
131 | void Reconstruction::WriteCamerasBinary(const std::string& path)
132 | void Reconstruction::ReadCamerasBinary(const std::string& path)
133 | """
134 | cameras = {}
135 | with open(path_to_model_file, "rb") as fid:
136 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
137 | for camera_line_index in range(num_cameras):
138 | camera_properties = read_next_bytes(
139 | fid, num_bytes=24, format_char_sequence="iiQQ")
140 | camera_id = camera_properties[0]
141 | model_id = camera_properties[1]
142 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
143 | width = camera_properties[2]
144 | height = camera_properties[3]
145 | num_params = CAMERA_MODEL_IDS[model_id].num_params
146 | params = read_next_bytes(fid, num_bytes=8*num_params,
147 | format_char_sequence="d"*num_params)
148 | cameras[camera_id] = Camera(id=camera_id,
149 | model=model_name,
150 | width=width,
151 | height=height,
152 | params=np.array(params))
153 | assert len(cameras) == num_cameras
154 | return cameras
155 |
156 |
157 | def write_cameras_text(cameras, path):
158 | """
159 | see: src/base/reconstruction.cc
160 | void Reconstruction::WriteCamerasText(const std::string& path)
161 | void Reconstruction::ReadCamerasText(const std::string& path)
162 | """
163 | HEADER = '# Camera list with one line of data per camera:\n'
164 | '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n'
165 | '# Number of cameras: {}\n'.format(len(cameras))
166 | with open(path, "w") as fid:
167 | fid.write(HEADER)
168 | for _, cam in cameras.items():
169 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
170 | line = " ".join([str(elem) for elem in to_write])
171 | fid.write(line + "\n")
172 |
173 |
174 | def write_cameras_binary(cameras, path_to_model_file):
175 | """
176 | see: src/base/reconstruction.cc
177 | void Reconstruction::WriteCamerasBinary(const std::string& path)
178 | void Reconstruction::ReadCamerasBinary(const std::string& path)
179 | """
180 | with open(path_to_model_file, "wb") as fid:
181 | write_next_bytes(fid, len(cameras), "Q")
182 | for _, cam in cameras.items():
183 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id
184 | camera_properties = [cam.id,
185 | model_id,
186 | cam.width,
187 | cam.height]
188 | write_next_bytes(fid, camera_properties, "iiQQ")
189 | for p in cam.params:
190 | write_next_bytes(fid, float(p), "d")
191 | return cameras
192 |
193 |
194 | def read_images_text(path):
195 | """
196 | see: src/base/reconstruction.cc
197 | void Reconstruction::ReadImagesText(const std::string& path)
198 | void Reconstruction::WriteImagesText(const std::string& path)
199 | """
200 | images = {}
201 | with open(path, "r") as fid:
202 | while True:
203 | line = fid.readline()
204 | if not line:
205 | break
206 | line = line.strip()
207 | if len(line) > 0 and line[0] != "#":
208 | elems = line.split()
209 | image_id = int(elems[0])
210 | qvec = np.array(tuple(map(float, elems[1:5])))
211 | tvec = np.array(tuple(map(float, elems[5:8])))
212 | camera_id = int(elems[8])
213 | image_name = elems[9]
214 | elems = fid.readline().split()
215 | xys = np.column_stack([tuple(map(float, elems[0::3])),
216 | tuple(map(float, elems[1::3]))])
217 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
218 | images[image_id] = Image(
219 | id=image_id, qvec=qvec, tvec=tvec,
220 | camera_id=camera_id, name=image_name,
221 | xys=xys, point3D_ids=point3D_ids)
222 | return images
223 |
224 |
225 | def read_images_binary(path_to_model_file):
226 | """
227 | see: src/base/reconstruction.cc
228 | void Reconstruction::ReadImagesBinary(const std::string& path)
229 | void Reconstruction::WriteImagesBinary(const std::string& path)
230 | """
231 | images = {}
232 | with open(path_to_model_file, "rb") as fid:
233 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
234 | for image_index in range(num_reg_images):
235 | binary_image_properties = read_next_bytes(
236 | fid, num_bytes=64, format_char_sequence="idddddddi")
237 | image_id = binary_image_properties[0]
238 | qvec = np.array(binary_image_properties[1:5])
239 | tvec = np.array(binary_image_properties[5:8])
240 | camera_id = binary_image_properties[8]
241 | image_name = ""
242 | current_char = read_next_bytes(fid, 1, "c")[0]
243 | while current_char != b"\x00": # look for the ASCII 0 entry
244 | image_name += current_char.decode("utf-8")
245 | current_char = read_next_bytes(fid, 1, "c")[0]
246 | num_points2D = read_next_bytes(fid, num_bytes=8,
247 | format_char_sequence="Q")[0]
248 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
249 | format_char_sequence="ddq"*num_points2D)
250 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
251 | tuple(map(float, x_y_id_s[1::3]))])
252 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
253 | images[image_id] = Image(
254 | id=image_id, qvec=qvec, tvec=tvec,
255 | camera_id=camera_id, name=image_name,
256 | xys=xys, point3D_ids=point3D_ids)
257 | return images
258 |
259 |
260 | def write_images_text(images, path):
261 | """
262 | see: src/base/reconstruction.cc
263 | void Reconstruction::ReadImagesText(const std::string& path)
264 | void Reconstruction::WriteImagesText(const std::string& path)
265 | """
266 | if len(images) == 0:
267 | mean_observations = 0
268 | else:
269 | mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images)
270 | HEADER = '# Image list with two lines of data per image:\n'
271 | '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
272 | '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
273 | '# Number of images: {}, mean observations per image: {}\n'.format(len(images), mean_observations)
274 |
275 | with open(path, "w") as fid:
276 | fid.write(HEADER)
277 | for _, img in images.items():
278 | image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
279 | first_line = " ".join(map(str, image_header))
280 | fid.write(first_line + "\n")
281 |
282 | points_strings = []
283 | for xy, point3D_id in zip(img.xys, img.point3D_ids):
284 | points_strings.append(" ".join(map(str, [*xy, point3D_id])))
285 | fid.write(" ".join(points_strings) + "\n")
286 |
287 |
288 | def write_images_binary(images, path_to_model_file):
289 | """
290 | see: src/base/reconstruction.cc
291 | void Reconstruction::ReadImagesBinary(const std::string& path)
292 | void Reconstruction::WriteImagesBinary(const std::string& path)
293 | """
294 | with open(path_to_model_file, "wb") as fid:
295 | write_next_bytes(fid, len(images), "Q")
296 | for _, img in images.items():
297 | write_next_bytes(fid, img.id, "i")
298 | write_next_bytes(fid, img.qvec.tolist(), "dddd")
299 | write_next_bytes(fid, img.tvec.tolist(), "ddd")
300 | write_next_bytes(fid, img.camera_id, "i")
301 | for char in img.name:
302 | write_next_bytes(fid, char.encode("utf-8"), "c")
303 | write_next_bytes(fid, b"\x00", "c")
304 | write_next_bytes(fid, len(img.point3D_ids), "Q")
305 | for xy, p3d_id in zip(img.xys, img.point3D_ids):
306 | write_next_bytes(fid, [*xy, p3d_id], "ddq")
307 |
308 |
309 | def read_points3D_text(path):
310 | """
311 | see: src/base/reconstruction.cc
312 | void Reconstruction::ReadPoints3DText(const std::string& path)
313 | void Reconstruction::WritePoints3DText(const std::string& path)
314 | """
315 | points3D = {}
316 | with open(path, "r") as fid:
317 | while True:
318 | line = fid.readline()
319 | if not line:
320 | break
321 | line = line.strip()
322 | if len(line) > 0 and line[0] != "#":
323 | elems = line.split()
324 | point3D_id = int(elems[0])
325 | xyz = np.array(tuple(map(float, elems[1:4])))
326 | rgb = np.array(tuple(map(int, elems[4:7])))
327 | error = float(elems[7])
328 | image_ids = np.array(tuple(map(int, elems[8::2])))
329 | point2D_idxs = np.array(tuple(map(int, elems[9::2])))
330 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
331 | error=error, image_ids=image_ids,
332 | point2D_idxs=point2D_idxs)
333 | return points3D
334 |
335 |
336 | def read_points3d_binary(path_to_model_file):
337 | """
338 | see: src/base/reconstruction.cc
339 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
340 | void Reconstruction::WritePoints3DBinary(const std::string& path)
341 | """
342 | points3D = {}
343 | with open(path_to_model_file, "rb") as fid:
344 | num_points = read_next_bytes(fid, 8, "Q")[0]
345 | for point_line_index in range(num_points):
346 | binary_point_line_properties = read_next_bytes(
347 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
348 | point3D_id = binary_point_line_properties[0]
349 | xyz = np.array(binary_point_line_properties[1:4])
350 | rgb = np.array(binary_point_line_properties[4:7])
351 | error = np.array(binary_point_line_properties[7])
352 | track_length = read_next_bytes(
353 | fid, num_bytes=8, format_char_sequence="Q")[0]
354 | track_elems = read_next_bytes(
355 | fid, num_bytes=8*track_length,
356 | format_char_sequence="ii"*track_length)
357 | image_ids = np.array(tuple(map(int, track_elems[0::2])))
358 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
359 | points3D[point3D_id] = Point3D(
360 | id=point3D_id, xyz=xyz, rgb=rgb,
361 | error=error, image_ids=image_ids,
362 | point2D_idxs=point2D_idxs)
363 | return points3D
364 |
365 |
366 | def write_points3D_text(points3D, path):
367 | """
368 | see: src/base/reconstruction.cc
369 | void Reconstruction::ReadPoints3DText(const std::string& path)
370 | void Reconstruction::WritePoints3DText(const std::string& path)
371 | """
372 | if len(points3D) == 0:
373 | mean_track_length = 0
374 | else:
375 | mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
376 | HEADER = '# 3D point list with one line of data per point:\n'
377 | '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n'
378 | '# Number of points: {}, mean track length: {}\n'.format(len(points3D), mean_track_length)
379 |
380 | with open(path, "w") as fid:
381 | fid.write(HEADER)
382 | for _, pt in points3D.items():
383 | point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
384 | fid.write(" ".join(map(str, point_header)) + " ")
385 | track_strings = []
386 | for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
387 | track_strings.append(" ".join(map(str, [image_id, point2D])))
388 | fid.write(" ".join(track_strings) + "\n")
389 |
390 |
391 | def write_points3d_binary(points3D, path_to_model_file):
392 | """
393 | see: src/base/reconstruction.cc
394 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
395 | void Reconstruction::WritePoints3DBinary(const std::string& path)
396 | """
397 | with open(path_to_model_file, "wb") as fid:
398 | write_next_bytes(fid, len(points3D), "Q")
399 | for _, pt in points3D.items():
400 | write_next_bytes(fid, pt.id, "Q")
401 | write_next_bytes(fid, pt.xyz.tolist(), "ddd")
402 | write_next_bytes(fid, pt.rgb.tolist(), "BBB")
403 | write_next_bytes(fid, pt.error, "d")
404 | track_length = pt.image_ids.shape[0]
405 | write_next_bytes(fid, track_length, "Q")
406 | for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
407 | write_next_bytes(fid, [image_id, point2D_id], "ii")
408 |
409 |
410 | def read_model(path, ext):
411 | if ext == ".txt":
412 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
413 | images = read_images_text(os.path.join(path, "images" + ext))
414 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
415 | else:
416 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
417 | images = read_images_binary(os.path.join(path, "images" + ext))
418 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
419 | return cameras, images, points3D
420 |
421 |
422 | def write_model(cameras, images, points3D, path, ext):
423 | if ext == ".txt":
424 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
425 | write_images_text(images, os.path.join(path, "images" + ext))
426 | write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
427 | else:
428 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
429 | write_images_binary(images, os.path.join(path, "images" + ext))
430 | write_points3d_binary(points3D, os.path.join(path, "points3D") + ext)
431 | return cameras, images, points3D
432 |
433 |
434 | def qvec2rotmat(qvec):
435 | return np.array([
436 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
437 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
438 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
439 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
440 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
441 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
442 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
443 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
444 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
445 |
446 |
447 | def rotmat2qvec(R):
448 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
449 | K = np.array([
450 | [Rxx - Ryy - Rzz, 0, 0, 0],
451 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
452 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
453 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
454 | eigvals, eigvecs = np.linalg.eigh(K)
455 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
456 | if qvec[0] < 0:
457 | qvec *= -1
458 | return qvec
459 |
460 |
461 | def main():
462 | parser = argparse.ArgumentParser(description='Read and write COLMAP binary and text models')
463 | parser.add_argument('input_model', help='path to input model folder')
464 | parser.add_argument('input_format', choices=['.bin', '.txt'],
465 | help='input model format')
466 | parser.add_argument('--output_model', metavar='PATH',
467 | help='path to output model folder')
468 | parser.add_argument('--output_format', choices=['.bin', '.txt'],
469 | help='outut model format', default='.txt')
470 | args = parser.parse_args()
471 |
472 | cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
473 |
474 | print("num_cameras:", len(cameras))
475 | print("num_images:", len(images))
476 | print("num_points3D:", len(points3D))
477 |
478 | if args.output_model is not None:
479 | write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format)
480 |
481 |
482 | if __name__ == "__main__":
483 | main()
484 |
--------------------------------------------------------------------------------
/colmap_runner/run_colmap.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from extract_sfm import extract_all_to_dir
4 | from normalize_cam_dict import normalize_cam_dict
5 |
6 | #########################################################################
7 | # Note: configure the colmap_bin to the colmap executable on your machine
8 | #########################################################################
9 |
10 | def bash_run(cmd):
11 | colmap_bin = '/home/zhangka2/code/colmap/build/__install__/bin/colmap'
12 | cmd = colmap_bin + ' ' + cmd
13 | print('\nRunning cmd: ', cmd)
14 |
15 | subprocess.check_call(['/bin/bash', '-c', cmd])
16 |
17 |
18 | gpu_index = '-1'
19 |
20 |
21 | def run_sift_matching(img_dir, db_file, remove_exist=False):
22 | print('Running sift matching...')
23 |
24 | if remove_exist and os.path.exists(db_file):
25 | os.remove(db_file) # otherwise colmap will skip sift matching
26 |
27 | # feature extraction
28 | # if there's no attached display, cannot use feature extractor with GPU
29 | cmd = ' feature_extractor --database_path {} \
30 | --image_path {} \
31 | --ImageReader.single_camera 1 \
32 | --ImageReader.camera_model SIMPLE_RADIAL \
33 | --SiftExtraction.max_image_size 5000 \
34 | --SiftExtraction.estimate_affine_shape 0 \
35 | --SiftExtraction.domain_size_pooling 1 \
36 | --SiftExtraction.use_gpu 1 \
37 | --SiftExtraction.max_num_features 16384 \
38 | --SiftExtraction.gpu_index {}'.format(db_file, img_dir, gpu_index)
39 | bash_run(cmd)
40 |
41 | # feature matching
42 | cmd = ' exhaustive_matcher --database_path {} \
43 | --SiftMatching.guided_matching 1 \
44 | --SiftMatching.use_gpu 1 \
45 | --SiftMatching.max_num_matches 65536 \
46 | --SiftMatching.max_error 3 \
47 | --SiftMatching.gpu_index {}'.format(db_file, gpu_index)
48 |
49 | bash_run(cmd)
50 |
51 |
52 | def run_sfm(img_dir, db_file, out_dir):
53 | print('Running SfM...')
54 |
55 | cmd = ' mapper \
56 | --database_path {} \
57 | --image_path {} \
58 | --output_path {} \
59 | --Mapper.tri_min_angle 3.0 \
60 | --Mapper.filter_min_tri_angle 3.0'.format(db_file, img_dir, out_dir)
61 |
62 | bash_run(cmd)
63 |
64 |
65 | def prepare_mvs(img_dir, sparse_dir, mvs_dir):
66 | print('Preparing for MVS...')
67 |
68 | cmd = ' image_undistorter \
69 | --image_path {} \
70 | --input_path {} \
71 | --output_path {} \
72 | --output_type COLMAP \
73 | --max_image_size 2000'.format(img_dir, sparse_dir, mvs_dir)
74 |
75 | bash_run(cmd)
76 |
77 |
78 | def run_photometric_mvs(mvs_dir, window_radius):
79 | print('Running photometric MVS...')
80 |
81 | cmd = ' patch_match_stereo --workspace_path {} \
82 | --PatchMatchStereo.window_radius {} \
83 | --PatchMatchStereo.min_triangulation_angle 3.0 \
84 | --PatchMatchStereo.filter 1 \
85 | --PatchMatchStereo.geom_consistency 1 \
86 | --PatchMatchStereo.gpu_index={} \
87 | --PatchMatchStereo.num_samples 15 \
88 | --PatchMatchStereo.num_iterations 12'.format(mvs_dir,
89 | window_radius, gpu_index)
90 |
91 | bash_run(cmd)
92 |
93 |
94 | def run_fuse(mvs_dir, out_ply):
95 | print('Running depth fusion...')
96 |
97 | cmd = ' stereo_fusion --workspace_path {} \
98 | --output_path {} \
99 | --input_type geometric'.format(mvs_dir, out_ply)
100 |
101 | bash_run(cmd)
102 |
103 |
104 | def run_possion_mesher(in_ply, out_ply, trim):
105 | print('Running possion mesher...')
106 |
107 | cmd = ' poisson_mesher \
108 | --input_path {} \
109 | --output_path {} \
110 | --PoissonMeshing.trim {}'.format(in_ply, out_ply, trim)
111 |
112 | bash_run(cmd)
113 |
114 |
115 | def main(img_dir, out_dir, run_mvs=False):
116 | os.makedirs(out_dir, exist_ok=True)
117 |
118 | #### run sfm
119 | sfm_dir = os.path.join(out_dir, 'sfm')
120 | os.makedirs(sfm_dir, exist_ok=True)
121 |
122 | img_dir_link = os.path.join(sfm_dir, 'images')
123 | if os.path.exists(img_dir_link):
124 | os.remove(img_dir_link)
125 | os.symlink(img_dir, img_dir_link)
126 |
127 | db_file = os.path.join(sfm_dir, 'database.db')
128 | run_sift_matching(img_dir, db_file, remove_exist=False)
129 | sparse_dir = os.path.join(sfm_dir, 'sparse')
130 | os.makedirs(sparse_dir, exist_ok=True)
131 | run_sfm(img_dir, db_file, sparse_dir)
132 |
133 | # undistort images
134 | mvs_dir = os.path.join(out_dir, 'mvs')
135 | os.makedirs(mvs_dir, exist_ok=True)
136 | prepare_mvs(img_dir, sparse_dir, mvs_dir)
137 |
138 | # extract camera parameters and undistorted images
139 | os.makedirs(os.path.join(out_dir, 'posed_images'), exist_ok=True)
140 | extract_all_to_dir(os.path.join(mvs_dir, 'sparse'), os.path.join(out_dir, 'posed_images'))
141 | undistorted_img_dir = os.path.join(mvs_dir, 'images')
142 | posed_img_dir_link = os.path.join(out_dir, 'posed_images/images')
143 | if os.path.exists(posed_img_dir_link):
144 | os.remove(posed_img_dir_link)
145 | os.symlink(undistorted_img_dir, posed_img_dir_link)
146 | # normalize average camera center to origin, and put all cameras inside the unit sphere
147 | normalize_cam_dict(os.path.join(out_dir, 'posed_images/kai_cameras.json'),
148 | os.path.join(out_dir, 'posed_images/kai_cameras_normalized.json'))
149 |
150 | if run_mvs:
151 | # run mvs
152 | run_photometric_mvs(mvs_dir, window_radius=7)
153 |
154 | out_ply = os.path.join(out_dir, 'mvs/fused.ply')
155 | run_fuse(mvs_dir, out_ply)
156 |
157 | out_mesh_ply = os.path.join(out_dir, 'mvs/meshed_trim_3.ply')
158 | run_possion_mesher(out_ply, out_mesh_ply, trim=3)
159 |
160 |
161 | if __name__ == '__main__':
162 | ### note: this script is intended for the case where all images are taken by the same camera, i.e., intrinisics are shared.
163 |
164 | img_dir = ''
165 | out_dir = ''
166 | run_mvs = False
167 | main(img_dir, out_dir, run_mvs=run_mvs)
168 |
169 |
--------------------------------------------------------------------------------
/colmap_runner/run_colmap_posed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from database import COLMAPDatabase
4 | from pyquaternion import Quaternion
5 | import numpy as np
6 | import imageio
7 | import subprocess
8 |
9 |
10 | def bash_run(cmd):
11 | # local install of colmap
12 | env = os.environ.copy()
13 | env['LD_LIBRARY_PATH'] = '/home/zhangka2/code/colmap/build/__install__/lib'
14 |
15 | colmap_bin = '/home/zhangka2/code/colmap/build/__install__/bin/colmap'
16 | cmd = colmap_bin + ' ' + cmd
17 | print('\nRunning cmd: ', cmd)
18 |
19 | subprocess.check_call(['/bin/bash', '-c', cmd], env=env)
20 |
21 |
22 | gpu_index = '-1'
23 |
24 |
25 | def run_sift_matching(img_dir, db_file):
26 | print('Running sift matching...')
27 |
28 | # if os.path.exists(db_file): # otherwise colmap will skip sift matching
29 | # os.remove(db_file)
30 |
31 | # feature extraction
32 | # if there's no attached display, cannot use feature extractor with GPU
33 | cmd = ' feature_extractor --database_path {} \
34 | --image_path {} \
35 | --ImageReader.camera_model PINHOLE \
36 | --SiftExtraction.max_image_size 5000 \
37 | --SiftExtraction.estimate_affine_shape 0 \
38 | --SiftExtraction.domain_size_pooling 1 \
39 | --SiftExtraction.num_threads 32 \
40 | --SiftExtraction.use_gpu 1 \
41 | --SiftExtraction.gpu_index {}'.format(db_file, img_dir, gpu_index)
42 | bash_run(cmd)
43 |
44 | # feature matching
45 | cmd = ' exhaustive_matcher --database_path {} \
46 | --SiftMatching.guided_matching 1 \
47 | --SiftMatching.use_gpu 1 \
48 | --SiftMatching.gpu_index {}'.format(db_file, gpu_index)
49 |
50 | bash_run(cmd)
51 |
52 |
53 | def create_init_files(pinhole_dict_file, db_file, out_dir):
54 | if not os.path.exists(out_dir):
55 | os.mkdir(out_dir)
56 |
57 | # create template
58 | with open(pinhole_dict_file) as fp:
59 | pinhole_dict = json.load(fp)
60 |
61 | template = {}
62 | cameras_line_template = '{camera_id} PINHOLE {width} {height} {fx} {fy} {cx} {cy}\n'
63 | images_line_template = '{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {image_name}\n\n'
64 |
65 | for img_name in pinhole_dict:
66 | # w, h, fx, fy, cx, cy, qvec, t
67 | params = pinhole_dict[img_name]
68 | w = params[0]
69 | h = params[1]
70 | fx = params[2]
71 | fy = params[3]
72 | cx = params[4]
73 | cy = params[5]
74 | qvec = params[6:10]
75 | tvec = params[10:13]
76 |
77 | cam_line = cameras_line_template.format(camera_id="{camera_id}", width=w, height=h, fx=fx, fy=fy, cx=cx, cy=cy)
78 | img_line = images_line_template.format(image_id="{image_id}", qw=qvec[0], qx=qvec[1], qy=qvec[2], qz=qvec[3],
79 | tx=tvec[0], ty=tvec[1], tz=tvec[2], camera_id="{camera_id}", image_name=img_name)
80 | template[img_name] = (cam_line, img_line)
81 |
82 | # read database
83 | db = COLMAPDatabase.connect(db_file)
84 | table_images = db.execute("SELECT * FROM images")
85 | img_name2id_dict = {}
86 | for row in table_images:
87 | img_name2id_dict[row[1]] = row[0]
88 |
89 | cameras_txt_lines = []
90 | images_txt_lines = []
91 | for img_name, img_id in img_name2id_dict.items():
92 | camera_line = template[img_name][0].format(camera_id=img_id)
93 | cameras_txt_lines.append(camera_line)
94 |
95 | image_line = template[img_name][1].format(image_id=img_id, camera_id=img_id)
96 | images_txt_lines.append(image_line)
97 |
98 | with open(os.path.join(out_dir, 'cameras.txt'), 'w') as fp:
99 | fp.writelines(cameras_txt_lines)
100 |
101 | with open(os.path.join(out_dir, 'images.txt'), 'w') as fp:
102 | fp.writelines(images_txt_lines)
103 | fp.write('\n')
104 |
105 | # create an empty points3D.txt
106 | fp = open(os.path.join(out_dir, 'points3D.txt'), 'w')
107 | fp.close()
108 |
109 |
110 | def run_point_triangulation(img_dir, db_file, out_dir):
111 | print('Running point triangulation...')
112 |
113 | # triangulate points
114 | cmd = ' point_triangulator --database_path {} \
115 | --image_path {} \
116 | --input_path {} \
117 | --output_path {} \
118 | --Mapper.tri_ignore_two_view_tracks 1'.format(db_file, img_dir, out_dir, out_dir)
119 | bash_run(cmd)
120 |
121 |
122 | # this step is optional
123 | def run_global_ba(in_dir, out_dir):
124 | print('Running global BA...')
125 | if not os.path.exists(out_dir):
126 | os.mkdir(out_dir)
127 |
128 | cmd = ' bundle_adjuster --input_path {in_dir} --output_path {out_dir}'.format(in_dir=in_dir, out_dir=out_dir)
129 |
130 | bash_run(cmd)
131 |
132 |
133 | def prepare_mvs(img_dir, sfm_dir, mvs_dir):
134 | if not os.path.exists(mvs_dir):
135 | os.mkdir(mvs_dir)
136 |
137 | images_symlink = os.path.join(mvs_dir, 'images')
138 | if os.path.exists(images_symlink):
139 | os.unlink(images_symlink)
140 | os.symlink(os.path.relpath(img_dir, mvs_dir),
141 | images_symlink)
142 |
143 | sparse_symlink = os.path.join(mvs_dir, 'sparse')
144 | if os.path.exists(sparse_symlink):
145 | os.unlink(sparse_symlink)
146 | os.symlink(os.path.relpath(sfm_dir, mvs_dir),
147 | sparse_symlink)
148 |
149 | # prepare stereo directory
150 | stereo_dir = os.path.join(mvs_dir, 'stereo')
151 | for subdir in [stereo_dir,
152 | os.path.join(stereo_dir, 'depth_maps'),
153 | os.path.join(stereo_dir, 'normal_maps'),
154 | os.path.join(stereo_dir, 'consistency_graphs')]:
155 | if not os.path.exists(subdir):
156 | os.mkdir(subdir)
157 |
158 | # write patch-match.cfg and fusion.cfg
159 | image_names = sorted(os.listdir(os.path.join(mvs_dir, 'images')))
160 |
161 | with open(os.path.join(stereo_dir, 'patch-match.cfg'), 'w') as fp:
162 | for img_name in image_names:
163 | fp.write(img_name + '\n__auto__, 20\n')
164 |
165 | # use all images
166 | # fp.write(img_name + '\n__all__\n')
167 |
168 | # randomly choose 20 images
169 | # from random import shuffle
170 | # candi_src_images = [x for x in image_names if x != img_name]
171 | # shuffle(candi_src_images)
172 | # max_src_images = 10
173 | # fp.write(img_name + '\n' + ', '.join(candi_src_images[:max_src_images]) + '\n')
174 |
175 | with open(os.path.join(stereo_dir, 'fusion.cfg'), 'w') as fp:
176 | for img_name in image_names:
177 | fp.write(img_name + '\n')
178 |
179 |
180 | def run_photometric_mvs(mvs_dir, window_radius):
181 | print('Running photometric MVS...')
182 |
183 | cmd = ' patch_match_stereo --workspace_path {} \
184 | --PatchMatchStereo.window_radius {} \
185 | --PatchMatchStereo.min_triangulation_angle 3.0 \
186 | --PatchMatchStereo.filter 1 \
187 | --PatchMatchStereo.geom_consistency 1 \
188 | --PatchMatchStereo.gpu_index={} \
189 | --PatchMatchStereo.num_samples 15 \
190 | --PatchMatchStereo.num_iterations 12'.format(mvs_dir,
191 | window_radius, gpu_index)
192 |
193 | bash_run(cmd)
194 |
195 |
196 | def run_fuse(mvs_dir, out_ply):
197 | print('Running depth fusion...')
198 |
199 | cmd = ' stereo_fusion --workspace_path {} \
200 | --output_path {} \
201 | --input_type geometric'.format(mvs_dir, out_ply)
202 | bash_run(cmd)
203 |
204 |
205 | def run_possion_mesher(in_ply, out_ply, trim):
206 | print('Running possion mesher...')
207 |
208 | cmd = ' poisson_mesher \
209 | --input_path {} \
210 | --output_path {} \
211 | --PoissonMeshing.trim {}'.format(in_ply, out_ply, trim)
212 |
213 | bash_run(cmd)
214 |
215 |
216 | def main(img_dir, pinhole_dict_file, out_dir):
217 | if not os.path.exists(out_dir):
218 | os.mkdir(out_dir)
219 |
220 | db_file = os.path.join(out_dir, 'database.db')
221 | run_sift_matching(img_dir, db_file)
222 |
223 | sfm_dir = os.path.join(out_dir, 'sfm')
224 | create_init_files(pinhole_dict_file, db_file, sfm_dir)
225 | run_point_triangulation(img_dir, db_file, sfm_dir)
226 |
227 | # # optional
228 | # run_global_ba(sfm_dir, sfm_dir)
229 |
230 | mvs_dir = os.path.join(out_dir, 'mvs')
231 | prepare_mvs(img_dir, sfm_dir, mvs_dir)
232 | run_photometric_mvs(mvs_dir, window_radius=5)
233 |
234 | out_ply = os.path.join(out_dir, 'fused.ply')
235 | run_fuse(mvs_dir, out_ply)
236 |
237 | out_mesh_ply = os.path.join(out_dir, 'meshed_trim_3.ply')
238 | run_possion_mesher(out_ply, out_mesh_ply, trim=3)
239 |
240 |
241 | def convert_cam_dict_to_pinhole_dict(cam_dict_file, pinhole_dict_file, img_dir):
242 | print('Writing pinhole_dict to: ', pinhole_dict_file)
243 |
244 | with open(cam_dict_file) as fp:
245 | cam_dict = json.load(fp)
246 |
247 | pinhole_dict = {}
248 | for img_name in cam_dict:
249 | data_item = cam_dict[img_name]
250 | if 'img_size' in data_item:
251 | w, h = data_item['img_size']
252 | else:
253 | im = imageio.imread(os.path.join(img_dir, img_name))
254 | h, w = im.shape[:2]
255 |
256 | K = np.array(data_item['K']).reshape((4, 4))
257 | W2C = np.array(data_item['W2C']).reshape((4, 4))
258 |
259 | # params
260 | fx = K[0, 0]
261 | fy = K[1, 1]
262 | assert(np.isclose(K[0, 1], 0.))
263 | cx = K[0, 2]
264 | cy = K[1, 2]
265 |
266 | print(img_name)
267 | R = W2C[:3, :3]
268 | print(R)
269 | u, s_old, vh = np.linalg.svd(R, full_matrices=False)
270 | s = np.round(s_old)
271 | print('s: {} ---> {}'.format(s_old, s))
272 | R = np.dot(u * s, vh)
273 |
274 | qvec = Quaternion(matrix=R)
275 | tvec = W2C[:3, 3]
276 |
277 | params = [w, h, fx, fy, cx, cy,
278 | qvec[0], qvec[1], qvec[2], qvec[3],
279 | tvec[0], tvec[1], tvec[2]]
280 | pinhole_dict[img_name] = params
281 |
282 | with open(pinhole_dict_file, 'w') as fp:
283 | json.dump(pinhole_dict, fp, indent=2, sort_keys=True)
284 |
285 |
286 | if __name__ == '__main__':
287 | img_dir = ''
288 | cam_dict_file = ''
289 | out_dir = ''
290 |
291 | os.makedirs(out_dir, exist_ok=True)
292 | pinhole_dict_file = os.path.join(out_dir, 'pinhole_dict.json')
293 | convert_cam_dict_to_pinhole_dict(cam_dict_file, pinhole_dict_file, img_dir)
294 |
295 | main(img_dir, pinhole_dict_file, out_dir)
296 |
--------------------------------------------------------------------------------
/configs/lf_data/lf_africa.txt:
--------------------------------------------------------------------------------
1 | ### INPUT
2 | datadir = ./data/lf_data
3 | scene = africa
4 | expname = africa
5 | basedir = ./logs
6 | config = None
7 | ckpt_path = None
8 | no_reload = False
9 | testskip = 1
10 |
11 | ### TRAINING
12 | N_iters = 500001
13 | N_rand = 1024
14 | lrate = 0.0005
15 | lrate_decay_factor = 0.1
16 | lrate_decay_steps = 50000000
17 |
18 | ### CASCADE
19 | cascade_level = 2
20 | cascade_samples = 64,128
21 |
22 | ### TESTING
23 | chunk_size = 8192
24 |
25 | ### RENDERING
26 | det = False
27 | max_freq_log2 = 10
28 | max_freq_log2_viewdirs = 4
29 | netdepth = 8
30 | netwidth = 256
31 | use_viewdirs = True
32 |
33 | ### CONSOLE AND TENSORBOARD
34 | i_img = 2000
35 | i_print = 100
36 | i_weights = 5000
37 |
--------------------------------------------------------------------------------
/configs/tanks_and_temples/tat_training_truck.txt:
--------------------------------------------------------------------------------
1 | ### INPUT
2 | datadir = ./data/tanks_and_temples
3 | scene = tat_training_Truck
4 | expname = tat_training_Truck
5 | basedir = ./logs
6 | config = None
7 | ckpt_path = None
8 | no_reload = False
9 | testskip = 1
10 |
11 | ### TRAINING
12 | N_iters = 500001
13 | N_rand = 1024
14 | lrate = 0.0005
15 | lrate_decay_factor = 0.1
16 | lrate_decay_steps = 50000000
17 |
18 | ### CASCADE
19 | cascade_level = 2
20 | cascade_samples = 64,128
21 |
22 | ### TESTING
23 | chunk_size = 8192
24 |
25 | ### RENDERING
26 | det = False
27 | max_freq_log2 = 10
28 | max_freq_log2_viewdirs = 4
29 | netdepth = 8
30 | netwidth = 256
31 | use_viewdirs = True
32 |
33 | ### CONSOLE AND TENSORBOARD
34 | i_img = 2000
35 | i_print = 100
36 | i_weights = 5000
37 |
--------------------------------------------------------------------------------
/data_loader_split.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 | import logging
5 | from nerf_sample_ray_split import RaySamplerSingleImage
6 | import glob
7 |
8 | logger = logging.getLogger(__package__)
9 |
10 | ########################################################################################################################
11 | # camera coordinate system: x-->right, y-->down, z-->scene (opencv/colmap convention)
12 | # poses is camera-to-world
13 | ########################################################################################################################
14 | def find_files(dir, exts):
15 | if os.path.isdir(dir):
16 | # types should be ['*.png', '*.jpg']
17 | files_grabbed = []
18 | for ext in exts:
19 | files_grabbed.extend(glob.glob(os.path.join(dir, ext)))
20 | if len(files_grabbed) > 0:
21 | files_grabbed = sorted(files_grabbed)
22 | return files_grabbed
23 | else:
24 | return []
25 |
26 |
27 | def load_data_split(basedir, scene, split, skip=1, try_load_min_depth=True, only_img_files=False):
28 |
29 | def parse_txt(filename):
30 | assert os.path.isfile(filename)
31 | nums = open(filename).read().split()
32 | return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32)
33 |
34 | if basedir[-1] == '/': # remove trailing '/'
35 | basedir = basedir[:-1]
36 |
37 | split_dir = '{}/{}/{}'.format(basedir, scene, split)
38 |
39 | if only_img_files:
40 | img_files = find_files('{}/rgb'.format(split_dir), exts=['*.png', '*.jpg'])
41 | return img_files
42 |
43 | # camera parameters files
44 | intrinsics_files = find_files('{}/intrinsics'.format(split_dir), exts=['*.txt'])
45 | pose_files = find_files('{}/pose'.format(split_dir), exts=['*.txt'])
46 | logger.info('raw intrinsics_files: {}'.format(len(intrinsics_files)))
47 | logger.info('raw pose_files: {}'.format(len(pose_files)))
48 |
49 | intrinsics_files = intrinsics_files[::skip]
50 | pose_files = pose_files[::skip]
51 | cam_cnt = len(pose_files)
52 |
53 | # img files
54 | img_files = find_files('{}/rgb'.format(split_dir), exts=['*.png', '*.jpg'])
55 | if len(img_files) > 0:
56 | logger.info('raw img_files: {}'.format(len(img_files)))
57 | img_files = img_files[::skip]
58 | assert(len(img_files) == cam_cnt)
59 | else:
60 | img_files = [None, ] * cam_cnt
61 |
62 | # mask files
63 | mask_files = find_files('{}/mask'.format(split_dir), exts=['*.png', '*.jpg'])
64 | if len(mask_files) > 0:
65 | logger.info('raw mask_files: {}'.format(len(mask_files)))
66 | mask_files = mask_files[::skip]
67 | assert(len(mask_files) == cam_cnt)
68 | else:
69 | mask_files = [None, ] * cam_cnt
70 |
71 | # min depth files
72 | mindepth_files = find_files('{}/min_depth'.format(split_dir), exts=['*.png', '*.jpg'])
73 | if try_load_min_depth and len(mindepth_files) > 0:
74 | logger.info('raw mindepth_files: {}'.format(len(mindepth_files)))
75 | mindepth_files = mindepth_files[::skip]
76 | assert(len(mindepth_files) == cam_cnt)
77 | else:
78 | mindepth_files = [None, ] * cam_cnt
79 |
80 | # assume all images have the same size as training image
81 | train_imgfile = find_files('{}/{}/train/rgb'.format(basedir, scene), exts=['*.png', '*.jpg'])[0]
82 | train_im = imageio.imread(train_imgfile)
83 | H, W = train_im.shape[:2]
84 |
85 | # create ray samplers
86 | ray_samplers = []
87 | for i in range(cam_cnt):
88 | intrinsics = parse_txt(intrinsics_files[i])
89 | pose = parse_txt(pose_files[i])
90 |
91 | # read max depth
92 | try:
93 | max_depth = float(open('{}/max_depth.txt'.format(split_dir)).readline().strip())
94 | except:
95 | max_depth = None
96 |
97 | ray_samplers.append(RaySamplerSingleImage(H=H, W=W, intrinsics=intrinsics, c2w=pose,
98 | img_path=img_files[i],
99 | mask_path=mask_files[i],
100 | min_depth_path=mindepth_files[i],
101 | max_depth=max_depth))
102 |
103 | logger.info('Split {}, # views: {}'.format(split, cam_cnt))
104 |
105 | return ray_samplers
106 |
--------------------------------------------------------------------------------
/ddp_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # import torch.nn.functional as F
4 | # import numpy as np
5 | from utils import TINY_NUMBER, HUGE_NUMBER
6 | from collections import OrderedDict
7 | from nerf_network import Embedder, MLPNet
8 | import os
9 | import logging
10 | logger = logging.getLogger(__package__)
11 |
12 |
13 | ######################################################################################
14 | # wrapper to simplify the use of nerfnet
15 | ######################################################################################
16 | def depth2pts_outside(ray_o, ray_d, depth):
17 | '''
18 | ray_o, ray_d: [..., 3]
19 | depth: [...]; inverse of distance to sphere origin
20 | '''
21 | # note: d1 becomes negative if this mid point is behind camera
22 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
23 | p_mid = ray_o + d1.unsqueeze(-1) * ray_d
24 | p_mid_norm = torch.norm(p_mid, dim=-1)
25 | ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
26 | d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos
27 | p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d
28 |
29 | rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
30 | rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
31 | phi = torch.asin(p_mid_norm)
32 | theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
33 | rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
34 |
35 | # now rotate p_sphere
36 | # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
37 | p_sphere_new = p_sphere * torch.cos(rot_angle) + \
38 | torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
39 | rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))
40 | p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
41 | pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)
42 |
43 | # now calculate conventional depth
44 | depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1
45 | return pts, depth_real
46 |
47 |
48 | class NerfNet(nn.Module):
49 | def __init__(self, args):
50 | super().__init__()
51 | # foreground
52 | self.fg_embedder_position = Embedder(input_dim=3,
53 | max_freq_log2=args.max_freq_log2 - 1,
54 | N_freqs=args.max_freq_log2)
55 | self.fg_embedder_viewdir = Embedder(input_dim=3,
56 | max_freq_log2=args.max_freq_log2_viewdirs - 1,
57 | N_freqs=args.max_freq_log2_viewdirs)
58 | self.fg_net = MLPNet(D=args.netdepth, W=args.netwidth,
59 | input_ch=self.fg_embedder_position.out_dim,
60 | input_ch_viewdirs=self.fg_embedder_viewdir.out_dim,
61 | use_viewdirs=args.use_viewdirs)
62 | # background; bg_pt is (x, y, z, 1/r)
63 | self.bg_embedder_position = Embedder(input_dim=4,
64 | max_freq_log2=args.max_freq_log2 - 1,
65 | N_freqs=args.max_freq_log2)
66 | self.bg_embedder_viewdir = Embedder(input_dim=3,
67 | max_freq_log2=args.max_freq_log2_viewdirs - 1,
68 | N_freqs=args.max_freq_log2_viewdirs)
69 | self.bg_net = MLPNet(D=args.netdepth, W=args.netwidth,
70 | input_ch=self.bg_embedder_position.out_dim,
71 | input_ch_viewdirs=self.bg_embedder_viewdir.out_dim,
72 | use_viewdirs=args.use_viewdirs)
73 |
74 | def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals):
75 | '''
76 | :param ray_o, ray_d: [..., 3]
77 | :param fg_z_max: [...,]
78 | :param fg_z_vals, bg_z_vals: [..., N_samples]
79 | :return
80 | '''
81 | # print(ray_o.shape, ray_d.shape, fg_z_max.shape, fg_z_vals.shape, bg_z_vals.shape)
82 | ray_d_norm = torch.norm(ray_d, dim=-1, keepdim=True) # [..., 1]
83 | viewdirs = ray_d / ray_d_norm # [..., 3]
84 | dots_sh = list(ray_d.shape[:-1])
85 |
86 | ######### render foreground
87 | N_samples = fg_z_vals.shape[-1]
88 | fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
89 | fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
90 | fg_viewdirs = viewdirs.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
91 | fg_pts = fg_ray_o + fg_z_vals.unsqueeze(-1) * fg_ray_d
92 | input = torch.cat((self.fg_embedder_position(fg_pts),
93 | self.fg_embedder_viewdir(fg_viewdirs)), dim=-1)
94 | fg_raw = self.fg_net(input)
95 | # alpha blending
96 | fg_dists = fg_z_vals[..., 1:] - fg_z_vals[..., :-1]
97 | # account for view directions
98 | fg_dists = ray_d_norm * torch.cat((fg_dists, fg_z_max.unsqueeze(-1) - fg_z_vals[..., -1:]), dim=-1) # [..., N_samples]
99 | fg_alpha = 1. - torch.exp(-fg_raw['sigma'] * fg_dists) # [..., N_samples]
100 | T = torch.cumprod(1. - fg_alpha + TINY_NUMBER, dim=-1) # [..., N_samples]
101 | bg_lambda = T[..., -1]
102 | T = torch.cat((torch.ones_like(T[..., 0:1]), T[..., :-1]), dim=-1) # [..., N_samples]
103 | fg_weights = fg_alpha * T # [..., N_samples]
104 | fg_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['rgb'], dim=-2) # [..., 3]
105 | fg_depth_map = torch.sum(fg_weights * fg_z_vals, dim=-1) # [...,]
106 |
107 | # render background
108 | N_samples = bg_z_vals.shape[-1]
109 | bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
110 | bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
111 | bg_viewdirs = viewdirs.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
112 | bg_pts, _ = depth2pts_outside(bg_ray_o, bg_ray_d, bg_z_vals) # [..., N_samples, 4]
113 | input = torch.cat((self.bg_embedder_position(bg_pts),
114 | self.bg_embedder_viewdir(bg_viewdirs)), dim=-1)
115 | # near_depth: physical far; far_depth: physical near
116 | input = torch.flip(input, dims=[-2,])
117 | bg_z_vals = torch.flip(bg_z_vals, dims=[-1,]) # 1--->0
118 | bg_dists = bg_z_vals[..., :-1] - bg_z_vals[..., 1:]
119 | bg_dists = torch.cat((bg_dists, HUGE_NUMBER * torch.ones_like(bg_dists[..., 0:1])), dim=-1) # [..., N_samples]
120 | bg_raw = self.bg_net(input)
121 | bg_alpha = 1. - torch.exp(-bg_raw['sigma'] * bg_dists) # [..., N_samples]
122 | # Eq. (3): T
123 | # maths show weights, and summation of weights along a ray, are always inside [0, 1]
124 | T = torch.cumprod(1. - bg_alpha + TINY_NUMBER, dim=-1)[..., :-1] # [..., N_samples-1]
125 | T = torch.cat((torch.ones_like(T[..., 0:1]), T), dim=-1) # [..., N_samples]
126 | bg_weights = bg_alpha * T # [..., N_samples]
127 | bg_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['rgb'], dim=-2) # [..., 3]
128 | bg_depth_map = torch.sum(bg_weights * bg_z_vals, dim=-1) # [...,]
129 |
130 | # composite foreground and background
131 | bg_rgb_map = bg_lambda.unsqueeze(-1) * bg_rgb_map
132 | bg_depth_map = bg_lambda * bg_depth_map
133 | rgb_map = fg_rgb_map + bg_rgb_map
134 |
135 | ret = OrderedDict([('rgb', rgb_map), # loss
136 | ('fg_weights', fg_weights), # importance sampling
137 | ('bg_weights', bg_weights), # importance sampling
138 | ('fg_rgb', fg_rgb_map), # below are for logging
139 | ('fg_depth', fg_depth_map),
140 | ('bg_rgb', bg_rgb_map),
141 | ('bg_depth', bg_depth_map),
142 | ('bg_lambda', bg_lambda)])
143 | return ret
144 |
145 |
146 | def remap_name(name):
147 | name = name.replace('.', '-') # dot is not allowed by pytorch
148 | if name[-1] == '/':
149 | name = name[:-1]
150 | idx = name.rfind('/')
151 | for i in range(2):
152 | if idx >= 0:
153 | idx = name[:idx].rfind('/')
154 | return name[idx + 1:]
155 |
156 |
157 | class NerfNetWithAutoExpo(nn.Module):
158 | def __init__(self, args, optim_autoexpo=False, img_names=None):
159 | super().__init__()
160 | self.nerf_net = NerfNet(args)
161 |
162 | self.optim_autoexpo = optim_autoexpo
163 | if self.optim_autoexpo:
164 | assert(img_names is not None)
165 | logger.info('Optimizing autoexposure!')
166 |
167 | self.img_names = [remap_name(x) for x in img_names]
168 | logger.info('\n'.join(self.img_names))
169 | self.autoexpo_params = nn.ParameterDict(OrderedDict([(x, nn.Parameter(torch.Tensor([0.5, 0.]))) for x in self.img_names]))
170 |
171 | def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals, img_name=None):
172 | '''
173 | :param ray_o, ray_d: [..., 3]
174 | :param fg_z_max: [...,]
175 | :param fg_z_vals, bg_z_vals: [..., N_samples]
176 | :return
177 | '''
178 | ret = self.nerf_net(ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals)
179 |
180 | if img_name is not None:
181 | img_name = remap_name(img_name)
182 | if self.optim_autoexpo and (img_name in self.autoexpo_params):
183 | autoexpo = self.autoexpo_params[img_name]
184 | scale = torch.abs(autoexpo[0]) + 0.5 # make sure scale is always positive
185 | shift = autoexpo[1]
186 | ret['autoexpo'] = (scale, shift)
187 |
188 | return ret
189 |
--------------------------------------------------------------------------------
/ddp_test_nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # import torch.nn as nn
3 | import torch.optim
4 | import torch.distributed
5 | # from torch.nn.parallel import DistributedDataParallel as DDP
6 | import torch.multiprocessing
7 | import numpy as np
8 | import os
9 | # from collections import OrderedDict
10 | # from ddp_model import NerfNet
11 | import time
12 | from data_loader_split import load_data_split
13 | from utils import mse2psnr, colorize_np, to8b
14 | import imageio
15 | from ddp_train_nerf import config_parser, setup_logger, setup, cleanup, render_single_image, create_nerf
16 | import logging
17 |
18 |
19 | logger = logging.getLogger(__package__)
20 |
21 |
22 | def ddp_test_nerf(rank, args):
23 | ###### set up multi-processing
24 | setup(rank, args.world_size)
25 | ###### set up logger
26 | logger = logging.getLogger(__package__)
27 | setup_logger()
28 |
29 | ###### decide chunk size according to gpu memory
30 | if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
31 | logger.info('setting batch size according to 24G gpu')
32 | args.N_rand = 1024
33 | args.chunk_size = 8192
34 | else:
35 | logger.info('setting batch size according to 12G gpu')
36 | args.N_rand = 512
37 | args.chunk_size = 4096
38 |
39 | ###### create network and wrap in ddp; each process should do this
40 | start, models = create_nerf(rank, args)
41 |
42 | render_splits = [x.strip() for x in args.render_splits.strip().split(',')]
43 | # start testing
44 | for split in render_splits:
45 | out_dir = os.path.join(args.basedir, args.expname,
46 | 'render_{}_{:06d}'.format(split, start))
47 | if rank == 0:
48 | os.makedirs(out_dir, exist_ok=True)
49 |
50 | ###### load data and create ray samplers; each process should do this
51 | ray_samplers = load_data_split(args.datadir, args.scene, split, try_load_min_depth=args.load_min_depth)
52 | for idx in range(len(ray_samplers)):
53 | ### each process should do this; but only main process merges the results
54 | fname = '{:06d}.png'.format(idx)
55 | if ray_samplers[idx].img_path is not None:
56 | fname = os.path.basename(ray_samplers[idx].img_path)
57 |
58 | if os.path.isfile(os.path.join(out_dir, fname)):
59 | logger.info('Skipping {}'.format(fname))
60 | continue
61 |
62 | time0 = time.time()
63 | ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
64 | dt = time.time() - time0
65 | if rank == 0: # only main process should do this
66 | logger.info('Rendered {} in {} seconds'.format(fname, dt))
67 |
68 | # only save last level
69 | im = ret[-1]['rgb'].numpy()
70 | # compute psnr if ground-truth is available
71 | if ray_samplers[idx].img_path is not None:
72 | gt_im = ray_samplers[idx].get_img()
73 | psnr = mse2psnr(np.mean((gt_im - im) * (gt_im - im)))
74 | logger.info('{}: psnr={}'.format(fname, psnr))
75 |
76 | im = to8b(im)
77 | imageio.imwrite(os.path.join(out_dir, fname), im)
78 |
79 | im = ret[-1]['fg_rgb'].numpy()
80 | im = to8b(im)
81 | imageio.imwrite(os.path.join(out_dir, 'fg_' + fname), im)
82 |
83 | im = ret[-1]['bg_rgb'].numpy()
84 | im = to8b(im)
85 | imageio.imwrite(os.path.join(out_dir, 'bg_' + fname), im)
86 |
87 | im = ret[-1]['fg_depth'].numpy()
88 | im = colorize_np(im, cmap_name='jet', append_cbar=True)
89 | im = to8b(im)
90 | imageio.imwrite(os.path.join(out_dir, 'fg_depth_' + fname), im)
91 |
92 | im = ret[-1]['bg_depth'].numpy()
93 | im = colorize_np(im, cmap_name='jet', append_cbar=True)
94 | im = to8b(im)
95 | imageio.imwrite(os.path.join(out_dir, 'bg_depth_' + fname), im)
96 |
97 | torch.cuda.empty_cache()
98 |
99 | # clean up for multi-processing
100 | cleanup()
101 |
102 |
103 | def test():
104 | parser = config_parser()
105 | args = parser.parse_args()
106 | logger.info(parser.format_values())
107 |
108 | if args.world_size == -1:
109 | args.world_size = torch.cuda.device_count()
110 | logger.info('Using # gpus: {}'.format(args.world_size))
111 | torch.multiprocessing.spawn(ddp_test_nerf,
112 | args=(args,),
113 | nprocs=args.world_size,
114 | join=True)
115 |
116 |
117 | if __name__ == '__main__':
118 | setup_logger()
119 | test()
120 |
121 |
--------------------------------------------------------------------------------
/ddp_train_nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim
4 | import torch.distributed
5 | from torch.nn.parallel import DistributedDataParallel as DDP
6 | import torch.multiprocessing
7 | import os
8 | from collections import OrderedDict
9 | from ddp_model import NerfNetWithAutoExpo
10 | import time
11 | from data_loader_split import load_data_split
12 | import numpy as np
13 | from tensorboardX import SummaryWriter
14 | from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER
15 | import logging
16 | import json
17 |
18 |
19 | logger = logging.getLogger(__package__)
20 |
21 |
22 | def setup_logger():
23 | # create logger
24 | logger = logging.getLogger(__package__)
25 | # logger.setLevel(logging.DEBUG)
26 | logger.setLevel(logging.INFO)
27 |
28 | # create console handler and set level to debug
29 | ch = logging.StreamHandler()
30 | ch.setLevel(logging.DEBUG)
31 |
32 | # create formatter
33 | formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
34 |
35 | # add formatter to ch
36 | ch.setFormatter(formatter)
37 |
38 | # add ch to logger
39 | logger.addHandler(ch)
40 |
41 |
42 | def intersect_sphere(ray_o, ray_d):
43 | '''
44 | ray_o, ray_d: [..., 3]
45 | compute the depth of the intersection point between this ray and unit sphere
46 | '''
47 | # note: d1 becomes negative if this mid point is behind camera
48 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
49 | p = ray_o + d1.unsqueeze(-1) * ray_d
50 | # consider the case where the ray does not intersect the sphere
51 | ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
52 | p_norm_sq = torch.sum(p * p, dim=-1)
53 | if (p_norm_sq >= 1.).any():
54 | raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!')
55 | d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos
56 |
57 | return d1 + d2
58 |
59 |
60 | def perturb_samples(z_vals):
61 | # get intervals between samples
62 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
63 | upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
64 | lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
65 | # uniform samples in those intervals
66 | t_rand = torch.rand_like(z_vals)
67 | z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
68 |
69 | return z_vals
70 |
71 |
72 | def sample_pdf(bins, weights, N_samples, det=False):
73 | '''
74 | :param bins: tensor of shape [..., M+1], M is the number of bins
75 | :param weights: tensor of shape [..., M]
76 | :param N_samples: number of samples along each ray
77 | :param det: if True, will perform deterministic sampling
78 | :return: [..., N_samples]
79 | '''
80 | # Get pdf
81 | weights = weights + TINY_NUMBER # prevent nans
82 | pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
83 | cdf = torch.cumsum(pdf, dim=-1) # [..., M]
84 | cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1]
85 |
86 | # Take uniform samples
87 | dots_sh = list(weights.shape[:-1])
88 | M = weights.shape[-1]
89 |
90 | min_cdf = 0.00
91 | max_cdf = 1.00 # prevent outlier samples
92 |
93 | if det:
94 | u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device)
95 | u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples]
96 | else:
97 | sh = dots_sh + [N_samples]
98 | u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples]
99 |
100 | # Invert CDF
101 | # [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,]
102 | above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long()
103 |
104 | # random sample inside each bin
105 | below_inds = torch.clamp(above_inds-1, min=0)
106 | inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2]
107 |
108 | cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
109 | cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2]
110 |
111 | bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
112 | bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2]
113 |
114 | # fix numeric issue
115 | denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples]
116 | denom = torch.where(denom1
250 | writer.add_image(prefix + 'level_{}/rgb'.format(m), rgb_im, global_step)
251 |
252 | rgb_im = img_HWC2CHW(log_data[m]['fg_rgb'])
253 | rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
254 | writer.add_image(prefix + 'level_{}/fg_rgb'.format(m), rgb_im, global_step)
255 | depth = log_data[m]['fg_depth']
256 | depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
257 | mask=mask))
258 | writer.add_image(prefix + 'level_{}/fg_depth'.format(m), depth_im, global_step)
259 |
260 | rgb_im = img_HWC2CHW(log_data[m]['bg_rgb'])
261 | rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
262 | writer.add_image(prefix + 'level_{}/bg_rgb'.format(m), rgb_im, global_step)
263 | depth = log_data[m]['bg_depth']
264 | depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
265 | mask=mask))
266 | writer.add_image(prefix + 'level_{}/bg_depth'.format(m), depth_im, global_step)
267 | bg_lambda = log_data[m]['bg_lambda']
268 | bg_lambda_im = img_HWC2CHW(colorize(bg_lambda, cmap_name='hot', append_cbar=True,
269 | mask=mask))
270 | writer.add_image(prefix + 'level_{}/bg_lambda'.format(m), bg_lambda_im, global_step)
271 |
272 |
273 | def setup(rank, world_size):
274 | os.environ['MASTER_ADDR'] = 'localhost'
275 | # port = np.random.randint(12355, 12399)
276 | # os.environ['MASTER_PORT'] = '{}'.format(port)
277 | os.environ['MASTER_PORT'] = '12355'
278 | # initialize the process group
279 | torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
280 |
281 |
282 | def cleanup():
283 | torch.distributed.destroy_process_group()
284 |
285 |
286 | def create_nerf(rank, args):
287 | ###### create network and wrap in ddp; each process should do this
288 | # fix random seed just to make sure the network is initialized with same weights at different processes
289 | torch.manual_seed(777)
290 | # very important!!! otherwise it might introduce extra memory in rank=0 gpu
291 | torch.cuda.set_device(rank)
292 |
293 | models = OrderedDict()
294 | models['cascade_level'] = args.cascade_level
295 | models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')]
296 | for m in range(models['cascade_level']):
297 | img_names = None
298 | if args.optim_autoexpo:
299 | # load training image names for autoexposure
300 | f = os.path.join(args.basedir, args.expname, 'train_images.json')
301 | with open(f) as file:
302 | img_names = json.load(file)
303 | net = NerfNetWithAutoExpo(args, optim_autoexpo=args.optim_autoexpo, img_names=img_names).to(rank)
304 | net = DDP(net, device_ids=[rank], output_device=rank, find_unused_parameters=True)
305 | # net = DDP(net, device_ids=[rank], output_device=rank)
306 | optim = torch.optim.Adam(net.parameters(), lr=args.lrate)
307 | models['net_{}'.format(m)] = net
308 | models['optim_{}'.format(m)] = optim
309 |
310 | start = -1
311 |
312 | ###### load pretrained weights; each process should do this
313 | if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)):
314 | ckpts = [args.ckpt_path]
315 | else:
316 | ckpts = [os.path.join(args.basedir, args.expname, f)
317 | for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')]
318 | def path2iter(path):
319 | tmp = os.path.basename(path)[:-4]
320 | idx = tmp.rfind('_')
321 | return int(tmp[idx + 1:])
322 | ckpts = sorted(ckpts, key=path2iter)
323 | logger.info('Found ckpts: {}'.format(ckpts))
324 | if len(ckpts) > 0 and not args.no_reload:
325 | fpath = ckpts[-1]
326 | logger.info('Reloading from: {}'.format(fpath))
327 | start = path2iter(fpath)
328 | # configure map_location properly for different processes
329 | map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
330 | to_load = torch.load(fpath, map_location=map_location)
331 | for m in range(models['cascade_level']):
332 | for name in ['net_{}'.format(m), 'optim_{}'.format(m)]:
333 | models[name].load_state_dict(to_load[name])
334 |
335 | return start, models
336 |
337 |
338 | def ddp_train_nerf(rank, args):
339 | ###### set up multi-processing
340 | setup(rank, args.world_size)
341 | ###### set up logger
342 | logger = logging.getLogger(__package__)
343 | setup_logger()
344 |
345 | ###### decide chunk size according to gpu memory
346 | logger.info('gpu_mem: {}'.format(torch.cuda.get_device_properties(rank).total_memory))
347 | if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
348 | logger.info('setting batch size according to 24G gpu')
349 | args.N_rand = 1024
350 | args.chunk_size = 8192
351 | else:
352 | logger.info('setting batch size according to 12G gpu')
353 | args.N_rand = 512
354 | args.chunk_size = 4096
355 |
356 | ###### Create log dir and copy the config file
357 | if rank == 0:
358 | os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True)
359 | f = os.path.join(args.basedir, args.expname, 'args.txt')
360 | with open(f, 'w') as file:
361 | for arg in sorted(vars(args)):
362 | attr = getattr(args, arg)
363 | file.write('{} = {}\n'.format(arg, attr))
364 | if args.config is not None:
365 | f = os.path.join(args.basedir, args.expname, 'config.txt')
366 | with open(f, 'w') as file:
367 | file.write(open(args.config, 'r').read())
368 | torch.distributed.barrier()
369 |
370 | ray_samplers = load_data_split(args.datadir, args.scene, split='train',
371 | try_load_min_depth=args.load_min_depth)
372 | val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation',
373 | try_load_min_depth=args.load_min_depth, skip=args.testskip)
374 |
375 | # write training image names for autoexposure
376 | if args.optim_autoexpo:
377 | f = os.path.join(args.basedir, args.expname, 'train_images.json')
378 | with open(f, 'w') as file:
379 | img_names = [ray_samplers[i].img_path for i in range(len(ray_samplers))]
380 | json.dump(img_names, file, indent=2)
381 |
382 | ###### create network and wrap in ddp; each process should do this
383 | start, models = create_nerf(rank, args)
384 |
385 | ##### important!!!
386 | # make sure different processes sample different rays
387 | np.random.seed((rank + 1) * 777)
388 | # make sure different processes have different perturbations in depth samples
389 | torch.manual_seed((rank + 1) * 777)
390 |
391 | ##### only main process should do the logging
392 | if rank == 0:
393 | writer = SummaryWriter(os.path.join(args.basedir, 'summaries', args.expname))
394 |
395 | # start training
396 | what_val_to_log = 0 # helper variable for parallel rendering of a image
397 | what_train_to_log = 0
398 | for global_step in range(start+1, start+1+args.N_iters):
399 | time0 = time.time()
400 | scalars_to_log = OrderedDict()
401 | ### Start of core optimization loop
402 | scalars_to_log['resolution'] = ray_samplers[0].resolution_level
403 | # randomly sample rays and move to device
404 | i = np.random.randint(low=0, high=len(ray_samplers))
405 | ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False)
406 | for key in ray_batch:
407 | if torch.is_tensor(ray_batch[key]):
408 | ray_batch[key] = ray_batch[key].to(rank)
409 |
410 | # forward and backward
411 | dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays
412 | all_rets = [] # results on different cascade levels
413 | for m in range(models['cascade_level']):
414 | optim = models['optim_{}'.format(m)]
415 | net = models['net_{}'.format(m)]
416 |
417 | # sample depths
418 | N_samples = models['cascade_samples'][m]
419 | if m == 0:
420 | # foreground depth
421 | fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,]
422 | fg_near_depth = ray_batch['min_depth'] # [..., ]
423 | step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
424 | fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
425 | fg_depth = perturb_samples(fg_depth) # random perturbation during training
426 |
427 | # background depth
428 | bg_depth = torch.linspace(0., 1., N_samples).view(
429 | [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
430 | bg_depth = perturb_samples(bg_depth) # random perturbation during training
431 | else:
432 | # sample pdf and concat with earlier samples
433 | fg_weights = ret['fg_weights'].clone().detach()
434 | fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
435 | fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
436 | fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
437 | N_samples=N_samples, det=False) # [..., N_samples]
438 | fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
439 |
440 | # sample pdf and concat with earlier samples
441 | bg_weights = ret['bg_weights'].clone().detach()
442 | bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
443 | bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
444 | bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
445 | N_samples=N_samples, det=False) # [..., N_samples]
446 | bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
447 |
448 | optim.zero_grad()
449 | ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth, img_name=ray_batch['img_name'])
450 | all_rets.append(ret)
451 |
452 | rgb_gt = ray_batch['rgb'].to(rank)
453 | if 'autoexpo' in ret:
454 | scale, shift = ret['autoexpo']
455 | scalars_to_log['level_{}/autoexpo_scale'.format(m)] = scale.item()
456 | scalars_to_log['level_{}/autoexpo_shift'.format(m)] = shift.item()
457 | # rgb_gt = scale * rgb_gt + shift
458 | rgb_pred = (ret['rgb'] - shift) / scale
459 | rgb_loss = img2mse(rgb_pred, rgb_gt)
460 | loss = rgb_loss + args.lambda_autoexpo * (torch.abs(scale-1.)+torch.abs(shift))
461 | else:
462 | rgb_loss = img2mse(ret['rgb'], rgb_gt)
463 | loss = rgb_loss
464 | scalars_to_log['level_{}/loss'.format(m)] = rgb_loss.item()
465 | scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(rgb_loss.item())
466 | loss.backward()
467 | optim.step()
468 |
469 | # # clean unused memory
470 | # torch.cuda.empty_cache()
471 |
472 | ### end of core optimization loop
473 | dt = time.time() - time0
474 | scalars_to_log['iter_time'] = dt
475 |
476 | ### only main process should do the logging
477 | if rank == 0 and (global_step % args.i_print == 0 or global_step < 10):
478 | logstr = '{} step: {} '.format(args.expname, global_step)
479 | for k in scalars_to_log:
480 | logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k])
481 | writer.add_scalar(k, scalars_to_log[k], global_step)
482 | logger.info(logstr)
483 |
484 | ### each process should do this; but only main process merges the results
485 | if global_step % args.i_img == 0 or global_step == start+1:
486 | #### critical: make sure each process is working on the same random image
487 | time0 = time.time()
488 | idx = what_val_to_log % len(val_ray_samplers)
489 | log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size)
490 | what_val_to_log += 1
491 | dt = time.time() - time0
492 | if rank == 0: # only main process should do this
493 | logger.info('Logged a random validation view in {} seconds'.format(dt))
494 | log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].get_img(), mask=None, prefix='val/')
495 |
496 | time0 = time.time()
497 | idx = what_train_to_log % len(ray_samplers)
498 | log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
499 | what_train_to_log += 1
500 | dt = time.time() - time0
501 | if rank == 0: # only main process should do this
502 | logger.info('Logged a random training view in {} seconds'.format(dt))
503 | log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].get_img(), mask=None, prefix='train/')
504 |
505 | del log_data
506 | torch.cuda.empty_cache()
507 |
508 | if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0):
509 | # saving checkpoints and logging
510 | fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step))
511 | to_save = OrderedDict()
512 | for m in range(models['cascade_level']):
513 | name = 'net_{}'.format(m)
514 | to_save[name] = models[name].state_dict()
515 |
516 | name = 'optim_{}'.format(m)
517 | to_save[name] = models[name].state_dict()
518 | torch.save(to_save, fpath)
519 |
520 | # clean up for multi-processing
521 | cleanup()
522 |
523 |
524 | def config_parser():
525 | import configargparse
526 | parser = configargparse.ArgumentParser()
527 | parser.add_argument('--config', is_config_file=True, help='config file path')
528 | parser.add_argument("--expname", type=str, help='experiment name')
529 | parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
530 | # dataset options
531 | parser.add_argument("--datadir", type=str, default=None, help='input data directory')
532 | parser.add_argument("--scene", type=str, default=None, help='scene name')
533 | parser.add_argument("--testskip", type=int, default=8,
534 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
535 | # model size
536 | parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network')
537 | parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network')
538 | parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
539 | # checkpoints
540 | parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
541 | parser.add_argument("--ckpt_path", type=str, default=None,
542 | help='specific weights npy file to reload for coarse network')
543 | # batch size
544 | parser.add_argument("--N_rand", type=int, default=32 * 32 * 2,
545 | help='batch size (number of random rays per gradient step)')
546 | parser.add_argument("--chunk_size", type=int, default=1024 * 8,
547 | help='number of rays processed in parallel, decrease if running out of memory')
548 | # iterations
549 | parser.add_argument("--N_iters", type=int, default=250001,
550 | help='number of iterations')
551 | # render only
552 | parser.add_argument("--render_splits", type=str, default='test',
553 | help='splits to render')
554 | # cascade training
555 | parser.add_argument("--cascade_level", type=int, default=2,
556 | help='number of cascade levels')
557 | parser.add_argument("--cascade_samples", type=str, default='64,64',
558 | help='samples at each level')
559 | # multiprocess learning
560 | parser.add_argument("--world_size", type=int, default='-1',
561 | help='number of processes')
562 | # optimize autoexposure
563 | parser.add_argument("--optim_autoexpo", action='store_true',
564 | help='optimize autoexposure parameters')
565 | parser.add_argument("--lambda_autoexpo", type=float, default=1., help='regularization weight for autoexposure')
566 |
567 | # learning rate options
568 | parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
569 | parser.add_argument("--lrate_decay_factor", type=float, default=0.1,
570 | help='decay learning rate by a factor every specified number of steps')
571 | parser.add_argument("--lrate_decay_steps", type=int, default=5000,
572 | help='decay learning rate by a factor every specified number of steps')
573 | # rendering options
574 | parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples')
575 | parser.add_argument("--max_freq_log2", type=int, default=10,
576 | help='log2 of max freq for positional encoding (3D location)')
577 | parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4,
578 | help='log2 of max freq for positional encoding (2D direction)')
579 | parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth')
580 | # logging/saving options
581 | parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
582 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
583 | parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
584 |
585 | return parser
586 |
587 |
588 | def train():
589 | parser = config_parser()
590 | args = parser.parse_args()
591 | logger.info(parser.format_values())
592 |
593 | if args.world_size == -1:
594 | args.world_size = torch.cuda.device_count()
595 | logger.info('Using # gpus: {}'.format(args.world_size))
596 | torch.multiprocessing.spawn(ddp_train_nerf,
597 | args=(args,),
598 | nprocs=args.world_size,
599 | join=True)
600 |
601 |
602 | if __name__ == '__main__':
603 | setup_logger()
604 | train()
605 |
606 |
607 |
--------------------------------------------------------------------------------
/demo/tat_Playground.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/demo/tat_Playground.gif
--------------------------------------------------------------------------------
/demo/tat_Truck.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/demo/tat_Truck.gif
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: nerfplusplus
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - ca-certificates=2020.7.22=0
7 | - certifi=2020.6.20=py36_0
8 | - ld_impl_linux-64=2.33.1=h53a641e_7
9 | - libedit=3.1.20191231=h14c3975_1
10 | - libffi=3.3=he6710b0_2
11 | - libgcc-ng=9.1.0=hdf63c60_0
12 | - libstdcxx-ng=9.1.0=hdf63c60_0
13 | - ncurses=6.2=he6710b0_1
14 | - openssl=1.1.1g=h7b6447c_0
15 | - pip=20.2.2=py36_0
16 | - python=3.6.12=hcff3b4d_2
17 | - readline=8.0=h7b6447c_0
18 | - setuptools=49.6.0=py36_0
19 | - sqlite=3.33.0=h62c20be_0
20 | - tk=8.6.10=hbc83047_0
21 | - wheel=0.35.1=py_0
22 | - xz=5.2.5=h7b6447c_0
23 | - zlib=1.2.11=h7b6447c_3
24 | - pip:
25 | - absl-py==0.10.0
26 | - astunparse==1.6.3
27 | - backcall==0.2.0
28 | - cachetools==4.1.1
29 | - chardet==3.0.4
30 | - configargparse==1.2.3
31 | - cycler==0.10.0
32 | - decorator==4.4.2
33 | - future==0.18.2
34 | - gast==0.3.3
35 | - google-auth==1.21.2
36 | - google-auth-oauthlib==0.4.1
37 | - google-pasta==0.2.0
38 | - grpcio==1.32.0
39 | - h5py==2.10.0
40 | - idna==2.10
41 | - imageio==2.9.0
42 | - importlib-metadata==1.7.0
43 | - ipython==7.16.1
44 | - ipython-genutils==0.2.0
45 | - jedi==0.17.2
46 | - keras-preprocessing==1.1.2
47 | - kiwisolver==1.2.0
48 | - lpips==0.1.1
49 | - markdown==3.2.2
50 | - matplotlib==3.3.2
51 | - networkx==2.5
52 | - numpy==1.18.0
53 | - oauthlib==3.1.0
54 | - opencv-python==4.4.0.42
55 | - opt-einsum==3.3.0
56 | - parso==0.7.1
57 | - pexpect==4.8.0
58 | - pickleshare==0.7.5
59 | - pillow==7.2.0
60 | - prompt-toolkit==3.0.7
61 | - protobuf==3.13.0
62 | - ptyprocess==0.6.0
63 | - pyasn1==0.4.8
64 | - pyasn1-modules==0.2.8
65 | - pygments==2.7.1
66 | - pymcubes==0.1.2
67 | - pyparsing==2.4.7
68 | - python-dateutil==2.8.1
69 | - pywavelets==1.1.1
70 | - requests==2.24.0
71 | - requests-oauthlib==1.3.0
72 | - rsa==4.6
73 | - scikit-image==0.17.2
74 | - scipy==1.4.1
75 | - six==1.15.0
76 | - tensorboard==2.3.0
77 | - tensorboard-plugin-wit==1.7.0
78 | - tensorboardx==2.1
79 | - tensorflow==2.3.0
80 | - tensorflow-estimator==2.3.0
81 | - termcolor==1.1.0
82 | - tifffile==2020.9.3
83 | - torch==1.6.0
84 | - torchvision==0.7.0
85 | - tqdm==4.49.0
86 | - traitlets==4.3.3
87 | - trimesh==3.8.10
88 | - urllib3==1.25.10
89 | - wcwidth==0.2.5
90 | - werkzeug==1.0.1
91 | - wrapt==1.12.1
92 | - zipp==3.1.0
93 | prefix: /home/zhangka2/anaconda3/envs/nerfplusplus
94 |
95 |
--------------------------------------------------------------------------------
/nerf_network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # import torch.nn.functional as F
4 | # import numpy as np
5 | from collections import OrderedDict
6 |
7 | import logging
8 | logger = logging.getLogger(__package__)
9 |
10 |
11 | class Embedder(nn.Module):
12 | def __init__(self, input_dim, max_freq_log2, N_freqs,
13 | log_sampling=True, include_input=True,
14 | periodic_fns=(torch.sin, torch.cos)):
15 | '''
16 | :param input_dim: dimension of input to be embedded
17 | :param max_freq_log2: log2 of max freq; min freq is 1 by default
18 | :param N_freqs: number of frequency bands
19 | :param log_sampling: if True, frequency bands are linerly sampled in log-space
20 | :param include_input: if True, raw input is included in the embedding
21 | :param periodic_fns: periodic functions used to embed input
22 | '''
23 | super().__init__()
24 |
25 | self.input_dim = input_dim
26 | self.include_input = include_input
27 | self.periodic_fns = periodic_fns
28 |
29 | self.out_dim = 0
30 | if self.include_input:
31 | self.out_dim += self.input_dim
32 |
33 | self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)
34 |
35 | if log_sampling:
36 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
37 | else:
38 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
39 |
40 | self.freq_bands = self.freq_bands.numpy().tolist()
41 |
42 | def forward(self, input):
43 | '''
44 | :param input: tensor of shape [..., self.input_dim]
45 | :return: tensor of shape [..., self.out_dim]
46 | '''
47 | assert (input.shape[-1] == self.input_dim)
48 |
49 | out = []
50 | if self.include_input:
51 | out.append(input)
52 |
53 | for i in range(len(self.freq_bands)):
54 | freq = self.freq_bands[i]
55 | for p_fn in self.periodic_fns:
56 | out.append(p_fn(input * freq))
57 | out = torch.cat(out, dim=-1)
58 |
59 | assert (out.shape[-1] == self.out_dim)
60 | return out
61 |
62 | # default tensorflow initialization of linear layers
63 | def weights_init(m):
64 | if isinstance(m, nn.Linear):
65 | nn.init.xavier_uniform_(m.weight.data)
66 | if m.bias is not None:
67 | nn.init.zeros_(m.bias.data)
68 |
69 |
70 | class MLPNet(nn.Module):
71 | def __init__(self, D=8, W=256, input_ch=3, input_ch_viewdirs=3,
72 | skips=[4], use_viewdirs=False):
73 | '''
74 | :param D: network depth
75 | :param W: network width
76 | :param input_ch: input channels for encodings of (x, y, z)
77 | :param input_ch_viewdirs: input channels for encodings of view directions
78 | :param skips: skip connection in network
79 | :param use_viewdirs: if True, will use the view directions as input
80 | '''
81 | super().__init__()
82 |
83 | self.input_ch = input_ch
84 | self.input_ch_viewdirs = input_ch_viewdirs
85 | self.use_viewdirs = use_viewdirs
86 | self.skips = skips
87 |
88 | self.base_layers = []
89 | dim = self.input_ch
90 | for i in range(D):
91 | self.base_layers.append(
92 | nn.Sequential(nn.Linear(dim, W), nn.ReLU())
93 | )
94 | dim = W
95 | if i in self.skips and i != (D-1): # skip connection after i^th layer
96 | dim += input_ch
97 | self.base_layers = nn.ModuleList(self.base_layers)
98 | # self.base_layers.apply(weights_init) # xavier init
99 |
100 | sigma_layers = [nn.Linear(dim, 1), ] # sigma must be positive
101 | self.sigma_layers = nn.Sequential(*sigma_layers)
102 | # self.sigma_layers.apply(weights_init) # xavier init
103 |
104 | # rgb color
105 | rgb_layers = []
106 | base_remap_layers = [nn.Linear(dim, 256), ]
107 | self.base_remap_layers = nn.Sequential(*base_remap_layers)
108 | # self.base_remap_layers.apply(weights_init)
109 |
110 | dim = 256 + self.input_ch_viewdirs
111 | for i in range(1):
112 | rgb_layers.append(nn.Linear(dim, W // 2))
113 | rgb_layers.append(nn.ReLU())
114 | dim = W // 2
115 | rgb_layers.append(nn.Linear(dim, 3))
116 | rgb_layers.append(nn.Sigmoid()) # rgb values are normalized to [0, 1]
117 | self.rgb_layers = nn.Sequential(*rgb_layers)
118 | # self.rgb_layers.apply(weights_init)
119 |
120 | def forward(self, input):
121 | '''
122 | :param input: [..., input_ch+input_ch_viewdirs]
123 | :return [..., 4]
124 | '''
125 | input_pts = input[..., :self.input_ch]
126 |
127 | base = self.base_layers[0](input_pts)
128 | for i in range(len(self.base_layers)-1):
129 | if i in self.skips:
130 | base = torch.cat((input_pts, base), dim=-1)
131 | base = self.base_layers[i+1](base)
132 |
133 | sigma = self.sigma_layers(base)
134 | sigma = torch.abs(sigma)
135 |
136 | base_remap = self.base_remap_layers(base)
137 | input_viewdirs = input[..., -self.input_ch_viewdirs:]
138 | rgb = self.rgb_layers(torch.cat((base_remap, input_viewdirs), dim=-1))
139 |
140 | ret = OrderedDict([('rgb', rgb),
141 | ('sigma', sigma.squeeze(-1))])
142 | return ret
143 |
--------------------------------------------------------------------------------
/nerf_sample_ray_split.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import OrderedDict
3 | import torch
4 | import cv2
5 | import imageio
6 |
7 | ########################################################################################################################
8 | # ray batch sampling
9 | ########################################################################################################################
10 | def get_rays_single_image(H, W, intrinsics, c2w):
11 | '''
12 | :param H: image height
13 | :param W: image width
14 | :param intrinsics: 4 by 4 intrinsic matrix
15 | :param c2w: 4 by 4 camera to world extrinsic matrix
16 | :return:
17 | '''
18 | u, v = np.meshgrid(np.arange(W), np.arange(H))
19 |
20 | u = u.reshape(-1).astype(dtype=np.float32) + 0.5 # add half pixel
21 | v = v.reshape(-1).astype(dtype=np.float32) + 0.5
22 | pixels = np.stack((u, v, np.ones_like(u)), axis=0) # (3, H*W)
23 |
24 | rays_d = np.dot(np.linalg.inv(intrinsics[:3, :3]), pixels)
25 | rays_d = np.dot(c2w[:3, :3], rays_d) # (3, H*W)
26 | rays_d = rays_d.transpose((1, 0)) # (H*W, 3)
27 |
28 | rays_o = c2w[:3, 3].reshape((1, 3))
29 | rays_o = np.tile(rays_o, (rays_d.shape[0], 1)) # (H*W, 3)
30 |
31 | depth = np.linalg.inv(c2w)[2, 3]
32 | depth = depth * np.ones((rays_o.shape[0],), dtype=np.float32) # (H*W,)
33 |
34 | return rays_o, rays_d, depth
35 |
36 |
37 | class RaySamplerSingleImage(object):
38 | def __init__(self, H, W, intrinsics, c2w,
39 | img_path=None,
40 | resolution_level=1,
41 | mask_path=None,
42 | min_depth_path=None,
43 | max_depth=None):
44 | super().__init__()
45 | self.W_orig = W
46 | self.H_orig = H
47 | self.intrinsics_orig = intrinsics
48 | self.c2w_mat = c2w
49 |
50 | self.img_path = img_path
51 | self.mask_path = mask_path
52 | self.min_depth_path = min_depth_path
53 | self.max_depth = max_depth
54 |
55 | self.resolution_level = -1
56 | self.set_resolution_level(resolution_level)
57 |
58 | def set_resolution_level(self, resolution_level):
59 | if resolution_level != self.resolution_level:
60 | self.resolution_level = resolution_level
61 | self.W = self.W_orig // resolution_level
62 | self.H = self.H_orig // resolution_level
63 | self.intrinsics = np.copy(self.intrinsics_orig)
64 | self.intrinsics[:2, :3] /= resolution_level
65 | # only load image at this time
66 | if self.img_path is not None:
67 | self.img = imageio.imread(self.img_path).astype(np.float32) / 255.
68 | self.img = cv2.resize(self.img, (self.W, self.H), interpolation=cv2.INTER_AREA)
69 | self.img = self.img.reshape((-1, 3))
70 | else:
71 | self.img = None
72 |
73 | if self.mask_path is not None:
74 | self.mask = imageio.imread(self.mask_path).astype(np.float32) / 255.
75 | self.mask = cv2.resize(self.mask, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
76 | self.mask = self.mask.reshape((-1))
77 | else:
78 | self.mask = None
79 |
80 | if self.min_depth_path is not None:
81 | self.min_depth = imageio.imread(self.min_depth_path).astype(np.float32) / 255. * self.max_depth + 1e-4
82 | self.min_depth = cv2.resize(self.min_depth, (self.W, self.H), interpolation=cv2.INTER_LINEAR)
83 | self.min_depth = self.min_depth.reshape((-1))
84 | else:
85 | self.min_depth = None
86 |
87 | self.rays_o, self.rays_d, self.depth = get_rays_single_image(self.H, self.W,
88 | self.intrinsics, self.c2w_mat)
89 |
90 | def get_img(self):
91 | if self.img is not None:
92 | return self.img.reshape((self.H, self.W, 3))
93 | else:
94 | return None
95 |
96 | def get_all(self):
97 | if self.min_depth is not None:
98 | min_depth = self.min_depth
99 | else:
100 | min_depth = 1e-4 * np.ones_like(self.rays_d[..., 0])
101 |
102 | ret = OrderedDict([
103 | ('ray_o', self.rays_o),
104 | ('ray_d', self.rays_d),
105 | ('depth', self.depth),
106 | ('rgb', self.img),
107 | ('mask', self.mask),
108 | ('min_depth', min_depth)
109 | ])
110 | # return torch tensors
111 | for k in ret:
112 | if ret[k] is not None:
113 | ret[k] = torch.from_numpy(ret[k])
114 | return ret
115 |
116 | def random_sample(self, N_rand, center_crop=False):
117 | '''
118 | :param N_rand: number of rays to be casted
119 | :return:
120 | '''
121 | if center_crop:
122 | half_H = self.H // 2
123 | half_W = self.W // 2
124 | quad_H = half_H // 2
125 | quad_W = half_W // 2
126 |
127 | # pixel coordinates
128 | u, v = np.meshgrid(np.arange(half_W-quad_W, half_W+quad_W),
129 | np.arange(half_H-quad_H, half_H+quad_H))
130 | u = u.reshape(-1)
131 | v = v.reshape(-1)
132 |
133 | select_inds = np.random.choice(u.shape[0], size=(N_rand,), replace=False)
134 |
135 | # Convert back to original image
136 | select_inds = v[select_inds] * self.W + u[select_inds]
137 | else:
138 | # Random from one image
139 | select_inds = np.random.choice(self.H*self.W, size=(N_rand,), replace=False)
140 |
141 | rays_o = self.rays_o[select_inds, :] # [N_rand, 3]
142 | rays_d = self.rays_d[select_inds, :] # [N_rand, 3]
143 | depth = self.depth[select_inds] # [N_rand, ]
144 |
145 | if self.img is not None:
146 | rgb = self.img[select_inds, :] # [N_rand, 3]
147 | else:
148 | rgb = None
149 |
150 | if self.mask is not None:
151 | mask = self.mask[select_inds]
152 | else:
153 | mask = None
154 |
155 | if self.min_depth is not None:
156 | min_depth = self.min_depth[select_inds]
157 | else:
158 | min_depth = 1e-4 * np.ones_like(rays_d[..., 0])
159 |
160 | ret = OrderedDict([
161 | ('ray_o', rays_o),
162 | ('ray_d', rays_d),
163 | ('depth', depth),
164 | ('rgb', rgb),
165 | ('mask', mask),
166 | ('min_depth', min_depth),
167 | ('img_name', self.img_path)
168 | ])
169 | # return torch tensors
170 | for k in ret:
171 | if isinstance(ret[k], np.ndarray):
172 | ret[k] = torch.from_numpy(ret[k])
173 |
174 | return ret
175 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # import torch.nn as nn
3 | # import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | HUGE_NUMBER = 1e10
8 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision
9 |
10 |
11 | # misc utils
12 | def img2mse(x, y, mask=None):
13 | if mask is None:
14 | return torch.mean((x - y) * (x - y))
15 | else:
16 | return torch.sum((x - y) * (x - y) * mask.unsqueeze(-1)) / (torch.sum(mask) * x.shape[-1] + TINY_NUMBER)
17 |
18 | img_HWC2CHW = lambda x: x.permute(2, 0, 1)
19 | gray2rgb = lambda x: x.unsqueeze(2).repeat(1, 1, 3)
20 |
21 |
22 | def normalize(x):
23 | min = x.min()
24 | max = x.max()
25 |
26 | return (x - min) / ((max - min) + TINY_NUMBER)
27 |
28 |
29 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
30 | # gray2rgb = lambda x: np.tile(x[:,:,np.newaxis], (1, 1, 3))
31 | mse2psnr = lambda x: -10. * np.log(x+TINY_NUMBER) / np.log(10.)
32 |
33 |
34 | ########################################################################################################################
35 | #
36 | ########################################################################################################################
37 | from matplotlib.backends.backend_agg import FigureCanvasAgg
38 | from matplotlib.figure import Figure
39 | import matplotlib as mpl
40 | from matplotlib import cm
41 | import cv2
42 |
43 |
44 | def get_vertical_colorbar(h, vmin, vmax, cmap_name='jet', label=None):
45 | fig = Figure(figsize=(1.2, 8), dpi=100)
46 | fig.subplots_adjust(right=1.5)
47 | canvas = FigureCanvasAgg(fig)
48 |
49 | # Do some plotting.
50 | ax = fig.add_subplot(111)
51 | cmap = cm.get_cmap(cmap_name)
52 | norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
53 |
54 | tick_cnt = 6
55 | tick_loc = np.linspace(vmin, vmax, tick_cnt)
56 | cb1 = mpl.colorbar.ColorbarBase(ax, cmap=cmap,
57 | norm=norm,
58 | ticks=tick_loc,
59 | orientation='vertical')
60 |
61 | tick_label = ['{:3.2f}'.format(x) for x in tick_loc]
62 | cb1.set_ticklabels(tick_label)
63 |
64 | cb1.ax.tick_params(labelsize=18, rotation=0)
65 |
66 | if label is not None:
67 | cb1.set_label(label)
68 |
69 | fig.tight_layout()
70 |
71 | canvas.draw()
72 | s, (width, height) = canvas.print_to_buffer()
73 |
74 | im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
75 |
76 | im = im[:, :, :3].astype(np.float32) / 255.
77 | if h != im.shape[0]:
78 | w = int(im.shape[1] / im.shape[0] * h)
79 | im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
80 |
81 | return im
82 |
83 |
84 | def colorize_np(x, cmap_name='jet', mask=None, append_cbar=False):
85 | if mask is not None:
86 | # vmin, vmax = np.percentile(x[mask], (1, 99))
87 | vmin = np.min(x[mask])
88 | vmax = np.max(x[mask])
89 | vmin = vmin - np.abs(vmin) * 0.01
90 | x[np.logical_not(mask)] = vmin
91 | x = np.clip(x, vmin, vmax)
92 | # print(vmin, vmax)
93 | else:
94 | vmin = x.min()
95 | vmax = x.max() + TINY_NUMBER
96 |
97 | x = (x - vmin) / (vmax - vmin)
98 | # x = np.clip(x, 0., 1.)
99 |
100 | cmap = cm.get_cmap(cmap_name)
101 | x_new = cmap(x)[:, :, :3]
102 |
103 | if mask is not None:
104 | mask = np.float32(mask[:, :, np.newaxis])
105 | x_new = x_new * mask + np.zeros_like(x_new) * (1. - mask)
106 |
107 | cbar = get_vertical_colorbar(h=x.shape[0], vmin=vmin, vmax=vmax, cmap_name=cmap_name)
108 |
109 | if append_cbar:
110 | x_new = np.concatenate((x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1)
111 | return x_new
112 | else:
113 | return x_new, cbar
114 |
115 |
116 | # tensor
117 | def colorize(x, cmap_name='jet', append_cbar=False, mask=None):
118 | x = x.numpy()
119 | if mask is not None:
120 | mask = mask.numpy().astype(dtype=np.bool)
121 | x, cbar = colorize_np(x, cmap_name, mask)
122 |
123 | if append_cbar:
124 | x = np.concatenate((x, np.zeros_like(x[:, :5, :]), cbar), axis=1)
125 |
126 | x = torch.from_numpy(x)
127 | return x
128 |
--------------------------------------------------------------------------------