├── .gitignore ├── License ├── README.md ├── assets ├── fern.gif ├── fern.jpg ├── horn.gif ├── horn.jpg ├── room.gif └── room.jpg ├── datasets ├── __init__.py ├── blender.py ├── colmap_utils.py ├── depth_utils.py ├── llff.py └── ray_utils.py ├── eval.py ├── losses.py ├── metrics.py ├── models ├── __init__.py ├── cloud_code.py ├── nerf.py └── rendering.py ├── opt.py ├── pointnet2_ops_lib └── pointnet2_ops │ ├── __init__.py │ ├── _ext-src │ ├── include │ │ ├── ball_query.h │ │ ├── cuda_utils.h │ │ ├── group_points.h │ │ ├── interpolate.h │ │ ├── sampling.h │ │ └── utils.h │ └── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── bindings.cpp │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── sampling.cpp │ │ └── sampling_gpu.cu │ ├── _version.py │ ├── pointnet2_modules.py │ └── pointnet2_utils.py ├── train.py └── utils ├── __init__.py ├── save_weights_only.py ├── visualization.py └── warmup_scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | logs/ 3 | ckpts/ 4 | results/ 5 | *.ply 6 | *.vol 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 KAIST-VICLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cloud NeRF 2 | This is the official implementation of Neural Radiance Fields with Points Cloud Latent Representation. 3 | [[Paper]]()[[Project Page]]() 4 | ## Instruction 5 | - Please download and arrange the dataset same as the instruction. 6 | - For the environment, we provide our Docker image for the best reproduction. 7 | - The scen optimization scripts are provided in the instruction. 8 | ## Data 9 | - We evaluate our framework on forward-facing LLFF dataset, available at [Google drive](https://drive.google.com/drive/folders/14boI-o5hGO9srnWaaogTU5_ji7wkX2S7). 10 | - We also need to download the pre-trained MVS depth estimation at [Google drive](https://drive.google.com/drive/folders/13lreojzboR7X7voJ1q8JduvWDdzyrwRe). 11 | - Our data folder structure is same as follow: 12 | ``` 13 | 14 | ├── datasets 15 | │ ├── nerf_llff_data 16 | │ │ │──fern 17 | │ │ │ | |──depths 18 | │ │ │ | |──iamges 19 | │ │ │ | |──images_4 20 | │ │ │ | |──sparse 21 | │ │ │ | |──colmap_depth.npy 22 | │ │ │ | |──poses_bounds.npy 23 | │ │ │ | |──... 24 | ``` 25 | ## Docker 26 | - We provide the Docker images of our environment at [DockerHub](https://hub.docker.com/repository/docker/quan5609/cloud_nerf). 27 | - To create docker container from image, run the following command 28 | ``` 29 | docker run \ 30 | --name ${CONTAINER_NAME} \ 31 | --gpus all \ 32 | --mount type=bind,source="${PATH_TO_SOURCE}",target="/workspace/source" \ 33 | --mount type=bind,source="${PATH_TO_DATASETS}",target="/workspace/datasets/" \ 34 | --shm-size=16GB \ 35 | -it ${IMAGE_NAME} 36 | ``` 37 | ## Train & Evaluation 38 | - To train from scratch, run the following command 39 | ``` 40 | CUDA_VISIBLE_DEVICES=1 python train.py \ 41 | --dataset_name llff \ 42 | --root_dir /workspace/datasets/nerf_llff_data/${SCENE_NAME}/ \ 43 | --N_importance 64 \ 44 | --N_sample 64 \ 45 | --img_wh 1008 756 \ 46 | --num_epochs 10 \ 47 | --batch_size 4096 \ 48 | --optimizer adam \ 49 | --lr 5e-3 \ 50 | --lr_scheduler steplr \ 51 | --decay_step 2 4 6 8 \ 52 | --decay_gamma 0.5 \ 53 | --exp_name ${EXP_NAME} 54 | ``` 55 | - To evaluate a checkpoint, run the following command 56 | ``` 57 | CUDA_VISIBLE_DEVICES=1 python eval.py \ 58 | --dataset_name llff \ 59 | --root_dir /workspace/datasets/nerf_llff_data/${SCENE_NAME}/ \ 60 | --N_importance 64 \ 61 | --N_sample 64 \ 62 | --img_wh 1008 756 \ 63 | --weight_path ${PATH_TO_CHECKPOINT} \ 64 | --split val 65 | ``` 66 | ## Visualization 67 | 68 | - Visualization of Fern scene 69 | 70 | 71 | - Visualization of Horn scene 72 | 73 | 74 | - Visualization of Room scene 75 | 76 | 77 | 78 | 79 | # Acknowledgement 80 | Our repo is based on [nerf](https://github.com/bmild/nerf), [nerf_pl](https://github.com/kwea123/nerf_pl), [DCCDIF](https://github.com/lity20/DCCDIF), and [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch). 81 | -------------------------------------------------------------------------------- /assets/fern.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/fern.gif -------------------------------------------------------------------------------- /assets/fern.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/fern.jpg -------------------------------------------------------------------------------- /assets/horn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/horn.gif -------------------------------------------------------------------------------- /assets/horn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/horn.jpg -------------------------------------------------------------------------------- /assets/room.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/room.gif -------------------------------------------------------------------------------- /assets/room.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/room.jpg -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .blender import BlenderDataset 2 | from .llff import LLFFDataset 3 | 4 | 5 | dataset_dict = {"blender": BlenderDataset, "llff": LLFFDataset} 6 | -------------------------------------------------------------------------------- /datasets/blender.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms as T 9 | 10 | from .ray_utils import get_ray_directions, get_rays 11 | 12 | 13 | class BlenderDataset(Dataset): 14 | def __init__(self, root_dir, split="train", img_wh=(800, 800)): 15 | self.root_dir = root_dir 16 | self.split = split 17 | assert img_wh[0] == img_wh[1], "image width must equal image height!" 18 | self.img_wh = img_wh 19 | self.define_transforms() 20 | 21 | self.read_meta() 22 | self.white_back = True 23 | 24 | def read_meta(self): 25 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), "r") as f: 26 | self.meta = json.load(f) 27 | 28 | w, h = self.img_wh 29 | # original focal length 30 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta["camera_angle_x"]) 31 | # when W=800 32 | 33 | # modify focal length to match size self.img_wh 34 | self.focal *= self.img_wh[0] / 800 35 | 36 | # bounds, common for all scenes 37 | self.near = 2.0 38 | self.far = 6.0 39 | self.bounds = np.array([self.near, self.far]) 40 | 41 | # ray directions for all pixels, same for all images (same H, W, focal) 42 | self.directions = get_ray_directions(h, w, self.focal) # (h, w, 3) 43 | 44 | if self.split == "train": # create buffer of all rays and rgb data 45 | self.image_paths = [] 46 | self.poses = [] 47 | self.all_rays = [] 48 | self.all_rgbs = [] 49 | for frame in self.meta["frames"]: 50 | pose = np.array(frame["transform_matrix"])[:3, :4] 51 | self.poses += [pose] 52 | c2w = torch.FloatTensor(pose) 53 | 54 | image_path = os.path.join( 55 | self.root_dir, f"{frame['file_path']}.png") 56 | self.image_paths += [image_path] 57 | img = Image.open(image_path) 58 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS) 59 | img = self.transform(img) # (4, h, w) 60 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA 61 | img = img[:, :3] * img[:, -1:] + \ 62 | (1 - img[:, -1:]) # blend A to RGB 63 | self.all_rgbs += [img] 64 | 65 | rays_o, rays_d = get_rays( 66 | self.directions, c2w) # both (h*w, 3) 67 | 68 | self.all_rays += [ 69 | torch.cat( 70 | [ 71 | rays_o, 72 | rays_d, 73 | self.near * torch.ones_like(rays_o[:, :1]), 74 | self.far * torch.ones_like(rays_o[:, :1]), 75 | ], 76 | 1, 77 | ) 78 | ] # (h*w, 8) 79 | 80 | # (len(self.meta['frames])*h*w, 3) 81 | self.all_rays = torch.cat(self.all_rays, 0) 82 | # (len(self.meta['frames])*h*w, 3) 83 | self.all_rgbs = torch.cat(self.all_rgbs, 0) 84 | 85 | def define_transforms(self): 86 | self.transform = T.ToTensor() 87 | 88 | def __len__(self): 89 | if self.split == "train": 90 | return len(self.all_rays) 91 | if self.split == "val": 92 | return 8 # only validate 8 images (to support <=8 gpus) 93 | return len(self.meta["frames"]) 94 | 95 | def __getitem__(self, idx): 96 | if self.split == "train": # use data in the buffers 97 | sample = {"rays": self.all_rays[idx], "rgbs": self.all_rgbs[idx]} 98 | 99 | else: # create data for each image separately 100 | frame = self.meta["frames"][idx] 101 | c2w = torch.FloatTensor(frame["transform_matrix"])[:3, :4] 102 | 103 | img = Image.open(os.path.join( 104 | self.root_dir, f"{frame['file_path']}.png")) 105 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS) 106 | img = self.transform(img) # (4, H, W) 107 | valid_mask = (img[-1] > 0).flatten() # (H*W) valid color area 108 | img = img.view(4, -1).permute(1, 0) # (H*W, 4) RGBA 109 | img = img[:, :3] * img[:, -1:] + \ 110 | (1 - img[:, -1:]) # blend A to RGB 111 | 112 | rays_o, rays_d = get_rays(self.directions, c2w) 113 | 114 | rays = torch.cat( 115 | [ 116 | rays_o, 117 | rays_d, 118 | self.near * torch.ones_like(rays_o[:, :1]), 119 | self.far * torch.ones_like(rays_o[:, :1]), 120 | ], 121 | 1, 122 | ) # (H*W, 8) 123 | 124 | sample = {"rays": rays, "rgbs": img, 125 | "c2w": c2w, "valid_mask": valid_mask} 126 | 127 | return sample 128 | -------------------------------------------------------------------------------- /datasets/colmap_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import collections 33 | import os 34 | import struct 35 | 36 | import numpy as np 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | 49 | class Image(BaseImage): 50 | def qvec2rotmat(self): 51 | return qvec2rotmat(self.qvec) 52 | 53 | 54 | CAMERA_MODELS = { 55 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 56 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 57 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 58 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 59 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 60 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 61 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 62 | CameraModel(model_id=7, model_name="FOV", num_params=5), 63 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 64 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 65 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), 66 | } 67 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 68 | for camera_model in CAMERA_MODELS]) 69 | 70 | 71 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 72 | """Read and unpack the next bytes from a binary file. 73 | :param fid: 74 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 75 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 76 | :param endian_character: Any of {@, =, <, >, !} 77 | :return: Tuple of read and unpacked values. 78 | """ 79 | data = fid.read(num_bytes) 80 | return struct.unpack(endian_character + format_char_sequence, data) 81 | 82 | 83 | def read_cameras_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::WriteCamerasText(const std::string& path) 87 | void Reconstruction::ReadCamerasText(const std::string& path) 88 | """ 89 | cameras = {} 90 | with open(path, "r") as fid: 91 | while True: 92 | line = fid.readline() 93 | if not line: 94 | break 95 | line = line.strip() 96 | if len(line) > 0 and line[0] != "#": 97 | elems = line.split() 98 | camera_id = int(elems[0]) 99 | model = elems[1] 100 | width = int(elems[2]) 101 | height = int(elems[3]) 102 | params = np.array(tuple(map(float, elems[4:]))) 103 | cameras[camera_id] = Camera( 104 | id=camera_id, model=model, width=width, height=height, params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes( 127 | fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params) 128 | cameras[camera_id] = Camera( 129 | id=camera_id, model=model_name, width=width, height=height, params=np.array( 130 | params) 131 | ) 132 | assert len(cameras) == num_cameras 133 | return cameras 134 | 135 | 136 | def read_images_text(path): 137 | """ 138 | see: src/base/reconstruction.cc 139 | void Reconstruction::ReadImagesText(const std::string& path) 140 | void Reconstruction::WriteImagesText(const std::string& path) 141 | """ 142 | images = {} 143 | with open(path, "r") as fid: 144 | while True: 145 | line = fid.readline() 146 | if not line: 147 | break 148 | line = line.strip() 149 | if len(line) > 0 and line[0] != "#": 150 | elems = line.split() 151 | image_id = int(elems[0]) 152 | qvec = np.array(tuple(map(float, elems[1:5]))) 153 | tvec = np.array(tuple(map(float, elems[5:8]))) 154 | camera_id = int(elems[8]) 155 | image_name = elems[9] 156 | elems = fid.readline().split() 157 | xys = np.column_stack( 158 | [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))]) 159 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 160 | images[image_id] = Image( 161 | id=image_id, 162 | qvec=qvec, 163 | tvec=tvec, 164 | camera_id=camera_id, 165 | name=image_name, 166 | xys=xys, 167 | point3D_ids=point3D_ids, 168 | ) 169 | return images 170 | 171 | 172 | def read_images_binary(path_to_model_file): 173 | """ 174 | see: src/base/reconstruction.cc 175 | void Reconstruction::ReadImagesBinary(const std::string& path) 176 | void Reconstruction::WriteImagesBinary(const std::string& path) 177 | """ 178 | images = {} 179 | with open(path_to_model_file, "rb") as fid: 180 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 181 | for image_index in range(num_reg_images): 182 | binary_image_properties = read_next_bytes( 183 | fid, num_bytes=64, format_char_sequence="idddddddi") 184 | image_id = binary_image_properties[0] 185 | qvec = np.array(binary_image_properties[1:5]) 186 | tvec = np.array(binary_image_properties[5:8]) 187 | camera_id = binary_image_properties[8] 188 | image_name = "" 189 | current_char = read_next_bytes(fid, 1, "c")[0] 190 | while current_char != b"\x00": # look for the ASCII 0 entry 191 | image_name += current_char.decode("utf-8") 192 | current_char = read_next_bytes(fid, 1, "c")[0] 193 | num_points2D = read_next_bytes( 194 | fid, num_bytes=8, format_char_sequence="Q")[0] 195 | x_y_id_s = read_next_bytes( 196 | fid, num_bytes=24 * num_points2D, format_char_sequence="ddq" * num_points2D) 197 | xys = np.column_stack( 198 | [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]) 199 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 200 | images[image_id] = Image( 201 | id=image_id, 202 | qvec=qvec, 203 | tvec=tvec, 204 | camera_id=camera_id, 205 | name=image_name, 206 | xys=xys, 207 | point3D_ids=point3D_ids, 208 | ) 209 | return images 210 | 211 | 212 | def read_points3D_text(path): 213 | """ 214 | see: src/base/reconstruction.cc 215 | void Reconstruction::ReadPoints3DText(const std::string& path) 216 | void Reconstruction::WritePoints3DText(const std::string& path) 217 | """ 218 | points3D = {} 219 | with open(path, "r") as fid: 220 | while True: 221 | line = fid.readline() 222 | if not line: 223 | break 224 | line = line.strip() 225 | if len(line) > 0 and line[0] != "#": 226 | elems = line.split() 227 | point3D_id = int(elems[0]) 228 | xyz = np.array(tuple(map(float, elems[1:4]))) 229 | rgb = np.array(tuple(map(int, elems[4:7]))) 230 | error = float(elems[7]) 231 | image_ids = np.array(tuple(map(int, elems[8::2]))) 232 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 233 | points3D[point3D_id] = Point3D( 234 | id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs 235 | ) 236 | return points3D 237 | 238 | 239 | def read_points3d_binary(path_to_model_file): 240 | """ 241 | see: src/base/reconstruction.cc 242 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 243 | void Reconstruction::WritePoints3DBinary(const std::string& path) 244 | """ 245 | points3D = {} 246 | with open(path_to_model_file, "rb") as fid: 247 | num_points = read_next_bytes(fid, 8, "Q")[0] 248 | for point_line_index in range(num_points): 249 | binary_point_line_properties = read_next_bytes( 250 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 251 | point3D_id = binary_point_line_properties[0] 252 | xyz = np.array(binary_point_line_properties[1:4]) 253 | rgb = np.array(binary_point_line_properties[4:7]) 254 | error = np.array(binary_point_line_properties[7]) 255 | track_length = read_next_bytes( 256 | fid, num_bytes=8, format_char_sequence="Q")[0] 257 | track_elems = read_next_bytes( 258 | fid, num_bytes=8 * track_length, format_char_sequence="ii" * track_length) 259 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 260 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 261 | points3D[point3D_id] = Point3D( 262 | id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs 263 | ) 264 | return points3D 265 | 266 | 267 | def read_model(path, ext): 268 | if ext == ".txt": 269 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 270 | images = read_images_text(os.path.join(path, "images" + ext)) 271 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 272 | else: 273 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 274 | images = read_images_binary(os.path.join(path, "images" + ext)) 275 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 276 | return cameras, images, points3D 277 | 278 | 279 | def qvec2rotmat(qvec): 280 | return np.array( 281 | [ 282 | [ 283 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 284 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 285 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], 286 | ], 287 | [ 288 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 289 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 290 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], 291 | ], 292 | [ 293 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 294 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 295 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, 296 | ], 297 | ] 298 | ) 299 | 300 | 301 | def rotmat2qvec(R): 302 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 303 | K = ( 304 | np.array( 305 | [ 306 | [Rxx - Ryy - Rzz, 0, 0, 0], 307 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 308 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 309 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], 310 | ] 311 | ) 312 | / 3.0 313 | ) 314 | eigvals, eigvecs = np.linalg.eigh(K) 315 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 316 | if qvec[0] < 0: 317 | qvec *= -1 318 | return qvec 319 | -------------------------------------------------------------------------------- /datasets/depth_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | import numpy as np 5 | 6 | 7 | def read_pfm(filename): 8 | file = open(filename, "rb") 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().decode("utf-8").rstrip() 16 | if header == "PF": 17 | color = True 18 | elif header == "Pf": 19 | color = False 20 | else: 21 | raise Exception("Not a PFM file.") 22 | 23 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("utf-8")) 24 | if dim_match: 25 | width, height = map(int, dim_match.groups()) 26 | else: 27 | raise Exception("Malformed PFM header.") 28 | 29 | scale = float(file.readline().rstrip()) 30 | if scale < 0: # little-endian 31 | endian = "<" 32 | scale = -scale 33 | else: 34 | endian = ">" # big-endian 35 | 36 | data = np.fromfile(file, endian + "f") 37 | shape = (height, width, 3) if color else (height, width) 38 | 39 | data = np.reshape(data, shape) 40 | data = np.flipud(data) 41 | file.close() 42 | return data, scale 43 | 44 | 45 | def save_pfm(filename, image, scale=1): 46 | file = open(filename, "wb") 47 | color = None 48 | 49 | image = np.flipud(image) 50 | 51 | if image.dtype.name != "float32": 52 | raise Exception("Image dtype must be float32.") 53 | 54 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 55 | color = True 56 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 57 | color = False 58 | else: 59 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 60 | 61 | file.write("PF\n".encode("utf-8") if color else "Pf\n".encode("utf-8")) 62 | file.write("{} {}\n".format(image.shape[1], image.shape[0]).encode("utf-8")) 63 | 64 | endian = image.dtype.byteorder 65 | 66 | if endian == "<" or endian == "=" and sys.byteorder == "little": 67 | scale = -scale 68 | 69 | file.write(("%f\n" % scale).encode("utf-8")) 70 | 71 | image.tofile(file) 72 | file.close() 73 | -------------------------------------------------------------------------------- /datasets/llff.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms as T 10 | 11 | from .colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary 12 | from .ray_utils import read_gen, get_ray_directions, get_rays, get_ndc_rays 13 | 14 | 15 | def normalize(v): 16 | """Normalize a vector.""" 17 | return v / np.linalg.norm(v) 18 | 19 | 20 | def average_poses(poses): 21 | """ 22 | Calculate the average pose, which is then used to center all poses 23 | using @center_poses. Its computation is as follows: 24 | 1. Compute the center: the average of pose centers. 25 | 2. Compute the z axis: the normalized average z axis. 26 | 3. Compute axis y': the average y axis. 27 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 28 | 5. Compute the y axis: z cross product x. 29 | 30 | Note that at step 3, we cannot directly use y' as y axis since it's 31 | not necessarily orthogonal to z axis. We need to pass from x to y. 32 | 33 | Inputs: 34 | poses: (N_images, 3, 4) 35 | 36 | Outputs: 37 | pose_avg: (3, 4) the average pose 38 | """ 39 | # 1. Compute the center 40 | center = poses[..., 3].mean(0) # (3) 41 | 42 | # 2. Compute the z axis 43 | z = normalize(poses[..., 2].mean(0)) # (3) 44 | 45 | # 3. Compute axis y' (no need to normalize as it's not the final output) 46 | y_ = poses[..., 1].mean(0) # (3) 47 | 48 | # 4. Compute the x axis 49 | x = normalize(np.cross(y_, z)) # (3) 50 | 51 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 52 | y = np.cross(z, x) # (3) 53 | 54 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 55 | 56 | return pose_avg 57 | 58 | 59 | def center_poses(poses): 60 | """ 61 | Center the poses so that we can use NDC. 62 | See https://github.com/bmild/nerf/issues/34 63 | 64 | Inputs: 65 | poses: (N_images, 3, 4) 66 | 67 | Outputs: 68 | poses_centered: (N_images, 3, 4) the centered poses 69 | pose_avg: (3, 4) the average pose 70 | """ 71 | 72 | pose_avg = average_poses(poses) # (3, 4) 73 | pose_avg_homo = np.eye(4) 74 | # convert to homogeneous coordinate for faster computation 75 | pose_avg_homo[:3] = pose_avg 76 | # by simply adding 0, 0, 0, 1 as the last row 77 | last_row = np.tile(np.array([0, 0, 0, 1]), 78 | (len(poses), 1, 1)) # (N_images, 1, 4) 79 | # (N_images, 4, 4) homogeneous coordinate 80 | poses_homo = np.concatenate([poses, last_row], 1) 81 | 82 | poses_centered = np.linalg.inv( 83 | pose_avg_homo) @ poses_homo # (N_images, 4, 4) 84 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 85 | return poses_centered, pose_avg 86 | 87 | 88 | def create_spiral_poses(radii, focus_depth, n_poses=120): 89 | """ 90 | Computes poses that follow a spiral path for rendering purpose. 91 | See https://github.com/Fyusion/LLFF/issues/19 92 | In particular, the path looks like: 93 | https://tinyurl.com/ybgtfns3 94 | 95 | Inputs: 96 | radii: (3) radii of the spiral for each axis 97 | focus_depth: float, the depth that the spiral poses look at 98 | n_poses: int, number of poses to create along the path 99 | 100 | Outputs: 101 | poses_spiral: (n_poses, 3, 4) the poses in the spiral path 102 | """ 103 | 104 | poses_spiral = [] 105 | # rotate 4pi (2 rounds) 106 | for t in np.linspace(0, 4 * np.pi, n_poses + 1)[:-1]: 107 | # the parametric function of the spiral (see the interactive web) 108 | center = np.array([np.cos(t), -np.sin(t), -np.sin(0.5 * t)]) * radii 109 | 110 | # the viewing z axis is the vector pointing from the @focus_depth plane 111 | # to @center 112 | z = normalize(center - np.array([0, 0, -focus_depth])) 113 | 114 | # compute other axes as in @average_poses 115 | y_ = np.array([0, 1, 0]) # (3) 116 | x = normalize(np.cross(y_, z)) # (3) 117 | y = np.cross(z, x) # (3) 118 | 119 | poses_spiral += [np.stack([x, y, z, center], 1)] # (3, 4) 120 | 121 | return np.stack(poses_spiral, 0) # (n_poses, 3, 4) 122 | 123 | 124 | def create_spheric_poses(radius, n_poses=120): 125 | """ 126 | Create circular poses around z axis. 127 | Inputs: 128 | radius: the (negative) height and the radius of the circle. 129 | 130 | Outputs: 131 | spheric_poses: (n_poses, 3, 4) the poses in the circular path 132 | """ 133 | 134 | def spheric_pose(theta, phi, radius): 135 | def trans_t(t): 136 | return np.array( 137 | [ 138 | [1, 0, 0, 0], 139 | [0, 1, 0, -0.9 * t], 140 | [0, 0, 1, t], 141 | [0, 0, 0, 1], 142 | ] 143 | ) 144 | 145 | def rot_phi(phi): 146 | return np.array( 147 | [ 148 | [1, 0, 0, 0], 149 | [0, np.cos(phi), -np.sin(phi), 0], 150 | [0, np.sin(phi), np.cos(phi), 0], 151 | [0, 0, 0, 1], 152 | ] 153 | ) 154 | 155 | def rot_theta(th): 156 | return np.array( 157 | [ 158 | [np.cos(th), 0, -np.sin(th), 0], 159 | [0, 1, 0, 0], 160 | [np.sin(th), 0, np.cos(th), 0], 161 | [0, 0, 0, 1], 162 | ] 163 | ) 164 | 165 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius) 166 | c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], 167 | [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w 168 | return c2w[:3] 169 | 170 | spheric_poses = [] 171 | for th in np.linspace(0, 2 * np.pi, n_poses + 1)[:-1]: 172 | # 36 degree view downwards 173 | spheric_poses += [spheric_pose(th, -np.pi / 5, radius)] 174 | return np.stack(spheric_poses, 0) 175 | 176 | 177 | class LLFFDataset(Dataset): 178 | def __init__(self, root_dir, split="train", img_wh=(504, 378), spheric_poses=False, val_num=1): 179 | """ 180 | spheric_poses: whether the images are taken in a spheric inward-facing manner 181 | default: False (forward-facing) 182 | val_num: number of val images (used for multigpu training, validate same image for all gpus) 183 | """ 184 | self.root_dir = root_dir 185 | self.split = split 186 | self.img_wh = img_wh 187 | self.spheric_poses = spheric_poses 188 | self.val_num = max(1, val_num) # at least 1 189 | self.holdout = 8 190 | self.define_transforms() 191 | 192 | self.read_meta() 193 | # self.load_mvs_depth() 194 | self.white_back = False 195 | 196 | def load_mvs_depth(self): 197 | depth_glob = os.path.join(self.root_dir, "depths", "*.pfm") 198 | self.depth_list = sorted(glob.glob(depth_glob)) 199 | depths = [] 200 | for i in range(len(self.depth_list)): 201 | depth = read_gen(self.depth_list[i]) 202 | 203 | depths.append(depth) 204 | self.depths = np.stack(depths, 0).astype(np.float32) # N x H x W 205 | 206 | per_view_points = self.project_to_3d() 207 | mvs_points = self.fwd_consistency_check(per_view_points) 208 | return mvs_points 209 | 210 | def project_to_3d(self): 211 | N, H, W = self.depths.shape 212 | focal = self.origin_intrinsics[1].params[0] 213 | origin_h, origin_w = self.origin_intrinsics[1].height, self.origin_intrinsics[1].width 214 | 215 | origin_cy, origin_cx = self.origin_intrinsics[1].params[2], self.origin_intrinsics[1].params[1] 216 | 217 | origin_K = np.array([[focal, 0, origin_cx, 0], [ 218 | 0, focal, origin_cy, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 219 | 220 | origin_K[0, :] /= origin_w 221 | origin_K[1, :] /= origin_h 222 | 223 | self.normalized_K = origin_K 224 | 225 | mvs_K = self.normalized_K.copy() 226 | mvs_K[0, :] *= W 227 | mvs_K[1, :] *= H 228 | self.mvs_K = mvs_K 229 | 230 | inv_mvs_K = np.linalg.pinv(mvs_K) 231 | inv_mvs_K = torch.from_numpy(inv_mvs_K) 232 | 233 | # create mesh grid for mvs image 234 | meshgrid = np.meshgrid(range(W), range(H), indexing="xy") 235 | id_coords = (np.stack(meshgrid, axis=0).astype( 236 | np.float32)).reshape(2, -1) 237 | id_coords = torch.from_numpy(id_coords) 238 | 239 | ones = torch.ones(N, 1, H * W) 240 | 241 | pix_coords = torch.unsqueeze(torch.stack( 242 | [id_coords[0].view(-1), id_coords[1].view(-1)], 0), 0) 243 | pix_coords = pix_coords.repeat(N, 1, 1) 244 | pix_coords = torch.cat([pix_coords, ones], 1) 245 | 246 | # project to cam coord 247 | inv_mvs_K = inv_mvs_K[None, ...].repeat(N, 1, 1).float() 248 | cam_points = torch.matmul(inv_mvs_K[:, :3, :3], pix_coords) 249 | mvs_depth = torch.from_numpy( 250 | self.depths).float().unsqueeze(1).view(N, 1, -1) 251 | cam_points = mvs_depth * cam_points 252 | cam_points = torch.cat([cam_points, ones], 1) 253 | 254 | # project to world coord 255 | T = torch.from_numpy(self.origin_extrinsics).float() 256 | world_points = torch.matmul(T, cam_points) 257 | world_points = world_points.permute(0, 2, 1) # N, H*W, 3 258 | 259 | return world_points 260 | 261 | def fwd_consistency_check(self, per_view_points): 262 | N, H, W = self.depths.shape 263 | global_valid_points = [] 264 | for view_id in range(per_view_points.shape[0]): 265 | curr_view_points = per_view_points[view_id].transpose( 266 | 1, 0) # 3, H*W 267 | homo_view_points = torch.cat( 268 | [curr_view_points, torch.ones(1, H * W)], dim=0) # 4, H*W 269 | homo_view_points = homo_view_points.unsqueeze( 270 | 0).repeat(N, 1, 1) # N,4,H*W 271 | 272 | # project to camera space 273 | T = torch.from_numpy(self.origin_extrinsics).float() 274 | homo_T = torch.cat([T, torch.zeros(N, 1, 4)], dim=1) 275 | homo_T[:, -1, -1] = 1 276 | inv_T = torch.inverse(homo_T) 277 | cam_points = torch.matmul(inv_T[:, :3, :], homo_view_points) 278 | 279 | # project to image space 280 | mvs_K = torch.from_numpy(self.mvs_K).unsqueeze( 281 | 0).repeat(N, 1, 1).float() 282 | cam_points = torch.matmul(mvs_K[:, :3, :3], cam_points) 283 | cam_points[:, :2, :] /= cam_points[:, 2:, :] 284 | 285 | z_values = cam_points[:, 2:, :].view(N, 1, H, W) # N,1,H,W 286 | xy_coords = cam_points[:, :2, :].transpose( 287 | 2, 1).view(N, H, W, 2) # N,H,W,2 288 | 289 | xy_coords[..., 0] /= W - 1 290 | xy_coords[..., 1] /= H - 1 291 | xy_coords = (xy_coords - 0.5) * 2 292 | 293 | mvs_depth = torch.from_numpy( 294 | self.depths).float().unsqueeze(1) # N,1,H,W 295 | ref_z_values = F.grid_sample( 296 | mvs_depth, xy_coords, mode="bilinear", align_corners=False) 297 | 298 | # ! z_values >= alpha*ref_values, also invalid index is 0 so they are satisfied this condition 299 | # ! point must be visible in at least n views 300 | err = z_values - 0.9 * ref_z_values 301 | 302 | visible_mask = ref_z_values != 0 303 | visible_count = visible_mask.int().sum(0) 304 | valid_visible = visible_count >= 1 305 | valid_points = err >= 0 306 | valid_points = torch.all( 307 | valid_points, dim=0) & valid_visible # 1,H,W 308 | global_valid_points.append(valid_points) 309 | global_valid_points = torch.cat( 310 | global_valid_points, dim=0).view(N, H * W) # N,H,W 311 | 312 | filtered_points = per_view_points[global_valid_points, :] 313 | # np.save('assets/filtered_mvs_points.npy', filtered_points) 314 | # breakpoint() 315 | return filtered_points 316 | 317 | def read_meta(self): 318 | # Step 1: rescale focal length according to training resolution 319 | camdata = read_cameras_binary(os.path.join( 320 | self.root_dir, "sparse/0/cameras.bin")) 321 | self.origin_intrinsics = camdata 322 | W = camdata[1].width 323 | self.focal = camdata[1].params[0] * self.img_wh[0] / W 324 | 325 | # Step 2: correct poses 326 | # read extrinsics (of successfully reconstructed images) 327 | imdata = read_images_binary(os.path.join( 328 | self.root_dir, "sparse/0/images.bin")) 329 | perm = np.argsort([imdata[k].name for k in imdata]) 330 | # read successfully reconstructed images and ignore others 331 | self.image_paths = [ 332 | os.path.join(self.root_dir, "images", name) for name in sorted([imdata[k].name for k in imdata]) 333 | ] 334 | w2c_mats = [] 335 | bottom = np.array([0, 0, 0, 1.0]).reshape(1, 4) 336 | for k in imdata: 337 | im = imdata[k] 338 | R = im.qvec2rotmat() 339 | t = im.tvec.reshape(3, 1) 340 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)] 341 | w2c_mats = np.stack(w2c_mats, 0) 342 | # (N_images, 3, 4) cam2world matrices 343 | poses = np.linalg.inv(w2c_mats)[:, :3] 344 | self.origin_extrinsics = poses 345 | 346 | # read bounds 347 | self.bounds = np.zeros((len(poses), 2)) # (N_images, 2) 348 | pts3d = read_points3d_binary(os.path.join( 349 | self.root_dir, "sparse/0/points3D.bin")) 350 | 351 | mvs_points = self.load_mvs_depth().numpy() # ! This is mvs depth pretrained 352 | near_bound = mvs_points.min(axis=0)[-1] 353 | pts3d = {k: v for (k, v) in pts3d.items() if v.xyz[-1] > near_bound} 354 | 355 | pts_world = np.zeros((1, 3, len(pts3d))) # (1, 3, N_points) 356 | visibilities = np.zeros((len(poses), len(pts3d)) 357 | ) # (N_images, N_points) 358 | for i, k in enumerate(pts3d): 359 | pts_world[0, :, i] = pts3d[k].xyz 360 | for j in pts3d[k].image_ids: 361 | visibilities[j - 1, i] = 1 362 | # calculate each point's depth w.r.t. each camera 363 | # it's the dot product of "points - camera center" and "camera frontal axis" 364 | # (N_images, N_points) 365 | depths = ((pts_world - poses[..., 3:4]) * poses[..., 2:3]).sum(1) 366 | for i in range(len(poses)): 367 | visibility_i = visibilities[i] 368 | zs = depths[i][visibility_i == 1] 369 | self.bounds[i] = [np.percentile(zs, 0.1), np.percentile(zs, 99.9)] 370 | # permute the matrices to increasing order 371 | poses = poses[perm] 372 | self.bounds = self.bounds[perm] 373 | 374 | # COLMAP poses has rotation in form "right down front", change to "right up back" 375 | # See https://github.com/bmild/nerf/issues/34 376 | poses = np.concatenate( 377 | [poses[..., 0:1], -poses[..., 1:3], poses[..., 3:4]], -1) 378 | self.poses, _ = center_poses(poses) 379 | distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) 380 | 381 | # choose val image same as nerf 0, 8 ,16 382 | indicies = np.arange(distances_from_center.shape[0], dtype=int) 383 | val_idx = indicies[::self.holdout] 384 | # center image 385 | 386 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0 387 | # See https://github.com/bmild/nerf/issues/34 388 | near_original = self.bounds.min() 389 | scale_factor = near_original * 0.75 # 0.75 is the default parameter 390 | # the nearest depth is at 1/0.75=1.33 391 | self.bounds /= scale_factor 392 | self.poses[..., 3] /= scale_factor 393 | 394 | # ray directions for all pixels, same for all images (same H, W, focal) 395 | self.directions = get_ray_directions( 396 | self.img_wh[1], self.img_wh[0], self.focal) # (H, W, 3) 397 | 398 | if self.split == "train": # create buffer of all rays and rgb data 399 | # use first N_images-1 to train, the LAST is val 400 | self.all_rays = [] 401 | self.all_rgbs = [] 402 | for i, image_path in enumerate(self.image_paths): 403 | if np.any(i == val_idx): # exclude the val image 404 | continue 405 | c2w = torch.FloatTensor(self.poses[i]) 406 | 407 | img = Image.open(image_path).convert("RGB") 408 | # assert img.size[1]*self.img_wh[0] == img.size[0]*self.img_wh[1], \ 409 | # f'''{image_path} has different aspect ratio than img_wh, 410 | # please check your data!''' 411 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS) 412 | img = self.transform(img) # (3, h, w) 413 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB 414 | self.all_rgbs += [img] 415 | 416 | rays_o, rays_d = get_rays( 417 | self.directions, c2w) # both (h*w, 3) 418 | viewdirs = rays_d # ! As get rays already normalized rays_d 419 | 420 | if not self.spheric_poses: 421 | near, far = 0, 1 422 | rays_o, rays_d = get_ndc_rays( 423 | self.img_wh[1], self.img_wh[0], self.focal, 1.0, rays_o, rays_d) 424 | 425 | # near plane is always at 1.0 426 | # near and far in NDC are always 0 and 1 427 | # See https://github.com/bmild/nerf/issues/34 428 | else: 429 | near = self.bounds.min() 430 | # focus on central object only 431 | far = min(8 * near, self.bounds.max()) 432 | 433 | self.all_rays += [ 434 | torch.cat( 435 | [ 436 | rays_o, 437 | rays_d, 438 | near * torch.ones_like(rays_o[:, :1]), 439 | far * torch.ones_like(rays_o[:, :1]), 440 | viewdirs, 441 | ], 442 | 1, 443 | ) 444 | ] # (h*w, 11) 445 | 446 | # ((N_images-1)*h*w, 8) 447 | self.all_rays = torch.cat(self.all_rays, 0) 448 | # ((N_images-1)*h*w, 3) 449 | self.all_rgbs = torch.cat(self.all_rgbs, 0) 450 | 451 | elif self.split == "val": 452 | print("val image is", val_idx) 453 | self.val_idx = val_idx 454 | 455 | else: # for testing, create a parametric rendering path 456 | if self.split.endswith("train"): # test on training set 457 | self.poses_test = self.poses 458 | elif not self.spheric_poses: 459 | focus_depth = 3.5 # hardcoded, this is numerically close to the formula 460 | # given in the original repo. Mathematically if near=1 461 | # and far=infinity, then this number will converge to 4 462 | radii = np.percentile(np.abs(self.poses[..., 3]), 90, axis=0) 463 | self.poses_test = create_spiral_poses(radii, focus_depth) 464 | else: 465 | radius = 1.1 * self.bounds.min() 466 | self.poses_test = create_spheric_poses(radius) 467 | 468 | def define_transforms(self): 469 | self.transform = T.ToTensor() 470 | 471 | def __len__(self): 472 | if self.split == "train": 473 | return len(self.all_rays) 474 | if self.split == "val": 475 | return self.val_num 476 | if self.split == "test_train": 477 | return len(self.poses) 478 | return len(self.poses_test) 479 | 480 | def __getitem__(self, idx): 481 | if self.split == "train": # use data in the buffers 482 | sample = {"rays": self.all_rays[idx], "rgbs": self.all_rgbs[idx]} 483 | 484 | else: 485 | if self.split == "val": 486 | idx = self.val_idx[idx] 487 | c2w = torch.FloatTensor(self.poses[idx]) 488 | elif self.split == "test_train": 489 | c2w = torch.FloatTensor(self.poses[idx]) 490 | else: 491 | c2w = torch.FloatTensor(self.poses_test[idx]) 492 | 493 | rays_o, rays_d = get_rays(self.directions, c2w) 494 | viewdirs = rays_d 495 | if not self.spheric_poses: 496 | near, far = 0, 1 497 | rays_o, rays_d = get_ndc_rays( 498 | self.img_wh[1], self.img_wh[0], self.focal, 1.0, rays_o, rays_d) 499 | else: 500 | near = self.bounds.min() 501 | far = min(8 * near, self.bounds.max()) 502 | 503 | rays = torch.cat( 504 | [ 505 | rays_o, 506 | rays_d, 507 | near * torch.ones_like(rays_o[:, :1]), 508 | far * torch.ones_like(rays_o[:, :1]), 509 | viewdirs, 510 | ], 511 | 1, 512 | ) # (h*w, 11) 513 | 514 | sample = {"rays": rays, "c2w": c2w} 515 | 516 | if self.split in ["val", "test_train"]: 517 | # if self.split == 'val': 518 | # idx = self.val_idx 519 | img = Image.open(self.image_paths[idx]).convert("RGB") 520 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS) 521 | img = self.transform(img) # (3, h, w) 522 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) 523 | sample["rgbs"] = img 524 | 525 | return sample 526 | -------------------------------------------------------------------------------- /datasets/ray_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import torch 5 | from kornia import create_meshgrid 6 | 7 | 8 | def get_ray_directions(H, W, focal): 9 | """ 10 | Get ray directions for all pixels in camera coordinate. 11 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 12 | ray-tracing-generating-camera-rays/standard-coordinate-systems 13 | 14 | Inputs: 15 | H, W, focal: image height, width and focal length 16 | 17 | Outputs: 18 | directions: (H, W, 3), the direction of the rays in camera coordinate 19 | """ 20 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] 21 | i, j = grid.unbind(-1) 22 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 23 | # see https://github.com/bmild/nerf/issues/24 24 | directions = torch.stack([(i - W / 2) / focal, -(j - H / 2) / focal, -torch.ones_like(i)], -1) # (H, W, 3) 25 | 26 | return directions 27 | 28 | 29 | def get_rays(directions, c2w): 30 | """ 31 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 32 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 33 | ray-tracing-generating-camera-rays/standard-coordinate-systems 34 | 35 | Inputs: 36 | directions: (H, W, 3) precomputed ray directions in camera coordinate 37 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 38 | 39 | Outputs: 40 | rays_o: (H*W, 3), the origin of the rays in world coordinate 41 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 42 | """ 43 | # Rotate ray directions from camera coordinate to the world coordinate 44 | rays_d = directions @ c2w[:, :3].T # (H, W, 3) 45 | rays_d /= torch.norm(rays_d, dim=-1, keepdim=True) 46 | # The origin of all rays is the camera origin in world coordinate 47 | rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3) 48 | 49 | rays_d = rays_d.view(-1, 3) 50 | rays_o = rays_o.view(-1, 3) 51 | 52 | return rays_o, rays_d 53 | 54 | 55 | def get_ndc_rays(H, W, focal, near, rays_o, rays_d): 56 | """ 57 | Transform rays from world coordinate to NDC. 58 | NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis. 59 | For detailed derivation, please see: 60 | http://www.songho.ca/opengl/gl_projectionmatrix.html 61 | https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf 62 | 63 | In practice, use NDC "if and only if" the scene is unbounded (has a large depth). 64 | See https://github.com/bmild/nerf/issues/18 65 | 66 | Inputs: 67 | H, W, focal: image height, width and focal length 68 | near: (N_rays) or float, the depths of the near plane 69 | rays_o: (N_rays, 3), the origin of the rays in world coordinate 70 | rays_d: (N_rays, 3), the direction of the rays in world coordinate 71 | 72 | Outputs: 73 | rays_o: (N_rays, 3), the origin of the rays in NDC 74 | rays_d: (N_rays, 3), the direction of the rays in NDC 75 | """ 76 | # Shift ray origins to near plane 77 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 78 | rays_o = rays_o + t[..., None] * rays_d 79 | 80 | # Store some intermediate homogeneous results 81 | ox_oz = rays_o[..., 0] / rays_o[..., 2] 82 | oy_oz = rays_o[..., 1] / rays_o[..., 2] 83 | 84 | # Projection 85 | o0 = -1.0 / (W / (2.0 * focal)) * ox_oz 86 | o1 = -1.0 / (H / (2.0 * focal)) * oy_oz 87 | o2 = 1.0 + 2.0 * near / rays_o[..., 2] 88 | 89 | d0 = -1.0 / (W / (2.0 * focal)) * (rays_d[..., 0] / rays_d[..., 2] - ox_oz) 90 | d1 = -1.0 / (H / (2.0 * focal)) * (rays_d[..., 1] / rays_d[..., 2] - oy_oz) 91 | d2 = 1 - o2 92 | 93 | rays_o = torch.stack([o0, o1, o2], -1) # (B, 3) 94 | rays_d = torch.stack([d0, d1, d2], -1) # (B, 3) 95 | 96 | return rays_o, rays_d 97 | 98 | 99 | def get_ndc_coor(H, W, focal, near, pts): 100 | # Store some intermediate homogeneous results 101 | ox_oz = pts[..., 0] / pts[..., 2] 102 | oy_oz = pts[..., 1] / pts[..., 2] 103 | 104 | # Projection 105 | o0 = -1.0 / (W / (2.0 * focal)) * ox_oz 106 | o1 = -1.0 / (H / (2.0 * focal)) * oy_oz 107 | o2 = 1.0 + 2.0 * near / pts[..., 2] 108 | 109 | pts = torch.stack([o0, o1, o2], -1) # (B, 3) 110 | 111 | return pts 112 | 113 | 114 | # read mvs pretrained depth 115 | 116 | 117 | def readPFM(file): 118 | file = open(file, "rb") 119 | 120 | color = None 121 | width = None 122 | height = None 123 | scale = None 124 | endian = None 125 | 126 | header = file.readline().rstrip() 127 | if header == b"PF": 128 | color = True 129 | elif header == b"Pf": 130 | color = False 131 | else: 132 | raise Exception("Not a PFM file.") 133 | 134 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) 135 | if dim_match: 136 | width, height = map(int, dim_match.groups()) 137 | else: 138 | raise Exception("Malformed PFM header.") 139 | 140 | scale = float(file.readline().rstrip()) 141 | if scale < 0: # little-endian 142 | endian = "<" 143 | scale = -scale 144 | else: 145 | endian = ">" # big-endian 146 | 147 | data = np.fromfile(file, endian + "f") 148 | shape = (height, width, 3) if color else (height, width) 149 | 150 | data = np.reshape(data, shape) 151 | data = np.flipud(data) 152 | return data 153 | 154 | 155 | def read_gen(file_name, pil=False): 156 | flow = readPFM(file_name).astype(np.float32) 157 | if len(flow.shape) == 2: 158 | return flow 159 | else: 160 | return flow[:, :, :-1] 161 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from collections import defaultdict 4 | 5 | import cv2 6 | import imageio 7 | import lpips 8 | import metrics 9 | import numpy as np 10 | import torch 11 | from datasets import dataset_dict 12 | from datasets.depth_utils import save_pfm 13 | from models.rendering import render_rays 14 | from tqdm import tqdm 15 | from train import NeRFSystem 16 | 17 | 18 | torch.backends.cudnn.benchmark = True 19 | 20 | 21 | def get_opts(): 22 | parser = ArgumentParser() 23 | parser.add_argument( 24 | "--root_dir", 25 | type=str, 26 | default="/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego", 27 | help="root directory of dataset", 28 | ) 29 | parser.add_argument( 30 | "--dataset_name", type=str, default="blender", choices=["blender", "llff"], help="which dataset to validate" 31 | ) 32 | parser.add_argument("--scene_name", type=str, default="test", 33 | help="scene name, used as output folder name") 34 | parser.add_argument("--split", type=str, default="test", 35 | help="test or test_train") 36 | parser.add_argument( 37 | "--img_wh", nargs="+", type=int, default=[800, 800], help="resolution (img_w, img_h) of the image" 38 | ) 39 | parser.add_argument( 40 | "--spheric_poses", 41 | default=False, 42 | action="store_true", 43 | help="whether images are taken in spheric poses (for llff)", 44 | ) 45 | 46 | parser.add_argument("--N_emb_xyz", type=int, default=10, 47 | help="number of frequencies in xyz positional encoding") 48 | parser.add_argument("--N_emb_dir", type=int, default=4, 49 | help="number of frequencies in dir positional encoding") 50 | parser.add_argument("--N_samples", type=int, default=64, 51 | help="number of coarse samples") 52 | parser.add_argument("--N_importance", type=int, default=128, 53 | help="number of additional fine samples") 54 | parser.add_argument("--use_disp", default=False, 55 | action="store_true", help="use disparity depth sampling") 56 | parser.add_argument("--chunk", type=int, default=32 * 1024 * 4, 57 | help="chunk size to split the input to avoid OOM") 58 | 59 | parser.add_argument("--weight_path", type=str, required=True, 60 | help="pretrained checkpoint path to load") 61 | 62 | parser.add_argument("--save_depth", default=False, 63 | action="store_true", help="whether to save depth prediction") 64 | parser.add_argument( 65 | "--depth_format", type=str, default="pfm", choices=["pfm", "bytes"], help="which format to save" 66 | ) 67 | 68 | return parser.parse_args() 69 | 70 | 71 | @torch.no_grad() 72 | def batched_inference(models, embeddings, rays, N_samples, N_importance, use_disp, chunk): 73 | """Do batched inference on rays using chunk.""" 74 | B = rays.shape[0] 75 | results = defaultdict(list) 76 | for i in range(0, B, chunk): 77 | rendered_ray_chunks = render_rays( 78 | models, 79 | embeddings, 80 | rays[i: i + chunk], 81 | N_samples, 82 | use_disp, 83 | 0, 84 | 0, 85 | N_importance, 86 | chunk, 87 | dataset.white_back, 88 | test_time=True, 89 | ) 90 | 91 | for k, v in rendered_ray_chunks.items(): 92 | results[k] += [v.cpu()] 93 | 94 | for k, v in results.items(): 95 | results[k] = torch.cat(v, 0) 96 | return results 97 | 98 | 99 | if __name__ == "__main__": 100 | args = get_opts() 101 | w, h = args.img_wh 102 | lpips_vgg = lpips.LPIPS(net="vgg").cuda() 103 | lpips_vgg = lpips_vgg.eval() 104 | 105 | kwargs = {"root_dir": args.root_dir, "split": args.split, 106 | "img_wh": tuple(args.img_wh), "val_num": 3} 107 | if args.dataset_name == "llff": 108 | kwargs["spheric_poses"] = args.spheric_poses 109 | dataset = dataset_dict[args.dataset_name](**kwargs) 110 | 111 | system = NeRFSystem(args) 112 | 113 | embedding_xyz = system.embedding_xyz 114 | embedding_dir = system.embedding_dir 115 | models = system.models 116 | for k in models.keys(): 117 | models[k] = models[k].cuda().eval() 118 | 119 | imgs, depth_maps, psnrs, mean_lpips, ssims = [], [], [], [], [] 120 | dir_name = f"results/{args.dataset_name}/{args.scene_name}" 121 | os.makedirs(dir_name, exist_ok=True) 122 | 123 | embeddings = {"xyz": embedding_xyz, "dir": embedding_dir} 124 | 125 | for i in tqdm(range(len(dataset))): 126 | sample = dataset[i] 127 | rays = sample["rays"].cuda() 128 | results = batched_inference( 129 | models, embeddings, rays, args.N_samples, args.N_importance, args.use_disp, args.chunk 130 | ) 131 | typ = "fine" if "rgb_fine" in results else "coarse" 132 | 133 | img_pred = np.clip(results[f"rgb_{typ}"].view( 134 | h, w, 3).cpu().numpy(), 0, 1) 135 | 136 | if args.save_depth: 137 | depth_pred = results[f"depth_{typ}"].view(h, w).cpu().numpy() 138 | depth_maps += [depth_pred] 139 | if args.depth_format == "pfm": 140 | save_pfm(os.path.join( 141 | dir_name, f"depth_{i:03d}.pfm"), depth_pred) 142 | else: 143 | with open(os.path.join(dir_name, f"depth_{i:03d}"), "wb") as f: 144 | f.write(depth_pred.tobytes()) 145 | 146 | img_pred_ = (img_pred * 255).astype(np.uint8) 147 | imgs += [img_pred_] 148 | imageio.imwrite(os.path.join(dir_name, f"{i:03d}.png"), img_pred_) 149 | 150 | img_gt_ = (sample["rgbs"].view( 151 | h, w, 3).cpu().numpy() * 255).astype(np.uint8) 152 | imageio.imwrite(os.path.join(dir_name, f"gt_{i:03d}.png"), img_gt_) 153 | 154 | if "rgbs" in sample: 155 | rgbs = sample["rgbs"] 156 | img_gt = rgbs.view(h, w, 3) 157 | 158 | # scale to compute lpips 159 | scaled_gt = img_gt * 2.0 - 1.0 160 | scaled_pred = img_pred * 2.0 - 1.0 161 | scaled_pred = torch.from_numpy(scaled_pred) 162 | lpips_val = lpips_vgg( 163 | scaled_gt[:, :, [2, 1, 0]].permute( 164 | 2, 0, 1).unsqueeze(0).cuda(), 165 | scaled_pred[:, :, [2, 1, 0]].permute( 166 | 2, 0, 1).unsqueeze(0).cuda(), 167 | ) 168 | mean_lpips.append(lpips_val.detach().squeeze().cpu().numpy()) 169 | psnrs += [metrics.psnr(img_gt, img_pred).item()] 170 | 171 | # compute ssim 172 | ssim = metrics.ssim(img_gt.permute(2, 0, 1)[None], torch.from_numpy( 173 | img_pred).permute(2, 0, 1)[None]) 174 | ssims.append(ssim) 175 | 176 | imageio.mimsave(os.path.join( 177 | dir_name, f"{args.scene_name}.gif"), imgs, fps=30) 178 | 179 | if args.save_depth: 180 | min_depth = np.min(depth_maps) 181 | max_depth = np.max(depth_maps) 182 | depth_imgs = (depth_maps - np.min(depth_maps)) / \ 183 | (max(np.max(depth_maps) - np.min(depth_maps), 1e-8)) 184 | depth_imgs_ = [cv2.applyColorMap( 185 | (img * 255).astype(np.uint8), cv2.COLORMAP_JET) for img in depth_imgs] 186 | imageio.mimsave(os.path.join( 187 | dir_name, f"{args.scene_name}_depth.gif"), depth_imgs_, fps=30) 188 | 189 | if psnrs: 190 | mean_psnr = np.mean(psnrs) 191 | print(f"Mean PSNR : {mean_psnr:.3f}") 192 | 193 | if ssims: 194 | mean_ssim = np.mean(ssims) 195 | print(f"Mean SSIM : {mean_ssim:.3f}") 196 | 197 | if mean_lpips: 198 | mean_lpips = np.mean(np.array(mean_lpips)) 199 | print(f"Mean LPIPS : {mean_lpips:.3f}") 200 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class ColorLoss(nn.Module): 5 | def __init__(self, coef=1): 6 | super().__init__() 7 | self.coef = coef 8 | self.loss = nn.MSELoss(reduction="mean") 9 | 10 | def forward(self, inputs, targets): 11 | if "rgb_corase" in inputs: 12 | loss = self.loss(inputs["rgb_coarse"], targets) 13 | if "rgb_fine" in inputs: 14 | loss += self.loss(inputs["rgb_fine"], targets) 15 | else: 16 | loss = self.loss(inputs["rgb_fine"], targets) 17 | return self.coef * loss 18 | 19 | 20 | loss_dict = {"color": ColorLoss} 21 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.losses import ssim_loss as dssim 3 | 4 | 5 | def mse(image_pred, image_gt, valid_mask=None, reduction="mean"): 6 | value = (image_pred - image_gt) ** 2 7 | if valid_mask is not None: 8 | value = value[valid_mask] 9 | if reduction == "mean": 10 | return torch.mean(value) 11 | return value 12 | 13 | 14 | def psnr(image_pred, image_gt, valid_mask=None, reduction="mean"): 15 | return -10 * torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 16 | 17 | 18 | def ssim(image_pred, image_gt, reduction="mean"): 19 | """ 20 | image_pred and image_gt: (1, 3, H, W) 21 | """ 22 | dssim_ = dssim(image_pred, image_gt, 3, reduction=reduction) # dissimilarity in [0, 1] 23 | return 1 - 2 * dssim_ # in [-1, 1] 24 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/models/__init__.py -------------------------------------------------------------------------------- /models/cloud_code.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import faiss # make faiss available 4 | import faiss.contrib.torch_utils 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | # ! CloudNeRF configs 10 | config = {} 11 | config["code_cloud"] = {} 12 | config["code_cloud"]["num_codes"] = 8192 # 8192 13 | config["code_cloud"]["num_neighbors"] = 32 # 32 14 | 15 | config["code_cloud"]["code_dim"] = 64 # 64 16 | config["code_cloud"]["dist_scale"] = 3.0 17 | config["code_regularization_lambda"] = 0.0 18 | config["code_position_lambda"] = 0 19 | config["distortion_lambda"] = 0 20 | 21 | 22 | @torch.no_grad() 23 | def find_knn(gpu_index, locs, current_codes, neighbor=config["code_cloud"]["num_neighbors"]): 24 | n_points = locs.shape[0] 25 | # Search with torch GPU using pre-allocated arrays 26 | new_d_square_torch_gpu = torch.zeros( 27 | n_points, neighbor, device=locs.device, dtype=torch.float32) 28 | new_i_torch_gpu = torch.zeros( 29 | n_points, neighbor, device=locs.device, dtype=torch.int64) 30 | 31 | # update current codes 32 | gpu_index.add(current_codes) 33 | 34 | gpu_index.search(locs, neighbor, new_d_square_torch_gpu, new_i_torch_gpu) 35 | gpu_index.reset() 36 | 37 | return new_d_square_torch_gpu, new_i_torch_gpu 38 | 39 | 40 | class CodeCloud(nn.Module): 41 | def __init__(self, config, num_records, keypoints, fps_keypoints): 42 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "Building CodeCloud.") 43 | super().__init__() 44 | self.config = config 45 | self.SH_basis_dim = 9 46 | self.origin_keypoints = nn.Parameter( 47 | torch.Tensor(keypoints.float())[None, ...].repeat(num_records, 1, 1), requires_grad=False 48 | ) 49 | 50 | self.codes_position = nn.Parameter(torch.Tensor(fps_keypoints.float())[ 51 | None, ...].repeat(num_records, 1, 1)) 52 | self.codes = nn.Parameter(torch.randn( 53 | num_records, config["num_codes"], config["code_dim"]) * 0.01) 54 | 55 | self.knn = self.init_knn() 56 | 57 | num_params = sum(p.data.nelement() for p in self.parameters()) 58 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 59 | "CodeCloud done(#parameters=%d)." % num_params) 60 | 61 | def init_knn(self): 62 | faiss_cfg = faiss.GpuIndexFlatConfig() 63 | faiss_cfg.useFloat16 = True 64 | faiss_cfg.device = 0 65 | 66 | return faiss.GpuIndexFlatL2(faiss.StandardGpuResources(), 3, faiss_cfg) 67 | 68 | def query(self, indices, query_points): 69 | """ 70 | Args: 71 | indices: tensor, (batch_size,) 72 | query_points: tensor, (batch_size, num_points, 3) 73 | Returns: 74 | query_codes: tensor, (batch_size, num_points, code_dim) 75 | square_dist: tensor, (batch_size, num_points, num_codes) 76 | weight: tensor, (batch_size, num_points, num_codes) 77 | """ 78 | batch_codes_position = self.codes_position[indices] 79 | 80 | # ! NO GRAD: Need to recompute the square distance for gradients 81 | _, new_i_torch_gpu = find_knn( 82 | self.knn, query_points[0], batch_codes_position[0]) 83 | 84 | square_dist = (query_points.unsqueeze(2) - self.codes_position[indices][:, new_i_torch_gpu, :]).pow(2).sum( 85 | dim=-1 86 | ) + 1e-16 87 | 88 | weight = 1.0 / (torch.sqrt(square_dist) ** self.config["dist_scale"]) 89 | weight = weight / weight.sum(dim=-1, keepdim=True) 90 | 91 | query_codes = torch.matmul(weight[0].unsqueeze( 92 | 1), self.codes[indices][:, new_i_torch_gpu, :][0]).squeeze(1) 93 | 94 | return query_codes 95 | 96 | def get_proposal(self, pts): 97 | n_rays, n_samples, _ = pts.shape 98 | new_d_torch_gpu, _ = find_knn( 99 | self.knn, pts.flatten(0, 1), self.origin_keypoints[0]) 100 | 101 | farthest_d = new_d_torch_gpu[:, -1].view(n_rays, n_samples) 102 | 103 | farthest_d[farthest_d > 0.2] = 0.2 104 | 105 | weight = 1.0 / (torch.sqrt(farthest_d) + 1e-10) 106 | weight = weight / weight.sum(dim=-1, keepdim=True) 107 | return weight 108 | 109 | 110 | class IM_Decoder(nn.Module): 111 | def __init__(self, D=8, W=256, in_channels_xyz=63, in_channels_dir=27, skips=[4]): 112 | print(datetime.now().strftime( 113 | "%Y-%m-%d %H:%M:%S"), "Building IM-decoder.") 114 | super(IM_Decoder, self).__init__() 115 | self.D = D 116 | self.W = W 117 | self.in_channels_xyz = in_channels_xyz 118 | self.in_channels_dir = in_channels_dir 119 | self.skips = skips 120 | 121 | # xyz encoding layers 122 | for i in range(D): 123 | if i == 0: 124 | layer = nn.Linear(in_channels_xyz, W) 125 | elif i in skips: 126 | layer = nn.Linear(W + in_channels_xyz, W) 127 | else: 128 | layer = nn.Linear(W, W) 129 | layer = nn.Sequential(layer, nn.ReLU(True)) 130 | setattr(self, f"xyz_encoding_{i+1}", layer) 131 | self.xyz_encoding_final = nn.Linear(W, W) 132 | 133 | # direction encoding layers 134 | self.dir_encoding = nn.Sequential( 135 | nn.Linear(W + in_channels_dir, W // 2), nn.ReLU(True)) 136 | 137 | # output layers 138 | self.sigma = nn.Linear(W, 1) 139 | self.rgb = nn.Linear(W // 2, 3) 140 | 141 | num_params = sum(p.data.nelement() for p in self.parameters()) 142 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 143 | "IM decoder done(#parameters=%d)." % num_params) 144 | 145 | def forward(self, x): 146 | input_xyz, input_dir = torch.split( 147 | x, [self.in_channels_xyz, self.in_channels_dir], dim=-1) 148 | xyz_ = input_xyz 149 | for i in range(self.D): 150 | if i in self.skips: 151 | xyz_ = torch.cat([input_xyz, xyz_], -1) 152 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 153 | 154 | sigma = self.sigma(xyz_) 155 | 156 | xyz_encoding_final = self.xyz_encoding_final(xyz_) 157 | 158 | dir_encoding_input = torch.cat([xyz_encoding_final, input_dir], -1) 159 | dir_encoding = self.dir_encoding(dir_encoding_input) 160 | rgb = self.rgb(dir_encoding) 161 | 162 | out = torch.cat([rgb, sigma], -1) 163 | 164 | return out 165 | 166 | 167 | class CloudNeRF(nn.Module): 168 | def __init__(self, keypoints, fps_kps, input_ch, input_ch_views, num_records=1): 169 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "Building network.") 170 | super().__init__() 171 | global config 172 | self.config = config 173 | 174 | self.code_cloud = CodeCloud( 175 | config["code_cloud"], num_records, keypoints, fps_kps) 176 | self.decoder = IM_Decoder( 177 | D=4, 178 | W=128, 179 | in_channels_xyz=config["code_cloud"]["code_dim"] + input_ch, 180 | in_channels_dir=input_ch_views, 181 | skips=[2], 182 | ) 183 | 184 | num_params = sum(p.data.nelement() for p in self.parameters()) 185 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 186 | "Network done(#parameters=%d)." % num_params) 187 | 188 | def forward(self, indices, query_points, xyzdir_embedded): 189 | """ 190 | Args: 191 | indices: tensor, (batch_size,) 192 | query_points: tensor, (batch_size, num_points, 3) 193 | Returns: 194 | pred_sd: tensor, (batch_size, num_points) 195 | """ 196 | query_codes = self.code_cloud.query(indices, query_points[None, ...]) 197 | 198 | batch_input = torch.cat([query_codes, xyzdir_embedded], dim=-1) 199 | pred_sd = self.decoder(batch_input) 200 | 201 | return pred_sd 202 | 203 | 204 | ################################## SH FEATURES ############################################## 205 | class SHCodeCloud(nn.Module): 206 | def __init__(self, config, num_records, keypoints, fps_keypoints): 207 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'Building CodeCloud.') 208 | super().__init__() 209 | self.config = config 210 | self.SH_basis_dim = 9 211 | self.origin_keypoints = nn.Parameter(torch.Tensor( 212 | keypoints.float())[None, ...].repeat(num_records, 1, 1), requires_grad=False) 213 | 214 | self.codes_position = nn.Parameter(torch.Tensor( 215 | fps_keypoints.float())[None, ...].repeat(num_records, 1, 1)) 216 | self.codes = nn.Parameter(torch.randn( 217 | num_records, config['num_codes'], config['code_dim']) * 0.01) 218 | self.sh_codes = nn.Parameter(torch.randn( 219 | num_records, config['num_codes'], config['code_dim'] * self.SH_basis_dim) * 0.01) 220 | 221 | self.knn = self.init_knn() 222 | 223 | num_params = sum(p.data.nelement() for p in self.parameters()) 224 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 225 | 'CodeCloud done(#parameters=%d).' % num_params) 226 | 227 | def init_knn(self): 228 | faiss_cfg = faiss.GpuIndexFlatConfig() 229 | faiss_cfg.useFloat16 = True 230 | faiss_cfg.device = 0 231 | 232 | return faiss.GpuIndexFlatL2( 233 | faiss.StandardGpuResources(), 3, faiss_cfg) 234 | 235 | def query(self, indices, query_points, viewdirs): 236 | """ 237 | Args: 238 | indices: tensor, (batch_size,) 239 | query_points: tensor, (batch_size, num_points, 3) 240 | Returns: 241 | query_codes: tensor, (batch_size, num_points, code_dim) 242 | square_dist: tensor, (batch_size, num_points, num_codes) 243 | weight: tensor, (batch_size, num_points, num_codes) 244 | """ 245 | batch_codes_position = self.codes_position[indices] 246 | 247 | # ! NO GRAD: Need to recompute the square distance for gradients 248 | _, new_i_torch_gpu = find_knn( 249 | self.knn, query_points[0], batch_codes_position[0]) 250 | 251 | sh_feat = self.sh_codes[indices][:, new_i_torch_gpu, :][0] 252 | sh_feat = sh_feat.reshape(-1, self.config['num_neighbors'], 253 | self.config['code_dim'], self.SH_basis_dim) 254 | sh_mult = eval_sh_bases(self.SH_basis_dim, viewdirs).unsqueeze( 255 | 1).repeat(1, self.config['num_neighbors'], 1) 256 | agg_sh_feat = torch.sum(sh_mult.unsqueeze(-2) * sh_feat, dim=-1) 257 | 258 | square_dist = (query_points.unsqueeze( 259 | 2) - self.codes_position[indices][:, new_i_torch_gpu, :]).pow(2).sum(dim=-1) + 1e-16 260 | 261 | weight = 1.0 / (torch.sqrt(square_dist) ** self.config['dist_scale']) 262 | weight = weight / weight.sum(dim=-1, keepdim=True) 263 | 264 | query_codes = torch.matmul( 265 | weight[0].unsqueeze(1), self.codes[indices][:, new_i_torch_gpu, :][0]).squeeze(1) 266 | 267 | query_sh_codes = torch.matmul( 268 | weight[0].unsqueeze(1), agg_sh_feat).squeeze(1) 269 | 270 | return query_codes, query_sh_codes 271 | 272 | def get_proposal(self, pts): 273 | n_rays, n_samples, _ = pts.shape 274 | new_d_torch_gpu, _ = find_knn( 275 | self.knn, pts.flatten(0, 1), self.origin_keypoints[0]) 276 | 277 | farthest_d = new_d_torch_gpu[:, -1].view(n_rays, n_samples) 278 | 279 | farthest_d[farthest_d > 0.2] = 0.2 280 | 281 | weight = 1.0 / (torch.sqrt(farthest_d) + 1e-10) 282 | weight = weight / weight.sum(dim=-1, keepdim=True) 283 | return weight 284 | 285 | 286 | class SH_IM_Decoder(nn.Module): 287 | def __init__(self, 288 | D=8, W=256, code_dim=64, 289 | in_channels_xyz=63, in_channels_dir=27, 290 | skips=[4]): 291 | print(datetime.now().strftime( 292 | '%Y-%m-%d %H:%M:%S'), 'Building IM-decoder.') 293 | super(SH_IM_Decoder, self).__init__() 294 | self.D = D 295 | self.W = W 296 | self.in_channels_xyz = in_channels_xyz 297 | self.in_channels_dir = in_channels_dir 298 | self.skips = skips 299 | self.code_dim = code_dim 300 | 301 | # xyz encoding layers 302 | for i in range(D): 303 | if i == 0: 304 | layer = nn.Linear(in_channels_xyz + self.code_dim, W) 305 | elif i in skips: 306 | layer = nn.Linear(W + in_channels_xyz + self.code_dim, W) 307 | else: 308 | layer = nn.Linear(W, W) 309 | layer = nn.Sequential(layer, nn.ReLU(True)) 310 | setattr(self, f"xyz_encoding_{i+1}", layer) 311 | self.xyz_encoding_final = nn.Linear(W, W) 312 | 313 | # direction encoding layers 314 | self.dir_encoding = nn.Sequential( 315 | nn.Linear(W + in_channels_dir + self.code_dim, W // 2), 316 | nn.ReLU(True)) 317 | 318 | # output layers 319 | self.sigma = nn.Linear(W, 1) 320 | self.rgb = nn.Linear(W // 2, 3) 321 | 322 | num_params = sum(p.data.nelement() for p in self.parameters()) 323 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 324 | 'IM decoder done(#parameters=%d).' % num_params) 325 | 326 | def forward(self, codes, sh_codes, xyzdir_embed): 327 | input_xyz, input_dir = \ 328 | torch.split(xyzdir_embed, [self.in_channels_xyz, 329 | self.in_channels_dir], dim=-1) 330 | input_xyz = torch.cat([codes, input_xyz], dim=-1) 331 | xyz_ = input_xyz 332 | for i in range(self.D): 333 | if i in self.skips: 334 | xyz_ = torch.cat([input_xyz, xyz_], -1) 335 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 336 | 337 | sigma = self.sigma(xyz_) 338 | 339 | xyz_encoding_final = self.xyz_encoding_final(xyz_) 340 | 341 | dir_encoding_input = torch.cat( 342 | [xyz_encoding_final, sh_codes, input_dir], -1) 343 | dir_encoding = self.dir_encoding(dir_encoding_input) 344 | rgb = self.rgb(dir_encoding) 345 | 346 | out = torch.cat([rgb, sigma], -1) 347 | 348 | return out 349 | 350 | 351 | class SHCloudNeRF(nn.Module): 352 | def __init__(self, keypoints, fps_kps, input_ch, input_ch_views, num_records=1): 353 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'Building network.') 354 | super().__init__() 355 | global config 356 | self.config = config 357 | 358 | self.code_cloud = SHCodeCloud( 359 | config['code_cloud'], num_records, keypoints, fps_kps) 360 | self.decoder = SH_IM_Decoder(D=4, W=128, code_dim=config['code_cloud']['code_dim'], 361 | in_channels_xyz=input_ch, in_channels_dir=input_ch_views, 362 | skips=[2]) 363 | 364 | num_params = sum(p.data.nelement() for p in self.parameters()) 365 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 366 | 'Network done(#parameters=%d).' % num_params) 367 | 368 | def forward(self, indices, query_points, viewdirs, xyzdir_embedded): 369 | """ 370 | Args: 371 | indices: tensor, (batch_size,) 372 | query_points: tensor, (batch_size, num_points, 3) 373 | Returns: 374 | pred_sd: tensor, (batch_size, num_points) 375 | """ 376 | query_codes, query_sh_codes = self.code_cloud.query( 377 | indices, query_points[None, ...], viewdirs) 378 | pred_sd = self.decoder(query_codes, query_sh_codes, xyzdir_embedded) 379 | 380 | return pred_sd 381 | 382 | 383 | SH_C0 = 0.28209479177387814 384 | SH_C1 = 0.4886025119029199 385 | SH_C2 = [ 386 | 1.0925484305920792, 387 | -1.0925484305920792, 388 | 0.31539156525252005, 389 | -1.0925484305920792, 390 | 0.5462742152960396 391 | ] 392 | SH_C3 = [ 393 | -0.5900435899266435, 394 | 2.890611442640554, 395 | -0.4570457994644658, 396 | 0.3731763325901154, 397 | -0.4570457994644658, 398 | 1.445305721320277, 399 | -0.5900435899266435 400 | ] 401 | SH_C4 = [ 402 | 2.5033429417967046, 403 | -1.7701307697799304, 404 | 0.9461746957575601, 405 | -0.6690465435572892, 406 | 0.10578554691520431, 407 | -0.6690465435572892, 408 | 0.47308734787878004, 409 | -1.7701307697799304, 410 | 0.6258357354491761, 411 | ] 412 | 413 | 414 | def eval_sh_bases(basis_dim: int, dirs: torch.Tensor): 415 | """ 416 | Evaluate spherical harmonics bases at unit directions, 417 | without taking linear combination. 418 | At each point, the final result may the be 419 | obtained through simple multiplication. 420 | :param basis_dim: int SH basis dim. Currently, 1-25 square numbers supported 421 | :param dirs: torch.Tensor (..., 3) unit directions 422 | :return: torch.Tensor (..., basis_dim) 423 | """ 424 | result = torch.empty( 425 | (*dirs.shape[:-1], basis_dim), dtype=dirs.dtype, device=dirs.device) 426 | result[..., 0] = SH_C0 427 | if basis_dim > 1: 428 | x, y, z = dirs.unbind(-1) 429 | result[..., 1] = -SH_C1 * y 430 | result[..., 2] = SH_C1 * z 431 | result[..., 3] = -SH_C1 * x 432 | if basis_dim > 4: 433 | xx, yy, zz = x * x, y * y, z * z 434 | xy, yz, xz = x * y, y * z, x * z 435 | result[..., 4] = SH_C2[0] * xy 436 | result[..., 5] = SH_C2[1] * yz 437 | result[..., 6] = SH_C2[2] * (2.0 * zz - xx - yy) 438 | result[..., 7] = SH_C2[3] * xz 439 | result[..., 8] = SH_C2[4] * (xx - yy) 440 | 441 | if basis_dim > 9: 442 | result[..., 9] = SH_C3[0] * y * (3 * xx - yy) 443 | result[..., 10] = SH_C3[1] * xy * z 444 | result[..., 11] = SH_C3[2] * y * (4 * zz - xx - yy) 445 | result[..., 12] = SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy) 446 | result[..., 13] = SH_C3[4] * x * (4 * zz - xx - yy) 447 | result[..., 14] = SH_C3[5] * z * (xx - yy) 448 | result[..., 15] = SH_C3[6] * x * (xx - 3 * yy) 449 | 450 | if basis_dim > 16: 451 | result[..., 16] = SH_C4[0] * xy * (xx - yy) 452 | result[..., 17] = SH_C4[1] * yz * (3 * xx - yy) 453 | result[..., 18] = SH_C4[2] * xy * (7 * zz - 1) 454 | result[..., 19] = SH_C4[3] * yz * (7 * zz - 3) 455 | result[..., 20] = SH_C4[4] * (zz * (35 * zz - 30) + 3) 456 | result[..., 21] = SH_C4[5] * xz * (7 * zz - 3) 457 | result[..., 22] = SH_C4[6] * (xx - yy) * (7 * zz - 1) 458 | result[..., 23] = SH_C4[7] * xz * (xx - 3 * yy) 459 | result[..., 24] = SH_C4[8] * \ 460 | (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 461 | return result 462 | -------------------------------------------------------------------------------- /models/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Embedding(nn.Module): 6 | def __init__(self, N_freqs, logscale=True): 7 | """ 8 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 9 | in_channels: number of input channels (3 for both xyz and direction) 10 | """ 11 | super().__init__() 12 | self.N_freqs = N_freqs 13 | self.funcs = [torch.sin, torch.cos] 14 | 15 | if logscale: 16 | self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs) 17 | else: 18 | self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs) 19 | 20 | def forward(self, x): 21 | """ 22 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 23 | Different from the paper, "x" is also in the output 24 | See https://github.com/bmild/nerf/issues/12 25 | 26 | Inputs: 27 | x: (B, f) 28 | 29 | Outputs: 30 | out: (B, 2*f*N_freqs+f) 31 | """ 32 | out = [x] 33 | for freq in self.freq_bands: 34 | for func in self.funcs: 35 | out += [func(freq * x)] 36 | 37 | return torch.cat(out, -1) 38 | 39 | 40 | class NeRF(nn.Module): 41 | def __init__(self, D=8, W=256, in_channels_xyz=63, in_channels_dir=27, skips=[4]): 42 | """ 43 | D: number of layers for density (sigma) encoder 44 | W: number of hidden units in each layer 45 | in_channels_xyz: number of input channels for xyz (3+3*10*2=63 by default) 46 | in_channels_dir: number of input channels for direction (3+3*4*2=27 by default) 47 | skips: add skip connection in the Dth layer 48 | """ 49 | super(NeRF, self).__init__() 50 | self.D = D 51 | self.W = W 52 | self.in_channels_xyz = in_channels_xyz 53 | self.in_channels_dir = in_channels_dir 54 | self.skips = skips 55 | 56 | # xyz encoding layers 57 | for i in range(D): 58 | if i == 0: 59 | layer = nn.Linear(in_channels_xyz, W) 60 | elif i in skips: 61 | layer = nn.Linear(W + in_channels_xyz, W) 62 | else: 63 | layer = nn.Linear(W, W) 64 | layer = nn.Sequential(layer, nn.ReLU(True)) 65 | setattr(self, f"xyz_encoding_{i+1}", layer) 66 | self.xyz_encoding_final = nn.Linear(W, W) 67 | 68 | # direction encoding layers 69 | self.dir_encoding = nn.Sequential(nn.Linear(W + in_channels_dir, W // 2), nn.ReLU(True)) 70 | 71 | # output layers 72 | self.sigma = nn.Linear(W, 1) 73 | self.rgb = nn.Sequential(nn.Linear(W // 2, 3), nn.Sigmoid()) 74 | 75 | def forward(self, x, sigma_only=False): 76 | """ 77 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet). 78 | For rendering this ray, please see rendering.py 79 | 80 | Inputs: 81 | x: (B, self.in_channels_xyz(+self.in_channels_dir)) 82 | the embedded vector of position and direction 83 | sigma_only: whether to infer sigma only. If True, 84 | x is of shape (B, self.in_channels_xyz) 85 | 86 | Outputs: 87 | if sigma_ony: 88 | sigma: (B, 1) sigma 89 | else: 90 | out: (B, 4), rgb and sigma 91 | """ 92 | if not sigma_only: 93 | input_xyz, input_dir = torch.split(x, [self.in_channels_xyz, self.in_channels_dir], dim=-1) 94 | else: 95 | input_xyz = x 96 | 97 | xyz_ = input_xyz 98 | for i in range(self.D): 99 | if i in self.skips: 100 | xyz_ = torch.cat([input_xyz, xyz_], -1) 101 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 102 | 103 | sigma = self.sigma(xyz_) 104 | if sigma_only: 105 | return sigma 106 | 107 | xyz_encoding_final = self.xyz_encoding_final(xyz_) 108 | 109 | dir_encoding_input = torch.cat([xyz_encoding_final, input_dir], -1) 110 | dir_encoding = self.dir_encoding(dir_encoding_input) 111 | rgb = self.rgb(dir_encoding) 112 | 113 | out = torch.cat([rgb, sigma], -1) 114 | 115 | return out 116 | -------------------------------------------------------------------------------- /models/rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange, reduce, repeat 4 | 5 | 6 | __all__ = ["render_rays"] 7 | 8 | 9 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): 10 | """ 11 | Sample @N_importance samples from @bins with distribution defined by @weights. 12 | Inputs: 13 | bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" 14 | weights: (N_rays, N_samples_) 15 | N_importance: the number of samples to draw from the distribution 16 | det: deterministic or not 17 | eps: a small number to prevent division by zero 18 | Outputs: 19 | samples: (N_rays, N_importance) the sampled samples 20 | """ 21 | N_rays, N_samples_ = weights.shape 22 | weights = weights + eps # prevent division by zero (don't do inplace op!) 23 | pdf = weights / reduce(weights, "n1 n2 -> n1 1", 24 | "sum") # (N_rays, N_samples_) 25 | # (N_rays, N_samples), cumulative distribution function 26 | cdf = torch.cumsum(pdf, -1) 27 | # (N_rays, N_samples_+1) 28 | cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) 29 | # padded to 0~1 inclusive 30 | 31 | if det: 32 | u = torch.linspace(0, 1, N_importance, device=bins.device) 33 | u = u.expand(N_rays, N_importance) 34 | else: 35 | u = torch.rand(N_rays, N_importance, device=bins.device) 36 | u = u.contiguous() 37 | 38 | inds = torch.searchsorted(cdf, u, right=True) 39 | below = torch.clamp_min(inds - 1, 0) 40 | above = torch.clamp_max(inds, N_samples_) 41 | 42 | inds_sampled = rearrange(torch.stack( 43 | [below, above], -1), "n1 n2 c -> n1 (n2 c)", c=2) 44 | cdf_g = rearrange(torch.gather(cdf, 1, inds_sampled), 45 | "n1 (n2 c) -> n1 n2 c", c=2) 46 | bins_g = rearrange(torch.gather(bins, 1, inds_sampled), 47 | "n1 (n2 c) -> n1 n2 c", c=2) 48 | 49 | denom = cdf_g[..., 1] - cdf_g[..., 0] 50 | denom[denom < eps] = 1 # denom equals 0 means a bin has weight 0, 51 | # in which case it will not be sampled 52 | # anyway, therefore any value for it is fine (set to 1 here) 53 | 54 | samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / \ 55 | denom * (bins_g[..., 1] - bins_g[..., 0]) 56 | return samples 57 | 58 | 59 | def render_rays( 60 | models, 61 | embeddings, 62 | rays, 63 | N_samples=64, 64 | use_disp=False, 65 | perturb=0, 66 | noise_std=1, 67 | N_importance=0, 68 | chunk=1024 * 32, 69 | white_back=False, 70 | test_time=False, 71 | **kwargs, 72 | ): 73 | """ 74 | Render rays by computing the output of @model applied on @rays 75 | Inputs: 76 | models: list of NeRF models (coarse and fine) defined in nerf.py 77 | embeddings: list of embedding models of origin and direction defined in nerf.py 78 | rays: (N_rays, 3+3+2), ray origins and directions, near and far depths 79 | N_samples: number of coarse samples per ray 80 | use_disp: whether to sample in disparity space (inverse depth) 81 | perturb: factor to perturb the sampling position on the ray (for coarse model only) 82 | noise_std: factor to perturb the model's prediction of sigma 83 | N_importance: number of fine samples per ray 84 | chunk: the chunk size in batched inference 85 | white_back: whether the background is white (dataset dependent) 86 | test_time: whether it is test (inference only) or not. If True, it will not do inference 87 | on coarse rgb to save time 88 | Outputs: 89 | result: dictionary containing final rgb and depth maps for coarse and fine models 90 | """ 91 | 92 | def inference(results, model, typ, xyz, z_vals, test_time=False, **kwargs): 93 | """ 94 | Helper function that performs model inference. 95 | Inputs: 96 | results: a dict storing all results 97 | model: NeRF model (coarse or fine) 98 | typ: 'coarse' or 'fine' 99 | xyz: (N_rays, N_samples_, 3) sampled positions 100 | N_samples_ is the number of sampled points in each ray; 101 | = N_samples for coarse model 102 | = N_samples+N_importance for fine model 103 | z_vals: (N_rays, N_samples_) depths of the sampled positions 104 | test_time: test time or not 105 | Outputs: 106 | if weights_only: 107 | weights: (N_rays, N_samples_): weights of each sample 108 | else: 109 | rgb_final: (N_rays, 3) the final rgb image 110 | depth_final: (N_rays) depth map 111 | weights: (N_rays, N_samples_): weights of each sample 112 | """ 113 | N_samples_ = xyz.shape[1] 114 | xyz_ = rearrange(xyz, "n1 n2 c -> (n1 n2) c") # (N_rays*N_samples_, 3) 115 | 116 | # Perform model inference to get rgb and raw sigma 117 | B = xyz_.shape[0] 118 | out_chunks = [] 119 | if typ == "coarse" and test_time and "fine" in models: 120 | for i in range(0, B, chunk): 121 | xyz_embedded = embedding_xyz(xyz_[i: i + chunk]) 122 | out_chunks += [model(xyz_embedded, sigma_only=True)] 123 | 124 | out = torch.cat(out_chunks, 0) 125 | sigmas = rearrange(out, "(n1 n2) 1 -> n1 n2", 126 | n1=N_rays, n2=N_samples_) 127 | else: # infer rgb and sigma and others 128 | dir_embedded_ = repeat( 129 | dir_embedded, "n1 c -> (n1 n2) c", n2=N_samples_) 130 | if kwargs["use_sh_feat"]: 131 | viewdirs = repeat(kwargs['viewdirs'], 'n1 c -> (n1 n2) c', n2=N_samples_) 132 | # (N_rays*N_samples_, embed_dir_channels) 133 | for i in range(0, B, chunk): 134 | xyz_embedded = embedding_xyz(xyz_[i: i + chunk]) 135 | xyzdir_embedded = torch.cat( 136 | [xyz_embedded, dir_embedded_[i: i + chunk]], 1) 137 | # out_chunks += [model(xyzdir_embedded, sigma_only=False)] 138 | # ! cloud nerf fwd 139 | indices = torch.zeros(1, dtype=torch.long).cuda() 140 | if kwargs["use_sh_feat"]: 141 | out_chunks += [model(indices, xyz_[i:i + chunk], viewdirs[i:i + chunk], 142 | xyzdir_embedded)] 143 | else: 144 | out_chunks += [model(indices, xyz_[i: i + 145 | chunk], xyzdir_embedded)] 146 | 147 | out = torch.cat(out_chunks, 0) 148 | # out = out.view(N_rays, N_samples_, 4) 149 | out = rearrange(out, "(n1 n2) c -> n1 n2 c", 150 | n1=N_rays, n2=N_samples_, c=4) 151 | # (N_rays, N_samples_, 3) # ! APPLY SIGMOID HERE NOT IN THE NETWORK 152 | rgbs = torch.sigmoid(out[..., :3]) 153 | sigmas = out[..., 3] # (N_rays, N_samples_) 154 | 155 | # Convert these values using volume rendering (Section 4) 156 | deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples_-1) 157 | # (N_rays, 1) the last delta is infinity 158 | delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) 159 | deltas = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_) 160 | 161 | # compute alpha by the formula (3) 162 | noise = torch.randn_like(sigmas) * noise_std 163 | # (N_rays, N_samples_) 164 | # alphas = 1 - torch.exp(-deltas * torch.relu(sigmas + noise)) 165 | density_bias = -1 166 | alphas = 1 - torch.exp(-deltas * 167 | F.softplus(sigmas + noise + density_bias)) 168 | 169 | alphas_shifted = torch.cat([torch.ones_like( 170 | alphas[:, :1]), 1 - alphas + 1e-10], -1) # [1, 1-a1, 1-a2, ...] 171 | # (N_rays, N_samples_) 172 | weights = alphas * torch.cumprod(alphas_shifted[:, :-1], -1) 173 | # (N_rays), the accumulated opacity along the rays 174 | weights_sum = reduce(weights, "n1 n2 -> n1", "sum") 175 | # equals "1 - (1-a1)(1-a2)...(1-an)" mathematically 176 | 177 | results[f"weights_{typ}"] = weights 178 | results[f"opacity_{typ}"] = weights_sum 179 | results[f"z_vals_{typ}"] = z_vals 180 | if test_time and typ == "coarse" and "fine" in models: 181 | return 182 | 183 | rgb_map = reduce(rearrange(weights, "n1 n2 -> n1 n2 1") 184 | * rgbs, "n1 n2 c -> n1 c", "sum") 185 | depth_map = reduce(weights * z_vals, "n1 n2 -> n1", "sum") 186 | 187 | if white_back: 188 | rgb_map += 1 - weights_sum.unsqueeze(1) 189 | 190 | results[f"rgb_{typ}"] = rgb_map 191 | results[f"depth_{typ}"] = depth_map 192 | 193 | return 194 | 195 | embedding_xyz, embedding_dir = embeddings["xyz"], embeddings["dir"] 196 | 197 | # Decompose the inputs #! Note that rays_o,d are ndc, viewdirs are normalized world space 198 | N_rays = rays.shape[0] 199 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 200 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) 201 | viewdirs = rays[:, 8:] # (N_rays, 3) 202 | # Embed direction 203 | # (N_rays, embed_dir_channels) 204 | if kwargs["use_sh_feat"]: 205 | kwargs['viewdirs'] = viewdirs 206 | dir_embedded = embedding_dir(kwargs.get("view_dir", viewdirs)) 207 | 208 | rays_o = rearrange(rays_o, "n1 c -> n1 1 c") 209 | rays_d = rearrange(rays_d, "n1 c -> n1 1 c") 210 | 211 | # Sample depth points 212 | z_steps = torch.linspace( 213 | 0, 1, N_samples, device=rays.device) # (N_samples) 214 | if not use_disp: # use linear sampling in depth space 215 | z_vals = near * (1 - z_steps) + far * z_steps 216 | else: # use linear sampling in disparity space 217 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) 218 | 219 | z_vals = z_vals.expand(N_rays, N_samples) 220 | 221 | if perturb > 0: # perturb sampling depths (z_vals) 222 | # (N_rays, N_samples-1) interval mid points 223 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) 224 | # get intervals between samples 225 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1) 226 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1) 227 | 228 | perturb_rand = perturb * torch.rand_like(z_vals) 229 | z_vals = lower + (upper - lower) * perturb_rand 230 | 231 | xyz_coarse = rays_o + rays_d * rearrange(z_vals, "n1 n2 -> n1 n2 1") 232 | results = {} 233 | 234 | if N_importance == 0: # sample points for fine model 235 | inference(results, models["coarse"], "coarse", 236 | xyz_coarse, z_vals, test_time, **kwargs) 237 | else: 238 | with torch.no_grad(): 239 | results["weights_coarse"] = models["fine"].code_cloud.get_proposal( 240 | xyz_coarse) 241 | 242 | # breakpoint() 243 | # np.save('code_ndc.npy', models['fine'].code_cloud.origin_keypoints.cpu().numpy()) 244 | if N_importance > 0: # sample points for fine model 245 | # (N_rays, N_samples-1) interval mid points 246 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) 247 | z_vals_ = sample_pdf( 248 | z_vals_mid, results["weights_coarse"][:, 1:-1].detach(), N_importance, det=(perturb == 0)) 249 | # detach so that grad doesn't propogate to weights_coarse from here 250 | 251 | z_vals = torch.sort(torch.cat([z_vals, z_vals_], -1), -1)[0] 252 | # combine coarse and fine samples 253 | 254 | xyz_fine = rays_o + rays_d * rearrange(z_vals, "n1 n2 -> n1 n2 1") 255 | 256 | inference(results, models["fine"], "fine", 257 | xyz_fine, z_vals, test_time, **kwargs) 258 | 259 | return results 260 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_opts(): 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument( 8 | "--root_dir", 9 | type=str, 10 | default="/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego", 11 | help="root directory of dataset", 12 | ) 13 | parser.add_argument( 14 | "--dataset_name", type=str, default="blender", choices=["blender", "llff"], help="which dataset to train/val" 15 | ) 16 | parser.add_argument( 17 | "--img_wh", nargs="+", type=int, default=[800, 800], help="resolution (img_w, img_h) of the image" 18 | ) 19 | parser.add_argument( 20 | "--spheric_poses", 21 | default=False, 22 | action="store_true", 23 | help="whether images are taken in spheric poses (for llff)", 24 | ) 25 | 26 | parser.add_argument( 27 | "--not_use_mvs", 28 | default=False, 29 | action="store_true", 30 | help="whether exclude the mvs points cloud", 31 | ) 32 | 33 | parser.add_argument( 34 | "--use_sh_feat", 35 | default=False, 36 | action="store_true", 37 | help="whether use spherical harmonics features", 38 | ) 39 | 40 | parser.add_argument("--N_emb_xyz", type=int, default=10, 41 | help="number of frequencies in xyz positional encoding") 42 | parser.add_argument("--N_emb_dir", type=int, default=4, 43 | help="number of frequencies in dir positional encoding") 44 | parser.add_argument("--N_samples", type=int, default=128, 45 | help="number of coarse samples") 46 | parser.add_argument("--N_importance", type=int, default=0, 47 | help="number of additional fine samples") 48 | parser.add_argument("--use_disp", default=False, 49 | action="store_true", help="use disparity depth sampling") 50 | parser.add_argument("--perturb", type=float, default=1.0, 51 | help="factor to perturb depth sampling points") 52 | parser.add_argument("--noise_std", type=float, default=1.0, 53 | help="std dev of noise added to regularize sigma") 54 | 55 | parser.add_argument("--batch_size", type=int, 56 | default=1024, help="batch size") 57 | parser.add_argument("--chunk", type=int, default=32 * 58 | 1024, help="chunk size to split the input to avoid OOM") 59 | parser.add_argument("--num_epochs", type=int, default=16, 60 | help="number of training epochs") 61 | parser.add_argument("--num_gpus", type=int, 62 | default=1, help="number of gpus") 63 | 64 | parser.add_argument( 65 | "--ckpt_path", type=str, default=None, help="pretrained checkpoint to load (including optimizers, etc)" 66 | ) 67 | parser.add_argument( 68 | "--prefixes_to_ignore", 69 | nargs="+", 70 | type=str, 71 | default=["loss"], 72 | help="the prefixes to ignore in the checkpoint state dict", 73 | ) 74 | parser.add_argument( 75 | "--weight_path", type=str, default=None, help="pretrained model weight to load (do not load optimizers, etc)" 76 | ) 77 | 78 | parser.add_argument( 79 | "--optimizer", type=str, default="adam", help="optimizer type", choices=["sgd", "adam", "radam", "ranger"] 80 | ) 81 | parser.add_argument("--lr", type=float, default=5e-4, help="learning rate") 82 | parser.add_argument("--momentum", type=float, 83 | default=0.9, help="learning rate momentum") 84 | parser.add_argument("--weight_decay", type=float, 85 | default=0, help="weight decay") 86 | parser.add_argument( 87 | "--lr_scheduler", type=str, default="steplr", help="scheduler type", choices=["steplr", "cosine", "poly"] 88 | ) 89 | # params for warmup, only applied when optimizer == 'sgd' or 'adam' 90 | parser.add_argument( 91 | "--warmup_multiplier", type=float, default=1.0, help="lr is multiplied by this factor after --warmup_epochs" 92 | ) 93 | parser.add_argument( 94 | "--warmup_epochs", type=int, default=0, help="Gradually warm-up(increasing) learning rate in optimizer" 95 | ) 96 | ########################### 97 | # params for steplr 98 | parser.add_argument("--decay_step", nargs="+", type=int, 99 | default=[20], help="scheduler decay step") 100 | parser.add_argument("--decay_gamma", type=float, 101 | default=0.1, help="learning rate decay amount") 102 | ########################### 103 | # params for poly 104 | parser.add_argument("--poly_exp", type=float, default=0.9, 105 | help="exponent for polynomial learning rate decay") 106 | ########################### 107 | 108 | parser.add_argument("--exp_name", type=str, 109 | default="exp", help="experiment name") 110 | 111 | return parser.parse_args() 112 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import pointnet2_ops.pointnet2_modules 2 | import pointnet2_ops.pointnet2_utils 3 | from pointnet2_ops._version import __version__ 4 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data_ptr(), 26 | xyz.data_ptr(), idx.data_ptr()); 27 | } else { 28 | AT_ASSERT(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data_ptr(), knows.data_ptr(), 34 | dist2.data_ptr(), idx.data_ptr()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 64 | output.data_ptr()); 65 | } else { 66 | AT_ASSERT(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data_ptr(), idx.data_ptr(), 93 | weight.data_ptr(), output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | furthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.0" 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pointnet2_ops import pointnet2_utils 7 | 8 | 9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True): 10 | layers = [] 11 | for i in range(1, len(mlp_spec)): 12 | layers.append( 13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) 14 | ) 15 | if bn: 16 | layers.append(nn.BatchNorm2d(mlp_spec[i])) 17 | layers.append(nn.ReLU(True)) 18 | 19 | return nn.Sequential(*layers) 20 | 21 | 22 | class _PointnetSAModuleBase(nn.Module): 23 | def __init__(self): 24 | super(_PointnetSAModuleBase, self).__init__() 25 | self.npoint = None 26 | self.groupers = None 27 | self.mlps = None 28 | 29 | def forward( 30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor] 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | r""" 33 | Parameters 34 | ---------- 35 | xyz : torch.Tensor 36 | (B, N, 3) tensor of the xyz coordinates of the features 37 | features : torch.Tensor 38 | (B, C, N) tensor of the descriptors of the the features 39 | 40 | Returns 41 | ------- 42 | new_xyz : torch.Tensor 43 | (B, npoint, 3) tensor of the new features' xyz 44 | new_features : torch.Tensor 45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 46 | """ 47 | 48 | new_features_list = [] 49 | 50 | xyz_flipped = xyz.transpose(1, 2).contiguous() 51 | new_xyz = ( 52 | pointnet2_utils.gather_operation( 53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) 54 | ) 55 | .transpose(1, 2) 56 | .contiguous() 57 | if self.npoint is not None 58 | else None 59 | ) 60 | 61 | for i in range(len(self.groupers)): 62 | new_features = self.groupers[i]( 63 | xyz, new_xyz, features 64 | ) # (B, C, npoint, nsample) 65 | 66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 67 | new_features = F.max_pool2d( 68 | new_features, kernel_size=[1, new_features.size(3)] 69 | ) # (B, mlp[-1], npoint, 1) 70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 71 | 72 | new_features_list.append(new_features) 73 | 74 | return new_xyz, torch.cat(new_features_list, dim=1) 75 | 76 | 77 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 78 | r"""Pointnet set abstrction layer with multiscale grouping 79 | 80 | Parameters 81 | ---------- 82 | npoint : int 83 | Number of features 84 | radii : list of float32 85 | list of radii to group with 86 | nsamples : list of int32 87 | Number of samples in each ball query 88 | mlps : list of list of int32 89 | Spec of the pointnet before the global max_pool for each scale 90 | bn : bool 91 | Use batchnorm 92 | """ 93 | 94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): 95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None 96 | super(PointnetSAModuleMSG, self).__init__() 97 | 98 | assert len(radii) == len(nsamples) == len(mlps) 99 | 100 | self.npoint = npoint 101 | self.groupers = nn.ModuleList() 102 | self.mlps = nn.ModuleList() 103 | for i in range(len(radii)): 104 | radius = radii[i] 105 | nsample = nsamples[i] 106 | self.groupers.append( 107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 108 | if npoint is not None 109 | else pointnet2_utils.GroupAll(use_xyz) 110 | ) 111 | mlp_spec = mlps[i] 112 | if use_xyz: 113 | mlp_spec[0] += 3 114 | 115 | self.mlps.append(build_shared_mlp(mlp_spec, bn)) 116 | 117 | 118 | class PointnetSAModule(PointnetSAModuleMSG): 119 | r"""Pointnet set abstrction layer 120 | 121 | Parameters 122 | ---------- 123 | npoint : int 124 | Number of features 125 | radius : float 126 | Radius of ball 127 | nsample : int 128 | Number of samples in the ball query 129 | mlp : list 130 | Spec of the pointnet before the global max_pool 131 | bn : bool 132 | Use batchnorm 133 | """ 134 | 135 | def __init__( 136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True 137 | ): 138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None 139 | super(PointnetSAModule, self).__init__( 140 | mlps=[mlp], 141 | npoint=npoint, 142 | radii=[radius], 143 | nsamples=[nsample], 144 | bn=bn, 145 | use_xyz=use_xyz, 146 | ) 147 | 148 | 149 | class PointnetFPModule(nn.Module): 150 | r"""Propigates the features of one set to another 151 | 152 | Parameters 153 | ---------- 154 | mlp : list 155 | Pointnet module parameters 156 | bn : bool 157 | Use batchnorm 158 | """ 159 | 160 | def __init__(self, mlp, bn=True): 161 | # type: (PointnetFPModule, List[int], bool) -> None 162 | super(PointnetFPModule, self).__init__() 163 | self.mlp = build_shared_mlp(mlp, bn=bn) 164 | 165 | def forward(self, unknown, known, unknow_feats, known_feats): 166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 167 | r""" 168 | Parameters 169 | ---------- 170 | unknown : torch.Tensor 171 | (B, n, 3) tensor of the xyz positions of the unknown features 172 | known : torch.Tensor 173 | (B, m, 3) tensor of the xyz positions of the known features 174 | unknow_feats : torch.Tensor 175 | (B, C1, n) tensor of the features to be propigated to 176 | known_feats : torch.Tensor 177 | (B, C2, m) tensor of features to be propigated 178 | 179 | Returns 180 | ------- 181 | new_features : torch.Tensor 182 | (B, mlp[-1], n) tensor of the features of the unknown features 183 | """ 184 | 185 | if known is not None: 186 | dist, idx = pointnet2_utils.three_nn(unknown, known) 187 | dist_recip = 1.0 / (dist + 1e-8) 188 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 189 | weight = dist_recip / norm 190 | 191 | interpolated_feats = pointnet2_utils.three_interpolate( 192 | known_feats, idx, weight 193 | ) 194 | else: 195 | interpolated_feats = known_feats.expand( 196 | *(known_feats.size()[0:2] + [unknown.size(1)]) 197 | ) 198 | 199 | if unknow_feats is not None: 200 | new_features = torch.cat( 201 | [interpolated_feats, unknow_feats], dim=1 202 | ) # (B, C2 + C1, n) 203 | else: 204 | new_features = interpolated_feats 205 | 206 | new_features = new_features.unsqueeze(-1) 207 | new_features = self.mlp(new_features) 208 | 209 | return new_features.squeeze(-1) 210 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | from torch.autograd import Function 5 | from typing import * 6 | 7 | try: 8 | import pointnet2_ops._ext as _ext 9 | except ImportError: 10 | from torch.utils.cpp_extension import load 11 | import glob 12 | import os.path as osp 13 | import os 14 | 15 | warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") 16 | 17 | _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") 18 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 19 | osp.join(_ext_src_root, "src", "*.cu") 20 | ) 21 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 22 | 23 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 24 | _ext = load( 25 | "_ext", 26 | sources=_ext_sources, 27 | extra_include_paths=[osp.join(_ext_src_root, "include")], 28 | extra_cflags=["-O3"], 29 | extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"], 30 | with_cuda=True, 31 | ) 32 | 33 | 34 | class FurthestPointSampling(Function): 35 | @staticmethod 36 | def forward(ctx, xyz, npoint): 37 | # type: (Any, torch.Tensor, int) -> torch.Tensor 38 | r""" 39 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 40 | minimum distance 41 | 42 | Parameters 43 | ---------- 44 | xyz : torch.Tensor 45 | (B, N, 3) tensor where N > npoint 46 | npoint : int32 47 | number of features in the sampled set 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, npoint) tensor containing the set 53 | """ 54 | out = _ext.furthest_point_sampling(xyz, npoint) 55 | 56 | ctx.mark_non_differentiable(out) 57 | 58 | return out 59 | 60 | @staticmethod 61 | def backward(ctx, grad_out): 62 | return () 63 | 64 | 65 | furthest_point_sample = FurthestPointSampling.apply 66 | 67 | 68 | class GatherOperation(Function): 69 | @staticmethod 70 | def forward(ctx, features, idx): 71 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 72 | r""" 73 | 74 | Parameters 75 | ---------- 76 | features : torch.Tensor 77 | (B, C, N) tensor 78 | 79 | idx : torch.Tensor 80 | (B, npoint) tensor of the features to gather 81 | 82 | Returns 83 | ------- 84 | torch.Tensor 85 | (B, C, npoint) tensor 86 | """ 87 | 88 | ctx.save_for_backward(idx, features) 89 | 90 | return _ext.gather_points(features, idx) 91 | 92 | @staticmethod 93 | def backward(ctx, grad_out): 94 | idx, features = ctx.saved_tensors 95 | N = features.size(2) 96 | 97 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 98 | return grad_features, None 99 | 100 | 101 | gather_operation = GatherOperation.apply 102 | 103 | 104 | class ThreeNN(Function): 105 | @staticmethod 106 | def forward(ctx, unknown, known): 107 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 108 | r""" 109 | Find the three nearest neighbors of unknown in known 110 | Parameters 111 | ---------- 112 | unknown : torch.Tensor 113 | (B, n, 3) tensor of known features 114 | known : torch.Tensor 115 | (B, m, 3) tensor of unknown features 116 | 117 | Returns 118 | ------- 119 | dist : torch.Tensor 120 | (B, n, 3) l2 distance to the three nearest neighbors 121 | idx : torch.Tensor 122 | (B, n, 3) index of 3 nearest neighbors 123 | """ 124 | dist2, idx = _ext.three_nn(unknown, known) 125 | dist = torch.sqrt(dist2) 126 | 127 | ctx.mark_non_differentiable(dist, idx) 128 | 129 | return dist, idx 130 | 131 | @staticmethod 132 | def backward(ctx, grad_dist, grad_idx): 133 | return () 134 | 135 | 136 | three_nn = ThreeNN.apply 137 | 138 | 139 | class ThreeInterpolate(Function): 140 | @staticmethod 141 | def forward(ctx, features, idx, weight): 142 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 143 | r""" 144 | Performs weight linear interpolation on 3 features 145 | Parameters 146 | ---------- 147 | features : torch.Tensor 148 | (B, c, m) Features descriptors to be interpolated from 149 | idx : torch.Tensor 150 | (B, n, 3) three nearest neighbors of the target features in features 151 | weight : torch.Tensor 152 | (B, n, 3) weights 153 | 154 | Returns 155 | ------- 156 | torch.Tensor 157 | (B, c, n) tensor of the interpolated features 158 | """ 159 | ctx.save_for_backward(idx, weight, features) 160 | 161 | return _ext.three_interpolate(features, idx, weight) 162 | 163 | @staticmethod 164 | def backward(ctx, grad_out): 165 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 166 | r""" 167 | Parameters 168 | ---------- 169 | grad_out : torch.Tensor 170 | (B, c, n) tensor with gradients of ouputs 171 | 172 | Returns 173 | ------- 174 | grad_features : torch.Tensor 175 | (B, c, m) tensor with gradients of features 176 | 177 | None 178 | 179 | None 180 | """ 181 | idx, weight, features = ctx.saved_tensors 182 | m = features.size(2) 183 | 184 | grad_features = _ext.three_interpolate_grad( 185 | grad_out.contiguous(), idx, weight, m 186 | ) 187 | 188 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) 189 | 190 | 191 | three_interpolate = ThreeInterpolate.apply 192 | 193 | 194 | class GroupingOperation(Function): 195 | @staticmethod 196 | def forward(ctx, features, idx): 197 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 198 | r""" 199 | 200 | Parameters 201 | ---------- 202 | features : torch.Tensor 203 | (B, C, N) tensor of features to group 204 | idx : torch.Tensor 205 | (B, npoint, nsample) tensor containing the indicies of features to group with 206 | 207 | Returns 208 | ------- 209 | torch.Tensor 210 | (B, C, npoint, nsample) tensor 211 | """ 212 | ctx.save_for_backward(idx, features) 213 | 214 | return _ext.group_points(features, idx) 215 | 216 | @staticmethod 217 | def backward(ctx, grad_out): 218 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 219 | r""" 220 | 221 | Parameters 222 | ---------- 223 | grad_out : torch.Tensor 224 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 225 | 226 | Returns 227 | ------- 228 | torch.Tensor 229 | (B, C, N) gradient of the features 230 | None 231 | """ 232 | idx, features = ctx.saved_tensors 233 | N = features.size(2) 234 | 235 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 236 | 237 | return grad_features, torch.zeros_like(idx) 238 | 239 | 240 | grouping_operation = GroupingOperation.apply 241 | 242 | 243 | class BallQuery(Function): 244 | @staticmethod 245 | def forward(ctx, radius, nsample, xyz, new_xyz): 246 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 247 | r""" 248 | 249 | Parameters 250 | ---------- 251 | radius : float 252 | radius of the balls 253 | nsample : int 254 | maximum number of features in the balls 255 | xyz : torch.Tensor 256 | (B, N, 3) xyz coordinates of the features 257 | new_xyz : torch.Tensor 258 | (B, npoint, 3) centers of the ball query 259 | 260 | Returns 261 | ------- 262 | torch.Tensor 263 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 264 | """ 265 | output = _ext.ball_query(new_xyz, xyz, radius, nsample) 266 | 267 | ctx.mark_non_differentiable(output) 268 | 269 | return output 270 | 271 | @staticmethod 272 | def backward(ctx, grad_out): 273 | return () 274 | 275 | 276 | ball_query = BallQuery.apply 277 | 278 | 279 | class QueryAndGroup(nn.Module): 280 | r""" 281 | Groups with a ball query of radius 282 | 283 | Parameters 284 | --------- 285 | radius : float32 286 | Radius of ball 287 | nsample : int32 288 | Maximum number of features to gather in the ball 289 | """ 290 | 291 | def __init__(self, radius, nsample, use_xyz=True): 292 | # type: (QueryAndGroup, float, int, bool) -> None 293 | super(QueryAndGroup, self).__init__() 294 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 295 | 296 | def forward(self, xyz, new_xyz, features=None): 297 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 298 | r""" 299 | Parameters 300 | ---------- 301 | xyz : torch.Tensor 302 | xyz coordinates of the features (B, N, 3) 303 | new_xyz : torch.Tensor 304 | centriods (B, npoint, 3) 305 | features : torch.Tensor 306 | Descriptors of the features (B, C, N) 307 | 308 | Returns 309 | ------- 310 | new_features : torch.Tensor 311 | (B, 3 + C, npoint, nsample) tensor 312 | """ 313 | 314 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 315 | xyz_trans = xyz.transpose(1, 2).contiguous() 316 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 317 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 318 | 319 | if features is not None: 320 | grouped_features = grouping_operation(features, idx) 321 | if self.use_xyz: 322 | new_features = torch.cat( 323 | [grouped_xyz, grouped_features], dim=1 324 | ) # (B, C + 3, npoint, nsample) 325 | else: 326 | new_features = grouped_features 327 | else: 328 | assert ( 329 | self.use_xyz 330 | ), "Cannot have not features and not use xyz as a feature!" 331 | new_features = grouped_xyz 332 | 333 | return new_features 334 | 335 | 336 | class GroupAll(nn.Module): 337 | r""" 338 | Groups all features 339 | 340 | Parameters 341 | --------- 342 | """ 343 | 344 | def __init__(self, use_xyz=True): 345 | # type: (GroupAll, bool) -> None 346 | super(GroupAll, self).__init__() 347 | self.use_xyz = use_xyz 348 | 349 | def forward(self, xyz, new_xyz, features=None): 350 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 351 | r""" 352 | Parameters 353 | ---------- 354 | xyz : torch.Tensor 355 | xyz coordinates of the features (B, N, 3) 356 | new_xyz : torch.Tensor 357 | Ignored 358 | features : torch.Tensor 359 | Descriptors of the features (B, C, N) 360 | 361 | Returns 362 | ------- 363 | new_features : torch.Tensor 364 | (B, C + 3, 1, N) tensor 365 | """ 366 | 367 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 368 | if features is not None: 369 | grouped_features = features.unsqueeze(2) 370 | if self.use_xyz: 371 | new_features = torch.cat( 372 | [grouped_xyz, grouped_features], dim=1 373 | ) # (B, 3 + C, 1, N) 374 | else: 375 | new_features = grouped_features 376 | else: 377 | new_features = grouped_xyz 378 | 379 | return new_features 380 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from datasets import dataset_dict 9 | 10 | # colmap 11 | from datasets.colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary 12 | from datasets.llff import center_poses 13 | from datasets.ray_utils import get_ndc_coor, read_gen 14 | 15 | # losses 16 | from losses import loss_dict 17 | 18 | # metrics 19 | from metrics import psnr 20 | from models.cloud_code import CloudNeRF, SHCloudNeRF, config 21 | 22 | # models 23 | from models.nerf import Embedding 24 | from models.rendering import render_rays 25 | from opt import get_opts 26 | 27 | # fps 28 | from pointnet2_ops.pointnet2_utils import furthest_point_sample, gather_operation 29 | 30 | # pytorch-lightning 31 | from pytorch_lightning import LightningModule, Trainer 32 | from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar 33 | from pytorch_lightning.loggers import TensorBoardLogger 34 | from pytorch_lightning.plugins import DDPPlugin 35 | from torch.utils.data import DataLoader 36 | 37 | # optimizer, scheduler, visualization 38 | from utils import get_learning_rate, get_optimizer, get_scheduler, load_ckpt, visualize_depth 39 | 40 | 41 | @torch.no_grad() 42 | def fps(points, n_samples): 43 | points = torch.from_numpy(points).unsqueeze( 44 | 0).float().cuda() # 1, N_points, 3 45 | points_flipped = points.transpose(1, 2).contiguous() 46 | fps_index = furthest_point_sample(points, n_samples) # 1, n_samples 47 | 48 | fps_kps = gather_operation(points_flipped, fps_index).transpose( 49 | 1, 2).contiguous().squeeze(0) # n_samples, 3 50 | return fps_kps.cpu().numpy() 51 | 52 | 53 | class NeRFSystem(LightningModule): 54 | def __init__(self, hparams): 55 | super().__init__() 56 | self.save_hyperparameters(hparams) 57 | self.root_dir = hparams.root_dir 58 | 59 | self.loss = loss_dict["color"](coef=1) 60 | 61 | self.embedding_xyz = Embedding(hparams.N_emb_xyz) 62 | self.embedding_dir = Embedding(hparams.N_emb_dir) 63 | self.embeddings = {"xyz": self.embedding_xyz, 64 | "dir": self.embedding_dir} 65 | 66 | kps, fps_kps = self.read_colmap_meta(hparams) 67 | 68 | if hparams.N_importance == 0: 69 | if hparams.use_sh_feat: 70 | self.nerf_coarse = SHCloudNeRF( 71 | kps, fps_kps, 6 * hparams.N_emb_xyz + 3, 6 * hparams.N_emb_dir + 3) 72 | else: 73 | self.nerf_coarse = CloudNeRF( 74 | kps, fps_kps, 6 * hparams.N_emb_xyz + 3, 6 * hparams.N_emb_dir + 3) 75 | 76 | self.models = {"coarse": self.nerf_coarse} 77 | load_ckpt(self.nerf_coarse, hparams.weight_path, "nerf_coarse") 78 | 79 | else: 80 | if hparams.use_sh_feat: 81 | self.nerf_fine = SHCloudNeRF( 82 | kps, fps_kps, 6 * hparams.N_emb_xyz + 3, 6 * hparams.N_emb_dir + 3) 83 | else: 84 | self.nerf_fine = CloudNeRF( 85 | kps, fps_kps, 6 * hparams.N_emb_xyz + 3, 6 * hparams.N_emb_dir + 3) 86 | 87 | self.models = {"fine": self.nerf_fine} 88 | load_ckpt(self.nerf_fine, hparams.weight_path, "nerf_fine") 89 | 90 | def read_colmap_meta(self, hparams): 91 | camdata = read_cameras_binary(os.path.join( 92 | hparams.root_dir, "sparse/0/cameras.bin")) 93 | self.origin_intrinsics = camdata 94 | W = camdata[1].width 95 | self.focal = camdata[1].params[0] * hparams.img_wh[0] / W 96 | 97 | imdata = read_images_binary(os.path.join( 98 | hparams.root_dir, "sparse/0/images.bin")) 99 | 100 | w2c_mats = [] 101 | bottom = np.array([0, 0, 0, 1.0]).reshape(1, 4) 102 | for k in imdata: 103 | im = imdata[k] 104 | R = im.qvec2rotmat() 105 | t = im.tvec.reshape(3, 1) 106 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)] 107 | w2c_mats = np.stack(w2c_mats, 0) 108 | # (N_images, 3, 4) cam2world matrices 109 | poses = np.linalg.inv(w2c_mats)[:, :3] 110 | self.origin_extrinsics = poses 111 | 112 | pts3d = read_points3d_binary(os.path.join( 113 | hparams.root_dir, "sparse/0/points3D.bin")) 114 | 115 | mvs_points = self.load_mvs_depth().numpy() 116 | near_bound = mvs_points.min(axis=0)[-1] 117 | pts3d = {k: v for (k, v) in pts3d.items() if v.xyz[-1] > near_bound} 118 | 119 | pts_world = np.zeros((1, 3, len(pts3d))) # (1, 3, N_points) 120 | self.bounds = np.zeros((len(poses), 2)) # (N_images, 2) 121 | visibilities = np.zeros((len(poses), len(pts3d)) 122 | ) # (N_images, N_points) 123 | 124 | for i, k in enumerate(pts3d): 125 | pts_world[0, :, i] = pts3d[k].xyz 126 | for j in pts3d[k].image_ids: 127 | visibilities[j - 1, i] = 1 128 | 129 | depths = ((pts_world - poses[..., 3:4]) * poses[..., 2:3]).sum(1) 130 | for i in range(len(poses)): 131 | visibility_i = visibilities[i] 132 | zs = depths[i][visibility_i == 1] 133 | self.bounds[i] = [np.percentile(zs, 0.1), np.percentile(zs, 99.9)] 134 | valid_depth = (depths[i] >= self.bounds[i][0]) & ( 135 | depths[i] <= self.bounds[i][1]) 136 | visibility_i = visibility_i.astype(bool) & valid_depth 137 | visibilities[i] = visibility_i.astype(np.float64) 138 | 139 | valid_points = np.any(visibilities, axis=0) 140 | pts_world = np.transpose(pts_world[0])[valid_points] # (N_points, 3) 141 | 142 | # fps 143 | fps_kps = fps(pts_world, config["code_cloud"]["num_codes"]) 144 | global_kps = fps(mvs_points, pts_world.shape[0]) 145 | 146 | if hparams.not_use_mvs: 147 | pts_world = np.concatenate([pts_world], axis=0) 148 | else: 149 | pts_world = np.concatenate([pts_world, global_kps], axis=0) 150 | 151 | poses = np.concatenate( 152 | [poses[..., 0:1], -poses[..., 1:3], poses[..., 3:4]], -1) 153 | poses, pose_avg = center_poses(poses) 154 | pose_avg_homo = np.eye(4) 155 | pose_avg_homo[:3] = pose_avg 156 | 157 | pts_world_homo = np.concatenate( 158 | [pts_world, np.ones((pts_world.shape[0], 1))], axis=1) 159 | fps_kps_homo = np.concatenate( 160 | [fps_kps, np.ones((fps_kps.shape[0], 1))], axis=1) 161 | 162 | trans_pts_world = np.linalg.inv( 163 | pose_avg_homo) @ pts_world_homo[:, :, None] 164 | trans_fps_kps = np.linalg.inv(pose_avg_homo) @ fps_kps_homo[:, :, None] 165 | 166 | kps = torch.from_numpy(trans_pts_world[:, :3, 0]) 167 | fps_kps = torch.from_numpy(trans_fps_kps[:, :3, 0]) 168 | near_original = self.bounds.min() 169 | scale_factor = near_original * 0.75 # 0.75 is the default parameter 170 | kps /= scale_factor 171 | fps_kps /= scale_factor 172 | 173 | # convert to ndc 174 | kps_ndc = get_ndc_coor( 175 | hparams.img_wh[1], hparams.img_wh[0], self.focal, 1.0, kps) 176 | fps_kps_ndc = get_ndc_coor( 177 | hparams.img_wh[1], hparams.img_wh[0], self.focal, 1.0, fps_kps) 178 | return kps_ndc, fps_kps_ndc 179 | 180 | def load_mvs_depth(self): 181 | depth_glob = os.path.join(self.root_dir, "depths", "*.pfm") 182 | self.depth_list = sorted(glob.glob(depth_glob)) 183 | depths = [] 184 | for i in range(len(self.depth_list)): 185 | depth = read_gen(self.depth_list[i]) 186 | 187 | depths.append(depth) 188 | self.depths = np.stack(depths, 0).astype(np.float32) # N x H x W 189 | 190 | per_view_points = self.project_to_3d() 191 | mvs_points = self.fwd_consistency_check(per_view_points) 192 | return mvs_points 193 | 194 | def project_to_3d(self): 195 | N, H, W = self.depths.shape 196 | focal = self.origin_intrinsics[1].params[0] 197 | origin_h, origin_w = self.origin_intrinsics[1].height, self.origin_intrinsics[1].width 198 | 199 | origin_cy, origin_cx = self.origin_intrinsics[1].params[2], self.origin_intrinsics[1].params[1] 200 | 201 | origin_K = np.array([[focal, 0, origin_cx, 0], [ 202 | 0, focal, origin_cy, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 203 | 204 | origin_K[0, :] /= origin_w 205 | origin_K[1, :] /= origin_h 206 | 207 | self.normalized_K = origin_K 208 | 209 | mvs_K = self.normalized_K.copy() 210 | mvs_K[0, :] *= W 211 | mvs_K[1, :] *= H 212 | self.mvs_K = mvs_K 213 | 214 | inv_mvs_K = np.linalg.pinv(mvs_K) 215 | inv_mvs_K = torch.from_numpy(inv_mvs_K) 216 | 217 | # create mesh grid for mvs image 218 | meshgrid = np.meshgrid(range(W), range(H), indexing="xy") 219 | id_coords = (np.stack(meshgrid, axis=0).astype( 220 | np.float32)).reshape(2, -1) 221 | id_coords = torch.from_numpy(id_coords) 222 | 223 | ones = torch.ones(N, 1, H * W) 224 | 225 | pix_coords = torch.unsqueeze(torch.stack( 226 | [id_coords[0].view(-1), id_coords[1].view(-1)], 0), 0) 227 | pix_coords = pix_coords.repeat(N, 1, 1) 228 | pix_coords = torch.cat([pix_coords, ones], 1) 229 | 230 | # project to cam coord 231 | inv_mvs_K = inv_mvs_K[None, ...].repeat(N, 1, 1).float() 232 | cam_points = torch.matmul(inv_mvs_K[:, :3, :3], pix_coords) 233 | mvs_depth = torch.from_numpy( 234 | self.depths).float().unsqueeze(1).view(N, 1, -1) 235 | cam_points = mvs_depth * cam_points 236 | cam_points = torch.cat([cam_points, ones], 1) 237 | 238 | # project to world coord 239 | T = torch.from_numpy(self.origin_extrinsics).float() 240 | world_points = torch.matmul(T, cam_points) 241 | world_points = world_points.permute(0, 2, 1) # N, H*W, 3 242 | 243 | return world_points 244 | 245 | def fwd_consistency_check(self, per_view_points): 246 | N, H, W = self.depths.shape 247 | global_valid_points = [] 248 | for view_id in range(per_view_points.shape[0]): 249 | curr_view_points = per_view_points[view_id].transpose( 250 | 1, 0) # 3, H*W 251 | homo_view_points = torch.cat( 252 | [curr_view_points, torch.ones(1, H * W)], dim=0) # 4, H*W 253 | homo_view_points = homo_view_points.unsqueeze( 254 | 0).repeat(N, 1, 1) # N,4,H*W 255 | 256 | # project to camera space 257 | T = torch.from_numpy(self.origin_extrinsics).float() 258 | homo_T = torch.cat([T, torch.zeros(N, 1, 4)], dim=1) 259 | homo_T[:, -1, -1] = 1 260 | inv_T = torch.inverse(homo_T) 261 | cam_points = torch.matmul(inv_T[:, :3, :], homo_view_points) 262 | 263 | # project to image space 264 | mvs_K = torch.from_numpy(self.mvs_K).unsqueeze( 265 | 0).repeat(N, 1, 1).float() 266 | cam_points = torch.matmul(mvs_K[:, :3, :3], cam_points) 267 | cam_points[:, :2, :] /= cam_points[:, 2:, :] 268 | 269 | z_values = cam_points[:, 2:, :].view(N, 1, H, W) # N,1,H,W 270 | xy_coords = cam_points[:, :2, :].transpose( 271 | 2, 1).view(N, H, W, 2) # N,H,W,2 272 | 273 | xy_coords[..., 0] /= W - 1 274 | xy_coords[..., 1] /= H - 1 275 | xy_coords = (xy_coords - 0.5) * 2 276 | 277 | mvs_depth = torch.from_numpy( 278 | self.depths).float().unsqueeze(1) # N,1,H,W 279 | ref_z_values = F.grid_sample( 280 | mvs_depth, xy_coords, mode="bilinear", align_corners=False) 281 | err = z_values - 0.9 * ref_z_values 282 | 283 | visible_mask = ref_z_values != 0 284 | visible_count = visible_mask.int().sum(0) 285 | valid_visible = visible_count >= 1 286 | valid_points = err >= 0 287 | valid_points = torch.all( 288 | valid_points, dim=0) & valid_visible # 1,H,W 289 | global_valid_points.append(valid_points) 290 | global_valid_points = torch.cat( 291 | global_valid_points, dim=0).view(N, H * W) # N,H,W 292 | 293 | filtered_points = per_view_points[global_valid_points, :] 294 | return filtered_points 295 | 296 | def forward(self, rays): 297 | """Do batched inference on rays using chunk.""" 298 | B = rays.shape[0] 299 | results = defaultdict(list) 300 | for i in range(0, B, self.hparams.chunk): 301 | rendered_ray_chunks = render_rays( 302 | self.models, 303 | self.embeddings, 304 | rays[i: i + self.hparams.chunk], 305 | self.hparams.N_samples, 306 | self.hparams.use_disp, 307 | self.hparams.perturb, 308 | self.hparams.noise_std, 309 | self.hparams.N_importance, 310 | self.hparams.chunk, # chunk size is effective in val mode 311 | self.train_dataset.white_back, 312 | use_sh_feat = self.hparams.use_sh_feat 313 | ) 314 | 315 | for k, v in rendered_ray_chunks.items(): 316 | results[k] += [v] 317 | 318 | for k, v in results.items(): 319 | results[k] = torch.cat(v, 0) 320 | return results 321 | 322 | def setup(self, stage): 323 | dataset = dataset_dict[self.hparams.dataset_name] 324 | kwargs = {"root_dir": self.hparams.root_dir, 325 | "img_wh": tuple(self.hparams.img_wh)} 326 | if self.hparams.dataset_name == "llff": 327 | kwargs["val_num"] = 3 328 | self.train_dataset = dataset(split="train", **kwargs) 329 | self.val_dataset = dataset(split="val", **kwargs) 330 | 331 | def configure_optimizers(self): 332 | self.optimizer = get_optimizer(self.hparams, self.models) 333 | scheduler = get_scheduler(self.hparams, self.optimizer) 334 | return [self.optimizer], [scheduler] 335 | 336 | def train_dataloader(self): 337 | return DataLoader( 338 | self.train_dataset, shuffle=True, num_workers=4, batch_size=self.hparams.batch_size, pin_memory=True 339 | ) 340 | 341 | def val_dataloader(self): 342 | return DataLoader( 343 | self.val_dataset, 344 | shuffle=False, 345 | num_workers=4, 346 | # validate one image (H*W rays) at a time 347 | batch_size=1, 348 | pin_memory=True, 349 | ) 350 | 351 | def training_step(self, batch, batch_nb): 352 | rays, rgbs = batch["rays"], batch["rgbs"] 353 | results = self(rays) 354 | loss = self.loss(results, rgbs) 355 | 356 | with torch.no_grad(): 357 | typ = "fine" if "rgb_fine" in results else "coarse" 358 | psnr_ = psnr(results[f"rgb_{typ}"], rgbs) 359 | 360 | self.log("lr", get_learning_rate(self.optimizer)) 361 | self.log("train/loss", loss) 362 | self.log("train/psnr", psnr_, prog_bar=True) 363 | 364 | return loss 365 | 366 | def validation_step(self, batch, batch_nb): 367 | rays, rgbs = batch["rays"], batch["rgbs"] 368 | rays = rays.squeeze() # (H*W, 3) 369 | rgbs = rgbs.squeeze() # (H*W, 3) 370 | results = self(rays) 371 | log = {"val_loss": self.loss(results, rgbs)} 372 | typ = "fine" if "rgb_fine" in results else "coarse" 373 | 374 | if batch_nb == 0: 375 | W, H = self.hparams.img_wh 376 | img = results[f"rgb_{typ}"].view( 377 | H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) 378 | img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) 379 | depth = visualize_depth( 380 | results[f"depth_{typ}"].view(H, W)) # (3, H, W) 381 | stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W) 382 | self.logger.experiment.add_images( 383 | "val/GT_pred_depth", stack, self.global_step) 384 | 385 | psnr_ = psnr(results[f"rgb_{typ}"], rgbs) 386 | log["val_psnr"] = psnr_ 387 | 388 | return log 389 | 390 | def validation_epoch_end(self, outputs): 391 | mean_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 392 | mean_psnr = torch.stack([x["val_psnr"] for x in outputs]).mean() 393 | 394 | self.log("val/loss", mean_loss) 395 | self.log("val/psnr", mean_psnr, prog_bar=True) 396 | 397 | 398 | def main(hparams): 399 | system = NeRFSystem(hparams) 400 | ckpt_cb = ModelCheckpoint( 401 | dirpath=f"ckpts/{hparams.exp_name}", filename="{epoch:d}", monitor="val/psnr", mode="max", save_top_k=5 402 | ) 403 | pbar = TQDMProgressBar(refresh_rate=1) 404 | callbacks = [ckpt_cb, pbar] 405 | 406 | logger = TensorBoardLogger( 407 | save_dir="logs", name=hparams.exp_name, default_hp_metric=False) 408 | 409 | trainer = Trainer( 410 | max_epochs=hparams.num_epochs, 411 | callbacks=callbacks, 412 | resume_from_checkpoint=hparams.ckpt_path, 413 | logger=logger, 414 | enable_model_summary=False, 415 | accelerator="auto", 416 | devices=hparams.num_gpus, 417 | num_sanity_val_steps=1, 418 | benchmark=True, 419 | profiler="simple" if hparams.num_gpus == 1 else None, 420 | strategy=DDPPlugin( 421 | find_unused_parameters=False) if hparams.num_gpus > 1 else None, 422 | ) 423 | 424 | trainer.fit(system) 425 | 426 | 427 | if __name__ == "__main__": 428 | hparams = get_opts() 429 | main(hparams) 430 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_optimizer as optim 3 | 4 | # optimizer 5 | from torch.optim import SGD, Adam 6 | 7 | # scheduler 8 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR 9 | 10 | from .warmup_scheduler import GradualWarmupScheduler 11 | from .visualization import visualize_depth 12 | 13 | 14 | def get_parameters(models): 15 | """Get all model parameters recursively.""" 16 | parameters = [] 17 | # latent_parameters = [] 18 | if isinstance(models, list): 19 | for model in models: 20 | parameters += get_parameters(model) 21 | elif isinstance(models, dict): 22 | for model in models.values(): 23 | parameters += get_parameters(model) 24 | else: # models is actually a single pytorch model 25 | parameters += list(models.parameters()) 26 | return parameters 27 | 28 | 29 | def get_optimizer(hparams, models): 30 | eps = 1e-8 31 | parameters = get_parameters(models) 32 | if hparams.optimizer == "sgd": 33 | optimizer = SGD(parameters, lr=hparams.lr, 34 | momentum=hparams.momentum, weight_decay=hparams.weight_decay) 35 | elif hparams.optimizer == "adam": 36 | optimizer = Adam(parameters, lr=hparams.lr, eps=eps, 37 | weight_decay=hparams.weight_decay) 38 | elif hparams.optimizer == "radam": 39 | optimizer = optim.RAdam(parameters, lr=hparams.lr, 40 | eps=eps, weight_decay=hparams.weight_decay) 41 | elif hparams.optimizer == "ranger": 42 | optimizer = optim.Ranger( 43 | parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay) 44 | else: 45 | raise ValueError("optimizer not recognized!") 46 | 47 | return optimizer 48 | 49 | 50 | def get_scheduler(hparams, optimizer): 51 | eps = 1e-8 52 | if hparams.lr_scheduler == "steplr": 53 | scheduler = MultiStepLR( 54 | optimizer, milestones=hparams.decay_step, gamma=hparams.decay_gamma) 55 | elif hparams.lr_scheduler == "cosine": 56 | scheduler = CosineAnnealingLR( 57 | optimizer, T_max=hparams.num_epochs, eta_min=eps) 58 | else: 59 | raise ValueError("scheduler not recognized!") 60 | 61 | if hparams.warmup_epochs > 0 and hparams.optimizer not in ["radam", "ranger"]: 62 | scheduler = GradualWarmupScheduler( 63 | optimizer, 64 | multiplier=hparams.warmup_multiplier, 65 | total_epoch=hparams.warmup_epochs, 66 | after_scheduler=scheduler, 67 | ) 68 | 69 | return scheduler 70 | 71 | 72 | def get_learning_rate(optimizer): 73 | for param_group in optimizer.param_groups: 74 | return param_group["lr"] 75 | 76 | 77 | def extract_model_state_dict(ckpt_path, model_name="model", prefixes_to_ignore=[]): 78 | checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu")) 79 | checkpoint_ = {} 80 | if "state_dict" in checkpoint: # if it's a pytorch-lightning checkpoint 81 | checkpoint = checkpoint["state_dict"] 82 | for k, v in checkpoint.items(): 83 | if not k.startswith(model_name): 84 | continue 85 | k = k[len(model_name) + 1:] 86 | for prefix in prefixes_to_ignore: 87 | if k.startswith(prefix): 88 | print("ignore", k) 89 | break 90 | else: 91 | checkpoint_[k] = v 92 | return checkpoint_ 93 | 94 | 95 | def load_ckpt(model, ckpt_path, model_name="model", prefixes_to_ignore=[]): 96 | if not ckpt_path: 97 | return 98 | model_dict = model.state_dict() 99 | checkpoint_ = extract_model_state_dict( 100 | ckpt_path, model_name, prefixes_to_ignore) 101 | model_dict.update(checkpoint_) 102 | model.load_state_dict(model_dict) 103 | -------------------------------------------------------------------------------- /utils/save_weights_only.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | 6 | def get_opts(): 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument("--ckpt_path", type=str, required=True, help="checkpoint path") 10 | 11 | return parser.parse_args() 12 | 13 | 14 | if __name__ == "__main__": 15 | args = get_opts() 16 | checkpoint = torch.load(args.ckpt_path, map_location=torch.device("cpu")) 17 | torch.save(checkpoint["state_dict"], args.ckpt_path.split("/")[-2] + ".ckpt") 18 | print("Done!") 19 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torchvision.transforms as T 4 | from PIL import Image 5 | 6 | 7 | def visualize_depth(depth, cmap=cv2.COLORMAP_JET): 8 | """ 9 | depth: (H, W) 10 | """ 11 | x = depth.cpu().numpy() 12 | x = np.nan_to_num(x) # change nan to 0 13 | mi = np.min(x) # get minimum depth 14 | ma = np.max(x) 15 | x = (x - mi) / max(ma - mi, 1e-8) # normalize to 0~1 16 | x = (255 * x).astype(np.uint8) 17 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 18 | x_ = T.ToTensor()(x_) # (3, H, W) 19 | return x_ 20 | -------------------------------------------------------------------------------- /utils/warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler 2 | 3 | 4 | class GradualWarmupScheduler(_LRScheduler): 5 | """Gradually warm-up(increasing) learning rate in optimizer. 6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 7 | Args: 8 | optimizer (Optimizer): Wrapped optimizer. 9 | multiplier: target learning rate = base lr * multiplier 10 | total_epoch: target learning rate is reached at total_epoch, gradually 11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 12 | """ 13 | 14 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 15 | self.multiplier = multiplier 16 | if self.multiplier < 1.0: 17 | raise ValueError("multiplier should be greater thant or equal to 1.") 18 | self.total_epoch = total_epoch 19 | self.after_scheduler = after_scheduler 20 | self.finished = False 21 | super().__init__(optimizer) 22 | 23 | def get_lr(self): 24 | if self.last_epoch > self.total_epoch: 25 | if self.after_scheduler: 26 | if not self.finished: 27 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 28 | self.finished = True 29 | return self.after_scheduler.get_lr() 30 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 31 | 32 | return [ 33 | base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) for base_lr in self.base_lrs 34 | ] 35 | 36 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 37 | if epoch is None: 38 | epoch = self.last_epoch + 1 39 | self.last_epoch = ( 40 | epoch if epoch != 0 else 1 41 | ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 42 | if self.last_epoch <= self.total_epoch: 43 | warmup_lr = [ 44 | base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) 45 | for base_lr in self.base_lrs 46 | ] 47 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 48 | param_group["lr"] = lr 49 | else: 50 | if epoch is None: 51 | self.after_scheduler.step(metrics, None) 52 | else: 53 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 54 | 55 | def step(self, epoch=None, metrics=None): 56 | if type(self.after_scheduler) != ReduceLROnPlateau: 57 | if self.finished and self.after_scheduler: 58 | if epoch is None: 59 | self.after_scheduler.step(None) 60 | else: 61 | self.after_scheduler.step(epoch - self.total_epoch) 62 | else: 63 | return super(GradualWarmupScheduler, self).step(epoch) 64 | else: 65 | self.step_ReduceLROnPlateau(metrics, epoch) 66 | --------------------------------------------------------------------------------