├── .gitignore
├── License
├── README.md
├── assets
├── fern.gif
├── fern.jpg
├── horn.gif
├── horn.jpg
├── room.gif
└── room.jpg
├── datasets
├── __init__.py
├── blender.py
├── colmap_utils.py
├── depth_utils.py
├── llff.py
└── ray_utils.py
├── eval.py
├── losses.py
├── metrics.py
├── models
├── __init__.py
├── cloud_code.py
├── nerf.py
└── rendering.py
├── opt.py
├── pointnet2_ops_lib
└── pointnet2_ops
│ ├── __init__.py
│ ├── _ext-src
│ ├── include
│ │ ├── ball_query.h
│ │ ├── cuda_utils.h
│ │ ├── group_points.h
│ │ ├── interpolate.h
│ │ ├── sampling.h
│ │ └── utils.h
│ └── src
│ │ ├── ball_query.cpp
│ │ ├── ball_query_gpu.cu
│ │ ├── bindings.cpp
│ │ ├── group_points.cpp
│ │ ├── group_points_gpu.cu
│ │ ├── interpolate.cpp
│ │ ├── interpolate_gpu.cu
│ │ ├── sampling.cpp
│ │ └── sampling_gpu.cu
│ ├── _version.py
│ ├── pointnet2_modules.py
│ └── pointnet2_utils.py
├── train.py
└── utils
├── __init__.py
├── save_weights_only.py
├── visualization.py
└── warmup_scheduler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | logs/
3 | ckpts/
4 | results/
5 | *.ply
6 | *.vol
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | pip-wheel-metadata/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
--------------------------------------------------------------------------------
/License:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 KAIST-VICLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cloud NeRF
2 | This is the official implementation of Neural Radiance Fields with Points Cloud Latent Representation.
3 | [[Paper]]()[[Project Page]]()
4 | ## Instruction
5 | - Please download and arrange the dataset same as the instruction.
6 | - For the environment, we provide our Docker image for the best reproduction.
7 | - The scen optimization scripts are provided in the instruction.
8 | ## Data
9 | - We evaluate our framework on forward-facing LLFF dataset, available at [Google drive](https://drive.google.com/drive/folders/14boI-o5hGO9srnWaaogTU5_ji7wkX2S7).
10 | - We also need to download the pre-trained MVS depth estimation at [Google drive](https://drive.google.com/drive/folders/13lreojzboR7X7voJ1q8JduvWDdzyrwRe).
11 | - Our data folder structure is same as follow:
12 | ```
13 |
14 | ├── datasets
15 | │ ├── nerf_llff_data
16 | │ │ │──fern
17 | │ │ │ | |──depths
18 | │ │ │ | |──iamges
19 | │ │ │ | |──images_4
20 | │ │ │ | |──sparse
21 | │ │ │ | |──colmap_depth.npy
22 | │ │ │ | |──poses_bounds.npy
23 | │ │ │ | |──...
24 | ```
25 | ## Docker
26 | - We provide the Docker images of our environment at [DockerHub](https://hub.docker.com/repository/docker/quan5609/cloud_nerf).
27 | - To create docker container from image, run the following command
28 | ```
29 | docker run \
30 | --name ${CONTAINER_NAME} \
31 | --gpus all \
32 | --mount type=bind,source="${PATH_TO_SOURCE}",target="/workspace/source" \
33 | --mount type=bind,source="${PATH_TO_DATASETS}",target="/workspace/datasets/" \
34 | --shm-size=16GB \
35 | -it ${IMAGE_NAME}
36 | ```
37 | ## Train & Evaluation
38 | - To train from scratch, run the following command
39 | ```
40 | CUDA_VISIBLE_DEVICES=1 python train.py \
41 | --dataset_name llff \
42 | --root_dir /workspace/datasets/nerf_llff_data/${SCENE_NAME}/ \
43 | --N_importance 64 \
44 | --N_sample 64 \
45 | --img_wh 1008 756 \
46 | --num_epochs 10 \
47 | --batch_size 4096 \
48 | --optimizer adam \
49 | --lr 5e-3 \
50 | --lr_scheduler steplr \
51 | --decay_step 2 4 6 8 \
52 | --decay_gamma 0.5 \
53 | --exp_name ${EXP_NAME}
54 | ```
55 | - To evaluate a checkpoint, run the following command
56 | ```
57 | CUDA_VISIBLE_DEVICES=1 python eval.py \
58 | --dataset_name llff \
59 | --root_dir /workspace/datasets/nerf_llff_data/${SCENE_NAME}/ \
60 | --N_importance 64 \
61 | --N_sample 64 \
62 | --img_wh 1008 756 \
63 | --weight_path ${PATH_TO_CHECKPOINT} \
64 | --split val
65 | ```
66 | ## Visualization
67 |
68 | - Visualization of Fern scene
69 |
70 |
71 | - Visualization of Horn scene
72 |
73 |
74 | - Visualization of Room scene
75 |
76 |
77 |
78 |
79 | # Acknowledgement
80 | Our repo is based on [nerf](https://github.com/bmild/nerf), [nerf_pl](https://github.com/kwea123/nerf_pl), [DCCDIF](https://github.com/lity20/DCCDIF), and [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch).
81 |
--------------------------------------------------------------------------------
/assets/fern.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/fern.gif
--------------------------------------------------------------------------------
/assets/fern.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/fern.jpg
--------------------------------------------------------------------------------
/assets/horn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/horn.gif
--------------------------------------------------------------------------------
/assets/horn.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/horn.jpg
--------------------------------------------------------------------------------
/assets/room.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/room.gif
--------------------------------------------------------------------------------
/assets/room.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/assets/room.jpg
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .blender import BlenderDataset
2 | from .llff import LLFFDataset
3 |
4 |
5 | dataset_dict = {"blender": BlenderDataset, "llff": LLFFDataset}
6 |
--------------------------------------------------------------------------------
/datasets/blender.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms as T
9 |
10 | from .ray_utils import get_ray_directions, get_rays
11 |
12 |
13 | class BlenderDataset(Dataset):
14 | def __init__(self, root_dir, split="train", img_wh=(800, 800)):
15 | self.root_dir = root_dir
16 | self.split = split
17 | assert img_wh[0] == img_wh[1], "image width must equal image height!"
18 | self.img_wh = img_wh
19 | self.define_transforms()
20 |
21 | self.read_meta()
22 | self.white_back = True
23 |
24 | def read_meta(self):
25 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), "r") as f:
26 | self.meta = json.load(f)
27 |
28 | w, h = self.img_wh
29 | # original focal length
30 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta["camera_angle_x"])
31 | # when W=800
32 |
33 | # modify focal length to match size self.img_wh
34 | self.focal *= self.img_wh[0] / 800
35 |
36 | # bounds, common for all scenes
37 | self.near = 2.0
38 | self.far = 6.0
39 | self.bounds = np.array([self.near, self.far])
40 |
41 | # ray directions for all pixels, same for all images (same H, W, focal)
42 | self.directions = get_ray_directions(h, w, self.focal) # (h, w, 3)
43 |
44 | if self.split == "train": # create buffer of all rays and rgb data
45 | self.image_paths = []
46 | self.poses = []
47 | self.all_rays = []
48 | self.all_rgbs = []
49 | for frame in self.meta["frames"]:
50 | pose = np.array(frame["transform_matrix"])[:3, :4]
51 | self.poses += [pose]
52 | c2w = torch.FloatTensor(pose)
53 |
54 | image_path = os.path.join(
55 | self.root_dir, f"{frame['file_path']}.png")
56 | self.image_paths += [image_path]
57 | img = Image.open(image_path)
58 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS)
59 | img = self.transform(img) # (4, h, w)
60 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
61 | img = img[:, :3] * img[:, -1:] + \
62 | (1 - img[:, -1:]) # blend A to RGB
63 | self.all_rgbs += [img]
64 |
65 | rays_o, rays_d = get_rays(
66 | self.directions, c2w) # both (h*w, 3)
67 |
68 | self.all_rays += [
69 | torch.cat(
70 | [
71 | rays_o,
72 | rays_d,
73 | self.near * torch.ones_like(rays_o[:, :1]),
74 | self.far * torch.ones_like(rays_o[:, :1]),
75 | ],
76 | 1,
77 | )
78 | ] # (h*w, 8)
79 |
80 | # (len(self.meta['frames])*h*w, 3)
81 | self.all_rays = torch.cat(self.all_rays, 0)
82 | # (len(self.meta['frames])*h*w, 3)
83 | self.all_rgbs = torch.cat(self.all_rgbs, 0)
84 |
85 | def define_transforms(self):
86 | self.transform = T.ToTensor()
87 |
88 | def __len__(self):
89 | if self.split == "train":
90 | return len(self.all_rays)
91 | if self.split == "val":
92 | return 8 # only validate 8 images (to support <=8 gpus)
93 | return len(self.meta["frames"])
94 |
95 | def __getitem__(self, idx):
96 | if self.split == "train": # use data in the buffers
97 | sample = {"rays": self.all_rays[idx], "rgbs": self.all_rgbs[idx]}
98 |
99 | else: # create data for each image separately
100 | frame = self.meta["frames"][idx]
101 | c2w = torch.FloatTensor(frame["transform_matrix"])[:3, :4]
102 |
103 | img = Image.open(os.path.join(
104 | self.root_dir, f"{frame['file_path']}.png"))
105 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS)
106 | img = self.transform(img) # (4, H, W)
107 | valid_mask = (img[-1] > 0).flatten() # (H*W) valid color area
108 | img = img.view(4, -1).permute(1, 0) # (H*W, 4) RGBA
109 | img = img[:, :3] * img[:, -1:] + \
110 | (1 - img[:, -1:]) # blend A to RGB
111 |
112 | rays_o, rays_d = get_rays(self.directions, c2w)
113 |
114 | rays = torch.cat(
115 | [
116 | rays_o,
117 | rays_d,
118 | self.near * torch.ones_like(rays_o[:, :1]),
119 | self.far * torch.ones_like(rays_o[:, :1]),
120 | ],
121 | 1,
122 | ) # (H*W, 8)
123 |
124 | sample = {"rays": rays, "rgbs": img,
125 | "c2w": c2w, "valid_mask": valid_mask}
126 |
127 | return sample
128 |
--------------------------------------------------------------------------------
/datasets/colmap_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright
8 | # notice, this list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 | # its contributors may be used to endorse or promote products derived
16 | # from this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 | # POSSIBILITY OF SUCH DAMAGE.
29 | #
30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)
31 |
32 | import collections
33 | import os
34 | import struct
35 |
36 | import numpy as np
37 |
38 |
39 | CameraModel = collections.namedtuple(
40 | "CameraModel", ["model_id", "model_name", "num_params"])
41 | Camera = collections.namedtuple(
42 | "Camera", ["id", "model", "width", "height", "params"])
43 | BaseImage = collections.namedtuple(
44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
45 | Point3D = collections.namedtuple(
46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
47 |
48 |
49 | class Image(BaseImage):
50 | def qvec2rotmat(self):
51 | return qvec2rotmat(self.qvec)
52 |
53 |
54 | CAMERA_MODELS = {
55 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
56 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
57 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
58 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
59 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
60 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
61 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
62 | CameraModel(model_id=7, model_name="FOV", num_params=5),
63 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
64 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
65 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
66 | }
67 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
68 | for camera_model in CAMERA_MODELS])
69 |
70 |
71 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
72 | """Read and unpack the next bytes from a binary file.
73 | :param fid:
74 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
75 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
76 | :param endian_character: Any of {@, =, <, >, !}
77 | :return: Tuple of read and unpacked values.
78 | """
79 | data = fid.read(num_bytes)
80 | return struct.unpack(endian_character + format_char_sequence, data)
81 |
82 |
83 | def read_cameras_text(path):
84 | """
85 | see: src/base/reconstruction.cc
86 | void Reconstruction::WriteCamerasText(const std::string& path)
87 | void Reconstruction::ReadCamerasText(const std::string& path)
88 | """
89 | cameras = {}
90 | with open(path, "r") as fid:
91 | while True:
92 | line = fid.readline()
93 | if not line:
94 | break
95 | line = line.strip()
96 | if len(line) > 0 and line[0] != "#":
97 | elems = line.split()
98 | camera_id = int(elems[0])
99 | model = elems[1]
100 | width = int(elems[2])
101 | height = int(elems[3])
102 | params = np.array(tuple(map(float, elems[4:])))
103 | cameras[camera_id] = Camera(
104 | id=camera_id, model=model, width=width, height=height, params=params)
105 | return cameras
106 |
107 |
108 | def read_cameras_binary(path_to_model_file):
109 | """
110 | see: src/base/reconstruction.cc
111 | void Reconstruction::WriteCamerasBinary(const std::string& path)
112 | void Reconstruction::ReadCamerasBinary(const std::string& path)
113 | """
114 | cameras = {}
115 | with open(path_to_model_file, "rb") as fid:
116 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
117 | for camera_line_index in range(num_cameras):
118 | camera_properties = read_next_bytes(
119 | fid, num_bytes=24, format_char_sequence="iiQQ")
120 | camera_id = camera_properties[0]
121 | model_id = camera_properties[1]
122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
123 | width = camera_properties[2]
124 | height = camera_properties[3]
125 | num_params = CAMERA_MODEL_IDS[model_id].num_params
126 | params = read_next_bytes(
127 | fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params)
128 | cameras[camera_id] = Camera(
129 | id=camera_id, model=model_name, width=width, height=height, params=np.array(
130 | params)
131 | )
132 | assert len(cameras) == num_cameras
133 | return cameras
134 |
135 |
136 | def read_images_text(path):
137 | """
138 | see: src/base/reconstruction.cc
139 | void Reconstruction::ReadImagesText(const std::string& path)
140 | void Reconstruction::WriteImagesText(const std::string& path)
141 | """
142 | images = {}
143 | with open(path, "r") as fid:
144 | while True:
145 | line = fid.readline()
146 | if not line:
147 | break
148 | line = line.strip()
149 | if len(line) > 0 and line[0] != "#":
150 | elems = line.split()
151 | image_id = int(elems[0])
152 | qvec = np.array(tuple(map(float, elems[1:5])))
153 | tvec = np.array(tuple(map(float, elems[5:8])))
154 | camera_id = int(elems[8])
155 | image_name = elems[9]
156 | elems = fid.readline().split()
157 | xys = np.column_stack(
158 | [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))])
159 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
160 | images[image_id] = Image(
161 | id=image_id,
162 | qvec=qvec,
163 | tvec=tvec,
164 | camera_id=camera_id,
165 | name=image_name,
166 | xys=xys,
167 | point3D_ids=point3D_ids,
168 | )
169 | return images
170 |
171 |
172 | def read_images_binary(path_to_model_file):
173 | """
174 | see: src/base/reconstruction.cc
175 | void Reconstruction::ReadImagesBinary(const std::string& path)
176 | void Reconstruction::WriteImagesBinary(const std::string& path)
177 | """
178 | images = {}
179 | with open(path_to_model_file, "rb") as fid:
180 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
181 | for image_index in range(num_reg_images):
182 | binary_image_properties = read_next_bytes(
183 | fid, num_bytes=64, format_char_sequence="idddddddi")
184 | image_id = binary_image_properties[0]
185 | qvec = np.array(binary_image_properties[1:5])
186 | tvec = np.array(binary_image_properties[5:8])
187 | camera_id = binary_image_properties[8]
188 | image_name = ""
189 | current_char = read_next_bytes(fid, 1, "c")[0]
190 | while current_char != b"\x00": # look for the ASCII 0 entry
191 | image_name += current_char.decode("utf-8")
192 | current_char = read_next_bytes(fid, 1, "c")[0]
193 | num_points2D = read_next_bytes(
194 | fid, num_bytes=8, format_char_sequence="Q")[0]
195 | x_y_id_s = read_next_bytes(
196 | fid, num_bytes=24 * num_points2D, format_char_sequence="ddq" * num_points2D)
197 | xys = np.column_stack(
198 | [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))])
199 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
200 | images[image_id] = Image(
201 | id=image_id,
202 | qvec=qvec,
203 | tvec=tvec,
204 | camera_id=camera_id,
205 | name=image_name,
206 | xys=xys,
207 | point3D_ids=point3D_ids,
208 | )
209 | return images
210 |
211 |
212 | def read_points3D_text(path):
213 | """
214 | see: src/base/reconstruction.cc
215 | void Reconstruction::ReadPoints3DText(const std::string& path)
216 | void Reconstruction::WritePoints3DText(const std::string& path)
217 | """
218 | points3D = {}
219 | with open(path, "r") as fid:
220 | while True:
221 | line = fid.readline()
222 | if not line:
223 | break
224 | line = line.strip()
225 | if len(line) > 0 and line[0] != "#":
226 | elems = line.split()
227 | point3D_id = int(elems[0])
228 | xyz = np.array(tuple(map(float, elems[1:4])))
229 | rgb = np.array(tuple(map(int, elems[4:7])))
230 | error = float(elems[7])
231 | image_ids = np.array(tuple(map(int, elems[8::2])))
232 | point2D_idxs = np.array(tuple(map(int, elems[9::2])))
233 | points3D[point3D_id] = Point3D(
234 | id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs
235 | )
236 | return points3D
237 |
238 |
239 | def read_points3d_binary(path_to_model_file):
240 | """
241 | see: src/base/reconstruction.cc
242 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
243 | void Reconstruction::WritePoints3DBinary(const std::string& path)
244 | """
245 | points3D = {}
246 | with open(path_to_model_file, "rb") as fid:
247 | num_points = read_next_bytes(fid, 8, "Q")[0]
248 | for point_line_index in range(num_points):
249 | binary_point_line_properties = read_next_bytes(
250 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
251 | point3D_id = binary_point_line_properties[0]
252 | xyz = np.array(binary_point_line_properties[1:4])
253 | rgb = np.array(binary_point_line_properties[4:7])
254 | error = np.array(binary_point_line_properties[7])
255 | track_length = read_next_bytes(
256 | fid, num_bytes=8, format_char_sequence="Q")[0]
257 | track_elems = read_next_bytes(
258 | fid, num_bytes=8 * track_length, format_char_sequence="ii" * track_length)
259 | image_ids = np.array(tuple(map(int, track_elems[0::2])))
260 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
261 | points3D[point3D_id] = Point3D(
262 | id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs
263 | )
264 | return points3D
265 |
266 |
267 | def read_model(path, ext):
268 | if ext == ".txt":
269 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
270 | images = read_images_text(os.path.join(path, "images" + ext))
271 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
272 | else:
273 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
274 | images = read_images_binary(os.path.join(path, "images" + ext))
275 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
276 | return cameras, images, points3D
277 |
278 |
279 | def qvec2rotmat(qvec):
280 | return np.array(
281 | [
282 | [
283 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
284 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
285 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
286 | ],
287 | [
288 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
289 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
290 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
291 | ],
292 | [
293 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
294 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
295 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
296 | ],
297 | ]
298 | )
299 |
300 |
301 | def rotmat2qvec(R):
302 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
303 | K = (
304 | np.array(
305 | [
306 | [Rxx - Ryy - Rzz, 0, 0, 0],
307 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
308 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
309 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
310 | ]
311 | )
312 | / 3.0
313 | )
314 | eigvals, eigvecs = np.linalg.eigh(K)
315 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
316 | if qvec[0] < 0:
317 | qvec *= -1
318 | return qvec
319 |
--------------------------------------------------------------------------------
/datasets/depth_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 |
4 | import numpy as np
5 |
6 |
7 | def read_pfm(filename):
8 | file = open(filename, "rb")
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().decode("utf-8").rstrip()
16 | if header == "PF":
17 | color = True
18 | elif header == "Pf":
19 | color = False
20 | else:
21 | raise Exception("Not a PFM file.")
22 |
23 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("utf-8"))
24 | if dim_match:
25 | width, height = map(int, dim_match.groups())
26 | else:
27 | raise Exception("Malformed PFM header.")
28 |
29 | scale = float(file.readline().rstrip())
30 | if scale < 0: # little-endian
31 | endian = "<"
32 | scale = -scale
33 | else:
34 | endian = ">" # big-endian
35 |
36 | data = np.fromfile(file, endian + "f")
37 | shape = (height, width, 3) if color else (height, width)
38 |
39 | data = np.reshape(data, shape)
40 | data = np.flipud(data)
41 | file.close()
42 | return data, scale
43 |
44 |
45 | def save_pfm(filename, image, scale=1):
46 | file = open(filename, "wb")
47 | color = None
48 |
49 | image = np.flipud(image)
50 |
51 | if image.dtype.name != "float32":
52 | raise Exception("Image dtype must be float32.")
53 |
54 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
55 | color = True
56 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
57 | color = False
58 | else:
59 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
60 |
61 | file.write("PF\n".encode("utf-8") if color else "Pf\n".encode("utf-8"))
62 | file.write("{} {}\n".format(image.shape[1], image.shape[0]).encode("utf-8"))
63 |
64 | endian = image.dtype.byteorder
65 |
66 | if endian == "<" or endian == "=" and sys.byteorder == "little":
67 | scale = -scale
68 |
69 | file.write(("%f\n" % scale).encode("utf-8"))
70 |
71 | image.tofile(file)
72 | file.close()
73 |
--------------------------------------------------------------------------------
/datasets/llff.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | from torchvision import transforms as T
10 |
11 | from .colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary
12 | from .ray_utils import read_gen, get_ray_directions, get_rays, get_ndc_rays
13 |
14 |
15 | def normalize(v):
16 | """Normalize a vector."""
17 | return v / np.linalg.norm(v)
18 |
19 |
20 | def average_poses(poses):
21 | """
22 | Calculate the average pose, which is then used to center all poses
23 | using @center_poses. Its computation is as follows:
24 | 1. Compute the center: the average of pose centers.
25 | 2. Compute the z axis: the normalized average z axis.
26 | 3. Compute axis y': the average y axis.
27 | 4. Compute x' = y' cross product z, then normalize it as the x axis.
28 | 5. Compute the y axis: z cross product x.
29 |
30 | Note that at step 3, we cannot directly use y' as y axis since it's
31 | not necessarily orthogonal to z axis. We need to pass from x to y.
32 |
33 | Inputs:
34 | poses: (N_images, 3, 4)
35 |
36 | Outputs:
37 | pose_avg: (3, 4) the average pose
38 | """
39 | # 1. Compute the center
40 | center = poses[..., 3].mean(0) # (3)
41 |
42 | # 2. Compute the z axis
43 | z = normalize(poses[..., 2].mean(0)) # (3)
44 |
45 | # 3. Compute axis y' (no need to normalize as it's not the final output)
46 | y_ = poses[..., 1].mean(0) # (3)
47 |
48 | # 4. Compute the x axis
49 | x = normalize(np.cross(y_, z)) # (3)
50 |
51 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
52 | y = np.cross(z, x) # (3)
53 |
54 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
55 |
56 | return pose_avg
57 |
58 |
59 | def center_poses(poses):
60 | """
61 | Center the poses so that we can use NDC.
62 | See https://github.com/bmild/nerf/issues/34
63 |
64 | Inputs:
65 | poses: (N_images, 3, 4)
66 |
67 | Outputs:
68 | poses_centered: (N_images, 3, 4) the centered poses
69 | pose_avg: (3, 4) the average pose
70 | """
71 |
72 | pose_avg = average_poses(poses) # (3, 4)
73 | pose_avg_homo = np.eye(4)
74 | # convert to homogeneous coordinate for faster computation
75 | pose_avg_homo[:3] = pose_avg
76 | # by simply adding 0, 0, 0, 1 as the last row
77 | last_row = np.tile(np.array([0, 0, 0, 1]),
78 | (len(poses), 1, 1)) # (N_images, 1, 4)
79 | # (N_images, 4, 4) homogeneous coordinate
80 | poses_homo = np.concatenate([poses, last_row], 1)
81 |
82 | poses_centered = np.linalg.inv(
83 | pose_avg_homo) @ poses_homo # (N_images, 4, 4)
84 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
85 | return poses_centered, pose_avg
86 |
87 |
88 | def create_spiral_poses(radii, focus_depth, n_poses=120):
89 | """
90 | Computes poses that follow a spiral path for rendering purpose.
91 | See https://github.com/Fyusion/LLFF/issues/19
92 | In particular, the path looks like:
93 | https://tinyurl.com/ybgtfns3
94 |
95 | Inputs:
96 | radii: (3) radii of the spiral for each axis
97 | focus_depth: float, the depth that the spiral poses look at
98 | n_poses: int, number of poses to create along the path
99 |
100 | Outputs:
101 | poses_spiral: (n_poses, 3, 4) the poses in the spiral path
102 | """
103 |
104 | poses_spiral = []
105 | # rotate 4pi (2 rounds)
106 | for t in np.linspace(0, 4 * np.pi, n_poses + 1)[:-1]:
107 | # the parametric function of the spiral (see the interactive web)
108 | center = np.array([np.cos(t), -np.sin(t), -np.sin(0.5 * t)]) * radii
109 |
110 | # the viewing z axis is the vector pointing from the @focus_depth plane
111 | # to @center
112 | z = normalize(center - np.array([0, 0, -focus_depth]))
113 |
114 | # compute other axes as in @average_poses
115 | y_ = np.array([0, 1, 0]) # (3)
116 | x = normalize(np.cross(y_, z)) # (3)
117 | y = np.cross(z, x) # (3)
118 |
119 | poses_spiral += [np.stack([x, y, z, center], 1)] # (3, 4)
120 |
121 | return np.stack(poses_spiral, 0) # (n_poses, 3, 4)
122 |
123 |
124 | def create_spheric_poses(radius, n_poses=120):
125 | """
126 | Create circular poses around z axis.
127 | Inputs:
128 | radius: the (negative) height and the radius of the circle.
129 |
130 | Outputs:
131 | spheric_poses: (n_poses, 3, 4) the poses in the circular path
132 | """
133 |
134 | def spheric_pose(theta, phi, radius):
135 | def trans_t(t):
136 | return np.array(
137 | [
138 | [1, 0, 0, 0],
139 | [0, 1, 0, -0.9 * t],
140 | [0, 0, 1, t],
141 | [0, 0, 0, 1],
142 | ]
143 | )
144 |
145 | def rot_phi(phi):
146 | return np.array(
147 | [
148 | [1, 0, 0, 0],
149 | [0, np.cos(phi), -np.sin(phi), 0],
150 | [0, np.sin(phi), np.cos(phi), 0],
151 | [0, 0, 0, 1],
152 | ]
153 | )
154 |
155 | def rot_theta(th):
156 | return np.array(
157 | [
158 | [np.cos(th), 0, -np.sin(th), 0],
159 | [0, 1, 0, 0],
160 | [np.sin(th), 0, np.cos(th), 0],
161 | [0, 0, 0, 1],
162 | ]
163 | )
164 |
165 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius)
166 | c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0],
167 | [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
168 | return c2w[:3]
169 |
170 | spheric_poses = []
171 | for th in np.linspace(0, 2 * np.pi, n_poses + 1)[:-1]:
172 | # 36 degree view downwards
173 | spheric_poses += [spheric_pose(th, -np.pi / 5, radius)]
174 | return np.stack(spheric_poses, 0)
175 |
176 |
177 | class LLFFDataset(Dataset):
178 | def __init__(self, root_dir, split="train", img_wh=(504, 378), spheric_poses=False, val_num=1):
179 | """
180 | spheric_poses: whether the images are taken in a spheric inward-facing manner
181 | default: False (forward-facing)
182 | val_num: number of val images (used for multigpu training, validate same image for all gpus)
183 | """
184 | self.root_dir = root_dir
185 | self.split = split
186 | self.img_wh = img_wh
187 | self.spheric_poses = spheric_poses
188 | self.val_num = max(1, val_num) # at least 1
189 | self.holdout = 8
190 | self.define_transforms()
191 |
192 | self.read_meta()
193 | # self.load_mvs_depth()
194 | self.white_back = False
195 |
196 | def load_mvs_depth(self):
197 | depth_glob = os.path.join(self.root_dir, "depths", "*.pfm")
198 | self.depth_list = sorted(glob.glob(depth_glob))
199 | depths = []
200 | for i in range(len(self.depth_list)):
201 | depth = read_gen(self.depth_list[i])
202 |
203 | depths.append(depth)
204 | self.depths = np.stack(depths, 0).astype(np.float32) # N x H x W
205 |
206 | per_view_points = self.project_to_3d()
207 | mvs_points = self.fwd_consistency_check(per_view_points)
208 | return mvs_points
209 |
210 | def project_to_3d(self):
211 | N, H, W = self.depths.shape
212 | focal = self.origin_intrinsics[1].params[0]
213 | origin_h, origin_w = self.origin_intrinsics[1].height, self.origin_intrinsics[1].width
214 |
215 | origin_cy, origin_cx = self.origin_intrinsics[1].params[2], self.origin_intrinsics[1].params[1]
216 |
217 | origin_K = np.array([[focal, 0, origin_cx, 0], [
218 | 0, focal, origin_cy, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
219 |
220 | origin_K[0, :] /= origin_w
221 | origin_K[1, :] /= origin_h
222 |
223 | self.normalized_K = origin_K
224 |
225 | mvs_K = self.normalized_K.copy()
226 | mvs_K[0, :] *= W
227 | mvs_K[1, :] *= H
228 | self.mvs_K = mvs_K
229 |
230 | inv_mvs_K = np.linalg.pinv(mvs_K)
231 | inv_mvs_K = torch.from_numpy(inv_mvs_K)
232 |
233 | # create mesh grid for mvs image
234 | meshgrid = np.meshgrid(range(W), range(H), indexing="xy")
235 | id_coords = (np.stack(meshgrid, axis=0).astype(
236 | np.float32)).reshape(2, -1)
237 | id_coords = torch.from_numpy(id_coords)
238 |
239 | ones = torch.ones(N, 1, H * W)
240 |
241 | pix_coords = torch.unsqueeze(torch.stack(
242 | [id_coords[0].view(-1), id_coords[1].view(-1)], 0), 0)
243 | pix_coords = pix_coords.repeat(N, 1, 1)
244 | pix_coords = torch.cat([pix_coords, ones], 1)
245 |
246 | # project to cam coord
247 | inv_mvs_K = inv_mvs_K[None, ...].repeat(N, 1, 1).float()
248 | cam_points = torch.matmul(inv_mvs_K[:, :3, :3], pix_coords)
249 | mvs_depth = torch.from_numpy(
250 | self.depths).float().unsqueeze(1).view(N, 1, -1)
251 | cam_points = mvs_depth * cam_points
252 | cam_points = torch.cat([cam_points, ones], 1)
253 |
254 | # project to world coord
255 | T = torch.from_numpy(self.origin_extrinsics).float()
256 | world_points = torch.matmul(T, cam_points)
257 | world_points = world_points.permute(0, 2, 1) # N, H*W, 3
258 |
259 | return world_points
260 |
261 | def fwd_consistency_check(self, per_view_points):
262 | N, H, W = self.depths.shape
263 | global_valid_points = []
264 | for view_id in range(per_view_points.shape[0]):
265 | curr_view_points = per_view_points[view_id].transpose(
266 | 1, 0) # 3, H*W
267 | homo_view_points = torch.cat(
268 | [curr_view_points, torch.ones(1, H * W)], dim=0) # 4, H*W
269 | homo_view_points = homo_view_points.unsqueeze(
270 | 0).repeat(N, 1, 1) # N,4,H*W
271 |
272 | # project to camera space
273 | T = torch.from_numpy(self.origin_extrinsics).float()
274 | homo_T = torch.cat([T, torch.zeros(N, 1, 4)], dim=1)
275 | homo_T[:, -1, -1] = 1
276 | inv_T = torch.inverse(homo_T)
277 | cam_points = torch.matmul(inv_T[:, :3, :], homo_view_points)
278 |
279 | # project to image space
280 | mvs_K = torch.from_numpy(self.mvs_K).unsqueeze(
281 | 0).repeat(N, 1, 1).float()
282 | cam_points = torch.matmul(mvs_K[:, :3, :3], cam_points)
283 | cam_points[:, :2, :] /= cam_points[:, 2:, :]
284 |
285 | z_values = cam_points[:, 2:, :].view(N, 1, H, W) # N,1,H,W
286 | xy_coords = cam_points[:, :2, :].transpose(
287 | 2, 1).view(N, H, W, 2) # N,H,W,2
288 |
289 | xy_coords[..., 0] /= W - 1
290 | xy_coords[..., 1] /= H - 1
291 | xy_coords = (xy_coords - 0.5) * 2
292 |
293 | mvs_depth = torch.from_numpy(
294 | self.depths).float().unsqueeze(1) # N,1,H,W
295 | ref_z_values = F.grid_sample(
296 | mvs_depth, xy_coords, mode="bilinear", align_corners=False)
297 |
298 | # ! z_values >= alpha*ref_values, also invalid index is 0 so they are satisfied this condition
299 | # ! point must be visible in at least n views
300 | err = z_values - 0.9 * ref_z_values
301 |
302 | visible_mask = ref_z_values != 0
303 | visible_count = visible_mask.int().sum(0)
304 | valid_visible = visible_count >= 1
305 | valid_points = err >= 0
306 | valid_points = torch.all(
307 | valid_points, dim=0) & valid_visible # 1,H,W
308 | global_valid_points.append(valid_points)
309 | global_valid_points = torch.cat(
310 | global_valid_points, dim=0).view(N, H * W) # N,H,W
311 |
312 | filtered_points = per_view_points[global_valid_points, :]
313 | # np.save('assets/filtered_mvs_points.npy', filtered_points)
314 | # breakpoint()
315 | return filtered_points
316 |
317 | def read_meta(self):
318 | # Step 1: rescale focal length according to training resolution
319 | camdata = read_cameras_binary(os.path.join(
320 | self.root_dir, "sparse/0/cameras.bin"))
321 | self.origin_intrinsics = camdata
322 | W = camdata[1].width
323 | self.focal = camdata[1].params[0] * self.img_wh[0] / W
324 |
325 | # Step 2: correct poses
326 | # read extrinsics (of successfully reconstructed images)
327 | imdata = read_images_binary(os.path.join(
328 | self.root_dir, "sparse/0/images.bin"))
329 | perm = np.argsort([imdata[k].name for k in imdata])
330 | # read successfully reconstructed images and ignore others
331 | self.image_paths = [
332 | os.path.join(self.root_dir, "images", name) for name in sorted([imdata[k].name for k in imdata])
333 | ]
334 | w2c_mats = []
335 | bottom = np.array([0, 0, 0, 1.0]).reshape(1, 4)
336 | for k in imdata:
337 | im = imdata[k]
338 | R = im.qvec2rotmat()
339 | t = im.tvec.reshape(3, 1)
340 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)]
341 | w2c_mats = np.stack(w2c_mats, 0)
342 | # (N_images, 3, 4) cam2world matrices
343 | poses = np.linalg.inv(w2c_mats)[:, :3]
344 | self.origin_extrinsics = poses
345 |
346 | # read bounds
347 | self.bounds = np.zeros((len(poses), 2)) # (N_images, 2)
348 | pts3d = read_points3d_binary(os.path.join(
349 | self.root_dir, "sparse/0/points3D.bin"))
350 |
351 | mvs_points = self.load_mvs_depth().numpy() # ! This is mvs depth pretrained
352 | near_bound = mvs_points.min(axis=0)[-1]
353 | pts3d = {k: v for (k, v) in pts3d.items() if v.xyz[-1] > near_bound}
354 |
355 | pts_world = np.zeros((1, 3, len(pts3d))) # (1, 3, N_points)
356 | visibilities = np.zeros((len(poses), len(pts3d))
357 | ) # (N_images, N_points)
358 | for i, k in enumerate(pts3d):
359 | pts_world[0, :, i] = pts3d[k].xyz
360 | for j in pts3d[k].image_ids:
361 | visibilities[j - 1, i] = 1
362 | # calculate each point's depth w.r.t. each camera
363 | # it's the dot product of "points - camera center" and "camera frontal axis"
364 | # (N_images, N_points)
365 | depths = ((pts_world - poses[..., 3:4]) * poses[..., 2:3]).sum(1)
366 | for i in range(len(poses)):
367 | visibility_i = visibilities[i]
368 | zs = depths[i][visibility_i == 1]
369 | self.bounds[i] = [np.percentile(zs, 0.1), np.percentile(zs, 99.9)]
370 | # permute the matrices to increasing order
371 | poses = poses[perm]
372 | self.bounds = self.bounds[perm]
373 |
374 | # COLMAP poses has rotation in form "right down front", change to "right up back"
375 | # See https://github.com/bmild/nerf/issues/34
376 | poses = np.concatenate(
377 | [poses[..., 0:1], -poses[..., 1:3], poses[..., 3:4]], -1)
378 | self.poses, _ = center_poses(poses)
379 | distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
380 |
381 | # choose val image same as nerf 0, 8 ,16
382 | indicies = np.arange(distances_from_center.shape[0], dtype=int)
383 | val_idx = indicies[::self.holdout]
384 | # center image
385 |
386 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0
387 | # See https://github.com/bmild/nerf/issues/34
388 | near_original = self.bounds.min()
389 | scale_factor = near_original * 0.75 # 0.75 is the default parameter
390 | # the nearest depth is at 1/0.75=1.33
391 | self.bounds /= scale_factor
392 | self.poses[..., 3] /= scale_factor
393 |
394 | # ray directions for all pixels, same for all images (same H, W, focal)
395 | self.directions = get_ray_directions(
396 | self.img_wh[1], self.img_wh[0], self.focal) # (H, W, 3)
397 |
398 | if self.split == "train": # create buffer of all rays and rgb data
399 | # use first N_images-1 to train, the LAST is val
400 | self.all_rays = []
401 | self.all_rgbs = []
402 | for i, image_path in enumerate(self.image_paths):
403 | if np.any(i == val_idx): # exclude the val image
404 | continue
405 | c2w = torch.FloatTensor(self.poses[i])
406 |
407 | img = Image.open(image_path).convert("RGB")
408 | # assert img.size[1]*self.img_wh[0] == img.size[0]*self.img_wh[1], \
409 | # f'''{image_path} has different aspect ratio than img_wh,
410 | # please check your data!'''
411 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS)
412 | img = self.transform(img) # (3, h, w)
413 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB
414 | self.all_rgbs += [img]
415 |
416 | rays_o, rays_d = get_rays(
417 | self.directions, c2w) # both (h*w, 3)
418 | viewdirs = rays_d # ! As get rays already normalized rays_d
419 |
420 | if not self.spheric_poses:
421 | near, far = 0, 1
422 | rays_o, rays_d = get_ndc_rays(
423 | self.img_wh[1], self.img_wh[0], self.focal, 1.0, rays_o, rays_d)
424 |
425 | # near plane is always at 1.0
426 | # near and far in NDC are always 0 and 1
427 | # See https://github.com/bmild/nerf/issues/34
428 | else:
429 | near = self.bounds.min()
430 | # focus on central object only
431 | far = min(8 * near, self.bounds.max())
432 |
433 | self.all_rays += [
434 | torch.cat(
435 | [
436 | rays_o,
437 | rays_d,
438 | near * torch.ones_like(rays_o[:, :1]),
439 | far * torch.ones_like(rays_o[:, :1]),
440 | viewdirs,
441 | ],
442 | 1,
443 | )
444 | ] # (h*w, 11)
445 |
446 | # ((N_images-1)*h*w, 8)
447 | self.all_rays = torch.cat(self.all_rays, 0)
448 | # ((N_images-1)*h*w, 3)
449 | self.all_rgbs = torch.cat(self.all_rgbs, 0)
450 |
451 | elif self.split == "val":
452 | print("val image is", val_idx)
453 | self.val_idx = val_idx
454 |
455 | else: # for testing, create a parametric rendering path
456 | if self.split.endswith("train"): # test on training set
457 | self.poses_test = self.poses
458 | elif not self.spheric_poses:
459 | focus_depth = 3.5 # hardcoded, this is numerically close to the formula
460 | # given in the original repo. Mathematically if near=1
461 | # and far=infinity, then this number will converge to 4
462 | radii = np.percentile(np.abs(self.poses[..., 3]), 90, axis=0)
463 | self.poses_test = create_spiral_poses(radii, focus_depth)
464 | else:
465 | radius = 1.1 * self.bounds.min()
466 | self.poses_test = create_spheric_poses(radius)
467 |
468 | def define_transforms(self):
469 | self.transform = T.ToTensor()
470 |
471 | def __len__(self):
472 | if self.split == "train":
473 | return len(self.all_rays)
474 | if self.split == "val":
475 | return self.val_num
476 | if self.split == "test_train":
477 | return len(self.poses)
478 | return len(self.poses_test)
479 |
480 | def __getitem__(self, idx):
481 | if self.split == "train": # use data in the buffers
482 | sample = {"rays": self.all_rays[idx], "rgbs": self.all_rgbs[idx]}
483 |
484 | else:
485 | if self.split == "val":
486 | idx = self.val_idx[idx]
487 | c2w = torch.FloatTensor(self.poses[idx])
488 | elif self.split == "test_train":
489 | c2w = torch.FloatTensor(self.poses[idx])
490 | else:
491 | c2w = torch.FloatTensor(self.poses_test[idx])
492 |
493 | rays_o, rays_d = get_rays(self.directions, c2w)
494 | viewdirs = rays_d
495 | if not self.spheric_poses:
496 | near, far = 0, 1
497 | rays_o, rays_d = get_ndc_rays(
498 | self.img_wh[1], self.img_wh[0], self.focal, 1.0, rays_o, rays_d)
499 | else:
500 | near = self.bounds.min()
501 | far = min(8 * near, self.bounds.max())
502 |
503 | rays = torch.cat(
504 | [
505 | rays_o,
506 | rays_d,
507 | near * torch.ones_like(rays_o[:, :1]),
508 | far * torch.ones_like(rays_o[:, :1]),
509 | viewdirs,
510 | ],
511 | 1,
512 | ) # (h*w, 11)
513 |
514 | sample = {"rays": rays, "c2w": c2w}
515 |
516 | if self.split in ["val", "test_train"]:
517 | # if self.split == 'val':
518 | # idx = self.val_idx
519 | img = Image.open(self.image_paths[idx]).convert("RGB")
520 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS)
521 | img = self.transform(img) # (3, h, w)
522 | img = img.view(3, -1).permute(1, 0) # (h*w, 3)
523 | sample["rgbs"] = img
524 |
525 | return sample
526 |
--------------------------------------------------------------------------------
/datasets/ray_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 | import torch
5 | from kornia import create_meshgrid
6 |
7 |
8 | def get_ray_directions(H, W, focal):
9 | """
10 | Get ray directions for all pixels in camera coordinate.
11 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
12 | ray-tracing-generating-camera-rays/standard-coordinate-systems
13 |
14 | Inputs:
15 | H, W, focal: image height, width and focal length
16 |
17 | Outputs:
18 | directions: (H, W, 3), the direction of the rays in camera coordinate
19 | """
20 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]
21 | i, j = grid.unbind(-1)
22 | # the direction here is without +0.5 pixel centering as calibration is not so accurate
23 | # see https://github.com/bmild/nerf/issues/24
24 | directions = torch.stack([(i - W / 2) / focal, -(j - H / 2) / focal, -torch.ones_like(i)], -1) # (H, W, 3)
25 |
26 | return directions
27 |
28 |
29 | def get_rays(directions, c2w):
30 | """
31 | Get ray origin and normalized directions in world coordinate for all pixels in one image.
32 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
33 | ray-tracing-generating-camera-rays/standard-coordinate-systems
34 |
35 | Inputs:
36 | directions: (H, W, 3) precomputed ray directions in camera coordinate
37 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
38 |
39 | Outputs:
40 | rays_o: (H*W, 3), the origin of the rays in world coordinate
41 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
42 | """
43 | # Rotate ray directions from camera coordinate to the world coordinate
44 | rays_d = directions @ c2w[:, :3].T # (H, W, 3)
45 | rays_d /= torch.norm(rays_d, dim=-1, keepdim=True)
46 | # The origin of all rays is the camera origin in world coordinate
47 | rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3)
48 |
49 | rays_d = rays_d.view(-1, 3)
50 | rays_o = rays_o.view(-1, 3)
51 |
52 | return rays_o, rays_d
53 |
54 |
55 | def get_ndc_rays(H, W, focal, near, rays_o, rays_d):
56 | """
57 | Transform rays from world coordinate to NDC.
58 | NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis.
59 | For detailed derivation, please see:
60 | http://www.songho.ca/opengl/gl_projectionmatrix.html
61 | https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf
62 |
63 | In practice, use NDC "if and only if" the scene is unbounded (has a large depth).
64 | See https://github.com/bmild/nerf/issues/18
65 |
66 | Inputs:
67 | H, W, focal: image height, width and focal length
68 | near: (N_rays) or float, the depths of the near plane
69 | rays_o: (N_rays, 3), the origin of the rays in world coordinate
70 | rays_d: (N_rays, 3), the direction of the rays in world coordinate
71 |
72 | Outputs:
73 | rays_o: (N_rays, 3), the origin of the rays in NDC
74 | rays_d: (N_rays, 3), the direction of the rays in NDC
75 | """
76 | # Shift ray origins to near plane
77 | t = -(near + rays_o[..., 2]) / rays_d[..., 2]
78 | rays_o = rays_o + t[..., None] * rays_d
79 |
80 | # Store some intermediate homogeneous results
81 | ox_oz = rays_o[..., 0] / rays_o[..., 2]
82 | oy_oz = rays_o[..., 1] / rays_o[..., 2]
83 |
84 | # Projection
85 | o0 = -1.0 / (W / (2.0 * focal)) * ox_oz
86 | o1 = -1.0 / (H / (2.0 * focal)) * oy_oz
87 | o2 = 1.0 + 2.0 * near / rays_o[..., 2]
88 |
89 | d0 = -1.0 / (W / (2.0 * focal)) * (rays_d[..., 0] / rays_d[..., 2] - ox_oz)
90 | d1 = -1.0 / (H / (2.0 * focal)) * (rays_d[..., 1] / rays_d[..., 2] - oy_oz)
91 | d2 = 1 - o2
92 |
93 | rays_o = torch.stack([o0, o1, o2], -1) # (B, 3)
94 | rays_d = torch.stack([d0, d1, d2], -1) # (B, 3)
95 |
96 | return rays_o, rays_d
97 |
98 |
99 | def get_ndc_coor(H, W, focal, near, pts):
100 | # Store some intermediate homogeneous results
101 | ox_oz = pts[..., 0] / pts[..., 2]
102 | oy_oz = pts[..., 1] / pts[..., 2]
103 |
104 | # Projection
105 | o0 = -1.0 / (W / (2.0 * focal)) * ox_oz
106 | o1 = -1.0 / (H / (2.0 * focal)) * oy_oz
107 | o2 = 1.0 + 2.0 * near / pts[..., 2]
108 |
109 | pts = torch.stack([o0, o1, o2], -1) # (B, 3)
110 |
111 | return pts
112 |
113 |
114 | # read mvs pretrained depth
115 |
116 |
117 | def readPFM(file):
118 | file = open(file, "rb")
119 |
120 | color = None
121 | width = None
122 | height = None
123 | scale = None
124 | endian = None
125 |
126 | header = file.readline().rstrip()
127 | if header == b"PF":
128 | color = True
129 | elif header == b"Pf":
130 | color = False
131 | else:
132 | raise Exception("Not a PFM file.")
133 |
134 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
135 | if dim_match:
136 | width, height = map(int, dim_match.groups())
137 | else:
138 | raise Exception("Malformed PFM header.")
139 |
140 | scale = float(file.readline().rstrip())
141 | if scale < 0: # little-endian
142 | endian = "<"
143 | scale = -scale
144 | else:
145 | endian = ">" # big-endian
146 |
147 | data = np.fromfile(file, endian + "f")
148 | shape = (height, width, 3) if color else (height, width)
149 |
150 | data = np.reshape(data, shape)
151 | data = np.flipud(data)
152 | return data
153 |
154 |
155 | def read_gen(file_name, pil=False):
156 | flow = readPFM(file_name).astype(np.float32)
157 | if len(flow.shape) == 2:
158 | return flow
159 | else:
160 | return flow[:, :, :-1]
161 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | from collections import defaultdict
4 |
5 | import cv2
6 | import imageio
7 | import lpips
8 | import metrics
9 | import numpy as np
10 | import torch
11 | from datasets import dataset_dict
12 | from datasets.depth_utils import save_pfm
13 | from models.rendering import render_rays
14 | from tqdm import tqdm
15 | from train import NeRFSystem
16 |
17 |
18 | torch.backends.cudnn.benchmark = True
19 |
20 |
21 | def get_opts():
22 | parser = ArgumentParser()
23 | parser.add_argument(
24 | "--root_dir",
25 | type=str,
26 | default="/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego",
27 | help="root directory of dataset",
28 | )
29 | parser.add_argument(
30 | "--dataset_name", type=str, default="blender", choices=["blender", "llff"], help="which dataset to validate"
31 | )
32 | parser.add_argument("--scene_name", type=str, default="test",
33 | help="scene name, used as output folder name")
34 | parser.add_argument("--split", type=str, default="test",
35 | help="test or test_train")
36 | parser.add_argument(
37 | "--img_wh", nargs="+", type=int, default=[800, 800], help="resolution (img_w, img_h) of the image"
38 | )
39 | parser.add_argument(
40 | "--spheric_poses",
41 | default=False,
42 | action="store_true",
43 | help="whether images are taken in spheric poses (for llff)",
44 | )
45 |
46 | parser.add_argument("--N_emb_xyz", type=int, default=10,
47 | help="number of frequencies in xyz positional encoding")
48 | parser.add_argument("--N_emb_dir", type=int, default=4,
49 | help="number of frequencies in dir positional encoding")
50 | parser.add_argument("--N_samples", type=int, default=64,
51 | help="number of coarse samples")
52 | parser.add_argument("--N_importance", type=int, default=128,
53 | help="number of additional fine samples")
54 | parser.add_argument("--use_disp", default=False,
55 | action="store_true", help="use disparity depth sampling")
56 | parser.add_argument("--chunk", type=int, default=32 * 1024 * 4,
57 | help="chunk size to split the input to avoid OOM")
58 |
59 | parser.add_argument("--weight_path", type=str, required=True,
60 | help="pretrained checkpoint path to load")
61 |
62 | parser.add_argument("--save_depth", default=False,
63 | action="store_true", help="whether to save depth prediction")
64 | parser.add_argument(
65 | "--depth_format", type=str, default="pfm", choices=["pfm", "bytes"], help="which format to save"
66 | )
67 |
68 | return parser.parse_args()
69 |
70 |
71 | @torch.no_grad()
72 | def batched_inference(models, embeddings, rays, N_samples, N_importance, use_disp, chunk):
73 | """Do batched inference on rays using chunk."""
74 | B = rays.shape[0]
75 | results = defaultdict(list)
76 | for i in range(0, B, chunk):
77 | rendered_ray_chunks = render_rays(
78 | models,
79 | embeddings,
80 | rays[i: i + chunk],
81 | N_samples,
82 | use_disp,
83 | 0,
84 | 0,
85 | N_importance,
86 | chunk,
87 | dataset.white_back,
88 | test_time=True,
89 | )
90 |
91 | for k, v in rendered_ray_chunks.items():
92 | results[k] += [v.cpu()]
93 |
94 | for k, v in results.items():
95 | results[k] = torch.cat(v, 0)
96 | return results
97 |
98 |
99 | if __name__ == "__main__":
100 | args = get_opts()
101 | w, h = args.img_wh
102 | lpips_vgg = lpips.LPIPS(net="vgg").cuda()
103 | lpips_vgg = lpips_vgg.eval()
104 |
105 | kwargs = {"root_dir": args.root_dir, "split": args.split,
106 | "img_wh": tuple(args.img_wh), "val_num": 3}
107 | if args.dataset_name == "llff":
108 | kwargs["spheric_poses"] = args.spheric_poses
109 | dataset = dataset_dict[args.dataset_name](**kwargs)
110 |
111 | system = NeRFSystem(args)
112 |
113 | embedding_xyz = system.embedding_xyz
114 | embedding_dir = system.embedding_dir
115 | models = system.models
116 | for k in models.keys():
117 | models[k] = models[k].cuda().eval()
118 |
119 | imgs, depth_maps, psnrs, mean_lpips, ssims = [], [], [], [], []
120 | dir_name = f"results/{args.dataset_name}/{args.scene_name}"
121 | os.makedirs(dir_name, exist_ok=True)
122 |
123 | embeddings = {"xyz": embedding_xyz, "dir": embedding_dir}
124 |
125 | for i in tqdm(range(len(dataset))):
126 | sample = dataset[i]
127 | rays = sample["rays"].cuda()
128 | results = batched_inference(
129 | models, embeddings, rays, args.N_samples, args.N_importance, args.use_disp, args.chunk
130 | )
131 | typ = "fine" if "rgb_fine" in results else "coarse"
132 |
133 | img_pred = np.clip(results[f"rgb_{typ}"].view(
134 | h, w, 3).cpu().numpy(), 0, 1)
135 |
136 | if args.save_depth:
137 | depth_pred = results[f"depth_{typ}"].view(h, w).cpu().numpy()
138 | depth_maps += [depth_pred]
139 | if args.depth_format == "pfm":
140 | save_pfm(os.path.join(
141 | dir_name, f"depth_{i:03d}.pfm"), depth_pred)
142 | else:
143 | with open(os.path.join(dir_name, f"depth_{i:03d}"), "wb") as f:
144 | f.write(depth_pred.tobytes())
145 |
146 | img_pred_ = (img_pred * 255).astype(np.uint8)
147 | imgs += [img_pred_]
148 | imageio.imwrite(os.path.join(dir_name, f"{i:03d}.png"), img_pred_)
149 |
150 | img_gt_ = (sample["rgbs"].view(
151 | h, w, 3).cpu().numpy() * 255).astype(np.uint8)
152 | imageio.imwrite(os.path.join(dir_name, f"gt_{i:03d}.png"), img_gt_)
153 |
154 | if "rgbs" in sample:
155 | rgbs = sample["rgbs"]
156 | img_gt = rgbs.view(h, w, 3)
157 |
158 | # scale to compute lpips
159 | scaled_gt = img_gt * 2.0 - 1.0
160 | scaled_pred = img_pred * 2.0 - 1.0
161 | scaled_pred = torch.from_numpy(scaled_pred)
162 | lpips_val = lpips_vgg(
163 | scaled_gt[:, :, [2, 1, 0]].permute(
164 | 2, 0, 1).unsqueeze(0).cuda(),
165 | scaled_pred[:, :, [2, 1, 0]].permute(
166 | 2, 0, 1).unsqueeze(0).cuda(),
167 | )
168 | mean_lpips.append(lpips_val.detach().squeeze().cpu().numpy())
169 | psnrs += [metrics.psnr(img_gt, img_pred).item()]
170 |
171 | # compute ssim
172 | ssim = metrics.ssim(img_gt.permute(2, 0, 1)[None], torch.from_numpy(
173 | img_pred).permute(2, 0, 1)[None])
174 | ssims.append(ssim)
175 |
176 | imageio.mimsave(os.path.join(
177 | dir_name, f"{args.scene_name}.gif"), imgs, fps=30)
178 |
179 | if args.save_depth:
180 | min_depth = np.min(depth_maps)
181 | max_depth = np.max(depth_maps)
182 | depth_imgs = (depth_maps - np.min(depth_maps)) / \
183 | (max(np.max(depth_maps) - np.min(depth_maps), 1e-8))
184 | depth_imgs_ = [cv2.applyColorMap(
185 | (img * 255).astype(np.uint8), cv2.COLORMAP_JET) for img in depth_imgs]
186 | imageio.mimsave(os.path.join(
187 | dir_name, f"{args.scene_name}_depth.gif"), depth_imgs_, fps=30)
188 |
189 | if psnrs:
190 | mean_psnr = np.mean(psnrs)
191 | print(f"Mean PSNR : {mean_psnr:.3f}")
192 |
193 | if ssims:
194 | mean_ssim = np.mean(ssims)
195 | print(f"Mean SSIM : {mean_ssim:.3f}")
196 |
197 | if mean_lpips:
198 | mean_lpips = np.mean(np.array(mean_lpips))
199 | print(f"Mean LPIPS : {mean_lpips:.3f}")
200 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class ColorLoss(nn.Module):
5 | def __init__(self, coef=1):
6 | super().__init__()
7 | self.coef = coef
8 | self.loss = nn.MSELoss(reduction="mean")
9 |
10 | def forward(self, inputs, targets):
11 | if "rgb_corase" in inputs:
12 | loss = self.loss(inputs["rgb_coarse"], targets)
13 | if "rgb_fine" in inputs:
14 | loss += self.loss(inputs["rgb_fine"], targets)
15 | else:
16 | loss = self.loss(inputs["rgb_fine"], targets)
17 | return self.coef * loss
18 |
19 |
20 | loss_dict = {"color": ColorLoss}
21 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia.losses import ssim_loss as dssim
3 |
4 |
5 | def mse(image_pred, image_gt, valid_mask=None, reduction="mean"):
6 | value = (image_pred - image_gt) ** 2
7 | if valid_mask is not None:
8 | value = value[valid_mask]
9 | if reduction == "mean":
10 | return torch.mean(value)
11 | return value
12 |
13 |
14 | def psnr(image_pred, image_gt, valid_mask=None, reduction="mean"):
15 | return -10 * torch.log10(mse(image_pred, image_gt, valid_mask, reduction))
16 |
17 |
18 | def ssim(image_pred, image_gt, reduction="mean"):
19 | """
20 | image_pred and image_gt: (1, 3, H, W)
21 | """
22 | dssim_ = dssim(image_pred, image_gt, 3, reduction=reduction) # dissimilarity in [0, 1]
23 | return 1 - 2 * dssim_ # in [-1, 1]
24 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/cloud_nerf/f64aff501588ff7a2b3573acc1a4325cefa6ddf7/models/__init__.py
--------------------------------------------------------------------------------
/models/cloud_code.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | import faiss # make faiss available
4 | import faiss.contrib.torch_utils
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | # ! CloudNeRF configs
10 | config = {}
11 | config["code_cloud"] = {}
12 | config["code_cloud"]["num_codes"] = 8192 # 8192
13 | config["code_cloud"]["num_neighbors"] = 32 # 32
14 |
15 | config["code_cloud"]["code_dim"] = 64 # 64
16 | config["code_cloud"]["dist_scale"] = 3.0
17 | config["code_regularization_lambda"] = 0.0
18 | config["code_position_lambda"] = 0
19 | config["distortion_lambda"] = 0
20 |
21 |
22 | @torch.no_grad()
23 | def find_knn(gpu_index, locs, current_codes, neighbor=config["code_cloud"]["num_neighbors"]):
24 | n_points = locs.shape[0]
25 | # Search with torch GPU using pre-allocated arrays
26 | new_d_square_torch_gpu = torch.zeros(
27 | n_points, neighbor, device=locs.device, dtype=torch.float32)
28 | new_i_torch_gpu = torch.zeros(
29 | n_points, neighbor, device=locs.device, dtype=torch.int64)
30 |
31 | # update current codes
32 | gpu_index.add(current_codes)
33 |
34 | gpu_index.search(locs, neighbor, new_d_square_torch_gpu, new_i_torch_gpu)
35 | gpu_index.reset()
36 |
37 | return new_d_square_torch_gpu, new_i_torch_gpu
38 |
39 |
40 | class CodeCloud(nn.Module):
41 | def __init__(self, config, num_records, keypoints, fps_keypoints):
42 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "Building CodeCloud.")
43 | super().__init__()
44 | self.config = config
45 | self.SH_basis_dim = 9
46 | self.origin_keypoints = nn.Parameter(
47 | torch.Tensor(keypoints.float())[None, ...].repeat(num_records, 1, 1), requires_grad=False
48 | )
49 |
50 | self.codes_position = nn.Parameter(torch.Tensor(fps_keypoints.float())[
51 | None, ...].repeat(num_records, 1, 1))
52 | self.codes = nn.Parameter(torch.randn(
53 | num_records, config["num_codes"], config["code_dim"]) * 0.01)
54 |
55 | self.knn = self.init_knn()
56 |
57 | num_params = sum(p.data.nelement() for p in self.parameters())
58 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
59 | "CodeCloud done(#parameters=%d)." % num_params)
60 |
61 | def init_knn(self):
62 | faiss_cfg = faiss.GpuIndexFlatConfig()
63 | faiss_cfg.useFloat16 = True
64 | faiss_cfg.device = 0
65 |
66 | return faiss.GpuIndexFlatL2(faiss.StandardGpuResources(), 3, faiss_cfg)
67 |
68 | def query(self, indices, query_points):
69 | """
70 | Args:
71 | indices: tensor, (batch_size,)
72 | query_points: tensor, (batch_size, num_points, 3)
73 | Returns:
74 | query_codes: tensor, (batch_size, num_points, code_dim)
75 | square_dist: tensor, (batch_size, num_points, num_codes)
76 | weight: tensor, (batch_size, num_points, num_codes)
77 | """
78 | batch_codes_position = self.codes_position[indices]
79 |
80 | # ! NO GRAD: Need to recompute the square distance for gradients
81 | _, new_i_torch_gpu = find_knn(
82 | self.knn, query_points[0], batch_codes_position[0])
83 |
84 | square_dist = (query_points.unsqueeze(2) - self.codes_position[indices][:, new_i_torch_gpu, :]).pow(2).sum(
85 | dim=-1
86 | ) + 1e-16
87 |
88 | weight = 1.0 / (torch.sqrt(square_dist) ** self.config["dist_scale"])
89 | weight = weight / weight.sum(dim=-1, keepdim=True)
90 |
91 | query_codes = torch.matmul(weight[0].unsqueeze(
92 | 1), self.codes[indices][:, new_i_torch_gpu, :][0]).squeeze(1)
93 |
94 | return query_codes
95 |
96 | def get_proposal(self, pts):
97 | n_rays, n_samples, _ = pts.shape
98 | new_d_torch_gpu, _ = find_knn(
99 | self.knn, pts.flatten(0, 1), self.origin_keypoints[0])
100 |
101 | farthest_d = new_d_torch_gpu[:, -1].view(n_rays, n_samples)
102 |
103 | farthest_d[farthest_d > 0.2] = 0.2
104 |
105 | weight = 1.0 / (torch.sqrt(farthest_d) + 1e-10)
106 | weight = weight / weight.sum(dim=-1, keepdim=True)
107 | return weight
108 |
109 |
110 | class IM_Decoder(nn.Module):
111 | def __init__(self, D=8, W=256, in_channels_xyz=63, in_channels_dir=27, skips=[4]):
112 | print(datetime.now().strftime(
113 | "%Y-%m-%d %H:%M:%S"), "Building IM-decoder.")
114 | super(IM_Decoder, self).__init__()
115 | self.D = D
116 | self.W = W
117 | self.in_channels_xyz = in_channels_xyz
118 | self.in_channels_dir = in_channels_dir
119 | self.skips = skips
120 |
121 | # xyz encoding layers
122 | for i in range(D):
123 | if i == 0:
124 | layer = nn.Linear(in_channels_xyz, W)
125 | elif i in skips:
126 | layer = nn.Linear(W + in_channels_xyz, W)
127 | else:
128 | layer = nn.Linear(W, W)
129 | layer = nn.Sequential(layer, nn.ReLU(True))
130 | setattr(self, f"xyz_encoding_{i+1}", layer)
131 | self.xyz_encoding_final = nn.Linear(W, W)
132 |
133 | # direction encoding layers
134 | self.dir_encoding = nn.Sequential(
135 | nn.Linear(W + in_channels_dir, W // 2), nn.ReLU(True))
136 |
137 | # output layers
138 | self.sigma = nn.Linear(W, 1)
139 | self.rgb = nn.Linear(W // 2, 3)
140 |
141 | num_params = sum(p.data.nelement() for p in self.parameters())
142 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
143 | "IM decoder done(#parameters=%d)." % num_params)
144 |
145 | def forward(self, x):
146 | input_xyz, input_dir = torch.split(
147 | x, [self.in_channels_xyz, self.in_channels_dir], dim=-1)
148 | xyz_ = input_xyz
149 | for i in range(self.D):
150 | if i in self.skips:
151 | xyz_ = torch.cat([input_xyz, xyz_], -1)
152 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)
153 |
154 | sigma = self.sigma(xyz_)
155 |
156 | xyz_encoding_final = self.xyz_encoding_final(xyz_)
157 |
158 | dir_encoding_input = torch.cat([xyz_encoding_final, input_dir], -1)
159 | dir_encoding = self.dir_encoding(dir_encoding_input)
160 | rgb = self.rgb(dir_encoding)
161 |
162 | out = torch.cat([rgb, sigma], -1)
163 |
164 | return out
165 |
166 |
167 | class CloudNeRF(nn.Module):
168 | def __init__(self, keypoints, fps_kps, input_ch, input_ch_views, num_records=1):
169 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "Building network.")
170 | super().__init__()
171 | global config
172 | self.config = config
173 |
174 | self.code_cloud = CodeCloud(
175 | config["code_cloud"], num_records, keypoints, fps_kps)
176 | self.decoder = IM_Decoder(
177 | D=4,
178 | W=128,
179 | in_channels_xyz=config["code_cloud"]["code_dim"] + input_ch,
180 | in_channels_dir=input_ch_views,
181 | skips=[2],
182 | )
183 |
184 | num_params = sum(p.data.nelement() for p in self.parameters())
185 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
186 | "Network done(#parameters=%d)." % num_params)
187 |
188 | def forward(self, indices, query_points, xyzdir_embedded):
189 | """
190 | Args:
191 | indices: tensor, (batch_size,)
192 | query_points: tensor, (batch_size, num_points, 3)
193 | Returns:
194 | pred_sd: tensor, (batch_size, num_points)
195 | """
196 | query_codes = self.code_cloud.query(indices, query_points[None, ...])
197 |
198 | batch_input = torch.cat([query_codes, xyzdir_embedded], dim=-1)
199 | pred_sd = self.decoder(batch_input)
200 |
201 | return pred_sd
202 |
203 |
204 | ################################## SH FEATURES ##############################################
205 | class SHCodeCloud(nn.Module):
206 | def __init__(self, config, num_records, keypoints, fps_keypoints):
207 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'Building CodeCloud.')
208 | super().__init__()
209 | self.config = config
210 | self.SH_basis_dim = 9
211 | self.origin_keypoints = nn.Parameter(torch.Tensor(
212 | keypoints.float())[None, ...].repeat(num_records, 1, 1), requires_grad=False)
213 |
214 | self.codes_position = nn.Parameter(torch.Tensor(
215 | fps_keypoints.float())[None, ...].repeat(num_records, 1, 1))
216 | self.codes = nn.Parameter(torch.randn(
217 | num_records, config['num_codes'], config['code_dim']) * 0.01)
218 | self.sh_codes = nn.Parameter(torch.randn(
219 | num_records, config['num_codes'], config['code_dim'] * self.SH_basis_dim) * 0.01)
220 |
221 | self.knn = self.init_knn()
222 |
223 | num_params = sum(p.data.nelement() for p in self.parameters())
224 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
225 | 'CodeCloud done(#parameters=%d).' % num_params)
226 |
227 | def init_knn(self):
228 | faiss_cfg = faiss.GpuIndexFlatConfig()
229 | faiss_cfg.useFloat16 = True
230 | faiss_cfg.device = 0
231 |
232 | return faiss.GpuIndexFlatL2(
233 | faiss.StandardGpuResources(), 3, faiss_cfg)
234 |
235 | def query(self, indices, query_points, viewdirs):
236 | """
237 | Args:
238 | indices: tensor, (batch_size,)
239 | query_points: tensor, (batch_size, num_points, 3)
240 | Returns:
241 | query_codes: tensor, (batch_size, num_points, code_dim)
242 | square_dist: tensor, (batch_size, num_points, num_codes)
243 | weight: tensor, (batch_size, num_points, num_codes)
244 | """
245 | batch_codes_position = self.codes_position[indices]
246 |
247 | # ! NO GRAD: Need to recompute the square distance for gradients
248 | _, new_i_torch_gpu = find_knn(
249 | self.knn, query_points[0], batch_codes_position[0])
250 |
251 | sh_feat = self.sh_codes[indices][:, new_i_torch_gpu, :][0]
252 | sh_feat = sh_feat.reshape(-1, self.config['num_neighbors'],
253 | self.config['code_dim'], self.SH_basis_dim)
254 | sh_mult = eval_sh_bases(self.SH_basis_dim, viewdirs).unsqueeze(
255 | 1).repeat(1, self.config['num_neighbors'], 1)
256 | agg_sh_feat = torch.sum(sh_mult.unsqueeze(-2) * sh_feat, dim=-1)
257 |
258 | square_dist = (query_points.unsqueeze(
259 | 2) - self.codes_position[indices][:, new_i_torch_gpu, :]).pow(2).sum(dim=-1) + 1e-16
260 |
261 | weight = 1.0 / (torch.sqrt(square_dist) ** self.config['dist_scale'])
262 | weight = weight / weight.sum(dim=-1, keepdim=True)
263 |
264 | query_codes = torch.matmul(
265 | weight[0].unsqueeze(1), self.codes[indices][:, new_i_torch_gpu, :][0]).squeeze(1)
266 |
267 | query_sh_codes = torch.matmul(
268 | weight[0].unsqueeze(1), agg_sh_feat).squeeze(1)
269 |
270 | return query_codes, query_sh_codes
271 |
272 | def get_proposal(self, pts):
273 | n_rays, n_samples, _ = pts.shape
274 | new_d_torch_gpu, _ = find_knn(
275 | self.knn, pts.flatten(0, 1), self.origin_keypoints[0])
276 |
277 | farthest_d = new_d_torch_gpu[:, -1].view(n_rays, n_samples)
278 |
279 | farthest_d[farthest_d > 0.2] = 0.2
280 |
281 | weight = 1.0 / (torch.sqrt(farthest_d) + 1e-10)
282 | weight = weight / weight.sum(dim=-1, keepdim=True)
283 | return weight
284 |
285 |
286 | class SH_IM_Decoder(nn.Module):
287 | def __init__(self,
288 | D=8, W=256, code_dim=64,
289 | in_channels_xyz=63, in_channels_dir=27,
290 | skips=[4]):
291 | print(datetime.now().strftime(
292 | '%Y-%m-%d %H:%M:%S'), 'Building IM-decoder.')
293 | super(SH_IM_Decoder, self).__init__()
294 | self.D = D
295 | self.W = W
296 | self.in_channels_xyz = in_channels_xyz
297 | self.in_channels_dir = in_channels_dir
298 | self.skips = skips
299 | self.code_dim = code_dim
300 |
301 | # xyz encoding layers
302 | for i in range(D):
303 | if i == 0:
304 | layer = nn.Linear(in_channels_xyz + self.code_dim, W)
305 | elif i in skips:
306 | layer = nn.Linear(W + in_channels_xyz + self.code_dim, W)
307 | else:
308 | layer = nn.Linear(W, W)
309 | layer = nn.Sequential(layer, nn.ReLU(True))
310 | setattr(self, f"xyz_encoding_{i+1}", layer)
311 | self.xyz_encoding_final = nn.Linear(W, W)
312 |
313 | # direction encoding layers
314 | self.dir_encoding = nn.Sequential(
315 | nn.Linear(W + in_channels_dir + self.code_dim, W // 2),
316 | nn.ReLU(True))
317 |
318 | # output layers
319 | self.sigma = nn.Linear(W, 1)
320 | self.rgb = nn.Linear(W // 2, 3)
321 |
322 | num_params = sum(p.data.nelement() for p in self.parameters())
323 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
324 | 'IM decoder done(#parameters=%d).' % num_params)
325 |
326 | def forward(self, codes, sh_codes, xyzdir_embed):
327 | input_xyz, input_dir = \
328 | torch.split(xyzdir_embed, [self.in_channels_xyz,
329 | self.in_channels_dir], dim=-1)
330 | input_xyz = torch.cat([codes, input_xyz], dim=-1)
331 | xyz_ = input_xyz
332 | for i in range(self.D):
333 | if i in self.skips:
334 | xyz_ = torch.cat([input_xyz, xyz_], -1)
335 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)
336 |
337 | sigma = self.sigma(xyz_)
338 |
339 | xyz_encoding_final = self.xyz_encoding_final(xyz_)
340 |
341 | dir_encoding_input = torch.cat(
342 | [xyz_encoding_final, sh_codes, input_dir], -1)
343 | dir_encoding = self.dir_encoding(dir_encoding_input)
344 | rgb = self.rgb(dir_encoding)
345 |
346 | out = torch.cat([rgb, sigma], -1)
347 |
348 | return out
349 |
350 |
351 | class SHCloudNeRF(nn.Module):
352 | def __init__(self, keypoints, fps_kps, input_ch, input_ch_views, num_records=1):
353 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'Building network.')
354 | super().__init__()
355 | global config
356 | self.config = config
357 |
358 | self.code_cloud = SHCodeCloud(
359 | config['code_cloud'], num_records, keypoints, fps_kps)
360 | self.decoder = SH_IM_Decoder(D=4, W=128, code_dim=config['code_cloud']['code_dim'],
361 | in_channels_xyz=input_ch, in_channels_dir=input_ch_views,
362 | skips=[2])
363 |
364 | num_params = sum(p.data.nelement() for p in self.parameters())
365 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
366 | 'Network done(#parameters=%d).' % num_params)
367 |
368 | def forward(self, indices, query_points, viewdirs, xyzdir_embedded):
369 | """
370 | Args:
371 | indices: tensor, (batch_size,)
372 | query_points: tensor, (batch_size, num_points, 3)
373 | Returns:
374 | pred_sd: tensor, (batch_size, num_points)
375 | """
376 | query_codes, query_sh_codes = self.code_cloud.query(
377 | indices, query_points[None, ...], viewdirs)
378 | pred_sd = self.decoder(query_codes, query_sh_codes, xyzdir_embedded)
379 |
380 | return pred_sd
381 |
382 |
383 | SH_C0 = 0.28209479177387814
384 | SH_C1 = 0.4886025119029199
385 | SH_C2 = [
386 | 1.0925484305920792,
387 | -1.0925484305920792,
388 | 0.31539156525252005,
389 | -1.0925484305920792,
390 | 0.5462742152960396
391 | ]
392 | SH_C3 = [
393 | -0.5900435899266435,
394 | 2.890611442640554,
395 | -0.4570457994644658,
396 | 0.3731763325901154,
397 | -0.4570457994644658,
398 | 1.445305721320277,
399 | -0.5900435899266435
400 | ]
401 | SH_C4 = [
402 | 2.5033429417967046,
403 | -1.7701307697799304,
404 | 0.9461746957575601,
405 | -0.6690465435572892,
406 | 0.10578554691520431,
407 | -0.6690465435572892,
408 | 0.47308734787878004,
409 | -1.7701307697799304,
410 | 0.6258357354491761,
411 | ]
412 |
413 |
414 | def eval_sh_bases(basis_dim: int, dirs: torch.Tensor):
415 | """
416 | Evaluate spherical harmonics bases at unit directions,
417 | without taking linear combination.
418 | At each point, the final result may the be
419 | obtained through simple multiplication.
420 | :param basis_dim: int SH basis dim. Currently, 1-25 square numbers supported
421 | :param dirs: torch.Tensor (..., 3) unit directions
422 | :return: torch.Tensor (..., basis_dim)
423 | """
424 | result = torch.empty(
425 | (*dirs.shape[:-1], basis_dim), dtype=dirs.dtype, device=dirs.device)
426 | result[..., 0] = SH_C0
427 | if basis_dim > 1:
428 | x, y, z = dirs.unbind(-1)
429 | result[..., 1] = -SH_C1 * y
430 | result[..., 2] = SH_C1 * z
431 | result[..., 3] = -SH_C1 * x
432 | if basis_dim > 4:
433 | xx, yy, zz = x * x, y * y, z * z
434 | xy, yz, xz = x * y, y * z, x * z
435 | result[..., 4] = SH_C2[0] * xy
436 | result[..., 5] = SH_C2[1] * yz
437 | result[..., 6] = SH_C2[2] * (2.0 * zz - xx - yy)
438 | result[..., 7] = SH_C2[3] * xz
439 | result[..., 8] = SH_C2[4] * (xx - yy)
440 |
441 | if basis_dim > 9:
442 | result[..., 9] = SH_C3[0] * y * (3 * xx - yy)
443 | result[..., 10] = SH_C3[1] * xy * z
444 | result[..., 11] = SH_C3[2] * y * (4 * zz - xx - yy)
445 | result[..., 12] = SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy)
446 | result[..., 13] = SH_C3[4] * x * (4 * zz - xx - yy)
447 | result[..., 14] = SH_C3[5] * z * (xx - yy)
448 | result[..., 15] = SH_C3[6] * x * (xx - 3 * yy)
449 |
450 | if basis_dim > 16:
451 | result[..., 16] = SH_C4[0] * xy * (xx - yy)
452 | result[..., 17] = SH_C4[1] * yz * (3 * xx - yy)
453 | result[..., 18] = SH_C4[2] * xy * (7 * zz - 1)
454 | result[..., 19] = SH_C4[3] * yz * (7 * zz - 3)
455 | result[..., 20] = SH_C4[4] * (zz * (35 * zz - 30) + 3)
456 | result[..., 21] = SH_C4[5] * xz * (7 * zz - 3)
457 | result[..., 22] = SH_C4[6] * (xx - yy) * (7 * zz - 1)
458 | result[..., 23] = SH_C4[7] * xz * (xx - 3 * yy)
459 | result[..., 24] = SH_C4[8] * \
460 | (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
461 | return result
462 |
--------------------------------------------------------------------------------
/models/nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class Embedding(nn.Module):
6 | def __init__(self, N_freqs, logscale=True):
7 | """
8 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
9 | in_channels: number of input channels (3 for both xyz and direction)
10 | """
11 | super().__init__()
12 | self.N_freqs = N_freqs
13 | self.funcs = [torch.sin, torch.cos]
14 |
15 | if logscale:
16 | self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs)
17 | else:
18 | self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs)
19 |
20 | def forward(self, x):
21 | """
22 | Embeds x to (x, sin(2^k x), cos(2^k x), ...)
23 | Different from the paper, "x" is also in the output
24 | See https://github.com/bmild/nerf/issues/12
25 |
26 | Inputs:
27 | x: (B, f)
28 |
29 | Outputs:
30 | out: (B, 2*f*N_freqs+f)
31 | """
32 | out = [x]
33 | for freq in self.freq_bands:
34 | for func in self.funcs:
35 | out += [func(freq * x)]
36 |
37 | return torch.cat(out, -1)
38 |
39 |
40 | class NeRF(nn.Module):
41 | def __init__(self, D=8, W=256, in_channels_xyz=63, in_channels_dir=27, skips=[4]):
42 | """
43 | D: number of layers for density (sigma) encoder
44 | W: number of hidden units in each layer
45 | in_channels_xyz: number of input channels for xyz (3+3*10*2=63 by default)
46 | in_channels_dir: number of input channels for direction (3+3*4*2=27 by default)
47 | skips: add skip connection in the Dth layer
48 | """
49 | super(NeRF, self).__init__()
50 | self.D = D
51 | self.W = W
52 | self.in_channels_xyz = in_channels_xyz
53 | self.in_channels_dir = in_channels_dir
54 | self.skips = skips
55 |
56 | # xyz encoding layers
57 | for i in range(D):
58 | if i == 0:
59 | layer = nn.Linear(in_channels_xyz, W)
60 | elif i in skips:
61 | layer = nn.Linear(W + in_channels_xyz, W)
62 | else:
63 | layer = nn.Linear(W, W)
64 | layer = nn.Sequential(layer, nn.ReLU(True))
65 | setattr(self, f"xyz_encoding_{i+1}", layer)
66 | self.xyz_encoding_final = nn.Linear(W, W)
67 |
68 | # direction encoding layers
69 | self.dir_encoding = nn.Sequential(nn.Linear(W + in_channels_dir, W // 2), nn.ReLU(True))
70 |
71 | # output layers
72 | self.sigma = nn.Linear(W, 1)
73 | self.rgb = nn.Sequential(nn.Linear(W // 2, 3), nn.Sigmoid())
74 |
75 | def forward(self, x, sigma_only=False):
76 | """
77 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet).
78 | For rendering this ray, please see rendering.py
79 |
80 | Inputs:
81 | x: (B, self.in_channels_xyz(+self.in_channels_dir))
82 | the embedded vector of position and direction
83 | sigma_only: whether to infer sigma only. If True,
84 | x is of shape (B, self.in_channels_xyz)
85 |
86 | Outputs:
87 | if sigma_ony:
88 | sigma: (B, 1) sigma
89 | else:
90 | out: (B, 4), rgb and sigma
91 | """
92 | if not sigma_only:
93 | input_xyz, input_dir = torch.split(x, [self.in_channels_xyz, self.in_channels_dir], dim=-1)
94 | else:
95 | input_xyz = x
96 |
97 | xyz_ = input_xyz
98 | for i in range(self.D):
99 | if i in self.skips:
100 | xyz_ = torch.cat([input_xyz, xyz_], -1)
101 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)
102 |
103 | sigma = self.sigma(xyz_)
104 | if sigma_only:
105 | return sigma
106 |
107 | xyz_encoding_final = self.xyz_encoding_final(xyz_)
108 |
109 | dir_encoding_input = torch.cat([xyz_encoding_final, input_dir], -1)
110 | dir_encoding = self.dir_encoding(dir_encoding_input)
111 | rgb = self.rgb(dir_encoding)
112 |
113 | out = torch.cat([rgb, sigma], -1)
114 |
115 | return out
116 |
--------------------------------------------------------------------------------
/models/rendering.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange, reduce, repeat
4 |
5 |
6 | __all__ = ["render_rays"]
7 |
8 |
9 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5):
10 | """
11 | Sample @N_importance samples from @bins with distribution defined by @weights.
12 | Inputs:
13 | bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
14 | weights: (N_rays, N_samples_)
15 | N_importance: the number of samples to draw from the distribution
16 | det: deterministic or not
17 | eps: a small number to prevent division by zero
18 | Outputs:
19 | samples: (N_rays, N_importance) the sampled samples
20 | """
21 | N_rays, N_samples_ = weights.shape
22 | weights = weights + eps # prevent division by zero (don't do inplace op!)
23 | pdf = weights / reduce(weights, "n1 n2 -> n1 1",
24 | "sum") # (N_rays, N_samples_)
25 | # (N_rays, N_samples), cumulative distribution function
26 | cdf = torch.cumsum(pdf, -1)
27 | # (N_rays, N_samples_+1)
28 | cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1)
29 | # padded to 0~1 inclusive
30 |
31 | if det:
32 | u = torch.linspace(0, 1, N_importance, device=bins.device)
33 | u = u.expand(N_rays, N_importance)
34 | else:
35 | u = torch.rand(N_rays, N_importance, device=bins.device)
36 | u = u.contiguous()
37 |
38 | inds = torch.searchsorted(cdf, u, right=True)
39 | below = torch.clamp_min(inds - 1, 0)
40 | above = torch.clamp_max(inds, N_samples_)
41 |
42 | inds_sampled = rearrange(torch.stack(
43 | [below, above], -1), "n1 n2 c -> n1 (n2 c)", c=2)
44 | cdf_g = rearrange(torch.gather(cdf, 1, inds_sampled),
45 | "n1 (n2 c) -> n1 n2 c", c=2)
46 | bins_g = rearrange(torch.gather(bins, 1, inds_sampled),
47 | "n1 (n2 c) -> n1 n2 c", c=2)
48 |
49 | denom = cdf_g[..., 1] - cdf_g[..., 0]
50 | denom[denom < eps] = 1 # denom equals 0 means a bin has weight 0,
51 | # in which case it will not be sampled
52 | # anyway, therefore any value for it is fine (set to 1 here)
53 |
54 | samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / \
55 | denom * (bins_g[..., 1] - bins_g[..., 0])
56 | return samples
57 |
58 |
59 | def render_rays(
60 | models,
61 | embeddings,
62 | rays,
63 | N_samples=64,
64 | use_disp=False,
65 | perturb=0,
66 | noise_std=1,
67 | N_importance=0,
68 | chunk=1024 * 32,
69 | white_back=False,
70 | test_time=False,
71 | **kwargs,
72 | ):
73 | """
74 | Render rays by computing the output of @model applied on @rays
75 | Inputs:
76 | models: list of NeRF models (coarse and fine) defined in nerf.py
77 | embeddings: list of embedding models of origin and direction defined in nerf.py
78 | rays: (N_rays, 3+3+2), ray origins and directions, near and far depths
79 | N_samples: number of coarse samples per ray
80 | use_disp: whether to sample in disparity space (inverse depth)
81 | perturb: factor to perturb the sampling position on the ray (for coarse model only)
82 | noise_std: factor to perturb the model's prediction of sigma
83 | N_importance: number of fine samples per ray
84 | chunk: the chunk size in batched inference
85 | white_back: whether the background is white (dataset dependent)
86 | test_time: whether it is test (inference only) or not. If True, it will not do inference
87 | on coarse rgb to save time
88 | Outputs:
89 | result: dictionary containing final rgb and depth maps for coarse and fine models
90 | """
91 |
92 | def inference(results, model, typ, xyz, z_vals, test_time=False, **kwargs):
93 | """
94 | Helper function that performs model inference.
95 | Inputs:
96 | results: a dict storing all results
97 | model: NeRF model (coarse or fine)
98 | typ: 'coarse' or 'fine'
99 | xyz: (N_rays, N_samples_, 3) sampled positions
100 | N_samples_ is the number of sampled points in each ray;
101 | = N_samples for coarse model
102 | = N_samples+N_importance for fine model
103 | z_vals: (N_rays, N_samples_) depths of the sampled positions
104 | test_time: test time or not
105 | Outputs:
106 | if weights_only:
107 | weights: (N_rays, N_samples_): weights of each sample
108 | else:
109 | rgb_final: (N_rays, 3) the final rgb image
110 | depth_final: (N_rays) depth map
111 | weights: (N_rays, N_samples_): weights of each sample
112 | """
113 | N_samples_ = xyz.shape[1]
114 | xyz_ = rearrange(xyz, "n1 n2 c -> (n1 n2) c") # (N_rays*N_samples_, 3)
115 |
116 | # Perform model inference to get rgb and raw sigma
117 | B = xyz_.shape[0]
118 | out_chunks = []
119 | if typ == "coarse" and test_time and "fine" in models:
120 | for i in range(0, B, chunk):
121 | xyz_embedded = embedding_xyz(xyz_[i: i + chunk])
122 | out_chunks += [model(xyz_embedded, sigma_only=True)]
123 |
124 | out = torch.cat(out_chunks, 0)
125 | sigmas = rearrange(out, "(n1 n2) 1 -> n1 n2",
126 | n1=N_rays, n2=N_samples_)
127 | else: # infer rgb and sigma and others
128 | dir_embedded_ = repeat(
129 | dir_embedded, "n1 c -> (n1 n2) c", n2=N_samples_)
130 | if kwargs["use_sh_feat"]:
131 | viewdirs = repeat(kwargs['viewdirs'], 'n1 c -> (n1 n2) c', n2=N_samples_)
132 | # (N_rays*N_samples_, embed_dir_channels)
133 | for i in range(0, B, chunk):
134 | xyz_embedded = embedding_xyz(xyz_[i: i + chunk])
135 | xyzdir_embedded = torch.cat(
136 | [xyz_embedded, dir_embedded_[i: i + chunk]], 1)
137 | # out_chunks += [model(xyzdir_embedded, sigma_only=False)]
138 | # ! cloud nerf fwd
139 | indices = torch.zeros(1, dtype=torch.long).cuda()
140 | if kwargs["use_sh_feat"]:
141 | out_chunks += [model(indices, xyz_[i:i + chunk], viewdirs[i:i + chunk],
142 | xyzdir_embedded)]
143 | else:
144 | out_chunks += [model(indices, xyz_[i: i +
145 | chunk], xyzdir_embedded)]
146 |
147 | out = torch.cat(out_chunks, 0)
148 | # out = out.view(N_rays, N_samples_, 4)
149 | out = rearrange(out, "(n1 n2) c -> n1 n2 c",
150 | n1=N_rays, n2=N_samples_, c=4)
151 | # (N_rays, N_samples_, 3) # ! APPLY SIGMOID HERE NOT IN THE NETWORK
152 | rgbs = torch.sigmoid(out[..., :3])
153 | sigmas = out[..., 3] # (N_rays, N_samples_)
154 |
155 | # Convert these values using volume rendering (Section 4)
156 | deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples_-1)
157 | # (N_rays, 1) the last delta is infinity
158 | delta_inf = 1e10 * torch.ones_like(deltas[:, :1])
159 | deltas = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_)
160 |
161 | # compute alpha by the formula (3)
162 | noise = torch.randn_like(sigmas) * noise_std
163 | # (N_rays, N_samples_)
164 | # alphas = 1 - torch.exp(-deltas * torch.relu(sigmas + noise))
165 | density_bias = -1
166 | alphas = 1 - torch.exp(-deltas *
167 | F.softplus(sigmas + noise + density_bias))
168 |
169 | alphas_shifted = torch.cat([torch.ones_like(
170 | alphas[:, :1]), 1 - alphas + 1e-10], -1) # [1, 1-a1, 1-a2, ...]
171 | # (N_rays, N_samples_)
172 | weights = alphas * torch.cumprod(alphas_shifted[:, :-1], -1)
173 | # (N_rays), the accumulated opacity along the rays
174 | weights_sum = reduce(weights, "n1 n2 -> n1", "sum")
175 | # equals "1 - (1-a1)(1-a2)...(1-an)" mathematically
176 |
177 | results[f"weights_{typ}"] = weights
178 | results[f"opacity_{typ}"] = weights_sum
179 | results[f"z_vals_{typ}"] = z_vals
180 | if test_time and typ == "coarse" and "fine" in models:
181 | return
182 |
183 | rgb_map = reduce(rearrange(weights, "n1 n2 -> n1 n2 1")
184 | * rgbs, "n1 n2 c -> n1 c", "sum")
185 | depth_map = reduce(weights * z_vals, "n1 n2 -> n1", "sum")
186 |
187 | if white_back:
188 | rgb_map += 1 - weights_sum.unsqueeze(1)
189 |
190 | results[f"rgb_{typ}"] = rgb_map
191 | results[f"depth_{typ}"] = depth_map
192 |
193 | return
194 |
195 | embedding_xyz, embedding_dir = embeddings["xyz"], embeddings["dir"]
196 |
197 | # Decompose the inputs #! Note that rays_o,d are ndc, viewdirs are normalized world space
198 | N_rays = rays.shape[0]
199 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
200 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)
201 | viewdirs = rays[:, 8:] # (N_rays, 3)
202 | # Embed direction
203 | # (N_rays, embed_dir_channels)
204 | if kwargs["use_sh_feat"]:
205 | kwargs['viewdirs'] = viewdirs
206 | dir_embedded = embedding_dir(kwargs.get("view_dir", viewdirs))
207 |
208 | rays_o = rearrange(rays_o, "n1 c -> n1 1 c")
209 | rays_d = rearrange(rays_d, "n1 c -> n1 1 c")
210 |
211 | # Sample depth points
212 | z_steps = torch.linspace(
213 | 0, 1, N_samples, device=rays.device) # (N_samples)
214 | if not use_disp: # use linear sampling in depth space
215 | z_vals = near * (1 - z_steps) + far * z_steps
216 | else: # use linear sampling in disparity space
217 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)
218 |
219 | z_vals = z_vals.expand(N_rays, N_samples)
220 |
221 | if perturb > 0: # perturb sampling depths (z_vals)
222 | # (N_rays, N_samples-1) interval mid points
223 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:])
224 | # get intervals between samples
225 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)
226 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)
227 |
228 | perturb_rand = perturb * torch.rand_like(z_vals)
229 | z_vals = lower + (upper - lower) * perturb_rand
230 |
231 | xyz_coarse = rays_o + rays_d * rearrange(z_vals, "n1 n2 -> n1 n2 1")
232 | results = {}
233 |
234 | if N_importance == 0: # sample points for fine model
235 | inference(results, models["coarse"], "coarse",
236 | xyz_coarse, z_vals, test_time, **kwargs)
237 | else:
238 | with torch.no_grad():
239 | results["weights_coarse"] = models["fine"].code_cloud.get_proposal(
240 | xyz_coarse)
241 |
242 | # breakpoint()
243 | # np.save('code_ndc.npy', models['fine'].code_cloud.origin_keypoints.cpu().numpy())
244 | if N_importance > 0: # sample points for fine model
245 | # (N_rays, N_samples-1) interval mid points
246 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:])
247 | z_vals_ = sample_pdf(
248 | z_vals_mid, results["weights_coarse"][:, 1:-1].detach(), N_importance, det=(perturb == 0))
249 | # detach so that grad doesn't propogate to weights_coarse from here
250 |
251 | z_vals = torch.sort(torch.cat([z_vals, z_vals_], -1), -1)[0]
252 | # combine coarse and fine samples
253 |
254 | xyz_fine = rays_o + rays_d * rearrange(z_vals, "n1 n2 -> n1 n2 1")
255 |
256 | inference(results, models["fine"], "fine",
257 | xyz_fine, z_vals, test_time, **kwargs)
258 |
259 | return results
260 |
--------------------------------------------------------------------------------
/opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_opts():
5 | parser = argparse.ArgumentParser()
6 |
7 | parser.add_argument(
8 | "--root_dir",
9 | type=str,
10 | default="/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego",
11 | help="root directory of dataset",
12 | )
13 | parser.add_argument(
14 | "--dataset_name", type=str, default="blender", choices=["blender", "llff"], help="which dataset to train/val"
15 | )
16 | parser.add_argument(
17 | "--img_wh", nargs="+", type=int, default=[800, 800], help="resolution (img_w, img_h) of the image"
18 | )
19 | parser.add_argument(
20 | "--spheric_poses",
21 | default=False,
22 | action="store_true",
23 | help="whether images are taken in spheric poses (for llff)",
24 | )
25 |
26 | parser.add_argument(
27 | "--not_use_mvs",
28 | default=False,
29 | action="store_true",
30 | help="whether exclude the mvs points cloud",
31 | )
32 |
33 | parser.add_argument(
34 | "--use_sh_feat",
35 | default=False,
36 | action="store_true",
37 | help="whether use spherical harmonics features",
38 | )
39 |
40 | parser.add_argument("--N_emb_xyz", type=int, default=10,
41 | help="number of frequencies in xyz positional encoding")
42 | parser.add_argument("--N_emb_dir", type=int, default=4,
43 | help="number of frequencies in dir positional encoding")
44 | parser.add_argument("--N_samples", type=int, default=128,
45 | help="number of coarse samples")
46 | parser.add_argument("--N_importance", type=int, default=0,
47 | help="number of additional fine samples")
48 | parser.add_argument("--use_disp", default=False,
49 | action="store_true", help="use disparity depth sampling")
50 | parser.add_argument("--perturb", type=float, default=1.0,
51 | help="factor to perturb depth sampling points")
52 | parser.add_argument("--noise_std", type=float, default=1.0,
53 | help="std dev of noise added to regularize sigma")
54 |
55 | parser.add_argument("--batch_size", type=int,
56 | default=1024, help="batch size")
57 | parser.add_argument("--chunk", type=int, default=32 *
58 | 1024, help="chunk size to split the input to avoid OOM")
59 | parser.add_argument("--num_epochs", type=int, default=16,
60 | help="number of training epochs")
61 | parser.add_argument("--num_gpus", type=int,
62 | default=1, help="number of gpus")
63 |
64 | parser.add_argument(
65 | "--ckpt_path", type=str, default=None, help="pretrained checkpoint to load (including optimizers, etc)"
66 | )
67 | parser.add_argument(
68 | "--prefixes_to_ignore",
69 | nargs="+",
70 | type=str,
71 | default=["loss"],
72 | help="the prefixes to ignore in the checkpoint state dict",
73 | )
74 | parser.add_argument(
75 | "--weight_path", type=str, default=None, help="pretrained model weight to load (do not load optimizers, etc)"
76 | )
77 |
78 | parser.add_argument(
79 | "--optimizer", type=str, default="adam", help="optimizer type", choices=["sgd", "adam", "radam", "ranger"]
80 | )
81 | parser.add_argument("--lr", type=float, default=5e-4, help="learning rate")
82 | parser.add_argument("--momentum", type=float,
83 | default=0.9, help="learning rate momentum")
84 | parser.add_argument("--weight_decay", type=float,
85 | default=0, help="weight decay")
86 | parser.add_argument(
87 | "--lr_scheduler", type=str, default="steplr", help="scheduler type", choices=["steplr", "cosine", "poly"]
88 | )
89 | # params for warmup, only applied when optimizer == 'sgd' or 'adam'
90 | parser.add_argument(
91 | "--warmup_multiplier", type=float, default=1.0, help="lr is multiplied by this factor after --warmup_epochs"
92 | )
93 | parser.add_argument(
94 | "--warmup_epochs", type=int, default=0, help="Gradually warm-up(increasing) learning rate in optimizer"
95 | )
96 | ###########################
97 | # params for steplr
98 | parser.add_argument("--decay_step", nargs="+", type=int,
99 | default=[20], help="scheduler decay step")
100 | parser.add_argument("--decay_gamma", type=float,
101 | default=0.1, help="learning rate decay amount")
102 | ###########################
103 | # params for poly
104 | parser.add_argument("--poly_exp", type=float, default=0.9,
105 | help="exponent for polynomial learning rate decay")
106 | ###########################
107 |
108 | parser.add_argument("--exp_name", type=str,
109 | default="exp", help="experiment name")
110 |
111 | return parser.parse_args()
112 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/__init__.py:
--------------------------------------------------------------------------------
1 | import pointnet2_ops.pointnet2_modules
2 | import pointnet2_ops.pointnet2_utils
3 | from pointnet2_ops._version import __version__
4 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
5 | const int nsample);
6 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include
9 | #include
10 |
11 | #include
12 |
13 | #define TOTAL_THREADS 512
14 |
15 | inline int opt_n_threads(int work_size) {
16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
17 |
18 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
19 | }
20 |
21 | inline dim3 opt_block_config(int x, int y) {
22 | const int x_threads = opt_n_threads(x);
23 | const int y_threads =
24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
25 | dim3 block_config(x_threads, y_threads, 1);
26 |
27 | return block_config;
28 | }
29 |
30 | #define CUDA_CHECK_ERRORS() \
31 | do { \
32 | cudaError_t err = cudaGetLastError(); \
33 | if (cudaSuccess != err) { \
34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
36 | __FILE__); \
37 | exit(-1); \
38 | } \
39 | } while (0)
40 |
41 | #endif
42 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor group_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows);
7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
8 | at::Tensor weight);
9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
10 | at::Tensor weight, const int m);
11 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
7 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #define CHECK_CUDA(x) \
6 | do { \
7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \
8 | } while (0)
9 |
10 | #define CHECK_CONTIGUOUS(x) \
11 | do { \
12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \
13 | } while (0)
14 |
15 | #define CHECK_IS_INT(x) \
16 | do { \
17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \
18 | #x " must be an int tensor"); \
19 | } while (0)
20 |
21 | #define CHECK_IS_FLOAT(x) \
22 | do { \
23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \
24 | #x " must be a float tensor"); \
25 | } while (0)
26 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "utils.h"
3 |
4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
5 | int nsample, const float *new_xyz,
6 | const float *xyz, int *idx);
7 |
8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
9 | const int nsample) {
10 | CHECK_CONTIGUOUS(new_xyz);
11 | CHECK_CONTIGUOUS(xyz);
12 | CHECK_IS_FLOAT(new_xyz);
13 | CHECK_IS_FLOAT(xyz);
14 |
15 | if (new_xyz.is_cuda()) {
16 | CHECK_CUDA(xyz);
17 | }
18 |
19 | at::Tensor idx =
20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int));
22 |
23 | if (new_xyz.is_cuda()) {
24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
25 | radius, nsample, new_xyz.data_ptr(),
26 | xyz.data_ptr(), idx.data_ptr());
27 | } else {
28 | AT_ASSERT(false, "CPU not supported");
29 | }
30 |
31 | return idx;
32 | }
33 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: new_xyz(b, m, 3) xyz(b, n, 3)
8 | // output: idx(b, m, nsample)
9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
10 | int nsample,
11 | const float *__restrict__ new_xyz,
12 | const float *__restrict__ xyz,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | xyz += batch_index * n * 3;
16 | new_xyz += batch_index * m * 3;
17 | idx += m * nsample * batch_index;
18 |
19 | int index = threadIdx.x;
20 | int stride = blockDim.x;
21 |
22 | float radius2 = radius * radius;
23 | for (int j = index; j < m; j += stride) {
24 | float new_x = new_xyz[j * 3 + 0];
25 | float new_y = new_xyz[j * 3 + 1];
26 | float new_z = new_xyz[j * 3 + 2];
27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
28 | float x = xyz[k * 3 + 0];
29 | float y = xyz[k * 3 + 1];
30 | float z = xyz[k * 3 + 2];
31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
32 | (new_z - z) * (new_z - z);
33 | if (d2 < radius2) {
34 | if (cnt == 0) {
35 | for (int l = 0; l < nsample; ++l) {
36 | idx[j * nsample + l] = k;
37 | }
38 | }
39 | idx[j * nsample + cnt] = k;
40 | ++cnt;
41 | }
42 | }
43 | }
44 | }
45 |
46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
47 | int nsample, const float *new_xyz,
48 | const float *xyz, int *idx) {
49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
50 | query_ball_point_kernel<<>>(
51 | b, n, m, radius, nsample, new_xyz, xyz, idx);
52 |
53 | CUDA_CHECK_ERRORS();
54 | }
55 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "group_points.h"
3 | #include "interpolate.h"
4 | #include "sampling.h"
5 |
6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7 | m.def("gather_points", &gather_points);
8 | m.def("gather_points_grad", &gather_points_grad);
9 | m.def("furthest_point_sampling", &furthest_point_sampling);
10 |
11 | m.def("three_nn", &three_nn);
12 | m.def("three_interpolate", &three_interpolate);
13 | m.def("three_interpolate_grad", &three_interpolate_grad);
14 |
15 | m.def("ball_query", &ball_query);
16 |
17 | m.def("group_points", &group_points);
18 | m.def("group_points_grad", &group_points_grad);
19 | }
20 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | #include "group_points.h"
2 | #include "utils.h"
3 |
4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
5 | const float *points, const int *idx,
6 | float *out);
7 |
8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
9 | int nsample, const float *grad_out,
10 | const int *idx, float *grad_points);
11 |
12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) {
13 | CHECK_CONTIGUOUS(points);
14 | CHECK_CONTIGUOUS(idx);
15 | CHECK_IS_FLOAT(points);
16 | CHECK_IS_INT(idx);
17 |
18 | if (points.is_cuda()) {
19 | CHECK_CUDA(idx);
20 | }
21 |
22 | at::Tensor output =
23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
24 | at::device(points.device()).dtype(at::ScalarType::Float));
25 |
26 | if (points.is_cuda()) {
27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
28 | idx.size(1), idx.size(2),
29 | points.data_ptr(), idx.data_ptr(),
30 | output.data_ptr());
31 | } else {
32 | AT_ASSERT(false, "CPU not supported");
33 | }
34 |
35 | return output;
36 | }
37 |
38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
39 | CHECK_CONTIGUOUS(grad_out);
40 | CHECK_CONTIGUOUS(idx);
41 | CHECK_IS_FLOAT(grad_out);
42 | CHECK_IS_INT(idx);
43 |
44 | if (grad_out.is_cuda()) {
45 | CHECK_CUDA(idx);
46 | }
47 |
48 | at::Tensor output =
49 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
50 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
51 |
52 | if (grad_out.is_cuda()) {
53 | group_points_grad_kernel_wrapper(
54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
55 | grad_out.data_ptr(), idx.data_ptr(),
56 | output.data_ptr());
57 | } else {
58 | AT_ASSERT(false, "CPU not supported");
59 | }
60 |
61 | return output;
62 | }
63 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 |
6 | // input: points(b, c, n) idx(b, npoints, nsample)
7 | // output: out(b, c, npoints, nsample)
8 | __global__ void group_points_kernel(int b, int c, int n, int npoints,
9 | int nsample,
10 | const float *__restrict__ points,
11 | const int *__restrict__ idx,
12 | float *__restrict__ out) {
13 | int batch_index = blockIdx.x;
14 | points += batch_index * n * c;
15 | idx += batch_index * npoints * nsample;
16 | out += batch_index * npoints * nsample * c;
17 |
18 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
19 | const int stride = blockDim.y * blockDim.x;
20 | for (int i = index; i < c * npoints; i += stride) {
21 | const int l = i / npoints;
22 | const int j = i % npoints;
23 | for (int k = 0; k < nsample; ++k) {
24 | int ii = idx[j * nsample + k];
25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii];
26 | }
27 | }
28 | }
29 |
30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
31 | const float *points, const int *idx,
32 | float *out) {
33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
34 |
35 | group_points_kernel<<>>(
36 | b, c, n, npoints, nsample, points, idx, out);
37 |
38 | CUDA_CHECK_ERRORS();
39 | }
40 |
41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
42 | // output: grad_points(b, c, n)
43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
44 | int nsample,
45 | const float *__restrict__ grad_out,
46 | const int *__restrict__ idx,
47 | float *__restrict__ grad_points) {
48 | int batch_index = blockIdx.x;
49 | grad_out += batch_index * npoints * nsample * c;
50 | idx += batch_index * npoints * nsample;
51 | grad_points += batch_index * n * c;
52 |
53 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
54 | const int stride = blockDim.y * blockDim.x;
55 | for (int i = index; i < c * npoints; i += stride) {
56 | const int l = i / npoints;
57 | const int j = i % npoints;
58 | for (int k = 0; k < nsample; ++k) {
59 | int ii = idx[j * nsample + k];
60 | atomicAdd(grad_points + l * n + ii,
61 | grad_out[(l * npoints + j) * nsample + k]);
62 | }
63 | }
64 | }
65 |
66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
67 | int nsample, const float *grad_out,
68 | const int *idx, float *grad_points) {
69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
70 |
71 | group_points_grad_kernel<<>>(
72 | b, c, n, npoints, nsample, grad_out, idx, grad_points);
73 |
74 | CUDA_CHECK_ERRORS();
75 | }
76 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | #include "interpolate.h"
2 | #include "utils.h"
3 |
4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
5 | const float *known, float *dist2, int *idx);
6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
7 | const float *points, const int *idx,
8 | const float *weight, float *out);
9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
10 | const float *grad_out,
11 | const int *idx, const float *weight,
12 | float *grad_points);
13 |
14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) {
15 | CHECK_CONTIGUOUS(unknowns);
16 | CHECK_CONTIGUOUS(knows);
17 | CHECK_IS_FLOAT(unknowns);
18 | CHECK_IS_FLOAT(knows);
19 |
20 | if (unknowns.is_cuda()) {
21 | CHECK_CUDA(knows);
22 | }
23 |
24 | at::Tensor idx =
25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
26 | at::device(unknowns.device()).dtype(at::ScalarType::Int));
27 | at::Tensor dist2 =
28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
29 | at::device(unknowns.device()).dtype(at::ScalarType::Float));
30 |
31 | if (unknowns.is_cuda()) {
32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
33 | unknowns.data_ptr(), knows.data_ptr(),
34 | dist2.data_ptr(), idx.data_ptr());
35 | } else {
36 | AT_ASSERT(false, "CPU not supported");
37 | }
38 |
39 | return {dist2, idx};
40 | }
41 |
42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
43 | at::Tensor weight) {
44 | CHECK_CONTIGUOUS(points);
45 | CHECK_CONTIGUOUS(idx);
46 | CHECK_CONTIGUOUS(weight);
47 | CHECK_IS_FLOAT(points);
48 | CHECK_IS_INT(idx);
49 | CHECK_IS_FLOAT(weight);
50 |
51 | if (points.is_cuda()) {
52 | CHECK_CUDA(idx);
53 | CHECK_CUDA(weight);
54 | }
55 |
56 | at::Tensor output =
57 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
58 | at::device(points.device()).dtype(at::ScalarType::Float));
59 |
60 | if (points.is_cuda()) {
61 | three_interpolate_kernel_wrapper(
62 | points.size(0), points.size(1), points.size(2), idx.size(1),
63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(),
64 | output.data_ptr());
65 | } else {
66 | AT_ASSERT(false, "CPU not supported");
67 | }
68 |
69 | return output;
70 | }
71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
72 | at::Tensor weight, const int m) {
73 | CHECK_CONTIGUOUS(grad_out);
74 | CHECK_CONTIGUOUS(idx);
75 | CHECK_CONTIGUOUS(weight);
76 | CHECK_IS_FLOAT(grad_out);
77 | CHECK_IS_INT(idx);
78 | CHECK_IS_FLOAT(weight);
79 |
80 | if (grad_out.is_cuda()) {
81 | CHECK_CUDA(idx);
82 | CHECK_CUDA(weight);
83 | }
84 |
85 | at::Tensor output =
86 | torch::zeros({grad_out.size(0), grad_out.size(1), m},
87 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
88 |
89 | if (grad_out.is_cuda()) {
90 | three_interpolate_grad_kernel_wrapper(
91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
92 | grad_out.data_ptr(), idx.data_ptr(),
93 | weight.data_ptr(), output.data_ptr());
94 | } else {
95 | AT_ASSERT(false, "CPU not supported");
96 | }
97 |
98 | return output;
99 | }
100 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: unknown(b, n, 3) known(b, m, 3)
8 | // output: dist2(b, n, 3), idx(b, n, 3)
9 | __global__ void three_nn_kernel(int b, int n, int m,
10 | const float *__restrict__ unknown,
11 | const float *__restrict__ known,
12 | float *__restrict__ dist2,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | unknown += batch_index * n * 3;
16 | known += batch_index * m * 3;
17 | dist2 += batch_index * n * 3;
18 | idx += batch_index * n * 3;
19 |
20 | int index = threadIdx.x;
21 | int stride = blockDim.x;
22 | for (int j = index; j < n; j += stride) {
23 | float ux = unknown[j * 3 + 0];
24 | float uy = unknown[j * 3 + 1];
25 | float uz = unknown[j * 3 + 2];
26 |
27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
28 | int besti1 = 0, besti2 = 0, besti3 = 0;
29 | for (int k = 0; k < m; ++k) {
30 | float x = known[k * 3 + 0];
31 | float y = known[k * 3 + 1];
32 | float z = known[k * 3 + 2];
33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
34 | if (d < best1) {
35 | best3 = best2;
36 | besti3 = besti2;
37 | best2 = best1;
38 | besti2 = besti1;
39 | best1 = d;
40 | besti1 = k;
41 | } else if (d < best2) {
42 | best3 = best2;
43 | besti3 = besti2;
44 | best2 = d;
45 | besti2 = k;
46 | } else if (d < best3) {
47 | best3 = d;
48 | besti3 = k;
49 | }
50 | }
51 | dist2[j * 3 + 0] = best1;
52 | dist2[j * 3 + 1] = best2;
53 | dist2[j * 3 + 2] = best3;
54 |
55 | idx[j * 3 + 0] = besti1;
56 | idx[j * 3 + 1] = besti2;
57 | idx[j * 3 + 2] = besti3;
58 | }
59 | }
60 |
61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
62 | const float *known, float *dist2, int *idx) {
63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
64 | three_nn_kernel<<>>(b, n, m, unknown, known,
65 | dist2, idx);
66 |
67 | CUDA_CHECK_ERRORS();
68 | }
69 |
70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
71 | // output: out(b, c, n)
72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n,
73 | const float *__restrict__ points,
74 | const int *__restrict__ idx,
75 | const float *__restrict__ weight,
76 | float *__restrict__ out) {
77 | int batch_index = blockIdx.x;
78 | points += batch_index * m * c;
79 |
80 | idx += batch_index * n * 3;
81 | weight += batch_index * n * 3;
82 |
83 | out += batch_index * n * c;
84 |
85 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
86 | const int stride = blockDim.y * blockDim.x;
87 | for (int i = index; i < c * n; i += stride) {
88 | const int l = i / n;
89 | const int j = i % n;
90 | float w1 = weight[j * 3 + 0];
91 | float w2 = weight[j * 3 + 1];
92 | float w3 = weight[j * 3 + 2];
93 |
94 | int i1 = idx[j * 3 + 0];
95 | int i2 = idx[j * 3 + 1];
96 | int i3 = idx[j * 3 + 2];
97 |
98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
99 | points[l * m + i3] * w3;
100 | }
101 | }
102 |
103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
104 | const float *points, const int *idx,
105 | const float *weight, float *out) {
106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
107 | three_interpolate_kernel<<>>(
108 | b, c, m, n, points, idx, weight, out);
109 |
110 | CUDA_CHECK_ERRORS();
111 | }
112 |
113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
114 | // output: grad_points(b, c, m)
115 |
116 | __global__ void three_interpolate_grad_kernel(
117 | int b, int c, int n, int m, const float *__restrict__ grad_out,
118 | const int *__restrict__ idx, const float *__restrict__ weight,
119 | float *__restrict__ grad_points) {
120 | int batch_index = blockIdx.x;
121 | grad_out += batch_index * n * c;
122 | idx += batch_index * n * 3;
123 | weight += batch_index * n * 3;
124 | grad_points += batch_index * m * c;
125 |
126 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
127 | const int stride = blockDim.y * blockDim.x;
128 | for (int i = index; i < c * n; i += stride) {
129 | const int l = i / n;
130 | const int j = i % n;
131 | float w1 = weight[j * 3 + 0];
132 | float w2 = weight[j * 3 + 1];
133 | float w3 = weight[j * 3 + 2];
134 |
135 | int i1 = idx[j * 3 + 0];
136 | int i2 = idx[j * 3 + 1];
137 | int i3 = idx[j * 3 + 2];
138 |
139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
142 | }
143 | }
144 |
145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
146 | const float *grad_out,
147 | const int *idx, const float *weight,
148 | float *grad_points) {
149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
150 | three_interpolate_grad_kernel<<>>(
151 | b, c, n, m, grad_out, idx, weight, grad_points);
152 |
153 | CUDA_CHECK_ERRORS();
154 | }
155 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | #include "sampling.h"
2 | #include "utils.h"
3 |
4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
5 | const float *points, const int *idx,
6 | float *out);
7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
8 | const float *grad_out, const int *idx,
9 | float *grad_points);
10 |
11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
12 | const float *dataset, float *temp,
13 | int *idxs);
14 |
15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
16 | CHECK_CONTIGUOUS(points);
17 | CHECK_CONTIGUOUS(idx);
18 | CHECK_IS_FLOAT(points);
19 | CHECK_IS_INT(idx);
20 |
21 | if (points.is_cuda()) {
22 | CHECK_CUDA(idx);
23 | }
24 |
25 | at::Tensor output =
26 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
27 | at::device(points.device()).dtype(at::ScalarType::Float));
28 |
29 | if (points.is_cuda()) {
30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
31 | idx.size(1), points.data_ptr(),
32 | idx.data_ptr(), output.data_ptr());
33 | } else {
34 | AT_ASSERT(false, "CPU not supported");
35 | }
36 |
37 | return output;
38 | }
39 |
40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
41 | const int n) {
42 | CHECK_CONTIGUOUS(grad_out);
43 | CHECK_CONTIGUOUS(idx);
44 | CHECK_IS_FLOAT(grad_out);
45 | CHECK_IS_INT(idx);
46 |
47 | if (grad_out.is_cuda()) {
48 | CHECK_CUDA(idx);
49 | }
50 |
51 | at::Tensor output =
52 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
53 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
54 |
55 | if (grad_out.is_cuda()) {
56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
57 | idx.size(1), grad_out.data_ptr(),
58 | idx.data_ptr(),
59 | output.data_ptr());
60 | } else {
61 | AT_ASSERT(false, "CPU not supported");
62 | }
63 |
64 | return output;
65 | }
66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
67 | CHECK_CONTIGUOUS(points);
68 | CHECK_IS_FLOAT(points);
69 |
70 | at::Tensor output =
71 | torch::zeros({points.size(0), nsamples},
72 | at::device(points.device()).dtype(at::ScalarType::Int));
73 |
74 | at::Tensor tmp =
75 | torch::full({points.size(0), points.size(1)}, 1e10,
76 | at::device(points.device()).dtype(at::ScalarType::Float));
77 |
78 | if (points.is_cuda()) {
79 | furthest_point_sampling_kernel_wrapper(
80 | points.size(0), points.size(1), nsamples, points.data_ptr(),
81 | tmp.data_ptr(), output.data_ptr());
82 | } else {
83 | AT_ASSERT(false, "CPU not supported");
84 | }
85 |
86 | return output;
87 | }
88 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 |
6 | // input: points(b, c, n) idx(b, m)
7 | // output: out(b, c, m)
8 | __global__ void gather_points_kernel(int b, int c, int n, int m,
9 | const float *__restrict__ points,
10 | const int *__restrict__ idx,
11 | float *__restrict__ out) {
12 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
13 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
14 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
15 | int a = idx[i * m + j];
16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
17 | }
18 | }
19 | }
20 | }
21 |
22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
23 | const float *points, const int *idx,
24 | float *out) {
25 | gather_points_kernel<<>>(b, c, n, npoints,
27 | points, idx, out);
28 |
29 | CUDA_CHECK_ERRORS();
30 | }
31 |
32 | // input: grad_out(b, c, m) idx(b, m)
33 | // output: grad_points(b, c, n)
34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
35 | const float *__restrict__ grad_out,
36 | const int *__restrict__ idx,
37 | float *__restrict__ grad_points) {
38 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
39 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
40 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
41 | int a = idx[i * m + j];
42 | atomicAdd(grad_points + (i * c + l) * n + a,
43 | grad_out[(i * c + l) * m + j]);
44 | }
45 | }
46 | }
47 | }
48 |
49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
50 | const float *grad_out, const int *idx,
51 | float *grad_points) {
52 | gather_points_grad_kernel<<>>(
54 | b, c, n, npoints, grad_out, idx, grad_points);
55 |
56 | CUDA_CHECK_ERRORS();
57 | }
58 |
59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
60 | int idx1, int idx2) {
61 | const float v1 = dists[idx1], v2 = dists[idx2];
62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
63 | dists[idx1] = max(v1, v2);
64 | dists_i[idx1] = v2 > v1 ? i2 : i1;
65 | }
66 |
67 | // Input dataset: (b, n, 3), tmp: (b, n)
68 | // Ouput idxs (b, m)
69 | template
70 | __global__ void furthest_point_sampling_kernel(
71 | int b, int n, int m, const float *__restrict__ dataset,
72 | float *__restrict__ temp, int *__restrict__ idxs) {
73 | if (m <= 0) return;
74 | __shared__ float dists[block_size];
75 | __shared__ int dists_i[block_size];
76 |
77 | int batch_index = blockIdx.x;
78 | dataset += batch_index * n * 3;
79 | temp += batch_index * n;
80 | idxs += batch_index * m;
81 |
82 | int tid = threadIdx.x;
83 | const int stride = block_size;
84 |
85 | int old = 0;
86 | if (threadIdx.x == 0) idxs[0] = old;
87 |
88 | __syncthreads();
89 | for (int j = 1; j < m; j++) {
90 | int besti = 0;
91 | float best = -1;
92 | float x1 = dataset[old * 3 + 0];
93 | float y1 = dataset[old * 3 + 1];
94 | float z1 = dataset[old * 3 + 2];
95 | for (int k = tid; k < n; k += stride) {
96 | float x2, y2, z2;
97 | x2 = dataset[k * 3 + 0];
98 | y2 = dataset[k * 3 + 1];
99 | z2 = dataset[k * 3 + 2];
100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
101 | if (mag <= 1e-3) continue;
102 |
103 | float d =
104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
105 |
106 | float d2 = min(d, temp[k]);
107 | temp[k] = d2;
108 | besti = d2 > best ? k : besti;
109 | best = d2 > best ? d2 : best;
110 | }
111 | dists[tid] = best;
112 | dists_i[tid] = besti;
113 | __syncthreads();
114 |
115 | if (block_size >= 512) {
116 | if (tid < 256) {
117 | __update(dists, dists_i, tid, tid + 256);
118 | }
119 | __syncthreads();
120 | }
121 | if (block_size >= 256) {
122 | if (tid < 128) {
123 | __update(dists, dists_i, tid, tid + 128);
124 | }
125 | __syncthreads();
126 | }
127 | if (block_size >= 128) {
128 | if (tid < 64) {
129 | __update(dists, dists_i, tid, tid + 64);
130 | }
131 | __syncthreads();
132 | }
133 | if (block_size >= 64) {
134 | if (tid < 32) {
135 | __update(dists, dists_i, tid, tid + 32);
136 | }
137 | __syncthreads();
138 | }
139 | if (block_size >= 32) {
140 | if (tid < 16) {
141 | __update(dists, dists_i, tid, tid + 16);
142 | }
143 | __syncthreads();
144 | }
145 | if (block_size >= 16) {
146 | if (tid < 8) {
147 | __update(dists, dists_i, tid, tid + 8);
148 | }
149 | __syncthreads();
150 | }
151 | if (block_size >= 8) {
152 | if (tid < 4) {
153 | __update(dists, dists_i, tid, tid + 4);
154 | }
155 | __syncthreads();
156 | }
157 | if (block_size >= 4) {
158 | if (tid < 2) {
159 | __update(dists, dists_i, tid, tid + 2);
160 | }
161 | __syncthreads();
162 | }
163 | if (block_size >= 2) {
164 | if (tid < 1) {
165 | __update(dists, dists_i, tid, tid + 1);
166 | }
167 | __syncthreads();
168 | }
169 |
170 | old = dists_i[0];
171 | if (tid == 0) idxs[j] = old;
172 | }
173 | }
174 |
175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
176 | const float *dataset, float *temp,
177 | int *idxs) {
178 | unsigned int n_threads = opt_n_threads(n);
179 |
180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
181 |
182 | switch (n_threads) {
183 | case 512:
184 | furthest_point_sampling_kernel<512>
185 | <<>>(b, n, m, dataset, temp, idxs);
186 | break;
187 | case 256:
188 | furthest_point_sampling_kernel<256>
189 | <<>>(b, n, m, dataset, temp, idxs);
190 | break;
191 | case 128:
192 | furthest_point_sampling_kernel<128>
193 | <<>>(b, n, m, dataset, temp, idxs);
194 | break;
195 | case 64:
196 | furthest_point_sampling_kernel<64>
197 | <<>>(b, n, m, dataset, temp, idxs);
198 | break;
199 | case 32:
200 | furthest_point_sampling_kernel<32>
201 | <<>>(b, n, m, dataset, temp, idxs);
202 | break;
203 | case 16:
204 | furthest_point_sampling_kernel<16>
205 | <<>>(b, n, m, dataset, temp, idxs);
206 | break;
207 | case 8:
208 | furthest_point_sampling_kernel<8>
209 | <<>>(b, n, m, dataset, temp, idxs);
210 | break;
211 | case 4:
212 | furthest_point_sampling_kernel<4>
213 | <<>>(b, n, m, dataset, temp, idxs);
214 | break;
215 | case 2:
216 | furthest_point_sampling_kernel<2>
217 | <<>>(b, n, m, dataset, temp, idxs);
218 | break;
219 | case 1:
220 | furthest_point_sampling_kernel<1>
221 | <<>>(b, n, m, dataset, temp, idxs);
222 | break;
223 | default:
224 | furthest_point_sampling_kernel<512>
225 | <<>>(b, n, m, dataset, temp, idxs);
226 | }
227 |
228 | CUDA_CHECK_ERRORS();
229 | }
230 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_version.py:
--------------------------------------------------------------------------------
1 | __version__ = "3.0.0"
2 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from pointnet2_ops import pointnet2_utils
7 |
8 |
9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
10 | layers = []
11 | for i in range(1, len(mlp_spec)):
12 | layers.append(
13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
14 | )
15 | if bn:
16 | layers.append(nn.BatchNorm2d(mlp_spec[i]))
17 | layers.append(nn.ReLU(True))
18 |
19 | return nn.Sequential(*layers)
20 |
21 |
22 | class _PointnetSAModuleBase(nn.Module):
23 | def __init__(self):
24 | super(_PointnetSAModuleBase, self).__init__()
25 | self.npoint = None
26 | self.groupers = None
27 | self.mlps = None
28 |
29 | def forward(
30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor]
31 | ) -> Tuple[torch.Tensor, torch.Tensor]:
32 | r"""
33 | Parameters
34 | ----------
35 | xyz : torch.Tensor
36 | (B, N, 3) tensor of the xyz coordinates of the features
37 | features : torch.Tensor
38 | (B, C, N) tensor of the descriptors of the the features
39 |
40 | Returns
41 | -------
42 | new_xyz : torch.Tensor
43 | (B, npoint, 3) tensor of the new features' xyz
44 | new_features : torch.Tensor
45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
46 | """
47 |
48 | new_features_list = []
49 |
50 | xyz_flipped = xyz.transpose(1, 2).contiguous()
51 | new_xyz = (
52 | pointnet2_utils.gather_operation(
53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
54 | )
55 | .transpose(1, 2)
56 | .contiguous()
57 | if self.npoint is not None
58 | else None
59 | )
60 |
61 | for i in range(len(self.groupers)):
62 | new_features = self.groupers[i](
63 | xyz, new_xyz, features
64 | ) # (B, C, npoint, nsample)
65 |
66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
67 | new_features = F.max_pool2d(
68 | new_features, kernel_size=[1, new_features.size(3)]
69 | ) # (B, mlp[-1], npoint, 1)
70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
71 |
72 | new_features_list.append(new_features)
73 |
74 | return new_xyz, torch.cat(new_features_list, dim=1)
75 |
76 |
77 | class PointnetSAModuleMSG(_PointnetSAModuleBase):
78 | r"""Pointnet set abstrction layer with multiscale grouping
79 |
80 | Parameters
81 | ----------
82 | npoint : int
83 | Number of features
84 | radii : list of float32
85 | list of radii to group with
86 | nsamples : list of int32
87 | Number of samples in each ball query
88 | mlps : list of list of int32
89 | Spec of the pointnet before the global max_pool for each scale
90 | bn : bool
91 | Use batchnorm
92 | """
93 |
94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
96 | super(PointnetSAModuleMSG, self).__init__()
97 |
98 | assert len(radii) == len(nsamples) == len(mlps)
99 |
100 | self.npoint = npoint
101 | self.groupers = nn.ModuleList()
102 | self.mlps = nn.ModuleList()
103 | for i in range(len(radii)):
104 | radius = radii[i]
105 | nsample = nsamples[i]
106 | self.groupers.append(
107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
108 | if npoint is not None
109 | else pointnet2_utils.GroupAll(use_xyz)
110 | )
111 | mlp_spec = mlps[i]
112 | if use_xyz:
113 | mlp_spec[0] += 3
114 |
115 | self.mlps.append(build_shared_mlp(mlp_spec, bn))
116 |
117 |
118 | class PointnetSAModule(PointnetSAModuleMSG):
119 | r"""Pointnet set abstrction layer
120 |
121 | Parameters
122 | ----------
123 | npoint : int
124 | Number of features
125 | radius : float
126 | Radius of ball
127 | nsample : int
128 | Number of samples in the ball query
129 | mlp : list
130 | Spec of the pointnet before the global max_pool
131 | bn : bool
132 | Use batchnorm
133 | """
134 |
135 | def __init__(
136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
137 | ):
138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
139 | super(PointnetSAModule, self).__init__(
140 | mlps=[mlp],
141 | npoint=npoint,
142 | radii=[radius],
143 | nsamples=[nsample],
144 | bn=bn,
145 | use_xyz=use_xyz,
146 | )
147 |
148 |
149 | class PointnetFPModule(nn.Module):
150 | r"""Propigates the features of one set to another
151 |
152 | Parameters
153 | ----------
154 | mlp : list
155 | Pointnet module parameters
156 | bn : bool
157 | Use batchnorm
158 | """
159 |
160 | def __init__(self, mlp, bn=True):
161 | # type: (PointnetFPModule, List[int], bool) -> None
162 | super(PointnetFPModule, self).__init__()
163 | self.mlp = build_shared_mlp(mlp, bn=bn)
164 |
165 | def forward(self, unknown, known, unknow_feats, known_feats):
166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
167 | r"""
168 | Parameters
169 | ----------
170 | unknown : torch.Tensor
171 | (B, n, 3) tensor of the xyz positions of the unknown features
172 | known : torch.Tensor
173 | (B, m, 3) tensor of the xyz positions of the known features
174 | unknow_feats : torch.Tensor
175 | (B, C1, n) tensor of the features to be propigated to
176 | known_feats : torch.Tensor
177 | (B, C2, m) tensor of features to be propigated
178 |
179 | Returns
180 | -------
181 | new_features : torch.Tensor
182 | (B, mlp[-1], n) tensor of the features of the unknown features
183 | """
184 |
185 | if known is not None:
186 | dist, idx = pointnet2_utils.three_nn(unknown, known)
187 | dist_recip = 1.0 / (dist + 1e-8)
188 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
189 | weight = dist_recip / norm
190 |
191 | interpolated_feats = pointnet2_utils.three_interpolate(
192 | known_feats, idx, weight
193 | )
194 | else:
195 | interpolated_feats = known_feats.expand(
196 | *(known_feats.size()[0:2] + [unknown.size(1)])
197 | )
198 |
199 | if unknow_feats is not None:
200 | new_features = torch.cat(
201 | [interpolated_feats, unknow_feats], dim=1
202 | ) # (B, C2 + C1, n)
203 | else:
204 | new_features = interpolated_feats
205 |
206 | new_features = new_features.unsqueeze(-1)
207 | new_features = self.mlp(new_features)
208 |
209 | return new_features.squeeze(-1)
210 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import warnings
4 | from torch.autograd import Function
5 | from typing import *
6 |
7 | try:
8 | import pointnet2_ops._ext as _ext
9 | except ImportError:
10 | from torch.utils.cpp_extension import load
11 | import glob
12 | import os.path as osp
13 | import os
14 |
15 | warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
16 |
17 | _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
18 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
19 | osp.join(_ext_src_root, "src", "*.cu")
20 | )
21 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
22 |
23 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
24 | _ext = load(
25 | "_ext",
26 | sources=_ext_sources,
27 | extra_include_paths=[osp.join(_ext_src_root, "include")],
28 | extra_cflags=["-O3"],
29 | extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
30 | with_cuda=True,
31 | )
32 |
33 |
34 | class FurthestPointSampling(Function):
35 | @staticmethod
36 | def forward(ctx, xyz, npoint):
37 | # type: (Any, torch.Tensor, int) -> torch.Tensor
38 | r"""
39 | Uses iterative furthest point sampling to select a set of npoint features that have the largest
40 | minimum distance
41 |
42 | Parameters
43 | ----------
44 | xyz : torch.Tensor
45 | (B, N, 3) tensor where N > npoint
46 | npoint : int32
47 | number of features in the sampled set
48 |
49 | Returns
50 | -------
51 | torch.Tensor
52 | (B, npoint) tensor containing the set
53 | """
54 | out = _ext.furthest_point_sampling(xyz, npoint)
55 |
56 | ctx.mark_non_differentiable(out)
57 |
58 | return out
59 |
60 | @staticmethod
61 | def backward(ctx, grad_out):
62 | return ()
63 |
64 |
65 | furthest_point_sample = FurthestPointSampling.apply
66 |
67 |
68 | class GatherOperation(Function):
69 | @staticmethod
70 | def forward(ctx, features, idx):
71 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
72 | r"""
73 |
74 | Parameters
75 | ----------
76 | features : torch.Tensor
77 | (B, C, N) tensor
78 |
79 | idx : torch.Tensor
80 | (B, npoint) tensor of the features to gather
81 |
82 | Returns
83 | -------
84 | torch.Tensor
85 | (B, C, npoint) tensor
86 | """
87 |
88 | ctx.save_for_backward(idx, features)
89 |
90 | return _ext.gather_points(features, idx)
91 |
92 | @staticmethod
93 | def backward(ctx, grad_out):
94 | idx, features = ctx.saved_tensors
95 | N = features.size(2)
96 |
97 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
98 | return grad_features, None
99 |
100 |
101 | gather_operation = GatherOperation.apply
102 |
103 |
104 | class ThreeNN(Function):
105 | @staticmethod
106 | def forward(ctx, unknown, known):
107 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
108 | r"""
109 | Find the three nearest neighbors of unknown in known
110 | Parameters
111 | ----------
112 | unknown : torch.Tensor
113 | (B, n, 3) tensor of known features
114 | known : torch.Tensor
115 | (B, m, 3) tensor of unknown features
116 |
117 | Returns
118 | -------
119 | dist : torch.Tensor
120 | (B, n, 3) l2 distance to the three nearest neighbors
121 | idx : torch.Tensor
122 | (B, n, 3) index of 3 nearest neighbors
123 | """
124 | dist2, idx = _ext.three_nn(unknown, known)
125 | dist = torch.sqrt(dist2)
126 |
127 | ctx.mark_non_differentiable(dist, idx)
128 |
129 | return dist, idx
130 |
131 | @staticmethod
132 | def backward(ctx, grad_dist, grad_idx):
133 | return ()
134 |
135 |
136 | three_nn = ThreeNN.apply
137 |
138 |
139 | class ThreeInterpolate(Function):
140 | @staticmethod
141 | def forward(ctx, features, idx, weight):
142 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
143 | r"""
144 | Performs weight linear interpolation on 3 features
145 | Parameters
146 | ----------
147 | features : torch.Tensor
148 | (B, c, m) Features descriptors to be interpolated from
149 | idx : torch.Tensor
150 | (B, n, 3) three nearest neighbors of the target features in features
151 | weight : torch.Tensor
152 | (B, n, 3) weights
153 |
154 | Returns
155 | -------
156 | torch.Tensor
157 | (B, c, n) tensor of the interpolated features
158 | """
159 | ctx.save_for_backward(idx, weight, features)
160 |
161 | return _ext.three_interpolate(features, idx, weight)
162 |
163 | @staticmethod
164 | def backward(ctx, grad_out):
165 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
166 | r"""
167 | Parameters
168 | ----------
169 | grad_out : torch.Tensor
170 | (B, c, n) tensor with gradients of ouputs
171 |
172 | Returns
173 | -------
174 | grad_features : torch.Tensor
175 | (B, c, m) tensor with gradients of features
176 |
177 | None
178 |
179 | None
180 | """
181 | idx, weight, features = ctx.saved_tensors
182 | m = features.size(2)
183 |
184 | grad_features = _ext.three_interpolate_grad(
185 | grad_out.contiguous(), idx, weight, m
186 | )
187 |
188 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
189 |
190 |
191 | three_interpolate = ThreeInterpolate.apply
192 |
193 |
194 | class GroupingOperation(Function):
195 | @staticmethod
196 | def forward(ctx, features, idx):
197 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
198 | r"""
199 |
200 | Parameters
201 | ----------
202 | features : torch.Tensor
203 | (B, C, N) tensor of features to group
204 | idx : torch.Tensor
205 | (B, npoint, nsample) tensor containing the indicies of features to group with
206 |
207 | Returns
208 | -------
209 | torch.Tensor
210 | (B, C, npoint, nsample) tensor
211 | """
212 | ctx.save_for_backward(idx, features)
213 |
214 | return _ext.group_points(features, idx)
215 |
216 | @staticmethod
217 | def backward(ctx, grad_out):
218 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
219 | r"""
220 |
221 | Parameters
222 | ----------
223 | grad_out : torch.Tensor
224 | (B, C, npoint, nsample) tensor of the gradients of the output from forward
225 |
226 | Returns
227 | -------
228 | torch.Tensor
229 | (B, C, N) gradient of the features
230 | None
231 | """
232 | idx, features = ctx.saved_tensors
233 | N = features.size(2)
234 |
235 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
236 |
237 | return grad_features, torch.zeros_like(idx)
238 |
239 |
240 | grouping_operation = GroupingOperation.apply
241 |
242 |
243 | class BallQuery(Function):
244 | @staticmethod
245 | def forward(ctx, radius, nsample, xyz, new_xyz):
246 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
247 | r"""
248 |
249 | Parameters
250 | ----------
251 | radius : float
252 | radius of the balls
253 | nsample : int
254 | maximum number of features in the balls
255 | xyz : torch.Tensor
256 | (B, N, 3) xyz coordinates of the features
257 | new_xyz : torch.Tensor
258 | (B, npoint, 3) centers of the ball query
259 |
260 | Returns
261 | -------
262 | torch.Tensor
263 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls
264 | """
265 | output = _ext.ball_query(new_xyz, xyz, radius, nsample)
266 |
267 | ctx.mark_non_differentiable(output)
268 |
269 | return output
270 |
271 | @staticmethod
272 | def backward(ctx, grad_out):
273 | return ()
274 |
275 |
276 | ball_query = BallQuery.apply
277 |
278 |
279 | class QueryAndGroup(nn.Module):
280 | r"""
281 | Groups with a ball query of radius
282 |
283 | Parameters
284 | ---------
285 | radius : float32
286 | Radius of ball
287 | nsample : int32
288 | Maximum number of features to gather in the ball
289 | """
290 |
291 | def __init__(self, radius, nsample, use_xyz=True):
292 | # type: (QueryAndGroup, float, int, bool) -> None
293 | super(QueryAndGroup, self).__init__()
294 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
295 |
296 | def forward(self, xyz, new_xyz, features=None):
297 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
298 | r"""
299 | Parameters
300 | ----------
301 | xyz : torch.Tensor
302 | xyz coordinates of the features (B, N, 3)
303 | new_xyz : torch.Tensor
304 | centriods (B, npoint, 3)
305 | features : torch.Tensor
306 | Descriptors of the features (B, C, N)
307 |
308 | Returns
309 | -------
310 | new_features : torch.Tensor
311 | (B, 3 + C, npoint, nsample) tensor
312 | """
313 |
314 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
315 | xyz_trans = xyz.transpose(1, 2).contiguous()
316 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
317 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
318 |
319 | if features is not None:
320 | grouped_features = grouping_operation(features, idx)
321 | if self.use_xyz:
322 | new_features = torch.cat(
323 | [grouped_xyz, grouped_features], dim=1
324 | ) # (B, C + 3, npoint, nsample)
325 | else:
326 | new_features = grouped_features
327 | else:
328 | assert (
329 | self.use_xyz
330 | ), "Cannot have not features and not use xyz as a feature!"
331 | new_features = grouped_xyz
332 |
333 | return new_features
334 |
335 |
336 | class GroupAll(nn.Module):
337 | r"""
338 | Groups all features
339 |
340 | Parameters
341 | ---------
342 | """
343 |
344 | def __init__(self, use_xyz=True):
345 | # type: (GroupAll, bool) -> None
346 | super(GroupAll, self).__init__()
347 | self.use_xyz = use_xyz
348 |
349 | def forward(self, xyz, new_xyz, features=None):
350 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
351 | r"""
352 | Parameters
353 | ----------
354 | xyz : torch.Tensor
355 | xyz coordinates of the features (B, N, 3)
356 | new_xyz : torch.Tensor
357 | Ignored
358 | features : torch.Tensor
359 | Descriptors of the features (B, C, N)
360 |
361 | Returns
362 | -------
363 | new_features : torch.Tensor
364 | (B, C + 3, 1, N) tensor
365 | """
366 |
367 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
368 | if features is not None:
369 | grouped_features = features.unsqueeze(2)
370 | if self.use_xyz:
371 | new_features = torch.cat(
372 | [grouped_xyz, grouped_features], dim=1
373 | ) # (B, 3 + C, 1, N)
374 | else:
375 | new_features = grouped_features
376 | else:
377 | new_features = grouped_xyz
378 |
379 | return new_features
380 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from datasets import dataset_dict
9 |
10 | # colmap
11 | from datasets.colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary
12 | from datasets.llff import center_poses
13 | from datasets.ray_utils import get_ndc_coor, read_gen
14 |
15 | # losses
16 | from losses import loss_dict
17 |
18 | # metrics
19 | from metrics import psnr
20 | from models.cloud_code import CloudNeRF, SHCloudNeRF, config
21 |
22 | # models
23 | from models.nerf import Embedding
24 | from models.rendering import render_rays
25 | from opt import get_opts
26 |
27 | # fps
28 | from pointnet2_ops.pointnet2_utils import furthest_point_sample, gather_operation
29 |
30 | # pytorch-lightning
31 | from pytorch_lightning import LightningModule, Trainer
32 | from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
33 | from pytorch_lightning.loggers import TensorBoardLogger
34 | from pytorch_lightning.plugins import DDPPlugin
35 | from torch.utils.data import DataLoader
36 |
37 | # optimizer, scheduler, visualization
38 | from utils import get_learning_rate, get_optimizer, get_scheduler, load_ckpt, visualize_depth
39 |
40 |
41 | @torch.no_grad()
42 | def fps(points, n_samples):
43 | points = torch.from_numpy(points).unsqueeze(
44 | 0).float().cuda() # 1, N_points, 3
45 | points_flipped = points.transpose(1, 2).contiguous()
46 | fps_index = furthest_point_sample(points, n_samples) # 1, n_samples
47 |
48 | fps_kps = gather_operation(points_flipped, fps_index).transpose(
49 | 1, 2).contiguous().squeeze(0) # n_samples, 3
50 | return fps_kps.cpu().numpy()
51 |
52 |
53 | class NeRFSystem(LightningModule):
54 | def __init__(self, hparams):
55 | super().__init__()
56 | self.save_hyperparameters(hparams)
57 | self.root_dir = hparams.root_dir
58 |
59 | self.loss = loss_dict["color"](coef=1)
60 |
61 | self.embedding_xyz = Embedding(hparams.N_emb_xyz)
62 | self.embedding_dir = Embedding(hparams.N_emb_dir)
63 | self.embeddings = {"xyz": self.embedding_xyz,
64 | "dir": self.embedding_dir}
65 |
66 | kps, fps_kps = self.read_colmap_meta(hparams)
67 |
68 | if hparams.N_importance == 0:
69 | if hparams.use_sh_feat:
70 | self.nerf_coarse = SHCloudNeRF(
71 | kps, fps_kps, 6 * hparams.N_emb_xyz + 3, 6 * hparams.N_emb_dir + 3)
72 | else:
73 | self.nerf_coarse = CloudNeRF(
74 | kps, fps_kps, 6 * hparams.N_emb_xyz + 3, 6 * hparams.N_emb_dir + 3)
75 |
76 | self.models = {"coarse": self.nerf_coarse}
77 | load_ckpt(self.nerf_coarse, hparams.weight_path, "nerf_coarse")
78 |
79 | else:
80 | if hparams.use_sh_feat:
81 | self.nerf_fine = SHCloudNeRF(
82 | kps, fps_kps, 6 * hparams.N_emb_xyz + 3, 6 * hparams.N_emb_dir + 3)
83 | else:
84 | self.nerf_fine = CloudNeRF(
85 | kps, fps_kps, 6 * hparams.N_emb_xyz + 3, 6 * hparams.N_emb_dir + 3)
86 |
87 | self.models = {"fine": self.nerf_fine}
88 | load_ckpt(self.nerf_fine, hparams.weight_path, "nerf_fine")
89 |
90 | def read_colmap_meta(self, hparams):
91 | camdata = read_cameras_binary(os.path.join(
92 | hparams.root_dir, "sparse/0/cameras.bin"))
93 | self.origin_intrinsics = camdata
94 | W = camdata[1].width
95 | self.focal = camdata[1].params[0] * hparams.img_wh[0] / W
96 |
97 | imdata = read_images_binary(os.path.join(
98 | hparams.root_dir, "sparse/0/images.bin"))
99 |
100 | w2c_mats = []
101 | bottom = np.array([0, 0, 0, 1.0]).reshape(1, 4)
102 | for k in imdata:
103 | im = imdata[k]
104 | R = im.qvec2rotmat()
105 | t = im.tvec.reshape(3, 1)
106 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)]
107 | w2c_mats = np.stack(w2c_mats, 0)
108 | # (N_images, 3, 4) cam2world matrices
109 | poses = np.linalg.inv(w2c_mats)[:, :3]
110 | self.origin_extrinsics = poses
111 |
112 | pts3d = read_points3d_binary(os.path.join(
113 | hparams.root_dir, "sparse/0/points3D.bin"))
114 |
115 | mvs_points = self.load_mvs_depth().numpy()
116 | near_bound = mvs_points.min(axis=0)[-1]
117 | pts3d = {k: v for (k, v) in pts3d.items() if v.xyz[-1] > near_bound}
118 |
119 | pts_world = np.zeros((1, 3, len(pts3d))) # (1, 3, N_points)
120 | self.bounds = np.zeros((len(poses), 2)) # (N_images, 2)
121 | visibilities = np.zeros((len(poses), len(pts3d))
122 | ) # (N_images, N_points)
123 |
124 | for i, k in enumerate(pts3d):
125 | pts_world[0, :, i] = pts3d[k].xyz
126 | for j in pts3d[k].image_ids:
127 | visibilities[j - 1, i] = 1
128 |
129 | depths = ((pts_world - poses[..., 3:4]) * poses[..., 2:3]).sum(1)
130 | for i in range(len(poses)):
131 | visibility_i = visibilities[i]
132 | zs = depths[i][visibility_i == 1]
133 | self.bounds[i] = [np.percentile(zs, 0.1), np.percentile(zs, 99.9)]
134 | valid_depth = (depths[i] >= self.bounds[i][0]) & (
135 | depths[i] <= self.bounds[i][1])
136 | visibility_i = visibility_i.astype(bool) & valid_depth
137 | visibilities[i] = visibility_i.astype(np.float64)
138 |
139 | valid_points = np.any(visibilities, axis=0)
140 | pts_world = np.transpose(pts_world[0])[valid_points] # (N_points, 3)
141 |
142 | # fps
143 | fps_kps = fps(pts_world, config["code_cloud"]["num_codes"])
144 | global_kps = fps(mvs_points, pts_world.shape[0])
145 |
146 | if hparams.not_use_mvs:
147 | pts_world = np.concatenate([pts_world], axis=0)
148 | else:
149 | pts_world = np.concatenate([pts_world, global_kps], axis=0)
150 |
151 | poses = np.concatenate(
152 | [poses[..., 0:1], -poses[..., 1:3], poses[..., 3:4]], -1)
153 | poses, pose_avg = center_poses(poses)
154 | pose_avg_homo = np.eye(4)
155 | pose_avg_homo[:3] = pose_avg
156 |
157 | pts_world_homo = np.concatenate(
158 | [pts_world, np.ones((pts_world.shape[0], 1))], axis=1)
159 | fps_kps_homo = np.concatenate(
160 | [fps_kps, np.ones((fps_kps.shape[0], 1))], axis=1)
161 |
162 | trans_pts_world = np.linalg.inv(
163 | pose_avg_homo) @ pts_world_homo[:, :, None]
164 | trans_fps_kps = np.linalg.inv(pose_avg_homo) @ fps_kps_homo[:, :, None]
165 |
166 | kps = torch.from_numpy(trans_pts_world[:, :3, 0])
167 | fps_kps = torch.from_numpy(trans_fps_kps[:, :3, 0])
168 | near_original = self.bounds.min()
169 | scale_factor = near_original * 0.75 # 0.75 is the default parameter
170 | kps /= scale_factor
171 | fps_kps /= scale_factor
172 |
173 | # convert to ndc
174 | kps_ndc = get_ndc_coor(
175 | hparams.img_wh[1], hparams.img_wh[0], self.focal, 1.0, kps)
176 | fps_kps_ndc = get_ndc_coor(
177 | hparams.img_wh[1], hparams.img_wh[0], self.focal, 1.0, fps_kps)
178 | return kps_ndc, fps_kps_ndc
179 |
180 | def load_mvs_depth(self):
181 | depth_glob = os.path.join(self.root_dir, "depths", "*.pfm")
182 | self.depth_list = sorted(glob.glob(depth_glob))
183 | depths = []
184 | for i in range(len(self.depth_list)):
185 | depth = read_gen(self.depth_list[i])
186 |
187 | depths.append(depth)
188 | self.depths = np.stack(depths, 0).astype(np.float32) # N x H x W
189 |
190 | per_view_points = self.project_to_3d()
191 | mvs_points = self.fwd_consistency_check(per_view_points)
192 | return mvs_points
193 |
194 | def project_to_3d(self):
195 | N, H, W = self.depths.shape
196 | focal = self.origin_intrinsics[1].params[0]
197 | origin_h, origin_w = self.origin_intrinsics[1].height, self.origin_intrinsics[1].width
198 |
199 | origin_cy, origin_cx = self.origin_intrinsics[1].params[2], self.origin_intrinsics[1].params[1]
200 |
201 | origin_K = np.array([[focal, 0, origin_cx, 0], [
202 | 0, focal, origin_cy, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
203 |
204 | origin_K[0, :] /= origin_w
205 | origin_K[1, :] /= origin_h
206 |
207 | self.normalized_K = origin_K
208 |
209 | mvs_K = self.normalized_K.copy()
210 | mvs_K[0, :] *= W
211 | mvs_K[1, :] *= H
212 | self.mvs_K = mvs_K
213 |
214 | inv_mvs_K = np.linalg.pinv(mvs_K)
215 | inv_mvs_K = torch.from_numpy(inv_mvs_K)
216 |
217 | # create mesh grid for mvs image
218 | meshgrid = np.meshgrid(range(W), range(H), indexing="xy")
219 | id_coords = (np.stack(meshgrid, axis=0).astype(
220 | np.float32)).reshape(2, -1)
221 | id_coords = torch.from_numpy(id_coords)
222 |
223 | ones = torch.ones(N, 1, H * W)
224 |
225 | pix_coords = torch.unsqueeze(torch.stack(
226 | [id_coords[0].view(-1), id_coords[1].view(-1)], 0), 0)
227 | pix_coords = pix_coords.repeat(N, 1, 1)
228 | pix_coords = torch.cat([pix_coords, ones], 1)
229 |
230 | # project to cam coord
231 | inv_mvs_K = inv_mvs_K[None, ...].repeat(N, 1, 1).float()
232 | cam_points = torch.matmul(inv_mvs_K[:, :3, :3], pix_coords)
233 | mvs_depth = torch.from_numpy(
234 | self.depths).float().unsqueeze(1).view(N, 1, -1)
235 | cam_points = mvs_depth * cam_points
236 | cam_points = torch.cat([cam_points, ones], 1)
237 |
238 | # project to world coord
239 | T = torch.from_numpy(self.origin_extrinsics).float()
240 | world_points = torch.matmul(T, cam_points)
241 | world_points = world_points.permute(0, 2, 1) # N, H*W, 3
242 |
243 | return world_points
244 |
245 | def fwd_consistency_check(self, per_view_points):
246 | N, H, W = self.depths.shape
247 | global_valid_points = []
248 | for view_id in range(per_view_points.shape[0]):
249 | curr_view_points = per_view_points[view_id].transpose(
250 | 1, 0) # 3, H*W
251 | homo_view_points = torch.cat(
252 | [curr_view_points, torch.ones(1, H * W)], dim=0) # 4, H*W
253 | homo_view_points = homo_view_points.unsqueeze(
254 | 0).repeat(N, 1, 1) # N,4,H*W
255 |
256 | # project to camera space
257 | T = torch.from_numpy(self.origin_extrinsics).float()
258 | homo_T = torch.cat([T, torch.zeros(N, 1, 4)], dim=1)
259 | homo_T[:, -1, -1] = 1
260 | inv_T = torch.inverse(homo_T)
261 | cam_points = torch.matmul(inv_T[:, :3, :], homo_view_points)
262 |
263 | # project to image space
264 | mvs_K = torch.from_numpy(self.mvs_K).unsqueeze(
265 | 0).repeat(N, 1, 1).float()
266 | cam_points = torch.matmul(mvs_K[:, :3, :3], cam_points)
267 | cam_points[:, :2, :] /= cam_points[:, 2:, :]
268 |
269 | z_values = cam_points[:, 2:, :].view(N, 1, H, W) # N,1,H,W
270 | xy_coords = cam_points[:, :2, :].transpose(
271 | 2, 1).view(N, H, W, 2) # N,H,W,2
272 |
273 | xy_coords[..., 0] /= W - 1
274 | xy_coords[..., 1] /= H - 1
275 | xy_coords = (xy_coords - 0.5) * 2
276 |
277 | mvs_depth = torch.from_numpy(
278 | self.depths).float().unsqueeze(1) # N,1,H,W
279 | ref_z_values = F.grid_sample(
280 | mvs_depth, xy_coords, mode="bilinear", align_corners=False)
281 | err = z_values - 0.9 * ref_z_values
282 |
283 | visible_mask = ref_z_values != 0
284 | visible_count = visible_mask.int().sum(0)
285 | valid_visible = visible_count >= 1
286 | valid_points = err >= 0
287 | valid_points = torch.all(
288 | valid_points, dim=0) & valid_visible # 1,H,W
289 | global_valid_points.append(valid_points)
290 | global_valid_points = torch.cat(
291 | global_valid_points, dim=0).view(N, H * W) # N,H,W
292 |
293 | filtered_points = per_view_points[global_valid_points, :]
294 | return filtered_points
295 |
296 | def forward(self, rays):
297 | """Do batched inference on rays using chunk."""
298 | B = rays.shape[0]
299 | results = defaultdict(list)
300 | for i in range(0, B, self.hparams.chunk):
301 | rendered_ray_chunks = render_rays(
302 | self.models,
303 | self.embeddings,
304 | rays[i: i + self.hparams.chunk],
305 | self.hparams.N_samples,
306 | self.hparams.use_disp,
307 | self.hparams.perturb,
308 | self.hparams.noise_std,
309 | self.hparams.N_importance,
310 | self.hparams.chunk, # chunk size is effective in val mode
311 | self.train_dataset.white_back,
312 | use_sh_feat = self.hparams.use_sh_feat
313 | )
314 |
315 | for k, v in rendered_ray_chunks.items():
316 | results[k] += [v]
317 |
318 | for k, v in results.items():
319 | results[k] = torch.cat(v, 0)
320 | return results
321 |
322 | def setup(self, stage):
323 | dataset = dataset_dict[self.hparams.dataset_name]
324 | kwargs = {"root_dir": self.hparams.root_dir,
325 | "img_wh": tuple(self.hparams.img_wh)}
326 | if self.hparams.dataset_name == "llff":
327 | kwargs["val_num"] = 3
328 | self.train_dataset = dataset(split="train", **kwargs)
329 | self.val_dataset = dataset(split="val", **kwargs)
330 |
331 | def configure_optimizers(self):
332 | self.optimizer = get_optimizer(self.hparams, self.models)
333 | scheduler = get_scheduler(self.hparams, self.optimizer)
334 | return [self.optimizer], [scheduler]
335 |
336 | def train_dataloader(self):
337 | return DataLoader(
338 | self.train_dataset, shuffle=True, num_workers=4, batch_size=self.hparams.batch_size, pin_memory=True
339 | )
340 |
341 | def val_dataloader(self):
342 | return DataLoader(
343 | self.val_dataset,
344 | shuffle=False,
345 | num_workers=4,
346 | # validate one image (H*W rays) at a time
347 | batch_size=1,
348 | pin_memory=True,
349 | )
350 |
351 | def training_step(self, batch, batch_nb):
352 | rays, rgbs = batch["rays"], batch["rgbs"]
353 | results = self(rays)
354 | loss = self.loss(results, rgbs)
355 |
356 | with torch.no_grad():
357 | typ = "fine" if "rgb_fine" in results else "coarse"
358 | psnr_ = psnr(results[f"rgb_{typ}"], rgbs)
359 |
360 | self.log("lr", get_learning_rate(self.optimizer))
361 | self.log("train/loss", loss)
362 | self.log("train/psnr", psnr_, prog_bar=True)
363 |
364 | return loss
365 |
366 | def validation_step(self, batch, batch_nb):
367 | rays, rgbs = batch["rays"], batch["rgbs"]
368 | rays = rays.squeeze() # (H*W, 3)
369 | rgbs = rgbs.squeeze() # (H*W, 3)
370 | results = self(rays)
371 | log = {"val_loss": self.loss(results, rgbs)}
372 | typ = "fine" if "rgb_fine" in results else "coarse"
373 |
374 | if batch_nb == 0:
375 | W, H = self.hparams.img_wh
376 | img = results[f"rgb_{typ}"].view(
377 | H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
378 | img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
379 | depth = visualize_depth(
380 | results[f"depth_{typ}"].view(H, W)) # (3, H, W)
381 | stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W)
382 | self.logger.experiment.add_images(
383 | "val/GT_pred_depth", stack, self.global_step)
384 |
385 | psnr_ = psnr(results[f"rgb_{typ}"], rgbs)
386 | log["val_psnr"] = psnr_
387 |
388 | return log
389 |
390 | def validation_epoch_end(self, outputs):
391 | mean_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
392 | mean_psnr = torch.stack([x["val_psnr"] for x in outputs]).mean()
393 |
394 | self.log("val/loss", mean_loss)
395 | self.log("val/psnr", mean_psnr, prog_bar=True)
396 |
397 |
398 | def main(hparams):
399 | system = NeRFSystem(hparams)
400 | ckpt_cb = ModelCheckpoint(
401 | dirpath=f"ckpts/{hparams.exp_name}", filename="{epoch:d}", monitor="val/psnr", mode="max", save_top_k=5
402 | )
403 | pbar = TQDMProgressBar(refresh_rate=1)
404 | callbacks = [ckpt_cb, pbar]
405 |
406 | logger = TensorBoardLogger(
407 | save_dir="logs", name=hparams.exp_name, default_hp_metric=False)
408 |
409 | trainer = Trainer(
410 | max_epochs=hparams.num_epochs,
411 | callbacks=callbacks,
412 | resume_from_checkpoint=hparams.ckpt_path,
413 | logger=logger,
414 | enable_model_summary=False,
415 | accelerator="auto",
416 | devices=hparams.num_gpus,
417 | num_sanity_val_steps=1,
418 | benchmark=True,
419 | profiler="simple" if hparams.num_gpus == 1 else None,
420 | strategy=DDPPlugin(
421 | find_unused_parameters=False) if hparams.num_gpus > 1 else None,
422 | )
423 |
424 | trainer.fit(system)
425 |
426 |
427 | if __name__ == "__main__":
428 | hparams = get_opts()
429 | main(hparams)
430 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_optimizer as optim
3 |
4 | # optimizer
5 | from torch.optim import SGD, Adam
6 |
7 | # scheduler
8 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
9 |
10 | from .warmup_scheduler import GradualWarmupScheduler
11 | from .visualization import visualize_depth
12 |
13 |
14 | def get_parameters(models):
15 | """Get all model parameters recursively."""
16 | parameters = []
17 | # latent_parameters = []
18 | if isinstance(models, list):
19 | for model in models:
20 | parameters += get_parameters(model)
21 | elif isinstance(models, dict):
22 | for model in models.values():
23 | parameters += get_parameters(model)
24 | else: # models is actually a single pytorch model
25 | parameters += list(models.parameters())
26 | return parameters
27 |
28 |
29 | def get_optimizer(hparams, models):
30 | eps = 1e-8
31 | parameters = get_parameters(models)
32 | if hparams.optimizer == "sgd":
33 | optimizer = SGD(parameters, lr=hparams.lr,
34 | momentum=hparams.momentum, weight_decay=hparams.weight_decay)
35 | elif hparams.optimizer == "adam":
36 | optimizer = Adam(parameters, lr=hparams.lr, eps=eps,
37 | weight_decay=hparams.weight_decay)
38 | elif hparams.optimizer == "radam":
39 | optimizer = optim.RAdam(parameters, lr=hparams.lr,
40 | eps=eps, weight_decay=hparams.weight_decay)
41 | elif hparams.optimizer == "ranger":
42 | optimizer = optim.Ranger(
43 | parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay)
44 | else:
45 | raise ValueError("optimizer not recognized!")
46 |
47 | return optimizer
48 |
49 |
50 | def get_scheduler(hparams, optimizer):
51 | eps = 1e-8
52 | if hparams.lr_scheduler == "steplr":
53 | scheduler = MultiStepLR(
54 | optimizer, milestones=hparams.decay_step, gamma=hparams.decay_gamma)
55 | elif hparams.lr_scheduler == "cosine":
56 | scheduler = CosineAnnealingLR(
57 | optimizer, T_max=hparams.num_epochs, eta_min=eps)
58 | else:
59 | raise ValueError("scheduler not recognized!")
60 |
61 | if hparams.warmup_epochs > 0 and hparams.optimizer not in ["radam", "ranger"]:
62 | scheduler = GradualWarmupScheduler(
63 | optimizer,
64 | multiplier=hparams.warmup_multiplier,
65 | total_epoch=hparams.warmup_epochs,
66 | after_scheduler=scheduler,
67 | )
68 |
69 | return scheduler
70 |
71 |
72 | def get_learning_rate(optimizer):
73 | for param_group in optimizer.param_groups:
74 | return param_group["lr"]
75 |
76 |
77 | def extract_model_state_dict(ckpt_path, model_name="model", prefixes_to_ignore=[]):
78 | checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu"))
79 | checkpoint_ = {}
80 | if "state_dict" in checkpoint: # if it's a pytorch-lightning checkpoint
81 | checkpoint = checkpoint["state_dict"]
82 | for k, v in checkpoint.items():
83 | if not k.startswith(model_name):
84 | continue
85 | k = k[len(model_name) + 1:]
86 | for prefix in prefixes_to_ignore:
87 | if k.startswith(prefix):
88 | print("ignore", k)
89 | break
90 | else:
91 | checkpoint_[k] = v
92 | return checkpoint_
93 |
94 |
95 | def load_ckpt(model, ckpt_path, model_name="model", prefixes_to_ignore=[]):
96 | if not ckpt_path:
97 | return
98 | model_dict = model.state_dict()
99 | checkpoint_ = extract_model_state_dict(
100 | ckpt_path, model_name, prefixes_to_ignore)
101 | model_dict.update(checkpoint_)
102 | model.load_state_dict(model_dict)
103 |
--------------------------------------------------------------------------------
/utils/save_weights_only.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 |
5 |
6 | def get_opts():
7 | parser = argparse.ArgumentParser()
8 |
9 | parser.add_argument("--ckpt_path", type=str, required=True, help="checkpoint path")
10 |
11 | return parser.parse_args()
12 |
13 |
14 | if __name__ == "__main__":
15 | args = get_opts()
16 | checkpoint = torch.load(args.ckpt_path, map_location=torch.device("cpu"))
17 | torch.save(checkpoint["state_dict"], args.ckpt_path.split("/")[-2] + ".ckpt")
18 | print("Done!")
19 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torchvision.transforms as T
4 | from PIL import Image
5 |
6 |
7 | def visualize_depth(depth, cmap=cv2.COLORMAP_JET):
8 | """
9 | depth: (H, W)
10 | """
11 | x = depth.cpu().numpy()
12 | x = np.nan_to_num(x) # change nan to 0
13 | mi = np.min(x) # get minimum depth
14 | ma = np.max(x)
15 | x = (x - mi) / max(ma - mi, 1e-8) # normalize to 0~1
16 | x = (255 * x).astype(np.uint8)
17 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
18 | x_ = T.ToTensor()(x_) # (3, H, W)
19 | return x_
20 |
--------------------------------------------------------------------------------
/utils/warmup_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
2 |
3 |
4 | class GradualWarmupScheduler(_LRScheduler):
5 | """Gradually warm-up(increasing) learning rate in optimizer.
6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
7 | Args:
8 | optimizer (Optimizer): Wrapped optimizer.
9 | multiplier: target learning rate = base lr * multiplier
10 | total_epoch: target learning rate is reached at total_epoch, gradually
11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
12 | """
13 |
14 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
15 | self.multiplier = multiplier
16 | if self.multiplier < 1.0:
17 | raise ValueError("multiplier should be greater thant or equal to 1.")
18 | self.total_epoch = total_epoch
19 | self.after_scheduler = after_scheduler
20 | self.finished = False
21 | super().__init__(optimizer)
22 |
23 | def get_lr(self):
24 | if self.last_epoch > self.total_epoch:
25 | if self.after_scheduler:
26 | if not self.finished:
27 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
28 | self.finished = True
29 | return self.after_scheduler.get_lr()
30 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
31 |
32 | return [
33 | base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) for base_lr in self.base_lrs
34 | ]
35 |
36 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
37 | if epoch is None:
38 | epoch = self.last_epoch + 1
39 | self.last_epoch = (
40 | epoch if epoch != 0 else 1
41 | ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
42 | if self.last_epoch <= self.total_epoch:
43 | warmup_lr = [
44 | base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
45 | for base_lr in self.base_lrs
46 | ]
47 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
48 | param_group["lr"] = lr
49 | else:
50 | if epoch is None:
51 | self.after_scheduler.step(metrics, None)
52 | else:
53 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
54 |
55 | def step(self, epoch=None, metrics=None):
56 | if type(self.after_scheduler) != ReduceLROnPlateau:
57 | if self.finished and self.after_scheduler:
58 | if epoch is None:
59 | self.after_scheduler.step(None)
60 | else:
61 | self.after_scheduler.step(epoch - self.total_epoch)
62 | else:
63 | return super(GradualWarmupScheduler, self).step(epoch)
64 | else:
65 | self.step_ReduceLROnPlateau(metrics, epoch)
66 |
--------------------------------------------------------------------------------