├── LICENSE ├── README.md ├── fairnr ├── __init__.py ├── clib │ ├── __init__.py │ ├── include │ │ ├── cuda_utils.h │ │ ├── cutil_math.h │ │ ├── intersect.h │ │ ├── octree.h │ │ └── utils.h │ └── src │ │ ├── binding.cpp │ │ ├── intersect.cpp │ │ ├── intersect_gpu.cu │ │ └── octree.cpp ├── criterions │ ├── __init__.py │ ├── gan_loss.py │ ├── perceptual_loss.py │ ├── rendering_loss.py │ └── utils.py ├── data │ ├── __init__.py │ ├── data_utils.py │ ├── geometry.py │ ├── shape_dataset.py │ └── trajectory.py ├── models │ ├── __init__.py │ ├── fairnr_model.py │ └── nsvf.py ├── modules │ ├── __init__.py │ ├── discriminator.py │ ├── encoder.py │ ├── field.py │ ├── hyper.py │ ├── implicit.py │ ├── linear.py │ ├── reader.py │ └── renderer.py ├── options.py ├── renderer.py └── tasks │ ├── __init__.py │ └── neural_rendering.py ├── fairnr_cli ├── __init__.py ├── extract.py ├── myrender.py └── mytrain.py ├── img └── tpami.jpg ├── render.py ├── render.sh ├── requirements.txt ├── run_scan_plant.sh ├── setup.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 IGLICT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RGBDNeRF: Neural Radiance Fields from Sparse RGB-D Images for High-Quality View Synthesis (IEEE TPAMI) 2 | 3 | ![Teaser image](./img/tpami.jpg) 4 | 5 | ## Abstract 6 | The recently proposed neural radiance fields (NeRF) use a continuous function formulated as a multi-layer perceptron (MLP) to model the appearance and geometry of a 3D scene. This enables realistic synthesis of novel views, even for scenes with view dependent appearance. Many follow-up works have since extended NeRFs in different ways. However, a fundamental restriction of the method remains that it requires a large number of images captured from densely placed viewpoints for high-quality synthesis and the quality of the results quickly degrades when the number of captured views is insufficient. To address this problem, we propose a novel NeRF-based framework capable of high-quality view synthesis using only a sparse set of RGB-D images, which can be easily captured using cameras and LiDAR sensors on current consumer devices. First, a geometric proxy of the scene is reconstructed from the captured RGB-D images. Renderings of the reconstructed scene along with precise camera parameters can then be used to pre-train a network. Finally, the network is fine-tuned with a small number of real captured images. We further introduce a patch discriminator to supervise the network under novel views during fine-tuning, as well as a 3D color prior to improve synthesis quality. We demonstrate that our method can generate arbitrary novel views of a 3D scene from as few as 6 RGB-D images. Extensive experiments show the improvements of our method compared with the existing NeRF-based methods, including approaches that also aim to reduce the number of input images. 7 | 8 | ## Requirements and Installation 9 | 10 | The code has been tested using the following environment: 11 | 12 | * python 3.7 13 | * pytorch >= 1.7.1 14 | * CUDA 11.0 15 | 16 | Other dependencies can be installed via 17 | 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | Then, run 23 | 24 | ```bash 25 | pip install --editable ./ 26 | ``` 27 | 28 | Or if you want to install the code locally, run: 29 | 30 | ```bash 31 | python setup.py build_ext --inplace 32 | ``` 33 | 34 | ## Data 35 | 36 | We have prepared a processed data [here](https://drive.google.com/drive/folders/1u_KSUJOROzg0Vx8jqZEeZaFyugeBI57g?usp=sharing) of the scene 'plant'. 37 | 38 | ## Training and Inference 39 | 40 | * For training, please refer to the example script `run_scan_plant.sh`. 41 | 42 | * For rendering, please refer to the example script `render.sh`. 43 | 44 | ## Acknowledgement 45 | This code borrows heavily from [NSVF](https://github.com/facebookresearch/NSVF). 46 | 47 | ## Citation 48 | 49 | If you found this code useful please cite our work as: 50 | 51 | ``` 52 | @article{yuan2022neural, 53 | title={Neural radiance fields from sparse RGB-D images for high-quality view synthesis}, 54 | author={Yuan, Yu-Jie and Lai, Yu-Kun and Huang, Yi-Hua and Kobbelt, Leif and Gao, Lin}, 55 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 56 | year={2022}, 57 | publisher={IEEE} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /fairnr/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | class ResetTrainerException(Exception): 3 | pass 4 | 5 | 6 | from . import data, tasks, models, modules, criterions 7 | -------------------------------------------------------------------------------- /fairnr/clib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | from __future__ import ( 8 | division, 9 | absolute_import, 10 | with_statement, 11 | print_function, 12 | unicode_literals, 13 | ) 14 | import os, sys 15 | import torch 16 | import torch.nn.functional as F 17 | from torch.autograd import Function 18 | import torch.nn as nn 19 | import sys 20 | import numpy as np 21 | 22 | try: 23 | import builtins 24 | except: 25 | import __builtin__ as builtins 26 | 27 | try: 28 | import fairnr.clib._ext as _ext 29 | except ImportError: 30 | raise ImportError( 31 | "Could not import _ext module.\n" 32 | "Please see the setup instructions in the README" 33 | ) 34 | 35 | MAX_DEPTH = 10000.0 36 | 37 | class BallRayIntersect(Function): 38 | @staticmethod 39 | def forward(ctx, radius, n_max, points, ray_start, ray_dir): 40 | inds, min_depth, max_depth = _ext.ball_intersect( 41 | ray_start.float(), ray_dir.float(), points.float(), radius, n_max) 42 | min_depth = min_depth.type_as(ray_start) 43 | max_depth = max_depth.type_as(ray_start) 44 | 45 | ctx.mark_non_differentiable(inds) 46 | ctx.mark_non_differentiable(min_depth) 47 | ctx.mark_non_differentiable(max_depth) 48 | return inds, min_depth, max_depth 49 | 50 | @staticmethod 51 | def backward(ctx, a, b, c): 52 | return None, None, None, None, None 53 | 54 | ball_ray_intersect = BallRayIntersect.apply 55 | 56 | 57 | class AABBRayIntersect(Function): 58 | @staticmethod 59 | def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir): 60 | # HACK: speed-up ray-voxel intersection by batching... 61 | G = min(2048, int(2 * 10 ** 9 / points.numel())) # HACK: avoid out-of-memory 62 | S, N = ray_start.shape[:2] 63 | K = int(np.ceil(N / G)) 64 | H = K * G 65 | if H > N: 66 | ray_start = torch.cat([ray_start, ray_start[:, :H-N]], 1) 67 | ray_dir = torch.cat([ray_dir, ray_dir[:, :H-N]], 1) 68 | ray_start = ray_start.reshape(S * G, K, 3) 69 | ray_dir = ray_dir.reshape(S * G, K, 3) 70 | points = points.expand(S * G, *points.size()[1:]).contiguous() 71 | 72 | inds, min_depth, max_depth = _ext.aabb_intersect( 73 | ray_start.float(), ray_dir.float(), points.float(), voxelsize, n_max) 74 | min_depth = min_depth.type_as(ray_start) 75 | max_depth = max_depth.type_as(ray_start) 76 | 77 | inds = inds.reshape(S, H, -1) 78 | min_depth = min_depth.reshape(S, H, -1) 79 | max_depth = max_depth.reshape(S, H, -1) 80 | if H > N: 81 | inds = inds[:, :N] 82 | min_depth = min_depth[:, :N] 83 | max_depth = max_depth[:, :N] 84 | 85 | ctx.mark_non_differentiable(inds) 86 | ctx.mark_non_differentiable(min_depth) 87 | ctx.mark_non_differentiable(max_depth) 88 | return inds, min_depth, max_depth 89 | 90 | @staticmethod 91 | def backward(ctx, a, b, c): 92 | return None, None, None, None, None 93 | 94 | aabb_ray_intersect = AABBRayIntersect.apply 95 | 96 | 97 | class SparseVoxelOctreeRayIntersect(Function): 98 | @staticmethod 99 | def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir): 100 | G = min(2048, int(2 * 10 ** 9 / (points.numel() + children.numel()))) # HACK: avoid out-of-memory 101 | S, N = ray_start.shape[:2] 102 | K = int(np.ceil(N / G)) 103 | H = K * G 104 | if H > N: 105 | ray_start = torch.cat([ray_start, ray_start[:, :H-N]], 1) 106 | ray_dir = torch.cat([ray_dir, ray_dir[:, :H-N]], 1) 107 | ray_start = ray_start.reshape(S * G, K, 3) 108 | ray_dir = ray_dir.reshape(S * G, K, 3) 109 | points = points.expand(S * G, *points.size()[1:]).contiguous() 110 | children = children.expand(S * G, *children.size()[1:]).contiguous() 111 | inds, min_depth, max_depth = _ext.svo_intersect( 112 | ray_start.float(), ray_dir.float(), points.float(), children.int(), voxelsize, n_max) 113 | 114 | min_depth = min_depth.type_as(ray_start) 115 | max_depth = max_depth.type_as(ray_start) 116 | 117 | inds = inds.reshape(S, H, -1) 118 | min_depth = min_depth.reshape(S, H, -1) 119 | max_depth = max_depth.reshape(S, H, -1) 120 | if H > N: 121 | inds = inds[:, :N] 122 | min_depth = min_depth[:, :N] 123 | max_depth = max_depth[:, :N] 124 | 125 | ctx.mark_non_differentiable(inds) 126 | ctx.mark_non_differentiable(min_depth) 127 | ctx.mark_non_differentiable(max_depth) 128 | return inds, min_depth, max_depth 129 | 130 | @staticmethod 131 | def backward(ctx, a, b, c): 132 | return None, None, None, None, None 133 | 134 | svo_ray_intersect = SparseVoxelOctreeRayIntersect.apply 135 | 136 | 137 | class TriangleRayIntersect(Function): 138 | @staticmethod 139 | def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start, ray_dir): 140 | # HACK: speed-up ray-voxel intersection by batching... 141 | G = min(2048, int(2 * 10 ** 9 / (3 * faces.numel()))) # HACK: avoid out-of-memory 142 | S, N = ray_start.shape[:2] 143 | K = int(np.ceil(N / G)) 144 | H = K * G 145 | if H > N: 146 | ray_start = torch.cat([ray_start, ray_start[:, :H-N]], 1) 147 | ray_dir = torch.cat([ray_dir, ray_dir[:, :H-N]], 1) 148 | ray_start = ray_start.reshape(S * G, K, 3) 149 | ray_dir = ray_dir.reshape(S * G, K, 3) 150 | face_points = F.embedding(faces.reshape(-1, 3), points.reshape(-1, 3)) 151 | face_points = face_points.unsqueeze(0).expand(S * G, *face_points.size()).contiguous() 152 | inds, depth, uv = _ext.triangle_intersect( 153 | ray_start.float(), ray_dir.float(), face_points.float(), cagesize, blur_ratio, n_max) 154 | depth = depth.type_as(ray_start) 155 | uv = uv.type_as(ray_start) 156 | 157 | inds = inds.reshape(S, H, -1) 158 | depth = depth.reshape(S, H, -1, 3) 159 | uv = uv.reshape(S, H, -1) 160 | if H > N: 161 | inds = inds[:, :N] 162 | depth = depth[:, :N] 163 | uv = uv[:, :N] 164 | 165 | ctx.mark_non_differentiable(inds) 166 | ctx.mark_non_differentiable(depth) 167 | ctx.mark_non_differentiable(uv) 168 | return inds, depth, uv 169 | 170 | @staticmethod 171 | def backward(ctx, a, b, c): 172 | return None, None, None, None, None, None 173 | 174 | triangle_ray_intersect = TriangleRayIntersect.apply 175 | 176 | 177 | class UniformRaySampling(Function): 178 | @staticmethod 179 | def forward(ctx, pts_idx, min_depth, max_depth, step_size, max_ray_length, deterministic=False): 180 | G, N, P = 256, pts_idx.size(0), pts_idx.size(1) 181 | H = int(np.ceil(N / G)) * G 182 | if H > N: 183 | pts_idx = torch.cat([pts_idx, pts_idx[:H-N]], 0) 184 | min_depth = torch.cat([min_depth, min_depth[:H-N]], 0) 185 | max_depth = torch.cat([max_depth, max_depth[:H-N]], 0) 186 | pts_idx = pts_idx.reshape(G, -1, P) 187 | min_depth = min_depth.reshape(G, -1, P) 188 | max_depth = max_depth.reshape(G, -1, P) 189 | 190 | # pre-generate noise 191 | max_steps = int(max_ray_length / step_size) 192 | max_steps = max_steps + min_depth.size(-1) * 2 193 | noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps) 194 | if deterministic: 195 | noise += 0.5 196 | else: 197 | noise = noise.uniform_() 198 | 199 | # call cuda function 200 | sampled_idx, sampled_depth, sampled_dists = _ext.uniform_ray_sampling( 201 | pts_idx, min_depth.float(), max_depth.float(), noise.float(), step_size, max_steps) 202 | sampled_depth = sampled_depth.type_as(min_depth) 203 | sampled_dists = sampled_dists.type_as(min_depth) 204 | 205 | sampled_idx = sampled_idx.reshape(H, -1) 206 | sampled_depth = sampled_depth.reshape(H, -1) 207 | sampled_dists = sampled_dists.reshape(H, -1) 208 | if H > N: 209 | sampled_idx = sampled_idx[: N] 210 | sampled_depth = sampled_depth[: N] 211 | sampled_dists = sampled_dists[: N] 212 | 213 | max_len = sampled_idx.ne(-1).sum(-1).max() 214 | sampled_idx = sampled_idx[:, :max_len] 215 | sampled_depth = sampled_depth[:, :max_len] 216 | sampled_dists = sampled_dists[:, :max_len] 217 | 218 | ctx.mark_non_differentiable(sampled_idx) 219 | ctx.mark_non_differentiable(sampled_depth) 220 | ctx.mark_non_differentiable(sampled_dists) 221 | return sampled_idx, sampled_depth, sampled_dists 222 | 223 | @staticmethod 224 | def backward(ctx, a, b, c): 225 | return None, None, None, None, None, None 226 | 227 | uniform_ray_sampling = UniformRaySampling.apply 228 | 229 | 230 | # back-up for ray point sampling 231 | @torch.no_grad() 232 | def _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False): 233 | # uniform sampling 234 | _min_depth = min_depth.min(1)[0] 235 | _max_depth = max_depth.masked_fill(max_depth.eq(MAX_DEPTH), 0).max(1)[0] 236 | max_ray_length = (_max_depth - _min_depth).max() 237 | 238 | delta = torch.arange(int(max_ray_length / MARCH_SIZE), device=min_depth.device, dtype=min_depth.dtype) 239 | delta = delta[None, :].expand(min_depth.size(0), delta.size(-1)) 240 | if deterministic: 241 | delta = delta + 0.5 242 | else: 243 | delta = delta + delta.clone().uniform_().clamp(min=0.01, max=0.99) 244 | delta = delta * MARCH_SIZE 245 | sampled_depth = min_depth[:, :1] + delta 246 | sampled_idx = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1 247 | sampled_idx = pts_idx.gather(1, sampled_idx) 248 | 249 | # include all boundary points 250 | sampled_depth = torch.cat([min_depth, max_depth, sampled_depth], -1) 251 | sampled_idx = torch.cat([pts_idx, pts_idx, sampled_idx], -1) 252 | 253 | # reorder 254 | sampled_depth, ordered_index = sampled_depth.sort(-1) 255 | sampled_idx = sampled_idx.gather(1, ordered_index) 256 | sampled_dists = sampled_depth[:, 1:] - sampled_depth[:, :-1] # distances 257 | sampled_depth = .5 * (sampled_depth[:, 1:] + sampled_depth[:, :-1]) # mid-points 258 | 259 | # remove all invalid depths 260 | min_ids = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1 261 | max_ids = (sampled_depth[:, :, None] >= max_depth[:, None, :]).sum(-1) 262 | 263 | sampled_depth.masked_fill_( 264 | (max_ids.ne(min_ids)) | 265 | (sampled_depth > _max_depth[:, None]) | 266 | (sampled_dists == 0.0) 267 | , MAX_DEPTH) 268 | sampled_depth, ordered_index = sampled_depth.sort(-1) # sort again 269 | sampled_masks = sampled_depth.eq(MAX_DEPTH) 270 | num_max_steps = (~sampled_masks).sum(-1).max() 271 | 272 | sampled_depth = sampled_depth[:, :num_max_steps] 273 | sampled_dists = sampled_dists.gather(1, ordered_index).masked_fill_(sampled_masks, 0.0)[:, :num_max_steps] 274 | sampled_idx = sampled_idx.gather(1, ordered_index).masked_fill_(sampled_masks, -1)[:, :num_max_steps] 275 | 276 | return sampled_idx, sampled_depth, sampled_dists 277 | 278 | 279 | @torch.no_grad() 280 | def parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False): 281 | chunk_size=4096 282 | full_size = min_depth.shape[0] 283 | if full_size <= chunk_size: 284 | return _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=deterministic) 285 | 286 | outputs = zip(*[ 287 | _parallel_ray_sampling( 288 | MARCH_SIZE, 289 | pts_idx[i:i+chunk_size], min_depth[i:i+chunk_size], max_depth[i:i+chunk_size], 290 | deterministic=deterministic) 291 | for i in range(0, full_size, chunk_size)]) 292 | sampled_idx, sampled_depth, sampled_dists = outputs 293 | 294 | def padding_points(xs, pad): 295 | if len(xs) == 1: 296 | return xs[0] 297 | 298 | maxlen = max([x.size(1) for x in xs]) 299 | full_size = sum([x.size(0) for x in xs]) 300 | xt = xs[0].new_ones(full_size, maxlen).fill_(pad) 301 | st = 0 302 | for i in range(len(xs)): 303 | xt[st: st + xs[i].size(0), :xs[i].size(1)] = xs[i] 304 | st += xs[i].size(0) 305 | return xt 306 | 307 | sampled_idx = padding_points(sampled_idx, -1) 308 | sampled_depth = padding_points(sampled_depth, MAX_DEPTH) 309 | sampled_dists = padding_points(sampled_dists, 0.0) 310 | return sampled_idx, sampled_depth, sampled_dists 311 | 312 | -------------------------------------------------------------------------------- /fairnr/clib/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /fairnr/clib/include/intersect.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | std::tuple ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 11 | const float radius, const int n_max); 12 | std::tuple aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 13 | const float voxelsize, const int n_max); 14 | std::tuple svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, at::Tensor children, 15 | const float voxelsize, const int n_max); 16 | std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, 17 | const float cagesize, const float blur, const int n_max); 18 | std::tuple uniform_ray_sampling(at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, 19 | const float step_size, const int max_steps); -------------------------------------------------------------------------------- /fairnr/clib/include/octree.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | std::tuple build_octree(at::Tensor center, at::Tensor points, int depth); -------------------------------------------------------------------------------- /fairnr/clib/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /fairnr/clib/src/binding.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "intersect.h" 7 | #include "octree.h" 8 | 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | m.def("ball_intersect", &ball_intersect); 12 | m.def("aabb_intersect", &aabb_intersect); 13 | m.def("svo_intersect", &svo_intersect); 14 | m.def("triangle_intersect", &triangle_intersect); 15 | m.def("uniform_ray_sampling", &uniform_ray_sampling); 16 | 17 | m.def("build_octree", &build_octree); 18 | } -------------------------------------------------------------------------------- /fairnr/clib/src/intersect.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "intersect.h" 7 | #include "utils.h" 8 | #include 9 | 10 | void ball_intersect_point_kernel_wrapper( 11 | int b, int n, int m, float radius, int n_max, 12 | const float *ray_start, const float *ray_dir, const float *points, 13 | int *idx, float *min_depth, float *max_depth); 14 | 15 | std::tuple< at::Tensor, at::Tensor, at::Tensor > ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 16 | const float radius, const int n_max){ 17 | CHECK_CONTIGUOUS(ray_start); 18 | CHECK_CONTIGUOUS(ray_dir); 19 | CHECK_CONTIGUOUS(points); 20 | CHECK_IS_FLOAT(ray_start); 21 | CHECK_IS_FLOAT(ray_dir); 22 | CHECK_IS_FLOAT(points); 23 | CHECK_CUDA(ray_start); 24 | CHECK_CUDA(ray_dir); 25 | CHECK_CUDA(points); 26 | 27 | at::Tensor idx = 28 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 29 | at::device(ray_start.device()).dtype(at::ScalarType::Int)); 30 | at::Tensor min_depth = 31 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 32 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 33 | at::Tensor max_depth = 34 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 35 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 36 | ball_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), 37 | radius, n_max, 38 | ray_start.data_ptr (), ray_dir.data_ptr (), points.data_ptr (), 39 | idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr ()); 40 | return std::make_tuple(idx, min_depth, max_depth); 41 | } 42 | 43 | 44 | void aabb_intersect_point_kernel_wrapper( 45 | int b, int n, int m, float voxelsize, int n_max, 46 | const float *ray_start, const float *ray_dir, const float *points, 47 | int *idx, float *min_depth, float *max_depth); 48 | 49 | std::tuple< at::Tensor, at::Tensor, at::Tensor > aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 50 | const float voxelsize, const int n_max){ 51 | CHECK_CONTIGUOUS(ray_start); 52 | CHECK_CONTIGUOUS(ray_dir); 53 | CHECK_CONTIGUOUS(points); 54 | CHECK_IS_FLOAT(ray_start); 55 | CHECK_IS_FLOAT(ray_dir); 56 | CHECK_IS_FLOAT(points); 57 | CHECK_CUDA(ray_start); 58 | CHECK_CUDA(ray_dir); 59 | CHECK_CUDA(points); 60 | 61 | at::Tensor idx = 62 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 63 | at::device(ray_start.device()).dtype(at::ScalarType::Int)); 64 | at::Tensor min_depth = 65 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 66 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 67 | at::Tensor max_depth = 68 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 69 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 70 | aabb_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), 71 | voxelsize, n_max, 72 | ray_start.data_ptr (), ray_dir.data_ptr (), points.data_ptr (), 73 | idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr ()); 74 | return std::make_tuple(idx, min_depth, max_depth); 75 | } 76 | 77 | 78 | void svo_intersect_point_kernel_wrapper( 79 | int b, int n, int m, float voxelsize, int n_max, 80 | const float *ray_start, const float *ray_dir, const float *points, const int *children, 81 | int *idx, float *min_depth, float *max_depth); 82 | 83 | 84 | std::tuple< at::Tensor, at::Tensor, at::Tensor > svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 85 | at::Tensor children, const float voxelsize, const int n_max){ 86 | CHECK_CONTIGUOUS(ray_start); 87 | CHECK_CONTIGUOUS(ray_dir); 88 | CHECK_CONTIGUOUS(points); 89 | CHECK_CONTIGUOUS(children); 90 | CHECK_IS_FLOAT(ray_start); 91 | CHECK_IS_FLOAT(ray_dir); 92 | CHECK_IS_FLOAT(points); 93 | CHECK_CUDA(ray_start); 94 | CHECK_CUDA(ray_dir); 95 | CHECK_CUDA(points); 96 | CHECK_CUDA(children); 97 | 98 | at::Tensor idx = 99 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 100 | at::device(ray_start.device()).dtype(at::ScalarType::Int)); 101 | at::Tensor min_depth = 102 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 103 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 104 | at::Tensor max_depth = 105 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 106 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 107 | svo_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), 108 | voxelsize, n_max, 109 | ray_start.data_ptr (), ray_dir.data_ptr (), points.data_ptr (), 110 | children.data_ptr (), idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr ()); 111 | return std::make_tuple(idx, min_depth, max_depth); 112 | } 113 | 114 | 115 | void triangle_intersect_point_kernel_wrapper( 116 | int b, int n, int m, float cagesize, float blur, int n_max, 117 | const float *ray_start, const float *ray_dir, const float *face_points, 118 | int *idx, float *depth, float *uv); 119 | 120 | std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, 121 | const float cagesize, const float blur, const int n_max){ 122 | CHECK_CONTIGUOUS(ray_start); 123 | CHECK_CONTIGUOUS(ray_dir); 124 | CHECK_CONTIGUOUS(face_points); 125 | CHECK_IS_FLOAT(ray_start); 126 | CHECK_IS_FLOAT(ray_dir); 127 | CHECK_IS_FLOAT(face_points); 128 | CHECK_CUDA(ray_start); 129 | CHECK_CUDA(ray_dir); 130 | CHECK_CUDA(face_points); 131 | 132 | at::Tensor idx = 133 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 134 | at::device(ray_start.device()).dtype(at::ScalarType::Int)); 135 | at::Tensor depth = 136 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 3}, 137 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 138 | at::Tensor uv = 139 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 2}, 140 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 141 | triangle_intersect_point_kernel_wrapper(face_points.size(0), face_points.size(1), ray_start.size(1), 142 | cagesize, blur, n_max, 143 | ray_start.data_ptr (), ray_dir.data_ptr (), face_points.data_ptr (), 144 | idx.data_ptr (), depth.data_ptr (), uv.data_ptr ()); 145 | return std::make_tuple(idx, depth, uv); 146 | } 147 | 148 | 149 | void uniform_ray_sampling_kernel_wrapper( 150 | int b, int num_rays, int max_hits, int max_steps, float step_size, 151 | const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise, 152 | int *sampled_idx, float *sampled_depth, float *sampled_dists); 153 | 154 | 155 | std::tuple< at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling( 156 | at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, 157 | const float step_size, const int max_steps){ 158 | 159 | CHECK_CONTIGUOUS(pts_idx); 160 | CHECK_CONTIGUOUS(min_depth); 161 | CHECK_CONTIGUOUS(max_depth); 162 | CHECK_CONTIGUOUS(uniform_noise); 163 | CHECK_IS_FLOAT(min_depth); 164 | CHECK_IS_FLOAT(max_depth); 165 | CHECK_IS_FLOAT(uniform_noise); 166 | CHECK_CUDA(pts_idx); 167 | CHECK_CUDA(min_depth); 168 | CHECK_CUDA(max_depth); 169 | CHECK_CUDA(uniform_noise); 170 | 171 | at::Tensor sampled_idx = 172 | -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps}, 173 | at::device(pts_idx.device()).dtype(at::ScalarType::Int)); 174 | at::Tensor sampled_depth = 175 | torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, 176 | at::device(min_depth.device()).dtype(at::ScalarType::Float)); 177 | at::Tensor sampled_dists = 178 | torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, 179 | at::device(min_depth.device()).dtype(at::ScalarType::Float)); 180 | uniform_ray_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2), 181 | step_size, 182 | pts_idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr (), 183 | uniform_noise.data_ptr (), sampled_idx.data_ptr (), 184 | sampled_depth.data_ptr (), sampled_dists.data_ptr ()); 185 | return std::make_tuple(sampled_idx, sampled_depth, sampled_dists); 186 | } -------------------------------------------------------------------------------- /fairnr/clib/src/intersect_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "cuda_utils.h" 12 | #include "cutil_math.h" // required for float3 vector math 13 | 14 | 15 | __global__ void ball_intersect_point_kernel( 16 | int b, int n, int m, float radius, 17 | int n_max, 18 | const float *__restrict__ ray_start, 19 | const float *__restrict__ ray_dir, 20 | const float *__restrict__ points, 21 | int *__restrict__ idx, 22 | float *__restrict__ min_depth, 23 | float *__restrict__ max_depth) { 24 | 25 | int batch_index = blockIdx.x; 26 | points += batch_index * n * 3; 27 | ray_start += batch_index * m * 3; 28 | ray_dir += batch_index * m * 3; 29 | idx += batch_index * m * n_max; 30 | min_depth += batch_index * m * n_max; 31 | max_depth += batch_index * m * n_max; 32 | 33 | int index = threadIdx.x; 34 | int stride = blockDim.x; 35 | float radius2 = radius * radius; 36 | 37 | for (int j = index; j < m; j += stride) { 38 | 39 | float x0 = ray_start[j * 3 + 0]; 40 | float y0 = ray_start[j * 3 + 1]; 41 | float z0 = ray_start[j * 3 + 2]; 42 | float xw = ray_dir[j * 3 + 0]; 43 | float yw = ray_dir[j * 3 + 1]; 44 | float zw = ray_dir[j * 3 + 2]; 45 | 46 | for (int l = 0; l < n_max; ++l) { 47 | idx[j * n_max + l] = -1; 48 | } 49 | 50 | for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k) { 51 | float x = points[k * 3 + 0] - x0; 52 | float y = points[k * 3 + 1] - y0; 53 | float z = points[k * 3 + 2] - z0; 54 | float d2 = x * x + y * y + z * z; 55 | float d2_proj = pow(x * xw + y * yw + z * zw, 2); 56 | float r2 = d2 - d2_proj; 57 | 58 | if (r2 < radius2) { 59 | idx[j * n_max + cnt] = k; 60 | 61 | float depth = sqrt(d2_proj); 62 | float depth_blur = sqrt(radius2 - r2); 63 | 64 | min_depth[j * n_max + cnt] = depth - depth_blur; 65 | max_depth[j * n_max + cnt] = depth + depth_blur; 66 | ++cnt; 67 | } 68 | } 69 | } 70 | } 71 | 72 | 73 | __device__ float2 RayAABBIntersection( 74 | const float3 &ori, 75 | const float3 &dir, 76 | const float3 ¢er, 77 | float half_voxel) { 78 | 79 | float f_low = 0; 80 | float f_high = 100000.; 81 | float f_dim_low, f_dim_high, temp, inv_ray_dir, start, aabb; 82 | 83 | for (int d = 0; d < 3; ++d) { 84 | switch (d) { 85 | case 0: 86 | inv_ray_dir = __fdividef(1.0f, dir.x); start = ori.x; aabb = center.x; break; 87 | case 1: 88 | inv_ray_dir = __fdividef(1.0f, dir.y); start = ori.y; aabb = center.y; break; 89 | case 2: 90 | inv_ray_dir = __fdividef(1.0f, dir.z); start = ori.z; aabb = center.z; break; 91 | } 92 | 93 | f_dim_low = (aabb - half_voxel - start) * inv_ray_dir; 94 | f_dim_high = (aabb + half_voxel - start) * inv_ray_dir; 95 | 96 | // Make sure low is less than high 97 | if (f_dim_high < f_dim_low) { 98 | temp = f_dim_low; 99 | f_dim_low = f_dim_high; 100 | f_dim_high = temp; 101 | } 102 | 103 | // If this dimension's high is less than the low we got then we definitely missed. 104 | if (f_dim_high < f_low) { 105 | return make_float2(-1.0f, -1.0f); 106 | } 107 | 108 | // Likewise if the low is less than the high. 109 | if (f_dim_low > f_high) { 110 | return make_float2(-1.0f, -1.0f); 111 | } 112 | 113 | // Add the clip from this dimension to the previous results 114 | f_low = (f_dim_low > f_low) ? f_dim_low : f_low; 115 | f_high = (f_dim_high < f_high) ? f_dim_high : f_high; 116 | 117 | if (f_low > f_high) { 118 | return make_float2(-1.0f, -1.0f); 119 | } 120 | } 121 | return make_float2(f_low, f_high); 122 | } 123 | 124 | 125 | __global__ void aabb_intersect_point_kernel( 126 | int b, int n, int m, float voxelsize, 127 | int n_max, 128 | const float *__restrict__ ray_start, 129 | const float *__restrict__ ray_dir, 130 | const float *__restrict__ points, 131 | int *__restrict__ idx, 132 | float *__restrict__ min_depth, 133 | float *__restrict__ max_depth) { 134 | 135 | int batch_index = blockIdx.x; 136 | points += batch_index * n * 3; 137 | ray_start += batch_index * m * 3; 138 | ray_dir += batch_index * m * 3; 139 | idx += batch_index * m * n_max; 140 | min_depth += batch_index * m * n_max; 141 | max_depth += batch_index * m * n_max; 142 | 143 | int index = threadIdx.x; 144 | int stride = blockDim.x; 145 | float half_voxel = voxelsize * 0.5; 146 | 147 | for (int j = index; j < m; j += stride) { 148 | for (int l = 0; l < n_max; ++l) { 149 | idx[j * n_max + l] = -1; 150 | } 151 | 152 | for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k) { 153 | float2 depths = RayAABBIntersection( 154 | make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]), 155 | make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]), 156 | make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]), 157 | half_voxel); 158 | 159 | if (depths.x > -1.0f){ 160 | idx[j * n_max + cnt] = k; 161 | min_depth[j * n_max + cnt] = depths.x; 162 | max_depth[j * n_max + cnt] = depths.y; 163 | ++cnt; 164 | } 165 | } 166 | } 167 | } 168 | 169 | 170 | __global__ void svo_intersect_point_kernel( 171 | int b, int n, int m, float voxelsize, 172 | int n_max, 173 | const float *__restrict__ ray_start, 174 | const float *__restrict__ ray_dir, 175 | const float *__restrict__ points, 176 | const int *__restrict__ children, 177 | int *__restrict__ idx, 178 | float *__restrict__ min_depth, 179 | float *__restrict__ max_depth) { 180 | /* 181 | TODO: this is an inefficient implementation of the 182 | navie Ray -- Sparse Voxel Octree Intersection. 183 | It can be further improved using: 184 | 185 | Revelles, Jorge, Carlos Urena, and Miguel Lastra. 186 | "An efficient parametric algorithm for octree traversal." (2000). 187 | */ 188 | int batch_index = blockIdx.x; 189 | points += batch_index * n * 3; 190 | children += batch_index * n * 9; 191 | ray_start += batch_index * m * 3; 192 | ray_dir += batch_index * m * 3; 193 | idx += batch_index * m * n_max; 194 | min_depth += batch_index * m * n_max; 195 | max_depth += batch_index * m * n_max; 196 | 197 | int index = threadIdx.x; 198 | int stride = blockDim.x; 199 | float half_voxel = voxelsize * 0.5; 200 | 201 | for (int j = index; j < m; j += stride) { 202 | for (int l = 0; l < n_max; ++l) { 203 | idx[j * n_max + l] = -1; 204 | } 205 | int stack[256] = {-1}; // DFS, initialize the stack 206 | int ptr = 0, cnt = 0, k = -1; 207 | stack[ptr] = n - 1; // ROOT node is always the last 208 | while (ptr > -1 && cnt < n_max) { 209 | assert((ptr < 256)); 210 | 211 | // evaluate the current node 212 | k = stack[ptr]; 213 | float2 depths = RayAABBIntersection( 214 | make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]), 215 | make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]), 216 | make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]), 217 | half_voxel * float(children[k * 9 + 8])); 218 | stack[ptr] = -1; ptr--; 219 | 220 | if (depths.x > -1.0f) { // ray did not miss the voxel 221 | // TODO: here it should be able to know which children is ok, further optimize the code 222 | if (children[k * 9 + 8] == 1) { // this is a terminal node 223 | idx[j * n_max + cnt] = k; 224 | min_depth[j * n_max + cnt] = depths.x; 225 | max_depth[j * n_max + cnt] = depths.y; 226 | ++cnt; continue; 227 | } 228 | 229 | for (int u = 0; u < 8; u++) { 230 | if (children[k * 9 + u] > -1) { 231 | ptr++; stack[ptr] = children[k * 9 + u]; // push child to the stack 232 | } 233 | } 234 | } 235 | } 236 | } 237 | } 238 | 239 | 240 | __device__ float3 RayTriangleIntersection( 241 | const float3 &ori, 242 | const float3 &dir, 243 | const float3 &v0, 244 | const float3 &v1, 245 | const float3 &v2, 246 | float blur) { 247 | 248 | float3 v0v1 = v1 - v0; 249 | float3 v0v2 = v2 - v0; 250 | float3 v0O = ori - v0; 251 | float3 dir_crs_v0v2 = cross(dir, v0v2); 252 | 253 | float det = dot(v0v1, dir_crs_v0v2); 254 | det = __fdividef(1.0f, det); // CUDA intrinsic function 255 | 256 | float u = dot(v0O, dir_crs_v0v2) * det; 257 | if (u < 0.0f - blur || u > 1.0f + blur) 258 | return make_float3(-1.0f, 0.0f, 0.0f); 259 | 260 | float3 v0O_crs_v0v1 = cross(v0O, v0v1); 261 | float v = dot(dir, v0O_crs_v0v1) * det; 262 | if (v < 0.0f - blur || v > 1.0f + blur) 263 | return make_float3(-1.0f, 0.0f, 0.0f); 264 | 265 | if ((u + v) < 0.0f - blur || (u + v) > 1.0f + blur) 266 | return make_float3(-1.0f, 0.0f, 0.0f); 267 | 268 | float t = dot(v0v2, v0O_crs_v0v1) * det; 269 | return make_float3(t, u, v); 270 | } 271 | 272 | 273 | __global__ void triangle_intersect_point_kernel( 274 | int b, int n, int m, float cagesize, 275 | float blur, int n_max, 276 | const float *__restrict__ ray_start, 277 | const float *__restrict__ ray_dir, 278 | const float *__restrict__ face_points, 279 | int *__restrict__ idx, 280 | float *__restrict__ depth, 281 | float *__restrict__ uv) { 282 | 283 | int batch_index = blockIdx.x; 284 | face_points += batch_index * n * 9; 285 | ray_start += batch_index * m * 3; 286 | ray_dir += batch_index * m * 3; 287 | idx += batch_index * m * n_max; 288 | depth += batch_index * m * n_max * 3; 289 | uv += batch_index * m * n_max * 2; 290 | 291 | int index = threadIdx.x; 292 | int stride = blockDim.x; 293 | for (int j = index; j < m; j += stride) { 294 | // go over rays 295 | for (int l = 0; l < n_max; ++l) { 296 | idx[j * n_max + l] = -1; 297 | } 298 | 299 | int cnt = 0; 300 | for (int k = 0; k < n && cnt < n_max; ++k) { 301 | // go over triangles 302 | float3 tuv = RayTriangleIntersection( 303 | make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]), 304 | make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]), 305 | make_float3(face_points[k * 9 + 0], face_points[k * 9 + 1], face_points[k * 9 + 2]), 306 | make_float3(face_points[k * 9 + 3], face_points[k * 9 + 4], face_points[k * 9 + 5]), 307 | make_float3(face_points[k * 9 + 6], face_points[k * 9 + 7], face_points[k * 9 + 8]), 308 | blur); 309 | 310 | if (tuv.x > 0) { 311 | int ki = k; 312 | float d = tuv.x, u = tuv.y, v = tuv.z; 313 | 314 | // sort 315 | for (int l = 0; l < cnt; l++) { 316 | if (d < depth[j * n_max * 3 + l * 3]) { 317 | swap(ki, idx[j * n_max + l]); 318 | swap(d, depth[j * n_max * 3 + l * 3]); 319 | swap(u, uv[j * n_max * 2 + l * 2]); 320 | swap(v, uv[j * n_max * 2 + l * 2 + 1]); 321 | } 322 | } 323 | idx[j * n_max + cnt] = ki; 324 | depth[j * n_max * 3 + cnt * 3] = d; 325 | uv[j * n_max * 2 + cnt * 2] = u; 326 | uv[j * n_max * 2 + cnt * 2 + 1] = v; 327 | cnt++; 328 | } 329 | } 330 | 331 | for (int l = 0; l < cnt; l++) { 332 | // compute min_depth 333 | if (l == 0) 334 | depth[j * n_max * 3 + l * 3 + 1] = -cagesize; 335 | else 336 | depth[j * n_max * 3 + l * 3 + 1] = -fminf(cagesize, 337 | .5 * (depth[j * n_max * 3 + l * 3] - depth[j * n_max * 3 + l * 3 - 3])); 338 | 339 | // compute max_depth 340 | if (l == cnt - 1) 341 | depth[j * n_max * 3 + l * 3 + 2] = cagesize; 342 | else 343 | depth[j * n_max * 3 + l * 3 + 2] = fminf(cagesize, 344 | .5 * (depth[j * n_max * 3 + l * 3 + 3] - depth[j * n_max * 3 + l * 3])); 345 | } 346 | } 347 | } 348 | 349 | void ball_intersect_point_kernel_wrapper( 350 | int b, int n, int m, float radius, int n_max, 351 | const float *ray_start, const float *ray_dir, const float *points, 352 | int *idx, float *min_depth, float *max_depth) { 353 | 354 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 355 | ball_intersect_point_kernel<<>>( 356 | b, n, m, radius, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth); 357 | 358 | CUDA_CHECK_ERRORS(); 359 | } 360 | 361 | 362 | void aabb_intersect_point_kernel_wrapper( 363 | int b, int n, int m, float voxelsize, int n_max, 364 | const float *ray_start, const float *ray_dir, const float *points, 365 | int *idx, float *min_depth, float *max_depth) { 366 | 367 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 368 | aabb_intersect_point_kernel<<>>( 369 | b, n, m, voxelsize, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth); 370 | 371 | CUDA_CHECK_ERRORS(); 372 | } 373 | 374 | 375 | void svo_intersect_point_kernel_wrapper( 376 | int b, int n, int m, float voxelsize, int n_max, 377 | const float *ray_start, const float *ray_dir, const float *points, const int *children, 378 | int *idx, float *min_depth, float *max_depth) { 379 | 380 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 381 | svo_intersect_point_kernel<<>>( 382 | b, n, m, voxelsize, n_max, ray_start, ray_dir, points, children, idx, min_depth, max_depth); 383 | 384 | CUDA_CHECK_ERRORS(); 385 | } 386 | 387 | 388 | void triangle_intersect_point_kernel_wrapper( 389 | int b, int n, int m, float cagesize, float blur, int n_max, 390 | const float *ray_start, const float *ray_dir, const float *face_points, 391 | int *idx, float *depth, float *uv) { 392 | 393 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 394 | triangle_intersect_point_kernel<<>>( 395 | b, n, m, cagesize, blur, n_max, ray_start, ray_dir, face_points, idx, depth, uv); 396 | 397 | CUDA_CHECK_ERRORS(); 398 | } 399 | 400 | 401 | __global__ void uniform_ray_sampling_kernel( 402 | int b, int num_rays, 403 | int max_hits, 404 | int max_steps, 405 | float step_size, 406 | const int *__restrict__ pts_idx, 407 | const float *__restrict__ min_depth, 408 | const float *__restrict__ max_depth, 409 | const float *__restrict__ uniform_noise, 410 | int *__restrict__ sampled_idx, 411 | float *__restrict__ sampled_depth, 412 | float *__restrict__ sampled_dists) { 413 | 414 | int batch_index = blockIdx.x; 415 | int index = threadIdx.x; 416 | int stride = blockDim.x; 417 | 418 | pts_idx += batch_index * num_rays * max_hits; 419 | min_depth += batch_index * num_rays * max_hits; 420 | max_depth += batch_index * num_rays * max_hits; 421 | 422 | uniform_noise += batch_index * num_rays * max_steps; 423 | sampled_idx += batch_index * num_rays * max_steps; 424 | sampled_depth += batch_index * num_rays * max_steps; 425 | sampled_dists += batch_index * num_rays * max_steps; 426 | 427 | // loop over all rays 428 | for (int j = index; j < num_rays; j += stride) { 429 | int H = j * max_hits, K = j * max_steps; 430 | int s = 0, ucur = 0, umin = 0, umax = 0; 431 | float last_min_depth, last_max_depth, curr_depth; 432 | 433 | // sort all depths 434 | while (true) { 435 | if (umax == max_hits || ucur == max_steps || pts_idx[H + umax] == -1) { 436 | break; // reach the maximum 437 | } 438 | if (umin < max_hits) { 439 | last_min_depth = min_depth[H + umin]; 440 | } else { 441 | last_min_depth = 10000.0; 442 | } 443 | if (umax < max_hits) { 444 | last_max_depth = max_depth[H + umax]; 445 | } else { 446 | last_max_depth = 10000.0; 447 | } 448 | if (ucur < max_steps) { 449 | curr_depth = min_depth[H] + (float(ucur) + uniform_noise[K + ucur]) * step_size; 450 | } 451 | 452 | if (last_max_depth <= curr_depth && last_max_depth <= last_min_depth) { 453 | sampled_depth[K + s] = last_max_depth; 454 | sampled_idx[K + s] = pts_idx[H + umax]; 455 | umax++; s++; continue; 456 | } 457 | if (curr_depth <= last_min_depth && curr_depth <= last_max_depth) { 458 | sampled_depth[K + s] = curr_depth; 459 | sampled_idx[K + s] = pts_idx[H + umin - 1]; 460 | ucur++; s++; continue; 461 | } 462 | if (last_min_depth <= curr_depth && last_min_depth <= last_max_depth) { 463 | sampled_depth[K + s] = last_min_depth; 464 | sampled_idx[K + s] = pts_idx[H + umin]; 465 | umin++; s++; continue; 466 | } 467 | } 468 | 469 | float l_depth, r_depth; 470 | int step = 0; 471 | for (ucur = 0, umin = 0, umax = 0; ucur < max_steps - 1; ucur++) { 472 | if (sampled_idx[K + ucur + 1] == -1) break; 473 | l_depth = sampled_depth[K + ucur]; 474 | r_depth = sampled_depth[K + ucur + 1]; 475 | sampled_depth[K + ucur] = (l_depth + r_depth) * .5; 476 | sampled_dists[K + ucur] = (r_depth - l_depth); 477 | if (umin < max_hits && sampled_depth[K + ucur] >= min_depth[H + umin] && pts_idx[H + umin] > -1) umin++; 478 | if (umax < max_hits && sampled_depth[K + ucur] >= max_depth[H + umax] && pts_idx[H + umax] > -1) umax++; 479 | if (umax == max_hits || pts_idx[H + umax] == -1) break; 480 | if (umin - 1 == umax && sampled_dists[K + ucur] > 0) { 481 | sampled_depth[K + step] = sampled_depth[K + ucur]; 482 | sampled_dists[K + step] = sampled_dists[K + ucur]; 483 | sampled_idx[K + step] = sampled_idx[K + ucur]; 484 | step++; 485 | } 486 | } 487 | 488 | for (int s = step; s < max_steps; s++) { 489 | sampled_idx[K + s] = -1; 490 | } 491 | } 492 | } 493 | 494 | 495 | void uniform_ray_sampling_kernel_wrapper( 496 | int b, int num_rays, int max_hits, int max_steps, float step_size, 497 | const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise, 498 | int *sampled_idx, float *sampled_depth, float *sampled_dists) { 499 | 500 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 501 | uniform_ray_sampling_kernel<<>>( 502 | b, num_rays, max_hits, max_steps, step_size, pts_idx, 503 | min_depth, max_depth, uniform_noise, sampled_idx, sampled_depth, sampled_dists); 504 | 505 | CUDA_CHECK_ERRORS(); 506 | } 507 | 508 | -------------------------------------------------------------------------------- /fairnr/clib/src/octree.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "octree.h" 7 | #include "utils.h" 8 | #include 9 | #include 10 | using namespace std::chrono; 11 | 12 | 13 | typedef struct OcTree 14 | { 15 | int depth; 16 | int index; 17 | at::Tensor center; 18 | struct OcTree *children[8]; 19 | void init(at::Tensor center, int d, int i) { 20 | this->center = center; 21 | this->depth = d; 22 | this->index = i; 23 | for (int i=0; i<8; i++) this->children[i] = nullptr; 24 | } 25 | }OcTree; 26 | 27 | class EasyOctree { 28 | public: 29 | OcTree *root; 30 | int total; 31 | int terminal; 32 | 33 | at::Tensor all_centers; 34 | at::Tensor all_children; 35 | 36 | EasyOctree(at::Tensor center, int depth) { 37 | root = new OcTree; 38 | root->init(center, depth, -1); 39 | total = -1; 40 | terminal = -1; 41 | } 42 | ~EasyOctree() { 43 | OcTree *p = root; 44 | destory(p); 45 | } 46 | void destory(OcTree * &p); 47 | void insert(OcTree * &p, at::Tensor point, int index); 48 | void finalize(); 49 | std::pair count(OcTree * &p); 50 | }; 51 | 52 | void EasyOctree::destory(OcTree * &p){ 53 | if (p != nullptr) { 54 | for (int i=0; i<8; i++) { 55 | if (p->children[i] != nullptr) destory(p->children[i]); 56 | } 57 | delete p; 58 | p = nullptr; 59 | } 60 | } 61 | 62 | void EasyOctree::insert(OcTree * &p, at::Tensor point, int index) { 63 | at::Tensor diff = (point > p->center).to(at::kInt); 64 | int idx = diff[0].item() + 2 * diff[1].item() + 4 * diff[2].item(); 65 | if (p->depth == 0) { 66 | p->children[idx] = new OcTree; 67 | p->children[idx]->init(point, -1, index); 68 | } else { 69 | if (p->children[idx] == nullptr) { 70 | int length = 1 << (p->depth - 1); 71 | at::Tensor new_center = p->center + (2 * diff - 1) * length; 72 | p->children[idx] = new OcTree; 73 | p->children[idx]->init(new_center, p->depth-1, -1); 74 | } 75 | insert(p->children[idx], point, index); 76 | } 77 | } 78 | 79 | std::pair EasyOctree::count(OcTree * &p) { 80 | int total = 0, terminal = 0; 81 | for (int i=0; i<8; i++) { 82 | if (p->children[i] != nullptr) { 83 | std::pair sub = count(p->children[i]); 84 | total += sub.first; 85 | terminal += sub.second; 86 | } 87 | } 88 | total += 1; 89 | if (p->depth == -1) terminal += 1; 90 | return std::make_pair(total, terminal); 91 | } 92 | 93 | void EasyOctree::finalize() { 94 | std::pair outs = count(root); 95 | total = outs.first; terminal = outs.second; 96 | 97 | all_centers = 98 | torch::zeros({outs.first, 3}, at::device(root->center.device()).dtype(at::ScalarType::Int)); 99 | all_children = 100 | -torch::ones({outs.first, 9}, at::device(root->center.device()).dtype(at::ScalarType::Int)); 101 | 102 | int node_idx = outs.first - 1; 103 | root->index = node_idx; 104 | 105 | std::queue all_leaves; all_leaves.push(root); 106 | while (!all_leaves.empty()) { 107 | OcTree* node_ptr = all_leaves.front(); 108 | all_leaves.pop(); 109 | for (int i=0; i<8; i++) { 110 | if (node_ptr->children[i] != nullptr) { 111 | if (node_ptr->children[i]->depth > -1) { 112 | node_idx--; 113 | node_ptr->children[i]->index = node_idx; 114 | } 115 | all_leaves.push(node_ptr->children[i]); 116 | all_children[node_ptr->index][i] = node_ptr->children[i]->index; 117 | } 118 | } 119 | all_children[node_ptr->index][8] = 1 << (node_ptr->depth + 1); 120 | all_centers[node_ptr->index] = node_ptr->center; 121 | } 122 | assert (node_idx == outs.second); 123 | }; 124 | 125 | std::tuple build_octree(at::Tensor center, at::Tensor points, int depth) { 126 | auto start = high_resolution_clock::now(); 127 | EasyOctree tree(center, depth); 128 | for (int k=0; k(stop - start); 133 | printf("Building EasyOctree done. total #nodes = %d, terminal #nodes = %d (time taken %f s)\n", 134 | tree.total, tree.terminal, float(duration.count()) / 1000000.); 135 | return std::make_tuple(tree.all_centers, tree.all_children); 136 | } -------------------------------------------------------------------------------- /fairnr/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib 3 | import os 4 | 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | criterion_name = file[: file.find(".py")] 8 | importlib.import_module( 9 | "fairnr.criterions." + criterion_name 10 | ) 11 | -------------------------------------------------------------------------------- /fairnr/criterions/gan_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class GANLoss(nn.Module): 5 | """Define different GAN objectives. 6 | The GANLoss class abstracts away the need to create the target label tensor 7 | that has the same size as the input. 8 | """ 9 | 10 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 11 | """ Initialize the GANLoss class. 12 | Parameters: 13 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 14 | target_real_label (bool) - - label for a real image 15 | target_fake_label (bool) - - label of a fake image 16 | Note: Do not use sigmoid as the last layer of Discriminator. 17 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 18 | """ 19 | super(GANLoss, self).__init__() 20 | self.register_buffer('real_label', torch.tensor(target_real_label).cuda()) 21 | self.register_buffer('fake_label', torch.tensor(target_fake_label).cuda()) 22 | self.gan_mode = gan_mode 23 | if gan_mode == 'lsgan': 24 | self.loss = nn.MSELoss() 25 | elif gan_mode == 'vanilla': 26 | self.loss = nn.BCEWithLogitsLoss() 27 | elif gan_mode in ['wgangp']: 28 | self.loss = None 29 | else: 30 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 31 | 32 | def get_target_tensor(self, prediction, target_is_real): 33 | """Create label tensors with the same size as the input. 34 | Parameters: 35 | prediction (tensor) - - tpyically the prediction from a discriminator 36 | target_is_real (bool) - - if the ground truth label is for real images or fake images 37 | Returns: 38 | A label tensor filled with ground truth label, and with the size of the input 39 | """ 40 | 41 | if target_is_real: 42 | target_tensor = self.real_label 43 | else: 44 | target_tensor = self.fake_label 45 | return target_tensor.expand_as(prediction) 46 | 47 | def __call__(self, prediction, target_is_real): 48 | """Calculate loss given Discriminator's output and grount truth labels. 49 | Parameters: 50 | prediction (tensor) - - tpyically the prediction output from a discriminator 51 | target_is_real (bool) - - if the ground truth label is for real images or fake images 52 | Returns: 53 | the calculated loss. 54 | """ 55 | if self.gan_mode in ['lsgan', 'vanilla']: 56 | target_tensor = self.get_target_tensor(prediction, target_is_real) 57 | loss = self.loss(prediction, target_tensor) 58 | elif self.gan_mode == 'wgangp': 59 | if target_is_real: 60 | loss = -prediction.mean() 61 | else: 62 | loss = prediction.mean() 63 | return loss -------------------------------------------------------------------------------- /fairnr/criterions/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torchvision 4 | 5 | class VGGPerceptualLoss(torch.nn.Module): 6 | def __init__(self, resize=False): 7 | super(VGGPerceptualLoss, self).__init__() 8 | blocks = [] 9 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) 10 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) 11 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) 12 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) 13 | self.blocks = torch.nn.ModuleList(blocks) 14 | self.transform = torch.nn.functional.interpolate 15 | self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) 16 | self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) 17 | self.resize = resize 18 | 19 | # NO GRADIENT! 20 | for param in self.parameters(): 21 | param.requires_grad = False 22 | 23 | def forward(self, input, target, level=2): 24 | # print(input.device, input.dtype, self.mean.device, self.mean.dtype, self.std, self.std.dtype) 25 | if input.shape[1] != 3: 26 | input = input.repeat(1, 3, 1, 1) 27 | target = target.repeat(1, 3, 1, 1) 28 | input = (input-self.mean) / self.std 29 | target = (target-self.mean) / self.std 30 | 31 | if self.resize: 32 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) 33 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) 34 | 35 | loss = 0.0 36 | x = input 37 | y = target 38 | for i, block in enumerate(self.blocks): 39 | x = block(x) 40 | y = block(y) 41 | if i < level: 42 | loss += torch.nn.functional.mse_loss(x, y) 43 | else: 44 | break 45 | return loss 46 | -------------------------------------------------------------------------------- /fairnr/criterions/rendering_loss.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | 4 | import torch.nn.functional as F 5 | import torch 6 | from torch import Tensor 7 | 8 | import fairnr.criterions.utils as utils 9 | 10 | def item(tensor): 11 | # tpu-comment: making this a no-op for xla devices. 12 | if torch.is_tensor(tensor) and tensor.device.type == 'xla': 13 | return tensor.detach() 14 | if hasattr(tensor, "item"): 15 | return tensor.item() 16 | if hasattr(tensor, "__getitem__"): 17 | return tensor[0] 18 | return tensor 19 | 20 | class RenderingCriterion(object): 21 | 22 | def __init__(self, args): 23 | super().__init__() 24 | self.args = args 25 | 26 | @classmethod 27 | def build_criterion(cls, args): 28 | """Construct a criterion from command-line args.""" 29 | return cls(args) 30 | 31 | @staticmethod 32 | def add_args(parser): 33 | """Add criterion-specific arguments to the parser.""" 34 | pass 35 | 36 | def forward(self, model, sample, reduce=True): 37 | """Compute the loss for the given sample. 38 | 39 | Returns a tuple with three elements: 40 | 1) the loss 41 | 2) the sample size, which is used as the denominator for the gradient 42 | 3) logging outputs to display while training 43 | """ 44 | net_output = model(**sample) 45 | loss, loss_output, loss_D = self.compute_loss(model, net_output, sample, reduce=reduce) 46 | sample_size = 1 47 | 48 | logging_output = { 49 | 'loss': loss.data.item() if reduce else loss.data, 50 | 'nsentences': sample['alpha'].size(0), 51 | 'ntokens': sample['alpha'].size(1), 52 | 'npixels': sample['alpha'].size(2), 53 | 'sample_size': sample_size, 54 | } 55 | for w in loss_output: 56 | logging_output[w] = loss_output[w] 57 | 58 | return loss, sample_size, logging_output, loss_D 59 | 60 | def compute_loss(self, model, net_output, sample, reduce=True): 61 | raise NotImplementedError 62 | 63 | @staticmethod 64 | def logging_outputs_can_be_summed() -> bool: 65 | """ 66 | Whether the logging outputs returned by `forward` can be summed 67 | across workers prior to calling `reduce_metrics`. Setting this 68 | to True will improves distributed training speed. 69 | """ 70 | return True 71 | 72 | 73 | class SRNLossCriterion(RenderingCriterion): 74 | 75 | def __init__(self, args): 76 | super().__init__(args) 77 | # HACK: to avoid warnings in c10d 78 | self.dummy_loss = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32), requires_grad=True) 79 | if args.vgg_weight > 0: 80 | from fairnr.criterions.perceptual_loss import VGGPerceptualLoss 81 | self.vgg = VGGPerceptualLoss(resize=False) 82 | 83 | if args.eval_lpips: # not use??? 84 | from lpips import LPIPS 85 | self.lpips = LPIPS(net='alex') 86 | 87 | if args.gan_weight > 0: 88 | from fairnr.criterions.gan_loss import GANLoss 89 | self.criterionGAN = GANLoss(args.gan_mode) 90 | 91 | @staticmethod 92 | def add_args(parser): 93 | """Add criterion-specific arguments to the parser.""" 94 | parser.add_argument('--L1', action='store_true', 95 | help='if enabled, use L1 instead of L2 for RGB loss') 96 | parser.add_argument('--color-weight', type=float, default=256.0) 97 | parser.add_argument('--depth-weight', type=float, default=0.0) 98 | parser.add_argument('--depth-weight-decay', type=str, default=None, 99 | help="""if set, use tuple to set (final_ratio, steps). 100 | For instance, (0, 30000) 101 | """) 102 | parser.add_argument('--alpha-weight', type=float, default=0.0) 103 | parser.add_argument('--vgg-weight', type=float, default=0.0) 104 | parser.add_argument('--vgg-level', type=int, choices=[1,2,3,4], default=2) 105 | parser.add_argument('--eval-lpips', action='store_true', 106 | help="evaluate LPIPS scores in validation") 107 | parser.add_argument('--no-background-loss', action='store_true') 108 | parser.add_argument('--gan-weight', type=float, default=0.0) 109 | parser.add_argument('--gan-mode', type=str, default='lsgan', 110 | help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 111 | 112 | def compute_loss(self, model, net_output, sample, reduce=True): 113 | losses, other_logs = {}, {} 114 | 115 | # prepare data before computing loss 116 | sampled_uv = net_output['sampled_uv'] # S, V, 2, N, P, P (patch-size) 117 | S, V, _, N, P1, P2 = sampled_uv.size() 118 | # H, W, h, w = sample['size'][0, 0].long().cpu().tolist() 119 | H, W, h, w = sample['size'][0, 0].cpu().tolist() 120 | L = N * P1 * P2 121 | flatten_uv = sampled_uv.view(S, V, 2, L) 122 | flatten_index = (flatten_uv[:,:,0] // h + flatten_uv[:,:,1] // w * W).long() 123 | 124 | assert 'colors' in sample and sample['colors'] is not None, "ground-truth colors not provided" 125 | target_colors = sample['colors'] 126 | masks = (sample['alpha'] > 0) if self.args.no_background_loss else None 127 | if L < target_colors.size(2): 128 | target_colors = target_colors.gather(2, flatten_index.unsqueeze(-1).repeat(1,1,1,3)) 129 | masks = masks.gather(2, flatten_uv) if masks is not None else None 130 | 131 | if 'other_logs' in net_output: 132 | other_logs.update(net_output['other_logs']) 133 | 134 | # computing loss 135 | if self.args.color_weight > 0: 136 | color_loss = utils.rgb_loss( 137 | net_output['colors'], target_colors, 138 | masks, self.args.L1) 139 | losses['color_loss'] = (color_loss, self.args.color_weight) 140 | 141 | if self.args.alpha_weight > 0: 142 | _alpha = net_output['missed'].reshape(-1) 143 | alpha_loss = torch.log1p( 144 | 1. / 0.11 * _alpha.float() * (1 - _alpha.float()) 145 | ).mean().type_as(_alpha) 146 | losses['alpha_loss'] = (alpha_loss, self.args.alpha_weight) 147 | 148 | if self.args.depth_weight > 0: 149 | if sample['depths'] is not None: 150 | target_depths = sample['depths'] 151 | target_depths = target_depths.gather(2, flatten_index) 152 | depth_mask = masks & (target_depths > 0) if masks is not None else None 153 | depth_loss = utils.depth_loss(net_output['depths'], target_depths, depth_mask) 154 | 155 | else: 156 | # no depth map is provided, depth loss only applied on background based on masks 157 | max_depth_target = self.args.max_depth * torch.ones_like(net_output['depths']) 158 | if sample['mask'] is not None: 159 | depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, (1 - sample['mask']).bool()) 160 | else: 161 | depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, ~masks) 162 | 163 | depth_weight = self.args.depth_weight 164 | if self.args.depth_weight_decay is not None: 165 | final_factor, final_steps = eval(self.args.depth_weight_decay) 166 | depth_weight *= max(0, 1 - (1 - final_factor) * self.task._num_updates / final_steps) 167 | other_logs['depth_weight'] = depth_weight 168 | 169 | losses['depth_loss'] = (depth_loss, depth_weight) 170 | 171 | 172 | if self.args.vgg_weight > 0: 173 | assert P1 * P2 > 1, "we have to use a patch-based sampling for VGG loss" 174 | target_colors = target_colors.reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5 175 | output_colors = net_output['colors'].reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5 176 | vgg_loss = self.vgg(output_colors, target_colors) 177 | losses['vgg_loss'] = (vgg_loss, self.args.vgg_weight) 178 | 179 | if self.args.gan_weight > 0 and 'pred_fake' in net_output: 180 | pred_fake, pred_real = net_output['pred_fake'], net_output['pred_real'] 181 | loss_G = self.criterionGAN(pred_fake, True) 182 | loss_D_fake = self.criterionGAN(pred_fake.detach(), False) 183 | loss_D_real = self.criterionGAN(pred_real, True) 184 | loss_D = (loss_D_fake + loss_D_real) * 0.5 185 | losses['gen_loss'] = (loss_G, self.args.gan_weight) 186 | else: 187 | loss_D = None 188 | 189 | loss = sum(losses[key][0] * losses[key][1] for key in losses) 190 | 191 | logging_outputs = {key: item(losses[key][0]) for key in losses} 192 | logging_outputs.update(other_logs) 193 | return loss, logging_outputs, loss_D 194 | -------------------------------------------------------------------------------- /fairnr/criterions/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | TINY = 1e-7 6 | 7 | 8 | def rgb_loss(predicts, rgbs, masks=None, L1=False, sum=False): 9 | if masks is not None: 10 | if masks.sum() == 0: 11 | return predicts.new_zeros(1).mean() 12 | predicts = predicts[masks] 13 | rgbs = rgbs[masks] 14 | 15 | if L1: 16 | loss = torch.abs(predicts - rgbs).sum(-1) 17 | else: 18 | loss = ((predicts - rgbs) ** 2).sum(-1) 19 | 20 | return loss.mean() if not sum else loss.sum() 21 | 22 | 23 | def depth_loss(depths, depth_gt, masks=None, sum=False): 24 | if masks is not None: 25 | if masks.sum() == 0: 26 | return depths.new_zeros(1).mean() 27 | depth_gt = depth_gt[masks] 28 | depths = depths[masks] 29 | 30 | loss = (depths[masks] - depth_gt[masks]) ** 2 31 | return loss.mean() if not sum else loss.sum() -------------------------------------------------------------------------------- /fairnr/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .shape_dataset import ( 3 | ShapeDataset, ShapeViewDataset, ShapeViewStreamDataset, 4 | SampledPixelDataset, WorldCoordDataset, 5 | InfiniteDataset 6 | ) 7 | 8 | __all__ = [ 9 | 'ShapeDataset', 10 | 'ShapeViewDataset', 11 | 'ShapeViewStreamDataset', 12 | 'SampledPixelDataset', 13 | 'WorldCoordDataset', 14 | ] 15 | -------------------------------------------------------------------------------- /fairnr/data/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import functools 4 | import cv2 5 | import math 6 | import numpy as np 7 | import imageio 8 | from glob import glob 9 | import os 10 | import copy 11 | import shutil 12 | import skimage.metrics 13 | import pandas as pd 14 | import pylab as plt 15 | 16 | 17 | def parse_views(view_args): 18 | output = [] 19 | try: 20 | xx = view_args.split(':') 21 | ids = xx[0].split(',') 22 | for id in ids: 23 | if '..' in id: 24 | a, b = id.split('..') 25 | output += list(range(int(a), int(b))) 26 | else: 27 | output += [int(id)] 28 | if len(xx) > 1: 29 | output = output[::int(xx[-1])] 30 | except Exception as e: 31 | raise Exception("parse view args error: {}".format(e)) 32 | 33 | return output 34 | 35 | 36 | def get_uv(H, W, h, w): 37 | """ 38 | H, W: real image (intrinsics) 39 | h, w: resized image 40 | """ 41 | uv = np.flip(np.mgrid[0: h, 0: w], axis=0).astype(np.float32) 42 | uv[0] = uv[0] * float(W / w) 43 | uv[1] = uv[1] * float(H / h) 44 | return uv, [float(H / h), float(W / w)] 45 | 46 | 47 | def load_rgb( 48 | path, 49 | resolution=None, 50 | with_alpha=True, 51 | load_ori=False, 52 | bg_color=[1.0, 1.0, 1.0], 53 | min_rgb=-1, 54 | interpolation='AREA'): 55 | if with_alpha: 56 | img = imageio.imread(path) # RGB-ALPHA 57 | else: 58 | img = imageio.imread(path)[:, :, :3] 59 | 60 | img = skimage.img_as_float32(img).astype('float32') 61 | H, W, D = img.shape 62 | h, w = resolution 63 | 64 | if load_ori: 65 | ori_img = img.copy() 66 | ori_img = ori_img.transpose(2, 0, 1) 67 | else: 68 | ori_img = None 69 | 70 | if D == 3: 71 | img = np.concatenate([img, np.ones((img.shape[0], img.shape[1], 1))], -1).astype('float32') 72 | 73 | uv, ratio = get_uv(H, W, h, w) 74 | if (h < H) or (w < W): 75 | # img = cv2.resize(img, (w, h), interpolation=cv2.INTER_NEAREST).astype('float32') 76 | img = cv2.resize(img, (w, h), interpolation=cv2.INTER_AREA).astype('float32') 77 | 78 | if min_rgb == -1: # 0, 1 --> -1, 1 79 | img[:, :, :3] -= 0.5 80 | img[:, :, :3] *= 2. 81 | 82 | img[:, :, :3] = img[:, :, :3] * img[:, :, 3:] + np.asarray(bg_color)[None, None, :] * (1 - img[:, :, 3:]) 83 | img[:, :, 3] = img[:, :, 3] * (img[:, :, :3] != np.asarray(bg_color)[None, None, :]).any(-1) 84 | img = img.transpose(2, 0, 1) 85 | 86 | return img, uv, ratio, ori_img 87 | 88 | 89 | def load_depth(path, resolution=None, depth_plane=5): 90 | if path is None: 91 | return None 92 | 93 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) 94 | # ret, img = cv2.threshold(img, depth_plane, depth_plane, cv2.THRESH_TRUNC) 95 | 96 | H, W = img.shape[:2] 97 | h, w = resolution 98 | if (h != H) or (w != W): 99 | img = cv2.resize(img, (w, h), interpolation=cv2.INTER_NEAREST).astype('float32') 100 | #img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR) 101 | 102 | if len(img.shape) ==3: 103 | img = img[:,:,:1] 104 | img = img.transpose(2,0,1) 105 | else: 106 | img = img[None,:,:] 107 | return img 108 | 109 | 110 | def load_mask(path, resolution=None): 111 | if path is None: 112 | return None 113 | 114 | img = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32) 115 | h, w = resolution 116 | H, W = img.shape[:2] 117 | if (h < H) or (w < W): 118 | img = cv2.resize(img, (w, h), interpolation=cv2.INTER_NEAREST).astype('float32') 119 | img = img / (img.max() + 1e-7) 120 | return img 121 | 122 | 123 | def load_matrix(path): 124 | return np.array([[float(w) for w in line.strip().split()] for line in open(path)]).astype(np.float32) 125 | 126 | 127 | def load_intrinsics(filepath, resized_width=None, invert_y=False): 128 | try: 129 | intrinsics = load_matrix(filepath) 130 | if intrinsics.shape[0] == 3 and intrinsics.shape[1] == 3: 131 | _intrinsics = np.zeros((4, 4), np.float32) 132 | _intrinsics[:3, :3] = intrinsics 133 | _intrinsics[3, 3] = 1 134 | intrinsics = _intrinsics 135 | return intrinsics 136 | except ValueError: 137 | pass 138 | 139 | # Get camera intrinsics 140 | with open(filepath, 'r') as file: 141 | f, cx, cy, _ = map(float, file.readline().split()) 142 | fx = f 143 | if invert_y: 144 | fy = -f 145 | else: 146 | fy = f 147 | 148 | # Build the intrinsic matrices 149 | full_intrinsic = np.array([[fx, 0., cx, 0.], 150 | [0., fy, cy, 0], 151 | [0., 0, 1, 0], 152 | [0, 0, 0, 1]]) 153 | return full_intrinsic 154 | 155 | 156 | def unflatten_img(img, width=512): 157 | sizes = img.size() 158 | height = sizes[-1] // width 159 | return img.reshape(*sizes[:-1], height, width) 160 | 161 | 162 | def square_crop_img(img): 163 | if img.shape[0] == img.shape[1]: 164 | return img # already square 165 | 166 | min_dim = np.amin(img.shape[:2]) 167 | center_coord = np.array(img.shape[:2]) // 2 168 | img = img[center_coord[0] - min_dim // 2:center_coord[0] + min_dim // 2, 169 | center_coord[1] - min_dim // 2:center_coord[1] + min_dim // 2] 170 | return img 171 | 172 | 173 | def sample_pixel_from_image( 174 | num_pixel, num_sample, 175 | mask=None, ratio=1.0, 176 | use_bbox=False, 177 | center_ratio=1.0, 178 | width=512, 179 | patch_size=1): 180 | 181 | if patch_size > 1: 182 | assert (num_pixel % (patch_size * patch_size) == 0) \ 183 | and (num_sample % (patch_size * patch_size) == 0), "size must match" 184 | _num_pixel = num_pixel // (patch_size * patch_size) 185 | _num_sample = num_sample // (patch_size * patch_size) 186 | height = num_pixel // width 187 | 188 | _mask = None if mask is None else \ 189 | mask.reshape(height, width).reshape( 190 | height//patch_size, patch_size, width//patch_size, patch_size 191 | ).any(1).any(-1).reshape(-1) 192 | _width = width // patch_size 193 | _out = sample_pixel_from_image(_num_pixel, _num_sample, _mask, ratio, use_bbox, _width) 194 | _x, _y = _out % _width, _out // _width 195 | x, y = _x * patch_size, _y * patch_size 196 | x = x[:, None, None] + np.arange(patch_size)[None, :, None] 197 | y = y[:, None, None] + np.arange(patch_size)[None, None, :] 198 | out = x + y * width 199 | return out.reshape(-1) 200 | 201 | if center_ratio < 1.0: 202 | r = (1 - center_ratio) / 2.0 203 | H, W = num_pixel // width, width 204 | mask0 = np.zeros((H, W)) 205 | mask0[int(H * r): H - int(H * r), int(W * r): W - int(W * r)] = 1 206 | mask0 = mask0.reshape(-1) 207 | 208 | if mask is None: 209 | mask = mask0 210 | else: 211 | mask = mask * mask0 212 | 213 | if mask is not None: 214 | mask = (mask > 0.0).astype('float32') 215 | 216 | if (mask is None) or \ 217 | (ratio <= 0.0) or \ 218 | (mask.sum() == 0) or \ 219 | ((1 - mask).sum() == 0): 220 | return np.random.choice(num_pixel, num_sample) 221 | 222 | if use_bbox: 223 | mask = mask.reshape(-1, width) 224 | x, y = np.where(mask == 1) 225 | mask = np.zeros_like(mask) 226 | mask[x.min(): x.max()+1, y.min(): y.max()+1] = 1.0 227 | mask = mask.reshape(-1) 228 | 229 | try: 230 | probs = mask * ratio / (mask.sum()) + (1 - mask) / (num_pixel - mask.sum()) * (1 - ratio) 231 | # x = np.random.choice(num_pixel, num_sample, True, p=probs) 232 | return np.random.choice(num_pixel, num_sample, True, p=probs) 233 | 234 | except Exception: 235 | return np.random.choice(num_pixel, num_sample) 236 | 237 | 238 | def colormap(dz): 239 | # return plt.cm.jet(dz) 240 | # return plt.cm.viridis(dz) 241 | return plt.cm.gray(dz) 242 | 243 | 244 | def recover_image(img, min_val=-1, max_val=1, width=512, bg=None, weight=None): 245 | sizes = img.size() 246 | height = sizes[0] // width 247 | img = img.float().to('cpu') 248 | 249 | if len(sizes) == 1 and (bg is not None): 250 | bg_mask = img.eq(bg)[:, None].type_as(img) 251 | 252 | if isinstance(min_val, torch.Tensor): 253 | min_val = min_val.to('cpu') 254 | if isinstance(max_val, torch.Tensor): 255 | max_val = max_val.to('cpu') 256 | img = ((img - min_val) / (max_val - min_val)).clamp(min=0, max=1) 257 | if len(sizes) == 1: 258 | img = torch.from_numpy(colormap(img.numpy())[:, :3]) 259 | if weight is not None: 260 | weight = weight.float().to('cpu') 261 | img = img * weight[:, None] 262 | 263 | if bg is not None: 264 | img = img * (1 - bg_mask) + bg_mask 265 | img = img.reshape(height, width, -1) 266 | return img 267 | 268 | 269 | def write_images(writer, images, updates): 270 | for tag in images: 271 | img = images[tag] 272 | tag, dataform = tag.split(':') 273 | writer.add_image(tag, img, updates, dataformats=dataform) 274 | 275 | 276 | def compute_psnr(p, t): 277 | """Compute PSNR of model image predictions. 278 | :param prediction: Return value of forward pass. 279 | :param ground_truth: Ground truth. 280 | :return: (psnr, ssim): tuple of floats 281 | """ 282 | ssim = skimage.metrics.structural_similarity(p, t, multichannel=True, data_range=1) 283 | psnr = skimage.metrics.peak_signal_noise_ratio(p, t, data_range=1) 284 | return ssim, psnr 285 | 286 | 287 | class InfIndex(object): 288 | 289 | def __init__(self, index_list, shuffle=False): 290 | self.index_list = index_list 291 | self.size = len(index_list) 292 | self.shuffle = shuffle 293 | self.reset_permutation() 294 | 295 | def reset_permutation(self): 296 | if self.shuffle: 297 | self._perm = np.random.permutation(self.index_list).tolist() 298 | else: 299 | self._perm = copy.deepcopy(self.index_list) 300 | 301 | def __iter__(self): 302 | return self 303 | 304 | def __next__(self): 305 | if len(self._perm) == 0: 306 | self.reset_permutation() 307 | return self._perm.pop() 308 | 309 | def __len__(self): 310 | return self.size 311 | 312 | 313 | class GPUTimer(object): 314 | def __enter__(self): 315 | """Start a new timer as a context manager""" 316 | self.start = torch.cuda.Event(enable_timing=True) 317 | self.end = torch.cuda.Event(enable_timing=True) 318 | self.start.record() 319 | self.sum = 0 320 | return self 321 | 322 | def __exit__(self, *exc_info): 323 | """Stop the context manager timer""" 324 | self.end.record() 325 | torch.cuda.synchronize() 326 | self.sum = self.start.elapsed_time(self.end) / 1000. 327 | -------------------------------------------------------------------------------- /fairnr/data/geometry.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from fairnr.data import data_utils as D 7 | from fairnr.clib._ext import build_octree 8 | import time 9 | 10 | INF = 1000.0 11 | 12 | 13 | def ones_like(x): 14 | T = torch if isinstance(x, torch.Tensor) else np 15 | return T.ones_like(x) 16 | 17 | 18 | def stack(x): 19 | T = torch if isinstance(x[0], torch.Tensor) else np 20 | return T.stack(x) 21 | 22 | 23 | def matmul(x, y): 24 | T = torch if isinstance(x, torch.Tensor) else np 25 | return T.matmul(x, y) 26 | 27 | 28 | def cross(x, y, axis=0): 29 | T = torch if isinstance(x, torch.Tensor) else np 30 | return T.cross(x, y, axis) 31 | 32 | 33 | def cat(x, axis=1): 34 | if isinstance(x[0], torch.Tensor): 35 | return torch.cat(x, dim=axis) 36 | return np.concatenate(x, axis=axis) 37 | 38 | 39 | def normalize(x, axis=-1, order=2): 40 | if isinstance(x, torch.Tensor): 41 | l2 = x.norm(p=order, dim=axis, keepdim=True) 42 | return x / (l2 + 1e-8), l2 43 | 44 | else: 45 | l2 = np.linalg.norm(x, order, axis) 46 | l2 = np.expand_dims(l2, axis) 47 | l2[l2==0] = 1 48 | return x / l2, l2 49 | 50 | 51 | def parse_extrinsics(extrinsics, world2camera=True): 52 | """ this function is only for numpy for now""" 53 | if extrinsics.shape[0] == 3 and extrinsics.shape[1] == 4: 54 | extrinsics = np.vstack([extrinsics, np.array([[0, 0, 0, 1.0]])]) 55 | if extrinsics.shape[0] == 1 and extrinsics.shape[1] == 16: 56 | extrinsics = extrinsics.reshape(4, 4) 57 | if world2camera: 58 | extrinsics = np.linalg.inv(extrinsics).astype(np.float32) 59 | return extrinsics 60 | 61 | 62 | def parse_intrinsics(intrinsics): 63 | fx = intrinsics[0, 0] 64 | fy = intrinsics[1, 1] 65 | cx = intrinsics[0, 2] 66 | cy = intrinsics[1, 2] 67 | return fx, fy, cx, cy 68 | 69 | 70 | def uv2cam(uv, z, intrinsics, homogeneous=False): 71 | fx, fy, cx, cy = parse_intrinsics(intrinsics) 72 | x_lift = (uv[0] - cx) / fx * z 73 | y_lift = (uv[1] - cy) / fy * z 74 | z_lift = ones_like(x_lift) * z 75 | 76 | if homogeneous: 77 | return stack([x_lift, y_lift, z_lift, ones_like(z_lift)]) 78 | else: 79 | return stack([x_lift, y_lift, z_lift]) 80 | 81 | 82 | def cam2world(xyz_cam, inv_RT): 83 | return matmul(inv_RT, xyz_cam)[:3] 84 | 85 | 86 | def r6d2mat(d6: torch.Tensor) -> torch.Tensor: 87 | """ 88 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 89 | using Gram--Schmidt orthogonalisation per Section B of [1]. 90 | Args: 91 | d6: 6D rotation representation, of size (*, 6) 92 | 93 | Returns: 94 | batch of rotation matrices of size (*, 3, 3) 95 | 96 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 97 | On the Continuity of Rotation Representations in Neural Networks. 98 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 99 | Retrieved from http://arxiv.org/abs/1812.07035 100 | """ 101 | 102 | a1, a2 = d6[..., :3], d6[..., 3:] 103 | b1 = F.normalize(a1, dim=-1) 104 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 105 | b2 = F.normalize(b2, dim=-1) 106 | b3 = torch.cross(b1, b2, dim=-1) 107 | return torch.stack((b1, b2, b3), dim=-2) 108 | 109 | 110 | def get_ray_direction(ray_start, uv, intrinsics, inv_RT, depths=None): 111 | if depths is None: 112 | depths = 1 113 | rt_cam = uv2cam(uv, depths, intrinsics, True) 114 | rt = cam2world(rt_cam, inv_RT) 115 | ray_dir, _ = normalize(rt - ray_start[:, None], axis=0) 116 | return ray_dir 117 | 118 | def transfer_uv(uv, size): 119 | h, w = int(size[0,0,0]*size[0,0,2]), int(size[0,0,1]*size[0,0,3]) 120 | h_rate = (2. / h) 121 | w_rate = (2. / w) 122 | new_uv = uv.clone() 123 | new_uv[0] = w_rate * uv[0] - 1 + w_rate / 2 # Image 0 -> Coor2d -1 + 1/w, Image w-1 -> Coor2d 1 - 1/w 124 | new_uv[1] = h_rate * (h-1-uv[1]) - 1 + h_rate / 2 # Image 0 -> Coor2d -1 + 1/h, Image h-1 -> Coor2d 1 - 1/h 125 | return new_uv 126 | 127 | def homogeneous(coor): 128 | if isinstance(coor, torch.Tensor): 129 | coor = torch.cat([coor, torch.ones([coor.shape[0], 1]).type_as(coor)], dim=1) 130 | else: 131 | coor = np.concatenate([coor, np.ones([coor.shape[0], 1])], 1) 132 | return coor 133 | 134 | def uv2world_proj(projectionMatrix, cameraPose, xy, depth): 135 | 136 | xy = xy.T 137 | if isinstance(depth, torch.Tensor): 138 | depth = depth.T / 1000 139 | coor_cvv = torch.cat([xy, depth], dim=1) 140 | else: 141 | coor_cvv = torch.cat([xy, torch.ones([xy.shape[0], 1]).type_as(xy)*depth], dim=1) 142 | point_num = coor_cvv.shape[0] 143 | p01 = projectionMatrix[:2].unsqueeze(0).expand([point_num, 2, 4]) 144 | p3 = projectionMatrix[-1:].unsqueeze(0).expand([point_num, 1, 4]) 145 | cvv_xy = coor_cvv[:, :2].unsqueeze(-1).expand([point_num, 2, 1]) 146 | M = p01 - torch.einsum('nto,nof->ntf', cvv_xy, p3) 147 | z = torch.stack([- coor_cvv[:, 2], torch.ones_like(coor_cvv[:, 2])], dim=-1).unsqueeze(-1) 148 | b = torch.einsum('nst,ntk->nsk', M[:, :, 2:], z) 149 | camera_xy = - torch.einsum('nts,nsk->ntk', torch.inverse(M[:, :, :2]), b).squeeze(-1) 150 | coor_camera = torch.cat([camera_xy, - coor_cvv[:, 2].unsqueeze(-1)], dim=-1) 151 | 152 | coor_camera = homogeneous(coor_camera) 153 | # cameraPose[:,1] *= -1 154 | # cameraPose[:,2] *= -1 155 | coor3d = torch.mm(cameraPose, coor_camera.T).T 156 | 157 | return coor3d[:, :3].T 158 | 159 | def get_ray_direction_proj(ray_start, uv, inv_RT, proj, size, depths=None): 160 | if depths is None: 161 | depths = 1 162 | new_uv = transfer_uv(uv, size) 163 | rt = uv2world_proj(proj, inv_RT, new_uv, depths) 164 | ray_dir, _ = normalize(rt - ray_start[:, None], axis=0) 165 | return ray_dir 166 | 167 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): 168 | """ 169 | This function takes a vector 'camera_position' which specifies the location 170 | of the camera in world coordinates and two vectors `at` and `up` which 171 | indicate the position of the object and the up directions of the world 172 | coordinate system respectively. The object is assumed to be centered at 173 | the origin. 174 | 175 | The output is a rotation matrix representing the transformation 176 | from world coordinates -> view coordinates. 177 | 178 | Input: 179 | camera_position: 3 180 | at: 1 x 3 or N x 3 (0, 0, 0) in default 181 | up: 1 x 3 or N x 3 (0, 1, 0) in default 182 | """ 183 | 184 | if at is None: 185 | at = torch.zeros_like(camera_position) 186 | else: 187 | at = torch.tensor(at).type_as(camera_position) 188 | if up is None: 189 | up = torch.zeros_like(camera_position) 190 | up[2] = -1 191 | else: 192 | up = torch.tensor(up).type_as(camera_position) 193 | 194 | z_axis = normalize(at - camera_position)[0] 195 | x_axis = normalize(cross(up, z_axis))[0] 196 | y_axis = normalize(cross(z_axis, x_axis))[0] 197 | 198 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) 199 | return R 200 | 201 | 202 | def ray(ray_start, ray_dir, depths): 203 | return ray_start + ray_dir * depths 204 | 205 | 206 | def compute_normal_map(ray_start, ray_dir, depths, RT, width=512, proj=False): 207 | # TODO: 208 | # this function is pytorch-only (for not) 209 | wld_coords = ray(ray_start, ray_dir, depths.unsqueeze(-1)).transpose(0, 1) 210 | cam_coords = matmul(RT[:3, :3], wld_coords) + RT[:3, 3].unsqueeze(-1) 211 | cam_coords = D.unflatten_img(cam_coords, width) 212 | 213 | # estimate local normal 214 | shift_l = cam_coords[:, 2:, :] 215 | shift_r = cam_coords[:, :-2, :] 216 | shift_u = cam_coords[:, :, 2: ] 217 | shift_d = cam_coords[:, :, :-2] 218 | diff_hor = normalize(shift_r - shift_l, axis=0)[0][:, :, 1:-1] 219 | diff_ver = normalize(shift_u - shift_d, axis=0)[0][:, 1:-1, :] 220 | normal = cross(diff_hor, diff_ver) 221 | _normal = normal.new_zeros(*cam_coords.size()) 222 | _normal[:, 1:-1, 1:-1] = normal 223 | _normal = _normal.reshape(3, -1).transpose(0, 1) 224 | 225 | # compute the projected color 226 | if proj: 227 | _normal = normalize(_normal, axis=1)[0] 228 | wld_coords0 = ray(ray_start, ray_dir, 0).transpose(0, 1) 229 | cam_coords0 = matmul(RT[:3, :3], wld_coords0) + RT[:3, 3].unsqueeze(-1) 230 | cam_coords0 = D.unflatten_img(cam_coords0, width) 231 | cam_raydir = normalize(cam_coords - cam_coords0, 0)[0].reshape(3, -1).transpose(0, 1) 232 | proj_factor = (_normal * cam_raydir).sum(-1).abs() * 0.8 + 0.2 233 | return proj_factor 234 | return _normal 235 | 236 | 237 | def trilinear_interp(p, q, point_feats): 238 | weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True) 239 | if point_feats.dim() == 2: 240 | point_feats = point_feats.view(point_feats.size(0), 8, -1) 241 | point_feats = (weights * point_feats).sum(1) 242 | return point_feats 243 | 244 | 245 | # helper functions for encoder 246 | 247 | def padding_points(xs, pad): 248 | if len(xs) == 1: 249 | return xs[0].unsqueeze(0) 250 | 251 | maxlen = max([x.size(0) for x in xs]) 252 | xt = xs[0].new_ones(len(xs), maxlen, xs[0].size(1)).fill_(pad) 253 | for i in range(len(xs)): 254 | xt[i, :xs[i].size(0)] = xs[i] 255 | return xt 256 | 257 | 258 | def pruning_points(feats, points, scores, depth=0, th=0.5): 259 | if depth > 0: 260 | g = int(8 ** depth) 261 | scores = scores.reshape(scores.size(0), -1, g).sum(-1, keepdim=True) 262 | scores = scores.expand(*scores.size()[:2], g).reshape(scores.size(0), -1) 263 | alpha = (1 - torch.exp(-scores)) > th 264 | feats = [feats[i][alpha[i]] for i in range(alpha.size(0))] 265 | points = [points[i][alpha[i]] for i in range(alpha.size(0))] 266 | points = padding_points(points, INF) 267 | feats = padding_points(feats, 0) 268 | return feats, points 269 | 270 | 271 | def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2): 272 | c = torch.arange(1, 2 * bits, 2, device=point_xyz.device) 273 | ox, oy, oz = torch.meshgrid([c, c, c]) 274 | offset = (torch.cat([ 275 | ox.reshape(-1, 1), 276 | oy.reshape(-1, 1), 277 | oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1) 278 | if not offset_only: 279 | return point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel 280 | return offset.type_as(point_xyz) * quarter_voxel 281 | 282 | 283 | def discretize_points(voxel_points, voxel_size): 284 | # this function turns voxel centers/corners into integer indeices 285 | # we assume all points are alreay put as voxels (real numbers) 286 | minimal_voxel_point = voxel_points.min(dim=0, keepdim=True)[0] 287 | voxel_indices = ((voxel_points - minimal_voxel_point) / voxel_size).round_().long() # float 288 | residual = (voxel_points - voxel_indices.type_as(voxel_points) * voxel_size).mean(0, keepdim=True) 289 | return voxel_indices, residual 290 | 291 | 292 | def splitting_points(point_xyz, point_feats, values, half_voxel): 293 | # generate new centers 294 | quarter_voxel = half_voxel * .5 295 | new_points = offset_points(point_xyz, quarter_voxel).reshape(-1, 3) 296 | old_coords = discretize_points(point_xyz, quarter_voxel)[0] 297 | new_coords = offset_points(old_coords).reshape(-1, 3) 298 | new_keys0 = offset_points(new_coords).reshape(-1, 3) 299 | 300 | # get unique keys and inverse indices (for original key0, where it maps to in keys) 301 | new_keys, new_feats = torch.unique(new_keys0, dim=0, sorted=True, return_inverse=True) 302 | new_keys_idx = new_feats.new_zeros(new_keys.size(0)).scatter_( 303 | 0, new_feats, torch.arange(new_keys0.size(0), device=new_feats.device) // 64) 304 | 305 | # recompute key vectors using trilinear interpolation 306 | new_feats = new_feats.reshape(-1, 8) 307 | 308 | if values is not None: 309 | p = (new_keys - old_coords[new_keys_idx]).type_as(point_xyz).unsqueeze(1) * .25 + 0.5 # (1/4 voxel size) 310 | q = offset_points(p, .5, offset_only=True).unsqueeze(0) + 0.5 # BUG? 311 | point_feats = point_feats[new_keys_idx] 312 | point_feats = F.embedding(point_feats, values).view(point_feats.size(0), -1) 313 | new_values = trilinear_interp(p, q, point_feats) 314 | else: 315 | new_values = None 316 | return new_points, new_feats, new_values, new_keys 317 | 318 | 319 | def expand_points(voxel_points, voxel_size): 320 | _voxel_size = min([ 321 | torch.sqrt(((voxel_points[j:j+1] - voxel_points[j+1:]) ** 2).sum(-1).min()) 322 | for j in range(100)]) 323 | depth = int(np.round(torch.log2(_voxel_size / voxel_size))) 324 | if depth > 0: 325 | half_voxel = _voxel_size / 2.0 326 | for _ in range(depth): 327 | voxel_points = offset_points(voxel_points, half_voxel / 2.0).reshape(-1, 3) 328 | half_voxel = half_voxel / 2.0 329 | 330 | return voxel_points, depth 331 | 332 | 333 | def get_edge(depth_pts, voxel_pts, voxel_size, th=0.05): 334 | voxel_pts = offset_points(voxel_pts, voxel_size / 2.0) 335 | diff_pts = (voxel_pts - depth_pts[:, None, :]).norm(dim=2) 336 | ab = diff_pts.sort(dim=1)[0][:, :2] 337 | a, b = ab[:, 0], ab[:, 1] 338 | c = voxel_size 339 | p = (ab.sum(-1) + c) / 2.0 340 | h = (p * (p - a) * (p - b) * (p - c)) ** 0.5 / c 341 | return h < (th * voxel_size) 342 | 343 | 344 | # fill-in image 345 | def fill_in(shape, hits, input, initial=1.0): 346 | if isinstance(initial, torch.Tensor): 347 | output = initial.expand(*shape) 348 | else: 349 | output = input.new_ones(*shape) * initial 350 | if input is not None: 351 | if len(shape) == 1: 352 | return output.masked_scatter(hits, input) 353 | return output.masked_scatter(hits.unsqueeze(-1).expand(*shape), input) 354 | return output 355 | 356 | 357 | def build_easy_octree(points, half_voxel): 358 | coords, residual = discretize_points(points, half_voxel) 359 | ranges = coords.max(0)[0] - coords.min(0)[0] 360 | depths = torch.log2(ranges.max().float()).ceil_().long() - 1 361 | center = (coords.max(0)[0] + coords.min(0)[0]) / 2 362 | centers, children = build_octree(center, coords, int(depths)) 363 | centers = centers.float() * half_voxel + residual # transform back to float 364 | return centers, children -------------------------------------------------------------------------------- /fairnr/data/trajectory.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | 5 | TRAJECTORY_REGISTRY = {} 6 | 7 | 8 | def register_traj(name): 9 | def register_traj_fn(fn): 10 | if name in TRAJECTORY_REGISTRY: 11 | raise ValueError('Cannot register duplicate trajectory ({})'.format(name)) 12 | TRAJECTORY_REGISTRY[name] = fn 13 | return fn 14 | return register_traj_fn 15 | 16 | 17 | def get_trajectory(name): 18 | return TRAJECTORY_REGISTRY.get(name, None) 19 | 20 | 21 | @register_traj('circle') 22 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): 23 | if axis == 'z': 24 | return lambda t: [radius * np.cos(r * t+t0), radius * np.sin(r * t+t0), h] 25 | elif axis == 'y': 26 | return lambda t: [radius * np.cos(r * t+t0), h, radius * np.sin(r * t+t0)] 27 | else: 28 | return lambda t: [h, radius * np.cos(r * t+t0), radius * np.sin(r * t+t0)] 29 | 30 | 31 | @register_traj('zoomin_circle') 32 | def zoomin_circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): 33 | ra = lambda t: 0.1 + abs(4.0 - t * 2 / np.pi) 34 | 35 | if axis == 'z': 36 | return lambda t: [radius * ra(t) * np.cos(r * t+t0), radius * ra(t) * np.sin(r * t+t0), h] 37 | elif axis == 'y': 38 | return lambda t: [radius * ra(t) * np.cos(r * t+t0), h, radius * ra(t) * np.sin(r * t+t0)] 39 | else: 40 | return lambda t: [h, radius * (4.2 - t * 2 / np.pi) * np.cos(r * t+t0), radius * (4.2 - t * 2 / np.pi) * np.sin(r * t+t0)] 41 | 42 | 43 | @register_traj('zoomin_line') 44 | def zoomin_line(radius=3.5, h=0.0, axis='z', t0=0, r=1, min_r=0.0001, max_r=10, step_r=10): 45 | ra = lambda t: min_r + (max_r - min_r) * t * 180 / np.pi / step_r 46 | 47 | if axis == 'z': 48 | return lambda t: [radius * ra(t) * np.cos(t0), radius * ra(t) * np.sin(t0), h * ra(t)] 49 | elif axis == 'y': 50 | return lambda t: [radius * ra(t) * np.cos(t0), h, radius * ra(t) * np.sin(t0)] 51 | else: 52 | return lambda t: [h, radius * (4.2 - t * 2 / np.pi) * np.cos(r * t+t0), radius * (4.2 - t * 2 / np.pi) * np.sin(r * t+t0)] 53 | -------------------------------------------------------------------------------- /fairnr/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib 3 | import os 4 | 5 | # automatically import any Python files in the models/ directory 6 | models_dir = os.path.dirname(__file__) 7 | for file in os.listdir(models_dir): 8 | path = os.path.join(models_dir, file) 9 | if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): 10 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 11 | module = importlib.import_module('fairnr.models.' + model_name) 12 | -------------------------------------------------------------------------------- /fairnr/models/fairnr_model.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | import skimage.metrics 6 | import imageio, os 7 | import numpy as np 8 | 9 | from fairnr.modules.encoder import get_encoder 10 | from fairnr.modules.field import get_field 11 | from fairnr.modules.renderer import get_renderer 12 | from fairnr.modules.reader import get_reader 13 | from fairnr.modules.discriminator import get_discriminator 14 | from fairnr.data.geometry import ray, compute_normal_map, compute_normal_map 15 | from fairnr.data.data_utils import recover_image 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class BaseModel(nn.Module): 21 | """Base class""" 22 | 23 | ENCODER = 'abstract_encoder' 24 | FIELD = 'abstract_field' 25 | RAYMARCHER = 'abstract_renderer' 26 | READER = 'abstract_reader' 27 | DISCRIMINATOR = 'abstract_discriminator' 28 | 29 | def __init__(self, args, reader, encoder, field, raymarcher, discriminator=None): 30 | super().__init__() 31 | self.args = args 32 | self.reader = reader 33 | self.encoder = encoder 34 | self.field = field 35 | self.raymarcher = raymarcher 36 | self.discriminator = discriminator 37 | self.cache = None 38 | 39 | @classmethod 40 | def build_model(cls, args): 41 | """Build a new model instance.""" 42 | reader = get_reader(cls.READER)(args) 43 | encoder = get_encoder(cls.ENCODER)(args) 44 | field = get_field(cls.FIELD)(args) 45 | raymarcher = get_renderer(cls.RAYMARCHER)(args) 46 | if args.dis: 47 | discriminator = get_discriminator(cls.DISCRIMINATOR)(args) 48 | return cls(args, reader, encoder, field, raymarcher, discriminator) 49 | else: 50 | return cls(args, reader, encoder, field, raymarcher) 51 | 52 | @classmethod 53 | def add_args(cls, parser): 54 | get_reader(cls.READER).add_args(parser) 55 | get_renderer(cls.RAYMARCHER).add_args(parser) 56 | get_encoder(cls.ENCODER).add_args(parser) 57 | get_field(cls.FIELD).add_args(parser) 58 | get_discriminator(cls.DISCRIMINATOR).add_args(parser) 59 | 60 | # def forward(self, ray_start, ray_dir, ray_split=1, **kwargs): 61 | def forward(self, ray_split=1, **kwargs): 62 | # ray_start, ray_dir, uv = self.reader(**kwargs) 63 | ray_start, ray_dir, uv, dis_ray_start, dis_ray_dir = self.reader(**kwargs) 64 | kwargs.update({ 65 | 'field_fn': self.field.forward, 66 | 'input_fn': self.encoder.forward}) 67 | 68 | if ray_split == 1: 69 | results = self._forward(ray_start, ray_dir, dis_ray_start, dis_ray_dir, **kwargs) 70 | else: 71 | total_rays = ray_dir.shape[2] 72 | chunk_size = total_rays // ray_split 73 | results = [ 74 | self._forward( 75 | ray_start, ray_dir[:, :, i: i+chunk_size], dis_ray_start, dis_ray_dir[:, :, i: i+chunk_size], **kwargs) 76 | for i in range(0, total_rays, chunk_size) 77 | ] 78 | results = self.merge_outputs(results) 79 | 80 | if results.get('sampled_uv', None) is None: 81 | results['sampled_uv'] = uv 82 | results['ray_start'] = ray_start 83 | results['ray_dir'] = ray_dir 84 | 85 | # caching the prediction 86 | self.cache = { 87 | w: results[w].detach() 88 | if isinstance(w, torch.Tensor) 89 | else results[w] 90 | for w in results 91 | } 92 | return results 93 | 94 | def _forward(self, ray_start, ray_dir, dis_ray_start, dis_ray_dir, **kwargs): 95 | raise NotImplementedError 96 | 97 | def merge_outputs(self, outputs): 98 | new_output = {} 99 | for key in outputs[0]: 100 | if isinstance(outputs[0][key], torch.Tensor) and outputs[0][key].dim() > 2: 101 | new_output[key] = torch.cat([o[key] for o in outputs], 2) 102 | else: 103 | new_output[key] = outputs[0][key] 104 | return new_output 105 | 106 | @torch.no_grad() 107 | def visualize(self, sample, output=None, shape=0, view=0, raw_depth=False, **kwargs): 108 | width = int(sample['size'][shape, view][1].item()) 109 | img_id = '{}_{}'.format(sample['shape'][shape], sample['view'][shape, view]) 110 | 111 | if output is None: 112 | assert self.cache is not None, "need to run forward-pass" 113 | output = self.cache # make sure to run forward-pass. 114 | 115 | images = {} 116 | images = self._visualize(images, sample, output, [img_id, shape, view, width, 'render']) 117 | images = self._visualize(images, sample, sample, [img_id, shape, view, width, 'target']) 118 | if raw_depth: 119 | for tag in images: 120 | images[tag] = recover_image(width=width, **images[tag]) if (images[tag] is not None and tag.split('/')[0].split('_')[1] != 'voxeldepth') else images[tag]['img'] 121 | else: 122 | images = { 123 | tag: recover_image(width=width, **images[tag]) 124 | for tag in images if images[tag] is not None 125 | } 126 | return images 127 | 128 | def _visualize(self, images, sample, output, state, **kwargs): 129 | img_id, shape, view, width, name = state 130 | if 'colors' in output and output['colors'] is not None: 131 | images['{}_color/{}:HWC'.format(name, img_id)] ={ 132 | 'img': output['colors'][shape, view]} 133 | 134 | if 'depths' in output and output['depths'] is not None and 'ray_start' in output: 135 | min_depth, max_depth = output['depths'].min(), output['depths'].max() 136 | images['{}_depth/{}:HWC'.format(name, img_id)] = { 137 | 'img': output['depths'][shape, view], 138 | 'min_val': min_depth, 139 | 'max_val': max_depth} 140 | normals = compute_normal_map( 141 | output['ray_start'][shape, view].float(), 142 | output['ray_dir'][shape, view].float(), 143 | output['depths'][shape, view].float(), 144 | sample['extrinsics'][shape, view].float().inverse(), width) 145 | images['{}_normal/{}:HWC'.format(name, img_id)] = { 146 | 'img': normals, 'min_val': -1, 'max_val': 1} 147 | 148 | if 'voxel_depth' in output and output['voxel_depth'] is not None: 149 | min_depth, max_depth = output['voxel_depth'].min(), output['voxel_depth'].max() 150 | images['{}_voxeldepth/{}:HWC'.format(name, img_id)] = { 151 | 'img': output['voxel_depth'][shape, view], 152 | 'min_val': min_depth, 153 | 'max_val': max_depth} 154 | 155 | return images 156 | 157 | def add_eval_scores(self, logging_output, sample, output, criterion, scores=['ssim', 'psnr', 'lpips'], outdir=None): 158 | predicts, targets = output['colors'], sample['colors'] 159 | ssims, psnrs, lpips, rmses = [], [], [], [] 160 | 161 | for s in range(predicts.size(0)): 162 | for v in range(predicts.size(1)): 163 | width = int(sample['size'][s, v][1]) 164 | p = recover_image(predicts[s, v], width=width) 165 | t = recover_image(targets[s, v], width=width) 166 | pn, tn = p.numpy(), t.numpy() 167 | p, t = p.to(predicts.device), t.to(targets.device) 168 | 169 | if 'ssim' in scores: 170 | ssims += [skimage.metrics.structural_similarity(pn, tn, multichannel=True, data_range=1)] 171 | if 'psnr' in scores: 172 | psnrs += [skimage.metrics.peak_signal_noise_ratio(pn, tn, data_range=1)] 173 | if 'lpips' in scores and hasattr(criterion, 'lpips'): 174 | with torch.no_grad(): 175 | lpips += [criterion.lpips( 176 | 2 * p.unsqueeze(-1).permute(3,2,0,1) - 1, 177 | 2 * t.unsqueeze(-1).permute(3,2,0,1) - 1).item()] 178 | if 'depths' in sample: 179 | td = sample['depths'][sample['depths'] > 0] 180 | pd = output['depths'][sample['depths'] > 0] 181 | rmses += [torch.sqrt(((td - pd) ** 2).mean()).item()] 182 | 183 | if outdir is not None: 184 | def imsave(filename, image): 185 | imageio.imsave(os.path.join(outdir, filename), (image * 255).astype('uint8')) 186 | 187 | figname = '-{:03d}_{:03d}.png'.format(sample['id'][s], sample['view'][s, v]) 188 | imsave('output' + figname, pn) 189 | imsave('target' + figname, tn) 190 | imsave('normal' + figname, recover_image(compute_normal_map( 191 | output['ray_start'][s, v].float(), output['ray_dir'][s, v].float(), 192 | output['depths'][s, v].float(), sample['extrinsics'][s, v].float().inverse(), width=width), 193 | min_val=-1, max_val=1, width=width).numpy()) 194 | 195 | if len(ssims) > 0: 196 | logging_output['ssim_loss'] = np.mean(ssims) 197 | if len(psnrs) > 0: 198 | logging_output['psnr_loss'] = np.mean(psnrs) 199 | if len(lpips) > 0: 200 | logging_output['lpips_loss'] = np.mean(lpips) 201 | if len(rmses) > 0: 202 | logging_output['rmses_loss'] = np.mean(rmses) 203 | 204 | 205 | -------------------------------------------------------------------------------- /fairnr/models/nsvf.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | logger = logging.getLogger(__name__) 4 | 5 | import cv2, math, time 6 | import numpy as np 7 | from collections import defaultdict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from fairnr.data.data_utils import GPUTimer 14 | from fairnr.data.geometry import compute_normal_map, fill_in 15 | from fairnr.models.fairnr_model import BaseModel 16 | 17 | MAX_DEPTH = 10000.0 18 | 19 | 20 | class NSVFModel(BaseModel): 21 | 22 | READER = 'image_reader' 23 | ENCODER = 'sparsevoxel_encoder' 24 | FIELD = 'radiance_field' 25 | RAYMARCHER = 'volume_rendering' 26 | DISCRIMINATOR = 'patch_discriminator' 27 | 28 | def _forward(self, ray_start, ray_dir, dis_ray_start, dis_ray_dir, **kwargs): 29 | S, V, P, _ = ray_dir.size() 30 | assert S == 1, "naive NeRF only supports single object." 31 | 32 | # voxel encoder (precompute for each voxel if needed) 33 | encoder_states = self.encoder.precompute(**kwargs) 34 | if 'move' in kwargs: 35 | encoder_states['voxel_octree_center_xyz'] += kwargs['move'] 36 | encoder_states['voxel_center_xyz'] += kwargs['move'] 37 | 38 | # ray-voxel intersection 39 | with GPUTimer() as timer0: 40 | ray_start, ray_dir, intersection_outputs, hits = \ 41 | self.encoder.ray_intersect(ray_start, ray_dir, encoder_states) 42 | 43 | sampled_hits = torch.zeros(2,2) 44 | if self.args.dis and dis_ray_start is not None: 45 | dis_ray_start, dis_ray_dir, dis_intersection_outputs, dis_hits = \ 46 | self.encoder.ray_intersect(dis_ray_start, dis_ray_dir, encoder_states) 47 | 48 | uv, size = kwargs['uv'], kwargs['size'] 49 | h, w = int(size[0,0,0]), int(size[0,0,1]) 50 | patch_size = self.args.patch_size 51 | while sampled_hits.sum() < 256: 52 | sampled_masks = torch.zeros((S, V, h, w)).to(uv.device) 53 | h_rand, w_rand = np.random.randint(0, h-patch_size), np.random.randint(0, w-patch_size) 54 | h_range, w_range = np.arange(h_rand, h_rand+patch_size), np.arange(w_rand, w_rand+patch_size) 55 | uu, vv = np.meshgrid(h_range, w_range) 56 | # dis_uv = torch.cat((torch.from_numpy(uu).unsqueeze(0), torch.from_numpy(vv).unsqueeze(0)), dim=0).reshape(S, V, 2, patch_size*patch_size).to(uv.device) 57 | hits_reshape = dis_hits.reshape(h, w) 58 | sampled_masks = sampled_masks.bool() 59 | sampled_masks[0, 0, uu, vv] = torch.ones_like(hits_reshape[uu, vv]) 60 | sampled_masks = sampled_masks.reshape(uv.size(0), -1).bool() 61 | 62 | sampled_hits = dis_hits[sampled_masks].reshape(S, -1) 63 | 64 | sampled_masks = sampled_masks.unsqueeze(-1) 65 | sampled_intersection_outputs = {name: outs[sampled_masks.expand_as(outs)].reshape(S, -1, outs.size(-1)) 66 | for name, outs in dis_intersection_outputs.items()} 67 | sampled_ray_start = dis_ray_start[sampled_masks.expand_as(dis_ray_start)].reshape(S, -1, 3) 68 | sampled_ray_dir = dis_ray_dir[sampled_masks.expand_as(dis_ray_dir)].reshape(S, -1, 3) 69 | sampled_P = sampled_hits.size(-1) // V # the number of pixels per image 70 | 71 | if self.reader.no_sampling and self.training: # sample points after ray-voxel intersection 72 | uv, size = kwargs['uv'], kwargs['size'] 73 | mask = hits.reshape(*uv.size()[:2], uv.size(-1)) 74 | 75 | # sample rays based on voxel intersections 76 | sampled_uv, sampled_masks = self.reader.sample_pixels( 77 | uv, size, mask=mask, return_mask=True) 78 | sampled_masks = sampled_masks.reshape(uv.size(0), -1).bool() 79 | hits, sampled_masks = hits[sampled_masks].reshape(S, -1), sampled_masks.unsqueeze(-1) 80 | intersection_outputs = {name: outs[sampled_masks.expand_as(outs)].reshape(S, -1, outs.size(-1)) 81 | for name, outs in intersection_outputs.items()} 82 | ray_start = ray_start[sampled_masks.expand_as(ray_start)].reshape(S, -1, 3) 83 | ray_dir = ray_dir[sampled_masks.expand_as(ray_dir)].reshape(S, -1, 3) 84 | P = hits.size(-1) // V # the number of pixels per image 85 | else: 86 | sampled_uv = None 87 | 88 | if self.args.dis and dis_ray_start is not None: 89 | # neural ray-marching 90 | sampled_fullsize = S * V * sampled_P 91 | 92 | BG_DEPTH = self.field.bg_color.depth 93 | bg_color = self.field.bg_color(sampled_ray_dir) 94 | 95 | sampled_all_results = defaultdict(lambda: None) 96 | if sampled_hits.sum() > 0: # check if ray missed everything 97 | sampled_intersection_outputs = {name: outs[sampled_hits] for name, outs in sampled_intersection_outputs.items()} 98 | sampled_ray_start, sampled_ray_dir = sampled_ray_start[sampled_hits], sampled_ray_dir[sampled_hits] 99 | 100 | # sample evalution points along the ray 101 | samples = self.encoder.ray_sample(sampled_intersection_outputs) 102 | encoder_states = {name: s.reshape(-1, s.size(-1)) if s is not None else None 103 | for name, s in encoder_states.items()} 104 | 105 | # rendering 106 | sampled_all_results = self.raymarcher( 107 | self.encoder, self.field, sampled_ray_start, sampled_ray_dir, samples, encoder_states) 108 | sampled_all_results['depths'] = sampled_all_results['depths'] + BG_DEPTH * sampled_all_results['missed'] 109 | sampled_all_results['voxel_edges'] = self.encoder.get_edge(sampled_ray_start, sampled_ray_dir, samples, encoder_states) 110 | sampled_all_results['voxel_depth'] = samples['sampled_point_depth'][:, 0] 111 | 112 | # fill out the full size 113 | sampled_hits = sampled_hits.reshape(sampled_fullsize) 114 | sampled_all_results['missed'] = fill_in((sampled_fullsize, ), sampled_hits, sampled_all_results['missed'], 1.0).view(S, V, sampled_P) 115 | sampled_all_results['colors'] = fill_in((sampled_fullsize, 3), sampled_hits, sampled_all_results['colors'], 0.0).view(S, V, sampled_P, 3) 116 | sampled_all_results['bg_color'] = bg_color.reshape(sampled_fullsize, 3).view(S, V, sampled_P, 3) 117 | sampled_all_results['colors'] += sampled_all_results['missed'].unsqueeze(-1) * sampled_all_results['bg_color'] 118 | 119 | # neural ray-marching 120 | fullsize = S * V * P 121 | 122 | BG_DEPTH = self.field.bg_color.depth 123 | bg_color = self.field.bg_color(ray_dir) 124 | 125 | all_results = defaultdict(lambda: None) 126 | if hits.sum() > 0: # check if ray missed everything 127 | intersection_outputs = {name: outs[hits] for name, outs in intersection_outputs.items()} 128 | ray_start, ray_dir = ray_start[hits], ray_dir[hits] 129 | 130 | # sample evalution points along the ray 131 | samples = self.encoder.ray_sample(intersection_outputs) 132 | encoder_states = {name: s.reshape(-1, s.size(-1)) if s is not None else None 133 | for name, s in encoder_states.items()} 134 | 135 | # rendering 136 | all_results = self.raymarcher( 137 | self.encoder, self.field, ray_start, ray_dir, samples, encoder_states) 138 | all_results['depths'] = all_results['depths'] + BG_DEPTH * all_results['missed'] 139 | all_results['voxel_edges'] = self.encoder.get_edge(ray_start, ray_dir, samples, encoder_states) 140 | all_results['voxel_depth'] = samples['sampled_point_depth'][:, 0] 141 | 142 | # fill out the full size 143 | hits = hits.reshape(fullsize) 144 | all_results['missed'] = fill_in((fullsize, ), hits, all_results['missed'], 1.0).view(S, V, P) 145 | all_results['depths'] = fill_in((fullsize, ), hits, all_results['depths'], BG_DEPTH).view(S, V, P) 146 | all_results['voxel_depth'] = fill_in((fullsize, ), hits, all_results['voxel_depth'], BG_DEPTH).view(S, V, P) 147 | all_results['voxel_edges'] = fill_in((fullsize, 3), hits, all_results['voxel_edges'], 1.0).view(S, V, P, 3) 148 | all_results['colors'] = fill_in((fullsize, 3), hits, all_results['colors'], 0.0).view(S, V, P, 3) 149 | all_results['bg_color'] = bg_color.reshape(fullsize, 3).view(S, V, P, 3) 150 | all_results['colors'] += all_results['missed'].unsqueeze(-1) * all_results['bg_color'] 151 | if 'normal' in all_results: 152 | all_results['normal'] = fill_in((fullsize, 3), hits, all_results['normal'], 0.0).view(S, V, P, 3) 153 | 154 | # discriminator 155 | if self.args.dis and dis_ray_start is not None: 156 | gen_img = sampled_all_results['colors'].view(patch_size, patch_size, 3) * .5 + .5 157 | H, W, h, w = kwargs['size'][0, 0].long().cpu().tolist() 158 | L = patch_size * patch_size 159 | tar_colors = kwargs['colors'][0, 0].view(H, W, 3) 160 | tar_img = tar_colors[uu, vv] * .5 + .5 161 | gen_img = gen_img.view(1, 3, patch_size, patch_size) 162 | tar_img = tar_img.view(1, 3, patch_size, patch_size) 163 | 164 | pred_fake = self.discriminator(gen_img) 165 | pred_real = self.discriminator(tar_img) 166 | all_results.update({'pred_fake':pred_fake, 'pred_real':pred_real}) 167 | 168 | # other logs 169 | all_results['other_logs'] = { 170 | 'voxs_log': self.encoder.voxel_size.item(), 171 | 'stps_log': self.encoder.step_size.item(), 172 | 'tvox_log': timer0.sum, 173 | 'asf_log': (all_results['ae'].float() / fullsize).item(), 174 | 'ash_log': (all_results['ae'].float() / hits.sum()).item(), 175 | 'nvox_log': self.encoder.num_voxels, 176 | } 177 | all_results['sampled_uv'] = sampled_uv 178 | return all_results 179 | 180 | def _visualize(self, images, sample, output, state, **kwargs): 181 | img_id, shape, view, width, name = state 182 | images = super()._visualize(images, sample, output, state, **kwargs) 183 | if 'voxel_edges' in output and output['voxel_edges'] is not None: 184 | # voxel hitting visualization 185 | images['{}_voxel/{}:HWC'.format(name, img_id)] = { 186 | 'img': output['voxel_edges'][shape, view].float(), 187 | 'min_val': 0, 188 | 'max_val': 1, 189 | 'weight': 190 | compute_normal_map( 191 | output['ray_start'][shape, view].float(), 192 | output['ray_dir'][shape, view].float(), 193 | output['voxel_depth'][shape, view].float(), 194 | sample['extrinsics'][shape, view].float().inverse(), 195 | width, proj=True) 196 | } 197 | if 'normal' in output and output['normal'] is not None: 198 | images['{}_predn/{}:HWC'.format(name, img_id)] = { 199 | 'img': output['normal'][shape, view], 'min_val': -1, 'max_val': 1} 200 | return images 201 | 202 | @torch.no_grad() 203 | def prune_voxels(self, th=0.5, train_stats=False): 204 | self.encoder.pruning(self.field, th, train_stats=train_stats) 205 | self.clean_caches() 206 | 207 | @torch.no_grad() 208 | def split_voxels(self): 209 | logger.info("half the global voxel size {:.4f} -> {:.4f}".format( 210 | self.encoder.voxel_size.item(), self.encoder.voxel_size.item() * .5)) 211 | self.encoder.splitting() 212 | self.encoder.voxel_size *= .5 213 | self.encoder.max_hits *= 1.5 214 | self.clean_caches() 215 | 216 | @torch.no_grad() 217 | def reduce_stepsize(self): 218 | logger.info("reduce the raymarching step size {:.4f} -> {:.4f}".format( 219 | self.encoder.step_size.item(), self.encoder.step_size.item() * .5)) 220 | self.encoder.step_size *= .5 221 | 222 | @torch.no_grad() 223 | def reduce_pixels_num(self, times): 224 | old_pixels_num = self.reader.num_pixels 225 | new_pixels_num = self.reader.downsample_pixels_num(times) 226 | logger.info("reduce the pixels per view {:.4f} -> {:.4f}".format( 227 | old_pixels_num, new_pixels_num)) 228 | 229 | def clean_caches(self, reset=False): 230 | self.encoder.clean_runtime_caches() 231 | if reset: 232 | self.encoder.reset_runtime_caches() 233 | torch.cuda.empty_cache() # cache release after Model do all things 234 | 235 | def base_architecture(args): 236 | # parameter needs to be changed 237 | args.voxel_size = getattr(args, "voxel_size", 0.25) 238 | args.max_hits = getattr(args, "max_hits", 60) 239 | args.raymarching_stepsize = getattr(args, "raymarching_stepsize", 0.01) 240 | args.raymarching_stepsize_ratio = getattr(args, "raymarching_stepsize_ratio", 0.0) 241 | 242 | # encoder default parameter 243 | args.voxel_embed_dim = getattr(args, "voxel_embed_dim", 32) 244 | args.voxel_path = getattr(args, "voxel_path", None) 245 | args.initial_boundingbox = getattr(args, "initial_boundingbox", None) 246 | 247 | # field 248 | args.inputs_to_density = getattr(args, "inputs_to_density", "emb:6:32") 249 | args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, ray:4") 250 | args.feature_embed_dim = getattr(args, "feature_embed_dim", 256) 251 | args.density_embed_dim = getattr(args, "density_embed_dim", 128) 252 | args.texture_embed_dim = getattr(args, "texture_embed_dim", 256) 253 | 254 | args.feature_layers = getattr(args, "feature_layers", 1) 255 | args.texture_layers = getattr(args, "texture_layers", 3) 256 | 257 | args.background_stop_gradient = getattr(args, "background_stop_gradient", False) 258 | args.background_depth = getattr(args, "background_depth", 5.0) 259 | 260 | # raymarcher 261 | args.discrete_regularization = getattr(args, "discrete_regularization", False) 262 | args.deterministic_step = getattr(args, "deterministic_step", False) 263 | args.raymarching_tolerance = getattr(args, "raymarching_tolerance", 0) 264 | args.use_octree = getattr(args, "use_octree", False) 265 | 266 | # reader 267 | args.pixel_per_view = getattr(args, "pixel_per_view", 2048) 268 | args.sampling_on_mask = getattr(args, "sampling_on_mask", 0.0) 269 | args.sampling_at_center = getattr(args, "sampling_at_center", 1.0) 270 | args.sampling_on_bbox = getattr(args, "sampling_on_bbox", False) 271 | args.sampling_patch_size = getattr(args, "sampling_patch_size", 1) 272 | args.sampling_skipping_size = getattr(args, "sampling_skipping_size", 1) 273 | 274 | # others 275 | args.chunk_size = getattr(args, "chunk_size", 64) 276 | args.valid_chunk_size = getattr(args, "valid_chunk_size", 64) 277 | 278 | def my_base_architecture(args): 279 | # parameter needs to be changed 280 | def set_default_value(args, name, value): 281 | if hasattr(args, name): 282 | if getattr(args, name) is None: 283 | setattr(args, name, value) 284 | else: 285 | setattr(args, name, value) 286 | set_default_value(args, "voxel_size", 0.25) 287 | set_default_value(args, "max_hits", 60) 288 | set_default_value(args, "raymarching_stepsize", 0.01) 289 | set_default_value(args, "raymarching_stepsize_ratio", 0.0) 290 | 291 | # encoder default parameter 292 | set_default_value(args, "voxel_embed_dim", 32) 293 | set_default_value(args, "voxel_path", None) 294 | set_default_value(args, "initial_boundingbox", None) 295 | 296 | # field 297 | set_default_value(args, "inputs_to_density", "emb:6:32") 298 | set_default_value(args, "inputs_to_texture", "feat:0:256, ray:4") 299 | set_default_value(args, "feature_embed_dim", 256) 300 | set_default_value(args, "density_embed_dim", 128) 301 | set_default_value(args, "texture_embed_dim", 256) 302 | 303 | set_default_value(args, "feature_layers", 1) 304 | set_default_value(args, "texture_layers", 3) 305 | 306 | set_default_value(args, "background_stop_gradient", False) 307 | set_default_value(args, "background_depth", 5.0) 308 | 309 | # raymarcher 310 | set_default_value(args, "discrete_regularization", False) 311 | set_default_value(args, "deterministic_step", False) 312 | set_default_value(args, "raymarching_tolerance", 0) 313 | set_default_value(args, "use_octree", False) 314 | 315 | # reader 316 | set_default_value(args, "pixel_per_view", 2048) 317 | set_default_value(args, "sampling_on_mask", 0.0) 318 | set_default_value(args, "sampling_at_center", 1.0) 319 | set_default_value(args, "sampling_on_bbox", False) 320 | set_default_value(args, "sampling_patch_size", 1) 321 | set_default_value(args, "sampling_skipping_size", 1) 322 | 323 | # others 324 | set_default_value(args, "chunk_size", 64) 325 | set_default_value(args, "valid_chunk_size", 64) 326 | set_default_value(args, "ray_chunk_size", 64) 327 | 328 | 329 | class PixelNSVFModel(NSVFModel): 330 | 331 | DISCRIMINATOR = "pixel_discriminator" 332 | -------------------------------------------------------------------------------- /fairnr/modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib 3 | import os 4 | 5 | # automatically import any Python files in the models/ directory 6 | models_dir = os.path.dirname(__file__) 7 | for file in os.listdir(models_dir): 8 | path = os.path.join(models_dir, file) 9 | if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): 10 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 11 | module = importlib.import_module('fairnr.modules.' + model_name) -------------------------------------------------------------------------------- /fairnr/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | DISCRIMINATOR_REGISTRY = {} 6 | 7 | def register_discriminator(name): 8 | def register_discriminator_cls(cls): 9 | if name in DISCRIMINATOR_REGISTRY: 10 | raise ValueError('Cannot register duplicate module ({})'.format(name)) 11 | DISCRIMINATOR_REGISTRY[name] = cls 12 | return cls 13 | return register_discriminator_cls 14 | 15 | 16 | def get_discriminator(name): 17 | if name not in DISCRIMINATOR_REGISTRY: 18 | raise ValueError('Cannot find module {}'.format(name)) 19 | return DISCRIMINATOR_REGISTRY[name] 20 | 21 | 22 | @register_discriminator('abstract_discriminator') 23 | class Discriminator(nn.Module): 24 | """ 25 | backbone network 26 | """ 27 | def __init__(self, args): 28 | super().__init__() 29 | self.args = args 30 | 31 | def forward(self, **kwargs): 32 | raise NotImplementedError 33 | 34 | @staticmethod 35 | def add_args(parser): 36 | pass 37 | 38 | @register_discriminator('patch_discriminator') 39 | class NLayerDiscriminator(Discriminator): 40 | """Defines a PatchGAN discriminator""" 41 | 42 | def __init__(self, args): 43 | """Construct a PatchGAN discriminator 44 | Parameters: 45 | input_nc (int) -- the number of channels in input images 46 | ndf (int) -- the number of filters in the last conv layer 47 | n_layers (int) -- the number of conv layers in the discriminator 48 | norm_layer -- normalization layer 49 | """ 50 | super().__init__(args) 51 | input_nc, ndf, n_layers = args.input_nc, args.ndf, args.n_layers 52 | if args.gan_norm_layer == 'batch': 53 | norm_layer = nn.BatchNorm2d 54 | use_bias = False 55 | elif args.gan_norm_layer == 'instance': 56 | norm_layer = nn.InstanceNorm2d 57 | use_bias = True 58 | 59 | kw = 4 60 | padw = 1 61 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 62 | nf_mult = 1 63 | nf_mult_prev = 1 64 | for n in range(1, n_layers): # gradually increase the number of filters 65 | nf_mult_prev = nf_mult 66 | nf_mult = min(2 ** n, 8) 67 | sequence += [ 68 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 69 | norm_layer(ndf * nf_mult), 70 | nn.LeakyReLU(0.2, True) 71 | ] 72 | 73 | nf_mult_prev = nf_mult 74 | nf_mult = min(2 ** n_layers, 8) 75 | sequence += [ 76 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True) 79 | ] 80 | 81 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 82 | self.model = nn.Sequential(*sequence) 83 | 84 | @staticmethod 85 | def add_args(parser): 86 | parser.add_argument('--input-nc', type=int, default=3, 87 | help='image input channel') 88 | parser.add_argument('--ndf', type=int, default=64, 89 | help='the number of filters in the last conv layer') 90 | parser.add_argument('--n-layers', type=int, default=3, 91 | help='the number of conv layers in the discriminator') 92 | parser.add_argument('--patch-size', type=int, default=32, 93 | help='image patch size that inputs to the discriminator') 94 | parser.add_argument('--gan-norm-layer', type=str, default='instance', 95 | help='batch normalization or instance normalization') 96 | 97 | def forward(self, input): 98 | """Standard forward.""" 99 | return self.model(input) 100 | 101 | 102 | @register_discriminator('pixel_discriminator') 103 | class PixelDiscriminator(Discriminator): 104 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 105 | 106 | def __init__(self, args): 107 | """Construct a 1x1 PatchGAN discriminator 108 | Parameters: 109 | input_nc (int) -- the number of channels in input images 110 | ndf (int) -- the number of filters in the last conv layer 111 | norm_layer -- normalization layer 112 | """ 113 | super().__init__(args) 114 | input_nc, ndf = args.input_nc, args.ndf 115 | if args.gan_norm_layer == 'batch': 116 | norm_layer = nn.BatchNorm2d 117 | use_bias = False 118 | elif args.gan_norm_layer == 'instance': 119 | norm_layer = nn.InstanceNorm2d 120 | use_bias = True 121 | 122 | self.net = [ 123 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 124 | nn.LeakyReLU(0.2, True), 125 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 126 | norm_layer(ndf * 2), 127 | nn.LeakyReLU(0.2, True), 128 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 129 | 130 | self.net = nn.Sequential(*self.net) 131 | 132 | @staticmethod 133 | def add_args(parser): 134 | parser.add_argument('--input-nc', type=int, default=3, 135 | help='image input channel') 136 | parser.add_argument('--ndf', type=int, default=64, 137 | help='the number of filters in the last conv layer') 138 | parser.add_argument('--patch-size', type=int, default=32, 139 | help='image patch size that inputs to the discriminator') 140 | parser.add_argument('--gan-norm-layer', type=str, default='batch', 141 | help='batch normalization or instance normalization') 142 | 143 | def forward(self, input): 144 | """Standard forward.""" 145 | return self.net(input) -------------------------------------------------------------------------------- /fairnr/modules/field.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import grad 7 | from collections import OrderedDict 8 | from fairnr.modules.implicit import ( 9 | ImplicitField, SignedDistanceField, 10 | TextureField, HyperImplicitField, BackgroundField 11 | ) 12 | from fairnr.modules.linear import NeRFPosEmbLinear 13 | 14 | FIELD_REGISTRY = {} 15 | 16 | def register_field(name): 17 | def register_field_cls(cls): 18 | if name in FIELD_REGISTRY: 19 | raise ValueError('Cannot register duplicate module ({})'.format(name)) 20 | FIELD_REGISTRY[name] = cls 21 | return cls 22 | return register_field_cls 23 | 24 | 25 | def get_field(name): 26 | if name not in FIELD_REGISTRY: 27 | raise ValueError('Cannot find module {}'.format(name)) 28 | return FIELD_REGISTRY[name] 29 | 30 | 31 | @register_field('abstract_field') 32 | class Field(nn.Module): 33 | """ 34 | Abstract class for implicit functions 35 | """ 36 | def __init__(self, args): 37 | super().__init__() 38 | self.args = args 39 | 40 | def forward(self, **kwargs): 41 | raise NotImplementedError 42 | 43 | @staticmethod 44 | def add_args(parser): 45 | pass 46 | 47 | 48 | @register_field('radiance_field') 49 | class RaidanceField(Field): 50 | 51 | def __init__(self, args): 52 | super().__init__(args) 53 | 54 | # additional arguments 55 | self.chunk_size = getattr(args, "chunk_size", 256) * 256 56 | self.deterministic_step = getattr(args, "deterministic_step", False) 57 | 58 | # background field 59 | self.min_color = getattr(args, "min_color", -1) 60 | self.trans_bg = getattr(args, "transparent_background", "1.0,1.0,1.0") 61 | self.sgbg = getattr(args, "background_stop_gradient", False) 62 | self.bg_color = BackgroundField(bg_color=self.trans_bg, min_color=self.min_color, stop_grad=self.sgbg) 63 | self.den_filters, self.den_ori_dims, self.den_input_dims = self.parse_inputs(args.inputs_to_density) 64 | self.tex_filters, self.tex_ori_dims, self.tex_input_dims = self.parse_inputs(args.inputs_to_texture) 65 | self.den_filters, self.tex_filters = nn.ModuleDict(self.den_filters), nn.ModuleDict(self.tex_filters) 66 | den_input_dim, tex_input_dim = sum(self.den_input_dims), sum(self.tex_input_dims) 67 | den_feat_dim = self.tex_input_dims[0] 68 | 69 | # build networks 70 | if not getattr(args, "hypernetwork", False): 71 | self.feature_field = ImplicitField(den_input_dim, den_feat_dim, 72 | args.feature_embed_dim, args.feature_layers) 73 | else: 74 | den_contxt_dim = self.den_input_dims[-1] 75 | self.feature_field = HyperImplicitField(den_contxt_dim, den_input_dim - den_contxt_dim, 76 | den_feat_dim, args.feature_embed_dim, args.feature_layers) 77 | self.load_pc = getattr(args, "load_pc", False) 78 | if self.load_pc: 79 | self.color_func = NeRFPosEmbLinear( 80 | 3, 3 * self.args.pc_pose_dim * 2, 81 | angular=False, no_linear=True, cat_input=True) 82 | render_input_dim = tex_input_dim + 3 + self.color_func.out_dim 83 | else: 84 | render_input_dim = tex_input_dim 85 | self.predictor = SignedDistanceField(den_feat_dim, args.density_embed_dim, recurrent=False) 86 | self.renderer = TextureField(render_input_dim, args.texture_embed_dim, args.texture_layers) 87 | 88 | def parse_inputs(self, arguments): 89 | def fillup(p): 90 | assert len(p) > 0 91 | default = 'b' if (p[0] != 'ray') and (p[0] != 'normal') else 'a' 92 | 93 | if len(p) == 1: 94 | return [p[0], 0, 3, default] 95 | elif len(p) == 2: 96 | return [p[0], int(p[1]), 3, default] 97 | elif len(p) == 3: 98 | return [p[0], int(p[1]), int(p[2]), default] 99 | return [p[0], int(p[1]), int(p[2]), p[4]] 100 | 101 | filters, input_dims, output_dims = OrderedDict(), [], [] 102 | for p in arguments.split(','): 103 | name, pos_dim, base_dim, pos_type = fillup([a.strip() for a in p.strip().split(':')]) 104 | 105 | if pos_dim > 0: # use positional embedding 106 | func = NeRFPosEmbLinear( 107 | base_dim, base_dim * pos_dim * 2, 108 | angular=(pos_type == 'a'), 109 | no_linear=True, 110 | cat_input=(pos_type != 'a')) 111 | odim = func.out_dim + func.in_dim if func.cat_input else func.out_dim 112 | 113 | else: 114 | func = nn.Identity() 115 | odim = base_dim 116 | 117 | input_dims += [base_dim] 118 | output_dims += [odim] 119 | filters[name] = func 120 | return filters, input_dims, output_dims 121 | 122 | @staticmethod 123 | def add_args(parser): 124 | parser.add_argument('--inputs-to-density', type=str, 125 | help=""" 126 | Types of inputs to predict the density. 127 | Choices of types are emb or pos. 128 | use first . to assign sinsudoal frequency. 129 | use second : to assign the input dimension (in default 3). 130 | use third : to set the type -> basic, angular or gaussian 131 | Size must match 132 | e.g. --inputs-to-density emb:6:32,pos:4 133 | """) 134 | parser.add_argument('--inputs-to-texture', type=str, 135 | help=""" 136 | Types of inputs to predict the texture. 137 | Choices of types are feat, emb, ray, pos or normal. 138 | """) 139 | 140 | parser.add_argument('--feature-embed-dim', type=int, metavar='N', 141 | help='field hidden dimension for FFN') 142 | parser.add_argument('--density-embed-dim', type=int, metavar='N', 143 | help='hidden dimension of density prediction'), 144 | parser.add_argument('--texture-embed-dim', type=int, metavar='N', 145 | help='hidden dimension of texture prediction') 146 | 147 | parser.add_argument('--input-embed-dim', type=int, metavar='N', 148 | help='number of features for query (in default 3, xyz)') 149 | parser.add_argument('--output-embed-dim', type=int, metavar='N', 150 | help='number of features the field returns') 151 | parser.add_argument('--raydir-embed-dim', type=int, metavar='N', 152 | help='the number of dimension to encode the ray directions') 153 | parser.add_argument('--disable-raydir', action='store_true', 154 | help='if set, not use view direction as additional inputs') 155 | parser.add_argument('--add-pos-embed', type=int, metavar='N', 156 | help='using periodic activation augmentation') 157 | parser.add_argument('--feature-layers', type=int, metavar='N', 158 | help='number of FC layers used to encode') 159 | parser.add_argument('--texture-layers', type=int, metavar='N', 160 | help='number of FC layers used to predict colors') 161 | 162 | # specific parameters (hypernetwork does not work right now) 163 | parser.add_argument('--hypernetwork', action='store_true', 164 | help='use hypernetwork to model feature') 165 | parser.add_argument('--hyper-feature-embed-dim', type=int, metavar='N', 166 | help='feature dimension used to predict the hypernetwork. consistent with context embedding') 167 | 168 | # backgound parameters 169 | parser.add_argument('--background-depth', type=float, 170 | help='the depth of background. used for depth visualization') 171 | parser.add_argument('--background-stop-gradient', action='store_true', 172 | help='do not optimize the background color') 173 | 174 | @torch.enable_grad() # tracking the gradient in case we need to have normal at testing time. 175 | def forward(self, inputs, outputs=['sigma', 'texture'], color_feat=None): 176 | filtered_inputs, context = [], None 177 | if 'feat' not in inputs: 178 | for i, name in enumerate(self.den_filters): 179 | d_in, func = self.den_ori_dims[i], self.den_filters[name] 180 | assert (name in inputs), "the encoder must contain target inputs" 181 | assert inputs[name].size(-1) == d_in, "{} dimension must match {} v.s. {}".format( 182 | name, inputs[name].size(-1), d_in) 183 | if name == 'context': 184 | assert (i == (len(self.den_filters) - 1)), "we force context as the last input" 185 | assert inputs[name].size(0) == 1, "context is object level" 186 | context = func(inputs[name]) 187 | else: 188 | filtered_inputs += [func(inputs[name])] 189 | 190 | filtered_inputs = torch.cat(filtered_inputs, -1) 191 | if context is not None: 192 | if getattr(self.args, "hypernetwork", False): 193 | filtered_inputs = (filtered_inputs, context) 194 | else: 195 | filtered_inputs = (torch.cat([filtered_inputs, context.repeat(filtered_inputs.size(0), 1)], -1),) 196 | else: 197 | filtered_inputs = (filtered_inputs, ) 198 | inputs['feat'] = self.feature_field(*filtered_inputs) 199 | 200 | if 'sigma' in outputs: 201 | assert 'feat' in inputs, "feature must be pre-computed" 202 | inputs['sigma'] = self.predictor(inputs['feat'])[0] 203 | 204 | if (('texture' in outputs) and ("normal" in self.tex_filters)) or ("normal" in outputs): 205 | assert 'sigma' in inputs, "sigma must be pre-computed" 206 | assert 'pos' in inputs, "position is used to compute sigma" 207 | grad_pos, = grad( 208 | outputs=inputs['sigma'], inputs=inputs['pos'], 209 | grad_outputs=torch.ones_like(inputs['sigma']), 210 | retain_graph=True) 211 | inputs['normal'] = F.normalize(-grad_pos, p=2, dim=1) # BUG: gradient direction reversed. 212 | 213 | if color_feat is not None: 214 | inputs['color_feat'] = self.color_func(color_feat) 215 | 216 | if 'texture' in outputs: 217 | filtered_inputs = [] 218 | for i, name in enumerate(self.tex_filters): 219 | d_in, func = self.tex_ori_dims[i], self.tex_filters[name] 220 | assert (name in inputs), "the encoder must contain target inputs" 221 | assert inputs[name].size(-1) == d_in, "dimension must match" 222 | 223 | filtered_inputs += [func(inputs[name])] 224 | 225 | filtered_inputs = torch.cat(filtered_inputs, -1) 226 | if self.load_pc and 'color_feat' in inputs: 227 | filtered_inputs = torch.cat((filtered_inputs, inputs['color_feat']), dim=-1) 228 | inputs['texture'] = self.renderer(filtered_inputs) 229 | 230 | return inputs 231 | -------------------------------------------------------------------------------- /fairnr/modules/hyper.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import functools 5 | 6 | from fairnr.modules.linear import FCBlock 7 | 8 | 9 | def partialclass(cls, *args, **kwds): 10 | 11 | class NewCls(cls): 12 | __init__ = functools.partialmethod(cls.__init__, *args, **kwds) 13 | 14 | return NewCls 15 | 16 | 17 | class LookupLayer(nn.Module): 18 | def __init__(self, in_ch, out_ch, num_objects): 19 | super().__init__() 20 | 21 | self.out_ch = out_ch 22 | self.lookup_lin = LookupLinear(in_ch, 23 | out_ch, 24 | num_objects=num_objects) 25 | self.norm_nl = nn.Sequential( 26 | nn.LayerNorm([self.out_ch], elementwise_affine=False), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | def forward(self, obj_idx): 31 | net = nn.Sequential( 32 | self.lookup_lin(obj_idx), 33 | self.norm_nl 34 | ) 35 | return net 36 | 37 | 38 | class LookupFC(nn.Module): 39 | def __init__(self, 40 | hidden_ch, 41 | num_hidden_layers, 42 | num_objects, 43 | in_ch, 44 | out_ch, 45 | outermost_linear=False): 46 | super().__init__() 47 | self.layers = nn.ModuleList() 48 | self.layers.append(LookupLayer(in_ch=in_ch, out_ch=hidden_ch, num_objects=num_objects)) 49 | 50 | for i in range(num_hidden_layers): 51 | self.layers.append(LookupLayer(in_ch=hidden_ch, out_ch=hidden_ch, num_objects=num_objects)) 52 | 53 | if outermost_linear: 54 | self.layers.append(LookupLinear(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)) 55 | else: 56 | self.layers.append(LookupLayer(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)) 57 | 58 | def forward(self, obj_idx): 59 | net = [] 60 | for i in range(len(self.layers)): 61 | net.append(self.layers[i](obj_idx)) 62 | 63 | return nn.Sequential(*net) 64 | 65 | 66 | class LookupLinear(nn.Module): 67 | def __init__(self, 68 | in_ch, 69 | out_ch, 70 | num_objects): 71 | super().__init__() 72 | self.in_ch = in_ch 73 | self.out_ch = out_ch 74 | 75 | self.hypo_params = nn.Embedding(num_objects, in_ch * out_ch + out_ch) 76 | 77 | for i in range(num_objects): 78 | nn.init.kaiming_normal_(self.hypo_params.weight.data[i, :self.in_ch * self.out_ch].view(self.out_ch, self.in_ch), 79 | a=0.0, 80 | nonlinearity='relu', 81 | mode='fan_in') 82 | self.hypo_params.weight.data[i, self.in_ch * self.out_ch:].fill_(0.) 83 | 84 | def forward(self, obj_idx): 85 | hypo_params = self.hypo_params(obj_idx) 86 | 87 | # Indices explicit to catch erros in shape of output layer 88 | weights = hypo_params[..., :self.in_ch * self.out_ch] 89 | biases = hypo_params[..., self.in_ch * self.out_ch:(self.in_ch * self.out_ch)+self.out_ch] 90 | 91 | biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) 92 | weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) 93 | 94 | return BatchLinear(weights=weights, biases=biases) 95 | 96 | 97 | class HyperLayer(nn.Module): 98 | '''A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU.''' 99 | def __init__(self, 100 | in_ch, 101 | out_ch, 102 | hyper_in_ch, 103 | hyper_num_hidden_layers, 104 | hyper_hidden_ch): 105 | super().__init__() 106 | 107 | self.hyper_linear = HyperLinear(in_ch=in_ch, 108 | out_ch=out_ch, 109 | hyper_in_ch=hyper_in_ch, 110 | hyper_num_hidden_layers=hyper_num_hidden_layers, 111 | hyper_hidden_ch=hyper_hidden_ch) 112 | self.norm_nl = nn.Sequential( 113 | nn.LayerNorm([out_ch], elementwise_affine=False), 114 | nn.ReLU(inplace=True) 115 | ) 116 | 117 | def forward(self, hyper_input): 118 | ''' 119 | :param hyper_input: input to hypernetwork. 120 | :return: nn.Module; predicted fully connected network. 121 | ''' 122 | return nn.Sequential(self.hyper_linear(hyper_input), self.norm_nl) 123 | 124 | 125 | class HyperFC(nn.Module): 126 | '''Builds a hypernetwork that predicts a fully connected neural network. 127 | ''' 128 | def __init__(self, 129 | hyper_in_ch, 130 | hyper_num_hidden_layers, 131 | hyper_hidden_ch, 132 | hidden_ch, 133 | num_hidden_layers, 134 | in_ch, 135 | out_ch, 136 | outermost_linear=False): 137 | super().__init__() 138 | 139 | PreconfHyperLinear = partialclass(HyperLinear, 140 | hyper_in_ch=hyper_in_ch, 141 | hyper_num_hidden_layers=hyper_num_hidden_layers, 142 | hyper_hidden_ch=hyper_hidden_ch) 143 | PreconfHyperLayer = partialclass(HyperLayer, 144 | hyper_in_ch=hyper_in_ch, 145 | hyper_num_hidden_layers=hyper_num_hidden_layers, 146 | hyper_hidden_ch=hyper_hidden_ch) 147 | 148 | self.layers = nn.ModuleList() 149 | self.layers.append(PreconfHyperLayer(in_ch=in_ch, out_ch=hidden_ch)) 150 | 151 | for i in range(num_hidden_layers): 152 | self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=hidden_ch)) 153 | 154 | if outermost_linear: 155 | self.layers.append(PreconfHyperLinear(in_ch=hidden_ch, out_ch=out_ch)) 156 | else: 157 | self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=out_ch)) 158 | 159 | 160 | def forward(self, hyper_input): 161 | ''' 162 | :param hyper_input: Input to hypernetwork. 163 | :return: nn.Module; Predicted fully connected neural network. 164 | ''' 165 | net = [] 166 | for i in range(len(self.layers)): 167 | net.append(self.layers[i](hyper_input)) 168 | 169 | return nn.Sequential(*net) 170 | 171 | 172 | class BatchLinear(nn.Module): 173 | def __init__(self, 174 | weights, 175 | biases): 176 | '''Implements a batch linear layer. 177 | 178 | :param weights: Shape: (batch, out_ch, in_ch) 179 | :param biases: Shape: (batch, 1, out_ch) 180 | ''' 181 | super().__init__() 182 | 183 | self.weights = weights 184 | self.biases = biases 185 | 186 | def __repr__(self): 187 | return "BatchLinear(batch=%d, in_ch=%d, out_ch=%d)"%( 188 | self.weights.shape[0], self.weights.shape[-1], self.weights.shape[-2]) 189 | 190 | def forward(self, input): 191 | output = input.matmul(self.weights.permute(*[i for i in range(len(self.weights.shape)-2)], -1, -2)) 192 | output += self.biases 193 | return output 194 | 195 | 196 | def last_hyper_layer_init(m): 197 | if type(m) == nn.Linear: 198 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 199 | m.weight.data *= 1e-1 200 | 201 | 202 | class HyperLinear(nn.Module): 203 | '''A hypernetwork that predicts a single linear layer (weights & biases).''' 204 | def __init__(self, 205 | in_ch, 206 | out_ch, 207 | hyper_in_ch, 208 | hyper_num_hidden_layers, 209 | hyper_hidden_ch): 210 | 211 | super().__init__() 212 | self.in_ch = in_ch 213 | self.out_ch = out_ch 214 | 215 | self.hypo_params = FCBlock( 216 | in_features=hyper_in_ch, 217 | hidden_ch=hyper_hidden_ch, 218 | num_hidden_layers=hyper_num_hidden_layers, 219 | out_features=(in_ch * out_ch) + out_ch, 220 | outermost_linear=True) 221 | self.hypo_params[-1].apply(last_hyper_layer_init) 222 | 223 | def forward(self, hyper_input): 224 | hypo_params = self.hypo_params(hyper_input.cuda()) 225 | 226 | # Indices explicit to catch erros in shape of output layer 227 | weights = hypo_params[..., :self.in_ch * self.out_ch] 228 | biases = hypo_params[..., self.in_ch * self.out_ch:(self.in_ch * self.out_ch)+self.out_ch] 229 | 230 | biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) 231 | weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) 232 | 233 | return BatchLinear(weights=weights, biases=biases) 234 | -------------------------------------------------------------------------------- /fairnr/modules/implicit.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from fairnr.modules.hyper import HyperFC 7 | from fairnr.modules.linear import ( 8 | NeRFPosEmbLinear, FCLayer, ResFCLayer 9 | ) 10 | 11 | 12 | class BackgroundField(nn.Module): 13 | """ 14 | Background (we assume a uniform color) 15 | """ 16 | def __init__(self, out_dim=3, bg_color="1.0,1.0,1.0", min_color=-1, stop_grad=False, background_depth=5.0): 17 | super().__init__() 18 | 19 | if out_dim == 3: # directly model RGB 20 | bg_color = [float(b) for b in bg_color.split(',')] if isinstance(bg_color, str) else [bg_color] 21 | if min_color == -1: 22 | bg_color = [b * 2 - 1 for b in bg_color] 23 | if len(bg_color) == 1: 24 | bg_color = bg_color + bg_color + bg_color 25 | bg_color = torch.tensor(bg_color) 26 | else: 27 | bg_color = torch.ones(out_dim).uniform_() 28 | if min_color == -1: 29 | bg_color = bg_color * 2 - 1 30 | self.out_dim = out_dim 31 | self.bg_color = nn.Parameter(bg_color, requires_grad=not stop_grad) 32 | self.depth = background_depth 33 | 34 | def forward(self, x, **kwargs): 35 | return self.bg_color.unsqueeze(0).expand( 36 | *x.size()[:-1], self.out_dim) 37 | 38 | 39 | class ImplicitField(nn.Module): 40 | 41 | """ 42 | An implicit field is a neural network that outputs a vector given any query point. 43 | """ 44 | def __init__(self, in_dim, out_dim, hidden_dim, num_layers, outmost_linear=False, pos_proj=0): 45 | super().__init__() 46 | if pos_proj > 0: 47 | new_in_dim = in_dim * 2 * pos_proj 48 | self.nerfpos = NeRFPosEmbLinear(in_dim, new_in_dim, no_linear=True) 49 | in_dim = new_in_dim + in_dim 50 | else: 51 | self.nerfpos = None 52 | 53 | self.net = [] 54 | self.net.append(FCLayer(in_dim, hidden_dim)) 55 | for _ in range(num_layers): 56 | self.net.append(FCLayer(hidden_dim, hidden_dim)) 57 | 58 | if not outmost_linear: 59 | self.net.append(FCLayer(hidden_dim, out_dim)) 60 | else: 61 | self.net.append(nn.Linear(hidden_dim, out_dim)) 62 | 63 | self.net = nn.Sequential(*self.net) 64 | self.net.apply(self.init_weights) 65 | 66 | def init_weights(self, m): 67 | if type(m) == nn.Linear: 68 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 69 | 70 | def forward(self, x): 71 | if self.nerfpos is not None: 72 | x = torch.cat([x, self.nerfpos(x)], -1) 73 | return self.net(x) 74 | 75 | 76 | class HyperImplicitField(nn.Module): 77 | 78 | def __init__(self, hyper_in_dim, in_dim, out_dim, hidden_dim, num_layers, outmost_linear=False, pos_proj=0): 79 | super().__init__() 80 | 81 | self.hyper_in_dim = hyper_in_dim 82 | self.in_dim = in_dim 83 | 84 | if pos_proj > 0: 85 | new_in_dim = in_dim * 2 * pos_proj 86 | self.nerfpos = NeRFPosEmbLinear(in_dim, new_in_dim, no_linear=True) 87 | in_dim = new_in_dim + in_dim 88 | else: 89 | self.nerfpos = None 90 | 91 | self.net = HyperFC( 92 | hyper_in_dim, 93 | 1, 256, 94 | hidden_dim, 95 | num_layers, 96 | in_dim, 97 | out_dim, 98 | outermost_linear=outmost_linear 99 | ) 100 | 101 | def forward(self, x, c): 102 | assert (x.size(-1) == self.in_dim) and (c.size(-1) == self.hyper_in_dim) 103 | if self.nerfpos is not None: 104 | x = torch.cat([x, self.nerfpos(x)], -1) 105 | return self.net(c)(x.unsqueeze(0)).squeeze(0) 106 | 107 | 108 | class SignedDistanceField(nn.Module): 109 | 110 | def __init__(self, in_dim, hidden_dim, recurrent=False): 111 | super().__init__() 112 | self.recurrent = recurrent 113 | 114 | if recurrent: 115 | self.hidden_layer = nn.LSTMCell(input_size=in_dim, hidden_size=hidden_dim) 116 | self.hidden_layer.apply(init_recurrent_weights) 117 | lstm_forget_gate_init(self.hidden_layer) 118 | else: 119 | self.hidden_layer = FCLayer(in_dim, hidden_dim) 120 | 121 | self.output_layer = nn.Linear(hidden_dim, 1) 122 | 123 | def forward(self, x, state=None): 124 | if self.recurrent: 125 | shape = x.size() 126 | state = self.hidden_layer(x.view(-1, shape[-1]), state) 127 | if state[0].requires_grad: 128 | state[0].register_hook(lambda x: x.clamp(min=-5, max=5)) 129 | 130 | return self.output_layer(state[0].view(*shape[:-1], -1)).squeeze(-1), state 131 | 132 | else: 133 | 134 | return self.output_layer(self.hidden_layer(x)).squeeze(-1), None 135 | 136 | 137 | class TextureField(ImplicitField): 138 | """ 139 | Pixel generator based on 1x1 conv networks 140 | """ 141 | def __init__(self, in_dim, hidden_dim, num_layers, with_alpha=False): 142 | out_dim = 3 if not with_alpha else 4 143 | super().__init__(in_dim, out_dim, hidden_dim, num_layers, outmost_linear=True) 144 | 145 | 146 | class OccupancyField(ImplicitField): 147 | """ 148 | Occupancy Network which predicts 0~1 at every space 149 | """ 150 | def __init__(self, in_dim, hidden_dim, num_layers): 151 | super().__init__(in_dim, 1, hidden_dim, num_layers, outmost_linear=True) 152 | 153 | def forward(self, x): 154 | return torch.sigmoid(super().forward(x)).squeeze(-1) 155 | 156 | 157 | # ------------------ # 158 | # helper functions # 159 | # ------------------ # 160 | def init_recurrent_weights(self): 161 | for m in self.modules(): 162 | if type(m) in [nn.GRU, nn.LSTM, nn.RNN]: 163 | for name, param in m.named_parameters(): 164 | if 'weight_ih' in name: 165 | nn.init.kaiming_normal_(param.data) 166 | elif 'weight_hh' in name: 167 | nn.init.orthogonal_(param.data) 168 | elif 'bias' in name: 169 | param.data.fill_(0) 170 | 171 | 172 | def lstm_forget_gate_init(lstm_layer): 173 | for name, parameter in lstm_layer.named_parameters(): 174 | if not "bias" in name: continue 175 | n = parameter.size(0) 176 | start, end = n // 4, n // 2 177 | parameter.data[start:end].fill_(1.) 178 | 179 | 180 | def clip_grad_norm_hook(x, max_norm=10): 181 | total_norm = x.norm() 182 | total_norm = total_norm ** (1 / 2.) 183 | clip_coef = max_norm / (total_norm + 1e-6) 184 | if clip_coef < 1: 185 | return x * clip_coef -------------------------------------------------------------------------------- /fairnr/modules/linear.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def Linear(in_features, out_features, bias=True): 9 | m = nn.Linear(in_features, out_features, bias) 10 | nn.init.xavier_uniform_(m.weight) 11 | if bias: 12 | nn.init.constant_(m.bias, 0.0) 13 | return m 14 | 15 | 16 | def Embedding(num_embeddings, embedding_dim, padding_idx=None): 17 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 18 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 19 | return m 20 | 21 | 22 | class PosEmbLinear(nn.Module): 23 | 24 | def __init__(self, in_dim, out_dim, no_linear=False, scale=1024): 25 | super().__init__() 26 | assert out_dim % (2 * in_dim) == 0, "dimension must be dividable" 27 | half_dim = out_dim // 2 // in_dim 28 | emb = math.log(10000) / (half_dim - 1) 29 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 30 | 31 | self.emb = nn.Parameter(emb, requires_grad=False) 32 | self.linear = Linear(out_dim, out_dim) if not no_linear else None 33 | self.scale = scale 34 | self.in_dim = in_dim 35 | self.out_dim = out_dim 36 | 37 | def forward(self, x): 38 | assert x.size(-1) == self.in_dim, "size must match" 39 | sizes = x.size() 40 | x = self.scale * x.unsqueeze(-1) @ self.emb.unsqueeze(0) 41 | x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 42 | x = x.view(*sizes[:-1], self.out_dim) 43 | if self.linear is not None: 44 | return self.linear(x) 45 | return x 46 | 47 | 48 | class NeRFPosEmbLinear(nn.Module): 49 | 50 | def __init__(self, in_dim, out_dim, angular=False, no_linear=False, cat_input=False): 51 | super().__init__() 52 | assert out_dim % (2 * in_dim) == 0, "dimension must be dividable" 53 | L = out_dim // 2 // in_dim 54 | emb = torch.exp(torch.arange(L, dtype=torch.float) * math.log(2.)) 55 | if not angular: 56 | emb = emb * math.pi 57 | 58 | self.emb = nn.Parameter(emb, requires_grad=False) 59 | self.angular = angular 60 | self.linear = Linear(out_dim, out_dim) if not no_linear else None 61 | self.in_dim = in_dim 62 | self.out_dim = out_dim 63 | self.cat_input = cat_input 64 | 65 | def forward(self, x): 66 | assert x.size(-1) == self.in_dim, "size must match" 67 | sizes = x.size() 68 | inputs = x.clone() 69 | 70 | if self.angular: 71 | x = torch.acos(x.clamp(-1 + 1e-6, 1 - 1e-6)) 72 | x = x.unsqueeze(-1) @ self.emb.unsqueeze(0) 73 | x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 74 | x = x.view(*sizes[:-1], self.out_dim) 75 | if self.linear is not None: 76 | x = self.linear(x) 77 | if self.cat_input: 78 | x = torch.cat([x, inputs], -1) 79 | return x 80 | 81 | def extra_repr(self) -> str: 82 | outstr = 'Sinusoidal (in={}, out={}, angular={})'.format( 83 | self.in_dim, self.out_dim, self.angular) 84 | if self.cat_input: 85 | outstr = 'Cat({}, {})'.format(self.in_dim, outstr) 86 | return outstr 87 | 88 | 89 | class FCLayer(nn.Module): 90 | """ 91 | Reference: 92 | https://github.com/vsitzmann/pytorch_prototyping/blob/10f49b1e7df38a58fd78451eac91d7ac1a21df64/pytorch_prototyping.py 93 | """ 94 | def __init__(self, in_dim, out_dim): 95 | super().__init__() 96 | 97 | self.net = nn.Sequential( 98 | nn.Linear(in_dim, out_dim), 99 | nn.LayerNorm([out_dim]), 100 | nn.ReLU(inplace=True)) 101 | 102 | def forward(self, x): 103 | return self.net(x) 104 | 105 | 106 | class FCBlock(nn.Module): 107 | def __init__(self, 108 | hidden_ch, 109 | num_hidden_layers, 110 | in_features, 111 | out_features, 112 | outermost_linear=False): 113 | super().__init__() 114 | 115 | self.net = [] 116 | self.net.append(FCLayer(in_features, hidden_ch)) 117 | 118 | for i in range(num_hidden_layers): 119 | self.net.append(FCLayer(hidden_ch, hidden_ch)) 120 | 121 | if outermost_linear: 122 | self.net.append(Linear(hidden_ch, out_features)) 123 | else: 124 | self.net.append(FCLayer(hidden_ch, out_features)) 125 | 126 | self.net = nn.Sequential(*self.net) 127 | self.net.apply(self.init_weights) 128 | 129 | def __getitem__(self,item): 130 | return self.net[item] 131 | 132 | def init_weights(self, m): 133 | if type(m) == nn.Linear: 134 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 135 | 136 | def forward(self, input): 137 | return self.net(input) 138 | 139 | 140 | class ResFCLayer(nn.Module): 141 | """ 142 | Reference: 143 | https://github.com/autonomousvision/occupancy_networks/blob/master/im2mesh/layers.py 144 | """ 145 | def __init__(self, in_dim, out_dim, hidden_dim, act='relu', dropout=0.0): 146 | super().__init__() 147 | 148 | self.fc1 = nn.Linear(in_dim, hidden_dim) 149 | self.fc2 = nn.Linear(hidden_dim, out_dim) 150 | self.nonlinear = nn.ReLU() 151 | self.dropout = dropout 152 | 153 | # Initialization (?) 154 | nn.init.zeros_(self.fc2.weight) 155 | 156 | def forward(self, x): 157 | residual = x 158 | x = self.fc1(self.nonlinear(x)) 159 | x = self.fc2(self.nonlinear(x)) 160 | if self.dropout > 0: 161 | x = F.dropout(x, p=self.dropout, training=self.training) 162 | x = residual + x 163 | return x -------------------------------------------------------------------------------- /fairnr/modules/reader.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import random, os, glob 5 | 6 | from fairnr.data.geometry import get_ray_direction, r6d2mat, get_ray_direction_proj 7 | 8 | torch.autograd.set_detect_anomaly(True) 9 | TINY = 1e-9 10 | READER_REGISTRY = {} 11 | 12 | def register_reader(name): 13 | def register_reader_cls(cls): 14 | if name in READER_REGISTRY: 15 | raise ValueError('Cannot register duplicate module ({})'.format(name)) 16 | READER_REGISTRY[name] = cls 17 | return cls 18 | return register_reader_cls 19 | 20 | 21 | def get_reader(name): 22 | if name not in READER_REGISTRY: 23 | raise ValueError('Cannot find module {}'.format(name)) 24 | return READER_REGISTRY[name] 25 | 26 | 27 | @register_reader('abstract_reader') 28 | class Reader(nn.Module): 29 | def __init__(self, args): 30 | super().__init__() 31 | self.args = args 32 | 33 | def forward(self, **kwargs): 34 | raise NotImplementedError 35 | 36 | @staticmethod 37 | def add_args(parser): 38 | pass 39 | 40 | 41 | @register_reader('image_reader') 42 | class ImageReader(Reader): 43 | """ 44 | basic image reader 45 | """ 46 | def __init__(self, args): 47 | super().__init__(args) 48 | self.num_pixels = args.pixel_per_view 49 | self.no_sampling = getattr(args, "no_sampling_at_reader", False) 50 | self.num_pixels_ratio = getattr(args, "pixel_per_view_down", None) 51 | if self.num_pixels_ratio is not None: 52 | self.num_pixels_ratio = [float(s) for s in self.num_pixels_ratio.split(',')] 53 | 54 | self.deltas = None 55 | if getattr(args, "trainable_extrinsics", False): 56 | self.all_data = self.find_data() 57 | self.all_data_idx = {data_img: (s, v) 58 | for s, data in enumerate(self.all_data) 59 | for v, data_img in enumerate(data)} 60 | self.deltas = nn.ParameterList([ 61 | nn.Parameter(torch.tensor( 62 | [[1., 0., 0., 0., 1., 0., 0., 0., 0.]]).repeat(len(data), 1)) 63 | for data in self.all_data]) 64 | 65 | def find_data(self): 66 | paths = self.args.data 67 | if os.path.isdir(paths): 68 | self.paths = [paths] 69 | else: 70 | self.paths = [line.strip() for line in open(paths)] 71 | return [sorted(glob.glob("{}/rgb/*".format(p))) for p in self.paths] 72 | 73 | @staticmethod 74 | def add_args(parser): 75 | parser.add_argument('--pixel-per-view', type=float, metavar='N', 76 | help='number of pixels sampled for each view') 77 | parser.add_argument('--pixel-per-view-down', type=str, 78 | help='ratio of downsample pixels sampled for each view') 79 | parser.add_argument("--sampling-on-mask", nargs='?', const=0.9, type=float, 80 | help="this value determined the probability of sampling rays on masks") 81 | parser.add_argument("--sampling-at-center", type=float, 82 | help="only useful for training where we restrict sampling at center of the image") 83 | parser.add_argument("--sampling-on-bbox", action='store_true', 84 | help="sampling points to close to the mask") 85 | parser.add_argument("--sampling-patch-size", type=int, 86 | help="sample pixels based on patches instead of independent pixels") 87 | parser.add_argument("--sampling-skipping-size", type=int, 88 | help="sample pixels if we have skipped pixels") 89 | parser.add_argument("--no-sampling-at-reader", action='store_true', 90 | help="do not perform sampling.") 91 | parser.add_argument("--trainable-extrinsics", action='store_true', 92 | help="if set, we assume extrinsics are trainable. We use 6D representations for rotation") 93 | 94 | def forward(self, uv, intrinsics, extrinsics, size, projections=None, path=None, **kwargs): 95 | S, V = uv.size()[:2] 96 | if (not self.training) or self.no_sampling: 97 | uv = uv.reshape(S, V, 2, -1, 1, 1) 98 | flatten_uv = uv.reshape(S, V, 2, -1) 99 | else: 100 | uv, _ = self.sample_pixels(uv, size, **kwargs) 101 | flatten_uv = uv.reshape(S, V, 2, -1) 102 | 103 | # go over all shapes 104 | ray_start, ray_dir = [[] for _ in range(S)], [[] for _ in range(S)] 105 | if self.args.dis and 'dis_extrinsics' in kwargs: 106 | dis_ray_start, dis_ray_dir = [[] for _ in range(S)], [[] for _ in range(S)] 107 | for s in range(S): 108 | for v in range(V): 109 | ixt = intrinsics[s] if intrinsics.dim() == 3 else intrinsics[s, v] 110 | ext = extrinsics[s, v] 111 | translation, rotation = ext[:3, 3], ext[:3, :3] 112 | if (self.deltas is not None) and (path is not None): 113 | shape_id, view_id = self.all_data_idx[path[s][v]] 114 | delta = self.deltas[shape_id][view_id] 115 | d_t, d_r = delta[6:], r6d2mat(delta[None, :6]).squeeze(0) 116 | rotation = rotation @ d_r 117 | translation = translation + d_t 118 | ext = torch.cat([torch.cat([rotation, translation[:, None]], 1), ext[3:]], 0) 119 | 120 | ray_start[s] += [translation] 121 | if projections is not None: 122 | proj = projections[s, v] 123 | ray_dir[s] += [get_ray_direction_proj(translation, flatten_uv[s, v], ext, proj, size, 1)] 124 | else: 125 | ray_dir[s] += [get_ray_direction(translation, flatten_uv[s, v], ixt, ext, 1)] 126 | 127 | if self.args.dis and 'dis_extrinsics' in kwargs: 128 | dis_ext = kwargs['dis_extrinsics'][s, v] 129 | dis_translation, dis_rotation = dis_ext[:3, 3], dis_ext[:3, :3] 130 | dis_ray_start[s] += [dis_translation] 131 | if projections is not None: 132 | dis_proj = kwargs['dis_projections'][s, v] 133 | dis_ray_dir[s] += [get_ray_direction_proj(dis_translation, flatten_uv[s, v], dis_ext, dis_proj, size, 1)] 134 | else: 135 | dis_ray_dir[s] += [get_ray_direction(dis_translation, flatten_uv[s, v], ixt, dis_ext, 1)] 136 | 137 | ray_start = torch.stack([torch.stack(r) for r in ray_start]) 138 | ray_dir = torch.stack([torch.stack(r) for r in ray_dir]) 139 | if self.args.dis and 'dis_extrinsics' in kwargs: 140 | dis_ray_start = torch.stack([torch.stack(r) for r in dis_ray_start]) 141 | dis_ray_dir = torch.stack([torch.stack(r) for r in dis_ray_dir]) 142 | return ray_start.unsqueeze(-2), ray_dir.transpose(2, 3), uv, dis_ray_start.unsqueeze(-2), dis_ray_dir.transpose(2, 3) 143 | else: 144 | return ray_start.unsqueeze(-2), ray_dir.transpose(2, 3), uv, None, None 145 | 146 | @torch.no_grad() 147 | def sample_pixels(self, uv, size, alpha=None, mask=None, **kwargs): 148 | H, W = int(size[0,0,0]), int(size[0,0,1]) 149 | S, V = uv.size()[:2] 150 | 151 | if mask is None: 152 | if alpha is not None: 153 | mask = (alpha > 0) 154 | else: 155 | mask = uv.new_ones(S, V, uv.size(-1)).bool() 156 | mask = mask.float().reshape(S, V, H, W) 157 | 158 | if self.args.sampling_at_center < 1.0: 159 | r = (1 - self.args.sampling_at_center) / 2.0 160 | mask0 = mask.new_zeros(S, V, H, W) 161 | mask0[:, :, int(H * r): H - int(H * r), int(W * r): W - int(W * r)] = 1 162 | mask = mask * mask0 163 | 164 | if self.args.sampling_on_bbox: 165 | x_has_points = mask.sum(2, keepdim=True) > 0 166 | y_has_points = mask.sum(3, keepdim=True) > 0 167 | mask = (x_has_points & y_has_points).float() 168 | 169 | probs = mask / (mask.sum() + 1e-8) 170 | if self.args.sampling_on_mask > 0.0: 171 | probs = self.args.sampling_on_mask * probs + (1 - self.args.sampling_on_mask) * 1.0 / (H * W) 172 | 173 | num_pixels = int(self.args.pixel_per_view) 174 | patch_size, skip_size = self.args.sampling_patch_size, self.args.sampling_skipping_size 175 | C = patch_size * skip_size 176 | 177 | if C > 1: 178 | probs = probs.reshape(S, V, H // C, C, W // C, C).sum(3).sum(-1) 179 | num_pixels = num_pixels // patch_size // patch_size 180 | 181 | flatten_probs = probs.reshape(S, V, -1) 182 | sampled_index = sampling_without_replacement(torch.log(flatten_probs+ TINY), num_pixels) 183 | sampled_masks = torch.zeros_like(flatten_probs).scatter_(-1, sampled_index, 1).reshape(S, V, H // C, W // C) 184 | 185 | if C > 1: 186 | sampled_masks = sampled_masks[:, :, :, None, :, None].repeat( 187 | 1, 1, 1, patch_size, 1, patch_size).reshape(S, V, H // skip_size, W // skip_size) 188 | if skip_size > 1: 189 | full_datamask = sampled_masks.new_zeros(S, V, skip_size * skip_size, H // skip_size, W // skip_size) 190 | full_index = torch.randint(skip_size*skip_size, (S, V)) 191 | for i in range(S): 192 | for j in range(V): 193 | full_datamask[i, j, full_index[i, j]] = sampled_masks[i, j] 194 | sampled_masks = full_datamask.reshape( 195 | S, V, skip_size, skip_size, H // skip_size, W // skip_size).permute(0, 1, 4, 2, 5, 3).reshape(S, V, H, W) 196 | 197 | X, Y = uv[:,:,0].reshape(S, V, H, W), uv[:,:,1].reshape(S, V, H, W) 198 | X = X[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size) 199 | Y = Y[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size) 200 | return torch.cat([X, Y], 2), sampled_masks 201 | 202 | @torch.no_grad() 203 | def downsample_pixels_num(self, times): 204 | if self.num_pixels_ratio is not None: 205 | self.num_pixels = int(self.num_pixels * self.num_pixels_ratio[times]) 206 | return self.num_pixels 207 | 208 | 209 | def sampling_without_replacement(logp, k): 210 | def gumbel_like(u): 211 | return -torch.log(-torch.log(torch.rand_like(u) + TINY) + TINY) 212 | scores = logp + gumbel_like(logp) 213 | return scores.topk(k, dim=-1)[1] -------------------------------------------------------------------------------- /fairnr/modules/renderer.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from collections import defaultdict 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from fairnr.modules.linear import FCLayer 10 | from fairnr.data.geometry import ray 11 | import os 12 | import numpy as np 13 | from fairnr.data.geometry import offset_points, trilinear_interp 14 | 15 | MAX_DEPTH = 10000.0 16 | RENDERER_REGISTRY = {} 17 | 18 | def register_renderer(name): 19 | def register_renderer_cls(cls): 20 | if name in RENDERER_REGISTRY: 21 | raise ValueError('Cannot register duplicate module ({})'.format(name)) 22 | RENDERER_REGISTRY[name] = cls 23 | return cls 24 | return register_renderer_cls 25 | 26 | 27 | def get_renderer(name): 28 | if name not in RENDERER_REGISTRY: 29 | raise ValueError('Cannot find module {}'.format(name)) 30 | return RENDERER_REGISTRY[name] 31 | 32 | 33 | @register_renderer('abstract_renderer') 34 | class Renderer(nn.Module): 35 | """ 36 | Abstract class for ray marching 37 | """ 38 | def __init__(self, args): 39 | super().__init__() 40 | self.args = args 41 | 42 | def forward(self, **kwargs): 43 | raise NotImplementedError 44 | 45 | @staticmethod 46 | def add_args(parser): 47 | pass 48 | 49 | 50 | @register_renderer('volume_rendering') 51 | class VolumeRenderer(Renderer): 52 | 53 | def __init__(self, args): 54 | super().__init__(args) 55 | self.chunk_size = 1024 * getattr(args, "chunk_size", 64) 56 | self.valid_chunk_size = 1024 * getattr(args, "valid_chunk_size", self.chunk_size // 1024) 57 | self.ray_chunk_size = 1024 * getattr(args, "ray_chunk_size", 64) 58 | self.discrete_reg = getattr(args, "discrete_regularization", False) 59 | self.raymarching_tolerance = getattr(args, "raymarching_tolerance", 0.0) 60 | 61 | @staticmethod 62 | def add_args(parser): 63 | # ray-marching parameters 64 | parser.add_argument('--discrete-regularization', action='store_true', 65 | help='if set, a zero mean unit variance gaussian will be added to encougrage discreteness') 66 | 67 | # additional arguments 68 | parser.add_argument('--chunk-size', type=int, metavar='D', 69 | help='set chunks to go through the network (~K forward passes). trade time for memory. ') 70 | parser.add_argument('--valid-chunk-size', type=int, metavar='D', 71 | help='chunk size used when no training. In default the same as chunk-size.') 72 | parser.add_argument('--ray-chunk-size', type=int, metavar='D', 73 | help='ray chunk size used when no training. In default the same as chunk-size.') 74 | parser.add_argument('--raymarching-tolerance', type=float, default=0) 75 | 76 | def forward_once( 77 | self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, 78 | early_stop=None, output_types=['sigma', 'texture'] 79 | ): 80 | """ 81 | chunks: set > 1 if out-of-memory. it can save some memory by time. 82 | """ 83 | sampled_depth = samples['sampled_point_depth'] 84 | sampled_dists = samples['sampled_point_distance'] 85 | sampled_idx = samples['sampled_point_voxel_idx'].long() 86 | 87 | # only compute when the ray hits 88 | sample_mask = sampled_idx.ne(-1) 89 | if early_stop is not None: 90 | sample_mask = sample_mask & (~early_stop.unsqueeze(-1)) 91 | if sample_mask.sum() == 0: # miss everything skip 92 | return None, 0 93 | 94 | sampled_xyz = ray(ray_start.unsqueeze(1), ray_dir.unsqueeze(1), sampled_depth.unsqueeze(2)) 95 | sampled_dir = ray_dir.unsqueeze(1).expand(*sampled_depth.size(), ray_dir.size()[-1]) 96 | samples['sampled_point_xyz'] = sampled_xyz 97 | samples['sampled_point_ray_direction'] = sampled_dir 98 | 99 | # apply mask 100 | samples = {name: s[sample_mask] for name, s in samples.items()} 101 | 102 | # get encoder features as inputs 103 | field_inputs = input_fn(samples, encoder_states) 104 | 105 | if 'voxel_color' in encoder_states: 106 | voxel_color = encoder_states['voxel_color'].float() 107 | sampled_idx = samples['sampled_point_voxel_idx'].long() 108 | voxel_color_feat = F.embedding(sampled_idx, voxel_color) 109 | else: 110 | voxel_color_feat = None 111 | 112 | def masked_scatter(mask, x): 113 | B, K = mask.size() 114 | if x.dim() == 1: 115 | return x.new_zeros(B, K).masked_scatter(mask, x) 116 | return x.new_zeros(B, K, x.size(-1)).masked_scatter( 117 | mask.unsqueeze(-1).expand(B, K, x.size(-1)), x) 118 | 119 | # forward implicit fields 120 | field_outputs = field_fn(field_inputs, outputs=output_types, color_feat=voxel_color_feat) 121 | outputs = {'sample_mask': sample_mask} 122 | 123 | # post processing 124 | if 'sigma' in field_outputs: 125 | sigma, sampled_dists= field_outputs['sigma'], samples['sampled_point_distance'] 126 | noise = 0 if not self.discrete_reg and (not self.training) else torch.zeros_like(sigma).normal_() 127 | free_energy = torch.relu(noise + sigma) * sampled_dists 128 | free_energy = free_energy * 7.0 # magic operation 129 | # free_energy = (F.elu(sigma - 3, alpha=1) + 1) * sampled_dists 130 | # (optional) free_energy = (F.elu(sigma - 3, alpha=1) + 1) * dists 131 | outputs['free_energy'] = masked_scatter(sample_mask, free_energy) 132 | if 'texture' in field_outputs: 133 | if self.args.res: 134 | field_outputs['texture'] = field_outputs['texture'] + voxel_color_feat 135 | outputs['texture'] = masked_scatter(sample_mask, field_outputs['texture']) 136 | if 'normal' in field_outputs: 137 | outputs['normal'] = masked_scatter(sample_mask, field_outputs['normal']) 138 | return outputs, sample_mask.sum() 139 | 140 | def forward_chunk( 141 | self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, 142 | gt_depths=None, output_types=['sigma', 'texture'], global_weights=None, 143 | ): 144 | sampled_depth = samples['sampled_point_depth'] 145 | sampled_idx = samples['sampled_point_voxel_idx'].long() 146 | 147 | tolerance = self.raymarching_tolerance 148 | chunk_size = self.chunk_size if self.training else self.valid_chunk_size 149 | early_stop = None 150 | if tolerance > 0: 151 | tolerance = -math.log(tolerance) 152 | 153 | hits = sampled_idx.ne(-1).long() 154 | outputs = defaultdict(lambda: []) 155 | size_so_far, start_step = 0, 0 156 | accumulated_free_energy = 0 157 | accumulated_evaluations = 0 158 | for i in range(hits.size(1) + 1): 159 | if ((i == hits.size(1)) or (size_so_far + hits[:, i].sum() > chunk_size)) and (i > start_step): 160 | _outputs, _evals = self.forward_once( 161 | input_fn, field_fn, 162 | ray_start, ray_dir, 163 | {name: s[:, start_step: i] 164 | for name, s in samples.items()}, 165 | encoder_states, 166 | early_stop=early_stop, 167 | output_types=output_types) 168 | if _outputs is not None: 169 | accumulated_evaluations += _evals 170 | 171 | if 'free_energy' in _outputs: 172 | accumulated_free_energy += _outputs['free_energy'].sum(1) 173 | if tolerance > 0: 174 | early_stop = accumulated_free_energy > tolerance 175 | hits[early_stop] *= 0 176 | 177 | for key in _outputs: 178 | outputs[key] += [_outputs[key]] 179 | else: 180 | for key in outputs: 181 | outputs[key] += [outputs[key][-1].new_zeros( 182 | outputs[key][-1].size(0), 183 | sampled_depth[:, start_step: i].size(1), 184 | *outputs[key][-1].size()[2:] 185 | )] 186 | start_step, size_so_far = i, 0 187 | 188 | if (i < hits.size(1)): 189 | size_so_far += hits[:, i].sum() 190 | 191 | outputs = {key: torch.cat(outputs[key], 1) for key in outputs} 192 | results = {} 193 | 194 | if 'free_energy' in outputs: 195 | free_energy = outputs['free_energy'] 196 | shifted_free_energy = torch.cat([free_energy.new_zeros(sampled_depth.size(0), 1), free_energy[:, :-1]], dim=-1) # shift one step 197 | a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here 198 | b = torch.exp(-torch.cumsum(shifted_free_energy.float(), dim=-1)) # probability of everything is empty up to now 199 | probs = (a * b).type_as(free_energy) # probability of the ray hits something here 200 | else: 201 | probs = outputs['sample_mask'].type_as(sampled_depth) / sampled_depth.size(-1) # assuming a uniform distribution 202 | 203 | if global_weights is not None: 204 | probs = probs * global_weights 205 | 206 | depth = (sampled_depth * probs).sum(-1) 207 | missed = 1 - probs.sum(-1) 208 | results.update({'probs': probs, 'depths': depth, 'missed': missed, 'ae': accumulated_evaluations}) 209 | 210 | if 'texture' in outputs: 211 | results['colors'] = (outputs['texture'] * probs.unsqueeze(-1)).sum(-2) 212 | if 'normal' in outputs: 213 | results['normal'] = (outputs['normal'] * probs.unsqueeze(-1)).sum(-2) 214 | return results 215 | 216 | def forward(self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, *args, **kwargs): 217 | chunk_size = self.chunk_size if self.training else self.valid_chunk_size 218 | ray_chunk_size = self.ray_chunk_size 219 | if ray_start.size(0) <= chunk_size: 220 | results = self.forward_chunk(input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, *args, **kwargs) 221 | return results 222 | 223 | # the number of rays is larger than maximum forward passes. pre-chuncking.. 224 | results = [ 225 | self.forward_chunk(input_fn, field_fn, 226 | ray_start[i: i+chunk_size], ray_dir[i: i+chunk_size], 227 | {name: s[i: i+chunk_size] for name, s in samples.items()}, encoder_states, *args, **kwargs) 228 | for i in range(0, ray_start.size(0), chunk_size) 229 | ] 230 | results = {name: torch.cat([r[name] for r in results], 0) 231 | if results[0][name].dim() > 0 else sum([r[name] for r in results]) 232 | for name in results[0]} 233 | 234 | if getattr(input_fn, "track_max_probs", False): 235 | input_fn.track_voxel_probs(samples['sampled_point_voxel_idx'].long(), results['probs']) 236 | return results 237 | 238 | -------------------------------------------------------------------------------- /fairnr/options.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import sys 4 | import torch 5 | 6 | 7 | def add_rendering_args(parser): 8 | group = parser.add_argument_group("Rendering") 9 | group.add_argument("--path", type=str, default=None, 10 | help="checkpoint path") 11 | group.add_argument("--model-overrides", type=str, default=None, 12 | help="change some parameters during rendering") 13 | group.add_argument("--render-beam", default=5, type=int, metavar="N", 14 | help="beam size for parallel rendering") 15 | group.add_argument("--render-resolution", default="512x512", type=str, metavar="N", help='if provide two numbers, means H x W') 16 | group.add_argument("--render-angular-speed", default=1, type=float, metavar="D", 17 | help="angular speed when rendering around the object") 18 | group.add_argument("--render-num-frames", default=500, type=int, metavar="N") 19 | group.add_argument("--render-path-style", default="circle", choices=["circle", "zoomin_circle", "zoomin_line"], type=str) 20 | group.add_argument("--render-path-args", default="{'radius': 2.5, 'h': 0.0}", 21 | help="specialized arguments for rendering paths") 22 | group.add_argument("--render-output", default=None, type=str) 23 | group.add_argument("--render-at-vector", default="(0,0,0)", type=str) 24 | group.add_argument("--render-up-vector", default="(0,0,-1)", type=str) 25 | group.add_argument("--render-output-types", nargs="+", type=str, default=["rgb"], 26 | choices=["target", "color", "depth", "voxeldepth", "normal", "voxel", "predn"]) 27 | group.add_argument("--render-raymarching-steps", default=None, type=int) 28 | group.add_argument("--render-save-fps", default=24, type=int) 29 | group.add_argument("--render-combine-output", action='store_true', 30 | help="if set, concat the images into one file.") 31 | group.add_argument("--render-camera-poses", default=None, type=str, 32 | help="text file saved for the testing trajectories") 33 | group.add_argument("--render-camera-projs", default=None, type=str, 34 | help="text file saved for projection matrices") 35 | group.add_argument("--render-camera-intrinsics", default=None, type=str) 36 | group.add_argument("--render-views", type=str, default=None, 37 | help="views sampled for rendering, you can set specific view id, or a range") 38 | group.add_argument("--render-move-vector", default="(0,0,0)", type=str, help="move voxel center") 39 | group.add_argument("--render-depth-rawoutput", action="store_true", help="output raw depth") -------------------------------------------------------------------------------- /fairnr/renderer.py: -------------------------------------------------------------------------------- 1 | 2 | import os, tempfile, shutil, glob 3 | import time 4 | import torch 5 | import numpy as np 6 | import logging 7 | import imageio 8 | import types 9 | 10 | from torchvision.utils import save_image 11 | from fairnr.data import trajectory, geometry, data_utils 12 | from fairnr.data.data_utils import recover_image, get_uv, parse_views 13 | from pathlib import Path 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class NeuralRenderer(object): 19 | 20 | def __init__(self, 21 | resolution="512x512", 22 | frames=501, 23 | speed=5, 24 | raymarching_steps=None, 25 | path_gen=None, 26 | beam=10, 27 | at=(0,0,0), 28 | up=(0,1,0), 29 | output_dir=None, 30 | output_type=None, 31 | fps=24, 32 | test_camera_poses=None, 33 | test_camera_projs=None, 34 | test_camera_intrinsics=None, 35 | test_camera_views=None, 36 | object_move=(0,0,0), 37 | raw_depth_output=False): 38 | 39 | self.frames = frames 40 | self.speed = speed 41 | self.raymarching_steps = raymarching_steps 42 | self.path_gen = path_gen 43 | 44 | if isinstance(resolution, str): 45 | self.resolution = [int(r) for r in resolution.split('x')] 46 | else: 47 | self.resolution = [resolution, resolution] 48 | 49 | self.beam = beam 50 | self.output_dir = output_dir 51 | self.output_type = output_type 52 | self.at = at 53 | self.up = up 54 | self.fps = fps 55 | self.object_move = object_move 56 | self.raw_depth_output = raw_depth_output 57 | 58 | if self.path_gen is None: 59 | self.path_gen = trajectory.circle() 60 | if self.output_type is None: 61 | self.output_type = ["rgb"] 62 | 63 | if test_camera_intrinsics is not None: 64 | self.test_int = data_utils.load_intrinsics(test_camera_intrinsics) 65 | else: 66 | self.test_int = None 67 | 68 | if test_camera_views is not None: 69 | self.render_views = parse_views(test_camera_views) 70 | self.test_frameids = None 71 | if test_camera_poses is not None: 72 | if os.path.isdir(test_camera_poses): 73 | self.test_poses = [ 74 | np.loadtxt(f)[None, :, :] for f in sorted(glob.glob(test_camera_poses + "/*.txt"))] 75 | self.test_poses = np.concatenate(self.test_poses, 0) 76 | else: 77 | self.test_poses = data_utils.load_matrix(test_camera_poses) 78 | if self.test_poses.shape[1] == 17: 79 | self.test_frameids = self.test_poses[:, -1].astype(np.int32) 80 | self.test_poses = self.test_poses[:, :-1] 81 | self.test_poses = self.test_poses.reshape(-1, 4, 4) 82 | 83 | if test_camera_views is not None: 84 | self.test_poses = np.stack([self.test_poses[r] for r in self.render_views]) 85 | 86 | else: 87 | self.test_poses = None 88 | 89 | if test_camera_projs is not None: 90 | if os.path.isdir(test_camera_projs): 91 | self.test_projs = [ 92 | np.loadtxt(f)[None, :, :] for f in sorted(glob.glob(test_camera_projs + "/*.txt"))] 93 | self.test_projs = np.concatenate(self.test_projs, 0) 94 | else: 95 | self.test_projs = data_utils.load_matrix(test_camera_projs) 96 | if self.test_projs.shape[1] == 17: 97 | self.test_frameids = self.test_projs[:, -1].astype(np.int32) 98 | self.test_projs = self.test_projs[:, :-1] 99 | self.test_projs = self.test_projs.reshape(-1, 4, 4) 100 | 101 | if test_camera_views is not None: 102 | self.test_projs = np.stack([self.test_projs[r] for r in self.render_views]) 103 | else: 104 | self.test_projs = np.tile(self.test_projs, (self.frames, 1, 1)) 105 | 106 | else: 107 | self.test_projs = None 108 | 109 | def generate_rays(self, t, intrinsics, img_size, projs=None, inv_RT=None, action='none'): 110 | if inv_RT is None: 111 | cam_pos = torch.tensor(self.path_gen(t * self.speed / 180 * np.pi), 112 | device=intrinsics.device, dtype=intrinsics.dtype) 113 | cam_rot = geometry.look_at_rotation(cam_pos, at=self.at, up=self.up, inverse=True, cv=True) 114 | 115 | inv_RT = cam_pos.new_zeros(4, 4) 116 | inv_RT[:3, :3] = cam_rot 117 | inv_RT[:3, 3] = cam_pos 118 | inv_RT[3, 3] = 1 119 | else: 120 | inv_RT = torch.from_numpy(inv_RT).type_as(intrinsics) 121 | inv_RT = torch.inverse(inv_RT) 122 | 123 | if projs is not None: 124 | projs = torch.from_numpy(projs).type_as(intrinsics) 125 | 126 | h, w, rh, rw = img_size[0], img_size[1], img_size[2], img_size[3] 127 | if self.test_int is not None: 128 | uv = torch.from_numpy(get_uv(h, w, h, w)[0]).type_as(intrinsics) 129 | intrinsics = self.test_int 130 | else: 131 | uv = torch.from_numpy(get_uv(h * rh, w * rw, h, w)[0]).type_as(intrinsics) 132 | 133 | uv = uv.reshape(2, -1) 134 | return uv, inv_RT, projs 135 | 136 | def parse_sample(self,sample): 137 | if len(sample) == 1: 138 | return sample[0], 0, self.frames 139 | elif len(sample) == 2: 140 | return sample[0], sample[1], self.frames 141 | elif len(sample) == 3: 142 | return sample[0], sample[1], sample[2] 143 | else: 144 | raise NotImplementedError 145 | 146 | @torch.no_grad() 147 | def generate(self, models, sample, **kwargs): 148 | model = models[0] 149 | model.eval() 150 | 151 | print("rendering starts.") 152 | output_path = self.output_dir 153 | image_names = [] 154 | sample, step, frames = self.parse_sample(sample) 155 | 156 | # fix the rendering size 157 | a = sample['size'][0,0,0] / self.resolution[0] 158 | b = sample['size'][0,0,1] / self.resolution[1] 159 | sample['size'][:, :, 0] /= a 160 | sample['size'][:, :, 1] /= b 161 | sample['size'][:, :, 2] *= a 162 | sample['size'][:, :, 3] *= b 163 | 164 | for shape in range(sample['shape'].size(0)): 165 | max_step = step + frames 166 | while step < max_step: 167 | next_step = min(step + self.beam, max_step) 168 | uv, inv_RT, projs = zip(*[ 169 | self.generate_rays( 170 | k, 171 | sample['intrinsics'][shape], 172 | sample['size'][shape, 0], 173 | self.test_projs[k] if self.test_projs is not None else None, 174 | self.test_poses[k] if self.test_poses is not None else None) 175 | for k in range(step, next_step) 176 | ]) 177 | if self.test_frameids is not None: 178 | assert next_step - step == 1 179 | ids = torch.tensor(self.test_frameids[step: next_step]).type_as(sample['id']) 180 | else: 181 | ids = sample['id'][shape:shape+1] 182 | 183 | real_images = sample['full_rgb'] if 'full_rgb' in sample else sample['colors'] 184 | real_images = real_images.transpose(2, 3) if real_images.size(-1) != 3 else real_images 185 | 186 | _sample = { 187 | 'id': ids, 188 | 'colors': torch.cat([real_images[shape:shape+1] for _ in range(step, next_step)], 1), 189 | 'intrinsics': sample['intrinsics'][shape:shape+1], 190 | 'extrinsics': torch.stack(inv_RT, 0).unsqueeze(0), 191 | 'projections': torch.stack(projs, 0).unsqueeze(0) if projs[0] is not None else None, 192 | 'uv': torch.stack(uv, 0).unsqueeze(0), 193 | 'shape': sample['shape'][shape:shape+1], 194 | 'view': torch.arange( 195 | step, next_step, 196 | device=sample['shape'].device).unsqueeze(0), 197 | 'size': torch.cat([sample['size'][shape:shape+1] for _ in range(step, next_step)], 1), 198 | 'step': step 199 | } 200 | if isinstance(self.object_move[0], tuple): 201 | assert len(self.object_move) == sample['shape'].size(0) 202 | _sample['move'] = torch.tensor(self.object_move[shape]).type_as(_sample['uv']) 203 | else: 204 | _sample['move'] = torch.tensor(self.object_move).type_as(_sample['uv']) 205 | with data_utils.GPUTimer() as timer: 206 | outs = model(**_sample) 207 | # logger.info("rendering frame={}\ttotal time={:.4f}\tvoxel={:.4f}".format(step, timer.sum, outs['other_logs']['tvox_log'])) 208 | print("rendering frame={}\ttotal time={:.4f}\tvoxel={:.4f}".format(step, timer.sum, outs['other_logs']['tvox_log'])) 209 | 210 | for k in range(step, next_step): 211 | images = model.visualize(_sample, None, 0, k-step, raw_depth=self.raw_depth_output) 212 | # image_name = "{:04d}".format(k) 213 | image_name = "{:04d}_{:04d}".format(shape, self.render_views[k]) if sample['shape'].size(0) > 1 else "{:04d}".format(self.render_views[k]) 214 | 215 | for key in images: 216 | name, type = key.split('/')[0].split('_') 217 | if type in self.output_type and name == 'render': 218 | if self.raw_depth_output and type == 'voxeldepth': 219 | prefix = os.path.join(output_path, type) 220 | Path(prefix).mkdir(parents=True, exist_ok=True) 221 | depth = images[key].cpu().numpy() 222 | height, width = int(_sample['size'][0,0,0]), int(_sample['size'][0,0,1]) 223 | depth = depth.reshape((height, width)) 224 | np.savetxt(os.path.join(prefix, image_name + '.txt'), depth) 225 | else: 226 | prefix = os.path.join(output_path, type) 227 | Path(prefix).mkdir(parents=True, exist_ok=True) 228 | image = images[key].permute(2, 0, 1) \ 229 | if images[key].dim() == 3 else torch.stack(3*[images[key]], 0) 230 | save_image(image, os.path.join(prefix, image_name + '.png'), format=None) 231 | image_names.append(os.path.join(prefix, image_name + '.png')) 232 | 233 | # save pose matrix 234 | prefix = os.path.join(output_path, 'pose') 235 | Path(prefix).mkdir(parents=True, exist_ok=True) 236 | pose = self.test_poses[k] if self.test_poses is not None else inv_RT[k].cpu().numpy() 237 | np.savetxt(os.path.join(prefix, image_name + '.txt'), pose) 238 | 239 | step = next_step 240 | model.clean_caches() 241 | 242 | # logger.info("done") 243 | print("done") 244 | return step, image_names 245 | 246 | def save_images(self, output_files, steps=None, combine_output=True): 247 | if not os.path.exists(self.output_dir): 248 | os.mkdir(self.output_dir) 249 | timestamp = time.strftime('%Y-%m-%d.%H-%M-%S',time.localtime(time.time())) 250 | if steps is not None: 251 | timestamp = "step_{}.".format(steps) + timestamp 252 | 253 | if not combine_output: 254 | for type in self.output_type: 255 | images = [imageio.imread(file_path) for file_path in output_files if type in file_path] 256 | # imageio.mimsave('{}/{}_{}.gif'.format(self.output_dir, type, timestamp), images, fps=self.fps) 257 | imageio.mimwrite('{}/{}_{}.mp4'.format(self.output_dir, type, timestamp), images, fps=self.fps, quality=8) 258 | else: 259 | images = [[imageio.imread(file_path) for file_path in output_files if type in file_path] for type in self.output_type] 260 | images = [np.concatenate([images[j][i] for j in range(len(images))], 1) for i in range(len(images[0]))] 261 | imageio.mimwrite('{}/{}_{}.mp4'.format(self.output_dir, 'full', timestamp), images, fps=self.fps, quality=8) 262 | 263 | return timestamp 264 | 265 | def merge_videos(self, timestamps): 266 | # logger.info("mergining mp4 files..") 267 | print("mergining mp4 files..") 268 | timestamp = time.strftime('%Y-%m-%d.%H-%M-%S',time.localtime(time.time())) 269 | writer = imageio.get_writer( 270 | os.path.join(self.output_dir, 'full_' + timestamp + '.mp4'), fps=self.fps) 271 | for timestamp in timestamps: 272 | tempfile = os.path.join(self.output_dir, 'full_' + timestamp + '.mp4') 273 | reader = imageio.get_reader(tempfile) 274 | for im in reader: 275 | writer.append_data(im) 276 | os.remove(tempfile) 277 | writer.close() -------------------------------------------------------------------------------- /fairnr/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib 3 | import os 4 | 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith('.py') and not file.startswith('_'): 7 | task_name = file[:file.find('.py')] 8 | importlib.import_module('fairnr.tasks.' + task_name) 9 | -------------------------------------------------------------------------------- /fairnr_cli/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fairnr_cli/extract.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import numpy as np 4 | import torch 5 | import sys, os 6 | import argparse 7 | import open3d as o3d 8 | 9 | from plyfile import PlyData, PlyElement 10 | 11 | from fairnr.tasks.neural_rendering import SingleObjRenderingTask 12 | from fairnr.models.nsvf import NSVFModel, my_base_architecture 13 | from fairnr.criterions.rendering_loss import SRNLossCriterion 14 | 15 | 16 | def main(parser): 17 | logging.basicConfig( 18 | format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 19 | datefmt='%Y-%m-%d %H:%M:%S', 20 | level=logging.INFO, 21 | stream=sys.stdout, 22 | ) 23 | logger = logging.getLogger('RGBDNeRF.extract') 24 | 25 | SingleObjRenderingTask.add_args(parser) 26 | NSVFModel.add_args(parser) 27 | SRNLossCriterion.add_args(parser) 28 | args = parser.parse_args() 29 | 30 | use_cuda = torch.cuda.is_available() and not args.cpu 31 | 32 | if use_cuda: 33 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 34 | 35 | ckpt = torch.load(args.path) 36 | if 'args' in ckpt: 37 | load_args = ckpt['args'] 38 | 39 | for key in vars(args).keys(): 40 | if key in vars(load_args).keys(): 41 | setattr(args, key, getattr(load_args, key)) 42 | 43 | my_base_architecture(args) 44 | task = SingleObjRenderingTask(args) 45 | model = task.build_model(args) 46 | model.load_state_dict(ckpt['model']) 47 | 48 | if use_cuda: 49 | model.cuda() 50 | 51 | if args.format == 'mc_mesh': 52 | plydata = model.encoder.export_surfaces( 53 | model.field, th=args.mc_threshold, 54 | bits=2 * args.mc_num_samples_per_halfvoxel) 55 | elif args.format == 'voxel_center': 56 | plydata = model.encoder.export_voxels(False) 57 | elif args.format == 'voxel_mesh': 58 | plydata = model.encoder.export_voxels(True) 59 | else: 60 | raise NotImplementedError 61 | 62 | # write to ply file. 63 | if not os.path.exists(args.output): 64 | os.makedirs(args.output) 65 | plydata.text = args.savetext 66 | plydata.write(open(os.path.join(args.output, args.name + '.ply'), 'wb')) 67 | 68 | 69 | def cli_main(): 70 | parser = argparse.ArgumentParser(description='Extract geometry from a trained model (only for learnable embeddings).') 71 | parser.add_argument('--path', type=str, required=True) 72 | parser.add_argument('--output', type=str, required=True) 73 | parser.add_argument('--name', type=str, default='sparsevoxel') 74 | parser.add_argument('--format', type=str, choices=['voxel_center', 'voxel_mesh', 'mc_mesh']) 75 | parser.add_argument('--savetext', action='store_true', help='save .ply in plain text') 76 | parser.add_argument('--mc-num-samples-per-halfvoxel', type=int, default=8, 77 | help="""the number of point samples every half voxel-size for marching cube. 78 | For instance, by setting to 8, it will use (8 x 2) ^ 3 = 4096 points to compute density for each voxel. 79 | In practise, the larger this number is, the more accurate surface you get. 80 | """) 81 | parser.add_argument('--mc-threshold', type=float, default=0.5, 82 | help="""the threshold used to find the isosurface from the learned implicit field. 83 | In our implementation, we define our values as ``1 - exp(-max(0, density))`` 84 | where "0" is empty and "1" is fully occupied. 85 | """) 86 | parser.add_argument('--cpu', action='store_true') 87 | # args = parser.parse_args() 88 | main(parser) 89 | 90 | 91 | if __name__ == '__main__': 92 | cli_main() 93 | -------------------------------------------------------------------------------- /fairnr_cli/myrender.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | 3 | import logging 4 | import math 5 | import os 6 | import sys 7 | import time 8 | import torch 9 | import imageio 10 | import numpy as np 11 | 12 | from fairnr import options 13 | 14 | import argparse 15 | from torch.utils.data import DataLoader 16 | from fairnr.tasks.neural_rendering import SingleObjRenderingTask 17 | from fairnr.models.nsvf import NSVFModel, my_base_architecture 18 | from fairnr.criterions.rendering_loss import SRNLossCriterion 19 | 20 | def my_generation(): 21 | logging.basicConfig( 22 | format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 23 | datefmt='%Y-%m-%d %H:%M:%S', 24 | level=logging.INFO, 25 | stream=sys.stdout, 26 | ) 27 | logger = logging.getLogger('RGBDNeRF.render') 28 | 29 | parser = argparse.ArgumentParser() 30 | SingleObjRenderingTask.add_args(parser) 31 | NSVFModel.add_args(parser) 32 | SRNLossCriterion.add_args(parser) 33 | options.add_rendering_args(parser) 34 | args = parser.parse_args() 35 | 36 | assert args.path is not None, '--path required for generation!' 37 | ckpt = torch.load(args.path) 38 | if 'args' in ckpt: 39 | load_args = ckpt['args'] 40 | 41 | # args = {key: load_args[key] for key in args.keys() if key in load_args.keys() else args[key]} 42 | for key in vars(args).keys(): 43 | if key in vars(load_args).keys(): 44 | setattr(args, key, getattr(load_args, key)) 45 | 46 | if args.model_overrides is not None: 47 | arg_overrides = eval(args.model_overrides) 48 | for key in arg_overrides: 49 | setattr(args, key, arg_overrides[key]) 50 | 51 | my_base_architecture(args) 52 | print(args) 53 | task = SingleObjRenderingTask(args) 54 | 55 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 56 | 57 | # load dataset 58 | task.load_dataset('test') 59 | gen_loader = DataLoader( 60 | task.datasets['test'], collate_fn=task.datasets['test'].collater, batch_size=1, num_workers=0 61 | ) 62 | 63 | # Build mocel and generator 64 | model = task.build_model(args).cuda() 65 | generator = task.build_generator(args) 66 | model.load_state_dict(ckpt['model']) 67 | if generator.test_poses is not None: 68 | frames = generator.test_poses.shape[0] 69 | else: 70 | frames = args.render_num_frames 71 | 72 | output_files, step= [], 0 73 | for i, sample in enumerate(gen_loader): 74 | sample = {key: sample[key].cuda() for key in sample.keys() if isinstance(sample[key], torch.Tensor)} 75 | step, _output_files = task.inference_step(generator, [model], [sample, step, frames]) 76 | output_files += _output_files 77 | print(step) 78 | 79 | generator.save_images(output_files, combine_output=args.render_combine_output) 80 | 81 | if __name__ == '__main__': 82 | my_generation() 83 | -------------------------------------------------------------------------------- /fairnr_cli/mytrain.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import math 4 | import os 5 | import random 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import argparse 12 | from torch.utils.data import DataLoader 13 | from torch import nn, optim 14 | from torch.optim.lr_scheduler import StepLR 15 | 16 | from fairnr.tasks.neural_rendering import SingleObjRenderingTask 17 | from fairnr.models.nsvf import NSVFModel, my_base_architecture 18 | from fairnr.criterions.rendering_loss import SRNLossCriterion 19 | 20 | logging.basicConfig( 21 | format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 22 | datefmt='%Y-%m-%d %H:%M:%S', 23 | level=logging.INFO, 24 | stream=sys.stdout, 25 | ) 26 | logger = logging.getLogger('RGBDNeRF.train') 27 | 28 | 29 | class PolynomialDecayLRSchedule(object): 30 | """Decay the LR on a fixed schedule.""" 31 | 32 | def __init__(self, total_num_update, lr, optimizer, warmup_updates=0, force_anneal=None, end_learning_rate=0.0, power=1.0): 33 | super().__init__() 34 | 35 | assert total_num_update > 0 36 | 37 | self.lr = lr 38 | if warmup_updates > 0: 39 | self.warmup_factor = 1.0 / warmup_updates 40 | else: 41 | self.warmup_factor = 1 42 | self.warmup_updates = warmup_updates 43 | self.force_anneal = force_anneal 44 | self.end_learning_rate = end_learning_rate 45 | self.total_num_update = total_num_update 46 | self.power = power 47 | self.optimizer = optimizer 48 | # self.set_lr(self.warmup_factor * self.lr) # set lr when we define optimizer 49 | 50 | def set_lr(self, lr): 51 | for param_group in self.optimizer.param_groups: 52 | param_group["lr"] = lr 53 | 54 | def get_lr(self): 55 | return self.optimizer.param_groups[0]["lr"] 56 | 57 | def get_next_lr(self, epoch): 58 | lrs = self.lr 59 | if self.force_anneal is None or epoch < self.force_anneal: 60 | # use fixed LR schedule 61 | next_lr = lrs[min(epoch, len(lrs) - 1)] 62 | else: 63 | # annneal based on lr_shrink 64 | next_lr = self.get_lr() 65 | return next_lr 66 | 67 | def step_begin_epoch(self, epoch): 68 | """Update the learning rate at the beginning of the given epoch.""" 69 | self.lr = self.get_next_lr(epoch) 70 | self.set_lr(self.warmup_factor * self.lr) 71 | return self.get_lr() 72 | 73 | def step_update(self, num_updates): 74 | """Update the learning rate after each update.""" 75 | if self.warmup_updates > 0 and num_updates <= self.warmup_updates: 76 | self.warmup_factor = num_updates / float(self.warmup_updates) 77 | lr = self.warmup_factor * self.lr 78 | elif num_updates >= self.total_num_update: 79 | lr = self.end_learning_rate 80 | else: 81 | warmup = self.warmup_updates 82 | lr_range = self.lr - self.end_learning_rate 83 | pct_remaining = 1 - (num_updates - warmup) / (self.total_num_update - warmup) 84 | lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate 85 | self.set_lr(lr) 86 | return self.get_lr() 87 | 88 | def step(self, num_updates): 89 | return self.step_update(num_updates) 90 | 91 | 92 | def my_train(): 93 | parser = argparse.ArgumentParser() 94 | SingleObjRenderingTask.add_args(parser) 95 | NSVFModel.add_args(parser) 96 | SRNLossCriterion.add_args(parser) 97 | args = parser.parse_args() 98 | my_base_architecture(args) 99 | print(args) 100 | 101 | task = SingleObjRenderingTask(args) 102 | gpu_id = args.gpu_id.split(',')[0] 103 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id 104 | 105 | torch.cuda.set_device(int(gpu_id)) 106 | 107 | if not os.path.exists(args.save_dir): 108 | os.makedirs(args.save_dir, exist_ok=True) 109 | 110 | if args.load_pretrain and os.path.exists(os.path.join(args.save_dir, 'checkpoint_mesh_pretrain.pt')): 111 | ckpt = torch.load(os.path.join(args.save_dir, 'checkpoint_mesh_pretrain.pt')) 112 | logger.info('checkpoint load!') 113 | start_epoch = ckpt['epoch'] + 1 114 | model_state = ckpt['model'] 115 | if 'lr' in ckpt: 116 | lr = ckpt['lr'] 117 | elif os.path.exists(os.path.join(args.save_dir, 'checkpoint_last.pt')): 118 | ckpt = torch.load(os.path.join(args.save_dir, 'checkpoint_last.pt')) 119 | logger.info('checkpoint load!') 120 | start_epoch = ckpt['epoch'] + 1 121 | model_state = ckpt['model'] 122 | if 'lr' in ckpt: 123 | lr = ckpt['lr'] 124 | else: 125 | start_epoch = 0 126 | model_state = None 127 | lr = args.lr 128 | 129 | # load dataset 130 | task.load_dataset('train') 131 | task.load_dataset('valid') 132 | # itr = task.get_batch_iterator(task.datasets['train']) 133 | train_loader = DataLoader( 134 | task.datasets['train'], collate_fn=task.datasets['train'].collater, batch_size=args.batch_size, num_workers=args.num_workers 135 | ) 136 | test_loader = DataLoader( 137 | task.datasets['valid'], collate_fn=task.datasets['valid'].dataset.collater, batch_size=args.batch_size, num_workers=args.num_workers 138 | ) 139 | 140 | if args.mesh_data is not None and start_epoch < args.mesh_pretrain_num: 141 | mesh_pretrain = True 142 | task.load_mesh_dataset('train') 143 | task.load_mesh_dataset('valid') 144 | mesh_train_loader = DataLoader( 145 | task.mesh_datasets['train'], collate_fn=task.mesh_datasets['train'].collater, batch_size=args.batch_size, num_workers=args.num_workers 146 | ) 147 | mesh_test_loader = DataLoader( 148 | task.mesh_datasets['valid'], collate_fn=task.mesh_datasets['valid'].dataset.collater, batch_size=args.batch_size, num_workers=args.num_workers 149 | ) 150 | else: 151 | mesh_pretrain = False 152 | 153 | # Build model and criterion 154 | model = task.build_model(args).cuda() 155 | criterion = task.build_criterion(args) 156 | # logger.info(model) 157 | logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) 158 | logger.info('num. model params: {} (num. trained: {})'.format( 159 | sum(p.numel() for p in model.parameters()), 160 | sum(p.numel() for p in model.parameters() if p.requires_grad), 161 | )) 162 | if model_state is not None: 163 | model.load_state_dict(model_state) 164 | 165 | model.encoder.max_hits = torch.scalar_tensor(args.max_hits) 166 | 167 | if len(args.tensorboard_logdir) > 0: 168 | from torch.utils.tensorboard import SummaryWriter 169 | train_writer = SummaryWriter(args.tensorboard_logdir + '/train') 170 | valid_writer = SummaryWriter(args.tensorboard_logdir + '/valid') 171 | image_writer = SummaryWriter(args.tensorboard_logdir + '/images') 172 | else: 173 | train_writer, valid_writer, image_writer = None, None, None 174 | 175 | if args.dis: 176 | if args.optimizer == 'adam': 177 | optimizer = optim.Adam([{'params':model.reader.parameters()}, {'params':model.encoder.parameters()}, {'params':model.field.parameters()}, {'params':model.raymarcher.parameters()}], lr=lr, betas=eval(args.adam_betas)) 178 | dis_optimizer = optim.Adam(model.discriminator.parameters(), lr=lr, betas=eval(args.adam_betas)) 179 | elif args.optimizer == 'rmsprop': 180 | optimizer = optim.RMSprop([{'params':model.reader.parameters()}, {'params':model.encoder.parameters()}, {'params':model.field.parameters()}, {'params':model.raymarcher.parameters()}], lr=lr, alpha=0.99, eps=1e-8) 181 | dis_optimizer = optim.RMSprop(model.discriminator.parameters(), lr=lr, alpha=0.99, eps=1e-8) 182 | elif args.optimizer == 'sgd': 183 | optimizer = optim.SGD([{'params':model.reader.parameters()}, {'params':model.encoder.parameters()}, {'params':model.field.parameters()}, {'params':model.raymarcher.parameters()}], lr=lr, momentum=0.) 184 | dis_optimizer = optim.SGD(model.discriminator.parameters(), lr=lr, momentum=0.) 185 | else: 186 | if args.optimizer == 'adam': 187 | optimizer = optim.Adam(model.parameters(), lr=lr) 188 | elif args.optimizer == 'rmsprop': 189 | optimizer = optim.RMSprop(model.parameters(), lr=lr, alpha=0.99, eps=1e-8) 190 | elif args.optimizer == 'sgd': 191 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.) 192 | if args.lr_scheduler == 'polynomial_decay': 193 | scheduler = PolynomialDecayLRSchedule(args.total_num_update, args.lr, optimizer) 194 | else: 195 | scheduler = StepLR(optimizer, int(args.lr_step), gamma=0.99) 196 | 197 | 198 | for epoch in range(start_epoch, args.total_num_update): 199 | if mesh_pretrain and epoch <= args.mesh_pretrain_num: 200 | loss, sample_size, logging_output, _ = task.train_step(mesh_train_loader, model, criterion, optimizer, scheduler, epoch) 201 | logger.info('epoch: {} | stage: {} | loss: {:.4f} | color: {:.4f} | alpha: {:.4f} | voxs_log: {:.4f} | stps_log: {:.4f} | tvox_log: {:.4f} | asf_log: {:.4f} | ash_log: {:.4f} | lr: {}'.format( 202 | epoch, 'mesh_pretrain', logging_output['loss'], logging_output['color_loss'], logging_output['alpha_loss'], logging_output['voxs_log'], 203 | logging_output['stps_log'], logging_output['tvox_log'], logging_output['asf_log'], logging_output['ash_log'], logging_output['lr'])) 204 | else: 205 | loss, sample_size, logging_output, _ = task.train_step(train_loader, model, criterion, optimizer, scheduler, epoch) 206 | if args.dis: 207 | for param_group in dis_optimizer.param_groups: 208 | param_group["lr"] = optimizer.param_groups[0]["lr"] 209 | _, _, _, loss_D = task.dis_train_step(train_loader, model, criterion, dis_optimizer) 210 | logger.info('epoch: {} | stage: {} | loss: {:.4f} | color: {:.4f} | alpha: {:.4f} | gen: {:.4f} | dis: {:.4f} | voxs_log: {:.4f} | stps_log: {:.4f} | tvox_log: {:.4f} | asf_log: {:.4f} | ash_log: {:.4f} | lr: {}'.format( 211 | epoch, 'real_train', logging_output['loss'], logging_output['color_loss'], logging_output['alpha_loss'], logging_output['gen_loss'], loss_D.data.item(), logging_output['voxs_log'], 212 | logging_output['stps_log'], logging_output['tvox_log'], logging_output['asf_log'], logging_output['ash_log'], logging_output['lr'])) 213 | else: 214 | logger.info('epoch: {} | stage: {} | loss: {:.4f} | color: {:.4f} | alpha: {:.4f} | voxs_log: {:.4f} | stps_log: {:.4f} | tvox_log: {:.4f} | asf_log: {:.4f} | ash_log: {:.4f} | lr: {}'.format( 215 | epoch, 'real_train', logging_output['loss'], logging_output['color_loss'], logging_output['alpha_loss'], logging_output['voxs_log'], 216 | logging_output['stps_log'], logging_output['tvox_log'], logging_output['asf_log'], logging_output['ash_log'], logging_output['lr'])) 217 | 218 | if train_writer is not None: 219 | for key in logging_output.keys(): 220 | if 'loss' in key: 221 | train_writer.add_scalar(key, logging_output[key], epoch) 222 | 223 | if epoch == args.mesh_pretrain_num: 224 | torch.save( 225 | {'model': model.state_dict(), 'args': args, 'epoch':epoch, 'lr':logging_output['lr']}, 226 | f'{args.save_dir}/checkpoint_mesh_pretrain.pt', 227 | ) 228 | 229 | if epoch > 0 and args.save_interval_updates > 0 and epoch % args.save_interval_updates == 0: 230 | last_epoch = epoch - args.keep_last_epochs * args.save_interval_updates 231 | if last_epoch > 0 and os.path.exists(os.path.join(args.save_dir,'checkpoint_'+str(last_epoch)+'.pt')): 232 | os.remove(os.path.join(args.save_dir,'checkpoint_'+str(last_epoch)+'.pt')) 233 | torch.save( 234 | {'model': model.state_dict(), 'args': args, 'epoch':epoch, 'lr':logging_output['lr']}, 235 | f'{args.save_dir}/checkpoint_{str(epoch)}.pt', 236 | ) 237 | torch.save( 238 | {'model': model.state_dict(), 'args': args, 'epoch':epoch, 'lr':logging_output['lr']}, 239 | f'{args.save_dir}/checkpoint_last.pt', 240 | ) 241 | logger.info('Save checkpoint!') 242 | 243 | if mesh_pretrain and epoch <= args.mesh_pretrain_num: 244 | for step, sample in enumerate(mesh_test_loader): 245 | sample = {key: sample[key].cuda() for key in sample.keys() if isinstance(sample[key], torch.Tensor)} 246 | valid_loss, valid_sample_size, valid_logging_output, valid_loss_D = task.valid_step(sample, model, criterion, image_writer) 247 | logger.info('epoch: {} | stage: {} | valid: {}/{} | loss: {:.4f} | color: {:.4f} | alpha: {:.4f} | voxs_log: {:.4f} | stps_log: {:.4f} | tvox_log: {:.4f} | asf_log: {:.4f} | ash_log: {:.4f}'.format( 248 | epoch, 'mesh_pretrain', step, len(mesh_test_loader), valid_logging_output['loss'], valid_logging_output['color_loss'], valid_logging_output['alpha_loss'], valid_logging_output['voxs_log'], 249 | valid_logging_output['stps_log'], valid_logging_output['tvox_log'], valid_logging_output['asf_log'], valid_logging_output['ash_log'])) 250 | else: 251 | for step, sample in enumerate(test_loader): 252 | sample = {key: sample[key].cuda() for key in sample.keys() if isinstance(sample[key], torch.Tensor)} 253 | valid_loss, valid_sample_size, valid_logging_output, valid_loss_D = task.valid_step(sample, model, criterion, image_writer) 254 | logger.info('epoch: {} | stage: {} | valid: {}/{} | loss: {:.4f} | color: {:.4f} | alpha: {:.4f} | voxs_log: {:.4f} | stps_log: {:.4f} | tvox_log: {:.4f} | asf_log: {:.4f} | ash_log: {:.4f}'.format( 255 | epoch, 'real_train', step, len(test_loader), valid_logging_output['loss'], valid_logging_output['color_loss'], valid_logging_output['alpha_loss'], valid_logging_output['voxs_log'], 256 | valid_logging_output['stps_log'], valid_logging_output['tvox_log'], valid_logging_output['asf_log'], valid_logging_output['ash_log'])) 257 | if valid_writer is not None: 258 | for key in valid_logging_output.keys(): 259 | if 'loss' in key: 260 | valid_writer.add_scalar(key, valid_logging_output[key], epoch) 261 | 262 | logger.info('done training!') 263 | 264 | 265 | if __name__ == '__main__': 266 | my_train() 267 | -------------------------------------------------------------------------------- /img/tpami.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/RGBDNeRF/f1c4b164337cd610b4e9da966a0c3755cecd2a6a/img/tpami.jpg -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | 2 | from fairnr_cli.myrender import my_generation 3 | 4 | 5 | if __name__ == '__main__': 6 | my_generation() 7 | -------------------------------------------------------------------------------- /render.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python render.py ./sample_data/scan_plant --path ./logs/test_scan_plant_save/checkpoint_last.pt --model-overrides '{"valid_chunk_size":64,"chunk_size":8,"raymarching_tolerance":0.01}' --render-save-fps 24 --render-resolution "540x720" --render-beam 1 --render-camera-poses ./sample_data/scan_plant_all/extrinsic --render-depth-rawoutput --render-views "0..339" --render-output ./logs/test_scan_plant_save/output_all --render-output-types "color" # 4 | 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | open3d==0.10.0 2 | opencv_python==4.2.0.32 3 | tqdm==4.43.0 4 | pandas==0.25.3 5 | imageio==2.6.1 6 | scikit_image==0.16.2 7 | scipy==1.4.1 8 | plyfile==0.7.1 9 | matplotlib==3.1.2 10 | numpy==1.16.4 11 | mathutils==2.81.2 12 | tensorboardX==2.0 13 | imageio-ffmpeg==0.4.2 14 | numba 15 | lpips 16 | -------------------------------------------------------------------------------- /run_scan_plant.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET=./sample_data/scan_plant 3 | save_dir=./logs/test_scan_plant_save 4 | mkdir -p "${save_dir}" 5 | python -u train.py "${DATASET}" --mesh-data ./sample_data/scan_plant_mesh \ 6 | --gpu-id "0" \ 7 | --train-views "0..9" --mesh-train-views "0..160" \ 8 | --view-resolution "540x720" --view-per-batch 1 --pixel-per-view 1024 \ 9 | --down-pixels-per-view-at "9000" --pixel-per-view-down "0.125" \ 10 | --no-preload --sampling-on-mask 1.0 --no-sampling-at-reader \ 11 | --valid-views "9..17" --mesh-valid-views "0..160:20" \ 12 | --valid-view-resolution "180x240" --valid-view-per-batch 1 \ 13 | --transparent-background "1.0,1.0,1.0" --background-stop-gradient \ 14 | --use-octree --raymarching-stepsize-ratio 0.125 --discrete-regularization \ 15 | --color-weight 128.0 --alpha-weight 1.0 --lr 0.001 \ 16 | --mesh-pretrain-num 9000 --total-num-update 150000 --save-interval-updates 500 \ 17 | --reduce-step-size-at "5000,25000,75000" \ 18 | --save-dir "${save_dir}" --tensorboard-logdir "${save_dir}"/tensorboard --chunk-size 8 \ 19 | --voxel-path "${DATASET}"/OccuVoxel_low.ply --octree-path "${DATASET}"/octree_low.npz --voxel-size 0.2 \ 20 | --dis --dis-views "9..323" --gan-weight 2.0 --gan-norm-layer "instance" --n-layers 3 --patch-size 32 \ 21 | --load-pc --pc-path "${DATASET}"/pc_color.ply --voxel-color-path "${DATASET}"/voxel_color.txt --pc-pose-dim 3 \ 22 | | tee -a "${save_dir}"/train.log 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | import glob 5 | 6 | # build clib 7 | _ext_src_root = "fairnr/clib" 8 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 9 | "{}/src/*.cu".format(_ext_src_root) 10 | ) 11 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 12 | 13 | setup( 14 | name='fairnr', 15 | ext_modules=[ 16 | CUDAExtension( 17 | name='fairnr.clib._ext', 18 | sources=_ext_sources, 19 | extra_compile_args={ 20 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 21 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 22 | }, 23 | ) 24 | ], 25 | cmdclass={ 26 | 'build_ext': BuildExtension 27 | }, 28 | entry_points={ 29 | 'console_scripts': [ 30 | 'fairnr-render = fairnr_cli.render:cli_main', 31 | 'fairnr-train = fairseq_cli.train:cli_main' 32 | ], 33 | }, 34 | ) 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from fairnr_cli.mytrain import my_train 3 | 4 | if __name__ == '__main__': 5 | my_train() 6 | --------------------------------------------------------------------------------