├── models ├── __init__.py ├── csrc │ ├── setup.py │ ├── include │ │ ├── utils.h │ │ └── helper_math.h │ ├── binding.cpp │ ├── losses.cu │ ├── intersection.cu │ └── volumerendering.cu ├── rendering.py ├── custom_functions.py └── nerfusion.py ├── representations ├── grufusion │ ├── __init__.py │ ├── backbone.py │ ├── back_project.py │ ├── torchsparse_utils.py │ ├── modules.py │ └── gru_fusion.py ├── __init__.py └── sparse_voxel_grid.py ├── assets ├── teaser.png └── pipeline.png ├── requirements.txt ├── metrics.py ├── datasets ├── __init__.py ├── color_utils.py ├── depth_utils.py ├── base.py ├── nerfpp.py ├── scannet.py ├── nerf.py ├── nsvf.py ├── _google_scanned_obj.py ├── colmap.py ├── ray_utils.py ├── data_utils.py └── colmap_utils.py ├── scripts └── scannet_get_bbox.py ├── LICENSE ├── utils.py ├── losses.py ├── .gitignore ├── opt.py ├── README.md └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /representations/grufusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /representations/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparse_voxel_grid import SparseVoxelGrid -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jetd1/NeRFusion/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jetd1/NeRFusion/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.4.1 2 | kornia==0.6.5 3 | pytorch-lightning==1.7.7 4 | matplotlib==3.5.2 5 | opencv-python==4.6.0.66 6 | lpips 7 | imageio 8 | imageio-ffmpeg 9 | jupyter 10 | scipy 11 | pymcubes 12 | trimesh 13 | dearpygui -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 5 | value = (image_pred-image_gt)**2 6 | if valid_mask is not None: 7 | value = value[valid_mask] 8 | if reduction == 'mean': 9 | return torch.mean(value) 10 | return value 11 | 12 | 13 | @torch.no_grad() 14 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'): 15 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 16 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nerf import NeRFDataset 2 | from .nsvf import NSVFDataset 3 | from .colmap import ColmapDataset 4 | from .nerfpp import NeRFPPDataset 5 | from .scannet import ScanNetDataset 6 | from ._google_scanned_obj import GoogleScannedDataset 7 | 8 | 9 | dataset_dict = {'nerf': NeRFDataset, 10 | 'nsvf': NSVFDataset, 11 | 'colmap': ColmapDataset, 12 | 'nerfpp': NeRFPPDataset, 13 | 'scannet': ScanNetDataset, 14 | 'google_scanned': GoogleScannedDataset} -------------------------------------------------------------------------------- /scripts/scannet_get_bbox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tqdm 4 | 5 | import numpy as np 6 | 7 | if __name__ == '__main__': 8 | root_dir = sys.argv[1] 9 | xyzs = [] 10 | for pose_file in tqdm.tqdm(os.listdir(os.path.join(root_dir, 'pose'))): 11 | pose = np.loadtxt(os.path.join(root_dir, f'pose/{pose_file}')) 12 | xyz = pose[:3, -1] 13 | xyzs.append(xyz) 14 | 15 | xyzs = np.array(xyzs) 16 | xyz_min = xyzs.min(axis=0) 17 | xyz_max = xyzs.max(axis=0) 18 | 19 | output = np.array([xyz_min, xyz_max]) 20 | np.savetxt(os.path.join(root_dir, 'cam_bbox.txt'), output) 21 | 22 | -------------------------------------------------------------------------------- /datasets/color_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from einops import rearrange 3 | import imageio 4 | import numpy as np 5 | 6 | 7 | def srgb_to_linear(img): 8 | limit = 0.04045 9 | return np.where(img>limit, ((img+0.055)/1.055)**2.4, img/12.92) 10 | 11 | 12 | def linear_to_srgb(img): 13 | limit = 0.0031308 14 | img = np.where(img>limit, 1.055*img**(1/2.4)-0.055, 12.92*img) 15 | img[img>1] = 1 # "clamp" tonemapper 16 | return img 17 | 18 | 19 | def read_image(img_path, img_wh, blend_a=True, unpad=0): 20 | img = imageio.imread(img_path).astype(np.float32)/255.0 21 | # img[..., :3] = srgb_to_linear(img[..., :3]) 22 | if img.shape[2] == 4: # blend A to RGB 23 | if blend_a: 24 | img = img[..., :3]*img[..., -1:]+(1-img[..., -1:]) 25 | else: 26 | img = img[..., :3]*img[..., -1:] 27 | 28 | if unpad > 0: 29 | img = img[unpad:-unpad, unpad:-unpad] 30 | 31 | img = cv2.resize(img, img_wh) 32 | img = rearrange(img, 'h w c -> (h w) c') 33 | 34 | return img -------------------------------------------------------------------------------- /models/csrc/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 5 | 6 | 7 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 8 | include_dirs = [osp.join(ROOT_DIR, "include")] 9 | # "helper_math.h" is copied from https://github.com/NVIDIA/cuda-samples/blob/master/Common/helper_math.h 10 | 11 | sources = glob.glob('*.cpp')+glob.glob('*.cu') 12 | 13 | 14 | setup( 15 | name='vren', 16 | version='2.0', 17 | author='kwea123', 18 | author_email='kwea123@gmail.com', 19 | description='cuda volume rendering library', 20 | long_description='cuda volume rendering library', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name='vren', 24 | sources=sources, 25 | include_dirs=include_dirs, 26 | extra_compile_args={'cxx': ['-O2'], 27 | 'nvcc': ['-O2']} 28 | ) 29 | ], 30 | cmdclass={ 31 | 'build_ext': BuildExtension 32 | } 33 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 Xiaoshuai Zhang 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /representations/sparse_voxel_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchsparse import PointTensor, SparseTensor 4 | from torchsparse.utils.quantize import sparse_quantize 5 | 6 | 7 | # https://github.com/zju3dv/NeuralRecon/blob/master/ops/generate_grids.py 8 | def generate_grid(n_vox, interval): 9 | with torch.no_grad(): 10 | # Create voxel grid 11 | grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)] 12 | grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2])) # 3 dx dy dz 13 | grid = grid.unsqueeze(0).cuda().float() # 1 3 dx dy dz 14 | grid = grid.view(1, 3, -1) 15 | return grid 16 | 17 | 18 | class SparseVoxelGrid(nn.Module): 19 | def __init__(self, scale, resolution, feat_dim): 20 | """ 21 | scale: range of xyz. 0.5 -> (-0.5, 0.5) 22 | resolution: #voxels within each dim. 128 -> 128x128x128 23 | """ 24 | super().__init__() 25 | 26 | self.scale = scale 27 | self.resolution = resolution 28 | self.voxel_size = scale * 2 / resolution 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /datasets/depth_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | 5 | def read_pfm(path): 6 | """Read pfm file. 7 | 8 | Args: 9 | path (str): path to file 10 | 11 | Returns: 12 | tuple: (data, scale) 13 | """ 14 | with open(path, "rb") as file: 15 | 16 | color = None 17 | width = None 18 | height = None 19 | scale = None 20 | endian = None 21 | 22 | header = file.readline().rstrip() 23 | if header.decode("ascii") == "PF": 24 | color = True 25 | elif header.decode("ascii") == "Pf": 26 | color = False 27 | else: 28 | raise Exception("Not a PFM file: " + path) 29 | 30 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 31 | if dim_match: 32 | width, height = list(map(int, dim_match.groups())) 33 | else: 34 | raise Exception("Malformed PFM header.") 35 | 36 | scale = float(file.readline().decode("ascii").rstrip()) 37 | if scale < 0: 38 | # little-endian 39 | endian = "<" 40 | scale = -scale 41 | else: 42 | # big-endian 43 | endian = ">" 44 | 45 | data = np.fromfile(file, endian + "f") 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | 51 | return data, scale -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]): 5 | checkpoint = torch.load(ckpt_path, map_location='cpu') 6 | checkpoint_ = {} 7 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint 8 | checkpoint = checkpoint['state_dict'] 9 | for k, v in checkpoint.items(): 10 | if not k.startswith(model_name): 11 | continue 12 | k = k[len(model_name)+1:] 13 | for prefix in prefixes_to_ignore: 14 | if k.startswith(prefix): 15 | break 16 | else: 17 | checkpoint_[k] = v 18 | return checkpoint_ 19 | 20 | 21 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): 22 | if not ckpt_path: return 23 | model_dict = model.state_dict() 24 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore) 25 | model_dict.update(checkpoint_) 26 | model.load_state_dict(model_dict) 27 | 28 | 29 | def slim_ckpt(ckpt_path, save_poses=False): 30 | ckpt = torch.load(ckpt_path, map_location='cpu') 31 | # pop unused parameters 32 | keys_to_pop = ['directions', 'model.density_grid', 'model.grid_coords'] 33 | if not save_poses: keys_to_pop += ['poses'] 34 | for k in ckpt['state_dict']: 35 | if k.startswith('val_lpips'): 36 | keys_to_pop += [k] 37 | for k in keys_to_pop: 38 | ckpt['state_dict'].pop(k, None) 39 | return ckpt['state_dict'] 40 | -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(Dataset): 6 | """ 7 | Define length and sampling method 8 | """ 9 | def __init__(self, root_dir, split='train', downsample=1.0): 10 | self.root_dir = root_dir 11 | self.split = split 12 | self.downsample = downsample 13 | 14 | def read_intrinsics(self): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | if self.split.startswith('train'): 19 | return 1000 20 | return len(self.poses) 21 | 22 | def __getitem__(self, idx): 23 | if self.split.startswith('train'): 24 | # training pose is retrieved in train.py 25 | if self.ray_sampling_strategy == 'all_images': # randomly select images 26 | img_idxs = np.random.choice(len(self.poses), self.batch_size) 27 | elif self.ray_sampling_strategy == 'same_image': # randomly select ONE image 28 | img_idxs = np.random.choice(len(self.poses), 1)[0] 29 | # randomly select pixels 30 | pix_idxs = np.random.choice(self.img_wh[0]*self.img_wh[1], self.batch_size) 31 | rays = self.rays[img_idxs, pix_idxs] 32 | sample = {'img_idxs': img_idxs, 'pix_idxs': pix_idxs, 33 | 'rgb': rays[:, :3]} 34 | if self.rays.shape[-1] == 4: # HDR-NeRF data 35 | sample['exposure'] = rays[:, 3:] 36 | else: 37 | sample = {'pose': self.poses[idx], 'img_idxs': idx} 38 | if len(self.rays)>0: # if ground truth available 39 | rays = self.rays[idx] 40 | sample['rgb'] = rays[:, :3] 41 | if rays.shape[1] == 4: # HDR-NeRF data 42 | sample['exposure'] = rays[0, 3] # same exposure for all rays 43 | 44 | return sample -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import vren 4 | 5 | 6 | class DistortionLoss(torch.autograd.Function): 7 | """ 8 | Distortion loss proposed in Mip-NeRF 360 (https://arxiv.org/pdf/2111.12077.pdf) 9 | Implementation is based on DVGO-v2 (https://arxiv.org/pdf/2206.05085.pdf) 10 | 11 | Inputs: 12 | ws: (N) sample point weights 13 | deltas: (N) considered as intervals 14 | ts: (N) considered as midpoints 15 | rays_a: (N_rays, 3) ray_idx, start_idx, N_samples 16 | meaning each entry corresponds to the @ray_idx th ray, 17 | whose samples are [start_idx:start_idx+N_samples] 18 | 19 | Outputs: 20 | loss: (N_rays) 21 | """ 22 | @staticmethod 23 | def forward(ctx, ws, deltas, ts, rays_a): 24 | loss, ws_inclusive_scan, wts_inclusive_scan = \ 25 | vren.distortion_loss_fw(ws, deltas, ts, rays_a) 26 | ctx.save_for_backward(ws_inclusive_scan, wts_inclusive_scan, 27 | ws, deltas, ts, rays_a) 28 | return loss 29 | 30 | @staticmethod 31 | def backward(ctx, dL_dloss): 32 | (ws_inclusive_scan, wts_inclusive_scan, 33 | ws, deltas, ts, rays_a) = ctx.saved_tensors 34 | dL_dws = vren.distortion_loss_bw(dL_dloss, ws_inclusive_scan, 35 | wts_inclusive_scan, 36 | ws, deltas, ts, rays_a) 37 | return dL_dws, None, None, None 38 | 39 | 40 | class NeRFLoss(nn.Module): 41 | def __init__(self, lambda_opacity=1e-3, lambda_distortion=1e-3): 42 | super().__init__() 43 | 44 | self.lambda_opacity = lambda_opacity 45 | self.lambda_distortion = lambda_distortion 46 | 47 | def forward(self, results, target, **kwargs): 48 | d = {} 49 | d['rgb'] = (results['rgb']-target['rgb'])**2 50 | 51 | o = results['opacity']+1e-10 52 | # encourage opacity to be either 0 or 1 to avoid floater 53 | d['opacity'] = self.lambda_opacity*(-o*torch.log(o)) 54 | 55 | if self.lambda_distortion > 0: 56 | d['distortion'] = self.lambda_distortion * \ 57 | DistortionLoss.apply(results['ws'], results['deltas'], 58 | results['ts'], results['rays_a']) 59 | 60 | return d 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | .idea 133 | /logs/ 134 | /results/ 135 | /ckpts/ 136 | -------------------------------------------------------------------------------- /datasets/nerfpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | from .ray_utils import get_ray_directions 9 | from .color_utils import read_image 10 | 11 | from .base import BaseDataset 12 | 13 | 14 | class NeRFPPDataset(BaseDataset): 15 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 16 | super().__init__(root_dir, split, downsample) 17 | 18 | self.read_intrinsics() 19 | 20 | if kwargs.get('read_meta', True): 21 | self.read_meta(split) 22 | 23 | def read_intrinsics(self): 24 | K = np.loadtxt(glob.glob(os.path.join(self.root_dir, 'train/intrinsics/*.txt'))[0], 25 | dtype=np.float32).reshape(4, 4)[:3, :3] 26 | K[:2] *= self.downsample 27 | w, h = Image.open(glob.glob(os.path.join(self.root_dir, 'train/rgb/*'))[0]).size 28 | w, h = int(w*self.downsample), int(h*self.downsample) 29 | self.K = torch.FloatTensor(K) 30 | self.directions = get_ray_directions(h, w, self.K) 31 | self.img_wh = (w, h) 32 | 33 | def read_meta(self, split): 34 | self.rays = [] 35 | self.poses = [] 36 | 37 | if split == 'test_traj': 38 | poses_path = \ 39 | sorted(glob.glob(os.path.join(self.root_dir, 'camera_path/pose/*.txt'))) 40 | self.poses = [np.loadtxt(p).reshape(4, 4)[:3] for p in poses_path] 41 | else: 42 | if split=='trainval': 43 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'train/rgb/*')))+\ 44 | sorted(glob.glob(os.path.join(self.root_dir, 'val/rgb/*'))) 45 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'train/pose/*.txt')))+\ 46 | sorted(glob.glob(os.path.join(self.root_dir, 'val/pose/*.txt'))) 47 | else: 48 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, split, 'rgb/*'))) 49 | poses = sorted(glob.glob(os.path.join(self.root_dir, split, 'pose/*.txt'))) 50 | 51 | print(f'Loading {len(img_paths)} {split} images ...') 52 | for img_path, pose in tqdm(zip(img_paths, poses)): 53 | self.poses += [np.loadtxt(pose).reshape(4, 4)[:3]] 54 | 55 | img = read_image(img_path, self.img_wh) 56 | self.rays += [img] 57 | 58 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 59 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 60 | -------------------------------------------------------------------------------- /datasets/scannet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import get_ray_directions 8 | from .color_utils import read_image 9 | 10 | from .base import BaseDataset 11 | 12 | SCANNET_FAR = 2.0 13 | 14 | 15 | class ScanNetDataset(BaseDataset): 16 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 17 | super().__init__(root_dir, split, downsample) 18 | 19 | self.unpad = 24 20 | 21 | self.read_intrinsics() 22 | 23 | if kwargs.get('read_meta', True): 24 | self.read_meta(split) 25 | 26 | def read_intrinsics(self): 27 | K = np.loadtxt(os.path.join(self.root_dir, "./intrinsic/intrinsic_color.txt"))[:3, :3] 28 | H, W = 968 - 2 * self.unpad, 1296 - 2 * self.unpad 29 | K[:2, 2] -= self.unpad 30 | self.K = torch.FloatTensor(K) 31 | self.directions = get_ray_directions(H, W, self.K) 32 | self.img_wh = (W, H) 33 | 34 | def read_meta(self, split): 35 | self.rays = [] 36 | self.poses = [] 37 | 38 | if split == 'train': 39 | with open(os.path.join(self.root_dir, "train.txt"), 'r') as f: 40 | frames = f.read().strip().split() 41 | frames = frames[:800] 42 | else: 43 | with open(os.path.join(self.root_dir, f"{split}.txt"), 'r') as f: 44 | frames = f.read().strip().split() 45 | frames = frames[:80] 46 | 47 | cam_bbox = np.loadtxt(os.path.join(self.root_dir, f"cam_bbox.txt")) 48 | sbbox_scale = (cam_bbox[1] - cam_bbox[0]).max() + 2 * SCANNET_FAR 49 | sbbox_shift = cam_bbox.mean(axis=0) 50 | 51 | print(f'Loading {len(frames)} {split} images ...') 52 | for frame in tqdm(frames): 53 | c2w = np.loadtxt(os.path.join(self.root_dir, f"pose/{frame}.txt"))[:3] 54 | 55 | # add shift 56 | c2w[0, 3] -= sbbox_shift[0] 57 | c2w[1, 3] -= sbbox_shift[1] 58 | c2w[2, 3] -= sbbox_shift[2] 59 | c2w[:, 3] /= sbbox_scale 60 | 61 | self.poses += [c2w] 62 | 63 | try: 64 | img_path = os.path.join(self.root_dir, f"color/{frame}.jpg") 65 | img = read_image(img_path, self.img_wh, unpad=self.unpad) 66 | self.rays += [img] 67 | except: pass 68 | 69 | if len(self.rays)>0: 70 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 71 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 72 | -------------------------------------------------------------------------------- /representations/grufusion/backbone.py: -------------------------------------------------------------------------------- 1 | # ported from NeuralRecon (https://github.com/zju3dv/NeuralRecon) 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | 7 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9): 8 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default 9 | bias, will round up, unless the number is no more than 10% greater than the 10 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ 11 | assert 0.0 < round_up_bias < 1.0 12 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) 13 | return new_val if new_val >= round_up_bias * val else new_val + divisor 14 | 15 | 16 | def _get_depths(alpha): 17 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up 18 | rather than down. """ 19 | depths = [32, 16, 24, 40, 80, 96, 192, 320] 20 | return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] 21 | 22 | 23 | class MnasMulti(nn.Module): 24 | 25 | def __init__(self, alpha=1.0): 26 | super(MnasMulti, self).__init__() 27 | depths = _get_depths(alpha) 28 | if alpha == 1.0: 29 | MNASNet = torchvision.models.mnasnet1_0(pretrained=True, progress=True) 30 | else: 31 | MNASNet = torchvision.models.MNASNet(alpha=alpha) 32 | 33 | self.conv0 = nn.Sequential( 34 | MNASNet.layers._modules['0'], 35 | MNASNet.layers._modules['1'], 36 | MNASNet.layers._modules['2'], 37 | MNASNet.layers._modules['3'], 38 | MNASNet.layers._modules['4'], 39 | MNASNet.layers._modules['5'], 40 | MNASNet.layers._modules['6'], 41 | MNASNet.layers._modules['7'], 42 | MNASNet.layers._modules['8'], 43 | ) 44 | 45 | self.conv1 = MNASNet.layers._modules['9'] 46 | self.conv2 = MNASNet.layers._modules['10'] 47 | 48 | self.out1 = nn.Conv2d(depths[4], depths[4], 1, bias=False) 49 | self.out_channels = [depths[4]] 50 | 51 | final_chs = depths[4] 52 | self.inner1 = nn.Conv2d(depths[3], final_chs, 1, bias=True) 53 | self.inner2 = nn.Conv2d(depths[2], final_chs, 1, bias=True) 54 | 55 | self.out2 = nn.Conv2d(final_chs, depths[3], 3, padding=1, bias=False) 56 | self.out3 = nn.Conv2d(final_chs, depths[2], 3, padding=1, bias=False) 57 | self.out_channels.append(depths[3]) 58 | self.out_channels.append(depths[2]) 59 | 60 | def forward(self, x): 61 | conv0 = self.conv0(x) 62 | conv1 = self.conv1(conv0) 63 | conv2 = self.conv2(conv1) 64 | 65 | intra_feat = conv2 66 | outputs = [] 67 | out = self.out1(intra_feat) 68 | outputs.append(out) 69 | 70 | intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner1(conv1) 71 | out = self.out2(intra_feat) 72 | outputs.append(out) 73 | 74 | intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner2(conv0) 75 | out = self.out3(intra_feat) 76 | outputs.append(out) 77 | 78 | return outputs[::-1] 79 | -------------------------------------------------------------------------------- /datasets/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import get_ray_directions 8 | from .color_utils import read_image 9 | 10 | from .base import BaseDataset 11 | 12 | 13 | class NeRFDataset(BaseDataset): 14 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 15 | super().__init__(root_dir, split, downsample) 16 | 17 | self.read_intrinsics() 18 | 19 | if kwargs.get('read_meta', True): 20 | self.read_meta(split) 21 | 22 | def read_intrinsics(self): 23 | with open(os.path.join(self.root_dir, "transforms_train.json"), 'r') as f: 24 | meta = json.load(f) 25 | 26 | w = h = int(800*self.downsample) 27 | fx = fy = 0.5*800/np.tan(0.5*meta['camera_angle_x'])*self.downsample 28 | 29 | K = np.float32([[fx, 0, w/2], 30 | [0, fy, h/2], 31 | [0, 0, 1]]) 32 | 33 | self.K = torch.FloatTensor(K) 34 | self.directions = get_ray_directions(h, w, self.K) 35 | self.img_wh = (w, h) 36 | 37 | def read_meta(self, split): 38 | self.rays = [] 39 | self.poses = [] 40 | 41 | if split == 'trainval': 42 | with open(os.path.join(self.root_dir, "transforms_train.json"), 'r') as f: 43 | frames = json.load(f)["frames"] 44 | with open(os.path.join(self.root_dir, "transforms_val.json"), 'r') as f: 45 | frames+= json.load(f)["frames"] 46 | else: 47 | with open(os.path.join(self.root_dir, f"transforms_{split}.json"), 'r') as f: 48 | frames = json.load(f)["frames"] 49 | 50 | print(f'Loading {len(frames)} {split} images ...') 51 | for frame in tqdm(frames): 52 | c2w = np.array(frame['transform_matrix'])[:3, :4] 53 | 54 | # determine scale 55 | if 'Jrender_Dataset' in self.root_dir: 56 | c2w[:, :2] *= -1 # [left up front] to [right down front] 57 | folder = self.root_dir.split('/') 58 | scene = folder[-1] if folder[-1] != '' else folder[-2] 59 | if scene=='Easyship': 60 | pose_radius_scale = 1.2 61 | elif scene=='Scar': 62 | pose_radius_scale = 1.8 63 | elif scene=='Coffee': 64 | pose_radius_scale = 2.5 65 | elif scene=='Car': 66 | pose_radius_scale = 0.8 67 | else: 68 | pose_radius_scale = 1.5 69 | else: 70 | c2w[:, 1:3] *= -1 # [right up back] to [right down front] 71 | pose_radius_scale = 1.5 72 | c2w[:, 3] /= np.linalg.norm(c2w[:, 3])/pose_radius_scale 73 | 74 | # add shift 75 | if 'Jrender_Dataset' in self.root_dir: 76 | if scene=='Coffee': 77 | c2w[1, 3] -= 0.4465 78 | elif scene=='Car': 79 | c2w[0, 3] -= 0.7 80 | self.poses += [c2w] 81 | 82 | try: 83 | img_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 84 | img = read_image(img_path, self.img_wh) 85 | self.rays += [img] 86 | except: pass 87 | 88 | if len(self.rays)>0: 89 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 90 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 91 | -------------------------------------------------------------------------------- /representations/grufusion/back_project.py: -------------------------------------------------------------------------------- 1 | # ported from NeuralRecon (https://github.com/zju3dv/NeuralRecon) 2 | import torch 3 | from torch.nn.functional import grid_sample 4 | 5 | 6 | def back_project(coords, origin, voxel_size, feats, KRcam): 7 | ''' 8 | Unproject the image fetures to form a 3D (sparse) feature volume 9 | 10 | :param coords: coordinates of voxels, 11 | dim: (num of voxels, 4) (4 : batch ind, x, y, z) 12 | :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) 13 | dim: (batch size, 3) (3: x, y, z) 14 | :param voxel_size: floats specifying the size of a voxel 15 | :param feats: image features 16 | dim: (num of views, batch size, C, H, W) 17 | :param KRcam: projection matrix 18 | dim: (num of views, batch size, 4, 4) 19 | :return: feature_volume_all: 3D feature volumes 20 | dim: (num of voxels, c + 1) 21 | :return: count: number of times each voxel can be seen 22 | dim: (num of voxels,) 23 | ''' 24 | n_views, bs, c, h, w = feats.shape 25 | 26 | feature_volume_all = torch.zeros(coords.shape[0], c + 1).cuda() 27 | count = torch.zeros(coords.shape[0]).cuda() 28 | 29 | for batch in range(bs): 30 | batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1) 31 | coords_batch = coords[batch_ind][:, 1:] 32 | 33 | coords_batch = coords_batch.view(-1, 3) 34 | origin_batch = origin[batch].unsqueeze(0) 35 | feats_batch = feats[:, batch] 36 | proj_batch = KRcam[:, batch] 37 | 38 | grid_batch = coords_batch * voxel_size + origin_batch.float() 39 | rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1) 40 | rs_grid = rs_grid.permute(0, 2, 1).contiguous() 41 | nV = rs_grid.shape[-1] 42 | rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).cuda()], dim=1) 43 | 44 | # Project grid 45 | im_p = proj_batch @ rs_grid 46 | im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] 47 | im_x = im_x / im_z 48 | im_y = im_y / im_z 49 | 50 | im_grid = torch.stack([2 * im_x / (w - 1) - 1, 2 * im_y / (h - 1) - 1], dim=-1) 51 | mask = im_grid.abs() <= 1 52 | mask = (mask.sum(dim=-1) == 2) & (im_z > 0) 53 | 54 | feats_batch = feats_batch.view(n_views, c, h, w) 55 | im_grid = im_grid.view(n_views, 1, -1, 2) 56 | features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True) 57 | 58 | features = features.view(n_views, c, -1) 59 | mask = mask.view(n_views, -1) 60 | im_z = im_z.view(n_views, -1) 61 | # remove nan 62 | features[mask.unsqueeze(1).expand(-1, c, -1) == False] = 0 63 | im_z[mask == False] = 0 64 | 65 | count[batch_ind] = mask.sum(dim=0).float() 66 | 67 | # aggregate multi view 68 | features = features.sum(dim=0) 69 | mask = mask.sum(dim=0) 70 | invalid_mask = mask == 0 71 | mask[invalid_mask] = 1 72 | in_scope_mask = mask.unsqueeze(0) 73 | features /= in_scope_mask 74 | features = features.permute(1, 0).contiguous() 75 | 76 | # concat normalized depth value 77 | im_z = im_z.sum(dim=0).unsqueeze(1) / in_scope_mask.permute(1, 0).contiguous() 78 | im_z_mean = im_z[im_z > 0].mean() 79 | im_z_std = torch.norm(im_z[im_z > 0] - im_z_mean) + 1e-5 80 | im_z_norm = (im_z - im_z_mean) / im_z_std 81 | im_z_norm[im_z <= 0] = 0 82 | features = torch.cat([features, im_z_norm], dim=1) 83 | 84 | feature_volume_all[batch_ind] = features 85 | return feature_volume_all, count 86 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_opts(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # dataset parameters 7 | parser.add_argument('--root_dir', type=str, required=True, 8 | help='root directory of dataset') 9 | parser.add_argument('--dataset_name', type=str, default='nsvf', 10 | choices=['nerf', 'nsvf', 'colmap', 'nerfpp', 'rtmv', 'scannet', 'google_scanned'], 11 | help='which dataset to train/test') 12 | parser.add_argument('--split', type=str, default='train', 13 | choices=['train', 'trainval', 'trainvaltest'], 14 | help='use which split to train') 15 | parser.add_argument('--downsample', type=float, default=1.0, 16 | help='downsample factor (<=1.0) for the images') 17 | 18 | # model parameters 19 | parser.add_argument('--scale', type=float, default=0.5, 20 | help='scene scale (whole scene must lie in [-scale, scale]^3') 21 | 22 | # loss parameters 23 | parser.add_argument('--distortion_loss_w', type=float, default=0, 24 | help='''weight of distortion loss (see losses.py), 25 | 0 to disable (default), to enable, 26 | a good value is 1e-3 for real scene and 1e-2 for synthetic scene 27 | ''') 28 | 29 | # training options 30 | parser.add_argument('--batch_size', type=int, default=8192, 31 | help='number of rays in a batch') 32 | parser.add_argument('--ray_sampling_strategy', type=str, default='all_images', 33 | choices=['all_images', 'same_image'], 34 | help=''' 35 | all_images: uniformly from all pixels of ALL images 36 | same_image: uniformly from all pixels of a SAME image 37 | ''') 38 | parser.add_argument('--num_epochs', type=int, default=30, 39 | help='number of training epochs') 40 | parser.add_argument('--num_gpus', type=int, default=1, 41 | help='number of gpus') 42 | parser.add_argument('--lr', type=float, default=1e-2, 43 | help='learning rate') 44 | # experimental training options 45 | parser.add_argument('--optimize_ext', action='store_true', default=False, 46 | help='whether to optimize extrinsics') 47 | parser.add_argument('--random_bg', action='store_true', default=False, 48 | help='''whether to train with random bg color (real scene only) 49 | to avoid objects with black color to be predicted as transparent 50 | ''') 51 | 52 | # validation options 53 | parser.add_argument('--eval_lpips', action='store_true', default=False, 54 | help='evaluate lpips metric (consumes more VRAM)') 55 | parser.add_argument('--val_only', action='store_true', default=False, 56 | help='run only validation (need to provide ckpt_path)') 57 | parser.add_argument('--no_save_test', action='store_true', default=False, 58 | help='whether to save test image and video') 59 | 60 | # scripts 61 | parser.add_argument('--exp_name', type=str, default='exp', 62 | help='experiment name') 63 | parser.add_argument('--ckpt_path', type=str, default=None, 64 | help='pretrained checkpoint to load (including optimizers, etc)') 65 | parser.add_argument('--weight_path', type=str, default=None, 66 | help='pretrained checkpoint to load (excluding optimizers, etc)') 67 | 68 | return parser.parse_args() 69 | -------------------------------------------------------------------------------- /models/csrc/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 5 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 6 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 7 | 8 | 9 | std::vector ray_aabb_intersect_cu( 10 | const torch::Tensor rays_o, 11 | const torch::Tensor rays_d, 12 | const torch::Tensor centers, 13 | const torch::Tensor half_sizes, 14 | const int max_hits 15 | ); 16 | 17 | 18 | std::vector ray_sphere_intersect_cu( 19 | const torch::Tensor rays_o, 20 | const torch::Tensor rays_d, 21 | const torch::Tensor centers, 22 | const torch::Tensor radii, 23 | const int max_hits 24 | ); 25 | 26 | 27 | void packbits_cu( 28 | torch::Tensor density_grid, 29 | const float density_threshold, 30 | torch::Tensor density_bitfield 31 | ); 32 | 33 | 34 | torch::Tensor morton3D_cu(const torch::Tensor coords); 35 | torch::Tensor morton3D_invert_cu(const torch::Tensor indices); 36 | 37 | 38 | std::vector raymarching_train_cu( 39 | const torch::Tensor rays_o, 40 | const torch::Tensor rays_d, 41 | const torch::Tensor hits_t, 42 | const torch::Tensor density_bitfield, 43 | const int cascades, 44 | const float scale, 45 | const float exp_step_factor, 46 | const torch::Tensor noise, 47 | const int grid_size, 48 | const int max_samples 49 | ); 50 | 51 | 52 | std::vector raymarching_test_cu( 53 | const torch::Tensor rays_o, 54 | const torch::Tensor rays_d, 55 | torch::Tensor hits_t, 56 | const torch::Tensor alive_indices, 57 | const torch::Tensor density_bitfield, 58 | const int cascades, 59 | const float scale, 60 | const float exp_step_factor, 61 | const int grid_size, 62 | const int max_samples, 63 | const int N_samples 64 | ); 65 | 66 | 67 | std::vector composite_train_fw_cu( 68 | const torch::Tensor sigmas, 69 | const torch::Tensor rgbs, 70 | const torch::Tensor deltas, 71 | const torch::Tensor ts, 72 | const torch::Tensor rays_a, 73 | const float T_threshold 74 | ); 75 | 76 | 77 | std::vector composite_train_bw_cu( 78 | const torch::Tensor dL_dopacity, 79 | const torch::Tensor dL_ddepth, 80 | const torch::Tensor dL_drgb, 81 | const torch::Tensor dL_dws, 82 | const torch::Tensor sigmas, 83 | const torch::Tensor rgbs, 84 | const torch::Tensor ws, 85 | const torch::Tensor deltas, 86 | const torch::Tensor ts, 87 | const torch::Tensor rays_a, 88 | const torch::Tensor opacity, 89 | const torch::Tensor depth, 90 | const torch::Tensor rgb, 91 | const float T_threshold 92 | ); 93 | 94 | 95 | void composite_test_fw_cu( 96 | const torch::Tensor sigmas, 97 | const torch::Tensor rgbs, 98 | const torch::Tensor deltas, 99 | const torch::Tensor ts, 100 | const torch::Tensor hits_t, 101 | const torch::Tensor alive_indices, 102 | const float T_threshold, 103 | const torch::Tensor N_eff_samples, 104 | torch::Tensor opacity, 105 | torch::Tensor depth, 106 | torch::Tensor rgb 107 | ); 108 | 109 | 110 | std::vector distortion_loss_fw_cu( 111 | const torch::Tensor ws, 112 | const torch::Tensor deltas, 113 | const torch::Tensor ts, 114 | const torch::Tensor rays_a 115 | ); 116 | 117 | 118 | torch::Tensor distortion_loss_bw_cu( 119 | const torch::Tensor dL_dloss, 120 | const torch::Tensor ws_inclusive_scan, 121 | const torch::Tensor wts_inclusive_scan, 122 | const torch::Tensor ws, 123 | const torch::Tensor deltas, 124 | const torch::Tensor ts, 125 | const torch::Tensor rays_a 126 | ); -------------------------------------------------------------------------------- /representations/grufusion/torchsparse_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from: 3 | https://github.com/mit-han-lab/spvnas/blob/b24f50379ed888d3a0e784508a809d4e92e820c0/core/models/utils.py 4 | """ 5 | import torch 6 | import torchsparse.nn.functional as F 7 | from torchsparse import PointTensor, SparseTensor 8 | from torchsparse.nn.utils import get_kernel_offsets 9 | 10 | __all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point'] 11 | 12 | 13 | # z: PointTensor 14 | # return: SparseTensor 15 | def initial_voxelize(z, init_res, after_res): 16 | new_float_coord = torch.cat( 17 | [(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1) 18 | 19 | pc_hash = F.sphash(torch.floor(new_float_coord).int()) 20 | sparse_hash = torch.unique(pc_hash) 21 | idx_query = F.sphashquery(pc_hash, sparse_hash) 22 | counts = F.spcount(idx_query.int(), len(sparse_hash)) 23 | 24 | inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query, 25 | counts) 26 | inserted_coords = torch.round(inserted_coords).int() 27 | inserted_feat = F.spvoxelize(z.F, idx_query, counts) 28 | 29 | new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) 30 | new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords) 31 | z.additional_features['idx_query'][1] = idx_query 32 | z.additional_features['counts'][1] = counts 33 | z.C = new_float_coord 34 | 35 | return new_tensor 36 | 37 | 38 | # x: SparseTensor, z: PointTensor 39 | # return: SparseTensor 40 | def point_to_voxel(x, z): 41 | if z.additional_features is None or z.additional_features.get('idx_query') is None\ 42 | or z.additional_features['idx_query'].get(x.s) is None: 43 | #pc_hash = hash_gpu(torch.floor(z.C).int()) 44 | pc_hash = F.sphash( 45 | torch.cat([ 46 | torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], 47 | z.C[:, -1].int().view(-1, 1) 48 | ], 1)) 49 | sparse_hash = F.sphash(x.C) 50 | idx_query = F.sphashquery(pc_hash, sparse_hash) 51 | counts = F.spcount(idx_query.int(), x.C.shape[0]) 52 | z.additional_features['idx_query'][x.s] = idx_query 53 | z.additional_features['counts'][x.s] = counts 54 | else: 55 | idx_query = z.additional_features['idx_query'][x.s] 56 | counts = z.additional_features['counts'][x.s] 57 | 58 | inserted_feat = F.spvoxelize(z.F, idx_query, counts) 59 | new_tensor = SparseTensor(inserted_feat, x.C, x.s) 60 | new_tensor.cmaps = x.cmaps 61 | new_tensor.kmaps = x.kmaps 62 | 63 | return new_tensor 64 | 65 | 66 | # x: SparseTensor, z: PointTensor 67 | # return: PointTensor 68 | def voxel_to_point(x, z, nearest=False): 69 | if z.idx_query is None or z.weights is None or z.idx_query.get( 70 | x.s) is None or z.weights.get(x.s) is None: 71 | off = get_kernel_offsets(2, x.s, 1, device=z.F.device) 72 | #old_hash = kernel_hash_gpu(torch.floor(z.C).int(), off) 73 | old_hash = F.sphash( 74 | torch.cat([ 75 | torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], 76 | z.C[:, -1].int().view(-1, 1) 77 | ], 1), off) 78 | pc_hash = F.sphash(x.C.to(z.F.device)) 79 | idx_query = F.sphashquery(old_hash, pc_hash) 80 | weights = F.calc_ti_weights(z.C, idx_query, 81 | scale=x.s[0]).transpose(0, 1).contiguous() 82 | idx_query = idx_query.transpose(0, 1).contiguous() 83 | if nearest: 84 | weights[:, 1:] = 0. 85 | idx_query[:, 1:] = -1 86 | new_feat = F.spdevoxelize(x.F, idx_query, weights) 87 | new_tensor = PointTensor(new_feat, 88 | z.C, 89 | idx_query=z.idx_query, 90 | weights=z.weights) 91 | new_tensor.additional_features = z.additional_features 92 | new_tensor.idx_query[x.s] = idx_query 93 | new_tensor.weights[x.s] = weights 94 | z.idx_query[x.s] = idx_query 95 | z.weights[x.s] = weights 96 | 97 | else: 98 | new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), 99 | z.weights.get(x.s)) 100 | new_tensor = PointTensor(new_feat, 101 | z.C, 102 | idx_query=z.idx_query, 103 | weights=z.weights) 104 | new_tensor.additional_features = z.additional_features 105 | 106 | return new_tensor -------------------------------------------------------------------------------- /datasets/nsvf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import get_ray_directions 8 | from .color_utils import read_image 9 | 10 | from .base import BaseDataset 11 | 12 | 13 | class NSVFDataset(BaseDataset): 14 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 15 | super().__init__(root_dir, split, downsample) 16 | 17 | self.read_intrinsics() 18 | 19 | if kwargs.get('read_meta', True): 20 | xyz_min, xyz_max = \ 21 | np.loadtxt(os.path.join(root_dir, 'bbox.txt'))[:6].reshape(2, 3) 22 | self.shift = (xyz_max+xyz_min)/2 23 | self.scale = (xyz_max-xyz_min).max()/2 * 1.05 # enlarge a little 24 | 25 | # hard-code fix the bound error for some scenes... 26 | if 'Mic' in self.root_dir: self.scale *= 1.2 27 | elif 'Lego' in self.root_dir: self.scale *= 1.1 28 | 29 | self.read_meta(split) 30 | 31 | def read_intrinsics(self): 32 | if 'Synthetic' in self.root_dir or 'Ignatius' in self.root_dir: 33 | with open(os.path.join(self.root_dir, 'intrinsics.txt')) as f: 34 | fx = fy = float(f.readline().split()[0]) * self.downsample 35 | if 'Synthetic' in self.root_dir: 36 | w = h = int(800*self.downsample) 37 | else: 38 | w, h = int(1920*self.downsample), int(1080*self.downsample) 39 | 40 | K = np.float32([[fx, 0, w/2], 41 | [0, fy, h/2], 42 | [0, 0, 1]]) 43 | else: 44 | K = np.loadtxt(os.path.join(self.root_dir, 'intrinsics.txt'), 45 | dtype=np.float32)[:3, :3] 46 | if 'BlendedMVS' in self.root_dir: 47 | w, h = int(768*self.downsample), int(576*self.downsample) 48 | elif 'Tanks' in self.root_dir: 49 | w, h = int(1920*self.downsample), int(1080*self.downsample) 50 | K[:2] *= self.downsample 51 | 52 | self.K = torch.FloatTensor(K) 53 | self.directions = get_ray_directions(h, w, self.K) 54 | self.img_wh = (w, h) 55 | 56 | def read_meta(self, split): 57 | self.rays = [] 58 | self.poses = [] 59 | 60 | if split == 'test_traj': # BlendedMVS and TanksAndTemple 61 | if 'Ignatius' in self.root_dir: 62 | poses_path = \ 63 | sorted(glob.glob(os.path.join(self.root_dir, 'test_pose/*.txt'))) 64 | poses = [np.loadtxt(p) for p in poses_path] 65 | else: 66 | poses = np.loadtxt(os.path.join(self.root_dir, 'test_traj.txt')) 67 | poses = poses.reshape(-1, 4, 4) 68 | for pose in poses: 69 | c2w = pose[:3] 70 | c2w[:, 0] *= -1 # [left down front] to [right down front] 71 | c2w[:, 3] -= self.shift 72 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5] 73 | self.poses += [c2w] 74 | else: 75 | if split == 'train': prefix = '0_' 76 | elif split == 'trainval': prefix = '[0-1]_' 77 | elif split == 'trainvaltest': prefix = '[0-2]_' 78 | elif split == 'val': prefix = '1_' 79 | elif 'Synthetic' in self.root_dir: prefix = '2_' # test set for synthetic scenes 80 | elif split == 'test': prefix = '1_' # test set for real scenes 81 | else: raise ValueError(f'{split} split not recognized!') 82 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'rgb', prefix+'*.png'))) 83 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'pose', prefix+'*.txt'))) 84 | 85 | print(f'Loading {len(img_paths)} {split} images ...') 86 | for img_path, pose in tqdm(zip(img_paths, poses)): 87 | c2w = np.loadtxt(pose)[:3] 88 | c2w[:, 3] -= self.shift 89 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5] 90 | self.poses += [c2w] 91 | 92 | img = read_image(img_path, self.img_wh) 93 | if 'Jade' in self.root_dir or 'Fountain' in self.root_dir: 94 | # these scenes have black background, changing to white 95 | img[torch.all(img<=0.1, dim=-1)] = 1.0 96 | 97 | self.rays += [img] 98 | 99 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 100 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 101 | -------------------------------------------------------------------------------- /datasets/_google_scanned_obj.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from IBRNet's codebase. 2 | # Copyright 2020 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import numpy as np 19 | import imageio 20 | import torch 21 | from torch.utils.data import Dataset 22 | import glob 23 | import sys 24 | sys.path.append('../') 25 | from .data_utils import rectify_inplane_rotation, get_nearest_pose_ids 26 | 27 | 28 | # only for training 29 | class GoogleScannedDataset(Dataset): 30 | def __init__(self, args, mode, **kwargs): 31 | self.folder_path = os.path.join(args.rootdir, 'data/google_scanned_objects/') 32 | self.num_source_views = args.num_source_views 33 | self.rectify_inplane_rotation = args.rectify_inplane_rotation 34 | self.scene_path_list = glob.glob(os.path.join(self.folder_path, '*')) 35 | 36 | all_rgb_files = [] 37 | all_pose_files = [] 38 | all_intrinsics_files = [] 39 | num_files = 250 40 | for i, scene_path in enumerate(self.scene_path_list): 41 | rgb_files = [os.path.join(scene_path, 'rgb', f) 42 | for f in sorted(os.listdir(os.path.join(scene_path, 'rgb')))] 43 | pose_files = [f.replace('rgb', 'pose').replace('png', 'txt') for f in rgb_files] 44 | intrinsics_files = [f.replace('rgb', 'intrinsics').replace('png', 'txt') for f in rgb_files] 45 | 46 | if np.min([len(rgb_files), len(pose_files), len(intrinsics_files)]) \ 47 | < num_files: 48 | print(scene_path) 49 | continue 50 | 51 | all_rgb_files.append(rgb_files) 52 | all_pose_files.append(pose_files) 53 | all_intrinsics_files.append(intrinsics_files) 54 | 55 | index = np.arange(len(all_rgb_files)) 56 | self.all_rgb_files = np.array(all_rgb_files)[index] 57 | self.all_pose_files = np.array(all_pose_files)[index] 58 | self.all_intrinsics_files = np.array(all_intrinsics_files)[index] 59 | 60 | def __len__(self): 61 | return len(self.all_rgb_files) 62 | 63 | def __getitem__(self, idx): 64 | rgb_files = self.all_rgb_files[idx] 65 | pose_files = self.all_pose_files[idx] 66 | intrinsics_files = self.all_intrinsics_files[idx] 67 | 68 | id_render = np.random.choice(np.arange(len(rgb_files))) 69 | train_poses = np.stack([np.loadtxt(file).reshape(4, 4) for file in pose_files], axis=0) 70 | render_pose = train_poses[id_render] 71 | subsample_factor = np.random.choice(np.arange(1, 6), p=[0.3, 0.25, 0.2, 0.2, 0.05]) 72 | 73 | id_feat_pool = get_nearest_pose_ids(render_pose, 74 | train_poses, 75 | self.num_source_views*subsample_factor, 76 | tar_id=id_render, 77 | angular_dist_method='vector') 78 | id_feat = np.random.choice(id_feat_pool, self.num_source_views, replace=False) 79 | 80 | assert id_render not in id_feat 81 | # occasionally include input image 82 | if np.random.choice([0, 1], p=[0.995, 0.005]): 83 | id_feat[np.random.choice(len(id_feat))] = id_render 84 | 85 | rgb = imageio.imread(rgb_files[id_render]).astype(np.float32) / 255. 86 | 87 | intrinsics = np.loadtxt(intrinsics_files[id_render]) 88 | img_size = rgb.shape[:2] 89 | camera = np.concatenate((list(img_size), intrinsics, render_pose.flatten())).astype(np.float32) 90 | 91 | # get depth range 92 | min_ratio = 0.1 93 | origin_depth = np.linalg.inv(render_pose)[2, 3] 94 | max_radius = 0.5 * np.sqrt(2) * 1.1 95 | near_depth = max(origin_depth - max_radius, min_ratio * origin_depth) 96 | far_depth = origin_depth + max_radius 97 | depth_range = torch.tensor([near_depth, far_depth]) 98 | 99 | src_rgbs = [] 100 | src_cameras = [] 101 | for id in id_feat: 102 | src_rgb = imageio.imread(rgb_files[id]).astype(np.float32) / 255. 103 | pose = np.loadtxt(pose_files[id]) 104 | if self.rectify_inplane_rotation: 105 | pose, src_rgb = rectify_inplane_rotation(pose.reshape(4, 4), render_pose, src_rgb) 106 | 107 | src_rgbs.append(src_rgb) 108 | intrinsics = np.loadtxt(intrinsics_files[id]) 109 | img_size = src_rgb.shape[:2] 110 | src_camera = np.concatenate((list(img_size), intrinsics, pose.flatten())).astype(np.float32) 111 | src_cameras.append(src_camera) 112 | 113 | src_rgbs = np.stack(src_rgbs) 114 | src_cameras = np.stack(src_cameras) 115 | 116 | return {'rgb': torch.from_numpy(rgb), 117 | 'camera': torch.from_numpy(camera), 118 | 'rgb_path': rgb_files[id_render], 119 | 'src_rgbs': torch.from_numpy(src_rgbs), 120 | 'src_cameras': torch.from_numpy(src_cameras), 121 | 'depth_range': depth_range 122 | } -------------------------------------------------------------------------------- /models/rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .custom_functions import \ 3 | RayAABBIntersector, RayMarcher, VolumeRenderer 4 | from einops import rearrange 5 | import vren 6 | 7 | MAX_SAMPLES = 1024 8 | NEAR_DISTANCE = 0.01 9 | 10 | 11 | @torch.cuda.amp.autocast() 12 | def render(model, rays_o, rays_d, **kwargs): 13 | """ 14 | Render rays by 15 | 1. Compute the intersection of the rays with the scene bounding box 16 | 2. Follow the process in @render_func (different for train/test) 17 | 18 | Inputs: 19 | model: 20 | rays_o: (N_rays, 3) ray origins 21 | rays_d: (N_rays, 3) ray directions 22 | 23 | Outputs: 24 | result: dictionary containing final rgb and depth 25 | """ 26 | rays_o = rays_o.contiguous(); rays_d = rays_d.contiguous() 27 | _, hits_t, _ = \ 28 | RayAABBIntersector.apply(rays_o, rays_d, model.center, model.half_size, 1) 29 | hits_t[(hits_t[:, 0, 0]>=0)&(hits_t[:, 0, 0] (n1 n2) c') 90 | dirs = rearrange(dirs, 'n1 n2 c -> (n1 n2) c') 91 | valid_mask = ~torch.all(dirs==0, dim=1) 92 | if valid_mask.sum()==0: break 93 | 94 | sigmas = torch.zeros(len(xyzs), device=device) 95 | rgbs = torch.zeros(len(xyzs), 3, device=device) 96 | sigmas[valid_mask], _rgbs = model(xyzs[valid_mask], dirs[valid_mask], **kwargs) 97 | rgbs[valid_mask] = _rgbs.float() 98 | sigmas = rearrange(sigmas, '(n1 n2) -> n1 n2', n2=N_samples) 99 | rgbs = rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=N_samples) 100 | 101 | vren.composite_test_fw( 102 | sigmas, rgbs, deltas, ts, 103 | hits_t[:, 0], alive_indices, kwargs.get('T_threshold', 1e-4), 104 | N_eff_samples, opacity, depth, rgb) 105 | alive_indices = alive_indices[alive_indices>=0] # remove converged rays 106 | 107 | results['opacity'] = opacity 108 | results['depth'] = depth 109 | results['rgb'] = rgb 110 | results['total_samples'] = total_samples # total samples for all rays 111 | 112 | if exp_step_factor==0: # synthetic 113 | rgb_bg = torch.ones(3, device=device) 114 | else: # real 115 | rgb_bg = torch.zeros(3, device=device) 116 | results['rgb'] += rgb_bg*rearrange(1-opacity, 'n -> n 1') 117 | 118 | return results 119 | 120 | 121 | def __render_rays_train(model, rays_o, rays_d, hits_t, **kwargs): 122 | """ 123 | Render rays by 124 | 1. March the rays along their directions, querying @density_bitfield 125 | to skip empty space, and get the effective sample points (where 126 | there is object) 127 | 2. Infer the NN at these positions and view directions to get properties 128 | (currently sigmas and rgbs) 129 | 3. Use volume rendering to combine the result (front to back compositing 130 | and early stop the ray if its transmittance is below a threshold) 131 | """ 132 | exp_step_factor = kwargs.get('exp_step_factor', 0.) 133 | results = {} 134 | 135 | (rays_a, xyzs, dirs, 136 | results['deltas'], results['ts'], results['rm_samples']) = \ 137 | RayMarcher.apply( 138 | rays_o, rays_d, hits_t[:, 0], model.density_bitfield, 139 | model.cascades, model.scale, 140 | exp_step_factor, model.grid_size, MAX_SAMPLES) 141 | 142 | for k, v in kwargs.items(): # supply additional inputs, repeated per ray 143 | if isinstance(v, torch.Tensor): 144 | kwargs[k] = torch.repeat_interleave(v[rays_a[:, 0]], rays_a[:, 2], 0) 145 | sigmas, rgbs = model(xyzs, dirs, **kwargs) 146 | 147 | (results['vr_samples'], results['opacity'], 148 | results['depth'], results['rgb'], results['ws']) = \ 149 | VolumeRenderer.apply(sigmas, rgbs, results['deltas'], results['ts'], 150 | rays_a, kwargs.get('T_threshold', 1e-4)) 151 | results['rays_a'] = rays_a 152 | 153 | if exp_step_factor==0: # synthetic 154 | rgb_bg = torch.ones(3, device=rays_o.device) 155 | else: # real 156 | if kwargs.get('random_bg', False): 157 | rgb_bg = torch.rand(3, device=rays_o.device) 158 | else: 159 | rgb_bg = torch.zeros(3, device=rays_o.device) 160 | results['rgb'] = results['rgb'] + \ 161 | rgb_bg*rearrange(1-results['opacity'], 'n -> n 1') 162 | 163 | return results -------------------------------------------------------------------------------- /models/custom_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import vren 3 | from torch.cuda.amp import custom_fwd, custom_bwd 4 | from torch_scatter import segment_csr 5 | from einops import rearrange 6 | 7 | 8 | class RayAABBIntersector(torch.autograd.Function): 9 | """ 10 | Computes the intersections of rays and axis-aligned voxels. 11 | 12 | Inputs: 13 | rays_o: (N_rays, 3) ray origins 14 | rays_d: (N_rays, 3) ray directions 15 | centers: (N_voxels, 3) voxel centers 16 | half_sizes: (N_voxels, 3) voxel half sizes 17 | max_hits: maximum number of intersected voxels to keep for one ray 18 | (for a cubic scene, this is at most 3*N_voxels^(1/3)-2) 19 | 20 | Outputs: 21 | hits_cnt: (N_rays) number of hits for each ray 22 | (followings are from near to far) 23 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit) 24 | hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit) 25 | """ 26 | @staticmethod 27 | @custom_fwd(cast_inputs=torch.float32) 28 | def forward(ctx, rays_o, rays_d, center, half_size, max_hits): 29 | return vren.ray_aabb_intersect(rays_o, rays_d, center, half_size, max_hits) 30 | 31 | 32 | class RaySphereIntersector(torch.autograd.Function): 33 | """ 34 | Computes the intersections of rays and spheres. 35 | 36 | Inputs: 37 | rays_o: (N_rays, 3) ray origins 38 | rays_d: (N_rays, 3) ray directions 39 | centers: (N_spheres, 3) sphere centers 40 | radii: (N_spheres, 3) radii 41 | max_hits: maximum number of intersected spheres to keep for one ray 42 | 43 | Outputs: 44 | hits_cnt: (N_rays) number of hits for each ray 45 | (followings are from near to far) 46 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit) 47 | hits_sphere_idx: (N_rays, max_hits) hit sphere indices (-1 if no hit) 48 | """ 49 | @staticmethod 50 | @custom_fwd(cast_inputs=torch.float32) 51 | def forward(ctx, rays_o, rays_d, center, radii, max_hits): 52 | return vren.ray_sphere_intersect(rays_o, rays_d, center, radii, max_hits) 53 | 54 | 55 | class RayMarcher(torch.autograd.Function): 56 | """ 57 | March the rays to get sample point positions and directions. 58 | 59 | Inputs: 60 | rays_o: (N_rays, 3) ray origins 61 | rays_d: (N_rays, 3) normalized ray directions 62 | hits_t: (N_rays, 2) near and far bounds from aabb intersection 63 | density_bitfield: (C*G**3//8) 64 | cascades: int 65 | scale: float 66 | exp_step_factor: the exponential factor to scale the steps 67 | grid_size: int 68 | max_samples: int 69 | 70 | Outputs: 71 | rays_a: (N_rays) ray_idx, start_idx, N_samples 72 | xyzs: (N, 3) sample positions 73 | dirs: (N, 3) sample view directions 74 | deltas: (N) dt for integration 75 | ts: (N) sample ts 76 | """ 77 | @staticmethod 78 | @custom_fwd(cast_inputs=torch.float32) 79 | def forward(ctx, rays_o, rays_d, hits_t, 80 | density_bitfield, cascades, scale, exp_step_factor, 81 | grid_size, max_samples): 82 | # noise to perturb the first sample of each ray 83 | noise = torch.rand_like(rays_o[:, 0]) 84 | 85 | rays_a, xyzs, dirs, deltas, ts, counter = \ 86 | vren.raymarching_train( 87 | rays_o, rays_d, hits_t, 88 | density_bitfield, cascades, scale, 89 | exp_step_factor, noise, grid_size, max_samples) 90 | 91 | total_samples = counter[0] # total samples for all rays 92 | # remove redundant output 93 | xyzs = xyzs[:total_samples] 94 | dirs = dirs[:total_samples] 95 | deltas = deltas[:total_samples] 96 | ts = ts[:total_samples] 97 | 98 | ctx.save_for_backward(rays_a, ts) 99 | 100 | return rays_a, xyzs, dirs, deltas, ts, total_samples 101 | 102 | @staticmethod 103 | @custom_bwd 104 | def backward(ctx, dL_drays_a, dL_dxyzs, dL_ddirs, 105 | dL_ddeltas, dL_dts, dL_dtotal_samples): 106 | rays_a, ts = ctx.saved_tensors 107 | segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1]+rays_a[-1:, 2]]) 108 | dL_drays_o = segment_csr(dL_dxyzs, segments) 109 | dL_drays_d = \ 110 | segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments) 111 | 112 | return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None 113 | 114 | 115 | class VolumeRenderer(torch.autograd.Function): 116 | """ 117 | Volume rendering with different number of samples per ray 118 | Used in training only 119 | 120 | Inputs: 121 | sigmas: (N) 122 | rgbs: (N, 3) 123 | deltas: (N) 124 | ts: (N) 125 | rays_a: (N_rays, 3) ray_idx, start_idx, N_samples 126 | meaning each entry corresponds to the @ray_idx th ray, 127 | whose samples are [start_idx:start_idx+N_samples] 128 | T_threshold: float, stop the ray if the transmittance is below it 129 | 130 | Outputs: 131 | total_samples: int, total effective samples 132 | opacity: (N_rays) 133 | depth: (N_rays) 134 | rgb: (N_rays, 3) 135 | ws: (N) sample point weights 136 | """ 137 | @staticmethod 138 | @custom_fwd(cast_inputs=torch.float32) 139 | def forward(ctx, sigmas, rgbs, deltas, ts, rays_a, T_threshold): 140 | total_samples, opacity, depth, rgb, ws = \ 141 | vren.composite_train_fw(sigmas, rgbs, deltas, ts, 142 | rays_a, T_threshold) 143 | ctx.save_for_backward(sigmas, rgbs, deltas, ts, rays_a, 144 | opacity, depth, rgb, ws) 145 | ctx.T_threshold = T_threshold 146 | return total_samples.sum(), opacity, depth, rgb, ws 147 | 148 | @staticmethod 149 | @custom_bwd 150 | def backward(ctx, dL_dtotal_samples, dL_dopacity, dL_ddepth, dL_drgb, dL_dws): 151 | sigmas, rgbs, deltas, ts, rays_a, \ 152 | opacity, depth, rgb, ws = ctx.saved_tensors 153 | dL_dsigmas, dL_drgbs = \ 154 | vren.composite_train_bw(dL_dopacity, dL_ddepth, dL_drgb, dL_dws, 155 | sigmas, rgbs, ws, deltas, ts, 156 | rays_a, 157 | opacity, depth, rgb, 158 | ctx.T_threshold) 159 | return dL_dsigmas, dL_drgbs, None, None, None, None 160 | 161 | 162 | class TruncExp(torch.autograd.Function): 163 | @staticmethod 164 | @custom_fwd(cast_inputs=torch.float32) 165 | def forward(ctx, x): 166 | ctx.save_for_backward(x) 167 | return torch.exp(x) 168 | 169 | @staticmethod 170 | @custom_bwd 171 | def backward(ctx, dL_dout): 172 | x = ctx.saved_tensors[0] 173 | return dL_dout * torch.exp(x.clamp(-15, 15)) 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeRFusion: Fusing Radiance Fields for Large-Scale Scene Reconstruction (CVPR 2022 Oral) 2 | 3 | [Project Sites](https://jetd1.github.io/NeRFusion-Web/) 4 | | [Paper](https://arxiv.org/abs/2203.11283) | 5 | Primary contact: [Xiaoshuai Zhang](https://jetd1.github.io/NeRFusion-Web/) 6 | 7 | ## Note 8 | 9 | This `dev` branch is currently **under development**. We will finish and merge this into `main` in a few days. This is a re-development of the original NeRFusion code based heavily on [nerf_pl](https://github.com/kwea123/nerf_pl), [NeuralRecon](https://github.com/zju3dv/NeuralRecon), [MVSNeRF](https://github.com/apchenstu/mvsnerf). We thank the authors for sharing their code. The model released in this repo is optimized for large-scale scenes further compared to the CVPR submission. A changelist will be provided. 10 | 11 | 12 | ## Introduction 13 | 14 | 15 | 16 | While NeRF has shown great success for neural reconstruction and rendering, its limited MLP capacity and long per-scene optimization times make it challenging to model large-scale indoor scenes. In contrast, classical 3D reconstruction methods can handle large-scale scenes but do not produce realistic renderings. We propose NeRFusion, a method that combines the advantages of NeRF and TSDF-based fusion techniques to achieve efficient large-scale reconstruction and photo-realistic rendering. We process the input image sequence to predict per-frame local radiance fields via direct network inference. These are then fused using a novel recurrent neural network that incrementally reconstructs a global, sparse scene representation in real-time at 22 fps. This volume can be further fine-tuned to boost rendering quality. We demonstrate that NeRFusion achieves state-of-the-art quality on both large-scale indoor and small-scale object scenes, with substantially faster reconstruction speed than NeRF and other recent methods. 17 | 18 | 19 | 20 | ## Reference 21 | Please cite our paper if you are interested 22 | NeRFusion: Fusing Radiance Fields for Large-Scale Scene Reconstruction.     23 | ``` 24 | @article{zhang2022nerfusion, 25 | author = {Zhang, Xiaoshuai and Bi, Sai and Sunkavalli, Kalyan and Su, Hao and Xu, Zexiang}, 26 | title = {NeRFusion: Fusing Radiance Fields for Large-Scale Scene Reconstruction}, 27 | journal = {CVPR}, 28 | year = {2022}, 29 | } 30 | ``` 31 | 32 | 33 | ## Installation 34 | 35 | ### Requirements 36 | All the codes are tested in the following environment: 37 | * Linux (Ubuntu 20.04 or above) 38 | * 32GB RAM (in order to load full size images) 39 | * NVIDIA GPU with Compute Compatibility >= 75 and VRAM >= 6GB, CUDA >= 11.3 40 | 41 | ### Dependencies 42 | * Python>=3.8 (installation via [anaconda](https://www.anaconda.com/distribution/) is recommended, use `conda create -n ngp_pl python=3.8` to create a conda environment and activate it by `conda activate ngp_pl`) 43 | * Python libraries 44 | * Install `pytorch>=1.11.0` by `pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113` 45 | * Install `torch-scatter` following their [instruction](https://github.com/rusty1s/pytorch_scatter#installation) 46 | * Install `tinycudann` following their [instruction](https://github.com/NVlabs/tiny-cuda-nn#requirements) (compilation and pytorch extension) 47 | * Install `apex` following their [instruction](https://github.com/NVIDIA/apex#linux) 48 | * Install `torchsparse` following their [instruction](https://github.com/mit-han-lab/torchsparse#installation) 49 | * Install core requirements by `pip install -r requirements.txt` 50 | 51 | * Cuda extension: Upgrade `pip` to >= 22.1 and run `pip install models/csrc/` (please run this each time you `pull` the code) 52 | 53 | ## Data Preparation 54 | We follow the same data organization as the original NeRF, which expects camera parameters to be provided in a `transforms.json` file. We also support data from NSVF, NeRF++, colmap and ScanNet. 55 | 56 | ### Custom Sequence 57 | You can test our pre-trained model on custom sequences captured under casual settings. To do so, the data should be organized in the original NeRF-style: 58 | 59 | ``` 60 | data 61 | ├── transforms.json 62 | ├── images 63 | │ ├── 0000.jpg 64 | ├── 0001.jpg 65 | ├── ... 66 | ``` 67 | 68 | If a video is all you have (no camera parameters). You should install `ffmpeg` and `colmap`. Then follow the instructions as introduced in [instant-ngp](https://github.com/NVlabs/instant-ngp/blob/master/scripts/colmap2nerf.py) to generate the `transformas.json`. 69 | 70 | ## Inference using Pre-trained Network 71 | ```bash 72 | python train.py --dataset_name scannet --root_dir DIR_TO_SCANNET_SCENE0000_01 --exp_name EXP_NAME --ckpt_path PATH_TO_G_CKPT 73 | ``` 74 | Please find the pre-trained weights for networks [here](https://drive.google.com/file/d/1YjwO1Q2CAn7tdnwVzDgL_iEH_m7cSiHW/view?usp=sharing). 75 | 76 | ### Per-Scene Optimization 77 | Note: currently this script trains model from scratch. We are updating generalized pipeline. 78 | ```bash 79 | python train.py --dataset_name DATASET_NAME --root_dir DIR_TO_SCANNET_SCENE --exp_name EXP_NAME 80 | ``` 81 | 82 | You can test using our [sample data](https://drive.google.com/file/d/1vy5whVQbMcyKTK5W0LJsTlDgCS7wGih7/view?usp=sharing) on ScanNet. You can also try evaluation using our [sample checkpoint](https://drive.google.com/file/d/1wHSPMSGhy1TVSWCYttz2JDNUTMTeI9w0/view?usp=sharing) on ScanNet: 83 | ```bash 84 | python train.py --dataset_name scannet --root_dir DIR_TO_SCANNET_SCENE0000_01 --exp_name EXP_NAME --val_only --ckpt_path PATH_TO_SCANNET_SCENE0000_01_CKPT 85 | ``` 86 | 87 | ## Training Procedure 88 | 89 | Please download and organize the datasets in the following manner: 90 | ``` 91 | ├──data/ 92 | ├──DTU/ 93 | ├──google_scanned_objects/ 94 | ├──ScanNet/ 95 | ``` 96 | 97 | For google scanned objects, we used [renderings](https://drive.google.com/file/d/1w1Cs0yztH6kE3JIz7mdggvPGCwIKkVi2/view?usp=sharing) from IBRNet. Download with: 98 | 99 | ``` 100 | gdown https://drive.google.com/uc?id=1w1Cs0yztH6kE3JIz7mdggvPGCwIKkVi2 101 | unzip google_scanned_objects_renderings.zip 102 | ``` 103 | 104 | For DTU and ScanNet, please use the official toolkits for downloading and processing of the data, and unpack the root directory to the `data` folder mentioned above. Train with: 105 | 106 | ```bash 107 | python train.py --train_root_dir DIR_TO_DATA --exp_name EXP_NAME 108 | ``` 109 | 110 | See `opt.py` for more options. 111 | 112 | 113 | ## Performance 114 | 115 | We applied optimization on large-scale scenes in this code base, and the performance may not exactly match all numbers in the paper. Our test results with this code base is reported here. For generalized no per-scene optimization setting, we achieve 23.35/0.844/0.333 on ScanNet eight scenes, 26.23/0.925/0.169 on DTU, and 24.21/0.888/0.129 on NeRF Synthetic. For per-scene optimization setting, we achieve 27.78/0.917/0.199 on ScanNet eight scenes, 31.76/0.961/0.118 on DTU, and 29.88/0.949/0.099 on NeRF Synthetic. 116 | 117 | 118 | ## Acknowledgement 119 | Our repo is developed based on [nerf_pl](https://github.com/kwea123/nerf_pl), [NeuralRecon](https://github.com/zju3dv/NeuralRecon) and [MVSNeRF](https://github.com/apchenstu/mvsnerf). Please also consider citing the corresponding papers. 120 | 121 | The project is conducted collaboratively between Adobe Research and University of California, San Diego. 122 | 123 | ## LICENSE 124 | 125 | The code is released under MIT License. 126 | -------------------------------------------------------------------------------- /models/csrc/binding.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | 4 | std::vector ray_aabb_intersect( 5 | const torch::Tensor rays_o, 6 | const torch::Tensor rays_d, 7 | const torch::Tensor centers, 8 | const torch::Tensor half_sizes, 9 | const int max_hits 10 | ){ 11 | CHECK_INPUT(rays_o); 12 | CHECK_INPUT(rays_d); 13 | CHECK_INPUT(centers); 14 | CHECK_INPUT(half_sizes); 15 | return ray_aabb_intersect_cu(rays_o, rays_d, centers, half_sizes, max_hits); 16 | } 17 | 18 | 19 | std::vector ray_sphere_intersect( 20 | const torch::Tensor rays_o, 21 | const torch::Tensor rays_d, 22 | const torch::Tensor centers, 23 | const torch::Tensor radii, 24 | const int max_hits 25 | ){ 26 | CHECK_INPUT(rays_o); 27 | CHECK_INPUT(rays_d); 28 | CHECK_INPUT(centers); 29 | CHECK_INPUT(radii); 30 | return ray_sphere_intersect_cu(rays_o, rays_d, centers, radii, max_hits); 31 | } 32 | 33 | 34 | void packbits( 35 | torch::Tensor density_grid, 36 | const float density_threshold, 37 | torch::Tensor density_bitfield 38 | ){ 39 | CHECK_INPUT(density_grid); 40 | CHECK_INPUT(density_bitfield); 41 | 42 | return packbits_cu(density_grid, density_threshold, density_bitfield); 43 | } 44 | 45 | 46 | torch::Tensor morton3D(const torch::Tensor coords){ 47 | CHECK_INPUT(coords); 48 | 49 | return morton3D_cu(coords); 50 | } 51 | 52 | 53 | torch::Tensor morton3D_invert(const torch::Tensor indices){ 54 | CHECK_INPUT(indices); 55 | 56 | return morton3D_invert_cu(indices); 57 | } 58 | 59 | 60 | std::vector raymarching_train( 61 | const torch::Tensor rays_o, 62 | const torch::Tensor rays_d, 63 | const torch::Tensor hits_t, 64 | const torch::Tensor density_bitfield, 65 | const int cascades, 66 | const float scale, 67 | const float exp_step_factor, 68 | const torch::Tensor noise, 69 | const int grid_size, 70 | const int max_samples 71 | ){ 72 | CHECK_INPUT(rays_o); 73 | CHECK_INPUT(rays_d); 74 | CHECK_INPUT(hits_t); 75 | CHECK_INPUT(density_bitfield); 76 | CHECK_INPUT(noise); 77 | 78 | return raymarching_train_cu( 79 | rays_o, rays_d, hits_t, density_bitfield, cascades, 80 | scale, exp_step_factor, noise, grid_size, max_samples); 81 | } 82 | 83 | 84 | std::vector raymarching_test( 85 | const torch::Tensor rays_o, 86 | const torch::Tensor rays_d, 87 | torch::Tensor hits_t, 88 | const torch::Tensor alive_indices, 89 | const torch::Tensor density_bitfield, 90 | const int cascades, 91 | const float scale, 92 | const float exp_step_factor, 93 | const int grid_size, 94 | const int max_samples, 95 | const int N_samples 96 | ){ 97 | CHECK_INPUT(rays_o); 98 | CHECK_INPUT(rays_d); 99 | CHECK_INPUT(hits_t); 100 | CHECK_INPUT(alive_indices); 101 | CHECK_INPUT(density_bitfield); 102 | 103 | return raymarching_test_cu( 104 | rays_o, rays_d, hits_t, alive_indices, density_bitfield, cascades, 105 | scale, exp_step_factor, grid_size, max_samples, N_samples); 106 | } 107 | 108 | 109 | std::vector composite_train_fw( 110 | const torch::Tensor sigmas, 111 | const torch::Tensor rgbs, 112 | const torch::Tensor deltas, 113 | const torch::Tensor ts, 114 | const torch::Tensor rays_a, 115 | const float opacity_threshold 116 | ){ 117 | CHECK_INPUT(sigmas); 118 | CHECK_INPUT(rgbs); 119 | CHECK_INPUT(deltas); 120 | CHECK_INPUT(ts); 121 | CHECK_INPUT(rays_a); 122 | 123 | return composite_train_fw_cu( 124 | sigmas, rgbs, deltas, ts, 125 | rays_a, opacity_threshold); 126 | } 127 | 128 | 129 | std::vector composite_train_bw( 130 | const torch::Tensor dL_dopacity, 131 | const torch::Tensor dL_ddepth, 132 | const torch::Tensor dL_drgb, 133 | const torch::Tensor dL_dws, 134 | const torch::Tensor sigmas, 135 | const torch::Tensor rgbs, 136 | const torch::Tensor ws, 137 | const torch::Tensor deltas, 138 | const torch::Tensor ts, 139 | const torch::Tensor rays_a, 140 | const torch::Tensor opacity, 141 | const torch::Tensor depth, 142 | const torch::Tensor rgb, 143 | const float opacity_threshold 144 | ){ 145 | CHECK_INPUT(dL_dopacity); 146 | CHECK_INPUT(dL_ddepth); 147 | CHECK_INPUT(dL_drgb); 148 | CHECK_INPUT(dL_dws); 149 | CHECK_INPUT(sigmas); 150 | CHECK_INPUT(rgbs); 151 | CHECK_INPUT(ws); 152 | CHECK_INPUT(deltas); 153 | CHECK_INPUT(ts); 154 | CHECK_INPUT(rays_a); 155 | CHECK_INPUT(opacity); 156 | CHECK_INPUT(depth); 157 | CHECK_INPUT(rgb); 158 | 159 | return composite_train_bw_cu( 160 | dL_dopacity, dL_ddepth, dL_drgb, dL_dws, 161 | sigmas, rgbs, ws, deltas, ts, rays_a, 162 | opacity, depth, rgb, opacity_threshold); 163 | } 164 | 165 | 166 | void composite_test_fw( 167 | const torch::Tensor sigmas, 168 | const torch::Tensor rgbs, 169 | const torch::Tensor deltas, 170 | const torch::Tensor ts, 171 | const torch::Tensor hits_t, 172 | const torch::Tensor alive_indices, 173 | const float T_threshold, 174 | const torch::Tensor N_eff_samples, 175 | torch::Tensor opacity, 176 | torch::Tensor depth, 177 | torch::Tensor rgb 178 | ){ 179 | CHECK_INPUT(sigmas); 180 | CHECK_INPUT(rgbs); 181 | CHECK_INPUT(deltas); 182 | CHECK_INPUT(ts); 183 | CHECK_INPUT(hits_t); 184 | CHECK_INPUT(alive_indices); 185 | CHECK_INPUT(N_eff_samples); 186 | CHECK_INPUT(opacity); 187 | CHECK_INPUT(depth); 188 | CHECK_INPUT(rgb); 189 | 190 | composite_test_fw_cu( 191 | sigmas, rgbs, deltas, ts, hits_t, alive_indices, 192 | T_threshold, N_eff_samples, 193 | opacity, depth, rgb); 194 | } 195 | 196 | 197 | std::vector distortion_loss_fw( 198 | const torch::Tensor ws, 199 | const torch::Tensor deltas, 200 | const torch::Tensor ts, 201 | const torch::Tensor rays_a 202 | ){ 203 | CHECK_INPUT(ws); 204 | CHECK_INPUT(deltas); 205 | CHECK_INPUT(ts); 206 | CHECK_INPUT(rays_a); 207 | 208 | return distortion_loss_fw_cu(ws, deltas, ts, rays_a); 209 | } 210 | 211 | 212 | torch::Tensor distortion_loss_bw( 213 | const torch::Tensor dL_dloss, 214 | const torch::Tensor ws_inclusive_scan, 215 | const torch::Tensor wts_inclusive_scan, 216 | const torch::Tensor ws, 217 | const torch::Tensor deltas, 218 | const torch::Tensor ts, 219 | const torch::Tensor rays_a 220 | ){ 221 | CHECK_INPUT(dL_dloss); 222 | CHECK_INPUT(ws_inclusive_scan); 223 | CHECK_INPUT(wts_inclusive_scan); 224 | CHECK_INPUT(ws); 225 | CHECK_INPUT(deltas); 226 | CHECK_INPUT(ts); 227 | CHECK_INPUT(rays_a); 228 | 229 | return distortion_loss_bw_cu(dL_dloss, ws_inclusive_scan, wts_inclusive_scan, 230 | ws, deltas, ts, rays_a); 231 | } 232 | 233 | 234 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 235 | m.def("ray_aabb_intersect", &ray_aabb_intersect); 236 | m.def("ray_sphere_intersect", &ray_sphere_intersect); 237 | 238 | m.def("morton3D", &morton3D); 239 | m.def("morton3D_invert", &morton3D_invert); 240 | m.def("packbits", &packbits); 241 | 242 | m.def("raymarching_train", &raymarching_train); 243 | m.def("raymarching_test", &raymarching_test); 244 | m.def("composite_train_fw", &composite_train_fw); 245 | m.def("composite_train_bw", &composite_train_bw); 246 | m.def("composite_test_fw", &composite_test_fw); 247 | 248 | m.def("distortion_loss_fw", &distortion_loss_fw); 249 | m.def("distortion_loss_bw", &distortion_loss_bw); 250 | 251 | } -------------------------------------------------------------------------------- /models/csrc/losses.cu: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | #include 3 | 4 | 5 | // for details of the formulae, please see https://arxiv.org/pdf/2206.05085.pdf 6 | 7 | template 8 | __global__ void prefix_sums_kernel( 9 | const scalar_t* __restrict__ ws, 10 | const scalar_t* __restrict__ wts, 11 | const torch::PackedTensorAccessor64 rays_a, 12 | scalar_t* __restrict__ ws_inclusive_scan, 13 | scalar_t* __restrict__ ws_exclusive_scan, 14 | scalar_t* __restrict__ wts_inclusive_scan, 15 | scalar_t* __restrict__ wts_exclusive_scan 16 | ){ 17 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (n >= rays_a.size(0)) return; 19 | 20 | const int start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 21 | 22 | // compute prefix sum of ws and ws*ts 23 | // [a0, a1, a2, a3, ...] -> [a0, a0+a1, a0+a1+a2, a0+a1+a2+a3, ...] 24 | thrust::inclusive_scan(thrust::device, 25 | ws+start_idx, 26 | ws+start_idx+N_samples, 27 | ws_inclusive_scan+start_idx); 28 | thrust::inclusive_scan(thrust::device, 29 | wts+start_idx, 30 | wts+start_idx+N_samples, 31 | wts_inclusive_scan+start_idx); 32 | // [a0, a1, a2, a3, ...] -> [0, a0, a0+a1, a0+a1+a2, ...] 33 | thrust::exclusive_scan(thrust::device, 34 | ws+start_idx, 35 | ws+start_idx+N_samples, 36 | ws_exclusive_scan+start_idx); 37 | thrust::exclusive_scan(thrust::device, 38 | wts+start_idx, 39 | wts+start_idx+N_samples, 40 | wts_exclusive_scan+start_idx); 41 | } 42 | 43 | 44 | template 45 | __global__ void distortion_loss_fw_kernel( 46 | const scalar_t* __restrict__ _loss, 47 | const torch::PackedTensorAccessor64 rays_a, 48 | torch::PackedTensorAccessor loss 49 | ){ 50 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 51 | if (n >= rays_a.size(0)) return; 52 | 53 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 54 | 55 | loss[ray_idx] = thrust::reduce(thrust::device, 56 | _loss+start_idx, 57 | _loss+start_idx+N_samples, 58 | (scalar_t)0); 59 | } 60 | 61 | 62 | std::vector distortion_loss_fw_cu( 63 | const torch::Tensor ws, 64 | const torch::Tensor deltas, 65 | const torch::Tensor ts, 66 | const torch::Tensor rays_a 67 | ){ 68 | const int N_rays = rays_a.size(0), N = ws.size(0); 69 | 70 | auto wts = ws * ts; 71 | 72 | auto ws_inclusive_scan = torch::zeros({N}, ws.options()); 73 | auto ws_exclusive_scan = torch::zeros({N}, ws.options()); 74 | auto wts_inclusive_scan = torch::zeros({N}, ws.options()); 75 | auto wts_exclusive_scan = torch::zeros({N}, ws.options()); 76 | 77 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_fw_cu_prefix_sums", 80 | ([&] { 81 | prefix_sums_kernel<<>>( 82 | ws.data_ptr(), 83 | wts.data_ptr(), 84 | rays_a.packed_accessor64(), 85 | ws_inclusive_scan.data_ptr(), 86 | ws_exclusive_scan.data_ptr(), 87 | wts_inclusive_scan.data_ptr(), 88 | wts_exclusive_scan.data_ptr() 89 | ); 90 | })); 91 | 92 | auto _loss = 2*(wts_inclusive_scan*ws_exclusive_scan- 93 | ws_inclusive_scan*wts_exclusive_scan) + 1.0f/3*ws*ws*deltas; 94 | 95 | auto loss = torch::zeros({N_rays}, ws.options()); 96 | 97 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_fw_cu", 98 | ([&] { 99 | distortion_loss_fw_kernel<<>>( 100 | _loss.data_ptr(), 101 | rays_a.packed_accessor64(), 102 | loss.packed_accessor() 103 | ); 104 | })); 105 | 106 | return {loss, ws_inclusive_scan, wts_inclusive_scan}; 107 | } 108 | 109 | 110 | template 111 | __global__ void distortion_loss_bw_kernel( 112 | const torch::PackedTensorAccessor dL_dloss, 113 | const torch::PackedTensorAccessor ws_inclusive_scan, 114 | const torch::PackedTensorAccessor wts_inclusive_scan, 115 | const torch::PackedTensorAccessor ws, 116 | const torch::PackedTensorAccessor deltas, 117 | const torch::PackedTensorAccessor ts, 118 | const torch::PackedTensorAccessor64 rays_a, 119 | torch::PackedTensorAccessor dL_dws 120 | ){ 121 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 122 | if (n >= rays_a.size(0)) return; 123 | 124 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 125 | const int end_idx = start_idx+N_samples-1; 126 | 127 | const scalar_t ws_sum = ws_inclusive_scan[end_idx]; 128 | const scalar_t wts_sum = wts_inclusive_scan[end_idx]; 129 | // fill in dL_dws from start_idx to end_idx 130 | for (int s=start_idx; s<=end_idx; s++){ 131 | dL_dws[s] = dL_dloss[ray_idx] * 2 * ( 132 | (s==start_idx? 133 | (scalar_t)0: 134 | (ts[s]*ws_inclusive_scan[s-1]-wts_inclusive_scan[s-1]) 135 | ) + 136 | (wts_sum-wts_inclusive_scan[s]-ts[s]*(ws_sum-ws_inclusive_scan[s])) 137 | ); 138 | dL_dws[s] += dL_dloss[ray_idx] * (scalar_t)2/3*ws[s]*deltas[s]; 139 | } 140 | } 141 | 142 | 143 | torch::Tensor distortion_loss_bw_cu( 144 | const torch::Tensor dL_dloss, 145 | const torch::Tensor ws_inclusive_scan, 146 | const torch::Tensor wts_inclusive_scan, 147 | const torch::Tensor ws, 148 | const torch::Tensor deltas, 149 | const torch::Tensor ts, 150 | const torch::Tensor rays_a 151 | ){ 152 | const int N_rays = rays_a.size(0), N = ws.size(0); 153 | 154 | auto dL_dws = torch::zeros({N}, dL_dloss.options()); 155 | 156 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 157 | 158 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_bw_cu", 159 | ([&] { 160 | distortion_loss_bw_kernel<<>>( 161 | dL_dloss.packed_accessor(), 162 | ws_inclusive_scan.packed_accessor(), 163 | wts_inclusive_scan.packed_accessor(), 164 | ws.packed_accessor(), 165 | deltas.packed_accessor(), 166 | ts.packed_accessor(), 167 | rays_a.packed_accessor64(), 168 | dL_dws.packed_accessor() 169 | ); 170 | })); 171 | 172 | return dL_dws; 173 | } -------------------------------------------------------------------------------- /representations/grufusion/modules.py: -------------------------------------------------------------------------------- 1 | # ported from NeuralRecon (https://github.com/zju3dv/NeuralRecon) 2 | import torch 3 | import torch.nn as nn 4 | import torchsparse 5 | import torchsparse.nn as spnn 6 | from torchsparse.tensor import PointTensor 7 | from torchsparse.utils import * 8 | 9 | from .torchsparse_utils import * 10 | 11 | __all__ = ['SPVCNN', 'SConv3d', 'ConvGRU'] 12 | 13 | 14 | class BasicConvolutionBlock(nn.Module): 15 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1): 16 | super().__init__() 17 | self.net = nn.Sequential( 18 | spnn.Conv3d(inc, 19 | outc, 20 | kernel_size=ks, 21 | dilation=dilation, 22 | stride=stride), spnn.BatchNorm(outc), 23 | spnn.ReLU(True)) 24 | 25 | def forward(self, x): 26 | out = self.net(x) 27 | return out 28 | 29 | 30 | class BasicDeconvolutionBlock(nn.Module): 31 | def __init__(self, inc, outc, ks=3, stride=1): 32 | super().__init__() 33 | self.net = nn.Sequential( 34 | spnn.Conv3d(inc, 35 | outc, 36 | kernel_size=ks, 37 | stride=stride, 38 | transposed=True), spnn.BatchNorm(outc), 39 | spnn.ReLU(True)) 40 | 41 | def forward(self, x): 42 | return self.net(x) 43 | 44 | 45 | class ResidualBlock(nn.Module): 46 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1): 47 | super().__init__() 48 | self.net = nn.Sequential( 49 | spnn.Conv3d(inc, 50 | outc, 51 | kernel_size=ks, 52 | dilation=dilation, 53 | stride=stride), spnn.BatchNorm(outc), 54 | spnn.ReLU(True), 55 | spnn.Conv3d(outc, 56 | outc, 57 | kernel_size=ks, 58 | dilation=dilation, 59 | stride=1), spnn.BatchNorm(outc)) 60 | 61 | self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ 62 | nn.Sequential( 63 | spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), 64 | spnn.BatchNorm(outc) 65 | ) 66 | 67 | self.relu = spnn.ReLU(True) 68 | 69 | def forward(self, x): 70 | out = self.relu(self.net(x) + self.downsample(x)) 71 | return out 72 | 73 | 74 | class SPVCNN(nn.Module): 75 | def __init__(self, **kwargs): 76 | super().__init__() 77 | 78 | self.dropout = kwargs['dropout'] 79 | 80 | cr = kwargs.get('cr', 1.0) 81 | cs = [32, 64, 128, 96, 96] 82 | cs = [int(cr * x) for x in cs] 83 | 84 | if 'pres' in kwargs and 'vres' in kwargs: 85 | self.pres = kwargs['pres'] 86 | self.vres = kwargs['vres'] 87 | 88 | self.stem = nn.Sequential( 89 | spnn.Conv3d(kwargs['in_channels'], cs[0], kernel_size=3, stride=1), 90 | spnn.BatchNorm(cs[0]), spnn.ReLU(True) 91 | ) 92 | 93 | self.stage1 = nn.Sequential( 94 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), 95 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), 96 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), 97 | ) 98 | 99 | self.stage2 = nn.Sequential( 100 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), 101 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), 102 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), 103 | ) 104 | 105 | self.up1 = nn.ModuleList([ 106 | BasicDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2), 107 | nn.Sequential( 108 | ResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1, 109 | dilation=1), 110 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), 111 | ) 112 | ]) 113 | 114 | self.up2 = nn.ModuleList([ 115 | BasicDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2), 116 | nn.Sequential( 117 | ResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1, 118 | dilation=1), 119 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), 120 | ) 121 | ]) 122 | 123 | self.point_transforms = nn.ModuleList([ 124 | nn.Sequential( 125 | nn.Linear(cs[0], cs[2]), 126 | nn.BatchNorm1d(cs[2]), 127 | nn.ReLU(True), 128 | ), 129 | nn.Sequential( 130 | nn.Linear(cs[2], cs[4]), 131 | nn.BatchNorm1d(cs[4]), 132 | nn.ReLU(True), 133 | ) 134 | ]) 135 | 136 | self.weight_initialization() 137 | 138 | if self.dropout: 139 | self.dropout = nn.Dropout(0.3, True) 140 | 141 | def weight_initialization(self): 142 | for m in self.modules(): 143 | if isinstance(m, nn.BatchNorm1d): 144 | nn.init.constant_(m.weight, 1) 145 | nn.init.constant_(m.bias, 0) 146 | 147 | def forward(self, z): 148 | # x: SparseTensor z: PointTensor 149 | x0 = initial_voxelize(z, self.pres, self.vres) 150 | 151 | x0 = self.stem(x0) 152 | z0 = voxel_to_point(x0, z, nearest=False) 153 | z0.F = z0.F 154 | 155 | x1 = point_to_voxel(x0, z0) 156 | x1 = self.stage1(x1) 157 | x2 = self.stage2(x1) 158 | z1 = voxel_to_point(x2, z0) 159 | z1.F = z1.F + self.point_transforms[0](z0.F) 160 | 161 | y3 = point_to_voxel(x2, z1) 162 | if self.dropout: 163 | y3.F = self.dropout(y3.F) 164 | y3 = self.up1[0](y3) 165 | y3 = torchsparse.cat([y3, x1]) 166 | y3 = self.up1[1](y3) 167 | 168 | y4 = self.up2[0](y3) 169 | y4 = torchsparse.cat([y4, x0]) 170 | y4 = self.up2[1](y4) 171 | z3 = voxel_to_point(y4, z1) 172 | z3.F = z3.F + self.point_transforms[1](z1.F) 173 | 174 | return z3.F 175 | 176 | 177 | class SConv3d(nn.Module): 178 | def __init__(self, inc, outc, pres, vres, ks=3, stride=1, dilation=1): 179 | super().__init__() 180 | self.net = spnn.Conv3d(inc, 181 | outc, 182 | kernel_size=ks, 183 | dilation=dilation, 184 | stride=stride) 185 | self.point_transforms = nn.Sequential( 186 | nn.Linear(inc, outc), 187 | ) 188 | self.pres = pres 189 | self.vres = vres 190 | 191 | def forward(self, z): 192 | x = initial_voxelize(z, self.pres, self.vres) 193 | x = self.net(x) 194 | out = voxel_to_point(x, z, nearest=False) 195 | out.F = out.F + self.point_transforms(z.F) 196 | return out 197 | 198 | 199 | class ConvGRU(nn.Module): 200 | def __init__(self, hidden_dim=128, input_dim=192 + 128, pres=1, vres=1): 201 | super(ConvGRU, self).__init__() 202 | self.convz = SConv3d(hidden_dim + input_dim, hidden_dim, pres, vres, 3) 203 | self.convr = SConv3d(hidden_dim + input_dim, hidden_dim, pres, vres, 3) 204 | self.convq = SConv3d(hidden_dim + input_dim, hidden_dim, pres, vres, 3) 205 | 206 | def forward(self, h, x): 207 | ''' 208 | 209 | :param h: PintTensor 210 | :param x: PintTensor 211 | :return: h.F: Tensor (N, C) 212 | ''' 213 | hx = PointTensor(torch.cat([h.F, x.F], dim=1), h.C) 214 | 215 | z = torch.sigmoid(self.convz(hx).F) 216 | r = torch.sigmoid(self.convr(hx).F) 217 | x.F = torch.cat([r * h.F, x.F], dim=1) 218 | q = torch.tanh(self.convq(x).F) 219 | 220 | h.F = (1 - z) * h.F + z * q 221 | return h.F 222 | 223 | -------------------------------------------------------------------------------- /datasets/colmap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import glob 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import * 8 | from .color_utils import read_image 9 | from .colmap_utils import \ 10 | read_cameras_binary, read_images_binary, read_points3d_binary 11 | 12 | from .base import BaseDataset 13 | 14 | 15 | class ColmapDataset(BaseDataset): 16 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 17 | super().__init__(root_dir, split, downsample) 18 | 19 | self.read_intrinsics() 20 | 21 | if kwargs.get('read_meta', True): 22 | self.read_meta(split, **kwargs) 23 | 24 | def read_intrinsics(self): 25 | # Step 1: read and scale intrinsics (same for all images) 26 | camdata = read_cameras_binary(os.path.join(self.root_dir, 'sparse/0/cameras.bin')) 27 | h = int(camdata[1].height*self.downsample) 28 | w = int(camdata[1].width*self.downsample) 29 | self.img_wh = (w, h) 30 | 31 | if camdata[1].model == 'SIMPLE_RADIAL': 32 | fx = fy = camdata[1].params[0]*self.downsample 33 | cx = camdata[1].params[1]*self.downsample 34 | cy = camdata[1].params[2]*self.downsample 35 | elif camdata[1].model in ['PINHOLE', 'OPENCV']: 36 | fx = camdata[1].params[0]*self.downsample 37 | fy = camdata[1].params[1]*self.downsample 38 | cx = camdata[1].params[2]*self.downsample 39 | cy = camdata[1].params[3]*self.downsample 40 | else: 41 | raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") 42 | self.K = torch.FloatTensor([[fx, 0, cx], 43 | [0, fy, cy], 44 | [0, 0, 1]]) 45 | self.directions = get_ray_directions(h, w, self.K) 46 | 47 | def read_meta(self, split, **kwargs): 48 | # Step 2: correct poses 49 | # read extrinsics (of successfully reconstructed images) 50 | imdata = read_images_binary(os.path.join(self.root_dir, 'sparse/0/images.bin')) 51 | img_names = [imdata[k].name for k in imdata] 52 | perm = np.argsort(img_names) 53 | if '360_v2' in self.root_dir and self.downsample<1: # mipnerf360 data 54 | folder = f'images_{int(1/self.downsample)}' 55 | else: 56 | folder = 'images' 57 | # read successfully reconstructed images and ignore others 58 | img_paths = [os.path.join(self.root_dir, folder, name) 59 | for name in sorted(img_names)] 60 | w2c_mats = [] 61 | bottom = np.array([[0, 0, 0, 1.]]) 62 | for k in imdata: 63 | im = imdata[k] 64 | R = im.qvec2rotmat(); t = im.tvec.reshape(3, 1) 65 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)] 66 | w2c_mats = np.stack(w2c_mats, 0) 67 | poses = np.linalg.inv(w2c_mats)[perm, :3] # (N_images, 3, 4) cam2world matrices 68 | 69 | pts3d = read_points3d_binary(os.path.join(self.root_dir, 'sparse/0/points3D.bin')) 70 | pts3d = np.array([pts3d[k].xyz for k in pts3d]) # (N, 3) 71 | 72 | self.poses, self.pts3d = center_poses(poses, pts3d) 73 | 74 | scale = np.linalg.norm(self.poses[..., 3], axis=-1).min() 75 | self.poses[..., 3] /= scale 76 | self.pts3d /= scale 77 | 78 | self.rays = [] 79 | if split == 'test_traj': # use precomputed test poses 80 | self.poses = create_spheric_poses(1.2, self.poses[:, 1, 3].mean()) 81 | self.poses = torch.FloatTensor(self.poses) 82 | return 83 | 84 | if 'HDR-NeRF' in self.root_dir: # HDR-NeRF data 85 | if 'syndata' in self.root_dir: # synthetic 86 | # first 17 are test, last 18 are train 87 | self.unit_exposure_rgb = 0.73 88 | if split=='train': 89 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 90 | f'train/*[024].png'))) 91 | self.poses = np.repeat(self.poses[-18:], 3, 0) 92 | elif split=='test': 93 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 94 | f'test/*[13].png'))) 95 | self.poses = np.repeat(self.poses[:17], 2, 0) 96 | else: 97 | raise ValueError(f"split {split} is invalid for HDR-NeRF!") 98 | else: # real 99 | self.unit_exposure_rgb = 0.5 100 | # even numbers are train, odd numbers are test 101 | if split=='train': 102 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 103 | f'input_images/*0.jpg')))[::2] 104 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 105 | f'input_images/*2.jpg')))[::2] 106 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 107 | f'input_images/*4.jpg')))[::2] 108 | self.poses = np.tile(self.poses[::2], (3, 1, 1)) 109 | elif split=='test': 110 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 111 | f'input_images/*1.jpg')))[1::2] 112 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 113 | f'input_images/*3.jpg')))[1::2] 114 | self.poses = np.tile(self.poses[1::2], (2, 1, 1)) 115 | else: 116 | raise ValueError(f"split {split} is invalid for HDR-NeRF!") 117 | else: 118 | # use every 8th image as test set 119 | if split=='train': 120 | img_paths = [x for i, x in enumerate(img_paths) if i%8!=0] 121 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8!=0]) 122 | elif split=='test': 123 | img_paths = [x for i, x in enumerate(img_paths) if i%8==0] 124 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8==0]) 125 | 126 | print(f'Loading {len(img_paths)} {split} images ...') 127 | for img_path in tqdm(img_paths): 128 | buf = [] # buffer for ray attributes: rgb, etc 129 | 130 | img = read_image(img_path, self.img_wh, blend_a=False) 131 | img = torch.FloatTensor(img) 132 | buf += [img] 133 | 134 | if 'HDR-NeRF' in self.root_dir: # get exposure 135 | folder = self.root_dir.split('/') 136 | scene = folder[-1] if folder[-1] != '' else folder[-2] 137 | if scene in ['bathroom', 'bear', 'chair', 'desk']: 138 | e_dict = {e: 1/8*4**e for e in range(5)} 139 | elif scene in ['diningroom', 'dog']: 140 | e_dict = {e: 1/16*4**e for e in range(5)} 141 | elif scene in ['sofa']: 142 | e_dict = {0:0.25, 1:1, 2:2, 3:4, 4:16} 143 | elif scene in ['sponza']: 144 | e_dict = {0:0.5, 1:2, 2:4, 3:8, 4:32} 145 | elif scene in ['box']: 146 | e_dict = {0:2/3, 1:1/3, 2:1/6, 3:0.1, 4:0.05} 147 | elif scene in ['computer']: 148 | e_dict = {0:1/3, 1:1/8, 2:1/15, 3:1/30, 4:1/60} 149 | elif scene in ['flower']: 150 | e_dict = {0:1/3, 1:1/6, 2:0.1, 3:0.05, 4:1/45} 151 | elif scene in ['luckycat']: 152 | e_dict = {0:2, 1:1, 2:0.5, 3:0.25, 4:0.125} 153 | e = int(img_path.split('.')[0][-1]) 154 | buf += [e_dict[e]*torch.ones_like(img[:, :1])] 155 | 156 | self.rays += [torch.cat(buf, 1)] 157 | 158 | self.rays = torch.stack(self.rays) # (N_images, hw, ?) 159 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) -------------------------------------------------------------------------------- /datasets/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from kornia import create_meshgrid 4 | from einops import rearrange 5 | 6 | 7 | @torch.cuda.amp.autocast(dtype=torch.float32) 8 | def get_ray_directions(H, W, K, device='cpu', random=False, return_uv=False, flatten=True): 9 | """ 10 | Get ray directions for all pixels in camera coordinate [right down front]. 11 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 12 | ray-tracing-generating-camera-rays/standard-coordinate-systems 13 | 14 | Inputs: 15 | H, W: image height and width 16 | K: (3, 3) camera intrinsics 17 | random: whether the ray passes randomly inside the pixel 18 | return_uv: whether to return uv image coordinates 19 | 20 | Outputs: (shape depends on @flatten) 21 | directions: (H, W, 3) or (H*W, 3), the direction of the rays in camera coordinate 22 | uv: (H, W, 2) or (H*W, 2) image coordinates 23 | """ 24 | grid = create_meshgrid(H, W, False, device=device)[0] # (H, W, 2) 25 | u, v = grid.unbind(-1) 26 | 27 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 28 | if random: 29 | directions = \ 30 | torch.stack([(u-cx+torch.rand_like(u))/fx, 31 | (v-cy+torch.rand_like(v))/fy, 32 | torch.ones_like(u)], -1) 33 | else: # pass by the center 34 | directions = \ 35 | torch.stack([(u-cx+0.5)/fx, (v-cy+0.5)/fy, torch.ones_like(u)], -1) 36 | if flatten: 37 | directions = directions.reshape(-1, 3) 38 | grid = grid.reshape(-1, 2) 39 | 40 | if return_uv: 41 | return directions, grid 42 | return directions 43 | 44 | 45 | @torch.cuda.amp.autocast(dtype=torch.float32) 46 | def get_rays(directions, c2w): 47 | """ 48 | Get ray origin and directions in world coordinate for all pixels in one image. 49 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 50 | ray-tracing-generating-camera-rays/standard-coordinate-systems 51 | 52 | Inputs: 53 | directions: (N, 3) ray directions in camera coordinate 54 | c2w: (3, 4) or (N, 3, 4) transformation matrix from camera coordinate to world coordinate 55 | 56 | Outputs: 57 | rays_o: (N, 3), the origin of the rays in world coordinate 58 | rays_d: (N, 3), the direction of the rays in world coordinate 59 | """ 60 | if c2w.ndim==2: 61 | # Rotate ray directions from camera coordinate to the world coordinate 62 | rays_d = directions @ c2w[:, :3].T 63 | else: 64 | rays_d = rearrange(directions, 'n c -> n 1 c') @ \ 65 | rearrange(c2w[..., :3], 'n a b -> n b a') 66 | rays_d = rearrange(rays_d, 'n 1 c -> n c') 67 | # The origin of all rays is the camera origin in world coordinate 68 | rays_o = c2w[..., 3].expand_as(rays_d) 69 | 70 | return rays_o, rays_d 71 | 72 | 73 | @torch.cuda.amp.autocast(dtype=torch.float32) 74 | def axisangle_to_R(v): 75 | """ 76 | Convert an axis-angle vector to rotation matrix 77 | from https://github.com/ActiveVisionLab/nerfmm/blob/main/utils/lie_group_helper.py#L47 78 | 79 | Inputs: 80 | v: (3) or (B, 3) 81 | 82 | Outputs: 83 | R: (3, 3) or (B, 3, 3) 84 | """ 85 | v_ndim = v.ndim 86 | if v_ndim==1: 87 | v = rearrange(v, 'c -> 1 c') 88 | zero = torch.zeros_like(v[:, :1]) # (B, 1) 89 | skew_v0 = torch.cat([zero, -v[:, 2:3], v[:, 1:2]], 1) # (B, 3) 90 | skew_v1 = torch.cat([v[:, 2:3], zero, -v[:, 0:1]], 1) 91 | skew_v2 = torch.cat([-v[:, 1:2], v[:, 0:1], zero], 1) 92 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=1) # (B, 3, 3) 93 | 94 | norm_v = rearrange(torch.norm(v, dim=1)+1e-7, 'b -> b 1 1') 95 | eye = torch.eye(3, device=v.device) 96 | R = eye + (torch.sin(norm_v)/norm_v)*skew_v + \ 97 | ((1-torch.cos(norm_v))/norm_v**2)*(skew_v@skew_v) 98 | if v_ndim==1: 99 | R = rearrange(R, '1 c d -> c d') 100 | return R 101 | 102 | 103 | def normalize(v): 104 | """Normalize a vector.""" 105 | return v/np.linalg.norm(v) 106 | 107 | 108 | def average_poses(poses, pts3d=None): 109 | """ 110 | Calculate the average pose, which is then used to center all poses 111 | using @center_poses. Its computation is as follows: 112 | 1. Compute the center: the average of 3d point cloud (if None, center of cameras). 113 | 2. Compute the z axis: the normalized average z axis. 114 | 3. Compute axis y': the average y axis. 115 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 116 | 5. Compute the y axis: z cross product x. 117 | 118 | Note that at step 3, we cannot directly use y' as y axis since it's 119 | not necessarily orthogonal to z axis. We need to pass from x to y. 120 | Inputs: 121 | poses: (N_images, 3, 4) 122 | pts3d: (N, 3) 123 | 124 | Outputs: 125 | pose_avg: (3, 4) the average pose 126 | """ 127 | # 1. Compute the center 128 | if pts3d is not None: 129 | center = pts3d.mean(0) 130 | else: 131 | center = poses[..., 3].mean(0) 132 | 133 | # 2. Compute the z axis 134 | z = normalize(poses[..., 2].mean(0)) # (3) 135 | 136 | # 3. Compute axis y' (no need to normalize as it's not the final output) 137 | y_ = poses[..., 1].mean(0) # (3) 138 | 139 | # 4. Compute the x axis 140 | x = normalize(np.cross(y_, z)) # (3) 141 | 142 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 143 | y = np.cross(z, x) # (3) 144 | 145 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 146 | 147 | return pose_avg 148 | 149 | 150 | def center_poses(poses, pts3d=None): 151 | """ 152 | See https://github.com/bmild/nerf/issues/34 153 | Inputs: 154 | poses: (N_images, 3, 4) 155 | pts3d: (N, 3) reconstructed point cloud 156 | 157 | Outputs: 158 | poses_centered: (N_images, 3, 4) the centered poses 159 | pts3d_centered: (N, 3) centered point cloud 160 | """ 161 | 162 | pose_avg = average_poses(poses, pts3d) # (3, 4) 163 | pose_avg_homo = np.eye(4) 164 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 165 | # by simply adding 0, 0, 0, 1 as the last row 166 | pose_avg_inv = np.linalg.inv(pose_avg_homo) 167 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 168 | poses_homo = \ 169 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 170 | 171 | poses_centered = pose_avg_inv @ poses_homo # (N_images, 4, 4) 172 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 173 | 174 | if pts3d is not None: 175 | pts3d_centered = pts3d @ pose_avg_inv[:, :3].T + pose_avg_inv[:, 3:].T 176 | return poses_centered, pts3d_centered 177 | 178 | return poses_centered 179 | 180 | def create_spheric_poses(radius, mean_h, n_poses=120): 181 | """ 182 | Create circular poses around z axis. 183 | Inputs: 184 | radius: the (negative) height and the radius of the circle. 185 | mean_h: mean camera height 186 | Outputs: 187 | spheric_poses: (n_poses, 3, 4) the poses in the circular path 188 | """ 189 | def spheric_pose(theta, phi, radius): 190 | trans_t = lambda t : np.array([ 191 | [1,0,0,0], 192 | [0,1,0,2*mean_h], 193 | [0,0,1,-t] 194 | ]) 195 | 196 | rot_phi = lambda phi : np.array([ 197 | [1,0,0], 198 | [0,np.cos(phi),-np.sin(phi)], 199 | [0,np.sin(phi), np.cos(phi)] 200 | ]) 201 | 202 | rot_theta = lambda th : np.array([ 203 | [np.cos(th),0,-np.sin(th)], 204 | [0,1,0], 205 | [np.sin(th),0, np.cos(th)] 206 | ]) 207 | 208 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius) 209 | c2w = np.array([[-1,0,0],[0,0,1],[0,1,0]]) @ c2w 210 | return c2w 211 | 212 | spheric_poses = [] 213 | for th in np.linspace(0, 2*np.pi, n_poses+1)[:-1]: 214 | spheric_poses += [spheric_pose(th, -np.pi/12, radius)] 215 | return np.stack(spheric_poses, 0) -------------------------------------------------------------------------------- /models/csrc/intersection.cu: -------------------------------------------------------------------------------- 1 | #include "helper_math.h" 2 | #include "utils.h" 3 | 4 | 5 | __device__ __forceinline__ float2 _ray_aabb_intersect( 6 | const float3 ray_o, 7 | const float3 inv_d, 8 | const float3 center, 9 | const float3 half_size 10 | ){ 11 | 12 | const float3 t_min = (center-half_size-ray_o)*inv_d; 13 | const float3 t_max = (center+half_size-ray_o)*inv_d; 14 | 15 | const float3 _t1 = fminf(t_min, t_max); 16 | const float3 _t2 = fmaxf(t_min, t_max); 17 | const float t1 = fmaxf(fmaxf(_t1.x, _t1.y), _t1.z); 18 | const float t2 = fminf(fminf(_t2.x, _t2.y), _t2.z); 19 | 20 | if (t1 > t2) return make_float2(-1.0f); // no intersection 21 | return make_float2(t1, t2); 22 | } 23 | 24 | 25 | __global__ void ray_aabb_intersect_kernel( 26 | const torch::PackedTensorAccessor32 rays_o, 27 | const torch::PackedTensorAccessor32 rays_d, 28 | const torch::PackedTensorAccessor32 centers, 29 | const torch::PackedTensorAccessor32 half_sizes, 30 | const int max_hits, 31 | int* __restrict__ hit_cnt, 32 | torch::PackedTensorAccessor32 hits_t, 33 | torch::PackedTensorAccessor64 hits_voxel_idx 34 | ){ 35 | const int r = blockIdx.x * blockDim.x + threadIdx.x; 36 | const int v = blockIdx.y * blockDim.y + threadIdx.y; 37 | 38 | if (v>=centers.size(0) || r>=rays_o.size(0)) return; 39 | 40 | const float3 ray_o = make_float3(rays_o[r][0], rays_o[r][1], rays_o[r][2]); 41 | const float3 ray_d = make_float3(rays_d[r][0], rays_d[r][1], rays_d[r][2]); 42 | const float3 inv_d = 1.0f/ray_d; 43 | 44 | const float3 center = make_float3(centers[v][0], centers[v][1], centers[v][2]); 45 | const float3 half_size = make_float3(half_sizes[v][0], half_sizes[v][1], half_sizes[v][2]); 46 | const float2 t1t2 = _ray_aabb_intersect(ray_o, inv_d, center, half_size); 47 | 48 | if (t1t2.y > 0){ // if ray hits the voxel 49 | const int cnt = atomicAdd(&hit_cnt[r], 1); 50 | if (cnt < max_hits){ 51 | hits_t[r][cnt][0] = fmaxf(t1t2.x, 0.0f); 52 | hits_t[r][cnt][1] = t1t2.y; 53 | hits_voxel_idx[r][cnt] = v; 54 | } 55 | } 56 | } 57 | 58 | 59 | std::vector ray_aabb_intersect_cu( 60 | const torch::Tensor rays_o, 61 | const torch::Tensor rays_d, 62 | const torch::Tensor centers, 63 | const torch::Tensor half_sizes, 64 | const int max_hits 65 | ){ 66 | 67 | const int N_rays = rays_o.size(0), N_voxels = centers.size(0); 68 | auto hits_t = torch::zeros({N_rays, max_hits, 2}, rays_o.options())-1; 69 | auto hits_voxel_idx = 70 | torch::zeros({N_rays, max_hits}, 71 | torch::dtype(torch::kLong).device(rays_o.device()))-1; 72 | auto hit_cnt = 73 | torch::zeros({N_rays}, 74 | torch::dtype(torch::kInt32).device(rays_o.device())); 75 | 76 | const dim3 threads(256, 1); 77 | const dim3 blocks((N_rays+threads.x-1)/threads.x, 78 | (N_voxels+threads.y-1)/threads.y); 79 | 80 | AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "ray_aabb_intersect_cu", 81 | ([&] { 82 | ray_aabb_intersect_kernel<<>>( 83 | rays_o.packed_accessor32(), 84 | rays_d.packed_accessor32(), 85 | centers.packed_accessor32(), 86 | half_sizes.packed_accessor32(), 87 | max_hits, 88 | hit_cnt.data_ptr(), 89 | hits_t.packed_accessor32(), 90 | hits_voxel_idx.packed_accessor64() 91 | ); 92 | })); 93 | 94 | // sort intersections from near to far based on t1 95 | auto hits_order = std::get<1>(torch::sort(hits_t.index({"...", 0}))); 96 | hits_voxel_idx = torch::gather(hits_voxel_idx, 1, hits_order); 97 | hits_t = torch::gather(hits_t, 1, hits_order.unsqueeze(-1).tile({1, 1, 2})); 98 | 99 | return {hit_cnt, hits_t, hits_voxel_idx}; 100 | } 101 | 102 | 103 | __device__ __forceinline__ float2 _ray_sphere_intersect( 104 | const float3 ray_o, 105 | const float3 ray_d, 106 | const float3 center, 107 | const float radius 108 | ){ 109 | const float3 co = ray_o-center; 110 | 111 | const float a = dot(ray_d, ray_d); 112 | const float half_b = dot(ray_d, co); 113 | const float c = dot(co, co)-radius*radius; 114 | 115 | const float discriminant = half_b*half_b-a*c; 116 | 117 | if (discriminant < 0) return make_float2(-1.0f); // no intersection 118 | 119 | const float disc_sqrt = sqrtf(discriminant); 120 | return make_float2(-half_b-disc_sqrt, -half_b+disc_sqrt)/a; 121 | } 122 | 123 | 124 | __global__ void ray_sphere_intersect_kernel( 125 | const torch::PackedTensorAccessor32 rays_o, 126 | const torch::PackedTensorAccessor32 rays_d, 127 | const torch::PackedTensorAccessor32 centers, 128 | const torch::PackedTensorAccessor32 radii, 129 | const int max_hits, 130 | int* __restrict__ hit_cnt, 131 | torch::PackedTensorAccessor32 hits_t, 132 | torch::PackedTensorAccessor64 hits_sphere_idx 133 | ){ 134 | const int r = blockIdx.x * blockDim.x + threadIdx.x; 135 | const int s = blockIdx.y * blockDim.y + threadIdx.y; 136 | 137 | if (s>=centers.size(0) || r>=rays_o.size(0)) return; 138 | 139 | const float3 ray_o = make_float3(rays_o[r][0], rays_o[r][1], rays_o[r][2]); 140 | const float3 ray_d = make_float3(rays_d[r][0], rays_d[r][1], rays_d[r][2]); 141 | const float3 center = make_float3(centers[s][0], centers[s][1], centers[s][2]); 142 | 143 | const float2 t1t2 = _ray_sphere_intersect(ray_o, ray_d, center, radii[s]); 144 | 145 | if (t1t2.y > 0){ // if ray hits the sphere 146 | const int cnt = atomicAdd(&hit_cnt[r], 1); 147 | if (cnt < max_hits){ 148 | hits_t[r][cnt][0] = fmaxf(t1t2.x, 0.0f); 149 | hits_t[r][cnt][1] = t1t2.y; 150 | hits_sphere_idx[r][cnt] = s; 151 | } 152 | } 153 | } 154 | 155 | 156 | std::vector ray_sphere_intersect_cu( 157 | const torch::Tensor rays_o, 158 | const torch::Tensor rays_d, 159 | const torch::Tensor centers, 160 | const torch::Tensor radii, 161 | const int max_hits 162 | ){ 163 | 164 | const int N_rays = rays_o.size(0), N_spheres = centers.size(0); 165 | auto hits_t = torch::zeros({N_rays, max_hits, 2}, rays_o.options())-1; 166 | auto hits_sphere_idx = 167 | torch::zeros({N_rays, max_hits}, 168 | torch::dtype(torch::kLong).device(rays_o.device()))-1; 169 | auto hit_cnt = 170 | torch::zeros({N_rays}, 171 | torch::dtype(torch::kInt32).device(rays_o.device())); 172 | 173 | const dim3 threads(256, 1); 174 | const dim3 blocks((N_rays+threads.x-1)/threads.x, 175 | (N_spheres+threads.y-1)/threads.y); 176 | 177 | AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "ray_sphere_intersect_cu", 178 | ([&] { 179 | ray_sphere_intersect_kernel<<>>( 180 | rays_o.packed_accessor32(), 181 | rays_d.packed_accessor32(), 182 | centers.packed_accessor32(), 183 | radii.packed_accessor32(), 184 | max_hits, 185 | hit_cnt.data_ptr(), 186 | hits_t.packed_accessor32(), 187 | hits_sphere_idx.packed_accessor64() 188 | ); 189 | })); 190 | 191 | // sort intersections from near to far based on t1 192 | auto hits_order = std::get<1>(torch::sort(hits_t.index({"...", 0}))); 193 | hits_sphere_idx = torch::gather(hits_sphere_idx, 1, hits_order); 194 | hits_t = torch::gather(hits_t, 1, hits_order.unsqueeze(-1).tile({1, 1, 2})); 195 | 196 | return {hit_cnt, hits_t, hits_sphere_idx}; 197 | } -------------------------------------------------------------------------------- /models/nerfusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import tinycudann as tcnn 4 | from .custom_functions import TruncExp 5 | import vren 6 | from einops import rearrange 7 | from kornia.utils.grid import create_meshgrid3d 8 | from .rendering import NEAR_DISTANCE 9 | 10 | class NeRFusion2(nn.Module): 11 | def __init__(self, scale, grid_size=128, global_representation=None): 12 | super().__init__() 13 | 14 | # scene bounding box 15 | # TODO: this is a temp easy solution 16 | self.scale = scale 17 | self.register_buffer('center', torch.zeros(1, 3)) 18 | self.register_buffer('xyz_min', -torch.ones(1, 3)*scale) 19 | self.register_buffer('xyz_max', torch.ones(1, 3)*scale) 20 | self.register_buffer('half_size', (self.xyz_max-self.xyz_min)/2) 21 | 22 | self.grid_size = grid_size 23 | 24 | self.cascades = 1 25 | self.register_buffer('density_bitfield', 26 | torch.ones(self.grid_size**3//8, dtype=torch.uint8)) # dummy 27 | self.register_buffer('density_grid', 28 | torch.zeros(self.cascades, self.grid_size**3)) 29 | self.register_buffer('grid_coords', 30 | create_meshgrid3d(self.grid_size, self.grid_size, self.grid_size, False, dtype=torch.int32).reshape(-1, 3)) 31 | 32 | self.global_representation = global_representation 33 | if global_representation is not None: 34 | self.initialize_global_volume(global_representation) 35 | self.xyz_encoder = \ 36 | tcnn.Network( 37 | n_input_dims=16, n_output_dims=16, 38 | network_config={ 39 | "otype": "FullyFusedMLP", 40 | "activation": "ReLU", 41 | "output_activation": "None", 42 | "n_neurons": 64, 43 | "n_hidden_layers": 1, 44 | } 45 | ) 46 | else: 47 | self.xyz_encoder = \ 48 | tcnn.NetworkWithInputEncoding( 49 | n_input_dims=3, n_output_dims=16, 50 | encoding_config={ 51 | "otype": "Grid", 52 | "type": "Dense", 53 | "n_levels": 3, 54 | "n_feature_per_level": 2, 55 | "base_resolution": 128, 56 | "per_level_scale": 2.0, 57 | "interpolation": "Linear", 58 | }, 59 | network_config={ 60 | "otype": "FullyFusedMLP", 61 | "activation": "ReLU", 62 | "output_activation": "None", 63 | "n_neurons": 64, 64 | "n_hidden_layers": 1, 65 | } 66 | ) 67 | 68 | self.dir_encoder = \ 69 | tcnn.Encoding( 70 | n_input_dims=3, 71 | encoding_config={ 72 | "otype": "SphericalHarmonics", 73 | "degree": 4, 74 | }, 75 | ) 76 | 77 | self.rgb_net = \ 78 | tcnn.Network( 79 | n_input_dims=32, n_output_dims=3, 80 | network_config={ 81 | "otype": "FullyFusedMLP", 82 | "activation": "ReLU", 83 | "output_activation": "Sigmoid", 84 | "n_neurons": 64, 85 | "n_hidden_layers": 2, 86 | } 87 | ) 88 | 89 | def density(self, x, return_feat=False): 90 | """ 91 | Inputs: 92 | x: (N, 3) xyz in [-scale, scale] 93 | return_feat: whether to return intermediate feature 94 | 95 | Outputs: 96 | sigmas: (N) 97 | """ 98 | x = (x-self.xyz_min)/(self.xyz_max-self.xyz_min) 99 | h = self.xyz_encoder(x) 100 | sigmas = TruncExp.apply(h[:, 0]) 101 | if return_feat: return sigmas, h 102 | return sigmas 103 | 104 | def forward(self, x, d, **kwargs): 105 | """ 106 | Inputs: 107 | x: (N, 3) xyz in [-scale, scale] 108 | d: (N, 3) directions 109 | 110 | Outputs: 111 | sigmas: (N) 112 | rgbs: (N, 3) 113 | """ 114 | if self.global_representation is not None: 115 | x = self.get_global_feature(x) 116 | sigmas, h = self.density(x, return_feat=True) 117 | d = d/torch.norm(d, dim=1, keepdim=True) 118 | d = self.dir_encoder((d+1)/2) 119 | rgbs = self.rgb_net(torch.cat([d, h], 1)) 120 | 121 | return sigmas, rgbs 122 | 123 | @torch.no_grad() 124 | def get_all_cells(self): 125 | """ 126 | Get all cells from the density grid. 127 | 128 | Outputs: 129 | cells: list (of length self.cascades) of indices and coords 130 | selected at each cascade 131 | """ 132 | indices = vren.morton3D(self.grid_coords).long() 133 | cells = [(indices, self.grid_coords)] * self.cascades 134 | 135 | return cells 136 | 137 | @torch.no_grad() 138 | def sample_uniform_and_occupied_cells(self, M, density_threshold): 139 | """ 140 | Sample both M uniform and occupied cells (per cascade) 141 | occupied cells are sample from cells with density > @density_threshold 142 | 143 | Outputs: 144 | cells: list (of length self.cascades) of indices and coords 145 | selected at each cascade 146 | """ 147 | cells = [] 148 | for c in range(self.cascades): 149 | # uniform cells 150 | coords1 = torch.randint(self.grid_size, (M, 3), dtype=torch.int32, 151 | device=self.density_grid.device) 152 | indices1 = vren.morton3D(coords1).long() 153 | # occupied cells 154 | indices2 = torch.nonzero(self.density_grid[c] > density_threshold)[:, 0] 155 | if len(indices2) > 0: 156 | rand_idx = torch.randint(len(indices2), (M,), 157 | device=self.density_grid.device) 158 | indices2 = indices2[rand_idx] 159 | coords2 = vren.morton3D_invert(indices2.int()) 160 | # concatenate 161 | cells += [(torch.cat([indices1, indices2]), torch.cat([coords1, coords2]))] 162 | 163 | return cells 164 | 165 | @torch.no_grad() 166 | def prune_cells(self, K, poses, img_wh, chunk=64 ** 3): 167 | """ 168 | mark the cells that aren't covered by the cameras with density -1 169 | only executed once before training starts 170 | 171 | Inputs: 172 | K: (3, 3) camera intrinsics 173 | poses: (N, 3, 4) camera to world poses 174 | img_wh: image width and height 175 | chunk: the chunk size to split the cells (to avoid OOM) 176 | """ 177 | N_cams = poses.shape[0] 178 | self.count_grid = torch.zeros_like(self.density_grid) 179 | w2c_R = rearrange(poses[:, :3, :3], 'n a b -> n b a') # (N_cams, 3, 3) 180 | w2c_T = -w2c_R @ poses[:, :3, 3:] # (N_cams, 3, 1) 181 | cells = self.get_all_cells() 182 | for c in range(self.cascades): 183 | indices, coords = cells[c] 184 | for i in range(0, len(indices), chunk): 185 | xyzs = coords[i:i + chunk] / (self.grid_size - 1) * 2 - 1 186 | s = min(2 ** (c - 1), self.scale) 187 | half_grid_size = s / self.grid_size 188 | xyzs_w = (xyzs * (s - half_grid_size)).T # (3, chunk) 189 | xyzs_c = w2c_R @ xyzs_w + w2c_T # (N_cams, 3, chunk) 190 | uvd = K @ xyzs_c # (N_cams, 3, chunk) 191 | uv = uvd[:, :2] / uvd[:, 2:] # (N_cams, 2, chunk) 192 | in_image = (uvd[:, 2] >= 0) & \ 193 | (uv[:, 0] >= 0) & (uv[:, 0] < img_wh[0]) & \ 194 | (uv[:, 1] >= 0) & (uv[:, 1] < img_wh[1]) 195 | covered_by_cam = (uvd[:, 2] >= NEAR_DISTANCE) & in_image # (N_cams, chunk) 196 | # if the cell is visible by at least one camera 197 | self.count_grid[c, indices[i:i + chunk]] = \ 198 | count = covered_by_cam.sum(0) / N_cams 199 | 200 | too_near_to_cam = (uvd[:, 2] < NEAR_DISTANCE) & in_image # (N, chunk) 201 | # if the cell is too close (in front) to any camera 202 | too_near_to_any_cam = too_near_to_cam.any(0) 203 | # a valid cell should be visible by at least one camera and not too close to any camera 204 | valid_mask = (count > 0) & (~too_near_to_any_cam) 205 | self.density_grid[c, indices[i:i + chunk]] = \ 206 | torch.where(valid_mask, 0., -1.) 207 | 208 | @torch.no_grad() 209 | def update_density_grid(self, density_threshold, warmup=False, decay=0.95, erode=False): 210 | density_grid_tmp = torch.zeros_like(self.density_grid) 211 | if warmup: # during the first steps 212 | cells = self.get_all_cells() 213 | else: 214 | cells = self.sample_uniform_and_occupied_cells(self.grid_size ** 3 // 4, 215 | density_threshold) 216 | # infer sigmas 217 | for c in range(self.cascades): 218 | indices, coords = cells[c] 219 | s = min(2 ** (c - 1), self.scale) 220 | half_grid_size = s / self.grid_size 221 | xyzs_w = (coords / (self.grid_size - 1) * 2 - 1) * (s - half_grid_size) 222 | # pick random position in the cell by adding noise in [-hgs, hgs] 223 | xyzs_w += (torch.rand_like(xyzs_w) * 2 - 1) * half_grid_size 224 | density_grid_tmp[c, indices] = self.density(xyzs_w) 225 | 226 | if erode: 227 | # My own logic. decay more the cells that are visible to few cameras 228 | decay = torch.clamp(decay ** (1 / self.count_grid), 0.1, 0.95) 229 | self.density_grid = \ 230 | torch.where(self.density_grid < 0, 231 | self.density_grid, 232 | torch.maximum(self.density_grid * decay, density_grid_tmp)) 233 | 234 | mean_density = self.density_grid[self.density_grid > 0].mean().item() 235 | 236 | vren.packbits(self.density_grid, min(mean_density, density_threshold), 237 | self.density_bitfield) 238 | 239 | 240 | -------------------------------------------------------------------------------- /datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import math 17 | from PIL import Image 18 | import torchvision.transforms as transforms 19 | import torch 20 | from scipy.spatial.transform import Rotation as R 21 | import cv2 22 | 23 | rng = np.random.RandomState(234) 24 | _EPS = np.finfo(float).eps * 4.0 25 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision 26 | 27 | 28 | def vector_norm(data, axis=None, out=None): 29 | """Return length, i.e. eucledian norm, of ndarray along axis. 30 | """ 31 | data = np.array(data, dtype=np.float64, copy=True) 32 | if out is None: 33 | if data.ndim == 1: 34 | return math.sqrt(np.dot(data, data)) 35 | data *= data 36 | out = np.atleast_1d(np.sum(data, axis=axis)) 37 | np.sqrt(out, out) 38 | return out 39 | else: 40 | data *= data 41 | np.sum(data, axis=axis, out=out) 42 | np.sqrt(out, out) 43 | 44 | 45 | def quaternion_about_axis(angle, axis): 46 | """Return quaternion for rotation about axis. 47 | """ 48 | quaternion = np.zeros((4, ), dtype=np.float64) 49 | quaternion[:3] = axis[:3] 50 | qlen = vector_norm(quaternion) 51 | if qlen > _EPS: 52 | quaternion *= math.sin(angle/2.0) / qlen 53 | quaternion[3] = math.cos(angle/2.0) 54 | return quaternion 55 | 56 | 57 | def quaternion_matrix(quaternion): 58 | """Return homogeneous rotation matrix from quaternion. 59 | """ 60 | q = np.array(quaternion[:4], dtype=np.float64, copy=True) 61 | nq = np.dot(q, q) 62 | if nq < _EPS: 63 | return np.identity(4) 64 | q *= math.sqrt(2.0 / nq) 65 | q = np.outer(q, q) 66 | return np.array(( 67 | (1.0-q[1, 1]-q[2, 2], q[0, 1]-q[2, 3], q[0, 2]+q[1, 3], 0.0), 68 | ( q[0, 1]+q[2, 3], 1.0-q[0, 0]-q[2, 2], q[1, 2]-q[0, 3], 0.0), 69 | ( q[0, 2]-q[1, 3], q[1, 2]+q[0, 3], 1.0-q[0, 0]-q[1, 1], 0.0), 70 | ( 0.0, 0.0, 0.0, 1.0) 71 | ), dtype=np.float64) 72 | 73 | 74 | def rectify_inplane_rotation(src_pose, tar_pose, src_img, th=40): 75 | relative = np.linalg.inv(tar_pose).dot(src_pose) 76 | relative_rot = relative[:3, :3] 77 | r = R.from_matrix(relative_rot) 78 | euler = r.as_euler('zxy', degrees=True) 79 | euler_z = euler[0] 80 | if np.abs(euler_z) < th: 81 | return src_pose, src_img 82 | 83 | R_rectify = R.from_euler('z', -euler_z, degrees=True).as_matrix() 84 | src_R_rectified = src_pose[:3, :3].dot(R_rectify) 85 | out_pose = np.eye(4) 86 | out_pose[:3, :3] = src_R_rectified 87 | out_pose[:3, 3:4] = src_pose[:3, 3:4] 88 | h, w = src_img.shape[:2] 89 | center = ((w - 1.) / 2., (h - 1.) / 2.) 90 | M = cv2.getRotationMatrix2D(center, -euler_z, 1) 91 | src_img = np.clip((255*src_img).astype(np.uint8), a_max=255, a_min=0) 92 | rotated = cv2.warpAffine(src_img, M, (w, h), borderValue=(255, 255, 255), flags=cv2.INTER_LANCZOS4) 93 | rotated = rotated.astype(np.float32) / 255. 94 | return out_pose, rotated 95 | 96 | 97 | def random_crop(rgb, camera, src_rgbs, src_cameras, size=(400, 600), center=None): 98 | h, w = rgb.shape[:2] 99 | out_h, out_w = size[0], size[1] 100 | if out_w >= w or out_h >= h: 101 | return rgb, camera, src_rgbs, src_cameras 102 | 103 | if center is not None: 104 | center_h, center_w = center 105 | else: 106 | center_h = np.random.randint(low=out_h // 2 + 1, high=h - out_h // 2 - 1) 107 | center_w = np.random.randint(low=out_w // 2 + 1, high=w - out_w // 2 - 1) 108 | 109 | rgb_out = rgb[center_h - out_h // 2:center_h + out_h // 2, center_w - out_w // 2:center_w + out_w // 2, :] 110 | src_rgbs = np.array(src_rgbs) 111 | src_rgbs = src_rgbs[:, center_h - out_h // 2:center_h + out_h // 2, 112 | center_w - out_w // 2:center_w + out_w // 2, :] 113 | camera[0] = out_h 114 | camera[1] = out_w 115 | camera[4] -= center_w - out_w // 2 116 | camera[8] -= center_h - out_h // 2 117 | src_cameras[:, 4] -= center_w - out_w // 2 118 | src_cameras[:, 8] -= center_h - out_h // 2 119 | src_cameras[:, 0] = out_h 120 | src_cameras[:, 1] = out_w 121 | return rgb_out, camera, src_rgbs, src_cameras 122 | 123 | 124 | def random_flip(rgb, camera, src_rgbs, src_cameras): 125 | h, w = rgb.shape[:2] 126 | h_r, w_r = src_rgbs.shape[1:3] 127 | rgb_out = np.flip(rgb, axis=1).copy() 128 | src_rgbs = np.flip(src_rgbs, axis=-2).copy() 129 | camera[2] *= -1 130 | camera[4] = w - 1. - camera[4] 131 | src_cameras[:, 2] *= -1 132 | src_cameras[:, 4] = w_r - 1. - src_cameras[:, 4] 133 | return rgb_out, camera, src_rgbs, src_cameras 134 | 135 | 136 | def get_color_jitter_params(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2): 137 | color_jitter = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 138 | transform = transforms.ColorJitter.get_params(color_jitter.brightness, 139 | color_jitter.contrast, 140 | color_jitter.saturation, 141 | color_jitter.hue) 142 | return transform 143 | 144 | 145 | def color_jitter(img, transform): 146 | ''' 147 | Args: 148 | img: np.float32 [h, w, 3] 149 | transform: 150 | Returns: transformed np.float32 151 | ''' 152 | img = Image.fromarray((255.*img).astype(np.uint8)) 153 | img_trans = transform(img) 154 | img_trans = np.array(img_trans).astype(np.float32) / 255. 155 | return img_trans 156 | 157 | 158 | def color_jitter_all_rgbs(rgb, ref_rgbs, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2): 159 | transform = get_color_jitter_params(brightness, contrast, saturation, hue) 160 | rgb_trans = color_jitter(rgb, transform) 161 | ref_rgbs_trans = [] 162 | for ref_rgb in ref_rgbs: 163 | ref_rgbs_trans.append(color_jitter(ref_rgb, transform)) 164 | 165 | ref_rgbs_trans = np.array(ref_rgbs_trans) 166 | return rgb_trans, ref_rgbs_trans 167 | 168 | 169 | def deepvoxels_parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 170 | # Get camera intrinsics 171 | with open(filepath, 'r') as file: 172 | f, cx, cy = list(map(float, file.readline().split()))[:3] 173 | grid_barycenter = torch.Tensor(list(map(float, file.readline().split()))) 174 | near_plane = float(file.readline()) 175 | scale = float(file.readline()) 176 | height, width = map(float, file.readline().split()) 177 | 178 | try: 179 | world2cam_poses = int(file.readline()) 180 | except ValueError: 181 | world2cam_poses = None 182 | 183 | if world2cam_poses is None: 184 | world2cam_poses = False 185 | 186 | world2cam_poses = bool(world2cam_poses) 187 | 188 | cx = cx / width * trgt_sidelength 189 | cy = cy / height * trgt_sidelength 190 | f = trgt_sidelength / height * f 191 | 192 | fx = f 193 | if invert_y: 194 | fy = -f 195 | else: 196 | fy = f 197 | 198 | # Build the intrinsic matrices 199 | full_intrinsic = np.array([[fx, 0., cx, 0.], 200 | [0., fy, cy, 0], 201 | [0., 0, 1, 0], 202 | [0, 0, 0, 1]]) 203 | 204 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 205 | 206 | 207 | def angular_dist_between_2_vectors(vec1, vec2): 208 | vec1_unit = vec1 / (np.linalg.norm(vec1, axis=1, keepdims=True) + TINY_NUMBER) 209 | vec2_unit = vec2 / (np.linalg.norm(vec2, axis=1, keepdims=True) + TINY_NUMBER) 210 | angular_dists = np.arccos(np.clip(np.sum(vec1_unit*vec2_unit, axis=-1), -1.0, 1.0)) 211 | return angular_dists 212 | 213 | 214 | def batched_angular_dist_rot_matrix(R1, R2): 215 | ''' 216 | calculate the angular distance between two rotation matrices (batched) 217 | :param R1: the first rotation matrix [N, 3, 3] 218 | :param R2: the second rotation matrix [N, 3, 3] 219 | :return: angular distance in radiance [N, ] 220 | ''' 221 | assert R1.shape[-1] == 3 and R2.shape[-1] == 3 and R1.shape[-2] == 3 and R2.shape[-2] == 3 222 | return np.arccos(np.clip((np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1) / 2., 223 | a_min=-1 + TINY_NUMBER, a_max=1 - TINY_NUMBER)) 224 | 225 | 226 | def get_nearest_pose_ids(tar_pose, ref_poses, num_select, tar_id=-1, angular_dist_method='vector', 227 | scene_center=(0, 0, 0)): 228 | ''' 229 | Args: 230 | tar_pose: target pose [3, 3] 231 | ref_poses: reference poses [N, 3, 3] 232 | num_select: the number of nearest views to select 233 | Returns: the selected indices 234 | ''' 235 | num_cams = len(ref_poses) 236 | num_select = min(num_select, num_cams-1) 237 | batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0) 238 | 239 | if angular_dist_method == 'matrix': 240 | dists = batched_angular_dist_rot_matrix(batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3]) 241 | elif angular_dist_method == 'vector': 242 | tar_cam_locs = batched_tar_pose[:, :3, 3] 243 | ref_cam_locs = ref_poses[:, :3, 3] 244 | scene_center = np.array(scene_center)[None, ...] 245 | tar_vectors = tar_cam_locs - scene_center 246 | ref_vectors = ref_cam_locs - scene_center 247 | dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors) 248 | elif angular_dist_method == 'dist': 249 | tar_cam_locs = batched_tar_pose[:, :3, 3] 250 | ref_cam_locs = ref_poses[:, :3, 3] 251 | dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1) 252 | else: 253 | raise Exception('unknown angular distance calculation method!') 254 | 255 | if tar_id >= 0: 256 | assert tar_id < num_cams 257 | dists[tar_id] = 1e3 # make sure not to select the target id itself 258 | 259 | sorted_ids = np.argsort(dists) 260 | selected_ids = sorted_ids[:num_select] 261 | # print(angular_dists[selected_ids] * 180 / np.pi) 262 | return selected_ids -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from opt import get_opts 4 | import os 5 | import glob 6 | import imageio 7 | import numpy as np 8 | import cv2 9 | from einops import rearrange 10 | 11 | # data 12 | from torch.utils.data import DataLoader 13 | from datasets import dataset_dict 14 | from datasets.ray_utils import axisangle_to_R, get_rays 15 | 16 | # models 17 | from models.nerfusion import NeRFusion2 18 | from models.rendering import render, MAX_SAMPLES 19 | 20 | # optimizer, losses 21 | from apex.optimizers import FusedAdam 22 | from torch.optim.lr_scheduler import CosineAnnealingLR 23 | from losses import NeRFLoss 24 | 25 | # metrics 26 | from torchmetrics import ( 27 | PeakSignalNoiseRatio, 28 | StructuralSimilarityIndexMeasure 29 | ) 30 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 31 | 32 | # pytorch-lightning 33 | from pytorch_lightning.plugins import DDPPlugin 34 | from pytorch_lightning import LightningModule, Trainer 35 | from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint 36 | from pytorch_lightning.loggers import TensorBoardLogger 37 | from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available 38 | 39 | from utils import slim_ckpt, load_ckpt 40 | 41 | import warnings; warnings.filterwarnings("ignore") 42 | 43 | 44 | def depth2img(depth): 45 | depth = (depth-depth.min())/(depth.max()-depth.min()) 46 | depth_img = cv2.applyColorMap((depth*255).astype(np.uint8), 47 | cv2.COLORMAP_TURBO) 48 | 49 | return depth_img 50 | 51 | 52 | class NeRFSystem(LightningModule): 53 | def __init__(self, hparams): 54 | super().__init__() 55 | self.save_hyperparameters(hparams) 56 | 57 | self.warmup_steps = 256 58 | self.update_interval = 16 59 | 60 | self.loss = NeRFLoss(lambda_distortion=self.hparams.distortion_loss_w) 61 | self.train_psnr = PeakSignalNoiseRatio(data_range=1) 62 | self.val_psnr = PeakSignalNoiseRatio(data_range=1) 63 | self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1) 64 | if self.hparams.eval_lpips: 65 | self.val_lpips = LearnedPerceptualImagePatchSimilarity('vgg') 66 | for p in self.val_lpips.net.parameters(): 67 | p.requires_grad = False 68 | 69 | self.model = NeRFusion2(scale=self.hparams.scale) 70 | 71 | def forward(self, batch, split): 72 | if split=='train': 73 | poses = self.poses[batch['img_idxs']] 74 | directions = self.directions[batch['pix_idxs']] 75 | else: 76 | poses = batch['pose'] 77 | directions = self.directions 78 | 79 | if self.hparams.optimize_ext: 80 | dR = axisangle_to_R(self.dR[batch['img_idxs']]) 81 | poses[..., :3] = dR @ poses[..., :3] 82 | poses[..., 3] += self.dT[batch['img_idxs']] 83 | 84 | rays_o, rays_d = get_rays(directions, poses) 85 | 86 | kwargs = {'test_time': split!='train', 87 | 'random_bg': self.hparams.random_bg} 88 | if self.hparams.scale > 0.5: 89 | kwargs['exp_step_factor'] = 1/256 90 | 91 | return render(self.model, rays_o, rays_d, **kwargs) 92 | 93 | def setup(self, stage): 94 | dataset = dataset_dict[self.hparams.dataset_name] 95 | kwargs = {'root_dir': self.hparams.root_dir, 96 | 'downsample': self.hparams.downsample} 97 | self.train_dataset = dataset(split=self.hparams.split, **kwargs) 98 | self.train_dataset.batch_size = self.hparams.batch_size 99 | self.train_dataset.ray_sampling_strategy = self.hparams.ray_sampling_strategy 100 | 101 | self.test_dataset = dataset(split='test', **kwargs) 102 | 103 | # define additional parameters 104 | self.register_buffer('directions', self.train_dataset.directions.to(self.device)) 105 | self.register_buffer('poses', self.train_dataset.poses.to(self.device)) 106 | 107 | if self.hparams.optimize_ext: 108 | N = len(self.train_dataset.poses) 109 | self.register_parameter('dR', 110 | nn.Parameter(torch.zeros(N, 3, device=self.device))) 111 | self.register_parameter('dT', 112 | nn.Parameter(torch.zeros(N, 3, device=self.device))) 113 | 114 | def configure_optimizers(self): 115 | load_ckpt(self.model, self.hparams.weight_path) 116 | 117 | net_params = [] 118 | for n, p in self.named_parameters(): 119 | if n not in ['dR', 'dT']: net_params += [p] 120 | 121 | opts = [] 122 | self.net_opt = FusedAdam(net_params, self.hparams.lr, eps=1e-15) 123 | opts += [self.net_opt] 124 | if self.hparams.optimize_ext: 125 | opts += [FusedAdam([self.dR, self.dT], 1e-6)] # learning rate is hard-coded 126 | net_sch = CosineAnnealingLR(self.net_opt, 127 | self.hparams.num_epochs, 128 | self.hparams.lr/30) 129 | 130 | return opts, [net_sch] 131 | 132 | def train_dataloader(self): 133 | return DataLoader(self.train_dataset, 134 | num_workers=16, 135 | persistent_workers=True, 136 | batch_size=None, 137 | pin_memory=True) 138 | 139 | def val_dataloader(self): 140 | return DataLoader(self.test_dataset, 141 | num_workers=8, 142 | batch_size=None, 143 | pin_memory=True) 144 | 145 | def on_train_start(self): 146 | self.model.prune_cells(self.train_dataset.K.to(self.device), 147 | self.poses, 148 | self.train_dataset.img_wh) 149 | 150 | def training_step(self, batch, batch_nb, *args): 151 | if self.global_step%self.update_interval == 0: 152 | self.model.update_density_grid(0.01*MAX_SAMPLES/3**0.5, 153 | warmup=self.global_step 1 c h w', h=h) 190 | rgb_gt = rearrange(rgb_gt, '(h w) c -> 1 c h w', h=h) 191 | self.val_ssim(rgb_pred, rgb_gt) 192 | logs['ssim'] = self.val_ssim.compute() 193 | self.val_ssim.reset() 194 | if self.hparams.eval_lpips: 195 | self.val_lpips(torch.clip(rgb_pred*2-1, -1, 1), 196 | torch.clip(rgb_gt*2-1, -1, 1)) 197 | logs['lpips'] = self.val_lpips.compute() 198 | self.val_lpips.reset() 199 | 200 | if not self.hparams.no_save_test: # save test image to disk 201 | idx = batch['img_idxs'] 202 | rgb_pred = rearrange(results['rgb'].cpu().numpy(), '(h w) c -> h w c', h=h) 203 | rgb_pred = (rgb_pred*255).astype(np.uint8) 204 | imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}.png'), rgb_pred) 205 | 206 | return logs 207 | 208 | def validation_epoch_end(self, outputs): 209 | psnrs = torch.stack([x['psnr'] for x in outputs]) 210 | mean_psnr = all_gather_ddp_if_available(psnrs).mean() 211 | self.log('test/psnr', mean_psnr, True) 212 | 213 | ssims = torch.stack([x['ssim'] for x in outputs]) 214 | mean_ssim = all_gather_ddp_if_available(ssims).mean() 215 | self.log('test/ssim', mean_ssim) 216 | 217 | if self.hparams.eval_lpips: 218 | lpipss = torch.stack([x['lpips'] for x in outputs]) 219 | mean_lpips = all_gather_ddp_if_available(lpipss).mean() 220 | self.log('test/lpips_vgg', mean_lpips) 221 | 222 | def get_progress_bar_dict(self): 223 | # don't show the version number 224 | items = super().get_progress_bar_dict() 225 | items.pop("v_num", None) 226 | return items 227 | 228 | 229 | if __name__ == '__main__': 230 | hparams = get_opts() 231 | if hparams.val_only and (not hparams.ckpt_path): 232 | raise ValueError('You need to provide a @ckpt_path for validation!') 233 | system = NeRFSystem(hparams) 234 | 235 | ckpt_cb = ModelCheckpoint(dirpath=f'ckpts/{hparams.dataset_name}/{hparams.exp_name}', 236 | filename='{epoch:d}', 237 | save_weights_only=False, 238 | every_n_epochs=hparams.num_epochs, 239 | save_on_train_epoch_end=True, 240 | save_top_k=-1) 241 | callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)] 242 | 243 | logger = TensorBoardLogger(save_dir=f"logs/{hparams.dataset_name}", 244 | name=hparams.exp_name, 245 | default_hp_metric=False) 246 | 247 | trainer = Trainer(max_epochs=hparams.num_epochs, 248 | check_val_every_n_epoch=hparams.num_epochs, 249 | callbacks=callbacks, 250 | logger=logger, 251 | enable_model_summary=False, 252 | accelerator='gpu', 253 | devices=hparams.num_gpus, 254 | strategy=DDPPlugin(find_unused_parameters=False) 255 | if hparams.num_gpus>1 else None, 256 | num_sanity_val_steps=-1 if hparams.val_only else 0, 257 | precision=16) 258 | 259 | trainer.fit(system, ckpt_path=hparams.ckpt_path) 260 | 261 | if not hparams.val_only: # save slimmed ckpt for the last epoch 262 | ckpt_ = \ 263 | slim_ckpt(f'ckpts/{hparams.dataset_name}/{hparams.exp_name}/epoch={hparams.num_epochs-1}.ckpt', 264 | save_poses=hparams.optimize_ext) 265 | torch.save(ckpt_, f'ckpts/{hparams.dataset_name}/{hparams.exp_name}/epoch={hparams.num_epochs-1}_slim.ckpt') 266 | 267 | if (not hparams.no_save_test) and \ 268 | hparams.dataset_name=='nsvf' and \ 269 | 'Synthetic' in hparams.root_dir: # save video 270 | imgs = sorted(glob.glob(os.path.join(system.val_dir, '*.png'))) 271 | imageio.mimsave(os.path.join(system.val_dir, 'rgb.mp4'), 272 | [imageio.imread(img) for img in imgs[::2]], 273 | fps=30, macro_block_size=1) 274 | imageio.mimsave(os.path.join(system.val_dir, 'depth.mp4'), 275 | [imageio.imread(img) for img in imgs[1::2]], 276 | fps=30, macro_block_size=1) 277 | -------------------------------------------------------------------------------- /models/csrc/include/helper_math.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | * 3 | * Redistribution and use in source and binary forms, with or without 4 | * modification, are permitted provided that the following conditions 5 | * are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of NVIDIA CORPORATION nor the names of its 12 | * contributors may be used to endorse or promote products derived 13 | * from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | /* 29 | * This file implements common mathematical operations on vector types 30 | * (float3, float4 etc.) since these are not provided as standard by CUDA. 31 | * 32 | * The syntax is modeled on the Cg standard library. 33 | * 34 | * This is part of the Helper library includes 35 | * 36 | * Thanks to Linh Hah for additions and fixes. 37 | */ 38 | 39 | #ifndef HELPER_MATH_H 40 | #define HELPER_MATH_H 41 | 42 | #include "cuda_runtime.h" 43 | 44 | typedef unsigned int uint; 45 | typedef unsigned short ushort; 46 | 47 | #ifndef EXIT_WAIVED 48 | #define EXIT_WAIVED 2 49 | #endif 50 | 51 | #ifndef __CUDACC__ 52 | #include 53 | 54 | //////////////////////////////////////////////////////////////////////////////// 55 | // host implementations of CUDA functions 56 | //////////////////////////////////////////////////////////////////////////////// 57 | 58 | inline float fminf(float a, float b) 59 | { 60 | return a < b ? a : b; 61 | } 62 | 63 | inline float fmaxf(float a, float b) 64 | { 65 | return a > b ? a : b; 66 | } 67 | 68 | inline int max(int a, int b) 69 | { 70 | return a > b ? a : b; 71 | } 72 | 73 | inline int min(int a, int b) 74 | { 75 | return a < b ? a : b; 76 | } 77 | 78 | inline float rsqrtf(float x) 79 | { 80 | return 1.0f / sqrtf(x); 81 | } 82 | #endif 83 | 84 | //////////////////////////////////////////////////////////////////////////////// 85 | // constructors 86 | //////////////////////////////////////////////////////////////////////////////// 87 | 88 | inline __host__ __device__ float2 make_float2(float s) 89 | { 90 | return make_float2(s, s); 91 | } 92 | inline __host__ __device__ float2 make_float2(float3 a) 93 | { 94 | return make_float2(a.x, a.y); 95 | } 96 | inline __host__ __device__ float3 make_float3(float s) 97 | { 98 | return make_float3(s, s, s); 99 | } 100 | inline __host__ __device__ float3 make_float3(float2 a) 101 | { 102 | return make_float3(a.x, a.y, 0.0f); 103 | } 104 | inline __host__ __device__ float3 make_float3(float2 a, float s) 105 | { 106 | return make_float3(a.x, a.y, s); 107 | } 108 | 109 | //////////////////////////////////////////////////////////////////////////////// 110 | // negate 111 | //////////////////////////////////////////////////////////////////////////////// 112 | 113 | inline __host__ __device__ float3 operator-(float3 &a) 114 | { 115 | return make_float3(-a.x, -a.y, -a.z); 116 | } 117 | 118 | //////////////////////////////////////////////////////////////////////////////// 119 | // addition 120 | //////////////////////////////////////////////////////////////////////////////// 121 | 122 | inline __host__ __device__ float3 operator+(float3 a, float3 b) 123 | { 124 | return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); 125 | } 126 | inline __host__ __device__ void operator+=(float3 &a, float3 b) 127 | { 128 | a.x += b.x; 129 | a.y += b.y; 130 | a.z += b.z; 131 | } 132 | inline __host__ __device__ float3 operator+(float3 a, float b) 133 | { 134 | return make_float3(a.x + b, a.y + b, a.z + b); 135 | } 136 | inline __host__ __device__ void operator+=(float3 &a, float b) 137 | { 138 | a.x += b; 139 | a.y += b; 140 | a.z += b; 141 | } 142 | inline __host__ __device__ float3 operator+(float b, float3 a) 143 | { 144 | return make_float3(a.x + b, a.y + b, a.z + b); 145 | } 146 | 147 | //////////////////////////////////////////////////////////////////////////////// 148 | // subtract 149 | //////////////////////////////////////////////////////////////////////////////// 150 | 151 | inline __host__ __device__ float3 operator-(float3 a, float3 b) 152 | { 153 | return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); 154 | } 155 | inline __host__ __device__ void operator-=(float3 &a, float3 b) 156 | { 157 | a.x -= b.x; 158 | a.y -= b.y; 159 | a.z -= b.z; 160 | } 161 | inline __host__ __device__ float3 operator-(float3 a, float b) 162 | { 163 | return make_float3(a.x - b, a.y - b, a.z - b); 164 | } 165 | inline __host__ __device__ float3 operator-(float b, float3 a) 166 | { 167 | return make_float3(b - a.x, b - a.y, b - a.z); 168 | } 169 | inline __host__ __device__ void operator-=(float3 &a, float b) 170 | { 171 | a.x -= b; 172 | a.y -= b; 173 | a.z -= b; 174 | } 175 | 176 | //////////////////////////////////////////////////////////////////////////////// 177 | // multiply 178 | //////////////////////////////////////////////////////////////////////////////// 179 | 180 | inline __host__ __device__ float3 operator*(float3 a, float3 b) 181 | { 182 | return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); 183 | } 184 | inline __host__ __device__ void operator*=(float3 &a, float3 b) 185 | { 186 | a.x *= b.x; 187 | a.y *= b.y; 188 | a.z *= b.z; 189 | } 190 | inline __host__ __device__ float3 operator*(float3 a, float b) 191 | { 192 | return make_float3(a.x * b, a.y * b, a.z * b); 193 | } 194 | inline __host__ __device__ float3 operator*(float b, float3 a) 195 | { 196 | return make_float3(b * a.x, b * a.y, b * a.z); 197 | } 198 | inline __host__ __device__ void operator*=(float3 &a, float b) 199 | { 200 | a.x *= b; 201 | a.y *= b; 202 | a.z *= b; 203 | } 204 | 205 | //////////////////////////////////////////////////////////////////////////////// 206 | // divide 207 | //////////////////////////////////////////////////////////////////////////////// 208 | 209 | inline __host__ __device__ float2 operator/(float2 a, float2 b) 210 | { 211 | return make_float2(a.x / b.x, a.y / b.y); 212 | } 213 | inline __host__ __device__ void operator/=(float2 &a, float2 b) 214 | { 215 | a.x /= b.x; 216 | a.y /= b.y; 217 | } 218 | inline __host__ __device__ float2 operator/(float2 a, float b) 219 | { 220 | return make_float2(a.x / b, a.y / b); 221 | } 222 | inline __host__ __device__ void operator/=(float2 &a, float b) 223 | { 224 | a.x /= b; 225 | a.y /= b; 226 | } 227 | inline __host__ __device__ float2 operator/(float b, float2 a) 228 | { 229 | return make_float2(b / a.x, b / a.y); 230 | } 231 | 232 | inline __host__ __device__ float3 operator/(float3 a, float3 b) 233 | { 234 | return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); 235 | } 236 | inline __host__ __device__ void operator/=(float3 &a, float3 b) 237 | { 238 | a.x /= b.x; 239 | a.y /= b.y; 240 | a.z /= b.z; 241 | } 242 | inline __host__ __device__ float3 operator/(float3 a, float b) 243 | { 244 | return make_float3(a.x / b, a.y / b, a.z / b); 245 | } 246 | inline __host__ __device__ void operator/=(float3 &a, float b) 247 | { 248 | a.x /= b; 249 | a.y /= b; 250 | a.z /= b; 251 | } 252 | inline __host__ __device__ float3 operator/(float b, float3 a) 253 | { 254 | return make_float3(b / a.x, b / a.y, b / a.z); 255 | } 256 | 257 | //////////////////////////////////////////////////////////////////////////////// 258 | // min 259 | //////////////////////////////////////////////////////////////////////////////// 260 | 261 | inline __host__ __device__ float3 fminf(float3 a, float3 b) 262 | { 263 | return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z)); 264 | } 265 | 266 | //////////////////////////////////////////////////////////////////////////////// 267 | // max 268 | //////////////////////////////////////////////////////////////////////////////// 269 | 270 | inline __host__ __device__ float3 fmaxf(float3 a, float3 b) 271 | { 272 | return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z)); 273 | } 274 | 275 | //////////////////////////////////////////////////////////////////////////////// 276 | // clamp 277 | // - clamp the value v to be in the range [a, b] 278 | //////////////////////////////////////////////////////////////////////////////// 279 | 280 | inline __device__ __host__ float clamp(float f, float a, float b) 281 | { 282 | return fmaxf(a, fminf(f, b)); 283 | } 284 | inline __device__ __host__ int clamp(int f, int a, int b) 285 | { 286 | return max(a, min(f, b)); 287 | } 288 | inline __device__ __host__ uint clamp(uint f, uint a, uint b) 289 | { 290 | return max(a, min(f, b)); 291 | } 292 | 293 | inline __device__ __host__ float3 clamp(float3 v, float a, float b) 294 | { 295 | return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); 296 | } 297 | inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b) 298 | { 299 | return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); 300 | } 301 | 302 | //////////////////////////////////////////////////////////////////////////////// 303 | // dot product 304 | //////////////////////////////////////////////////////////////////////////////// 305 | 306 | inline __host__ __device__ float dot(float3 a, float3 b) 307 | { 308 | return a.x * b.x + a.y * b.y + a.z * b.z; 309 | } 310 | 311 | //////////////////////////////////////////////////////////////////////////////// 312 | // length 313 | //////////////////////////////////////////////////////////////////////////////// 314 | 315 | inline __host__ __device__ float length(float3 v) 316 | { 317 | return sqrtf(dot(v, v)); 318 | } 319 | 320 | //////////////////////////////////////////////////////////////////////////////// 321 | // normalize 322 | //////////////////////////////////////////////////////////////////////////////// 323 | 324 | inline __host__ __device__ float3 normalize(float3 v) 325 | { 326 | float invLen = rsqrtf(dot(v, v)); 327 | return v * invLen; 328 | } 329 | 330 | //////////////////////////////////////////////////////////////////////////////// 331 | // reflect 332 | // - returns reflection of incident ray I around surface normal N 333 | // - N should be normalized, reflected vector's length is equal to length of I 334 | //////////////////////////////////////////////////////////////////////////////// 335 | 336 | inline __host__ __device__ float3 reflect(float3 i, float3 n) 337 | { 338 | return i - 2.0f * n * dot(n,i); 339 | } 340 | 341 | //////////////////////////////////////////////////////////////////////////////// 342 | // cross product 343 | //////////////////////////////////////////////////////////////////////////////// 344 | 345 | inline __host__ __device__ float3 cross(float3 a, float3 b) 346 | { 347 | return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); 348 | } 349 | 350 | //////////////////////////////////////////////////////////////////////////////// 351 | // smoothstep 352 | // - returns 0 if x < a 353 | // - returns 1 if x > b 354 | // - otherwise returns smooth interpolation between 0 and 1 based on x 355 | //////////////////////////////////////////////////////////////////////////////// 356 | 357 | inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x) 358 | { 359 | float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f); 360 | return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y))); 361 | } 362 | 363 | #endif 364 | -------------------------------------------------------------------------------- /models/csrc/volumerendering.cu: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | #include 3 | 4 | 5 | template 6 | __global__ void composite_train_fw_kernel( 7 | const torch::PackedTensorAccessor sigmas, 8 | const torch::PackedTensorAccessor rgbs, 9 | const torch::PackedTensorAccessor deltas, 10 | const torch::PackedTensorAccessor ts, 11 | const torch::PackedTensorAccessor64 rays_a, 12 | const scalar_t T_threshold, 13 | torch::PackedTensorAccessor64 total_samples, 14 | torch::PackedTensorAccessor opacity, 15 | torch::PackedTensorAccessor depth, 16 | torch::PackedTensorAccessor rgb, 17 | torch::PackedTensorAccessor ws 18 | ){ 19 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 20 | if (n >= opacity.size(0)) return; 21 | 22 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 23 | 24 | // front to back compositing 25 | int samples = 0; scalar_t T = 1.0f; 26 | 27 | while (samples < N_samples) { 28 | const int s = start_idx + samples; 29 | const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]); 30 | const scalar_t w = a * T; // weight of the sample point 31 | 32 | rgb[ray_idx][0] += w*rgbs[s][0]; 33 | rgb[ray_idx][1] += w*rgbs[s][1]; 34 | rgb[ray_idx][2] += w*rgbs[s][2]; 35 | depth[ray_idx] += w*ts[s]; 36 | opacity[ray_idx] += w; 37 | ws[s] = w; 38 | T *= 1.0f-a; 39 | 40 | if (T <= T_threshold) break; // ray has enough opacity 41 | samples++; 42 | } 43 | total_samples[ray_idx] = samples; 44 | } 45 | 46 | 47 | std::vector composite_train_fw_cu( 48 | const torch::Tensor sigmas, 49 | const torch::Tensor rgbs, 50 | const torch::Tensor deltas, 51 | const torch::Tensor ts, 52 | const torch::Tensor rays_a, 53 | const float T_threshold 54 | ){ 55 | const int N_rays = rays_a.size(0), N = sigmas.size(0); 56 | 57 | auto opacity = torch::zeros({N_rays}, sigmas.options()); 58 | auto depth = torch::zeros({N_rays}, sigmas.options()); 59 | auto rgb = torch::zeros({N_rays, 3}, sigmas.options()); 60 | auto ws = torch::zeros({N}, sigmas.options()); 61 | auto total_samples = torch::zeros({N_rays}, torch::dtype(torch::kLong).device(sigmas.device())); 62 | 63 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 64 | 65 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_train_fw_cu", 66 | ([&] { 67 | composite_train_fw_kernel<<>>( 68 | sigmas.packed_accessor(), 69 | rgbs.packed_accessor(), 70 | deltas.packed_accessor(), 71 | ts.packed_accessor(), 72 | rays_a.packed_accessor64(), 73 | T_threshold, 74 | total_samples.packed_accessor64(), 75 | opacity.packed_accessor(), 76 | depth.packed_accessor(), 77 | rgb.packed_accessor(), 78 | ws.packed_accessor() 79 | ); 80 | })); 81 | 82 | return {total_samples, opacity, depth, rgb, ws}; 83 | } 84 | 85 | 86 | template 87 | __global__ void composite_train_bw_kernel( 88 | const torch::PackedTensorAccessor dL_dopacity, 89 | const torch::PackedTensorAccessor dL_ddepth, 90 | const torch::PackedTensorAccessor dL_drgb, 91 | const torch::PackedTensorAccessor dL_dws, 92 | scalar_t* __restrict__ dL_dws_times_ws, 93 | const torch::PackedTensorAccessor sigmas, 94 | const torch::PackedTensorAccessor rgbs, 95 | const torch::PackedTensorAccessor deltas, 96 | const torch::PackedTensorAccessor ts, 97 | const torch::PackedTensorAccessor64 rays_a, 98 | const torch::PackedTensorAccessor opacity, 99 | const torch::PackedTensorAccessor depth, 100 | const torch::PackedTensorAccessor rgb, 101 | const scalar_t T_threshold, 102 | torch::PackedTensorAccessor dL_dsigmas, 103 | torch::PackedTensorAccessor dL_drgbs 104 | ){ 105 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 106 | if (n >= opacity.size(0)) return; 107 | 108 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 109 | 110 | // front to back compositing 111 | int samples = 0; 112 | scalar_t R = rgb[ray_idx][0], G = rgb[ray_idx][1], B = rgb[ray_idx][2]; 113 | scalar_t O = opacity[ray_idx], D = depth[ray_idx]; 114 | scalar_t T = 1.0f, r = 0.0f, g = 0.0f, b = 0.0f, d = 0.0f; 115 | 116 | // compute prefix sum of dL_dws * ws 117 | // [a0, a1, a2, a3, ...] -> [a0, a0+a1, a0+a1+a2, a0+a1+a2+a3, ...] 118 | thrust::inclusive_scan(thrust::device, 119 | dL_dws_times_ws+start_idx, 120 | dL_dws_times_ws+start_idx+N_samples, 121 | dL_dws_times_ws+start_idx); 122 | scalar_t dL_dws_times_ws_sum = dL_dws_times_ws[start_idx+N_samples-1]; 123 | 124 | while (samples < N_samples) { 125 | const int s = start_idx + samples; 126 | const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]); 127 | const scalar_t w = a * T; 128 | 129 | r += w*rgbs[s][0]; g += w*rgbs[s][1]; b += w*rgbs[s][2]; 130 | d += w*ts[s]; 131 | T *= 1.0f-a; 132 | 133 | // compute gradients by math... 134 | dL_drgbs[s][0] = dL_drgb[ray_idx][0]*w; 135 | dL_drgbs[s][1] = dL_drgb[ray_idx][1]*w; 136 | dL_drgbs[s][2] = dL_drgb[ray_idx][2]*w; 137 | 138 | dL_dsigmas[s] = deltas[s] * ( 139 | dL_drgb[ray_idx][0]*(rgbs[s][0]*T-(R-r)) + 140 | dL_drgb[ray_idx][1]*(rgbs[s][1]*T-(G-g)) + 141 | dL_drgb[ray_idx][2]*(rgbs[s][2]*T-(B-b)) + // gradients from rgb 142 | dL_dopacity[ray_idx]*(1-O) + // gradient from opacity 143 | dL_ddepth[ray_idx]*(ts[s]*T-(D-d)) + // gradient from depth 144 | T*dL_dws[s]-(dL_dws_times_ws_sum-dL_dws_times_ws[s]) // gradient from ws 145 | ); 146 | 147 | if (T <= T_threshold) break; // ray has enough opacity 148 | samples++; 149 | } 150 | } 151 | 152 | 153 | std::vector composite_train_bw_cu( 154 | const torch::Tensor dL_dopacity, 155 | const torch::Tensor dL_ddepth, 156 | const torch::Tensor dL_drgb, 157 | const torch::Tensor dL_dws, 158 | const torch::Tensor sigmas, 159 | const torch::Tensor rgbs, 160 | const torch::Tensor ws, 161 | const torch::Tensor deltas, 162 | const torch::Tensor ts, 163 | const torch::Tensor rays_a, 164 | const torch::Tensor opacity, 165 | const torch::Tensor depth, 166 | const torch::Tensor rgb, 167 | const float T_threshold 168 | ){ 169 | const int N = sigmas.size(0), N_rays = rays_a.size(0); 170 | 171 | auto dL_dsigmas = torch::zeros({N}, sigmas.options()); 172 | auto dL_drgbs = torch::zeros({N, 3}, sigmas.options()); 173 | 174 | auto dL_dws_times_ws = dL_dws * ws; // auxiliary input 175 | 176 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 177 | 178 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_train_bw_cu", 179 | ([&] { 180 | composite_train_bw_kernel<<>>( 181 | dL_dopacity.packed_accessor(), 182 | dL_ddepth.packed_accessor(), 183 | dL_drgb.packed_accessor(), 184 | dL_dws.packed_accessor(), 185 | dL_dws_times_ws.data_ptr(), 186 | sigmas.packed_accessor(), 187 | rgbs.packed_accessor(), 188 | deltas.packed_accessor(), 189 | ts.packed_accessor(), 190 | rays_a.packed_accessor64(), 191 | opacity.packed_accessor(), 192 | depth.packed_accessor(), 193 | rgb.packed_accessor(), 194 | T_threshold, 195 | dL_dsigmas.packed_accessor(), 196 | dL_drgbs.packed_accessor() 197 | ); 198 | })); 199 | 200 | return {dL_dsigmas, dL_drgbs}; 201 | } 202 | 203 | 204 | template 205 | __global__ void composite_test_fw_kernel( 206 | const torch::PackedTensorAccessor sigmas, 207 | const torch::PackedTensorAccessor rgbs, 208 | const torch::PackedTensorAccessor deltas, 209 | const torch::PackedTensorAccessor ts, 210 | const torch::PackedTensorAccessor hits_t, 211 | torch::PackedTensorAccessor64 alive_indices, 212 | const scalar_t T_threshold, 213 | const torch::PackedTensorAccessor32 N_eff_samples, 214 | torch::PackedTensorAccessor opacity, 215 | torch::PackedTensorAccessor depth, 216 | torch::PackedTensorAccessor rgb 217 | ){ 218 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 219 | if (n >= alive_indices.size(0)) return; 220 | 221 | if (N_eff_samples[n]==0){ // no hit 222 | alive_indices[n] = -1; 223 | return; 224 | } 225 | 226 | const size_t r = alive_indices[n]; // ray index 227 | 228 | // front to back compositing 229 | int s = 0; scalar_t T = 1-opacity[r]; 230 | 231 | while (s < N_eff_samples[n]) { 232 | const scalar_t a = 1.0f - __expf(-sigmas[n][s]*deltas[n][s]); 233 | const scalar_t w = a * T; 234 | 235 | rgb[r][0] += w*rgbs[n][s][0]; 236 | rgb[r][1] += w*rgbs[n][s][1]; 237 | rgb[r][2] += w*rgbs[n][s][2]; 238 | depth[r] += w*ts[n][s]; 239 | opacity[r] += w; 240 | T *= 1.0f-a; 241 | 242 | if (T <= T_threshold){ // ray has enough opacity 243 | alive_indices[n] = -1; 244 | break; 245 | } 246 | s++; 247 | } 248 | } 249 | 250 | 251 | void composite_test_fw_cu( 252 | const torch::Tensor sigmas, 253 | const torch::Tensor rgbs, 254 | const torch::Tensor deltas, 255 | const torch::Tensor ts, 256 | const torch::Tensor hits_t, 257 | torch::Tensor alive_indices, 258 | const float T_threshold, 259 | const torch::Tensor N_eff_samples, 260 | torch::Tensor opacity, 261 | torch::Tensor depth, 262 | torch::Tensor rgb 263 | ){ 264 | const int N_rays = alive_indices.size(0); 265 | 266 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 267 | 268 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_test_fw_cu", 269 | ([&] { 270 | composite_test_fw_kernel<<>>( 271 | sigmas.packed_accessor(), 272 | rgbs.packed_accessor(), 273 | deltas.packed_accessor(), 274 | ts.packed_accessor(), 275 | hits_t.packed_accessor(), 276 | alive_indices.packed_accessor64(), 277 | T_threshold, 278 | N_eff_samples.packed_accessor32(), 279 | opacity.packed_accessor(), 280 | depth.packed_accessor(), 281 | rgb.packed_accessor() 282 | ); 283 | })); 284 | } -------------------------------------------------------------------------------- /datasets/colmap_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | class Image(BaseImage): 49 | def qvec2rotmat(self): 50 | return qvec2rotmat(self.qvec) 51 | 52 | 53 | CAMERA_MODELS = { 54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 61 | CameraModel(model_id=7, model_name="FOV", num_params=5), 62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 65 | } 66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 67 | for camera_model in CAMERA_MODELS]) 68 | 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | 82 | def read_cameras_text(path): 83 | """ 84 | see: src/base/reconstruction.cc 85 | void Reconstruction::WriteCamerasText(const std::string& path) 86 | void Reconstruction::ReadCamerasText(const std::string& path) 87 | """ 88 | cameras = {} 89 | with open(path, "r") as fid: 90 | while True: 91 | line = fid.readline() 92 | if not line: 93 | break 94 | line = line.strip() 95 | if len(line) > 0 and line[0] != "#": 96 | elems = line.split() 97 | camera_id = int(elems[0]) 98 | model = elems[1] 99 | width = int(elems[2]) 100 | height = int(elems[3]) 101 | params = np.array(tuple(map(float, elems[4:]))) 102 | cameras[camera_id] = Camera(id=camera_id, model=model, 103 | width=width, height=height, 104 | params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes(fid, num_bytes=8*num_params, 127 | format_char_sequence="d"*num_params) 128 | cameras[camera_id] = Camera(id=camera_id, 129 | model=model_name, 130 | width=width, 131 | height=height, 132 | params=np.array(params)) 133 | assert len(cameras) == num_cameras 134 | return cameras 135 | 136 | 137 | def read_images_text(path): 138 | """ 139 | see: src/base/reconstruction.cc 140 | void Reconstruction::ReadImagesText(const std::string& path) 141 | void Reconstruction::WriteImagesText(const std::string& path) 142 | """ 143 | images = {} 144 | with open(path, "r") as fid: 145 | while True: 146 | line = fid.readline() 147 | if not line: 148 | break 149 | line = line.strip() 150 | if len(line) > 0 and line[0] != "#": 151 | elems = line.split() 152 | image_id = int(elems[0]) 153 | qvec = np.array(tuple(map(float, elems[1:5]))) 154 | tvec = np.array(tuple(map(float, elems[5:8]))) 155 | camera_id = int(elems[8]) 156 | image_name = elems[9] 157 | elems = fid.readline().split() 158 | xys = np.column_stack([tuple(map(float, elems[0::3])), 159 | tuple(map(float, elems[1::3]))]) 160 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 161 | images[image_id] = Image( 162 | id=image_id, qvec=qvec, tvec=tvec, 163 | camera_id=camera_id, name=image_name, 164 | xys=xys, point3D_ids=point3D_ids) 165 | return images 166 | 167 | 168 | def read_images_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for image_index in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_points3D_text(path): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::ReadPoints3DText(const std::string& path) 207 | void Reconstruction::WritePoints3DText(const std::string& path) 208 | """ 209 | points3D = {} 210 | with open(path, "r") as fid: 211 | while True: 212 | line = fid.readline() 213 | if not line: 214 | break 215 | line = line.strip() 216 | if len(line) > 0 and line[0] != "#": 217 | elems = line.split() 218 | point3D_id = int(elems[0]) 219 | xyz = np.array(tuple(map(float, elems[1:4]))) 220 | rgb = np.array(tuple(map(int, elems[4:7]))) 221 | error = float(elems[7]) 222 | image_ids = np.array(tuple(map(int, elems[8::2]))) 223 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 225 | error=error, image_ids=image_ids, 226 | point2D_idxs=point2D_idxs) 227 | return points3D 228 | 229 | 230 | def read_points3d_binary(path_to_model_file): 231 | """ 232 | see: src/base/reconstruction.cc 233 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 234 | void Reconstruction::WritePoints3DBinary(const std::string& path) 235 | """ 236 | points3D = {} 237 | with open(path_to_model_file, "rb") as fid: 238 | num_points = read_next_bytes(fid, 8, "Q")[0] 239 | for point_line_index in range(num_points): 240 | binary_point_line_properties = read_next_bytes( 241 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 242 | point3D_id = binary_point_line_properties[0] 243 | xyz = np.array(binary_point_line_properties[1:4]) 244 | rgb = np.array(binary_point_line_properties[4:7]) 245 | error = np.array(binary_point_line_properties[7]) 246 | track_length = read_next_bytes( 247 | fid, num_bytes=8, format_char_sequence="Q")[0] 248 | track_elems = read_next_bytes( 249 | fid, num_bytes=8*track_length, 250 | format_char_sequence="ii"*track_length) 251 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 253 | points3D[point3D_id] = Point3D( 254 | id=point3D_id, xyz=xyz, rgb=rgb, 255 | error=error, image_ids=image_ids, 256 | point2D_idxs=point2D_idxs) 257 | return points3D 258 | 259 | 260 | def read_model(path, ext): 261 | if ext == ".txt": 262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 263 | images = read_images_text(os.path.join(path, "images" + ext)) 264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 265 | else: 266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 267 | images = read_images_binary(os.path.join(path, "images" + ext)) 268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 269 | return cameras, images, points3D 270 | 271 | 272 | def qvec2rotmat(qvec): 273 | return np.array([ 274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 283 | 284 | 285 | def rotmat2qvec(R): 286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 287 | K = np.array([ 288 | [Rxx - Ryy - Rzz, 0, 0, 0], 289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 292 | eigvals, eigvecs = np.linalg.eigh(K) 293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 294 | if qvec[0] < 0: 295 | qvec *= -1 296 | return qvec -------------------------------------------------------------------------------- /representations/grufusion/gru_fusion.py: -------------------------------------------------------------------------------- 1 | # ported from NeuralRecon (https://github.com/zju3dv/NeuralRecon) 2 | import torch 3 | import torch.nn as nn 4 | from torchsparse.tensor import PointTensor 5 | from .modules import ConvGRU 6 | 7 | 8 | def sparse_to_dense_torch(locs, values, dim, default_val, device): 9 | dense = torch.full([dim[0], dim[1], dim[2]], float(default_val), device=device) 10 | if locs.shape[0] > 0: 11 | dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values 12 | return dense 13 | 14 | 15 | def sparse_to_dense_channel(locs, values, dim, c, default_val, device): 16 | dense = torch.full([dim[0], dim[1], dim[2], c], float(default_val), device=device) 17 | if locs.shape[0] > 0: 18 | dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values 19 | return dense 20 | 21 | 22 | class GRUFusion(nn.Module): 23 | """ 24 | Two functionalities of this class: 25 | 1. GRU Fusion module as in the paper. Update hidden state features with ConvGRU. 26 | 2. Substitute TSDF in the global volume when direct_substitute = True. 27 | """ 28 | 29 | def __init__(self, cfg, ch_in=None, direct_substitute=False): 30 | super(GRUFusion, self).__init__() 31 | self.cfg = cfg 32 | # replace tsdf in global tsdf volume by direct substitute corresponding voxels 33 | self.direct_substitude = direct_substitute 34 | 35 | if direct_substitute: 36 | # tsdf 37 | self.ch_in = [1, 1, 1] 38 | self.feat_init = 1 39 | else: 40 | # features 41 | self.ch_in = ch_in 42 | self.feat_init = 0 43 | 44 | self.n_scales = len(cfg.THRESHOLDS) - 1 45 | self.scene_name = [None, None, None] 46 | self.global_origin = [None, None, None] 47 | self.global_volume = [None, None, None] 48 | self.target_tsdf_volume = [None, None, None] 49 | 50 | if direct_substitute: 51 | self.fusion_nets = None 52 | else: 53 | self.fusion_nets = nn.ModuleList() 54 | for i, ch in enumerate(ch_in): 55 | self.fusion_nets.append(ConvGRU(hidden_dim=ch, 56 | input_dim=ch, 57 | pres=1, 58 | vres=self.cfg.VOXEL_SIZE * 2 ** (self.n_scales - i))) 59 | 60 | def reset(self, i): 61 | self.global_volume[i] = PointTensor(torch.Tensor([]), torch.Tensor([]).view(0, 3).long()).cuda() 62 | self.target_tsdf_volume[i] = PointTensor(torch.Tensor([]), torch.Tensor([]).view(0, 3).long()).cuda() 63 | 64 | def convert2dense(self, current_coords, current_values, coords_target_global, tsdf_target, relative_origin, 65 | scale): 66 | ''' 67 | 1. convert sparse feature to dense feature; 68 | 2. combine current feature coordinates and previous coordinates within FBV from global hidden state to get 69 | new feature coordinates (updated_coords); 70 | 3. fuse ground truth tsdf. 71 | 72 | :param current_coords: (Tensor), current coordinates, (N, 3) 73 | :param current_values: (Tensor), current features/tsdf, (N, C) 74 | :param coords_target_global: (Tensor), ground truth coordinates, (N', 3) 75 | :param tsdf_target: (Tensor), tsdf ground truth, (N',) 76 | :param relative_origin: (Tensor), origin in global volume, (3,) 77 | :param scale: 78 | :return: updated_coords: (Tensor), coordinates after combination, (N', 3) 79 | :return: current_volume: (Tensor), current dense feature/tsdf volume, (DIM_X, DIM_Y, DIM_Z, C) 80 | :return: global_volume: (Tensor), global dense feature/tsdf volume, (DIM_X, DIM_Y, DIM_Z, C) 81 | :return: target_volume: (Tensor), dense target tsdf volume, (DIM_X, DIM_Y, DIM_Z, 1) 82 | :return: valid: mask: 1 represent in current FBV (N,) 83 | :return: valid_target: gt mask: 1 represent in current FBV (N,) 84 | ''' 85 | # previous frame 86 | global_coords = self.global_volume[scale].C 87 | global_value = self.global_volume[scale].F 88 | global_tsdf_target = self.target_tsdf_volume[scale].F 89 | global_coords_target = self.target_tsdf_volume[scale].C 90 | 91 | dim = (torch.Tensor(self.cfg.N_VOX).cuda() // 2 ** (self.cfg.N_LAYER - scale - 1)).int() 92 | dim_list = dim.data.cpu().numpy().tolist() 93 | 94 | # mask voxels that are out of the FBV 95 | global_coords = global_coords - relative_origin 96 | valid = ((global_coords < dim) & (global_coords >= 0)).all(dim=-1) 97 | if self.cfg.FUSION.FULL is False: 98 | valid_volume = sparse_to_dense_torch(current_coords, 1, dim_list, 0, global_value.device) 99 | value = valid_volume[global_coords[valid][:, 0], global_coords[valid][:, 1], global_coords[valid][:, 2]] 100 | all_true = valid[valid] 101 | all_true[value == 0] = False 102 | valid[valid] = all_true 103 | # sparse to dense 104 | global_volume = sparse_to_dense_channel(global_coords[valid], global_value[valid], dim_list, self.ch_in[scale], 105 | self.feat_init, global_value.device) 106 | 107 | current_volume = sparse_to_dense_channel(current_coords, current_values, dim_list, self.ch_in[scale], 108 | self.feat_init, global_value.device) 109 | 110 | if self.cfg.FUSION.FULL is True: 111 | # change the structure of sparsity, combine current coordinates and previous coordinates from global volume 112 | if self.direct_substitude: 113 | updated_coords = torch.nonzero((global_volume.abs() < 1).any(-1) | (current_volume.abs() < 1).any(-1)) 114 | else: 115 | updated_coords = torch.nonzero((global_volume != 0).any(-1) | (current_volume != 0).any(-1)) 116 | else: 117 | updated_coords = current_coords 118 | 119 | # fuse ground truth 120 | if tsdf_target is not None: 121 | # mask voxels that are out of the FBV 122 | global_coords_target = global_coords_target - relative_origin 123 | valid_target = ((global_coords_target < dim) & (global_coords_target >= 0)).all(dim=-1) 124 | # combine current tsdf and global tsdf 125 | coords_target = torch.cat([global_coords_target[valid_target], coords_target_global])[:, :3] 126 | tsdf_target = torch.cat([global_tsdf_target[valid_target], tsdf_target.unsqueeze(-1)]) 127 | # sparse to dense 128 | target_volume = sparse_to_dense_channel(coords_target, tsdf_target, dim_list, 1, 1, 129 | tsdf_target.device) 130 | else: 131 | target_volume = valid_target = None 132 | 133 | return updated_coords, current_volume, global_volume, target_volume, valid, valid_target 134 | 135 | def update_map(self, value, coords, target_volume, valid, valid_target, 136 | relative_origin, scale): 137 | ''' 138 | Replace Hidden state/tsdf in global Hidden state/tsdf volume by direct substitute corresponding voxels 139 | :param value: (Tensor) fused feature (N, C) 140 | :param coords: (Tensor) updated coords (N, 3) 141 | :param target_volume: (Tensor) tsdf volume (DIM_X, DIM_Y, DIM_Z, 1) 142 | :param valid: (Tensor) mask: 1 represent in current FBV (N,) 143 | :param valid_target: (Tensor) gt mask: 1 represent in current FBV (N,) 144 | :param relative_origin: (Tensor), origin in global volume, (3,) 145 | :param scale: 146 | :return: 147 | ''' 148 | # pred 149 | self.global_volume[scale].F = torch.cat( 150 | [self.global_volume[scale].F[valid == False], value]) 151 | coords = coords + relative_origin 152 | self.global_volume[scale].C = torch.cat([self.global_volume[scale].C[valid == False], coords]) 153 | 154 | # target 155 | if target_volume is not None: 156 | target_volume = target_volume.squeeze() 157 | self.target_tsdf_volume[scale].F = torch.cat( 158 | [self.target_tsdf_volume[scale].F[valid_target == False], 159 | target_volume[target_volume.abs() < 1].unsqueeze(-1)]) 160 | target_coords = torch.nonzero(target_volume.abs() < 1) + relative_origin 161 | 162 | self.target_tsdf_volume[scale].C = torch.cat( 163 | [self.target_tsdf_volume[scale].C[valid_target == False], target_coords]) 164 | 165 | def save_mesh(self, scale, outputs, scene): 166 | if outputs is None: 167 | outputs = dict() 168 | if "scene_name" not in outputs: 169 | outputs['origin'] = [] 170 | outputs['scene_tsdf'] = [] 171 | outputs['scene_name'] = [] 172 | # only keep the newest result 173 | if scene in outputs['scene_name']: 174 | # delete old 175 | idx = outputs['scene_name'].index(scene) 176 | del outputs['origin'][idx] 177 | del outputs['scene_tsdf'][idx] 178 | del outputs['scene_name'][idx] 179 | 180 | # scene name 181 | outputs['scene_name'].append(scene) 182 | 183 | fuse_coords = self.global_volume[scale].C 184 | tsdf = self.global_volume[scale].F.squeeze(-1) 185 | max_c = torch.max(fuse_coords, dim=0)[0][:3] 186 | min_c = torch.min(fuse_coords, dim=0)[0][:3] 187 | outputs['origin'].append(min_c * self.cfg.VOXEL_SIZE * (2 ** (self.cfg.N_LAYER - scale - 1))) 188 | 189 | ind_coords = fuse_coords - min_c 190 | dim_list = (max_c - min_c + 1).int().data.cpu().numpy().tolist() 191 | tsdf_volume = sparse_to_dense_torch(ind_coords, tsdf, dim_list, 1, tsdf.device) 192 | outputs['scene_tsdf'].append(tsdf_volume) 193 | 194 | return outputs 195 | 196 | def forward(self, coords, values_in, inputs, scale=2, outputs=None, save_mesh=False): 197 | ''' 198 | :param coords: (Tensor), coordinates of voxels, (N, 4) (4 : Batch ind, x, y, z) 199 | :param values_in: (Tensor), features/tsdf, (N, C) 200 | :param inputs: dict: meta data from dataloader 201 | :param scale: 202 | :param outputs: 203 | :param save_mesh: a bool to indicate whether or not to save the reconstructed mesh of current sample 204 | if direct_substitude: 205 | :return: outputs: dict: { 206 | 'origin': (List), origin of the predicted partial volume, 207 | [3] 208 | 'scene_tsdf': (List), predicted tsdf volume, 209 | [(nx, ny, nz)] 210 | 'target': (List), ground truth tsdf volume, 211 | [(nx', ny', nz')] 212 | 'scene_name': (List), name of each scene in 'scene_tsdf', 213 | [string] 214 | } 215 | else: 216 | :return: updated_coords_all: (Tensor), updated coordinates, (N', 4) (4 : Batch ind, x, y, z) 217 | :return: values_all: (Tensor), features after gru fusion, (N', C) 218 | :return: tsdf_target_all: (Tensor), tsdf ground truth, (N', 1) 219 | :return: occ_target_all: (Tensor), occupancy ground truth, (N', 1) 220 | ''' 221 | if self.global_volume[scale] is not None: 222 | # delete computational graph to save memory 223 | self.global_volume[scale] = self.global_volume[scale].detach() 224 | 225 | batch_size = len(inputs['fragment']) 226 | interval = 2 ** (self.cfg.N_LAYER - scale - 1) 227 | 228 | tsdf_target_all = None 229 | occ_target_all = None 230 | values_all = None 231 | updated_coords_all = None 232 | 233 | # ---incremental fusion---- 234 | for i in range(batch_size): 235 | scene = inputs['scene'][i] # scene name 236 | global_origin = inputs['vol_origin'][i] # origin of global volume 237 | origin = inputs['vol_origin_partial'][i] # origin of part volume 238 | 239 | if scene != self.scene_name[scale] and self.scene_name[scale] is not None and self.direct_substitude: 240 | outputs = self.save_mesh(scale, outputs, self.scene_name[scale]) 241 | 242 | # if this fragment is from new scene, we reinitialize backend map 243 | if self.scene_name[scale] is None or scene != self.scene_name[scale]: 244 | self.scene_name[scale] = scene 245 | self.reset(scale) 246 | self.global_origin[scale] = global_origin 247 | 248 | # each level has its corresponding voxel size 249 | voxel_size = self.cfg.VOXEL_SIZE * interval 250 | 251 | # relative origin in global volume 252 | relative_origin = (origin - self.global_origin[scale]) / voxel_size 253 | relative_origin = relative_origin.cuda().long() 254 | 255 | batch_ind = torch.nonzero(coords[:, 0] == i).squeeze(1) 256 | if len(batch_ind) == 0: 257 | continue 258 | coords_b = coords[batch_ind, 1:].long() // interval 259 | values = values_in[batch_ind] 260 | 261 | if 'occ_list' in inputs.keys(): 262 | # get partial gt 263 | occ_target = inputs['occ_list'][self.cfg.N_LAYER - scale - 1][i] 264 | tsdf_target = inputs['tsdf_list'][self.cfg.N_LAYER - scale - 1][i][occ_target] 265 | coords_target = torch.nonzero(occ_target) 266 | else: 267 | coords_target = tsdf_target = None 268 | 269 | # convert to dense: 1. convert sparse feature to dense feature; 2. combine current feature coordinates and 270 | # previous feature coordinates within FBV from our backend map to get new feature coordinates (updated_coords) 271 | updated_coords, current_volume, global_volume, target_volume, valid, valid_target = self.convert2dense( 272 | coords_b, 273 | values, 274 | coords_target, 275 | tsdf_target, 276 | relative_origin, 277 | scale) 278 | 279 | # dense to sparse: get features using new feature coordinates (updated_coords) 280 | values = current_volume[updated_coords[:, 0], updated_coords[:, 1], updated_coords[:, 2]] 281 | global_values = global_volume[updated_coords[:, 0], updated_coords[:, 1], updated_coords[:, 2]] 282 | # get fused gt 283 | if target_volume is not None: 284 | tsdf_target = target_volume[updated_coords[:, 0], updated_coords[:, 1], updated_coords[:, 2]] 285 | occ_target = tsdf_target.abs() < 1 286 | else: 287 | tsdf_target = occ_target = None 288 | 289 | if not self.direct_substitude: 290 | # convert to aligned camera coordinate 291 | r_coords = updated_coords.detach().clone().float() 292 | r_coords = r_coords.permute(1, 0).contiguous().float() * voxel_size + origin.unsqueeze(-1).float() 293 | r_coords = torch.cat((r_coords, torch.ones_like(r_coords[:1])), dim=0) 294 | r_coords = inputs['world_to_aligned_camera'][i, :3, :] @ r_coords 295 | r_coords = torch.cat([r_coords, torch.zeros(1, r_coords.shape[-1]).to(r_coords.device)]) 296 | r_coords = r_coords.permute(1, 0).contiguous() 297 | 298 | h = PointTensor(global_values, r_coords) 299 | x = PointTensor(values, r_coords) 300 | 301 | values = self.fusion_nets[scale](h, x) 302 | 303 | # feed back to global volume (direct substitute) 304 | self.update_map(values, updated_coords, target_volume, valid, valid_target, relative_origin, scale) 305 | 306 | if updated_coords_all is None: 307 | updated_coords_all = torch.cat([torch.ones_like(updated_coords[:, :1]) * i, updated_coords * interval], 308 | dim=1) 309 | values_all = values 310 | tsdf_target_all = tsdf_target 311 | occ_target_all = occ_target 312 | else: 313 | updated_coords = torch.cat([torch.ones_like(updated_coords[:, :1]) * i, updated_coords * interval], 314 | dim=1) 315 | updated_coords_all = torch.cat([updated_coords_all, updated_coords]) 316 | values_all = torch.cat([values_all, values]) 317 | if tsdf_target_all is not None: 318 | tsdf_target_all = torch.cat([tsdf_target_all, tsdf_target]) 319 | occ_target_all = torch.cat([occ_target_all, occ_target]) 320 | 321 | if self.direct_substitude and save_mesh: 322 | outputs = self.save_mesh(scale, outputs, self.scene_name[scale]) 323 | 324 | if self.direct_substitude: 325 | return outputs 326 | else: 327 | return updated_coords_all, values_all, tsdf_target_all, occ_target_all 328 | --------------------------------------------------------------------------------