├── .gitignore ├── LICENSE ├── README.md ├── camera_inspector ├── inspect_epipolar_geometry.py ├── screenshot_lowres.png └── train │ ├── cam_dict_norm.json │ └── rgb │ ├── 000001.png │ └── 000005.png ├── camera_visualizer ├── camera_path │ └── cam_dict_norm.json ├── mesh_norm.ply ├── screenshot_lowres.png ├── test │ └── cam_dict_norm.json ├── train │ └── cam_dict_norm.json └── visualize_cameras.py ├── colmap_runner ├── database.py ├── extract_sfm.py ├── normalize_cam_dict.py ├── read_write_model.py ├── run_colmap.py └── run_colmap_posed.py ├── configs ├── lf_data │ └── lf_africa.txt └── tanks_and_temples │ └── tat_training_truck.txt ├── data_loader_split.py ├── ddp_model.py ├── ddp_test_nerf.py ├── ddp_train_nerf.py ├── demo ├── tat_Playground.gif └── tat_Truck.gif ├── environment.yml ├── nerf_network.py ├── nerf_sample_ray_split.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # scripts 2 | *.sh 3 | 4 | # mac 5 | .DS_Store 6 | 7 | # pycharm 8 | .idea/ 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, the NeRF++ authors 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeRF++ 2 | Codebase for arXiv preprint ["NeRF++: Analyzing and Improving Neural Radiance Fields"](http://arxiv.org/abs/2010.07492) 3 | * Work with 360 capture of large-scale unbounded scenes. 4 | * Support multi-gpu training and inference with PyTorch DistributedDataParallel (DDP). 5 | * Optimize per-image autoexposure (**experimental feature**). 6 | 7 | ## Demo 8 | ![](demo/tat_Truck.gif) ![](demo/tat_Playground.gif) 9 | 10 | ## Data 11 | * Download our preprocessed data from [tanks_and_temples](https://drive.google.com/file/d/11KRfN91W1AxAW6lOFs4EeYDbeoQZCi87/view?usp=sharing), [lf_data](https://drive.google.com/file/d/1gsjDjkbTh4GAR9fFqlIDZ__qR9NYTURQ/view?usp=sharing). 12 | * Put the data in the sub-folder data/ of this code directory. 13 | * Data format. 14 | * Each scene consists of 3 splits: train/test/validation. 15 | * Intrinsics and poses are stored as flattened 4x4 matrices (row-major). 16 | * Pixel coordinate of an image's upper-left corner is (column, row)=(0, 0), lower-right corner is (width-1, height-1). 17 | * Poses are camera-to-world, not world-to-camera transformations. 18 | * Opencv camera coordinate system is adopted, i.e., x--->right, y--->down, z--->scene. Similarly, intrinsic matrix also follows Opencv convention. 19 | * To convert camera poses between Opencv and Opengl conventions, the following code snippet can be used for both Opengl2Opencv and Opencv2Opengl. 20 | ```python 21 | import numpy as np 22 | def convert_pose(C2W): 23 | flip_yz = np.eye(4) 24 | flip_yz[1, 1] = -1 25 | flip_yz[2, 2] = -1 26 | C2W = np.matmul(C2W, flip_yz) 27 | return C2W 28 | ``` 29 | * Scene normalization: move the average camera center to origin, and put all the camera centers inside the unit sphere. Check [normalize_cam_dict.py](https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/normalize_cam_dict.py) for details. 30 | 31 | ## Create environment 32 | ```bash 33 | conda env create --file environment.yml 34 | conda activate nerfplusplus 35 | ``` 36 | 37 | ## Training (Use all available GPUs by default) 38 | ```python 39 | python ddp_train_nerf.py --config configs/tanks_and_temples/tat_training_truck.txt 40 | ``` 41 | 42 | **Note: In the paper, we train NeRF++ on a node with 4 RTX 2080 Ti GPUs, which took ∼24 hours.** 43 | 44 | ## Testing (Use all available GPUs by default) 45 | ```python 46 | python ddp_test_nerf.py --config configs/tanks_and_temples/tat_training_truck.txt \ 47 | --render_splits test,camera_path 48 | ``` 49 | 50 | **Note**: due to restriction imposed by torch.distributed.gather function, please make sure the number of pixels in each image is divisible by the number of GPUs if you render images parallelly. 51 | 52 | ## Pretrained weights 53 | I recently re-trained NeRF++ on the tanks and temples data for another project. Here are the checkpoints ([google drive](https://drive.google.com/drive/folders/15P0vCeDiLULzktvFByE5mb6fucJomvHr?usp=sharing)) just in case you might find them useful. 54 | 55 | ## Citation 56 | Plese cite our work if you use the code. 57 | ```python 58 | @article{kaizhang2020, 59 | author = {Kai Zhang and Gernot Riegler and Noah Snavely and Vladlen Koltun}, 60 | title = {NeRF++: Analyzing and Improving Neural Radiance Fields}, 61 | journal = {arXiv:2010.07492}, 62 | year = {2020}, 63 | } 64 | ``` 65 | 66 | ## Generate camera parameters (intrinsics and poses) with [COLMAP SfM](https://colmap.github.io/) 67 | You can use the scripts inside `colmap_runner` to generate camera parameters from images with COLMAP SfM. 68 | * Specify `img_dir` and `out_dir` in `colmap_runner/run_colmap.py`. 69 | * Inside `colmap_runner/`, execute command `python run_colmap.py`. 70 | * After program finishes, you would see the posed images in the folder `out_dir/posed_images`. 71 | * Distortion-free images are inside `out_dir/posed_images/images`. 72 | * Raw COLMAP intrinsics and poses are stored as a json file `out_dir/posed_images/kai_cameras.json`. 73 | * Normalized cameras are stored in `out_dir/posed_images/kai_cameras_normalized.json`. See the **Scene normalization method** in the **Data** section. 74 | * Split distortion-free images and `kai_cameras_normalized.json` according to your need. You might find the self-explanatory script `data_loader_split.py` helpful when you try converting the json file to data format compatible with NeRF++. 75 | 76 | ## Visualize cameras in 3D 77 | Check `camera_visualizer/visualize_cameras.py` for visualizing cameras in 3D. It creates an interactive viewer for you to inspect whether your cameras have been normalized to be compatible with this codebase. Below is a screenshot of the viewer: green cameras are used for training, blue ones are for testing, while yellow ones denote a novel camera path to be synthesized; red sphere is the unit sphere. 78 | 79 | 80 | 81 | 82 | 83 | ## Inspect camera parameters 84 | You can use `camera_inspector/inspect_epipolar_geometry.py` to inspect if the camera paramters are correct and follow the Opencv convention assumed by this codebase. The script creates a viewer for visually inspecting two-view epipolar geometry like below: for key points in the left image, it plots their correspoinding epipolar lines in the right image. If the epipolar geometry does not look correct in this visualization, it's likely that there are some issues with the camera parameters. 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /camera_inspector/inspect_epipolar_geometry.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import json 5 | 6 | 7 | def skew(x): 8 | return np.array([[0, -x[2], x[1]], 9 | [x[2], 0, -x[0]], 10 | [-x[1], x[0], 0]]) 11 | 12 | 13 | def two_view_geometry(intrinsics1, extrinsics1, intrinsics2, extrinsics2): 14 | relative_pose = extrinsics2.dot(np.linalg.inv(extrinsics1)) 15 | R = relative_pose[:3, :3] 16 | T = relative_pose[:3, 3] 17 | tx = skew(T) 18 | E = np.dot(tx, R) 19 | F = np.linalg.inv(intrinsics2[:3, :3]).T.dot(E).dot(np.linalg.inv(intrinsics1[:3, :3])) 20 | 21 | return E, F, relative_pose 22 | 23 | 24 | def drawpointslines(img1, pts1, img2, lines2, colors): 25 | h, w = img2.shape[:2] 26 | # img1 = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR) 27 | # img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR) 28 | print(pts1.shape, lines2.shape, colors.shape) 29 | for p, l, c in zip(pts1, lines2, colors): 30 | c = tuple(c.tolist()) 31 | img1 = cv2.circle(img1, tuple(p), 5, c, -1) 32 | 33 | x0, y0 = map(int, [0, -l[2]/l[1]]) 34 | x1, y1 = map(int, [w, -(l[2]+l[0]*w)/l[1]]) 35 | img2 = cv2.line(img2, (x0, y0), (x1, y1), c, 1, lineType=cv2.LINE_AA) 36 | return img1, img2 37 | 38 | 39 | def inspect(img1, K1, W2C1, img2, K2, W2C2): 40 | E, F, relative_pose = two_view_geometry(K1, W2C1, K2, W2C2) 41 | 42 | orb = cv2.ORB_create() 43 | kp1 = orb.detect(cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY), None)[:20] 44 | pts1 = np.array([[int(kp.pt[0]), int(kp.pt[1])] for kp in kp1]) 45 | 46 | lines2 = cv2.computeCorrespondEpilines(pts1.reshape(-1,1,2), 1, F) 47 | lines2 = lines2.reshape(-1, 3) 48 | 49 | colors = np.random.randint(0, high=255, size=(len(pts1), 3)) 50 | 51 | img1, img2 = drawpointslines(img1, pts1, img2, lines2, colors) 52 | 53 | im_to_show = np.concatenate((img1, img2), axis=1) 54 | # down sample to fit screen 55 | h, w = im_to_show.shape[:2] 56 | im_to_show = cv2.resize(im_to_show, (int(0.5*w), int(0.5*h)), interpolation=cv2.INTER_AREA) 57 | cv2.imshow('epipolar geometry', im_to_show) 58 | cv2.waitKey(0) 59 | cv2.destroyAllWindows() 60 | 61 | 62 | if __name__ == '__main__': 63 | base_dir = './' 64 | 65 | img_dir = os.path.join(base_dir, 'train/rgb') 66 | cam_dict_file = os.path.join(base_dir, 'train/cam_dict_norm.json') 67 | img_name1 = '000001.png' 68 | img_name2 = '000005.png' 69 | 70 | cam_dict = json.load(open(cam_dict_file)) 71 | img1 = cv2.imread(os.path.join(img_dir, img_name1)) 72 | K1 = np.array(cam_dict[img_name1]['K']).reshape((4, 4)) 73 | W2C1 = np.array(cam_dict[img_name1]['W2C']).reshape((4, 4)) 74 | 75 | img2 = cv2.imread(os.path.join(img_dir, img_name2)) 76 | K2 = np.array(cam_dict[img_name2]['K']).reshape((4, 4)) 77 | W2C2 = np.array(cam_dict[img_name2]['W2C']).reshape((4, 4)) 78 | 79 | inspect(img1, K1, W2C1, img2, K2, W2C2) 80 | 81 | -------------------------------------------------------------------------------- /camera_inspector/screenshot_lowres.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_inspector/screenshot_lowres.png -------------------------------------------------------------------------------- /camera_inspector/train/rgb/000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_inspector/train/rgb/000001.png -------------------------------------------------------------------------------- /camera_inspector/train/rgb/000005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_inspector/train/rgb/000005.png -------------------------------------------------------------------------------- /camera_visualizer/mesh_norm.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_visualizer/mesh_norm.ply -------------------------------------------------------------------------------- /camera_visualizer/screenshot_lowres.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/camera_visualizer/screenshot_lowres.png -------------------------------------------------------------------------------- /camera_visualizer/test/cam_dict_norm.json: -------------------------------------------------------------------------------- 1 | { 2 | "000173.png": { 3 | "K": [ 4 | 581.7877197265625, 5 | 0.0, 6 | 490.25, 7 | 0.0, 8 | 0.0, 9 | 581.7877197265625, 10 | 272.75, 11 | 0.0, 12 | 0.0, 13 | 0.0, 14 | 1.0, 15 | 0.0, 16 | 0.0, 17 | 0.0, 18 | 0.0, 19 | 1.0 20 | ], 21 | "W2C": [ 22 | -0.6749785542488097, 23 | -0.10284826159477237, 24 | 0.730634093284607, 25 | -0.1415396263044849, 26 | 0.1326861083507538, 27 | 0.9571743607521058, 28 | 0.2573162019252777, 29 | -0.11108331785106688, 30 | -0.7258087396621704, 31 | 0.2706279158592224, 32 | -0.6324256062507629, 33 | 0.5185927744415963, 34 | 0.0, 35 | 0.0, 36 | 0.0, 37 | 1.0 38 | ], 39 | "img_size": [ 40 | 980, 41 | 546 42 | ] 43 | }, 44 | "000174.png": { 45 | "K": [ 46 | 581.7877197265625, 47 | 0.0, 48 | 490.25, 49 | 0.0, 50 | 0.0, 51 | 581.7877197265625, 52 | 272.75, 53 | 0.0, 54 | 0.0, 55 | 0.0, 56 | 1.0, 57 | 0.0, 58 | 0.0, 59 | 0.0, 60 | 0.0, 61 | 1.0 62 | ], 63 | "W2C": [ 64 | -0.6433882117271422, 65 | -0.10309550166130069, 66 | 0.758566379547119, 67 | -0.10336116912543912, 68 | 0.142945721745491, 69 | 0.9572839736938477, 70 | 0.2513442039489746, 71 | -0.10921025065909233, 72 | -0.7520759105682373, 73 | 0.27014571428298956, 74 | -0.601168155670166, 75 | 0.5231507731285338, 76 | 0.0, 77 | 0.0, 78 | 0.0, 79 | 1.0 80 | ], 81 | "img_size": [ 82 | 980, 83 | 546 84 | ] 85 | }, 86 | "000175.png": { 87 | "K": [ 88 | 581.7877197265625, 89 | 0.0, 90 | 490.25, 91 | 0.0, 92 | 0.0, 93 | 581.7877197265625, 94 | 272.75, 95 | 0.0, 96 | 0.0, 97 | 0.0, 98 | 1.0, 99 | 0.0, 100 | 0.0, 101 | 0.0, 102 | 0.0, 103 | 1.0 104 | ], 105 | "W2C": [ 106 | -0.6209585666656491, 107 | -0.10511306673288354, 108 | 0.776763617992401, 109 | -0.06012172693562135, 110 | 0.1500456780195236, 111 | 0.9567026495933535, 112 | 0.24941192567348483, 113 | -0.11759666418072104, 114 | -0.7693482637405393, 115 | 0.27142450213432306, 116 | -0.578300952911377, 117 | 0.558755932058601, 118 | 0.0, 119 | 0.0, 120 | 0.0, 121 | 1.0 122 | ], 123 | "img_size": [ 124 | 980, 125 | 546 126 | ] 127 | }, 128 | "000176.png": { 129 | "K": [ 130 | 581.7877197265625, 131 | 0.0, 132 | 490.25, 133 | 0.0, 134 | 0.0, 135 | 581.7877197265625, 136 | 272.75, 137 | 0.0, 138 | 0.0, 139 | 0.0, 140 | 1.0, 141 | 0.0, 142 | 0.0, 143 | 0.0, 144 | 0.0, 145 | 1.0 146 | ], 147 | "W2C": [ 148 | -0.6149386763572693, 149 | -0.10690607875585548, 150 | 0.7812947630882263, 151 | -0.0004664353357674536, 152 | 0.15467639267444613, 153 | 0.9551697969436644, 154 | 0.2524398863315582, 155 | -0.12307422643556815, 156 | -0.7732564806938171, 157 | 0.27608290314674383, 158 | -0.5708349943161011, 159 | 0.5910430066979042, 160 | 0.0, 161 | 0.0, 162 | 0.0, 163 | 1.0 164 | ], 165 | "img_size": [ 166 | 980, 167 | 546 168 | ] 169 | }, 170 | "000177.png": { 171 | "K": [ 172 | 581.7877197265625, 173 | 0.0, 174 | 490.25, 175 | 0.0, 176 | 0.0, 177 | 581.7877197265625, 178 | 272.75, 179 | 0.0, 180 | 0.0, 181 | 0.0, 182 | 1.0, 183 | 0.0, 184 | 0.0, 185 | 0.0, 186 | 0.0, 187 | 1.0 188 | ], 189 | "W2C": [ 190 | -0.5901945233345031, 191 | -0.11034809798002247, 192 | 0.7996835112571716, 193 | 0.04891851798678256, 194 | 0.16408699750900269, 195 | 0.9535346627235411, 196 | 0.2526799142360687, 197 | -0.13504251807839815, 198 | -0.790408730506897, 199 | 0.28034797310829157, 200 | -0.544664204120636, 201 | 0.6147589251643815, 202 | 0.0, 203 | 0.0, 204 | 0.0, 205 | 1.0 206 | ], 207 | "img_size": [ 208 | 980, 209 | 546 210 | ] 211 | }, 212 | "000178.png": { 213 | "K": [ 214 | 581.7877197265625, 215 | 0.0, 216 | 490.25, 217 | 0.0, 218 | 0.0, 219 | 581.7877197265625, 220 | 272.75, 221 | 0.0, 222 | 0.0, 223 | 0.0, 224 | 1.0, 225 | 0.0, 226 | 0.0, 227 | 0.0, 228 | 0.0, 229 | 1.0 230 | ], 231 | "W2C": [ 232 | -0.5680983662605286, 233 | -0.11199841648340227, 234 | 0.8153039813041688, 235 | 0.11406040174495254, 236 | 0.17730242013931277, 237 | 0.9507739543914796, 238 | 0.2541510760784149, 239 | -0.13489467795265128, 240 | -0.8036343455314636, 241 | 0.28893816471099854, 242 | -0.5202755331993103, 243 | 0.6091760132826988, 244 | 0.0, 245 | 0.0, 246 | 0.0, 247 | 1.0 248 | ], 249 | "img_size": [ 250 | 980, 251 | 546 252 | ] 253 | }, 254 | "000179.png": { 255 | "K": [ 256 | 581.7877197265625, 257 | 0.0, 258 | 490.25, 259 | 0.0, 260 | 0.0, 261 | 581.7877197265625, 262 | 272.75, 263 | 0.0, 264 | 0.0, 265 | 0.0, 266 | 1.0, 267 | 0.0, 268 | 0.0, 269 | 0.0, 270 | 0.0, 271 | 1.0 272 | ], 273 | "W2C": [ 274 | -0.5251939892768861, 275 | -0.11293330788612364, 276 | 0.8434556126594543, 277 | 0.15233507599020302, 278 | 0.19423152506351476, 279 | 0.9490842819213869, 280 | 0.24801832437515262, 281 | -0.13645353989132347, 282 | -0.8285199999809266, 283 | 0.2940833866596223, 284 | -0.47651815414428705, 285 | 0.6077169856718815, 286 | 0.0, 287 | 0.0, 288 | 0.0, 289 | 1.0 290 | ], 291 | "img_size": [ 292 | 980, 293 | 546 294 | ] 295 | }, 296 | "000180.png": { 297 | "K": [ 298 | 581.7877197265625, 299 | 0.0, 300 | 490.25, 301 | 0.0, 302 | 0.0, 303 | 581.7877197265625, 304 | 272.75, 305 | 0.0, 306 | 0.0, 307 | 0.0, 308 | 1.0, 309 | 0.0, 310 | 0.0, 311 | 0.0, 312 | 0.0, 313 | 1.0 314 | ], 315 | "W2C": [ 316 | -0.4368840157985687, 317 | -0.12017646431922917, 318 | 0.8914538621902465, 319 | 0.1584227767276896, 320 | 0.21023185551166537, 321 | 0.9499467611312866, 322 | 0.23109236359596255, 323 | -0.14294346814956563, 324 | -0.8746055960655214, 325 | 0.28837254643440247, 326 | -0.3897516429424286, 327 | 0.6279825376657556, 328 | 0.0, 329 | 0.0, 330 | 0.0, 331 | 1.0 332 | ], 333 | "img_size": [ 334 | 980, 335 | 546 336 | ] 337 | }, 338 | "000181.png": { 339 | "K": [ 340 | 581.7877197265625, 341 | 0.0, 342 | 490.25, 343 | 0.0, 344 | 0.0, 345 | 581.7877197265625, 346 | 272.75, 347 | 0.0, 348 | 0.0, 349 | 0.0, 350 | 1.0, 351 | 0.0, 352 | 0.0, 353 | 0.0, 354 | 0.0, 355 | 1.0 356 | ], 357 | "W2C": [ 358 | -0.3019417822360992, 359 | -0.12455697357654576, 360 | 0.9451543092727661, 361 | 0.12391406697080755, 362 | 0.23592309653758997, 363 | 0.9508262872695923, 364 | 0.2006731331348419, 365 | -0.15344331535910236, 366 | -0.9236727952957151, 367 | 0.2835753560066222, 368 | -0.2577083110809326, 369 | 0.649313956576192, 370 | 0.0, 371 | 0.0, 372 | 0.0, 373 | 1.0 374 | ], 375 | "img_size": [ 376 | 980, 377 | 546 378 | ] 379 | }, 380 | "000182.png": { 381 | "K": [ 382 | 581.7877197265625, 383 | 0.0, 384 | 490.25, 385 | 0.0, 386 | 0.0, 387 | 581.7877197265625, 388 | 272.75, 389 | 0.0, 390 | 0.0, 391 | 0.0, 392 | 1.0, 393 | 0.0, 394 | 0.0, 395 | 0.0, 396 | 0.0, 397 | 1.0 398 | ], 399 | "W2C": [ 400 | -0.05365315079689027, 401 | -0.13223105669021606, 402 | 0.9897657632827759, 403 | 0.008633672276731267, 404 | 0.256687343120575, 405 | 0.9560590982437135, 406 | 0.14164239168167117, 407 | -0.1679039701011501, 408 | -0.9650041460990905, 409 | 0.26165989041328425, 410 | -0.01735354587435723, 411 | 0.6778822326079519, 412 | 0.0, 413 | 0.0, 414 | 0.0, 415 | 1.0 416 | ], 417 | "img_size": [ 418 | 980, 419 | 546 420 | ] 421 | }, 422 | "000183.png": { 423 | "K": [ 424 | 581.7877197265625, 425 | 0.0, 426 | 490.25, 427 | 0.0, 428 | 0.0, 429 | 581.7877197265625, 430 | 272.75, 431 | 0.0, 432 | 0.0, 433 | 0.0, 434 | 1.0, 435 | 0.0, 436 | 0.0, 437 | 0.0, 438 | 0.0, 439 | 1.0 440 | ], 441 | "W2C": [ 442 | 0.08435853570699692, 443 | -0.11229652166366576, 444 | 0.9900874495506288, 445 | -0.02667809703294818, 446 | 0.2744568288326264, 447 | 0.9578129649162292, 448 | 0.08525134623050688, 449 | -0.18123123987972412, 450 | -0.9578920006752015, 451 | 0.2645445764064789, 452 | 0.1116202473640442, 453 | 0.68498322092808, 454 | 0.0, 455 | 0.0, 456 | 0.0, 457 | 1.0 458 | ], 459 | "img_size": [ 460 | 980, 461 | 546 462 | ] 463 | }, 464 | "000184.png": { 465 | "K": [ 466 | 581.7877197265625, 467 | 0.0, 468 | 490.25, 469 | 0.0, 470 | 0.0, 471 | 581.7877197265625, 472 | 272.75, 473 | 0.0, 474 | 0.0, 475 | 0.0, 476 | 1.0, 477 | 0.0, 478 | 0.0, 479 | 0.0, 480 | 0.0, 481 | 1.0 482 | ], 483 | "W2C": [ 484 | 0.173485592007637, 485 | -0.09174104034900667, 486 | 0.9805541038513182, 487 | -0.029180471150004907, 488 | 0.29095843434333796, 489 | 0.955982208251953, 490 | 0.03796394541859627, 491 | -0.2051112120771908, 492 | -0.9408751130104065, 493 | 0.278714269399643, 494 | 0.1925419718027115, 495 | 0.7113401071184143, 496 | 0.0, 497 | 0.0, 498 | 0.0, 499 | 1.0 500 | ], 501 | "img_size": [ 502 | 980, 503 | 546 504 | ] 505 | }, 506 | "000185.png": { 507 | "K": [ 508 | 581.7877197265625, 509 | 0.0, 510 | 490.25, 511 | 0.0, 512 | 0.0, 513 | 581.7877197265625, 514 | 272.75, 515 | 0.0, 516 | 0.0, 517 | 0.0, 518 | 1.0, 519 | 0.0, 520 | 0.0, 521 | 0.0, 522 | 0.0, 523 | 1.0 524 | ], 525 | "W2C": [ 526 | 0.23342822492122645, 527 | -0.07284945249557495, 528 | 0.9696412682533264, 529 | -0.005887109452774881, 530 | 0.29702088236808777, 531 | 0.954870939254761, 532 | 0.0002359295467613364, 533 | -0.21843709467707623, 534 | -0.9258995056152343, 535 | 0.28794863820075983, 536 | 0.24453164637088773, 537 | 0.7118756075295797, 538 | 0.0, 539 | 0.0, 540 | 0.0, 541 | 1.0 542 | ], 543 | "img_size": [ 544 | 980, 545 | 546 546 | ] 547 | }, 548 | "000186.png": { 549 | "K": [ 550 | 581.7877197265625, 551 | 0.0, 552 | 490.25, 553 | 0.0, 554 | 0.0, 555 | 581.7877197265625, 556 | 272.75, 557 | 0.0, 558 | 0.0, 559 | 0.0, 560 | 1.0, 561 | 0.0, 562 | 0.0, 563 | 0.0, 564 | 0.0, 565 | 1.0 566 | ], 567 | "W2C": [ 568 | 0.3044866621494293, 569 | -0.05772994086146357, 570 | 0.9507655501365662, 571 | 0.005409785790216546, 572 | 0.29629379510879517, 573 | 0.9543822407722473, 574 | -0.0369397886097431, 575 | -0.2284336673437035, 576 | -0.905261218547821, 577 | 0.2929536104202271, 578 | 0.30770167708396906, 579 | 0.7186774152057739, 580 | 0.0, 581 | 0.0, 582 | 0.0, 583 | 1.0 584 | ], 585 | "img_size": [ 586 | 980, 587 | 546 588 | ] 589 | }, 590 | "000187.png": { 591 | "K": [ 592 | 581.7877197265625, 593 | 0.0, 594 | 490.25, 595 | 0.0, 596 | 0.0, 597 | 581.7877197265625, 598 | 272.75, 599 | 0.0, 600 | 0.0, 601 | 0.0, 602 | 1.0, 603 | 0.0, 604 | 0.0, 605 | 0.0, 606 | 0.0, 607 | 1.0 608 | ], 609 | "W2C": [ 610 | 0.37352681159973145, 611 | -0.042923863977193846, 612 | 0.9266257286071778, 613 | 0.013849648484600553, 614 | 0.2894437611103058, 615 | 0.9544482231140137, 616 | -0.07246334105730054, 617 | -0.2333756901439718, 618 | -0.8813058733940125, 619 | 0.29527303576469427, 620 | 0.3689360618591308, 621 | 0.7126951784083984, 622 | 0.0, 623 | 0.0, 624 | 0.0, 625 | 1.0 626 | ], 627 | "img_size": [ 628 | 980, 629 | 546 630 | ] 631 | }, 632 | "000188.png": { 633 | "K": [ 634 | 581.7877197265625, 635 | 0.0, 636 | 490.25, 637 | 0.0, 638 | 0.0, 639 | 581.7877197265625, 640 | 272.75, 641 | 0.0, 642 | 0.0, 643 | 0.0, 644 | 1.0, 645 | 0.0, 646 | 0.0, 647 | 0.0, 648 | 0.0, 649 | 1.0 650 | ], 651 | "W2C": [ 652 | 0.49691322445869446, 653 | -0.03459263965487483, 654 | 0.8671104907989503, 655 | -0.02804523469799382, 656 | 0.27165842056274414, 657 | 0.9551849961280822, 658 | -0.11757244169712067, 659 | -0.23797951926387803, 660 | -0.8241838216781616, 661 | 0.2939811646938324, 662 | 0.48404145240783697, 663 | 0.7082496879378362, 664 | 0.0, 665 | 0.0, 666 | 0.0, 667 | 1.0 668 | ], 669 | "img_size": [ 670 | 980, 671 | 546 672 | ] 673 | }, 674 | "000189.png": { 675 | "K": [ 676 | 581.7877197265625, 677 | 0.0, 678 | 490.25, 679 | 0.0, 680 | 0.0, 681 | 581.7877197265625, 682 | 272.75, 683 | 0.0, 684 | 0.0, 685 | 0.0, 686 | 1.0, 687 | 0.0, 688 | 0.0, 689 | 0.0, 690 | 0.0, 691 | 1.0 692 | ], 693 | "W2C": [ 694 | 0.656022787094116, 695 | -0.028436420485377232, 696 | 0.7542051672935486, 697 | -0.11063643763864509, 698 | 0.2370839267969131, 699 | 0.9564713835716248, 700 | -0.17015771567821505, 701 | -0.23631046162031968, 702 | -0.7165369987487792, 703 | 0.29043728113174433, 704 | 0.6342088580131531, 705 | 0.6869845554740425, 706 | 0.0, 707 | 0.0, 708 | 0.0, 709 | 1.0 710 | ], 711 | "img_size": [ 712 | 980, 713 | 546 714 | ] 715 | }, 716 | "000190.png": { 717 | "K": [ 718 | 581.7877197265625, 719 | 0.0, 720 | 490.25, 721 | 0.0, 722 | 0.0, 723 | 581.7877197265625, 724 | 272.75, 725 | 0.0, 726 | 0.0, 727 | 0.0, 728 | 1.0, 729 | 0.0, 730 | 0.0, 731 | 0.0, 732 | 0.0, 733 | 1.0 734 | ], 735 | "W2C": [ 736 | 0.7924461960792542, 737 | -0.01046304591000084, 738 | 0.6098520755767822, 739 | -0.19199604683373434, 740 | 0.18565021455287933, 741 | 0.956550121307373, 742 | -0.22482399642467502, 743 | -0.23064418498320746, 744 | -0.5810017585754395, 745 | 0.2913800776004791, 746 | 0.759956955909729, 747 | 0.6486130335907647, 748 | 0.0, 749 | 0.0, 750 | 0.0, 751 | 1.0 752 | ], 753 | "img_size": [ 754 | 980, 755 | 546 756 | ] 757 | }, 758 | "000191.png": { 759 | "K": [ 760 | 581.7877197265625, 761 | 0.0, 762 | 490.25, 763 | 0.0, 764 | 0.0, 765 | 581.7877197265625, 766 | 272.75, 767 | 0.0, 768 | 0.0, 769 | 0.0, 770 | 1.0, 771 | 0.0, 772 | 0.0, 773 | 0.0, 774 | 0.0, 775 | 1.0 776 | ], 777 | "W2C": [ 778 | 0.8854711055755615, 779 | 0.010563611052930369, 780 | 0.46457436680793757, 781 | -0.24282484537054266, 782 | 0.1252961456775665, 783 | 0.9572874307632447, 784 | -0.26057946681976313, 785 | -0.21866898323124923, 786 | -0.44748386740684487, 787 | 0.288944959640503, 788 | 0.8463267683982846, 789 | 0.6151146596650793, 790 | 0.0, 791 | 0.0, 792 | 0.0, 793 | 1.0 794 | ], 795 | "img_size": [ 796 | 980, 797 | 546 798 | ] 799 | }, 800 | "000192.png": { 801 | "K": [ 802 | 581.7877197265625, 803 | 0.0, 804 | 490.25, 805 | 0.0, 806 | 0.0, 807 | 581.7877197265625, 808 | 272.75, 809 | 0.0, 810 | 0.0, 811 | 0.0, 812 | 1.0, 813 | 0.0, 814 | 0.0, 815 | 0.0, 816 | 0.0, 817 | 1.0 818 | ], 819 | "W2C": [ 820 | 0.9388862252235413, 821 | 0.03694019094109534, 822 | 0.3422397673130036, 823 | -0.27164514724995453, 824 | 0.06416944414377213, 825 | 0.9580151438713074, 826 | -0.27944466471672064, 827 | -0.20499586789544427, 828 | -0.33819359540939326, 829 | 0.28432807326316833, 830 | 0.8970967531204224, 831 | 0.5788967097431258, 832 | 0.0, 833 | 0.0, 834 | 0.0, 835 | 1.0 836 | ], 837 | "img_size": [ 838 | 980, 839 | 546 840 | ] 841 | }, 842 | "000193.png": { 843 | "K": [ 844 | 581.7877197265625, 845 | 0.0, 846 | 490.25, 847 | 0.0, 848 | 0.0, 849 | 581.7877197265625, 850 | 272.75, 851 | 0.0, 852 | 0.0, 853 | 0.0, 854 | 1.0, 855 | 0.0, 856 | 0.0, 857 | 0.0, 858 | 0.0, 859 | 1.0 860 | ], 861 | "W2C": [ 862 | 0.9651142954826355, 863 | 0.05770149827003481, 864 | 0.2553918063640595, 865 | -0.272934838794894, 866 | 0.015767628327012072, 867 | 0.9608356356620789, 868 | -0.27666985988616943, 869 | -0.19339081801285132, 870 | -0.2613538205623627, 871 | 0.2710449695587158, 872 | 0.9264063835144044, 873 | 0.5628077025323545, 874 | 0.0, 875 | 0.0, 876 | 0.0, 877 | 1.0 878 | ], 879 | "img_size": [ 880 | 980, 881 | 546 882 | ] 883 | }, 884 | "000194.png": { 885 | "K": [ 886 | 581.7877197265625, 887 | 0.0, 888 | 490.25, 889 | 0.0, 890 | 0.0, 891 | 581.7877197265625, 892 | 272.75, 893 | 0.0, 894 | 0.0, 895 | 0.0, 896 | 1.0, 897 | 0.0, 898 | 0.0, 899 | 0.0, 900 | 0.0, 901 | 1.0 902 | ], 903 | "W2C": [ 904 | 0.9572978019714354, 905 | 0.06780177354812619, 906 | 0.28104069828987116, 907 | -0.20122927411147834, 908 | 0.007655009161680944, 909 | 0.9658248424530029, 910 | -0.25908261537551885, 911 | -0.19321847581034227, 912 | -0.2890023589134216, 913 | 0.2501705884933471, 914 | 0.924062967300415, 915 | 0.563575894892584, 916 | 0.0, 917 | 0.0, 918 | 0.0, 919 | 1.0 920 | ], 921 | "img_size": [ 922 | 980, 923 | 546 924 | ] 925 | }, 926 | "000195.png": { 927 | "K": [ 928 | 581.7877197265625, 929 | 0.0, 930 | 490.25, 931 | 0.0, 932 | 0.0, 933 | 581.7877197265625, 934 | 272.75, 935 | 0.0, 936 | 0.0, 937 | 0.0, 938 | 1.0, 939 | 0.0, 940 | 0.0, 941 | 0.0, 942 | 0.0, 943 | 1.0 944 | ], 945 | "W2C": [ 946 | 0.9394192695617677, 947 | 0.040031708776950836, 948 | 0.3404245376586914, 949 | -0.10805127469870195, 950 | 0.03222170472145081, 951 | 0.9784454107284546, 952 | -0.20397628843784332, 953 | -0.17995865052818397, 954 | -0.34125232696533203, 955 | 0.20258831977844238, 956 | 0.917880594730377, 957 | 0.5956252623341279, 958 | 0.0, 959 | 0.0, 960 | 0.0, 961 | 1.0 962 | ], 963 | "img_size": [ 964 | 980, 965 | 546 966 | ] 967 | }, 968 | "000196.png": { 969 | "K": [ 970 | 581.7877197265625, 971 | 0.0, 972 | 490.25, 973 | 0.0, 974 | 0.0, 975 | 581.7877197265625, 976 | 272.75, 977 | 0.0, 978 | 0.0, 979 | 0.0, 980 | 1.0, 981 | 0.0, 982 | 0.0, 983 | 0.0, 984 | 0.0, 985 | 1.0 986 | ], 987 | "W2C": [ 988 | 0.9188615679740906, 989 | 0.01963340304791927, 990 | 0.3940912783145905, 991 | -0.02124089688202119, 992 | 0.0335158184170723, 993 | 0.9912682771682739, 994 | -0.1275297701358795, 995 | -0.13773930218650604, 996 | -0.39315405488014227, 997 | 0.1303904950618744, 998 | 0.9101803302764893, 999 | 0.590018327758047, 1000 | 0.0, 1001 | 0.0, 1002 | 0.0, 1003 | 1.0 1004 | ], 1005 | "img_size": [ 1006 | 980, 1007 | 546 1008 | ] 1009 | }, 1010 | "000197.png": { 1011 | "K": [ 1012 | 581.7877197265625, 1013 | 0.0, 1014 | 490.25, 1015 | 0.0, 1016 | 0.0, 1017 | 581.7877197265625, 1018 | 272.75, 1019 | 0.0, 1020 | 0.0, 1021 | 0.0, 1022 | 1.0, 1023 | 0.0, 1024 | 0.0, 1025 | 0.0, 1026 | 0.0, 1027 | 1.0 1028 | ], 1029 | "W2C": [ 1030 | 0.9298350214958191, 1031 | -0.002477371366694553, 1032 | 0.3679683208465576, 1033 | 0.023792225073484205, 1034 | 0.04552911967039108, 1035 | 0.9930682182312012, 1036 | -0.10836359858512877, 1037 | -0.13308144421741533, 1038 | -0.3651491701602936, 1039 | 0.11751354485750198, 1040 | 0.9235023856163025, 1041 | 0.5847770507172574, 1042 | 0.0, 1043 | 0.0, 1044 | 0.0, 1045 | 1.0 1046 | ], 1047 | "img_size": [ 1048 | 980, 1049 | 546 1050 | ] 1051 | } 1052 | } -------------------------------------------------------------------------------- /camera_visualizer/visualize_cameras.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import json 3 | import numpy as np 4 | 5 | 6 | def get_camera_frustum(img_size, K, W2C, frustum_length=0.5, color=[0., 1., 0.]): 7 | W, H = img_size 8 | hfov = np.rad2deg(np.arctan(W / 2. / K[0, 0]) * 2.) 9 | vfov = np.rad2deg(np.arctan(H / 2. / K[1, 1]) * 2.) 10 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.)) 11 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.)) 12 | 13 | # build view frustum for camera (I, 0) 14 | frustum_points = np.array([[0., 0., 0.], # frustum origin 15 | [-half_w, -half_h, frustum_length], # top-left image corner 16 | [half_w, -half_h, frustum_length], # top-right image corner 17 | [half_w, half_h, frustum_length], # bottom-right image corner 18 | [-half_w, half_h, frustum_length]]) # bottom-left image corner 19 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) 20 | frustum_colors = np.tile(np.array(color).reshape((1, 3)), (frustum_lines.shape[0], 1)) 21 | 22 | # frustum_colors = np.vstack((np.tile(np.array([[1., 0., 0.]]), (4, 1)), 23 | # np.tile(np.array([[0., 1., 0.]]), (4, 1)))) 24 | 25 | # transform view frustum from (I, 0) to (R, t) 26 | C2W = np.linalg.inv(W2C) 27 | frustum_points = np.dot(np.hstack((frustum_points, np.ones_like(frustum_points[:, 0:1]))), C2W.T) 28 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] 29 | 30 | return frustum_points, frustum_lines, frustum_colors 31 | 32 | 33 | def frustums2lineset(frustums): 34 | N = len(frustums) 35 | merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum 36 | merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum 37 | merged_colors = np.zeros((N*8, 3)) # each line gets a color 38 | 39 | for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums): 40 | merged_points[i*5:(i+1)*5, :] = frustum_points 41 | merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5 42 | merged_colors[i*8:(i+1)*8, :] = frustum_colors 43 | 44 | lineset = o3d.geometry.LineSet() 45 | lineset.points = o3d.utility.Vector3dVector(merged_points) 46 | lineset.lines = o3d.utility.Vector2iVector(merged_lines) 47 | lineset.colors = o3d.utility.Vector3dVector(merged_colors) 48 | 49 | return lineset 50 | 51 | def visualize_cameras(colored_camera_dicts, sphere_radius, camera_size=0.1, geometry_file=None, geometry_type='mesh'): 52 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=sphere_radius, resolution=10) 53 | sphere = o3d.geometry.LineSet.create_from_triangle_mesh(sphere) 54 | sphere.paint_uniform_color((1, 0, 0)) 55 | 56 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0., 0., 0.]) 57 | things_to_draw = [sphere, coord_frame] 58 | 59 | idx = 0 60 | for color, camera_dict in colored_camera_dicts: 61 | idx += 1 62 | 63 | cnt = 0 64 | frustums = [] 65 | for img_name in sorted(camera_dict.keys()): 66 | K = np.array(camera_dict[img_name]['K']).reshape((4, 4)) 67 | W2C = np.array(camera_dict[img_name]['W2C']).reshape((4, 4)) 68 | C2W = np.linalg.inv(W2C) 69 | img_size = camera_dict[img_name]['img_size'] 70 | frustums.append(get_camera_frustum(img_size, K, W2C, frustum_length=camera_size, color=color)) 71 | cnt += 1 72 | cameras = frustums2lineset(frustums) 73 | things_to_draw.append(cameras) 74 | 75 | if geometry_file is not None: 76 | if geometry_type == 'mesh': 77 | geometry = o3d.io.read_triangle_mesh(geometry_file) 78 | geometry.compute_vertex_normals() 79 | elif geometry_type == 'pointcloud': 80 | geometry = o3d.io.read_point_cloud(geometry_file) 81 | else: 82 | raise Exception('Unknown geometry_type: ', geometry_type) 83 | 84 | things_to_draw.append(geometry) 85 | 86 | o3d.visualization.draw_geometries(things_to_draw) 87 | 88 | 89 | if __name__ == '__main__': 90 | import os 91 | 92 | base_dir = './' 93 | 94 | sphere_radius = 1. 95 | train_cam_dict = json.load(open(os.path.join(base_dir, 'train/cam_dict_norm.json'))) 96 | test_cam_dict = json.load(open(os.path.join(base_dir, 'test/cam_dict_norm.json'))) 97 | path_cam_dict = json.load(open(os.path.join(base_dir, 'camera_path/cam_dict_norm.json'))) 98 | camera_size = 0.1 99 | colored_camera_dicts = [([0, 1, 0], train_cam_dict), 100 | ([0, 0, 1], test_cam_dict), 101 | ([1, 1, 0], path_cam_dict) 102 | ] 103 | 104 | geometry_file = os.path.join(base_dir, 'mesh_norm.ply') 105 | geometry_type = 'mesh' 106 | 107 | visualize_cameras(colored_camera_dicts, sphere_radius, 108 | camera_size=camera_size, geometry_file=geometry_file, geometry_type=geometry_type) 109 | -------------------------------------------------------------------------------- /colmap_runner/database.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | # This script is based on an original implementation by True Price. 33 | 34 | import sys 35 | import sqlite3 36 | import numpy as np 37 | 38 | 39 | IS_PYTHON3 = sys.version_info[0] >= 3 40 | 41 | MAX_IMAGE_ID = 2**31 - 1 42 | 43 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( 44 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 45 | model INTEGER NOT NULL, 46 | width INTEGER NOT NULL, 47 | height INTEGER NOT NULL, 48 | params BLOB, 49 | prior_focal_length INTEGER NOT NULL)""" 50 | 51 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( 52 | image_id INTEGER PRIMARY KEY NOT NULL, 53 | rows INTEGER NOT NULL, 54 | cols INTEGER NOT NULL, 55 | data_500 BLOB, 56 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 57 | 58 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( 59 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 60 | name TEXT NOT NULL UNIQUE, 61 | camera_id INTEGER NOT NULL, 62 | prior_qw REAL, 63 | prior_qx REAL, 64 | prior_qy REAL, 65 | prior_qz REAL, 66 | prior_tx REAL, 67 | prior_ty REAL, 68 | prior_tz REAL, 69 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}), 70 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id)) 71 | """.format(MAX_IMAGE_ID) 72 | 73 | CREATE_TWO_VIEW_GEOMETRIES_TABLE = """ 74 | CREATE TABLE IF NOT EXISTS two_view_geometries ( 75 | pair_id INTEGER PRIMARY KEY NOT NULL, 76 | rows INTEGER NOT NULL, 77 | cols INTEGER NOT NULL, 78 | data_500 BLOB, 79 | config INTEGER NOT NULL, 80 | F BLOB, 81 | E BLOB, 82 | H BLOB) 83 | """ 84 | 85 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( 86 | image_id INTEGER PRIMARY KEY NOT NULL, 87 | rows INTEGER NOT NULL, 88 | cols INTEGER NOT NULL, 89 | data_500 BLOB, 90 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE) 91 | """ 92 | 93 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( 94 | pair_id INTEGER PRIMARY KEY NOT NULL, 95 | rows INTEGER NOT NULL, 96 | cols INTEGER NOT NULL, 97 | data_500 BLOB)""" 98 | 99 | CREATE_NAME_INDEX = \ 100 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" 101 | 102 | CREATE_ALL = "; ".join([ 103 | CREATE_CAMERAS_TABLE, 104 | CREATE_IMAGES_TABLE, 105 | CREATE_KEYPOINTS_TABLE, 106 | CREATE_DESCRIPTORS_TABLE, 107 | CREATE_MATCHES_TABLE, 108 | CREATE_TWO_VIEW_GEOMETRIES_TABLE, 109 | CREATE_NAME_INDEX 110 | ]) 111 | 112 | 113 | def image_ids_to_pair_id(image_id1, image_id2): 114 | if image_id1 > image_id2: 115 | image_id1, image_id2 = image_id2, image_id1 116 | return image_id1 * MAX_IMAGE_ID + image_id2 117 | 118 | 119 | def pair_id_to_image_ids(pair_id): 120 | image_id2 = pair_id % MAX_IMAGE_ID 121 | image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID 122 | return image_id1, image_id2 123 | 124 | 125 | def array_to_blob(array): 126 | if IS_PYTHON3: 127 | return array.tostring() 128 | else: 129 | return np.getbuffer(array) 130 | 131 | 132 | def blob_to_array(blob, dtype, shape=(-1,)): 133 | if IS_PYTHON3: 134 | return np.fromstring(blob, dtype=dtype).reshape(*shape) 135 | else: 136 | return np.frombuffer(blob, dtype=dtype).reshape(*shape) 137 | 138 | 139 | class COLMAPDatabase(sqlite3.Connection): 140 | 141 | @staticmethod 142 | def connect(database_path): 143 | return sqlite3.connect(database_path, factory=COLMAPDatabase) 144 | 145 | 146 | def __init__(self, *args, **kwargs): 147 | super(COLMAPDatabase, self).__init__(*args, **kwargs) 148 | 149 | self.create_tables = lambda: self.executescript(CREATE_ALL) 150 | self.create_cameras_table = \ 151 | lambda: self.executescript(CREATE_CAMERAS_TABLE) 152 | self.create_descriptors_table = \ 153 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) 154 | self.create_images_table = \ 155 | lambda: self.executescript(CREATE_IMAGES_TABLE) 156 | self.create_two_view_geometries_table = \ 157 | lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE) 158 | self.create_keypoints_table = \ 159 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE) 160 | self.create_matches_table = \ 161 | lambda: self.executescript(CREATE_MATCHES_TABLE) 162 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) 163 | 164 | def add_camera(self, model, width, height, params, 165 | prior_focal_length=False, camera_id=None): 166 | params = np.asarray(params, np.float64) 167 | cursor = self.execute( 168 | "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", 169 | (camera_id, model, width, height, array_to_blob(params), 170 | prior_focal_length)) 171 | return cursor.lastrowid 172 | 173 | def add_image(self, name, camera_id, 174 | prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None): 175 | cursor = self.execute( 176 | "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 177 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], 178 | prior_q[3], prior_t[0], prior_t[1], prior_t[2])) 179 | return cursor.lastrowid 180 | 181 | def add_keypoints(self, image_id, keypoints): 182 | assert(len(keypoints.shape) == 2) 183 | assert(keypoints.shape[1] in [2, 4, 6]) 184 | 185 | keypoints = np.asarray(keypoints, np.float32) 186 | self.execute( 187 | "INSERT INTO keypoints VALUES (?, ?, ?, ?)", 188 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) 189 | 190 | def add_descriptors(self, image_id, descriptors): 191 | descriptors = np.ascontiguousarray(descriptors, np.uint8) 192 | self.execute( 193 | "INSERT INTO descriptors VALUES (?, ?, ?, ?)", 194 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) 195 | 196 | def add_matches(self, image_id1, image_id2, matches): 197 | assert(len(matches.shape) == 2) 198 | assert(matches.shape[1] == 2) 199 | 200 | if image_id1 > image_id2: 201 | matches = matches[:,::-1] 202 | 203 | pair_id = image_ids_to_pair_id(image_id1, image_id2) 204 | matches = np.asarray(matches, np.uint32) 205 | self.execute( 206 | "INSERT INTO matches VALUES (?, ?, ?, ?)", 207 | (pair_id,) + matches.shape + (array_to_blob(matches),)) 208 | 209 | def add_two_view_geometry(self, image_id1, image_id2, matches, 210 | F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2): 211 | assert(len(matches.shape) == 2) 212 | assert(matches.shape[1] == 2) 213 | 214 | if image_id1 > image_id2: 215 | matches = matches[:,::-1] 216 | 217 | pair_id = image_ids_to_pair_id(image_id1, image_id2) 218 | matches = np.asarray(matches, np.uint32) 219 | F = np.asarray(F, dtype=np.float64) 220 | E = np.asarray(E, dtype=np.float64) 221 | H = np.asarray(H, dtype=np.float64) 222 | self.execute( 223 | "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)", 224 | (pair_id,) + matches.shape + (array_to_blob(matches), config, 225 | array_to_blob(F), array_to_blob(E), array_to_blob(H))) 226 | 227 | 228 | def example_usage(): 229 | import os 230 | import argparse 231 | 232 | parser = argparse.ArgumentParser() 233 | parser.add_argument("--database_path", default="database.db") 234 | args = parser.parse_args() 235 | 236 | if os.path.exists(args.database_path): 237 | logging.info("ERROR: database path already exists -- will not modify it.") 238 | return 239 | 240 | # Open the database. 241 | 242 | db = COLMAPDatabase.connect(args.database_path) 243 | 244 | # For convenience, try creating all the tables upfront. 245 | 246 | db.create_tables() 247 | 248 | # Create dummy cameras. 249 | 250 | model1, width1, height1, params1 = \ 251 | 0, 1024, 768, np.array((1024., 512., 384.)) 252 | model2, width2, height2, params2 = \ 253 | 2, 1024, 768, np.array((1024., 512., 384., 0.1)) 254 | 255 | camera_id1 = db.add_camera(model1, width1, height1, params1) 256 | camera_id2 = db.add_camera(model2, width2, height2, params2) 257 | 258 | # Create dummy images. 259 | 260 | image_id1 = db.add_image("image1.png", camera_id1) 261 | image_id2 = db.add_image("image2.png", camera_id1) 262 | image_id3 = db.add_image("image3.png", camera_id2) 263 | image_id4 = db.add_image("image4.png", camera_id2) 264 | 265 | # Create dummy keypoints. 266 | # 267 | # Note that COLMAP supports: 268 | # - 2D keypoints: (x, y) 269 | # - 4D keypoints: (x, y, theta, scale) 270 | # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) 271 | 272 | num_keypoints = 1000 273 | keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1) 274 | keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1) 275 | keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2) 276 | keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2) 277 | 278 | db.add_keypoints(image_id1, keypoints1) 279 | db.add_keypoints(image_id2, keypoints2) 280 | db.add_keypoints(image_id3, keypoints3) 281 | db.add_keypoints(image_id4, keypoints4) 282 | 283 | # Create dummy matches. 284 | 285 | M = 50 286 | matches12 = np.random.randint(num_keypoints, size=(M, 2)) 287 | matches23 = np.random.randint(num_keypoints, size=(M, 2)) 288 | matches34 = np.random.randint(num_keypoints, size=(M, 2)) 289 | 290 | db.add_matches(image_id1, image_id2, matches12) 291 | db.add_matches(image_id2, image_id3, matches23) 292 | db.add_matches(image_id3, image_id4, matches34) 293 | 294 | # Commit the data_500 to the file. 295 | 296 | db.commit() 297 | 298 | # Read and check cameras. 299 | 300 | rows = db.execute("SELECT * FROM cameras") 301 | 302 | camera_id, model, width, height, params, prior = next(rows) 303 | params = blob_to_array(params, np.float64) 304 | assert camera_id == camera_id1 305 | assert model == model1 and width == width1 and height == height1 306 | assert np.allclose(params, params1) 307 | 308 | camera_id, model, width, height, params, prior = next(rows) 309 | params = blob_to_array(params, np.float64) 310 | assert camera_id == camera_id2 311 | assert model == model2 and width == width2 and height == height2 312 | assert np.allclose(params, params2) 313 | 314 | # Read and check keypoints. 315 | 316 | keypoints = dict( 317 | (image_id, blob_to_array(data, np.float32, (-1, 2))) 318 | for image_id, data in db.execute( 319 | "SELECT image_id, data_500 FROM keypoints")) 320 | 321 | assert np.allclose(keypoints[image_id1], keypoints1) 322 | assert np.allclose(keypoints[image_id2], keypoints2) 323 | assert np.allclose(keypoints[image_id3], keypoints3) 324 | assert np.allclose(keypoints[image_id4], keypoints4) 325 | 326 | # Read and check matches. 327 | 328 | pair_ids = [image_ids_to_pair_id(*pair) for pair in 329 | ((image_id1, image_id2), 330 | (image_id2, image_id3), 331 | (image_id3, image_id4))] 332 | 333 | matches = dict( 334 | (pair_id_to_image_ids(pair_id), 335 | blob_to_array(data, np.uint32, (-1, 2))) 336 | for pair_id, data in db.execute("SELECT pair_id, data_500 FROM matches") 337 | ) 338 | 339 | assert np.all(matches[(image_id1, image_id2)] == matches12) 340 | assert np.all(matches[(image_id2, image_id3)] == matches23) 341 | assert np.all(matches[(image_id3, image_id4)] == matches34) 342 | 343 | # Clean up. 344 | 345 | db.close() 346 | 347 | if os.path.exists(args.database_path): 348 | os.remove(args.database_path) 349 | 350 | 351 | if __name__ == "__main__": 352 | example_usage() 353 | -------------------------------------------------------------------------------- /colmap_runner/extract_sfm.py: -------------------------------------------------------------------------------- 1 | from read_write_model import read_model 2 | import numpy as np 3 | import json 4 | import os 5 | from pyquaternion import Quaternion 6 | import trimesh 7 | 8 | 9 | def parse_tracks(colmap_images, colmap_points3D): 10 | all_tracks = [] # list of dicts; each dict represents a track 11 | all_points = [] # list of all 3D points 12 | view_keypoints = {} # dict of lists; each list represents the triangulated key points of a view 13 | 14 | 15 | for point3D_id in colmap_points3D: 16 | point3D = colmap_points3D[point3D_id] 17 | image_ids = point3D.image_ids 18 | point2D_idxs = point3D.point2D_idxs 19 | 20 | cur_track = {} 21 | cur_track['xyz'] = (point3D.xyz[0], point3D.xyz[1], point3D.xyz[2]) 22 | cur_track['err'] = point3D.error.item() 23 | 24 | cur_track_len = len(image_ids) 25 | assert (cur_track_len == len(point2D_idxs)) 26 | all_points.append(list(cur_track['xyz'] + (cur_track['err'], cur_track_len) + tuple(point3D.rgb))) 27 | 28 | pixels = [] 29 | for i in range(cur_track_len): 30 | image = colmap_images[image_ids[i]] 31 | img_name = image.name 32 | point2D_idx = point2D_idxs[i] 33 | point2D = image.xys[point2D_idx] 34 | assert (image.point3D_ids[point2D_idx] == point3D_id) 35 | pixels.append((img_name, point2D[0], point2D[1])) 36 | 37 | if img_name not in view_keypoints: 38 | view_keypoints[img_name] = [(point2D[0], point2D[1]) + cur_track['xyz'] + (cur_track_len, ), ] 39 | else: 40 | view_keypoints[img_name].append((point2D[0], point2D[1]) + cur_track['xyz'] + (cur_track_len, )) 41 | 42 | cur_track['pixels'] = sorted(pixels, key=lambda x: x[0]) # sort pixels by the img_name 43 | all_tracks.append(cur_track) 44 | 45 | return all_tracks, all_points, view_keypoints 46 | 47 | 48 | def parse_camera_dict(colmap_cameras, colmap_images): 49 | camera_dict = {} 50 | for image_id in colmap_images: 51 | image = colmap_images[image_id] 52 | 53 | img_name = image.name 54 | cam = colmap_cameras[image.camera_id] 55 | 56 | # print(cam) 57 | assert(cam.model == 'PINHOLE') 58 | 59 | img_size = [cam.width, cam.height] 60 | params = list(cam.params) 61 | qvec = list(image.qvec) 62 | tvec = list(image.tvec) 63 | 64 | # w, h, fx, fy, cx, cy, qvec, tvec 65 | # camera_dict[img_name] = img_size + params + qvec + tvec 66 | camera_dict[img_name] = {} 67 | camera_dict[img_name]['img_size'] = img_size 68 | 69 | fx, fy, cx, cy = params 70 | K = np.eye(4) 71 | K[0, 0] = fx 72 | K[1, 1] = fy 73 | K[0, 2] = cx 74 | K[1, 2] = cy 75 | camera_dict[img_name]['K'] = list(K.flatten()) 76 | 77 | rot = Quaternion(qvec[0], qvec[1], qvec[2], qvec[3]).rotation_matrix 78 | W2C = np.eye(4) 79 | W2C[:3, :3] = rot 80 | W2C[:3, 3] = np.array(tvec) 81 | camera_dict[img_name]['W2C'] = list(W2C.flatten()) 82 | 83 | return camera_dict 84 | 85 | 86 | def extract_all_to_dir(sparse_dir, out_dir, ext='.bin'): 87 | if not os.path.exists(out_dir): 88 | os.mkdir(out_dir) 89 | 90 | camera_dict_file = os.path.join(out_dir, 'kai_cameras.json') 91 | xyz_file = os.path.join(out_dir, 'kai_points.txt') 92 | track_file = os.path.join(out_dir, 'kai_tracks.json') 93 | keypoints_file = os.path.join(out_dir, 'kai_keypoints.json') 94 | 95 | colmap_cameras, colmap_images, colmap_points3D = read_model(sparse_dir, ext) 96 | 97 | camera_dict = parse_camera_dict(colmap_cameras, colmap_images) 98 | with open(camera_dict_file, 'w') as fp: 99 | json.dump(camera_dict, fp, indent=2, sort_keys=True) 100 | 101 | all_tracks, all_points, view_keypoints = parse_tracks(colmap_images, colmap_points3D) 102 | all_points = np.array(all_points) 103 | np.savetxt(xyz_file, all_points, header='# format: x, y, z, reproj_err, track_len, color(RGB)', fmt='%.6f') 104 | 105 | mesh = trimesh.Trimesh(vertices=all_points[:, :3].astype(np.float32), 106 | vertex_colors=all_points[:, -3:].astype(np.uint8)) 107 | mesh.export(os.path.join(out_dir, 'kai_points.ply')) 108 | 109 | with open(track_file, 'w') as fp: 110 | json.dump(all_tracks, fp) 111 | 112 | with open(keypoints_file, 'w') as fp: 113 | json.dump(view_keypoints, fp) 114 | 115 | 116 | if __name__ == '__main__': 117 | mvs_dir = '/home/zhangka2/sg_render/run_mvs/scan114_train_5/colmap_mvs/mvs' 118 | sparse_dir = os.path.join(mvs_dir, 'sparse') 119 | out_dir = os.path.join(mvs_dir, 'sparse_inspect') 120 | extract_all_to_dir(sparse_dir, out_dir) 121 | 122 | xyz_file = os.path.join(out_dir, 'kai_points.txt') 123 | reproj_errs = np.loadtxt(xyz_file)[:, 3] 124 | with open(os.path.join(out_dir, 'stats.txt'), 'w') as fp: 125 | fp.write('reprojection errors (px) in SfM:\n') 126 | fp.write(' percentile value\n') 127 | for a in [50, 70, 90, 99]: 128 | fp.write(' {} {:.3f}\n'.format(a, np.percentile(reproj_errs, a))) 129 | 130 | print('reprojection errors (px) in SfM:') 131 | print(' percentile value') 132 | for a in [50, 70, 90, 99]: 133 | print(' {} {:.3f}'.format(a, np.percentile(reproj_errs, a))) -------------------------------------------------------------------------------- /colmap_runner/normalize_cam_dict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import copy 4 | import open3d as o3d 5 | 6 | 7 | def get_tf_cams(cam_dict, target_radius=1.): 8 | cam_centers = [] 9 | for im_name in cam_dict: 10 | W2C = np.array(cam_dict[im_name]['W2C']).reshape((4, 4)) 11 | C2W = np.linalg.inv(W2C) 12 | cam_centers.append(C2W[:3, 3:4]) 13 | 14 | def get_center_and_diag(cam_centers): 15 | cam_centers = np.hstack(cam_centers) 16 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 17 | center = avg_cam_center 18 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 19 | diagonal = np.max(dist) 20 | return center.flatten(), diagonal 21 | 22 | center, diagonal = get_center_and_diag(cam_centers) 23 | radius = diagonal * 1.1 24 | 25 | translate = -center 26 | scale = target_radius / radius 27 | 28 | return translate, scale 29 | 30 | 31 | def normalize_cam_dict(in_cam_dict_file, out_cam_dict_file, target_radius=1., in_geometry_file=None, out_geometry_file=None): 32 | with open(in_cam_dict_file) as fp: 33 | in_cam_dict = json.load(fp) 34 | 35 | translate, scale = get_tf_cams(in_cam_dict, target_radius=target_radius) 36 | 37 | if in_geometry_file is not None and out_geometry_file is not None: 38 | # check this page if you encounter issue in file io: http://www.open3d.org/docs/0.9.0/tutorial/Basic/file_io.html 39 | geometry = o3d.io.read_triangle_mesh(in_geometry_file) 40 | 41 | tf_translate = np.eye(4) 42 | tf_translate[:3, 3:4] = translate 43 | tf_scale = np.eye(4) 44 | tf_scale[:3, :3] *= scale 45 | tf = np.matmul(tf_scale, tf_translate) 46 | 47 | geometry_norm = geometry.transform(tf) 48 | o3d.io.write_triangle_mesh(out_geometry_file, geometry_norm) 49 | 50 | def transform_pose(W2C, translate, scale): 51 | C2W = np.linalg.inv(W2C) 52 | cam_center = C2W[:3, 3] 53 | cam_center = (cam_center + translate) * scale 54 | C2W[:3, 3] = cam_center 55 | return np.linalg.inv(C2W) 56 | 57 | out_cam_dict = copy.deepcopy(in_cam_dict) 58 | for img_name in out_cam_dict: 59 | W2C = np.array(out_cam_dict[img_name]['W2C']).reshape((4, 4)) 60 | W2C = transform_pose(W2C, translate, scale) 61 | assert(np.isclose(np.linalg.det(W2C[:3, :3]), 1.)) 62 | out_cam_dict[img_name]['W2C'] = list(W2C.flatten()) 63 | 64 | with open(out_cam_dict_file, 'w') as fp: 65 | json.dump(out_cam_dict, fp, indent=2, sort_keys=True) 66 | 67 | 68 | if __name__ == '__main__': 69 | in_cam_dict_file = '' 70 | out_cam_dict_file = '' 71 | normalize_cam_dict(in_cam_dict_file, out_cam_dict_file, target_radius=1.) 72 | -------------------------------------------------------------------------------- /colmap_runner/read_write_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | import argparse 38 | 39 | 40 | CameraModel = collections.namedtuple( 41 | "CameraModel", ["model_id", "model_name", "num_params"]) 42 | Camera = collections.namedtuple( 43 | "Camera", ["id", "model", "width", "height", "params"]) 44 | BaseImage = collections.namedtuple( 45 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 46 | Point3D = collections.namedtuple( 47 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 48 | 49 | 50 | class Image(BaseImage): 51 | def qvec2rotmat(self): 52 | return qvec2rotmat(self.qvec) 53 | 54 | 55 | CAMERA_MODELS = { 56 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 57 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 58 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 59 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 60 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 61 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 62 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 63 | CameraModel(model_id=7, model_name="FOV", num_params=5), 64 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 65 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 66 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 67 | } 68 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 69 | for camera_model in CAMERA_MODELS]) 70 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 71 | for camera_model in CAMERA_MODELS]) 72 | 73 | 74 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 75 | """Read and unpack the next bytes from a binary file. 76 | :param fid: 77 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 78 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 79 | :param endian_character: Any of {@, =, <, >, !} 80 | :return: Tuple of read and unpacked values. 81 | """ 82 | data = fid.read(num_bytes) 83 | return struct.unpack(endian_character + format_char_sequence, data) 84 | 85 | 86 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): 87 | """pack and write to a binary file. 88 | :param fid: 89 | :param data: data to send, if multiple elements are sent at the same time, 90 | they should be encapsuled either in a list or a tuple 91 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 92 | should be the same length as the data list or tuple 93 | :param endian_character: Any of {@, =, <, >, !} 94 | """ 95 | if isinstance(data, (list, tuple)): 96 | bytes = struct.pack(endian_character + format_char_sequence, *data) 97 | else: 98 | bytes = struct.pack(endian_character + format_char_sequence, data) 99 | fid.write(bytes) 100 | 101 | 102 | def read_cameras_text(path): 103 | """ 104 | see: src/base/reconstruction.cc 105 | void Reconstruction::WriteCamerasText(const std::string& path) 106 | void Reconstruction::ReadCamerasText(const std::string& path) 107 | """ 108 | cameras = {} 109 | with open(path, "r") as fid: 110 | while True: 111 | line = fid.readline() 112 | if not line: 113 | break 114 | line = line.strip() 115 | if len(line) > 0 and line[0] != "#": 116 | elems = line.split() 117 | camera_id = int(elems[0]) 118 | model = elems[1] 119 | width = int(elems[2]) 120 | height = int(elems[3]) 121 | params = np.array(tuple(map(float, elems[4:]))) 122 | cameras[camera_id] = Camera(id=camera_id, model=model, 123 | width=width, height=height, 124 | params=params) 125 | return cameras 126 | 127 | 128 | def read_cameras_binary(path_to_model_file): 129 | """ 130 | see: src/base/reconstruction.cc 131 | void Reconstruction::WriteCamerasBinary(const std::string& path) 132 | void Reconstruction::ReadCamerasBinary(const std::string& path) 133 | """ 134 | cameras = {} 135 | with open(path_to_model_file, "rb") as fid: 136 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 137 | for camera_line_index in range(num_cameras): 138 | camera_properties = read_next_bytes( 139 | fid, num_bytes=24, format_char_sequence="iiQQ") 140 | camera_id = camera_properties[0] 141 | model_id = camera_properties[1] 142 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 143 | width = camera_properties[2] 144 | height = camera_properties[3] 145 | num_params = CAMERA_MODEL_IDS[model_id].num_params 146 | params = read_next_bytes(fid, num_bytes=8*num_params, 147 | format_char_sequence="d"*num_params) 148 | cameras[camera_id] = Camera(id=camera_id, 149 | model=model_name, 150 | width=width, 151 | height=height, 152 | params=np.array(params)) 153 | assert len(cameras) == num_cameras 154 | return cameras 155 | 156 | 157 | def write_cameras_text(cameras, path): 158 | """ 159 | see: src/base/reconstruction.cc 160 | void Reconstruction::WriteCamerasText(const std::string& path) 161 | void Reconstruction::ReadCamerasText(const std::string& path) 162 | """ 163 | HEADER = '# Camera list with one line of data per camera:\n' 164 | '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n' 165 | '# Number of cameras: {}\n'.format(len(cameras)) 166 | with open(path, "w") as fid: 167 | fid.write(HEADER) 168 | for _, cam in cameras.items(): 169 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] 170 | line = " ".join([str(elem) for elem in to_write]) 171 | fid.write(line + "\n") 172 | 173 | 174 | def write_cameras_binary(cameras, path_to_model_file): 175 | """ 176 | see: src/base/reconstruction.cc 177 | void Reconstruction::WriteCamerasBinary(const std::string& path) 178 | void Reconstruction::ReadCamerasBinary(const std::string& path) 179 | """ 180 | with open(path_to_model_file, "wb") as fid: 181 | write_next_bytes(fid, len(cameras), "Q") 182 | for _, cam in cameras.items(): 183 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id 184 | camera_properties = [cam.id, 185 | model_id, 186 | cam.width, 187 | cam.height] 188 | write_next_bytes(fid, camera_properties, "iiQQ") 189 | for p in cam.params: 190 | write_next_bytes(fid, float(p), "d") 191 | return cameras 192 | 193 | 194 | def read_images_text(path): 195 | """ 196 | see: src/base/reconstruction.cc 197 | void Reconstruction::ReadImagesText(const std::string& path) 198 | void Reconstruction::WriteImagesText(const std::string& path) 199 | """ 200 | images = {} 201 | with open(path, "r") as fid: 202 | while True: 203 | line = fid.readline() 204 | if not line: 205 | break 206 | line = line.strip() 207 | if len(line) > 0 and line[0] != "#": 208 | elems = line.split() 209 | image_id = int(elems[0]) 210 | qvec = np.array(tuple(map(float, elems[1:5]))) 211 | tvec = np.array(tuple(map(float, elems[5:8]))) 212 | camera_id = int(elems[8]) 213 | image_name = elems[9] 214 | elems = fid.readline().split() 215 | xys = np.column_stack([tuple(map(float, elems[0::3])), 216 | tuple(map(float, elems[1::3]))]) 217 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 218 | images[image_id] = Image( 219 | id=image_id, qvec=qvec, tvec=tvec, 220 | camera_id=camera_id, name=image_name, 221 | xys=xys, point3D_ids=point3D_ids) 222 | return images 223 | 224 | 225 | def read_images_binary(path_to_model_file): 226 | """ 227 | see: src/base/reconstruction.cc 228 | void Reconstruction::ReadImagesBinary(const std::string& path) 229 | void Reconstruction::WriteImagesBinary(const std::string& path) 230 | """ 231 | images = {} 232 | with open(path_to_model_file, "rb") as fid: 233 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 234 | for image_index in range(num_reg_images): 235 | binary_image_properties = read_next_bytes( 236 | fid, num_bytes=64, format_char_sequence="idddddddi") 237 | image_id = binary_image_properties[0] 238 | qvec = np.array(binary_image_properties[1:5]) 239 | tvec = np.array(binary_image_properties[5:8]) 240 | camera_id = binary_image_properties[8] 241 | image_name = "" 242 | current_char = read_next_bytes(fid, 1, "c")[0] 243 | while current_char != b"\x00": # look for the ASCII 0 entry 244 | image_name += current_char.decode("utf-8") 245 | current_char = read_next_bytes(fid, 1, "c")[0] 246 | num_points2D = read_next_bytes(fid, num_bytes=8, 247 | format_char_sequence="Q")[0] 248 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 249 | format_char_sequence="ddq"*num_points2D) 250 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 251 | tuple(map(float, x_y_id_s[1::3]))]) 252 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 253 | images[image_id] = Image( 254 | id=image_id, qvec=qvec, tvec=tvec, 255 | camera_id=camera_id, name=image_name, 256 | xys=xys, point3D_ids=point3D_ids) 257 | return images 258 | 259 | 260 | def write_images_text(images, path): 261 | """ 262 | see: src/base/reconstruction.cc 263 | void Reconstruction::ReadImagesText(const std::string& path) 264 | void Reconstruction::WriteImagesText(const std::string& path) 265 | """ 266 | if len(images) == 0: 267 | mean_observations = 0 268 | else: 269 | mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images) 270 | HEADER = '# Image list with two lines of data per image:\n' 271 | '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n' 272 | '# POINTS2D[] as (X, Y, POINT3D_ID)\n' 273 | '# Number of images: {}, mean observations per image: {}\n'.format(len(images), mean_observations) 274 | 275 | with open(path, "w") as fid: 276 | fid.write(HEADER) 277 | for _, img in images.items(): 278 | image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] 279 | first_line = " ".join(map(str, image_header)) 280 | fid.write(first_line + "\n") 281 | 282 | points_strings = [] 283 | for xy, point3D_id in zip(img.xys, img.point3D_ids): 284 | points_strings.append(" ".join(map(str, [*xy, point3D_id]))) 285 | fid.write(" ".join(points_strings) + "\n") 286 | 287 | 288 | def write_images_binary(images, path_to_model_file): 289 | """ 290 | see: src/base/reconstruction.cc 291 | void Reconstruction::ReadImagesBinary(const std::string& path) 292 | void Reconstruction::WriteImagesBinary(const std::string& path) 293 | """ 294 | with open(path_to_model_file, "wb") as fid: 295 | write_next_bytes(fid, len(images), "Q") 296 | for _, img in images.items(): 297 | write_next_bytes(fid, img.id, "i") 298 | write_next_bytes(fid, img.qvec.tolist(), "dddd") 299 | write_next_bytes(fid, img.tvec.tolist(), "ddd") 300 | write_next_bytes(fid, img.camera_id, "i") 301 | for char in img.name: 302 | write_next_bytes(fid, char.encode("utf-8"), "c") 303 | write_next_bytes(fid, b"\x00", "c") 304 | write_next_bytes(fid, len(img.point3D_ids), "Q") 305 | for xy, p3d_id in zip(img.xys, img.point3D_ids): 306 | write_next_bytes(fid, [*xy, p3d_id], "ddq") 307 | 308 | 309 | def read_points3D_text(path): 310 | """ 311 | see: src/base/reconstruction.cc 312 | void Reconstruction::ReadPoints3DText(const std::string& path) 313 | void Reconstruction::WritePoints3DText(const std::string& path) 314 | """ 315 | points3D = {} 316 | with open(path, "r") as fid: 317 | while True: 318 | line = fid.readline() 319 | if not line: 320 | break 321 | line = line.strip() 322 | if len(line) > 0 and line[0] != "#": 323 | elems = line.split() 324 | point3D_id = int(elems[0]) 325 | xyz = np.array(tuple(map(float, elems[1:4]))) 326 | rgb = np.array(tuple(map(int, elems[4:7]))) 327 | error = float(elems[7]) 328 | image_ids = np.array(tuple(map(int, elems[8::2]))) 329 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 330 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 331 | error=error, image_ids=image_ids, 332 | point2D_idxs=point2D_idxs) 333 | return points3D 334 | 335 | 336 | def read_points3d_binary(path_to_model_file): 337 | """ 338 | see: src/base/reconstruction.cc 339 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 340 | void Reconstruction::WritePoints3DBinary(const std::string& path) 341 | """ 342 | points3D = {} 343 | with open(path_to_model_file, "rb") as fid: 344 | num_points = read_next_bytes(fid, 8, "Q")[0] 345 | for point_line_index in range(num_points): 346 | binary_point_line_properties = read_next_bytes( 347 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 348 | point3D_id = binary_point_line_properties[0] 349 | xyz = np.array(binary_point_line_properties[1:4]) 350 | rgb = np.array(binary_point_line_properties[4:7]) 351 | error = np.array(binary_point_line_properties[7]) 352 | track_length = read_next_bytes( 353 | fid, num_bytes=8, format_char_sequence="Q")[0] 354 | track_elems = read_next_bytes( 355 | fid, num_bytes=8*track_length, 356 | format_char_sequence="ii"*track_length) 357 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 358 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 359 | points3D[point3D_id] = Point3D( 360 | id=point3D_id, xyz=xyz, rgb=rgb, 361 | error=error, image_ids=image_ids, 362 | point2D_idxs=point2D_idxs) 363 | return points3D 364 | 365 | 366 | def write_points3D_text(points3D, path): 367 | """ 368 | see: src/base/reconstruction.cc 369 | void Reconstruction::ReadPoints3DText(const std::string& path) 370 | void Reconstruction::WritePoints3DText(const std::string& path) 371 | """ 372 | if len(points3D) == 0: 373 | mean_track_length = 0 374 | else: 375 | mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D) 376 | HEADER = '# 3D point list with one line of data per point:\n' 377 | '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n' 378 | '# Number of points: {}, mean track length: {}\n'.format(len(points3D), mean_track_length) 379 | 380 | with open(path, "w") as fid: 381 | fid.write(HEADER) 382 | for _, pt in points3D.items(): 383 | point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] 384 | fid.write(" ".join(map(str, point_header)) + " ") 385 | track_strings = [] 386 | for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): 387 | track_strings.append(" ".join(map(str, [image_id, point2D]))) 388 | fid.write(" ".join(track_strings) + "\n") 389 | 390 | 391 | def write_points3d_binary(points3D, path_to_model_file): 392 | """ 393 | see: src/base/reconstruction.cc 394 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 395 | void Reconstruction::WritePoints3DBinary(const std::string& path) 396 | """ 397 | with open(path_to_model_file, "wb") as fid: 398 | write_next_bytes(fid, len(points3D), "Q") 399 | for _, pt in points3D.items(): 400 | write_next_bytes(fid, pt.id, "Q") 401 | write_next_bytes(fid, pt.xyz.tolist(), "ddd") 402 | write_next_bytes(fid, pt.rgb.tolist(), "BBB") 403 | write_next_bytes(fid, pt.error, "d") 404 | track_length = pt.image_ids.shape[0] 405 | write_next_bytes(fid, track_length, "Q") 406 | for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): 407 | write_next_bytes(fid, [image_id, point2D_id], "ii") 408 | 409 | 410 | def read_model(path, ext): 411 | if ext == ".txt": 412 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 413 | images = read_images_text(os.path.join(path, "images" + ext)) 414 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 415 | else: 416 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 417 | images = read_images_binary(os.path.join(path, "images" + ext)) 418 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 419 | return cameras, images, points3D 420 | 421 | 422 | def write_model(cameras, images, points3D, path, ext): 423 | if ext == ".txt": 424 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) 425 | write_images_text(images, os.path.join(path, "images" + ext)) 426 | write_points3D_text(points3D, os.path.join(path, "points3D") + ext) 427 | else: 428 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) 429 | write_images_binary(images, os.path.join(path, "images" + ext)) 430 | write_points3d_binary(points3D, os.path.join(path, "points3D") + ext) 431 | return cameras, images, points3D 432 | 433 | 434 | def qvec2rotmat(qvec): 435 | return np.array([ 436 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 437 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 438 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 439 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 440 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 441 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 442 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 443 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 444 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 445 | 446 | 447 | def rotmat2qvec(R): 448 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 449 | K = np.array([ 450 | [Rxx - Ryy - Rzz, 0, 0, 0], 451 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 452 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 453 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 454 | eigvals, eigvecs = np.linalg.eigh(K) 455 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 456 | if qvec[0] < 0: 457 | qvec *= -1 458 | return qvec 459 | 460 | 461 | def main(): 462 | parser = argparse.ArgumentParser(description='Read and write COLMAP binary and text models') 463 | parser.add_argument('input_model', help='path to input model folder') 464 | parser.add_argument('input_format', choices=['.bin', '.txt'], 465 | help='input model format') 466 | parser.add_argument('--output_model', metavar='PATH', 467 | help='path to output model folder') 468 | parser.add_argument('--output_format', choices=['.bin', '.txt'], 469 | help='outut model format', default='.txt') 470 | args = parser.parse_args() 471 | 472 | cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) 473 | 474 | print("num_cameras:", len(cameras)) 475 | print("num_images:", len(images)) 476 | print("num_points3D:", len(points3D)) 477 | 478 | if args.output_model is not None: 479 | write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format) 480 | 481 | 482 | if __name__ == "__main__": 483 | main() 484 | -------------------------------------------------------------------------------- /colmap_runner/run_colmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from extract_sfm import extract_all_to_dir 4 | from normalize_cam_dict import normalize_cam_dict 5 | 6 | ######################################################################### 7 | # Note: configure the colmap_bin to the colmap executable on your machine 8 | ######################################################################### 9 | 10 | def bash_run(cmd): 11 | colmap_bin = '/home/zhangka2/code/colmap/build/__install__/bin/colmap' 12 | cmd = colmap_bin + ' ' + cmd 13 | print('\nRunning cmd: ', cmd) 14 | 15 | subprocess.check_call(['/bin/bash', '-c', cmd]) 16 | 17 | 18 | gpu_index = '-1' 19 | 20 | 21 | def run_sift_matching(img_dir, db_file, remove_exist=False): 22 | print('Running sift matching...') 23 | 24 | if remove_exist and os.path.exists(db_file): 25 | os.remove(db_file) # otherwise colmap will skip sift matching 26 | 27 | # feature extraction 28 | # if there's no attached display, cannot use feature extractor with GPU 29 | cmd = ' feature_extractor --database_path {} \ 30 | --image_path {} \ 31 | --ImageReader.single_camera 1 \ 32 | --ImageReader.camera_model SIMPLE_RADIAL \ 33 | --SiftExtraction.max_image_size 5000 \ 34 | --SiftExtraction.estimate_affine_shape 0 \ 35 | --SiftExtraction.domain_size_pooling 1 \ 36 | --SiftExtraction.use_gpu 1 \ 37 | --SiftExtraction.max_num_features 16384 \ 38 | --SiftExtraction.gpu_index {}'.format(db_file, img_dir, gpu_index) 39 | bash_run(cmd) 40 | 41 | # feature matching 42 | cmd = ' exhaustive_matcher --database_path {} \ 43 | --SiftMatching.guided_matching 1 \ 44 | --SiftMatching.use_gpu 1 \ 45 | --SiftMatching.max_num_matches 65536 \ 46 | --SiftMatching.max_error 3 \ 47 | --SiftMatching.gpu_index {}'.format(db_file, gpu_index) 48 | 49 | bash_run(cmd) 50 | 51 | 52 | def run_sfm(img_dir, db_file, out_dir): 53 | print('Running SfM...') 54 | 55 | cmd = ' mapper \ 56 | --database_path {} \ 57 | --image_path {} \ 58 | --output_path {} \ 59 | --Mapper.tri_min_angle 3.0 \ 60 | --Mapper.filter_min_tri_angle 3.0'.format(db_file, img_dir, out_dir) 61 | 62 | bash_run(cmd) 63 | 64 | 65 | def prepare_mvs(img_dir, sparse_dir, mvs_dir): 66 | print('Preparing for MVS...') 67 | 68 | cmd = ' image_undistorter \ 69 | --image_path {} \ 70 | --input_path {} \ 71 | --output_path {} \ 72 | --output_type COLMAP \ 73 | --max_image_size 2000'.format(img_dir, sparse_dir, mvs_dir) 74 | 75 | bash_run(cmd) 76 | 77 | 78 | def run_photometric_mvs(mvs_dir, window_radius): 79 | print('Running photometric MVS...') 80 | 81 | cmd = ' patch_match_stereo --workspace_path {} \ 82 | --PatchMatchStereo.window_radius {} \ 83 | --PatchMatchStereo.min_triangulation_angle 3.0 \ 84 | --PatchMatchStereo.filter 1 \ 85 | --PatchMatchStereo.geom_consistency 1 \ 86 | --PatchMatchStereo.gpu_index={} \ 87 | --PatchMatchStereo.num_samples 15 \ 88 | --PatchMatchStereo.num_iterations 12'.format(mvs_dir, 89 | window_radius, gpu_index) 90 | 91 | bash_run(cmd) 92 | 93 | 94 | def run_fuse(mvs_dir, out_ply): 95 | print('Running depth fusion...') 96 | 97 | cmd = ' stereo_fusion --workspace_path {} \ 98 | --output_path {} \ 99 | --input_type geometric'.format(mvs_dir, out_ply) 100 | 101 | bash_run(cmd) 102 | 103 | 104 | def run_possion_mesher(in_ply, out_ply, trim): 105 | print('Running possion mesher...') 106 | 107 | cmd = ' poisson_mesher \ 108 | --input_path {} \ 109 | --output_path {} \ 110 | --PoissonMeshing.trim {}'.format(in_ply, out_ply, trim) 111 | 112 | bash_run(cmd) 113 | 114 | 115 | def main(img_dir, out_dir, run_mvs=False): 116 | os.makedirs(out_dir, exist_ok=True) 117 | 118 | #### run sfm 119 | sfm_dir = os.path.join(out_dir, 'sfm') 120 | os.makedirs(sfm_dir, exist_ok=True) 121 | 122 | img_dir_link = os.path.join(sfm_dir, 'images') 123 | if os.path.exists(img_dir_link): 124 | os.remove(img_dir_link) 125 | os.symlink(img_dir, img_dir_link) 126 | 127 | db_file = os.path.join(sfm_dir, 'database.db') 128 | run_sift_matching(img_dir, db_file, remove_exist=False) 129 | sparse_dir = os.path.join(sfm_dir, 'sparse') 130 | os.makedirs(sparse_dir, exist_ok=True) 131 | run_sfm(img_dir, db_file, sparse_dir) 132 | 133 | # undistort images 134 | mvs_dir = os.path.join(out_dir, 'mvs') 135 | os.makedirs(mvs_dir, exist_ok=True) 136 | prepare_mvs(img_dir, sparse_dir, mvs_dir) 137 | 138 | # extract camera parameters and undistorted images 139 | os.makedirs(os.path.join(out_dir, 'posed_images'), exist_ok=True) 140 | extract_all_to_dir(os.path.join(mvs_dir, 'sparse'), os.path.join(out_dir, 'posed_images')) 141 | undistorted_img_dir = os.path.join(mvs_dir, 'images') 142 | posed_img_dir_link = os.path.join(out_dir, 'posed_images/images') 143 | if os.path.exists(posed_img_dir_link): 144 | os.remove(posed_img_dir_link) 145 | os.symlink(undistorted_img_dir, posed_img_dir_link) 146 | # normalize average camera center to origin, and put all cameras inside the unit sphere 147 | normalize_cam_dict(os.path.join(out_dir, 'posed_images/kai_cameras.json'), 148 | os.path.join(out_dir, 'posed_images/kai_cameras_normalized.json')) 149 | 150 | if run_mvs: 151 | # run mvs 152 | run_photometric_mvs(mvs_dir, window_radius=7) 153 | 154 | out_ply = os.path.join(out_dir, 'mvs/fused.ply') 155 | run_fuse(mvs_dir, out_ply) 156 | 157 | out_mesh_ply = os.path.join(out_dir, 'mvs/meshed_trim_3.ply') 158 | run_possion_mesher(out_ply, out_mesh_ply, trim=3) 159 | 160 | 161 | if __name__ == '__main__': 162 | ### note: this script is intended for the case where all images are taken by the same camera, i.e., intrinisics are shared. 163 | 164 | img_dir = '' 165 | out_dir = '' 166 | run_mvs = False 167 | main(img_dir, out_dir, run_mvs=run_mvs) 168 | 169 | -------------------------------------------------------------------------------- /colmap_runner/run_colmap_posed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from database import COLMAPDatabase 4 | from pyquaternion import Quaternion 5 | import numpy as np 6 | import imageio 7 | import subprocess 8 | 9 | 10 | def bash_run(cmd): 11 | # local install of colmap 12 | env = os.environ.copy() 13 | env['LD_LIBRARY_PATH'] = '/home/zhangka2/code/colmap/build/__install__/lib' 14 | 15 | colmap_bin = '/home/zhangka2/code/colmap/build/__install__/bin/colmap' 16 | cmd = colmap_bin + ' ' + cmd 17 | print('\nRunning cmd: ', cmd) 18 | 19 | subprocess.check_call(['/bin/bash', '-c', cmd], env=env) 20 | 21 | 22 | gpu_index = '-1' 23 | 24 | 25 | def run_sift_matching(img_dir, db_file): 26 | print('Running sift matching...') 27 | 28 | # if os.path.exists(db_file): # otherwise colmap will skip sift matching 29 | # os.remove(db_file) 30 | 31 | # feature extraction 32 | # if there's no attached display, cannot use feature extractor with GPU 33 | cmd = ' feature_extractor --database_path {} \ 34 | --image_path {} \ 35 | --ImageReader.camera_model PINHOLE \ 36 | --SiftExtraction.max_image_size 5000 \ 37 | --SiftExtraction.estimate_affine_shape 0 \ 38 | --SiftExtraction.domain_size_pooling 1 \ 39 | --SiftExtraction.num_threads 32 \ 40 | --SiftExtraction.use_gpu 1 \ 41 | --SiftExtraction.gpu_index {}'.format(db_file, img_dir, gpu_index) 42 | bash_run(cmd) 43 | 44 | # feature matching 45 | cmd = ' exhaustive_matcher --database_path {} \ 46 | --SiftMatching.guided_matching 1 \ 47 | --SiftMatching.use_gpu 1 \ 48 | --SiftMatching.gpu_index {}'.format(db_file, gpu_index) 49 | 50 | bash_run(cmd) 51 | 52 | 53 | def create_init_files(pinhole_dict_file, db_file, out_dir): 54 | if not os.path.exists(out_dir): 55 | os.mkdir(out_dir) 56 | 57 | # create template 58 | with open(pinhole_dict_file) as fp: 59 | pinhole_dict = json.load(fp) 60 | 61 | template = {} 62 | cameras_line_template = '{camera_id} PINHOLE {width} {height} {fx} {fy} {cx} {cy}\n' 63 | images_line_template = '{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {image_name}\n\n' 64 | 65 | for img_name in pinhole_dict: 66 | # w, h, fx, fy, cx, cy, qvec, t 67 | params = pinhole_dict[img_name] 68 | w = params[0] 69 | h = params[1] 70 | fx = params[2] 71 | fy = params[3] 72 | cx = params[4] 73 | cy = params[5] 74 | qvec = params[6:10] 75 | tvec = params[10:13] 76 | 77 | cam_line = cameras_line_template.format(camera_id="{camera_id}", width=w, height=h, fx=fx, fy=fy, cx=cx, cy=cy) 78 | img_line = images_line_template.format(image_id="{image_id}", qw=qvec[0], qx=qvec[1], qy=qvec[2], qz=qvec[3], 79 | tx=tvec[0], ty=tvec[1], tz=tvec[2], camera_id="{camera_id}", image_name=img_name) 80 | template[img_name] = (cam_line, img_line) 81 | 82 | # read database 83 | db = COLMAPDatabase.connect(db_file) 84 | table_images = db.execute("SELECT * FROM images") 85 | img_name2id_dict = {} 86 | for row in table_images: 87 | img_name2id_dict[row[1]] = row[0] 88 | 89 | cameras_txt_lines = [] 90 | images_txt_lines = [] 91 | for img_name, img_id in img_name2id_dict.items(): 92 | camera_line = template[img_name][0].format(camera_id=img_id) 93 | cameras_txt_lines.append(camera_line) 94 | 95 | image_line = template[img_name][1].format(image_id=img_id, camera_id=img_id) 96 | images_txt_lines.append(image_line) 97 | 98 | with open(os.path.join(out_dir, 'cameras.txt'), 'w') as fp: 99 | fp.writelines(cameras_txt_lines) 100 | 101 | with open(os.path.join(out_dir, 'images.txt'), 'w') as fp: 102 | fp.writelines(images_txt_lines) 103 | fp.write('\n') 104 | 105 | # create an empty points3D.txt 106 | fp = open(os.path.join(out_dir, 'points3D.txt'), 'w') 107 | fp.close() 108 | 109 | 110 | def run_point_triangulation(img_dir, db_file, out_dir): 111 | print('Running point triangulation...') 112 | 113 | # triangulate points 114 | cmd = ' point_triangulator --database_path {} \ 115 | --image_path {} \ 116 | --input_path {} \ 117 | --output_path {} \ 118 | --Mapper.tri_ignore_two_view_tracks 1'.format(db_file, img_dir, out_dir, out_dir) 119 | bash_run(cmd) 120 | 121 | 122 | # this step is optional 123 | def run_global_ba(in_dir, out_dir): 124 | print('Running global BA...') 125 | if not os.path.exists(out_dir): 126 | os.mkdir(out_dir) 127 | 128 | cmd = ' bundle_adjuster --input_path {in_dir} --output_path {out_dir}'.format(in_dir=in_dir, out_dir=out_dir) 129 | 130 | bash_run(cmd) 131 | 132 | 133 | def prepare_mvs(img_dir, sfm_dir, mvs_dir): 134 | if not os.path.exists(mvs_dir): 135 | os.mkdir(mvs_dir) 136 | 137 | images_symlink = os.path.join(mvs_dir, 'images') 138 | if os.path.exists(images_symlink): 139 | os.unlink(images_symlink) 140 | os.symlink(os.path.relpath(img_dir, mvs_dir), 141 | images_symlink) 142 | 143 | sparse_symlink = os.path.join(mvs_dir, 'sparse') 144 | if os.path.exists(sparse_symlink): 145 | os.unlink(sparse_symlink) 146 | os.symlink(os.path.relpath(sfm_dir, mvs_dir), 147 | sparse_symlink) 148 | 149 | # prepare stereo directory 150 | stereo_dir = os.path.join(mvs_dir, 'stereo') 151 | for subdir in [stereo_dir, 152 | os.path.join(stereo_dir, 'depth_maps'), 153 | os.path.join(stereo_dir, 'normal_maps'), 154 | os.path.join(stereo_dir, 'consistency_graphs')]: 155 | if not os.path.exists(subdir): 156 | os.mkdir(subdir) 157 | 158 | # write patch-match.cfg and fusion.cfg 159 | image_names = sorted(os.listdir(os.path.join(mvs_dir, 'images'))) 160 | 161 | with open(os.path.join(stereo_dir, 'patch-match.cfg'), 'w') as fp: 162 | for img_name in image_names: 163 | fp.write(img_name + '\n__auto__, 20\n') 164 | 165 | # use all images 166 | # fp.write(img_name + '\n__all__\n') 167 | 168 | # randomly choose 20 images 169 | # from random import shuffle 170 | # candi_src_images = [x for x in image_names if x != img_name] 171 | # shuffle(candi_src_images) 172 | # max_src_images = 10 173 | # fp.write(img_name + '\n' + ', '.join(candi_src_images[:max_src_images]) + '\n') 174 | 175 | with open(os.path.join(stereo_dir, 'fusion.cfg'), 'w') as fp: 176 | for img_name in image_names: 177 | fp.write(img_name + '\n') 178 | 179 | 180 | def run_photometric_mvs(mvs_dir, window_radius): 181 | print('Running photometric MVS...') 182 | 183 | cmd = ' patch_match_stereo --workspace_path {} \ 184 | --PatchMatchStereo.window_radius {} \ 185 | --PatchMatchStereo.min_triangulation_angle 3.0 \ 186 | --PatchMatchStereo.filter 1 \ 187 | --PatchMatchStereo.geom_consistency 1 \ 188 | --PatchMatchStereo.gpu_index={} \ 189 | --PatchMatchStereo.num_samples 15 \ 190 | --PatchMatchStereo.num_iterations 12'.format(mvs_dir, 191 | window_radius, gpu_index) 192 | 193 | bash_run(cmd) 194 | 195 | 196 | def run_fuse(mvs_dir, out_ply): 197 | print('Running depth fusion...') 198 | 199 | cmd = ' stereo_fusion --workspace_path {} \ 200 | --output_path {} \ 201 | --input_type geometric'.format(mvs_dir, out_ply) 202 | bash_run(cmd) 203 | 204 | 205 | def run_possion_mesher(in_ply, out_ply, trim): 206 | print('Running possion mesher...') 207 | 208 | cmd = ' poisson_mesher \ 209 | --input_path {} \ 210 | --output_path {} \ 211 | --PoissonMeshing.trim {}'.format(in_ply, out_ply, trim) 212 | 213 | bash_run(cmd) 214 | 215 | 216 | def main(img_dir, pinhole_dict_file, out_dir): 217 | if not os.path.exists(out_dir): 218 | os.mkdir(out_dir) 219 | 220 | db_file = os.path.join(out_dir, 'database.db') 221 | run_sift_matching(img_dir, db_file) 222 | 223 | sfm_dir = os.path.join(out_dir, 'sfm') 224 | create_init_files(pinhole_dict_file, db_file, sfm_dir) 225 | run_point_triangulation(img_dir, db_file, sfm_dir) 226 | 227 | # # optional 228 | # run_global_ba(sfm_dir, sfm_dir) 229 | 230 | mvs_dir = os.path.join(out_dir, 'mvs') 231 | prepare_mvs(img_dir, sfm_dir, mvs_dir) 232 | run_photometric_mvs(mvs_dir, window_radius=5) 233 | 234 | out_ply = os.path.join(out_dir, 'fused.ply') 235 | run_fuse(mvs_dir, out_ply) 236 | 237 | out_mesh_ply = os.path.join(out_dir, 'meshed_trim_3.ply') 238 | run_possion_mesher(out_ply, out_mesh_ply, trim=3) 239 | 240 | 241 | def convert_cam_dict_to_pinhole_dict(cam_dict_file, pinhole_dict_file, img_dir): 242 | print('Writing pinhole_dict to: ', pinhole_dict_file) 243 | 244 | with open(cam_dict_file) as fp: 245 | cam_dict = json.load(fp) 246 | 247 | pinhole_dict = {} 248 | for img_name in cam_dict: 249 | data_item = cam_dict[img_name] 250 | if 'img_size' in data_item: 251 | w, h = data_item['img_size'] 252 | else: 253 | im = imageio.imread(os.path.join(img_dir, img_name)) 254 | h, w = im.shape[:2] 255 | 256 | K = np.array(data_item['K']).reshape((4, 4)) 257 | W2C = np.array(data_item['W2C']).reshape((4, 4)) 258 | 259 | # params 260 | fx = K[0, 0] 261 | fy = K[1, 1] 262 | assert(np.isclose(K[0, 1], 0.)) 263 | cx = K[0, 2] 264 | cy = K[1, 2] 265 | 266 | print(img_name) 267 | R = W2C[:3, :3] 268 | print(R) 269 | u, s_old, vh = np.linalg.svd(R, full_matrices=False) 270 | s = np.round(s_old) 271 | print('s: {} ---> {}'.format(s_old, s)) 272 | R = np.dot(u * s, vh) 273 | 274 | qvec = Quaternion(matrix=R) 275 | tvec = W2C[:3, 3] 276 | 277 | params = [w, h, fx, fy, cx, cy, 278 | qvec[0], qvec[1], qvec[2], qvec[3], 279 | tvec[0], tvec[1], tvec[2]] 280 | pinhole_dict[img_name] = params 281 | 282 | with open(pinhole_dict_file, 'w') as fp: 283 | json.dump(pinhole_dict, fp, indent=2, sort_keys=True) 284 | 285 | 286 | if __name__ == '__main__': 287 | img_dir = '' 288 | cam_dict_file = '' 289 | out_dir = '' 290 | 291 | os.makedirs(out_dir, exist_ok=True) 292 | pinhole_dict_file = os.path.join(out_dir, 'pinhole_dict.json') 293 | convert_cam_dict_to_pinhole_dict(cam_dict_file, pinhole_dict_file, img_dir) 294 | 295 | main(img_dir, pinhole_dict_file, out_dir) 296 | -------------------------------------------------------------------------------- /configs/lf_data/lf_africa.txt: -------------------------------------------------------------------------------- 1 | ### INPUT 2 | datadir = ./data/lf_data 3 | scene = africa 4 | expname = africa 5 | basedir = ./logs 6 | config = None 7 | ckpt_path = None 8 | no_reload = False 9 | testskip = 1 10 | 11 | ### TRAINING 12 | N_iters = 500001 13 | N_rand = 1024 14 | lrate = 0.0005 15 | lrate_decay_factor = 0.1 16 | lrate_decay_steps = 50000000 17 | 18 | ### CASCADE 19 | cascade_level = 2 20 | cascade_samples = 64,128 21 | 22 | ### TESTING 23 | chunk_size = 8192 24 | 25 | ### RENDERING 26 | det = False 27 | max_freq_log2 = 10 28 | max_freq_log2_viewdirs = 4 29 | netdepth = 8 30 | netwidth = 256 31 | use_viewdirs = True 32 | 33 | ### CONSOLE AND TENSORBOARD 34 | i_img = 2000 35 | i_print = 100 36 | i_weights = 5000 37 | -------------------------------------------------------------------------------- /configs/tanks_and_temples/tat_training_truck.txt: -------------------------------------------------------------------------------- 1 | ### INPUT 2 | datadir = ./data/tanks_and_temples 3 | scene = tat_training_Truck 4 | expname = tat_training_Truck 5 | basedir = ./logs 6 | config = None 7 | ckpt_path = None 8 | no_reload = False 9 | testskip = 1 10 | 11 | ### TRAINING 12 | N_iters = 500001 13 | N_rand = 1024 14 | lrate = 0.0005 15 | lrate_decay_factor = 0.1 16 | lrate_decay_steps = 50000000 17 | 18 | ### CASCADE 19 | cascade_level = 2 20 | cascade_samples = 64,128 21 | 22 | ### TESTING 23 | chunk_size = 8192 24 | 25 | ### RENDERING 26 | det = False 27 | max_freq_log2 = 10 28 | max_freq_log2_viewdirs = 4 29 | netdepth = 8 30 | netwidth = 256 31 | use_viewdirs = True 32 | 33 | ### CONSOLE AND TENSORBOARD 34 | i_img = 2000 35 | i_print = 100 36 | i_weights = 5000 37 | -------------------------------------------------------------------------------- /data_loader_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | import logging 5 | from nerf_sample_ray_split import RaySamplerSingleImage 6 | import glob 7 | 8 | logger = logging.getLogger(__package__) 9 | 10 | ######################################################################################################################## 11 | # camera coordinate system: x-->right, y-->down, z-->scene (opencv/colmap convention) 12 | # poses is camera-to-world 13 | ######################################################################################################################## 14 | def find_files(dir, exts): 15 | if os.path.isdir(dir): 16 | # types should be ['*.png', '*.jpg'] 17 | files_grabbed = [] 18 | for ext in exts: 19 | files_grabbed.extend(glob.glob(os.path.join(dir, ext))) 20 | if len(files_grabbed) > 0: 21 | files_grabbed = sorted(files_grabbed) 22 | return files_grabbed 23 | else: 24 | return [] 25 | 26 | 27 | def load_data_split(basedir, scene, split, skip=1, try_load_min_depth=True, only_img_files=False): 28 | 29 | def parse_txt(filename): 30 | assert os.path.isfile(filename) 31 | nums = open(filename).read().split() 32 | return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32) 33 | 34 | if basedir[-1] == '/': # remove trailing '/' 35 | basedir = basedir[:-1] 36 | 37 | split_dir = '{}/{}/{}'.format(basedir, scene, split) 38 | 39 | if only_img_files: 40 | img_files = find_files('{}/rgb'.format(split_dir), exts=['*.png', '*.jpg']) 41 | return img_files 42 | 43 | # camera parameters files 44 | intrinsics_files = find_files('{}/intrinsics'.format(split_dir), exts=['*.txt']) 45 | pose_files = find_files('{}/pose'.format(split_dir), exts=['*.txt']) 46 | logger.info('raw intrinsics_files: {}'.format(len(intrinsics_files))) 47 | logger.info('raw pose_files: {}'.format(len(pose_files))) 48 | 49 | intrinsics_files = intrinsics_files[::skip] 50 | pose_files = pose_files[::skip] 51 | cam_cnt = len(pose_files) 52 | 53 | # img files 54 | img_files = find_files('{}/rgb'.format(split_dir), exts=['*.png', '*.jpg']) 55 | if len(img_files) > 0: 56 | logger.info('raw img_files: {}'.format(len(img_files))) 57 | img_files = img_files[::skip] 58 | assert(len(img_files) == cam_cnt) 59 | else: 60 | img_files = [None, ] * cam_cnt 61 | 62 | # mask files 63 | mask_files = find_files('{}/mask'.format(split_dir), exts=['*.png', '*.jpg']) 64 | if len(mask_files) > 0: 65 | logger.info('raw mask_files: {}'.format(len(mask_files))) 66 | mask_files = mask_files[::skip] 67 | assert(len(mask_files) == cam_cnt) 68 | else: 69 | mask_files = [None, ] * cam_cnt 70 | 71 | # min depth files 72 | mindepth_files = find_files('{}/min_depth'.format(split_dir), exts=['*.png', '*.jpg']) 73 | if try_load_min_depth and len(mindepth_files) > 0: 74 | logger.info('raw mindepth_files: {}'.format(len(mindepth_files))) 75 | mindepth_files = mindepth_files[::skip] 76 | assert(len(mindepth_files) == cam_cnt) 77 | else: 78 | mindepth_files = [None, ] * cam_cnt 79 | 80 | # assume all images have the same size as training image 81 | train_imgfile = find_files('{}/{}/train/rgb'.format(basedir, scene), exts=['*.png', '*.jpg'])[0] 82 | train_im = imageio.imread(train_imgfile) 83 | H, W = train_im.shape[:2] 84 | 85 | # create ray samplers 86 | ray_samplers = [] 87 | for i in range(cam_cnt): 88 | intrinsics = parse_txt(intrinsics_files[i]) 89 | pose = parse_txt(pose_files[i]) 90 | 91 | # read max depth 92 | try: 93 | max_depth = float(open('{}/max_depth.txt'.format(split_dir)).readline().strip()) 94 | except: 95 | max_depth = None 96 | 97 | ray_samplers.append(RaySamplerSingleImage(H=H, W=W, intrinsics=intrinsics, c2w=pose, 98 | img_path=img_files[i], 99 | mask_path=mask_files[i], 100 | min_depth_path=mindepth_files[i], 101 | max_depth=max_depth)) 102 | 103 | logger.info('Split {}, # views: {}'.format(split, cam_cnt)) 104 | 105 | return ray_samplers 106 | -------------------------------------------------------------------------------- /ddp_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # import torch.nn.functional as F 4 | # import numpy as np 5 | from utils import TINY_NUMBER, HUGE_NUMBER 6 | from collections import OrderedDict 7 | from nerf_network import Embedder, MLPNet 8 | import os 9 | import logging 10 | logger = logging.getLogger(__package__) 11 | 12 | 13 | ###################################################################################### 14 | # wrapper to simplify the use of nerfnet 15 | ###################################################################################### 16 | def depth2pts_outside(ray_o, ray_d, depth): 17 | ''' 18 | ray_o, ray_d: [..., 3] 19 | depth: [...]; inverse of distance to sphere origin 20 | ''' 21 | # note: d1 becomes negative if this mid point is behind camera 22 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) 23 | p_mid = ray_o + d1.unsqueeze(-1) * ray_d 24 | p_mid_norm = torch.norm(p_mid, dim=-1) 25 | ray_d_cos = 1. / torch.norm(ray_d, dim=-1) 26 | d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos 27 | p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d 28 | 29 | rot_axis = torch.cross(ray_o, p_sphere, dim=-1) 30 | rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True) 31 | phi = torch.asin(p_mid_norm) 32 | theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1] 33 | rot_angle = (phi - theta).unsqueeze(-1) # [..., 1] 34 | 35 | # now rotate p_sphere 36 | # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula 37 | p_sphere_new = p_sphere * torch.cos(rot_angle) + \ 38 | torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \ 39 | rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle)) 40 | p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True) 41 | pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1) 42 | 43 | # now calculate conventional depth 44 | depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1 45 | return pts, depth_real 46 | 47 | 48 | class NerfNet(nn.Module): 49 | def __init__(self, args): 50 | super().__init__() 51 | # foreground 52 | self.fg_embedder_position = Embedder(input_dim=3, 53 | max_freq_log2=args.max_freq_log2 - 1, 54 | N_freqs=args.max_freq_log2) 55 | self.fg_embedder_viewdir = Embedder(input_dim=3, 56 | max_freq_log2=args.max_freq_log2_viewdirs - 1, 57 | N_freqs=args.max_freq_log2_viewdirs) 58 | self.fg_net = MLPNet(D=args.netdepth, W=args.netwidth, 59 | input_ch=self.fg_embedder_position.out_dim, 60 | input_ch_viewdirs=self.fg_embedder_viewdir.out_dim, 61 | use_viewdirs=args.use_viewdirs) 62 | # background; bg_pt is (x, y, z, 1/r) 63 | self.bg_embedder_position = Embedder(input_dim=4, 64 | max_freq_log2=args.max_freq_log2 - 1, 65 | N_freqs=args.max_freq_log2) 66 | self.bg_embedder_viewdir = Embedder(input_dim=3, 67 | max_freq_log2=args.max_freq_log2_viewdirs - 1, 68 | N_freqs=args.max_freq_log2_viewdirs) 69 | self.bg_net = MLPNet(D=args.netdepth, W=args.netwidth, 70 | input_ch=self.bg_embedder_position.out_dim, 71 | input_ch_viewdirs=self.bg_embedder_viewdir.out_dim, 72 | use_viewdirs=args.use_viewdirs) 73 | 74 | def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals): 75 | ''' 76 | :param ray_o, ray_d: [..., 3] 77 | :param fg_z_max: [...,] 78 | :param fg_z_vals, bg_z_vals: [..., N_samples] 79 | :return 80 | ''' 81 | # print(ray_o.shape, ray_d.shape, fg_z_max.shape, fg_z_vals.shape, bg_z_vals.shape) 82 | ray_d_norm = torch.norm(ray_d, dim=-1, keepdim=True) # [..., 1] 83 | viewdirs = ray_d / ray_d_norm # [..., 3] 84 | dots_sh = list(ray_d.shape[:-1]) 85 | 86 | ######### render foreground 87 | N_samples = fg_z_vals.shape[-1] 88 | fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 89 | fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 90 | fg_viewdirs = viewdirs.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 91 | fg_pts = fg_ray_o + fg_z_vals.unsqueeze(-1) * fg_ray_d 92 | input = torch.cat((self.fg_embedder_position(fg_pts), 93 | self.fg_embedder_viewdir(fg_viewdirs)), dim=-1) 94 | fg_raw = self.fg_net(input) 95 | # alpha blending 96 | fg_dists = fg_z_vals[..., 1:] - fg_z_vals[..., :-1] 97 | # account for view directions 98 | fg_dists = ray_d_norm * torch.cat((fg_dists, fg_z_max.unsqueeze(-1) - fg_z_vals[..., -1:]), dim=-1) # [..., N_samples] 99 | fg_alpha = 1. - torch.exp(-fg_raw['sigma'] * fg_dists) # [..., N_samples] 100 | T = torch.cumprod(1. - fg_alpha + TINY_NUMBER, dim=-1) # [..., N_samples] 101 | bg_lambda = T[..., -1] 102 | T = torch.cat((torch.ones_like(T[..., 0:1]), T[..., :-1]), dim=-1) # [..., N_samples] 103 | fg_weights = fg_alpha * T # [..., N_samples] 104 | fg_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['rgb'], dim=-2) # [..., 3] 105 | fg_depth_map = torch.sum(fg_weights * fg_z_vals, dim=-1) # [...,] 106 | 107 | # render background 108 | N_samples = bg_z_vals.shape[-1] 109 | bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 110 | bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 111 | bg_viewdirs = viewdirs.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 112 | bg_pts, _ = depth2pts_outside(bg_ray_o, bg_ray_d, bg_z_vals) # [..., N_samples, 4] 113 | input = torch.cat((self.bg_embedder_position(bg_pts), 114 | self.bg_embedder_viewdir(bg_viewdirs)), dim=-1) 115 | # near_depth: physical far; far_depth: physical near 116 | input = torch.flip(input, dims=[-2,]) 117 | bg_z_vals = torch.flip(bg_z_vals, dims=[-1,]) # 1--->0 118 | bg_dists = bg_z_vals[..., :-1] - bg_z_vals[..., 1:] 119 | bg_dists = torch.cat((bg_dists, HUGE_NUMBER * torch.ones_like(bg_dists[..., 0:1])), dim=-1) # [..., N_samples] 120 | bg_raw = self.bg_net(input) 121 | bg_alpha = 1. - torch.exp(-bg_raw['sigma'] * bg_dists) # [..., N_samples] 122 | # Eq. (3): T 123 | # maths show weights, and summation of weights along a ray, are always inside [0, 1] 124 | T = torch.cumprod(1. - bg_alpha + TINY_NUMBER, dim=-1)[..., :-1] # [..., N_samples-1] 125 | T = torch.cat((torch.ones_like(T[..., 0:1]), T), dim=-1) # [..., N_samples] 126 | bg_weights = bg_alpha * T # [..., N_samples] 127 | bg_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['rgb'], dim=-2) # [..., 3] 128 | bg_depth_map = torch.sum(bg_weights * bg_z_vals, dim=-1) # [...,] 129 | 130 | # composite foreground and background 131 | bg_rgb_map = bg_lambda.unsqueeze(-1) * bg_rgb_map 132 | bg_depth_map = bg_lambda * bg_depth_map 133 | rgb_map = fg_rgb_map + bg_rgb_map 134 | 135 | ret = OrderedDict([('rgb', rgb_map), # loss 136 | ('fg_weights', fg_weights), # importance sampling 137 | ('bg_weights', bg_weights), # importance sampling 138 | ('fg_rgb', fg_rgb_map), # below are for logging 139 | ('fg_depth', fg_depth_map), 140 | ('bg_rgb', bg_rgb_map), 141 | ('bg_depth', bg_depth_map), 142 | ('bg_lambda', bg_lambda)]) 143 | return ret 144 | 145 | 146 | def remap_name(name): 147 | name = name.replace('.', '-') # dot is not allowed by pytorch 148 | if name[-1] == '/': 149 | name = name[:-1] 150 | idx = name.rfind('/') 151 | for i in range(2): 152 | if idx >= 0: 153 | idx = name[:idx].rfind('/') 154 | return name[idx + 1:] 155 | 156 | 157 | class NerfNetWithAutoExpo(nn.Module): 158 | def __init__(self, args, optim_autoexpo=False, img_names=None): 159 | super().__init__() 160 | self.nerf_net = NerfNet(args) 161 | 162 | self.optim_autoexpo = optim_autoexpo 163 | if self.optim_autoexpo: 164 | assert(img_names is not None) 165 | logger.info('Optimizing autoexposure!') 166 | 167 | self.img_names = [remap_name(x) for x in img_names] 168 | logger.info('\n'.join(self.img_names)) 169 | self.autoexpo_params = nn.ParameterDict(OrderedDict([(x, nn.Parameter(torch.Tensor([0.5, 0.]))) for x in self.img_names])) 170 | 171 | def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals, img_name=None): 172 | ''' 173 | :param ray_o, ray_d: [..., 3] 174 | :param fg_z_max: [...,] 175 | :param fg_z_vals, bg_z_vals: [..., N_samples] 176 | :return 177 | ''' 178 | ret = self.nerf_net(ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals) 179 | 180 | if img_name is not None: 181 | img_name = remap_name(img_name) 182 | if self.optim_autoexpo and (img_name in self.autoexpo_params): 183 | autoexpo = self.autoexpo_params[img_name] 184 | scale = torch.abs(autoexpo[0]) + 0.5 # make sure scale is always positive 185 | shift = autoexpo[1] 186 | ret['autoexpo'] = (scale, shift) 187 | 188 | return ret 189 | -------------------------------------------------------------------------------- /ddp_test_nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import torch.nn as nn 3 | import torch.optim 4 | import torch.distributed 5 | # from torch.nn.parallel import DistributedDataParallel as DDP 6 | import torch.multiprocessing 7 | import numpy as np 8 | import os 9 | # from collections import OrderedDict 10 | # from ddp_model import NerfNet 11 | import time 12 | from data_loader_split import load_data_split 13 | from utils import mse2psnr, colorize_np, to8b 14 | import imageio 15 | from ddp_train_nerf import config_parser, setup_logger, setup, cleanup, render_single_image, create_nerf 16 | import logging 17 | 18 | 19 | logger = logging.getLogger(__package__) 20 | 21 | 22 | def ddp_test_nerf(rank, args): 23 | ###### set up multi-processing 24 | setup(rank, args.world_size) 25 | ###### set up logger 26 | logger = logging.getLogger(__package__) 27 | setup_logger() 28 | 29 | ###### decide chunk size according to gpu memory 30 | if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14: 31 | logger.info('setting batch size according to 24G gpu') 32 | args.N_rand = 1024 33 | args.chunk_size = 8192 34 | else: 35 | logger.info('setting batch size according to 12G gpu') 36 | args.N_rand = 512 37 | args.chunk_size = 4096 38 | 39 | ###### create network and wrap in ddp; each process should do this 40 | start, models = create_nerf(rank, args) 41 | 42 | render_splits = [x.strip() for x in args.render_splits.strip().split(',')] 43 | # start testing 44 | for split in render_splits: 45 | out_dir = os.path.join(args.basedir, args.expname, 46 | 'render_{}_{:06d}'.format(split, start)) 47 | if rank == 0: 48 | os.makedirs(out_dir, exist_ok=True) 49 | 50 | ###### load data and create ray samplers; each process should do this 51 | ray_samplers = load_data_split(args.datadir, args.scene, split, try_load_min_depth=args.load_min_depth) 52 | for idx in range(len(ray_samplers)): 53 | ### each process should do this; but only main process merges the results 54 | fname = '{:06d}.png'.format(idx) 55 | if ray_samplers[idx].img_path is not None: 56 | fname = os.path.basename(ray_samplers[idx].img_path) 57 | 58 | if os.path.isfile(os.path.join(out_dir, fname)): 59 | logger.info('Skipping {}'.format(fname)) 60 | continue 61 | 62 | time0 = time.time() 63 | ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size) 64 | dt = time.time() - time0 65 | if rank == 0: # only main process should do this 66 | logger.info('Rendered {} in {} seconds'.format(fname, dt)) 67 | 68 | # only save last level 69 | im = ret[-1]['rgb'].numpy() 70 | # compute psnr if ground-truth is available 71 | if ray_samplers[idx].img_path is not None: 72 | gt_im = ray_samplers[idx].get_img() 73 | psnr = mse2psnr(np.mean((gt_im - im) * (gt_im - im))) 74 | logger.info('{}: psnr={}'.format(fname, psnr)) 75 | 76 | im = to8b(im) 77 | imageio.imwrite(os.path.join(out_dir, fname), im) 78 | 79 | im = ret[-1]['fg_rgb'].numpy() 80 | im = to8b(im) 81 | imageio.imwrite(os.path.join(out_dir, 'fg_' + fname), im) 82 | 83 | im = ret[-1]['bg_rgb'].numpy() 84 | im = to8b(im) 85 | imageio.imwrite(os.path.join(out_dir, 'bg_' + fname), im) 86 | 87 | im = ret[-1]['fg_depth'].numpy() 88 | im = colorize_np(im, cmap_name='jet', append_cbar=True) 89 | im = to8b(im) 90 | imageio.imwrite(os.path.join(out_dir, 'fg_depth_' + fname), im) 91 | 92 | im = ret[-1]['bg_depth'].numpy() 93 | im = colorize_np(im, cmap_name='jet', append_cbar=True) 94 | im = to8b(im) 95 | imageio.imwrite(os.path.join(out_dir, 'bg_depth_' + fname), im) 96 | 97 | torch.cuda.empty_cache() 98 | 99 | # clean up for multi-processing 100 | cleanup() 101 | 102 | 103 | def test(): 104 | parser = config_parser() 105 | args = parser.parse_args() 106 | logger.info(parser.format_values()) 107 | 108 | if args.world_size == -1: 109 | args.world_size = torch.cuda.device_count() 110 | logger.info('Using # gpus: {}'.format(args.world_size)) 111 | torch.multiprocessing.spawn(ddp_test_nerf, 112 | args=(args,), 113 | nprocs=args.world_size, 114 | join=True) 115 | 116 | 117 | if __name__ == '__main__': 118 | setup_logger() 119 | test() 120 | 121 | -------------------------------------------------------------------------------- /ddp_train_nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim 4 | import torch.distributed 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | import torch.multiprocessing 7 | import os 8 | from collections import OrderedDict 9 | from ddp_model import NerfNetWithAutoExpo 10 | import time 11 | from data_loader_split import load_data_split 12 | import numpy as np 13 | from tensorboardX import SummaryWriter 14 | from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER 15 | import logging 16 | import json 17 | 18 | 19 | logger = logging.getLogger(__package__) 20 | 21 | 22 | def setup_logger(): 23 | # create logger 24 | logger = logging.getLogger(__package__) 25 | # logger.setLevel(logging.DEBUG) 26 | logger.setLevel(logging.INFO) 27 | 28 | # create console handler and set level to debug 29 | ch = logging.StreamHandler() 30 | ch.setLevel(logging.DEBUG) 31 | 32 | # create formatter 33 | formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') 34 | 35 | # add formatter to ch 36 | ch.setFormatter(formatter) 37 | 38 | # add ch to logger 39 | logger.addHandler(ch) 40 | 41 | 42 | def intersect_sphere(ray_o, ray_d): 43 | ''' 44 | ray_o, ray_d: [..., 3] 45 | compute the depth of the intersection point between this ray and unit sphere 46 | ''' 47 | # note: d1 becomes negative if this mid point is behind camera 48 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) 49 | p = ray_o + d1.unsqueeze(-1) * ray_d 50 | # consider the case where the ray does not intersect the sphere 51 | ray_d_cos = 1. / torch.norm(ray_d, dim=-1) 52 | p_norm_sq = torch.sum(p * p, dim=-1) 53 | if (p_norm_sq >= 1.).any(): 54 | raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!') 55 | d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos 56 | 57 | return d1 + d2 58 | 59 | 60 | def perturb_samples(z_vals): 61 | # get intervals between samples 62 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 63 | upper = torch.cat([mids, z_vals[..., -1:]], dim=-1) 64 | lower = torch.cat([z_vals[..., 0:1], mids], dim=-1) 65 | # uniform samples in those intervals 66 | t_rand = torch.rand_like(z_vals) 67 | z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples] 68 | 69 | return z_vals 70 | 71 | 72 | def sample_pdf(bins, weights, N_samples, det=False): 73 | ''' 74 | :param bins: tensor of shape [..., M+1], M is the number of bins 75 | :param weights: tensor of shape [..., M] 76 | :param N_samples: number of samples along each ray 77 | :param det: if True, will perform deterministic sampling 78 | :return: [..., N_samples] 79 | ''' 80 | # Get pdf 81 | weights = weights + TINY_NUMBER # prevent nans 82 | pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] 83 | cdf = torch.cumsum(pdf, dim=-1) # [..., M] 84 | cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1] 85 | 86 | # Take uniform samples 87 | dots_sh = list(weights.shape[:-1]) 88 | M = weights.shape[-1] 89 | 90 | min_cdf = 0.00 91 | max_cdf = 1.00 # prevent outlier samples 92 | 93 | if det: 94 | u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device) 95 | u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples] 96 | else: 97 | sh = dots_sh + [N_samples] 98 | u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples] 99 | 100 | # Invert CDF 101 | # [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,] 102 | above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long() 103 | 104 | # random sample inside each bin 105 | below_inds = torch.clamp(above_inds-1, min=0) 106 | inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2] 107 | 108 | cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] 109 | cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2] 110 | 111 | bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] 112 | bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2] 113 | 114 | # fix numeric issue 115 | denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples] 116 | denom = torch.where(denom1 250 | writer.add_image(prefix + 'level_{}/rgb'.format(m), rgb_im, global_step) 251 | 252 | rgb_im = img_HWC2CHW(log_data[m]['fg_rgb']) 253 | rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1 254 | writer.add_image(prefix + 'level_{}/fg_rgb'.format(m), rgb_im, global_step) 255 | depth = log_data[m]['fg_depth'] 256 | depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True, 257 | mask=mask)) 258 | writer.add_image(prefix + 'level_{}/fg_depth'.format(m), depth_im, global_step) 259 | 260 | rgb_im = img_HWC2CHW(log_data[m]['bg_rgb']) 261 | rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1 262 | writer.add_image(prefix + 'level_{}/bg_rgb'.format(m), rgb_im, global_step) 263 | depth = log_data[m]['bg_depth'] 264 | depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True, 265 | mask=mask)) 266 | writer.add_image(prefix + 'level_{}/bg_depth'.format(m), depth_im, global_step) 267 | bg_lambda = log_data[m]['bg_lambda'] 268 | bg_lambda_im = img_HWC2CHW(colorize(bg_lambda, cmap_name='hot', append_cbar=True, 269 | mask=mask)) 270 | writer.add_image(prefix + 'level_{}/bg_lambda'.format(m), bg_lambda_im, global_step) 271 | 272 | 273 | def setup(rank, world_size): 274 | os.environ['MASTER_ADDR'] = 'localhost' 275 | # port = np.random.randint(12355, 12399) 276 | # os.environ['MASTER_PORT'] = '{}'.format(port) 277 | os.environ['MASTER_PORT'] = '12355' 278 | # initialize the process group 279 | torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) 280 | 281 | 282 | def cleanup(): 283 | torch.distributed.destroy_process_group() 284 | 285 | 286 | def create_nerf(rank, args): 287 | ###### create network and wrap in ddp; each process should do this 288 | # fix random seed just to make sure the network is initialized with same weights at different processes 289 | torch.manual_seed(777) 290 | # very important!!! otherwise it might introduce extra memory in rank=0 gpu 291 | torch.cuda.set_device(rank) 292 | 293 | models = OrderedDict() 294 | models['cascade_level'] = args.cascade_level 295 | models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')] 296 | for m in range(models['cascade_level']): 297 | img_names = None 298 | if args.optim_autoexpo: 299 | # load training image names for autoexposure 300 | f = os.path.join(args.basedir, args.expname, 'train_images.json') 301 | with open(f) as file: 302 | img_names = json.load(file) 303 | net = NerfNetWithAutoExpo(args, optim_autoexpo=args.optim_autoexpo, img_names=img_names).to(rank) 304 | net = DDP(net, device_ids=[rank], output_device=rank, find_unused_parameters=True) 305 | # net = DDP(net, device_ids=[rank], output_device=rank) 306 | optim = torch.optim.Adam(net.parameters(), lr=args.lrate) 307 | models['net_{}'.format(m)] = net 308 | models['optim_{}'.format(m)] = optim 309 | 310 | start = -1 311 | 312 | ###### load pretrained weights; each process should do this 313 | if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)): 314 | ckpts = [args.ckpt_path] 315 | else: 316 | ckpts = [os.path.join(args.basedir, args.expname, f) 317 | for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')] 318 | def path2iter(path): 319 | tmp = os.path.basename(path)[:-4] 320 | idx = tmp.rfind('_') 321 | return int(tmp[idx + 1:]) 322 | ckpts = sorted(ckpts, key=path2iter) 323 | logger.info('Found ckpts: {}'.format(ckpts)) 324 | if len(ckpts) > 0 and not args.no_reload: 325 | fpath = ckpts[-1] 326 | logger.info('Reloading from: {}'.format(fpath)) 327 | start = path2iter(fpath) 328 | # configure map_location properly for different processes 329 | map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} 330 | to_load = torch.load(fpath, map_location=map_location) 331 | for m in range(models['cascade_level']): 332 | for name in ['net_{}'.format(m), 'optim_{}'.format(m)]: 333 | models[name].load_state_dict(to_load[name]) 334 | 335 | return start, models 336 | 337 | 338 | def ddp_train_nerf(rank, args): 339 | ###### set up multi-processing 340 | setup(rank, args.world_size) 341 | ###### set up logger 342 | logger = logging.getLogger(__package__) 343 | setup_logger() 344 | 345 | ###### decide chunk size according to gpu memory 346 | logger.info('gpu_mem: {}'.format(torch.cuda.get_device_properties(rank).total_memory)) 347 | if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14: 348 | logger.info('setting batch size according to 24G gpu') 349 | args.N_rand = 1024 350 | args.chunk_size = 8192 351 | else: 352 | logger.info('setting batch size according to 12G gpu') 353 | args.N_rand = 512 354 | args.chunk_size = 4096 355 | 356 | ###### Create log dir and copy the config file 357 | if rank == 0: 358 | os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True) 359 | f = os.path.join(args.basedir, args.expname, 'args.txt') 360 | with open(f, 'w') as file: 361 | for arg in sorted(vars(args)): 362 | attr = getattr(args, arg) 363 | file.write('{} = {}\n'.format(arg, attr)) 364 | if args.config is not None: 365 | f = os.path.join(args.basedir, args.expname, 'config.txt') 366 | with open(f, 'w') as file: 367 | file.write(open(args.config, 'r').read()) 368 | torch.distributed.barrier() 369 | 370 | ray_samplers = load_data_split(args.datadir, args.scene, split='train', 371 | try_load_min_depth=args.load_min_depth) 372 | val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation', 373 | try_load_min_depth=args.load_min_depth, skip=args.testskip) 374 | 375 | # write training image names for autoexposure 376 | if args.optim_autoexpo: 377 | f = os.path.join(args.basedir, args.expname, 'train_images.json') 378 | with open(f, 'w') as file: 379 | img_names = [ray_samplers[i].img_path for i in range(len(ray_samplers))] 380 | json.dump(img_names, file, indent=2) 381 | 382 | ###### create network and wrap in ddp; each process should do this 383 | start, models = create_nerf(rank, args) 384 | 385 | ##### important!!! 386 | # make sure different processes sample different rays 387 | np.random.seed((rank + 1) * 777) 388 | # make sure different processes have different perturbations in depth samples 389 | torch.manual_seed((rank + 1) * 777) 390 | 391 | ##### only main process should do the logging 392 | if rank == 0: 393 | writer = SummaryWriter(os.path.join(args.basedir, 'summaries', args.expname)) 394 | 395 | # start training 396 | what_val_to_log = 0 # helper variable for parallel rendering of a image 397 | what_train_to_log = 0 398 | for global_step in range(start+1, start+1+args.N_iters): 399 | time0 = time.time() 400 | scalars_to_log = OrderedDict() 401 | ### Start of core optimization loop 402 | scalars_to_log['resolution'] = ray_samplers[0].resolution_level 403 | # randomly sample rays and move to device 404 | i = np.random.randint(low=0, high=len(ray_samplers)) 405 | ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False) 406 | for key in ray_batch: 407 | if torch.is_tensor(ray_batch[key]): 408 | ray_batch[key] = ray_batch[key].to(rank) 409 | 410 | # forward and backward 411 | dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays 412 | all_rets = [] # results on different cascade levels 413 | for m in range(models['cascade_level']): 414 | optim = models['optim_{}'.format(m)] 415 | net = models['net_{}'.format(m)] 416 | 417 | # sample depths 418 | N_samples = models['cascade_samples'][m] 419 | if m == 0: 420 | # foreground depth 421 | fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,] 422 | fg_near_depth = ray_batch['min_depth'] # [..., ] 423 | step = (fg_far_depth - fg_near_depth) / (N_samples - 1) 424 | fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples] 425 | fg_depth = perturb_samples(fg_depth) # random perturbation during training 426 | 427 | # background depth 428 | bg_depth = torch.linspace(0., 1., N_samples).view( 429 | [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank) 430 | bg_depth = perturb_samples(bg_depth) # random perturbation during training 431 | else: 432 | # sample pdf and concat with earlier samples 433 | fg_weights = ret['fg_weights'].clone().detach() 434 | fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1] 435 | fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2] 436 | fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights, 437 | N_samples=N_samples, det=False) # [..., N_samples] 438 | fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1)) 439 | 440 | # sample pdf and concat with earlier samples 441 | bg_weights = ret['bg_weights'].clone().detach() 442 | bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1]) 443 | bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2] 444 | bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights, 445 | N_samples=N_samples, det=False) # [..., N_samples] 446 | bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1)) 447 | 448 | optim.zero_grad() 449 | ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth, img_name=ray_batch['img_name']) 450 | all_rets.append(ret) 451 | 452 | rgb_gt = ray_batch['rgb'].to(rank) 453 | if 'autoexpo' in ret: 454 | scale, shift = ret['autoexpo'] 455 | scalars_to_log['level_{}/autoexpo_scale'.format(m)] = scale.item() 456 | scalars_to_log['level_{}/autoexpo_shift'.format(m)] = shift.item() 457 | # rgb_gt = scale * rgb_gt + shift 458 | rgb_pred = (ret['rgb'] - shift) / scale 459 | rgb_loss = img2mse(rgb_pred, rgb_gt) 460 | loss = rgb_loss + args.lambda_autoexpo * (torch.abs(scale-1.)+torch.abs(shift)) 461 | else: 462 | rgb_loss = img2mse(ret['rgb'], rgb_gt) 463 | loss = rgb_loss 464 | scalars_to_log['level_{}/loss'.format(m)] = rgb_loss.item() 465 | scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(rgb_loss.item()) 466 | loss.backward() 467 | optim.step() 468 | 469 | # # clean unused memory 470 | # torch.cuda.empty_cache() 471 | 472 | ### end of core optimization loop 473 | dt = time.time() - time0 474 | scalars_to_log['iter_time'] = dt 475 | 476 | ### only main process should do the logging 477 | if rank == 0 and (global_step % args.i_print == 0 or global_step < 10): 478 | logstr = '{} step: {} '.format(args.expname, global_step) 479 | for k in scalars_to_log: 480 | logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k]) 481 | writer.add_scalar(k, scalars_to_log[k], global_step) 482 | logger.info(logstr) 483 | 484 | ### each process should do this; but only main process merges the results 485 | if global_step % args.i_img == 0 or global_step == start+1: 486 | #### critical: make sure each process is working on the same random image 487 | time0 = time.time() 488 | idx = what_val_to_log % len(val_ray_samplers) 489 | log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size) 490 | what_val_to_log += 1 491 | dt = time.time() - time0 492 | if rank == 0: # only main process should do this 493 | logger.info('Logged a random validation view in {} seconds'.format(dt)) 494 | log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].get_img(), mask=None, prefix='val/') 495 | 496 | time0 = time.time() 497 | idx = what_train_to_log % len(ray_samplers) 498 | log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size) 499 | what_train_to_log += 1 500 | dt = time.time() - time0 501 | if rank == 0: # only main process should do this 502 | logger.info('Logged a random training view in {} seconds'.format(dt)) 503 | log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].get_img(), mask=None, prefix='train/') 504 | 505 | del log_data 506 | torch.cuda.empty_cache() 507 | 508 | if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0): 509 | # saving checkpoints and logging 510 | fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step)) 511 | to_save = OrderedDict() 512 | for m in range(models['cascade_level']): 513 | name = 'net_{}'.format(m) 514 | to_save[name] = models[name].state_dict() 515 | 516 | name = 'optim_{}'.format(m) 517 | to_save[name] = models[name].state_dict() 518 | torch.save(to_save, fpath) 519 | 520 | # clean up for multi-processing 521 | cleanup() 522 | 523 | 524 | def config_parser(): 525 | import configargparse 526 | parser = configargparse.ArgumentParser() 527 | parser.add_argument('--config', is_config_file=True, help='config file path') 528 | parser.add_argument("--expname", type=str, help='experiment name') 529 | parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs') 530 | # dataset options 531 | parser.add_argument("--datadir", type=str, default=None, help='input data directory') 532 | parser.add_argument("--scene", type=str, default=None, help='scene name') 533 | parser.add_argument("--testskip", type=int, default=8, 534 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 535 | # model size 536 | parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network') 537 | parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network') 538 | parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D') 539 | # checkpoints 540 | parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') 541 | parser.add_argument("--ckpt_path", type=str, default=None, 542 | help='specific weights npy file to reload for coarse network') 543 | # batch size 544 | parser.add_argument("--N_rand", type=int, default=32 * 32 * 2, 545 | help='batch size (number of random rays per gradient step)') 546 | parser.add_argument("--chunk_size", type=int, default=1024 * 8, 547 | help='number of rays processed in parallel, decrease if running out of memory') 548 | # iterations 549 | parser.add_argument("--N_iters", type=int, default=250001, 550 | help='number of iterations') 551 | # render only 552 | parser.add_argument("--render_splits", type=str, default='test', 553 | help='splits to render') 554 | # cascade training 555 | parser.add_argument("--cascade_level", type=int, default=2, 556 | help='number of cascade levels') 557 | parser.add_argument("--cascade_samples", type=str, default='64,64', 558 | help='samples at each level') 559 | # multiprocess learning 560 | parser.add_argument("--world_size", type=int, default='-1', 561 | help='number of processes') 562 | # optimize autoexposure 563 | parser.add_argument("--optim_autoexpo", action='store_true', 564 | help='optimize autoexposure parameters') 565 | parser.add_argument("--lambda_autoexpo", type=float, default=1., help='regularization weight for autoexposure') 566 | 567 | # learning rate options 568 | parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') 569 | parser.add_argument("--lrate_decay_factor", type=float, default=0.1, 570 | help='decay learning rate by a factor every specified number of steps') 571 | parser.add_argument("--lrate_decay_steps", type=int, default=5000, 572 | help='decay learning rate by a factor every specified number of steps') 573 | # rendering options 574 | parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples') 575 | parser.add_argument("--max_freq_log2", type=int, default=10, 576 | help='log2 of max freq for positional encoding (3D location)') 577 | parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4, 578 | help='log2 of max freq for positional encoding (2D direction)') 579 | parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth') 580 | # logging/saving options 581 | parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin') 582 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') 583 | parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving') 584 | 585 | return parser 586 | 587 | 588 | def train(): 589 | parser = config_parser() 590 | args = parser.parse_args() 591 | logger.info(parser.format_values()) 592 | 593 | if args.world_size == -1: 594 | args.world_size = torch.cuda.device_count() 595 | logger.info('Using # gpus: {}'.format(args.world_size)) 596 | torch.multiprocessing.spawn(ddp_train_nerf, 597 | args=(args,), 598 | nprocs=args.world_size, 599 | join=True) 600 | 601 | 602 | if __name__ == '__main__': 603 | setup_logger() 604 | train() 605 | 606 | 607 | -------------------------------------------------------------------------------- /demo/tat_Playground.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/demo/tat_Playground.gif -------------------------------------------------------------------------------- /demo/tat_Truck.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/nerfplusplus/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/demo/tat_Truck.gif -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: nerfplusplus 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - ca-certificates=2020.7.22=0 7 | - certifi=2020.6.20=py36_0 8 | - ld_impl_linux-64=2.33.1=h53a641e_7 9 | - libedit=3.1.20191231=h14c3975_1 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - ncurses=6.2=he6710b0_1 14 | - openssl=1.1.1g=h7b6447c_0 15 | - pip=20.2.2=py36_0 16 | - python=3.6.12=hcff3b4d_2 17 | - readline=8.0=h7b6447c_0 18 | - setuptools=49.6.0=py36_0 19 | - sqlite=3.33.0=h62c20be_0 20 | - tk=8.6.10=hbc83047_0 21 | - wheel=0.35.1=py_0 22 | - xz=5.2.5=h7b6447c_0 23 | - zlib=1.2.11=h7b6447c_3 24 | - pip: 25 | - absl-py==0.10.0 26 | - astunparse==1.6.3 27 | - backcall==0.2.0 28 | - cachetools==4.1.1 29 | - chardet==3.0.4 30 | - configargparse==1.2.3 31 | - cycler==0.10.0 32 | - decorator==4.4.2 33 | - future==0.18.2 34 | - gast==0.3.3 35 | - google-auth==1.21.2 36 | - google-auth-oauthlib==0.4.1 37 | - google-pasta==0.2.0 38 | - grpcio==1.32.0 39 | - h5py==2.10.0 40 | - idna==2.10 41 | - imageio==2.9.0 42 | - importlib-metadata==1.7.0 43 | - ipython==7.16.1 44 | - ipython-genutils==0.2.0 45 | - jedi==0.17.2 46 | - keras-preprocessing==1.1.2 47 | - kiwisolver==1.2.0 48 | - lpips==0.1.1 49 | - markdown==3.2.2 50 | - matplotlib==3.3.2 51 | - networkx==2.5 52 | - numpy==1.18.0 53 | - oauthlib==3.1.0 54 | - opencv-python==4.4.0.42 55 | - opt-einsum==3.3.0 56 | - parso==0.7.1 57 | - pexpect==4.8.0 58 | - pickleshare==0.7.5 59 | - pillow==7.2.0 60 | - prompt-toolkit==3.0.7 61 | - protobuf==3.13.0 62 | - ptyprocess==0.6.0 63 | - pyasn1==0.4.8 64 | - pyasn1-modules==0.2.8 65 | - pygments==2.7.1 66 | - pymcubes==0.1.2 67 | - pyparsing==2.4.7 68 | - python-dateutil==2.8.1 69 | - pywavelets==1.1.1 70 | - requests==2.24.0 71 | - requests-oauthlib==1.3.0 72 | - rsa==4.6 73 | - scikit-image==0.17.2 74 | - scipy==1.4.1 75 | - six==1.15.0 76 | - tensorboard==2.3.0 77 | - tensorboard-plugin-wit==1.7.0 78 | - tensorboardx==2.1 79 | - tensorflow==2.3.0 80 | - tensorflow-estimator==2.3.0 81 | - termcolor==1.1.0 82 | - tifffile==2020.9.3 83 | - torch==1.6.0 84 | - torchvision==0.7.0 85 | - tqdm==4.49.0 86 | - traitlets==4.3.3 87 | - trimesh==3.8.10 88 | - urllib3==1.25.10 89 | - wcwidth==0.2.5 90 | - werkzeug==1.0.1 91 | - wrapt==1.12.1 92 | - zipp==3.1.0 93 | prefix: /home/zhangka2/anaconda3/envs/nerfplusplus 94 | 95 | -------------------------------------------------------------------------------- /nerf_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # import torch.nn.functional as F 4 | # import numpy as np 5 | from collections import OrderedDict 6 | 7 | import logging 8 | logger = logging.getLogger(__package__) 9 | 10 | 11 | class Embedder(nn.Module): 12 | def __init__(self, input_dim, max_freq_log2, N_freqs, 13 | log_sampling=True, include_input=True, 14 | periodic_fns=(torch.sin, torch.cos)): 15 | ''' 16 | :param input_dim: dimension of input to be embedded 17 | :param max_freq_log2: log2 of max freq; min freq is 1 by default 18 | :param N_freqs: number of frequency bands 19 | :param log_sampling: if True, frequency bands are linerly sampled in log-space 20 | :param include_input: if True, raw input is included in the embedding 21 | :param periodic_fns: periodic functions used to embed input 22 | ''' 23 | super().__init__() 24 | 25 | self.input_dim = input_dim 26 | self.include_input = include_input 27 | self.periodic_fns = periodic_fns 28 | 29 | self.out_dim = 0 30 | if self.include_input: 31 | self.out_dim += self.input_dim 32 | 33 | self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns) 34 | 35 | if log_sampling: 36 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) 37 | else: 38 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) 39 | 40 | self.freq_bands = self.freq_bands.numpy().tolist() 41 | 42 | def forward(self, input): 43 | ''' 44 | :param input: tensor of shape [..., self.input_dim] 45 | :return: tensor of shape [..., self.out_dim] 46 | ''' 47 | assert (input.shape[-1] == self.input_dim) 48 | 49 | out = [] 50 | if self.include_input: 51 | out.append(input) 52 | 53 | for i in range(len(self.freq_bands)): 54 | freq = self.freq_bands[i] 55 | for p_fn in self.periodic_fns: 56 | out.append(p_fn(input * freq)) 57 | out = torch.cat(out, dim=-1) 58 | 59 | assert (out.shape[-1] == self.out_dim) 60 | return out 61 | 62 | # default tensorflow initialization of linear layers 63 | def weights_init(m): 64 | if isinstance(m, nn.Linear): 65 | nn.init.xavier_uniform_(m.weight.data) 66 | if m.bias is not None: 67 | nn.init.zeros_(m.bias.data) 68 | 69 | 70 | class MLPNet(nn.Module): 71 | def __init__(self, D=8, W=256, input_ch=3, input_ch_viewdirs=3, 72 | skips=[4], use_viewdirs=False): 73 | ''' 74 | :param D: network depth 75 | :param W: network width 76 | :param input_ch: input channels for encodings of (x, y, z) 77 | :param input_ch_viewdirs: input channels for encodings of view directions 78 | :param skips: skip connection in network 79 | :param use_viewdirs: if True, will use the view directions as input 80 | ''' 81 | super().__init__() 82 | 83 | self.input_ch = input_ch 84 | self.input_ch_viewdirs = input_ch_viewdirs 85 | self.use_viewdirs = use_viewdirs 86 | self.skips = skips 87 | 88 | self.base_layers = [] 89 | dim = self.input_ch 90 | for i in range(D): 91 | self.base_layers.append( 92 | nn.Sequential(nn.Linear(dim, W), nn.ReLU()) 93 | ) 94 | dim = W 95 | if i in self.skips and i != (D-1): # skip connection after i^th layer 96 | dim += input_ch 97 | self.base_layers = nn.ModuleList(self.base_layers) 98 | # self.base_layers.apply(weights_init) # xavier init 99 | 100 | sigma_layers = [nn.Linear(dim, 1), ] # sigma must be positive 101 | self.sigma_layers = nn.Sequential(*sigma_layers) 102 | # self.sigma_layers.apply(weights_init) # xavier init 103 | 104 | # rgb color 105 | rgb_layers = [] 106 | base_remap_layers = [nn.Linear(dim, 256), ] 107 | self.base_remap_layers = nn.Sequential(*base_remap_layers) 108 | # self.base_remap_layers.apply(weights_init) 109 | 110 | dim = 256 + self.input_ch_viewdirs 111 | for i in range(1): 112 | rgb_layers.append(nn.Linear(dim, W // 2)) 113 | rgb_layers.append(nn.ReLU()) 114 | dim = W // 2 115 | rgb_layers.append(nn.Linear(dim, 3)) 116 | rgb_layers.append(nn.Sigmoid()) # rgb values are normalized to [0, 1] 117 | self.rgb_layers = nn.Sequential(*rgb_layers) 118 | # self.rgb_layers.apply(weights_init) 119 | 120 | def forward(self, input): 121 | ''' 122 | :param input: [..., input_ch+input_ch_viewdirs] 123 | :return [..., 4] 124 | ''' 125 | input_pts = input[..., :self.input_ch] 126 | 127 | base = self.base_layers[0](input_pts) 128 | for i in range(len(self.base_layers)-1): 129 | if i in self.skips: 130 | base = torch.cat((input_pts, base), dim=-1) 131 | base = self.base_layers[i+1](base) 132 | 133 | sigma = self.sigma_layers(base) 134 | sigma = torch.abs(sigma) 135 | 136 | base_remap = self.base_remap_layers(base) 137 | input_viewdirs = input[..., -self.input_ch_viewdirs:] 138 | rgb = self.rgb_layers(torch.cat((base_remap, input_viewdirs), dim=-1)) 139 | 140 | ret = OrderedDict([('rgb', rgb), 141 | ('sigma', sigma.squeeze(-1))]) 142 | return ret 143 | -------------------------------------------------------------------------------- /nerf_sample_ray_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | import torch 4 | import cv2 5 | import imageio 6 | 7 | ######################################################################################################################## 8 | # ray batch sampling 9 | ######################################################################################################################## 10 | def get_rays_single_image(H, W, intrinsics, c2w): 11 | ''' 12 | :param H: image height 13 | :param W: image width 14 | :param intrinsics: 4 by 4 intrinsic matrix 15 | :param c2w: 4 by 4 camera to world extrinsic matrix 16 | :return: 17 | ''' 18 | u, v = np.meshgrid(np.arange(W), np.arange(H)) 19 | 20 | u = u.reshape(-1).astype(dtype=np.float32) + 0.5 # add half pixel 21 | v = v.reshape(-1).astype(dtype=np.float32) + 0.5 22 | pixels = np.stack((u, v, np.ones_like(u)), axis=0) # (3, H*W) 23 | 24 | rays_d = np.dot(np.linalg.inv(intrinsics[:3, :3]), pixels) 25 | rays_d = np.dot(c2w[:3, :3], rays_d) # (3, H*W) 26 | rays_d = rays_d.transpose((1, 0)) # (H*W, 3) 27 | 28 | rays_o = c2w[:3, 3].reshape((1, 3)) 29 | rays_o = np.tile(rays_o, (rays_d.shape[0], 1)) # (H*W, 3) 30 | 31 | depth = np.linalg.inv(c2w)[2, 3] 32 | depth = depth * np.ones((rays_o.shape[0],), dtype=np.float32) # (H*W,) 33 | 34 | return rays_o, rays_d, depth 35 | 36 | 37 | class RaySamplerSingleImage(object): 38 | def __init__(self, H, W, intrinsics, c2w, 39 | img_path=None, 40 | resolution_level=1, 41 | mask_path=None, 42 | min_depth_path=None, 43 | max_depth=None): 44 | super().__init__() 45 | self.W_orig = W 46 | self.H_orig = H 47 | self.intrinsics_orig = intrinsics 48 | self.c2w_mat = c2w 49 | 50 | self.img_path = img_path 51 | self.mask_path = mask_path 52 | self.min_depth_path = min_depth_path 53 | self.max_depth = max_depth 54 | 55 | self.resolution_level = -1 56 | self.set_resolution_level(resolution_level) 57 | 58 | def set_resolution_level(self, resolution_level): 59 | if resolution_level != self.resolution_level: 60 | self.resolution_level = resolution_level 61 | self.W = self.W_orig // resolution_level 62 | self.H = self.H_orig // resolution_level 63 | self.intrinsics = np.copy(self.intrinsics_orig) 64 | self.intrinsics[:2, :3] /= resolution_level 65 | # only load image at this time 66 | if self.img_path is not None: 67 | self.img = imageio.imread(self.img_path).astype(np.float32) / 255. 68 | self.img = cv2.resize(self.img, (self.W, self.H), interpolation=cv2.INTER_AREA) 69 | self.img = self.img.reshape((-1, 3)) 70 | else: 71 | self.img = None 72 | 73 | if self.mask_path is not None: 74 | self.mask = imageio.imread(self.mask_path).astype(np.float32) / 255. 75 | self.mask = cv2.resize(self.mask, (self.W, self.H), interpolation=cv2.INTER_NEAREST) 76 | self.mask = self.mask.reshape((-1)) 77 | else: 78 | self.mask = None 79 | 80 | if self.min_depth_path is not None: 81 | self.min_depth = imageio.imread(self.min_depth_path).astype(np.float32) / 255. * self.max_depth + 1e-4 82 | self.min_depth = cv2.resize(self.min_depth, (self.W, self.H), interpolation=cv2.INTER_LINEAR) 83 | self.min_depth = self.min_depth.reshape((-1)) 84 | else: 85 | self.min_depth = None 86 | 87 | self.rays_o, self.rays_d, self.depth = get_rays_single_image(self.H, self.W, 88 | self.intrinsics, self.c2w_mat) 89 | 90 | def get_img(self): 91 | if self.img is not None: 92 | return self.img.reshape((self.H, self.W, 3)) 93 | else: 94 | return None 95 | 96 | def get_all(self): 97 | if self.min_depth is not None: 98 | min_depth = self.min_depth 99 | else: 100 | min_depth = 1e-4 * np.ones_like(self.rays_d[..., 0]) 101 | 102 | ret = OrderedDict([ 103 | ('ray_o', self.rays_o), 104 | ('ray_d', self.rays_d), 105 | ('depth', self.depth), 106 | ('rgb', self.img), 107 | ('mask', self.mask), 108 | ('min_depth', min_depth) 109 | ]) 110 | # return torch tensors 111 | for k in ret: 112 | if ret[k] is not None: 113 | ret[k] = torch.from_numpy(ret[k]) 114 | return ret 115 | 116 | def random_sample(self, N_rand, center_crop=False): 117 | ''' 118 | :param N_rand: number of rays to be casted 119 | :return: 120 | ''' 121 | if center_crop: 122 | half_H = self.H // 2 123 | half_W = self.W // 2 124 | quad_H = half_H // 2 125 | quad_W = half_W // 2 126 | 127 | # pixel coordinates 128 | u, v = np.meshgrid(np.arange(half_W-quad_W, half_W+quad_W), 129 | np.arange(half_H-quad_H, half_H+quad_H)) 130 | u = u.reshape(-1) 131 | v = v.reshape(-1) 132 | 133 | select_inds = np.random.choice(u.shape[0], size=(N_rand,), replace=False) 134 | 135 | # Convert back to original image 136 | select_inds = v[select_inds] * self.W + u[select_inds] 137 | else: 138 | # Random from one image 139 | select_inds = np.random.choice(self.H*self.W, size=(N_rand,), replace=False) 140 | 141 | rays_o = self.rays_o[select_inds, :] # [N_rand, 3] 142 | rays_d = self.rays_d[select_inds, :] # [N_rand, 3] 143 | depth = self.depth[select_inds] # [N_rand, ] 144 | 145 | if self.img is not None: 146 | rgb = self.img[select_inds, :] # [N_rand, 3] 147 | else: 148 | rgb = None 149 | 150 | if self.mask is not None: 151 | mask = self.mask[select_inds] 152 | else: 153 | mask = None 154 | 155 | if self.min_depth is not None: 156 | min_depth = self.min_depth[select_inds] 157 | else: 158 | min_depth = 1e-4 * np.ones_like(rays_d[..., 0]) 159 | 160 | ret = OrderedDict([ 161 | ('ray_o', rays_o), 162 | ('ray_d', rays_d), 163 | ('depth', depth), 164 | ('rgb', rgb), 165 | ('mask', mask), 166 | ('min_depth', min_depth), 167 | ('img_name', self.img_path) 168 | ]) 169 | # return torch tensors 170 | for k in ret: 171 | if isinstance(ret[k], np.ndarray): 172 | ret[k] = torch.from_numpy(ret[k]) 173 | 174 | return ret 175 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import torch.nn as nn 3 | # import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | HUGE_NUMBER = 1e10 8 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision 9 | 10 | 11 | # misc utils 12 | def img2mse(x, y, mask=None): 13 | if mask is None: 14 | return torch.mean((x - y) * (x - y)) 15 | else: 16 | return torch.sum((x - y) * (x - y) * mask.unsqueeze(-1)) / (torch.sum(mask) * x.shape[-1] + TINY_NUMBER) 17 | 18 | img_HWC2CHW = lambda x: x.permute(2, 0, 1) 19 | gray2rgb = lambda x: x.unsqueeze(2).repeat(1, 1, 3) 20 | 21 | 22 | def normalize(x): 23 | min = x.min() 24 | max = x.max() 25 | 26 | return (x - min) / ((max - min) + TINY_NUMBER) 27 | 28 | 29 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 30 | # gray2rgb = lambda x: np.tile(x[:,:,np.newaxis], (1, 1, 3)) 31 | mse2psnr = lambda x: -10. * np.log(x+TINY_NUMBER) / np.log(10.) 32 | 33 | 34 | ######################################################################################################################## 35 | # 36 | ######################################################################################################################## 37 | from matplotlib.backends.backend_agg import FigureCanvasAgg 38 | from matplotlib.figure import Figure 39 | import matplotlib as mpl 40 | from matplotlib import cm 41 | import cv2 42 | 43 | 44 | def get_vertical_colorbar(h, vmin, vmax, cmap_name='jet', label=None): 45 | fig = Figure(figsize=(1.2, 8), dpi=100) 46 | fig.subplots_adjust(right=1.5) 47 | canvas = FigureCanvasAgg(fig) 48 | 49 | # Do some plotting. 50 | ax = fig.add_subplot(111) 51 | cmap = cm.get_cmap(cmap_name) 52 | norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) 53 | 54 | tick_cnt = 6 55 | tick_loc = np.linspace(vmin, vmax, tick_cnt) 56 | cb1 = mpl.colorbar.ColorbarBase(ax, cmap=cmap, 57 | norm=norm, 58 | ticks=tick_loc, 59 | orientation='vertical') 60 | 61 | tick_label = ['{:3.2f}'.format(x) for x in tick_loc] 62 | cb1.set_ticklabels(tick_label) 63 | 64 | cb1.ax.tick_params(labelsize=18, rotation=0) 65 | 66 | if label is not None: 67 | cb1.set_label(label) 68 | 69 | fig.tight_layout() 70 | 71 | canvas.draw() 72 | s, (width, height) = canvas.print_to_buffer() 73 | 74 | im = np.frombuffer(s, np.uint8).reshape((height, width, 4)) 75 | 76 | im = im[:, :, :3].astype(np.float32) / 255. 77 | if h != im.shape[0]: 78 | w = int(im.shape[1] / im.shape[0] * h) 79 | im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA) 80 | 81 | return im 82 | 83 | 84 | def colorize_np(x, cmap_name='jet', mask=None, append_cbar=False): 85 | if mask is not None: 86 | # vmin, vmax = np.percentile(x[mask], (1, 99)) 87 | vmin = np.min(x[mask]) 88 | vmax = np.max(x[mask]) 89 | vmin = vmin - np.abs(vmin) * 0.01 90 | x[np.logical_not(mask)] = vmin 91 | x = np.clip(x, vmin, vmax) 92 | # print(vmin, vmax) 93 | else: 94 | vmin = x.min() 95 | vmax = x.max() + TINY_NUMBER 96 | 97 | x = (x - vmin) / (vmax - vmin) 98 | # x = np.clip(x, 0., 1.) 99 | 100 | cmap = cm.get_cmap(cmap_name) 101 | x_new = cmap(x)[:, :, :3] 102 | 103 | if mask is not None: 104 | mask = np.float32(mask[:, :, np.newaxis]) 105 | x_new = x_new * mask + np.zeros_like(x_new) * (1. - mask) 106 | 107 | cbar = get_vertical_colorbar(h=x.shape[0], vmin=vmin, vmax=vmax, cmap_name=cmap_name) 108 | 109 | if append_cbar: 110 | x_new = np.concatenate((x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1) 111 | return x_new 112 | else: 113 | return x_new, cbar 114 | 115 | 116 | # tensor 117 | def colorize(x, cmap_name='jet', append_cbar=False, mask=None): 118 | x = x.numpy() 119 | if mask is not None: 120 | mask = mask.numpy().astype(dtype=np.bool) 121 | x, cbar = colorize_np(x, cmap_name, mask) 122 | 123 | if append_cbar: 124 | x = np.concatenate((x, np.zeros_like(x[:, :5, :]), cbar), axis=1) 125 | 126 | x = torch.from_numpy(x) 127 | return x 128 | --------------------------------------------------------------------------------