├── README.md ├── benchmarking └── benchmark_synthetic_nerf.sh ├── config ├── img │ └── config.json ├── nerf │ └── config.json └── sdf │ └── config.json ├── datasets ├── __init__.py ├── img │ └── imager.py ├── nerf │ ├── base.py │ ├── colmap.py │ ├── colmap_utils.py │ ├── color_utils.py │ ├── depth_utils.py │ ├── nerf.py │ ├── nerfpp.py │ ├── nsvf.py │ ├── ray_utils.py │ └── rtmv.py └── sdf │ └── sampler.py ├── docs └── figures │ ├── 2d_fitting.png │ ├── 3d_fitting.png │ ├── nvs.png │ └── teaser.png ├── models ├── __init__.py ├── csrc │ ├── binding.cpp │ ├── include │ │ ├── helper_math.h │ │ └── utils.h │ ├── intersection.cu │ ├── losses.cu │ ├── raymarching.cu │ ├── setup.py │ └── volumerendering.cu ├── loss │ └── nerf │ │ ├── __init__.py │ │ └── losses.py └── networks │ ├── FFB_encoder.py │ ├── Sine.py │ ├── __init__.py │ ├── img │ ├── NFFB_2d.py │ └── __init__.py │ ├── nerf │ ├── NFFB_nerf.py │ ├── __init__.py │ ├── custom_functions.py │ └── rendering.py │ └── sdf │ ├── NFFB_3d.py │ └── __init__.py ├── requirements.txt ├── scripts ├── img │ ├── common.py │ ├── opt.py │ └── utils.py ├── nvs │ ├── opt.py │ └── prepare_rtmv.py └── sdf │ ├── opt.py │ └── utils.py ├── train_img.py ├── train_nerf.py ├── train_sdf.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Neural Fourier Filter Bank 2 | 3 | 4 | This repository contains the code (in [PyTorch Lightning](https://www.pytorchlightning.ai/index.html)) for the paper: 5 | 6 | [__Neural Fourier Filter Bank__](https://arxiv.org/abs/2212.01735) 7 |
8 | [Zhijie Wu](https://zhijiew94.github.io/), [Yuhe Jin](https://scholar.google.ca/citations?user=oAYi1YQAAAAJ&hl=en), [Kwang Moo Yi](https://www.cs.ubc.ca/~kmyi/) 9 |
10 | CVPR 2023 11 | 12 | 13 | ## Introduction 14 | 15 | In this project, we propose to learn a neural field by decomposing the signal both spatially and frequency-wise. 16 | We follow the grid-based paradigm for spatial decomposition, but unlike existing work, encourage specific frequencies to be stored in each grid via Fourier feature encodings. 17 | We then apply a multi-layer perceptron with sine activations, taking these Fourier encoded features in at appropriate layers so that higher-frequency components are accumulated on top of lower-frequency components sequentially, which we sum up to form the final output. 18 | We do the evaluations in the tasks of 2D image fitting, 3D shape reconstruction, and neural radiance fields. 19 | All results are tested upon an Nvidia RTX 3090. 20 | 21 | If you have any questions, please feel free to contact Zhijie Wu (wzj.micker@gmail.com). 22 | 23 | ![teaser](./docs/figures/teaser.png) 24 | 25 | 26 | 27 | ## Key Requirements 28 | - Python 3.8 29 | - CUDA 11.6 30 | - [PyTorch 1.12.0](https://www.tensorflow.org/) 31 | - [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) 32 | - [torch-scatter](https://github.com/rusty1s/pytorch_scatter#installation) 33 | - [apex](https://github.com/NVIDIA/apex#linux) 34 | - [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn#pytorch-extension) 35 | - Install requirements by `pip install -r requirements.txt` 36 | 37 | > Note: Our current implementations are heavily based on the [ngp-pl](https://github.com/kwea123/ngp_pl) repo. 38 | > For further details, please also refer to their codebase. 39 | 40 | 41 | 42 | 43 | ## Novel View Synthesis 44 | ![nvs](./docs/figures/nvs.png) 45 | 46 | A quickstart: 47 | ```bash 48 | python train_nerf.py --root_dir --exp_name Lego --num_epochs 30 --lr 2e-2 --eval_lpips --no_save_test 49 | ``` 50 | It will train the Lego scene for 30k steps. `--no_save_test` is to disable saving synthesized images. 51 | 52 | More options can be found in `opt.py` and `FFB_config.json` under the `config` folder. 53 | 54 | To compute the metrics for the eight Blender scenes, please run the script `benchmark_synthetic_nerf.sh` under the folder `benchmarking`. 55 | 56 | 57 | ## 2D Image Fitting 58 | ![2d_fitting](./docs/figures/2d_fitting.png) 59 | 60 | ```bash 61 | python train_img.py --config --input_path 62 | ``` 63 | 64 | Currently, the model is trained for 50k iterations. But our experiences show that the model has already achieved comparable results near 20k iterations' training. 65 | 66 | 67 | ## 3D Shape Fitting 68 | ![3d_fitting](./docs/figures/3d_fitting.png) 69 | 70 | ```bash 71 | python train_sdf.py --config --input_path 72 | ``` 73 | 74 | Similar to **2D Image Fitting**, the model is trained with 50k iterations to achieve improved geometric details. However, using `size=100` instead of `size=1000` in the train_dataset (`train_sdf.py`) would slightly reduce the output quality while significantly accelerating the training process. 75 | 76 | ## Citation and License 77 | 78 | ``` 79 | @InProceedings{Wu_2023_CVPR, 80 | author = {Wu, Zhijie and Jin, Yuhe and Yi, Kwang Moo}, 81 | title = {Neural Fourier Filter Bank}, 82 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 83 | month = {June}, 84 | year = {2023}, 85 | pages = {14153-14163} 86 | } 87 | ``` 88 | 89 | Our codebase is under the MIT License. 90 | 91 | 92 | ## TODO 93 | 94 | - [ ] Finish the CUDA version 95 | 96 | -------------------------------------------------------------------------------- /benchmarking/benchmark_synthetic_nerf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export ROOT_DIR=NeRF/Synthetic_NeRF 4 | 5 | python train_nerf.py \ 6 | --root_dir $ROOT_DIR/Chair \ 7 | --exp_name Chair --no_save_test \ 8 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips 9 | 10 | python train_nerf.py \ 11 | --root_dir $ROOT_DIR/Drums \ 12 | --exp_name Drums --no_save_test \ 13 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips 14 | 15 | python train_nerf.py \ 16 | --root_dir $ROOT_DIR/Ficus \ 17 | --exp_name Ficus --no_save_test \ 18 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips 19 | 20 | python train_nerf.py \ 21 | --root_dir $ROOT_DIR/Hotdog \ 22 | --exp_name Hotdog --no_save_test \ 23 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips 24 | 25 | python train_nerf.py \ 26 | --root_dir $ROOT_DIR/Lego \ 27 | --exp_name Lego --no_save_test \ 28 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips 29 | 30 | python train_nerf.py \ 31 | --root_dir $ROOT_DIR/Materials \ 32 | --exp_name Materials --no_save_test \ 33 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips 34 | 35 | python train_nerf.py \ 36 | --root_dir $ROOT_DIR/Mic \ 37 | --exp_name Mic --no_save_test \ 38 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips 39 | 40 | python train_nerf.py \ 41 | --root_dir $ROOT_DIR/Ship \ 42 | --exp_name Ship --no_save_test \ 43 | --num_epochs 30 --batch_size 4096 --lr 2e-2 --eval_lpips 44 | -------------------------------------------------------------------------------- /config/img/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "encoding": { 4 | "feat_dim": 2, 5 | "base_resolution": 96, 6 | "per_level_scale": 1.5, 7 | "base_sigma": 5.0, 8 | "exp_sigma": 2.0, 9 | "grid_embedding_std": 0.01 10 | }, 11 | "SIREN": { 12 | "dims" : [128, 128, 128, 128, 128, 128, 128, 128], 13 | "w0": 100.0, 14 | "w1": 100.0, 15 | "size_factor": 1 16 | }, 17 | "Backbone": { 18 | "dims": [64, 64] 19 | } 20 | }, 21 | "training": { 22 | "LR_scheduler" : [ 23 | { 24 | "type" : "Step", 25 | "initial" : 0.0001, 26 | "interval" : 5, 27 | "factor" : 0.5 28 | }] 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /config/nerf/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "encoding": { 4 | "feat_dim": 2, 5 | "base_resolution": 64, 6 | "per_level_scale": 2.0, 7 | "base_sigma": 8.0, 8 | "exp_sigma": 1.5, 9 | "grid_embedding_std": 0.001 10 | }, 11 | "SIREN": { 12 | "dims" : [128, 128, 128, 128, 128], 13 | "w0": 15.0, 14 | "w1": 25.0, 15 | "size_factor": 2 16 | } 17 | }, 18 | "training": { 19 | "LearningRateSchedule" : [ 20 | { 21 | "type" : "Step", 22 | "initial" : 0.0001, 23 | "interval" : 5000, 24 | "factor" : 0.5 25 | }, 26 | { 27 | "type" : "Step", 28 | "initial" : 0.0001, 29 | "interval" : 5000, 30 | "factor" : 0.5 31 | }, 32 | { 33 | "type" : "Step", 34 | "initial" : 0.001, 35 | "interval" : 5000, 36 | "factor" : 0.5 37 | }, 38 | { 39 | "type" : "Step", 40 | "initial" : 0.005, 41 | "interval" : 5000, 42 | "factor" : 0.5 43 | }], 44 | "lr_threshold": 1e-5 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /config/sdf/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "encoding": { 4 | "feat_dim": 2, 5 | "base_resolution": 8, 6 | "per_level_scale": 1.3, 7 | "base_sigma": 5.0, 8 | "exp_sigma": 1.2, 9 | "grid_embedding_std": 0.01 10 | }, 11 | "SIREN": { 12 | "dims" : [256, 256, 256, 256, 256, 256], 13 | "w0": 45.0, 14 | "w1": 45.0, 15 | "size_factor": 1 16 | } 17 | }, 18 | "training": { 19 | "LR_scheduler" : [ 20 | { 21 | "type" : "Step", 22 | "initial" : 0.0001, 23 | "interval" : 5, 24 | "factor" : 0.5 25 | }] 26 | } 27 | } -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.nerf.nerf import NeRFDataset 2 | from datasets.nerf.nsvf import NSVFDataset 3 | from datasets.nerf.colmap import ColmapDataset 4 | from datasets.nerf.nerfpp import NeRFPPDataset 5 | from datasets.nerf.rtmv import RTMVDataset 6 | 7 | 8 | dataset_dict = {'nerf': NeRFDataset, 9 | 'nsvf': NSVFDataset, 10 | 'colmap': ColmapDataset, 11 | 'nerfpp': NeRFPPDataset, 12 | 'rtmv': RTMVDataset} -------------------------------------------------------------------------------- /datasets/img/imager.py: -------------------------------------------------------------------------------- 1 | """ 2 | These codes are adapted from tiny-cuda-nn (https://github.com/NVlabs/tiny-cuda-nn) 3 | """ 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | import math 9 | 10 | 11 | class ImageDataset(Dataset): 12 | def __init__(self, data, size=100, num_samples=2**18, split='train'): 13 | super().__init__() 14 | 15 | # assign image 16 | self.data = data 17 | 18 | self.img_wh = (self.data.shape[0], self.data.shape[1]) 19 | self.img_shape = torch.tensor([self.img_wh[0], self.img_wh[1]]).float() 20 | 21 | print(f"[INFO] image: {self.data.shape}") 22 | 23 | self.num_samples = num_samples 24 | 25 | self.split = split 26 | self.size = size 27 | 28 | if self.split.startswith("test"): 29 | half_dx = 0.5 / self.img_wh[0] 30 | half_dy = 0.5 / self.img_wh[1] 31 | xs = torch.linspace(half_dx, 1-half_dx, self.img_wh[0]) 32 | ys = torch.linspace(half_dy, 1-half_dy, self.img_wh[1]) 33 | xv, yv = torch.meshgrid([xs, ys], indexing="ij") 34 | 35 | xy = torch.stack((xv.flatten(), yv.flatten())).t() 36 | 37 | xy_max_num = math.ceil(xy.shape[0] / 1024.0) 38 | padding_delta = xy_max_num * 1024 - xy.shape[0] 39 | zeros_padding = torch.zeros((padding_delta, 2)) 40 | self.xs = torch.cat([xy, zeros_padding], dim=0) 41 | 42 | 43 | def __len__(self): 44 | return self.size 45 | 46 | 47 | def __getitem__(self, _): 48 | if self.split.startswith('train'): 49 | xs = torch.rand([self.num_samples, 2], dtype=torch.float32) 50 | 51 | assert torch.sum(xs < 0) == 0, "The coordinates for input image should be non-negative." 52 | 53 | with torch.no_grad(): 54 | scaled_xs = xs * self.img_shape 55 | indices = scaled_xs.long() 56 | lerp_weights = scaled_xs - indices.float() 57 | 58 | x0 = indices[:, 0].clamp(min=0, max=self.img_wh[0]-1).long() 59 | y0 = indices[:, 1].clamp(min=0, max=self.img_wh[1]-1).long() 60 | x1 = (x0 + 1).clamp(min=0, max=self.img_wh[0]-1).long() 61 | y1 = (y0 + 1).clamp(min=0, max=self.img_wh[1]-1).long() 62 | 63 | rgbs = self.data[x0, y0] * (1.0 - lerp_weights[:, 0:1]) * (1.0 - lerp_weights[:, 1:2]) + \ 64 | self.data[x0, y1] * (1.0 - lerp_weights[:, 0:1]) * lerp_weights[:, 1:2] + \ 65 | self.data[x1, y0] * lerp_weights[:, 0:1] * (1.0 - lerp_weights[:, 1:2]) + \ 66 | self.data[x1, y1] * lerp_weights[:, 0:1] * lerp_weights[:, 1:2] 67 | else: 68 | xs = self.xs 69 | rgbs = self.data 70 | 71 | results = { 72 | 'points': xs, 73 | 'rgbs': rgbs, 74 | } 75 | 76 | return results -------------------------------------------------------------------------------- /datasets/nerf/base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(Dataset): 6 | """ 7 | Define length and sampling method 8 | """ 9 | def __init__(self, root_dir, split='train', downsample=1.0): 10 | self.root_dir = root_dir 11 | self.split = split 12 | self.downsample = downsample 13 | 14 | def read_intrinsics(self): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | if self.split.startswith('train'): 19 | return 1000 20 | return len(self.poses) 21 | 22 | def __getitem__(self, idx): 23 | if self.split.startswith('train'): 24 | # training pose is retrieved in train_nerf.py 25 | if self.ray_sampling_strategy == 'all_images': # randomly select images 26 | img_idxs = np.random.choice(len(self.poses), self.batch_size) 27 | elif self.ray_sampling_strategy == 'same_image': # randomly select ONE image 28 | img_idxs = np.random.choice(len(self.poses), 1)[0] 29 | # randomly select pixels 30 | pix_idxs = np.random.choice(self.img_wh[0]*self.img_wh[1], self.batch_size) 31 | rays = self.rays[img_idxs, pix_idxs] 32 | sample = {'img_idxs': img_idxs, 'pix_idxs': pix_idxs, 33 | 'rgb': rays[:, :3]} 34 | if self.rays.shape[-1] == 4: # HDR-NeRF data 35 | sample['exposure'] = rays[:, 3:] 36 | else: 37 | sample = {'pose': self.poses[idx], 'img_idxs': idx} 38 | if len(self.rays) > 0: # if ground truth available 39 | rays = self.rays[idx] 40 | sample['rgb'] = rays[:, :3] 41 | if rays.shape[1] == 4: # HDR-NeRF data 42 | sample['exposure'] = rays[0, 3] # same exposure for all rays 43 | 44 | return sample -------------------------------------------------------------------------------- /datasets/nerf/colmap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import glob 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import * 8 | from .color_utils import read_image 9 | from .colmap_utils import \ 10 | read_cameras_binary, read_images_binary, read_points3d_binary 11 | 12 | from .base import BaseDataset 13 | 14 | 15 | class ColmapDataset(BaseDataset): 16 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 17 | super().__init__(root_dir, split, downsample) 18 | 19 | self.read_intrinsics() 20 | 21 | if kwargs.get('read_meta', True): 22 | self.read_meta(split, **kwargs) 23 | 24 | def read_intrinsics(self): 25 | # Step 1: read and scale intrinsics (same for all images) 26 | camdata = read_cameras_binary(os.path.join(self.root_dir, 'sparse/0/cameras.bin')) 27 | h = int(camdata[1].height*self.downsample) 28 | w = int(camdata[1].width*self.downsample) 29 | self.img_wh = (w, h) 30 | 31 | if camdata[1].model == 'SIMPLE_RADIAL': 32 | fx = fy = camdata[1].params[0]*self.downsample 33 | cx = camdata[1].params[1]*self.downsample 34 | cy = camdata[1].params[2]*self.downsample 35 | elif camdata[1].model in ['PINHOLE', 'OPENCV']: 36 | fx = camdata[1].params[0]*self.downsample 37 | fy = camdata[1].params[1]*self.downsample 38 | cx = camdata[1].params[2]*self.downsample 39 | cy = camdata[1].params[3]*self.downsample 40 | else: 41 | raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") 42 | self.K = torch.FloatTensor([[fx, 0, cx], 43 | [0, fy, cy], 44 | [0, 0, 1]]) 45 | self.directions = get_ray_directions(h, w, self.K) 46 | 47 | def read_meta(self, split, **kwargs): 48 | # Step 2: correct poses 49 | # read extrinsics (of successfully reconstructed images) 50 | imdata = read_images_binary(os.path.join(self.root_dir, 'sparse/0/images.bin')) 51 | img_names = [imdata[k].name for k in imdata] 52 | perm = np.argsort(img_names) 53 | if '360_v2' in self.root_dir and self.downsample<1: # mipnerf360 data 54 | folder = f'images_{int(1/self.downsample)}' 55 | else: 56 | folder = 'images' 57 | # read successfully reconstructed images and ignore others 58 | img_paths = [os.path.join(self.root_dir, folder, name) 59 | for name in sorted(img_names)] 60 | w2c_mats = [] 61 | bottom = np.array([[0, 0, 0, 1.]]) 62 | for k in imdata: 63 | im = imdata[k] 64 | R = im.qvec2rotmat(); t = im.tvec.reshape(3, 1) 65 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)] 66 | w2c_mats = np.stack(w2c_mats, 0) 67 | poses = np.linalg.inv(w2c_mats)[perm, :3] # (N_images, 3, 4) cam2world matrices 68 | 69 | pts3d = read_points3d_binary(os.path.join(self.root_dir, 'sparse/0/points3D.bin')) 70 | pts3d = np.array([pts3d[k].xyz for k in pts3d]) # (N, 3) 71 | 72 | self.poses, self.pts3d = center_poses(poses, pts3d) 73 | 74 | scale = np.linalg.norm(self.poses[..., 3], axis=-1).min() 75 | self.poses[..., 3] /= scale 76 | self.pts3d /= scale 77 | 78 | self.rays = [] 79 | if split == 'test_traj': # use precomputed test poses 80 | self.poses = create_spheric_poses(1.2, self.poses[:, 1, 3].mean()) 81 | self.poses = torch.FloatTensor(self.poses) 82 | return 83 | 84 | if 'HDR-NeRF' in self.root_dir: # HDR-NeRF data 85 | if 'syndata' in self.root_dir: # synthetic 86 | # first 17 are test, last 18 are train 87 | self.unit_exposure_rgb = 0.73 88 | if split=='train': 89 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 90 | f'train/*[024].png'))) 91 | self.poses = np.repeat(self.poses[-18:], 3, 0) 92 | elif split=='test': 93 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 94 | f'test/*[13].png'))) 95 | self.poses = np.repeat(self.poses[:17], 2, 0) 96 | else: 97 | raise ValueError(f"split {split} is invalid for HDR-NeRF!") 98 | else: # real 99 | self.unit_exposure_rgb = 0.5 100 | # even numbers are train, odd numbers are test 101 | if split=='train': 102 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 103 | f'input_images/*0.jpg')))[::2] 104 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 105 | f'input_images/*2.jpg')))[::2] 106 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 107 | f'input_images/*4.jpg')))[::2] 108 | self.poses = np.tile(self.poses[::2], (3, 1, 1)) 109 | elif split=='test': 110 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 111 | f'input_images/*1.jpg')))[1::2] 112 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 113 | f'input_images/*3.jpg')))[1::2] 114 | self.poses = np.tile(self.poses[1::2], (2, 1, 1)) 115 | else: 116 | raise ValueError(f"split {split} is invalid for HDR-NeRF!") 117 | else: 118 | # use every 8th image as test set 119 | if split=='train': 120 | img_paths = [x for i, x in enumerate(img_paths) if i%8!=0] 121 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8!=0]) 122 | elif split=='test': 123 | img_paths = [x for i, x in enumerate(img_paths) if i%8==0] 124 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8==0]) 125 | 126 | print(f'Loading {len(img_paths)} {split} images ...') 127 | for img_path in tqdm(img_paths): 128 | buf = [] # buffer for ray attributes: rgb, etc 129 | 130 | img = read_image(img_path, self.img_wh, blend_a=False) 131 | img = torch.FloatTensor(img) 132 | buf += [img] 133 | 134 | if 'HDR-NeRF' in self.root_dir: # get exposure 135 | folder = self.root_dir.split('/') 136 | scene = folder[-1] if folder[-1] != '' else folder[-2] 137 | if scene in ['bathroom', 'bear', 'chair', 'desk']: 138 | e_dict = {e: 1/8*4**e for e in range(5)} 139 | elif scene in ['diningroom', 'dog']: 140 | e_dict = {e: 1/16*4**e for e in range(5)} 141 | elif scene in ['sofa']: 142 | e_dict = {0:0.25, 1:1, 2:2, 3:4, 4:16} 143 | elif scene in ['sponza']: 144 | e_dict = {0:0.5, 1:2, 2:4, 3:8, 4:32} 145 | elif scene in ['box']: 146 | e_dict = {0:2/3, 1:1/3, 2:1/6, 3:0.1, 4:0.05} 147 | elif scene in ['computer']: 148 | e_dict = {0:1/3, 1:1/8, 2:1/15, 3:1/30, 4:1/60} 149 | elif scene in ['flower']: 150 | e_dict = {0:1/3, 1:1/6, 2:0.1, 3:0.05, 4:1/45} 151 | elif scene in ['luckycat']: 152 | e_dict = {0:2, 1:1, 2:0.5, 3:0.25, 4:0.125} 153 | e = int(img_path.split('.')[0][-1]) 154 | buf += [e_dict[e]*torch.ones_like(img[:, :1])] 155 | 156 | self.rays += [torch.cat(buf, 1)] 157 | 158 | self.rays = torch.stack(self.rays) # (N_images, hw, ?) 159 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) -------------------------------------------------------------------------------- /datasets/nerf/colmap_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | class Image(BaseImage): 49 | def qvec2rotmat(self): 50 | return qvec2rotmat(self.qvec) 51 | 52 | 53 | CAMERA_MODELS = { 54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 61 | CameraModel(model_id=7, model_name="FOV", num_params=5), 62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 65 | } 66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 67 | for camera_model in CAMERA_MODELS]) 68 | 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | 82 | def read_cameras_text(path): 83 | """ 84 | see: src/base/reconstruction.cc 85 | void Reconstruction::WriteCamerasText(const std::string& path) 86 | void Reconstruction::ReadCamerasText(const std::string& path) 87 | """ 88 | cameras = {} 89 | with open(path, "r") as fid: 90 | while True: 91 | line = fid.readline() 92 | if not line: 93 | break 94 | line = line.strip() 95 | if len(line) > 0 and line[0] != "#": 96 | elems = line.split() 97 | camera_id = int(elems[0]) 98 | model = elems[1] 99 | width = int(elems[2]) 100 | height = int(elems[3]) 101 | params = np.array(tuple(map(float, elems[4:]))) 102 | cameras[camera_id] = Camera(id=camera_id, model=model, 103 | width=width, height=height, 104 | params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes(fid, num_bytes=8*num_params, 127 | format_char_sequence="d"*num_params) 128 | cameras[camera_id] = Camera(id=camera_id, 129 | model=model_name, 130 | width=width, 131 | height=height, 132 | params=np.array(params)) 133 | assert len(cameras) == num_cameras 134 | return cameras 135 | 136 | 137 | def read_images_text(path): 138 | """ 139 | see: src/base/reconstruction.cc 140 | void Reconstruction::ReadImagesText(const std::string& path) 141 | void Reconstruction::WriteImagesText(const std::string& path) 142 | """ 143 | images = {} 144 | with open(path, "r") as fid: 145 | while True: 146 | line = fid.readline() 147 | if not line: 148 | break 149 | line = line.strip() 150 | if len(line) > 0 and line[0] != "#": 151 | elems = line.split() 152 | image_id = int(elems[0]) 153 | qvec = np.array(tuple(map(float, elems[1:5]))) 154 | tvec = np.array(tuple(map(float, elems[5:8]))) 155 | camera_id = int(elems[8]) 156 | image_name = elems[9] 157 | elems = fid.readline().split() 158 | xys = np.column_stack([tuple(map(float, elems[0::3])), 159 | tuple(map(float, elems[1::3]))]) 160 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 161 | images[image_id] = Image( 162 | id=image_id, qvec=qvec, tvec=tvec, 163 | camera_id=camera_id, name=image_name, 164 | xys=xys, point3D_ids=point3D_ids) 165 | return images 166 | 167 | 168 | def read_images_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for image_index in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_points3D_text(path): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::ReadPoints3DText(const std::string& path) 207 | void Reconstruction::WritePoints3DText(const std::string& path) 208 | """ 209 | points3D = {} 210 | with open(path, "r") as fid: 211 | while True: 212 | line = fid.readline() 213 | if not line: 214 | break 215 | line = line.strip() 216 | if len(line) > 0 and line[0] != "#": 217 | elems = line.split() 218 | point3D_id = int(elems[0]) 219 | xyz = np.array(tuple(map(float, elems[1:4]))) 220 | rgb = np.array(tuple(map(int, elems[4:7]))) 221 | error = float(elems[7]) 222 | image_ids = np.array(tuple(map(int, elems[8::2]))) 223 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 225 | error=error, image_ids=image_ids, 226 | point2D_idxs=point2D_idxs) 227 | return points3D 228 | 229 | 230 | def read_points3d_binary(path_to_model_file): 231 | """ 232 | see: src/base/reconstruction.cc 233 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 234 | void Reconstruction::WritePoints3DBinary(const std::string& path) 235 | """ 236 | points3D = {} 237 | with open(path_to_model_file, "rb") as fid: 238 | num_points = read_next_bytes(fid, 8, "Q")[0] 239 | for point_line_index in range(num_points): 240 | binary_point_line_properties = read_next_bytes( 241 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 242 | point3D_id = binary_point_line_properties[0] 243 | xyz = np.array(binary_point_line_properties[1:4]) 244 | rgb = np.array(binary_point_line_properties[4:7]) 245 | error = np.array(binary_point_line_properties[7]) 246 | track_length = read_next_bytes( 247 | fid, num_bytes=8, format_char_sequence="Q")[0] 248 | track_elems = read_next_bytes( 249 | fid, num_bytes=8*track_length, 250 | format_char_sequence="ii"*track_length) 251 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 253 | points3D[point3D_id] = Point3D( 254 | id=point3D_id, xyz=xyz, rgb=rgb, 255 | error=error, image_ids=image_ids, 256 | point2D_idxs=point2D_idxs) 257 | return points3D 258 | 259 | 260 | def read_model(path, ext): 261 | if ext == ".txt": 262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 263 | images = read_images_text(os.path.join(path, "images" + ext)) 264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 265 | else: 266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 267 | images = read_images_binary(os.path.join(path, "images" + ext)) 268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 269 | return cameras, images, points3D 270 | 271 | 272 | def qvec2rotmat(qvec): 273 | return np.array([ 274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 283 | 284 | 285 | def rotmat2qvec(R): 286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 287 | K = np.array([ 288 | [Rxx - Ryy - Rzz, 0, 0, 0], 289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 292 | eigvals, eigvecs = np.linalg.eigh(K) 293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 294 | if qvec[0] < 0: 295 | qvec *= -1 296 | return qvec -------------------------------------------------------------------------------- /datasets/nerf/color_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from einops import rearrange 3 | import imageio 4 | import numpy as np 5 | 6 | 7 | def srgb_to_linear(img): 8 | limit = 0.04045 9 | return np.where(img>limit, ((img+0.055)/1.055)**2.4, img/12.92) 10 | 11 | 12 | def linear_to_srgb(img): 13 | limit = 0.0031308 14 | img = np.where(img>limit, 1.055*img**(1/2.4)-0.055, 12.92*img) 15 | img[img>1] = 1 # "clamp" tonemapper 16 | return img 17 | 18 | 19 | def read_image(img_path, img_wh, blend_a=True): 20 | img = imageio.imread(img_path).astype(np.float32)/255.0 21 | # img[..., :3] = srgb_to_linear(img[..., :3]) 22 | if img.shape[2] == 4: # blend A to RGB 23 | if blend_a: 24 | img = img[..., :3]*img[..., -1:]+(1-img[..., -1:]) 25 | else: 26 | img = img[..., :3]*img[..., -1:] 27 | 28 | img = cv2.resize(img, img_wh) 29 | img = rearrange(img, 'h w c -> (h w) c') 30 | 31 | return img -------------------------------------------------------------------------------- /datasets/nerf/depth_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | 5 | def read_pfm(path): 6 | """Read pfm file. 7 | 8 | Args: 9 | path (str): path to file 10 | 11 | Returns: 12 | tuple: (data, scale) 13 | """ 14 | with open(path, "rb") as file: 15 | 16 | color = None 17 | width = None 18 | height = None 19 | scale = None 20 | endian = None 21 | 22 | header = file.readline().rstrip() 23 | if header.decode("ascii") == "PF": 24 | color = True 25 | elif header.decode("ascii") == "Pf": 26 | color = False 27 | else: 28 | raise Exception("Not a PFM file: " + path) 29 | 30 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 31 | if dim_match: 32 | width, height = list(map(int, dim_match.groups())) 33 | else: 34 | raise Exception("Malformed PFM header.") 35 | 36 | scale = float(file.readline().decode("ascii").rstrip()) 37 | if scale < 0: 38 | # little-endian 39 | endian = "<" 40 | scale = -scale 41 | else: 42 | # big-endian 43 | endian = ">" 44 | 45 | data = np.fromfile(file, endian + "f") 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | 51 | return data, scale -------------------------------------------------------------------------------- /datasets/nerf/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import get_ray_directions 8 | from .color_utils import read_image 9 | 10 | from .base import BaseDataset 11 | 12 | 13 | class NeRFDataset(BaseDataset): 14 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 15 | super().__init__(root_dir, split, downsample) 16 | 17 | self.read_intrinsics() 18 | 19 | if kwargs.get('read_meta', True): 20 | self.read_meta(split) 21 | 22 | def read_intrinsics(self): 23 | with open(os.path.join(self.root_dir, "transforms_train.json"), 'r') as f: 24 | meta = json.load(f) 25 | 26 | w = h = int(800*self.downsample) 27 | fx = fy = 0.5*800/np.tan(0.5*meta['camera_angle_x'])*self.downsample 28 | 29 | K = np.float32([[fx, 0, w/2], 30 | [0, fy, h/2], 31 | [0, 0, 1]]) 32 | 33 | self.K = torch.FloatTensor(K) 34 | self.directions = get_ray_directions(h, w, self.K) 35 | self.img_wh = (w, h) 36 | 37 | def read_meta(self, split): 38 | self.rays = [] 39 | self.poses = [] 40 | 41 | if split == 'trainval': 42 | with open(os.path.join(self.root_dir, "transforms_train.json"), 'r') as f: 43 | frames = json.load(f)["frames"] 44 | with open(os.path.join(self.root_dir, "transforms_val.json"), 'r') as f: 45 | frames += json.load(f)["frames"] 46 | else: 47 | with open(os.path.join(self.root_dir, f"transforms_{split}.json"), 'r') as f: 48 | frames = json.load(f)["frames"] 49 | 50 | print(f'Loading {len(frames)} {split} images ...') 51 | for frame in tqdm(frames): 52 | c2w = np.array(frame['transform_matrix'])[:3, :4] 53 | 54 | # determine scale 55 | if 'Jrender_Dataset' in self.root_dir: 56 | c2w[:, :2] *= -1 # [left up front] to [right down front] 57 | folder = self.root_dir.split('/') 58 | scene = folder[-1] if folder[-1] != '' else folder[-2] 59 | if scene=='Easyship': 60 | pose_radius_scale = 1.2 61 | elif scene=='Scar': 62 | pose_radius_scale = 1.8 63 | elif scene=='Coffee': 64 | pose_radius_scale = 2.5 65 | elif scene=='Car': 66 | pose_radius_scale = 0.8 67 | else: 68 | pose_radius_scale = 1.5 69 | else: 70 | c2w[:, 1:3] *= -1 # [right up back] to [right down front] 71 | pose_radius_scale = 1.5 72 | c2w[:, 3] /= np.linalg.norm(c2w[:, 3])/pose_radius_scale 73 | 74 | # add shift 75 | if 'Jrender_Dataset' in self.root_dir: 76 | if scene=='Coffee': 77 | c2w[1, 3] -= 0.4465 78 | elif scene=='Car': 79 | c2w[0, 3] -= 0.7 80 | self.poses += [c2w] 81 | 82 | try: 83 | img_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 84 | img = read_image(img_path, self.img_wh) 85 | self.rays += [img] 86 | except: pass 87 | 88 | if len(self.rays)>0: 89 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 90 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 91 | -------------------------------------------------------------------------------- /datasets/nerf/nerfpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | from .ray_utils import get_ray_directions 9 | from .color_utils import read_image 10 | 11 | from .base import BaseDataset 12 | 13 | 14 | class NeRFPPDataset(BaseDataset): 15 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 16 | super().__init__(root_dir, split, downsample) 17 | 18 | self.read_intrinsics() 19 | 20 | if kwargs.get('read_meta', True): 21 | self.read_meta(split) 22 | 23 | def read_intrinsics(self): 24 | K = np.loadtxt(glob.glob(os.path.join(self.root_dir, 'train/intrinsics/*.txt'))[0], 25 | dtype=np.float32).reshape(4, 4)[:3, :3] 26 | K[:2] *= self.downsample 27 | w, h = Image.open(glob.glob(os.path.join(self.root_dir, 'train/rgb/*'))[0]).size 28 | w, h = int(w*self.downsample), int(h*self.downsample) 29 | self.K = torch.FloatTensor(K) 30 | self.directions = get_ray_directions(h, w, self.K) 31 | self.img_wh = (w, h) 32 | 33 | def read_meta(self, split): 34 | self.rays = [] 35 | self.poses = [] 36 | 37 | if split == 'test_traj': 38 | poses_path = \ 39 | sorted(glob.glob(os.path.join(self.root_dir, 'camera_path/pose/*.txt'))) 40 | self.poses = [np.loadtxt(p).reshape(4, 4)[:3] for p in poses_path] 41 | else: 42 | if split=='trainval': 43 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'train/rgb/*')))+\ 44 | sorted(glob.glob(os.path.join(self.root_dir, 'val/rgb/*'))) 45 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'train/pose/*.txt')))+\ 46 | sorted(glob.glob(os.path.join(self.root_dir, 'val/pose/*.txt'))) 47 | else: 48 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, split, 'rgb/*'))) 49 | poses = sorted(glob.glob(os.path.join(self.root_dir, split, 'pose/*.txt'))) 50 | 51 | print(f'Loading {len(img_paths)} {split} images ...') 52 | for img_path, pose in tqdm(zip(img_paths, poses)): 53 | self.poses += [np.loadtxt(pose).reshape(4, 4)[:3]] 54 | 55 | img = read_image(img_path, self.img_wh) 56 | self.rays += [img] 57 | 58 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 59 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 60 | -------------------------------------------------------------------------------- /datasets/nerf/nsvf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import get_ray_directions 8 | from .color_utils import read_image 9 | 10 | from .base import BaseDataset 11 | 12 | 13 | class NSVFDataset(BaseDataset): 14 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 15 | super().__init__(root_dir, split, downsample) 16 | 17 | self.read_intrinsics() 18 | 19 | if kwargs.get('read_meta', True): 20 | xyz_min, xyz_max = \ 21 | np.loadtxt(os.path.join(root_dir, 'bbox.txt'))[:6].reshape(2, 3) 22 | self.shift = (xyz_max + xyz_min) / 2 23 | self.scale = (xyz_max - xyz_min).max() / 2 * 1.05 # enlarge a little 24 | 25 | # hard-code fix the bound error for some scenes... 26 | if 'Mic' in self.root_dir: self.scale *= 1.2 27 | elif 'Lego' in self.root_dir: self.scale *= 1.1 28 | 29 | self.read_meta(split) 30 | 31 | def read_intrinsics(self): 32 | if 'Synthetic' in self.root_dir or 'Ignatius' in self.root_dir: 33 | with open(os.path.join(self.root_dir, 'intrinsics.txt')) as f: 34 | fx = fy = float(f.readline().split()[0]) * self.downsample 35 | if 'Synthetic' in self.root_dir: 36 | w = h = int(800*self.downsample) 37 | else: 38 | w, h = int(1920*self.downsample), int(1080*self.downsample) 39 | 40 | K = np.float32([[fx, 0, w/2], 41 | [0, fy, h/2], 42 | [0, 0, 1]]) 43 | else: 44 | K = np.loadtxt(os.path.join(self.root_dir, 'intrinsics.txt'), 45 | dtype=np.float32)[:3, :3] 46 | if 'BlendedMVS' in self.root_dir: 47 | w, h = int(768*self.downsample), int(576*self.downsample) 48 | elif 'Tanks' in self.root_dir: 49 | w, h = int(1920*self.downsample), int(1080*self.downsample) 50 | K[:2] *= self.downsample 51 | 52 | self.K = torch.FloatTensor(K) 53 | self.directions = get_ray_directions(h, w, self.K) 54 | self.img_wh = (w, h) 55 | 56 | def read_meta(self, split): 57 | self.rays = [] 58 | self.poses = [] 59 | 60 | if split == 'test_traj': # BlendedMVS and TanksAndTemple 61 | if 'Ignatius' in self.root_dir: 62 | poses_path = \ 63 | sorted(glob.glob(os.path.join(self.root_dir, 'test_pose/*.txt'))) 64 | poses = [np.loadtxt(p) for p in poses_path] 65 | else: 66 | poses = np.loadtxt(os.path.join(self.root_dir, 'test_traj.txt')) 67 | poses = poses.reshape(-1, 4, 4) 68 | for pose in poses: 69 | c2w = pose[:3] 70 | c2w[:, 0] *= -1 # [left down front] to [right down front] 71 | c2w[:, 3] -= self.shift 72 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5] 73 | self.poses += [c2w] 74 | else: 75 | if split == 'train': prefix = '0_' 76 | elif split == 'trainval': prefix = '[0-1]_' 77 | elif split == 'trainvaltest': prefix = '[0-2]_' 78 | elif split == 'val': prefix = '1_' 79 | elif 'Synthetic' in self.root_dir: prefix = '2_' # test set for synthetic scenes 80 | elif split == 'test': prefix = '1_' # test set for real scenes 81 | else: raise ValueError(f'{split} split not recognized!') 82 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'rgb', prefix+'*.png'))) 83 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'pose', prefix+'*.txt'))) 84 | 85 | print(f'Loading {len(img_paths)} {split} images ...') 86 | for img_path, pose in tqdm(zip(img_paths, poses)): 87 | c2w = np.loadtxt(pose)[:3] 88 | c2w[:, 3] -= self.shift 89 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5] 90 | self.poses += [c2w] 91 | 92 | img = read_image(img_path, self.img_wh) 93 | if 'Jade' in self.root_dir or 'Fountain' in self.root_dir: 94 | # these scenes have black background, changing to white 95 | img[torch.all(img<=0.1, dim=-1)] = 1.0 96 | 97 | self.rays += [img] 98 | 99 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 100 | 101 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 102 | -------------------------------------------------------------------------------- /datasets/nerf/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from kornia import create_meshgrid 4 | from einops import rearrange 5 | 6 | 7 | @torch.cuda.amp.autocast(dtype=torch.float32) 8 | def get_ray_directions(H, W, K, device='cpu', random=False, return_uv=False, flatten=True): 9 | """ 10 | Get ray directions for all pixels in camera coordinate [right down front]. 11 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 12 | ray-tracing-generating-camera-rays/standard-coordinate-systems 13 | 14 | Inputs: 15 | H, W: image height and width 16 | K: (3, 3) camera intrinsics 17 | random: whether the ray passes randomly inside the pixel 18 | return_uv: whether to return uv image coordinates 19 | 20 | Outputs: (shape depends on @flatten) 21 | directions: (H, W, 3) or (H*W, 3), the direction of the rays in camera coordinate 22 | uv: (H, W, 2) or (H*W, 2) image coordinates 23 | """ 24 | grid = create_meshgrid(H, W, False, device=device)[0] # (H, W, 2) 25 | u, v = grid.unbind(-1) 26 | 27 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 28 | if random: 29 | directions = \ 30 | torch.stack([(u-cx+torch.rand_like(u))/fx, 31 | (v-cy+torch.rand_like(v))/fy, 32 | torch.ones_like(u)], -1) 33 | else: # pass by the center 34 | directions = \ 35 | torch.stack([(u-cx+0.5)/fx, (v-cy+0.5)/fy, torch.ones_like(u)], -1) 36 | if flatten: 37 | directions = directions.reshape(-1, 3) 38 | grid = grid.reshape(-1, 2) 39 | 40 | if return_uv: 41 | return directions, grid 42 | return directions 43 | 44 | 45 | @torch.cuda.amp.autocast(dtype=torch.float32) 46 | def get_rays(directions, c2w): 47 | """ 48 | Get ray origin and directions in world coordinate for all pixels in one image. 49 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 50 | ray-tracing-generating-camera-rays/standard-coordinate-systems 51 | 52 | Inputs: 53 | directions: (N, 3) ray directions in camera coordinate 54 | c2w: (3, 4) or (N, 3, 4) transformation matrix from camera coordinate to world coordinate 55 | 56 | Outputs: 57 | rays_o: (N, 3), the origin of the rays in world coordinate 58 | rays_d: (N, 3), the direction of the rays in world coordinate 59 | """ 60 | if c2w.ndim==2: 61 | # Rotate ray directions from camera coordinate to the world coordinate 62 | rays_d = directions @ c2w[:, :3].T 63 | else: 64 | rays_d = rearrange(directions, 'n c -> n 1 c') @ \ 65 | rearrange(c2w[..., :3], 'n a b -> n b a') 66 | rays_d = rearrange(rays_d, 'n 1 c -> n c') 67 | # The origin of all rays is the camera origin in world coordinate 68 | rays_o = c2w[..., 3].expand_as(rays_d) 69 | 70 | return rays_o, rays_d 71 | 72 | 73 | @torch.cuda.amp.autocast(dtype=torch.float32) 74 | def axisangle_to_R(v): 75 | """ 76 | Convert an axis-angle vector to rotation matrix 77 | from https://github.com/ActiveVisionLab/nerfmm/blob/main/utils/lie_group_helper.py#L47 78 | 79 | Inputs: 80 | v: (3) or (B, 3) 81 | 82 | Outputs: 83 | R: (3, 3) or (B, 3, 3) 84 | """ 85 | v_ndim = v.ndim 86 | if v_ndim==1: 87 | v = rearrange(v, 'c -> 1 c') 88 | zero = torch.zeros_like(v[:, :1]) # (B, 1) 89 | skew_v0 = torch.cat([zero, -v[:, 2:3], v[:, 1:2]], 1) # (B, 3) 90 | skew_v1 = torch.cat([v[:, 2:3], zero, -v[:, 0:1]], 1) 91 | skew_v2 = torch.cat([-v[:, 1:2], v[:, 0:1], zero], 1) 92 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=1) # (B, 3, 3) 93 | 94 | norm_v = rearrange(torch.norm(v, dim=1)+1e-7, 'b -> b 1 1') 95 | eye = torch.eye(3, device=v.device) 96 | R = eye + (torch.sin(norm_v)/norm_v)*skew_v + \ 97 | ((1-torch.cos(norm_v))/norm_v**2)*(skew_v@skew_v) 98 | 99 | if v_ndim==1: 100 | R = rearrange(R, '1 c d -> c d') 101 | 102 | return R 103 | 104 | 105 | def normalize(v): 106 | """Normalize a vector.""" 107 | return v/np.linalg.norm(v) 108 | 109 | 110 | def average_poses(poses, pts3d=None): 111 | """ 112 | Calculate the average pose, which is then used to center all poses 113 | using @center_poses. Its computation is as follows: 114 | 1. Compute the center: the average of 3d point cloud (if None, center of cameras). 115 | 2. Compute the z axis: the normalized average z axis. 116 | 3. Compute axis y': the average y axis. 117 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 118 | 5. Compute the y axis: z cross product x. 119 | 120 | Note that at step 3, we cannot directly use y' as y axis since it's 121 | not necessarily orthogonal to z axis. We need to pass from x to y. 122 | Inputs: 123 | poses: (N_images, 3, 4) 124 | pts3d: (N, 3) 125 | 126 | Outputs: 127 | pose_avg: (3, 4) the average pose 128 | """ 129 | # 1. Compute the center 130 | if pts3d is not None: 131 | center = pts3d.mean(0) 132 | else: 133 | center = poses[..., 3].mean(0) 134 | 135 | # 2. Compute the z axis 136 | z = normalize(poses[..., 2].mean(0)) # (3) 137 | 138 | # 3. Compute axis y' (no need to normalize as it's not the final output) 139 | y_ = poses[..., 1].mean(0) # (3) 140 | 141 | # 4. Compute the x axis 142 | x = normalize(np.cross(y_, z)) # (3) 143 | 144 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 145 | y = np.cross(z, x) # (3) 146 | 147 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 148 | 149 | return pose_avg 150 | 151 | 152 | def center_poses(poses, pts3d=None): 153 | """ 154 | See https://github.com/bmild/nerf/issues/34 155 | Inputs: 156 | poses: (N_images, 3, 4) 157 | pts3d: (N, 3) reconstructed point cloud 158 | 159 | Outputs: 160 | poses_centered: (N_images, 3, 4) the centered poses 161 | pts3d_centered: (N, 3) centered point cloud 162 | """ 163 | 164 | pose_avg = average_poses(poses, pts3d) # (3, 4) 165 | pose_avg_homo = np.eye(4) 166 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 167 | # by simply adding 0, 0, 0, 1 as the last row 168 | pose_avg_inv = np.linalg.inv(pose_avg_homo) 169 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 170 | poses_homo = \ 171 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 172 | 173 | poses_centered = pose_avg_inv @ poses_homo # (N_images, 4, 4) 174 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 175 | 176 | if pts3d is not None: 177 | pts3d_centered = pts3d @ pose_avg_inv[:, :3].T + pose_avg_inv[:, 3:].T 178 | return poses_centered, pts3d_centered 179 | 180 | return poses_centered 181 | 182 | 183 | def create_spheric_poses(radius, mean_h, n_poses=120): 184 | """ 185 | Create circular poses around z axis. 186 | Inputs: 187 | radius: the (negative) height and the radius of the circle. 188 | mean_h: mean camera height 189 | Outputs: 190 | spheric_poses: (n_poses, 3, 4) the poses in the circular path 191 | """ 192 | def spheric_pose(theta, phi, radius): 193 | trans_t = lambda t : np.array([ 194 | [1,0,0,0], 195 | [0,1,0,2*mean_h], 196 | [0,0,1,-t] 197 | ]) 198 | 199 | rot_phi = lambda phi : np.array([ 200 | [1,0,0], 201 | [0,np.cos(phi),-np.sin(phi)], 202 | [0,np.sin(phi), np.cos(phi)] 203 | ]) 204 | 205 | rot_theta = lambda th : np.array([ 206 | [np.cos(th),0,-np.sin(th)], 207 | [0,1,0], 208 | [np.sin(th),0, np.cos(th)] 209 | ]) 210 | 211 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius) 212 | c2w = np.array([[-1,0,0],[0,0,1],[0,1,0]]) @ c2w 213 | return c2w 214 | 215 | spheric_poses = [] 216 | for th in np.linspace(0, 2*np.pi, n_poses+1)[:-1]: 217 | spheric_poses += [spheric_pose(th, -np.pi/12, radius)] 218 | return np.stack(spheric_poses, 0) -------------------------------------------------------------------------------- /datasets/nerf/rtmv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import json 4 | import numpy as np 5 | import os 6 | from tqdm import tqdm 7 | 8 | from .ray_utils import get_ray_directions 9 | from .color_utils import read_image 10 | 11 | from .base import BaseDataset 12 | 13 | 14 | class RTMVDataset(BaseDataset): 15 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 16 | super().__init__(root_dir, split, downsample) 17 | 18 | self.read_intrinsics() 19 | 20 | if kwargs.get('read_meta', True): 21 | self.read_meta(split) 22 | 23 | def read_intrinsics(self): 24 | with open(os.path.join(self.root_dir, '00000.json'), 'r') as f: 25 | meta = json.load(f)['camera_data'] 26 | 27 | self.shift = np.array(meta['scene_center_3d_box']) 28 | self.scale = (np.array(meta['scene_max_3d_box'])- 29 | np.array(meta['scene_min_3d_box'])).max()/2 * 1.05 # enlarge a little 30 | 31 | fx = meta['intrinsics']['fx'] * self.downsample 32 | fy = meta['intrinsics']['fy'] * self.downsample 33 | cx = meta['intrinsics']['cx'] * self.downsample 34 | cy = meta['intrinsics']['cy'] * self.downsample 35 | w = int(meta['width']*self.downsample) 36 | h = int(meta['height']*self.downsample) 37 | K = np.float32([[fx, 0, cx], 38 | [0, fy, cy], 39 | [0, 0, 1]]) 40 | self.K = torch.FloatTensor(K) 41 | self.directions = get_ray_directions(h, w, self.K) 42 | self.img_wh = (w, h) 43 | 44 | def read_meta(self, split): 45 | self.rays = [] 46 | self.poses = [] 47 | 48 | if split == 'train': start_idx, end_idx = 0, 100 49 | elif split == 'trainval': start_idx, end_idx = 0, 105 50 | elif split == 'test': start_idx, end_idx = 105, 150 51 | else: start_idx, end_idx = 0, 150 52 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images/*')))[start_idx:end_idx] 53 | poses = sorted(glob.glob(os.path.join(self.root_dir, '*.json')))[start_idx:end_idx] 54 | 55 | print(f'Loading {len(img_paths)} {split} images ...') 56 | for img_path, pose in tqdm(zip(img_paths, poses)): 57 | with open(pose, 'r') as f: 58 | p = json.load(f)['camera_data'] 59 | c2w = np.array(p['cam2world']).T[:3] 60 | c2w[:, 1:3] *= -1 61 | if 'bricks' in self.root_dir: 62 | c2w[:, 3] -= self.shift 63 | c2w[:, 3] /= 2*self.scale # bound in [-0.5, 0.5] 64 | self.poses += [c2w] 65 | 66 | img = read_image(img_path, self.img_wh) 67 | self.rays += [img] 68 | 69 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 70 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 71 | -------------------------------------------------------------------------------- /datasets/sdf/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | These codes are adapted from torch-ngp (https://github.com/ashawkey/torch-ngp/tree/main) 3 | """ 4 | 5 | from torch.utils.data import Dataset 6 | 7 | import numpy as np 8 | import trimesh 9 | import pysdf 10 | 11 | 12 | class SDFDataset(Dataset): 13 | def __init__(self, path, size=100, num_samples=2**18, clip_sdf=None): 14 | super().__init__() 15 | self.path = path 16 | 17 | # load obj 18 | self.mesh = trimesh.load(path, force='mesh') 19 | 20 | # normalize to [-1, 1] (different from instant-sdf where is [0, 1]) 21 | vs = self.mesh.vertices 22 | vmin = vs.min(0) 23 | vmax = vs.max(0) 24 | v_center = (vmin + vmax) / 2 25 | v_scale = 2 / np.sqrt(np.sum((vmax - vmin) ** 2)) * 0.95 26 | vs = (vs - v_center[None, :]) * v_scale 27 | self.mesh.vertices = vs 28 | 29 | print(f"[INFO] mesh: {self.mesh.vertices.shape} {self.mesh.faces.shape}") 30 | 31 | if not self.mesh.is_watertight: 32 | print(f"[WARN] mesh is not watertight! SDF maybe incorrect.") 33 | 34 | self.sdf_fn = pysdf.SDF(self.mesh.vertices, self.mesh.faces) 35 | 36 | self.num_samples = num_samples 37 | assert self.num_samples % 8 == 0, "num_samples must be divisible by 8." 38 | self.clip_sdf = clip_sdf 39 | 40 | self.size = size 41 | 42 | def __len__(self): 43 | return self.size 44 | 45 | def __getitem__(self, _): 46 | # online sampling 47 | sdfs = np.zeros((self.num_samples, 1)) 48 | # surface 49 | points_surface = self.mesh.sample(self.num_samples * 2 // 3) 50 | # perturb surface 51 | points_surface[self.num_samples // 3:] += 0.01 * np.random.randn(self.num_samples // 3, 3) 52 | # random 53 | points_uniform = np.random.rand(self.num_samples // 3, 3) * 2 - 1 54 | points = np.concatenate([points_surface, points_uniform], axis=0).astype(np.float32) 55 | 56 | sdfs[self.num_samples // 3:] = -self.sdf_fn(points[self.num_samples // 3:])[:,None].astype(np.float32) 57 | 58 | # clip sdf 59 | if self.clip_sdf is not None: 60 | sdfs = sdfs.clip(-self.clip_sdf, self.clip_sdf) 61 | 62 | results = { 63 | 'sdfs': sdfs, 64 | 'points': points, 65 | } 66 | 67 | return results -------------------------------------------------------------------------------- /docs/figures/2d_fitting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/docs/figures/2d_fitting.png -------------------------------------------------------------------------------- /docs/figures/3d_fitting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/docs/figures/3d_fitting.png -------------------------------------------------------------------------------- /docs/figures/nvs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/docs/figures/nvs.png -------------------------------------------------------------------------------- /docs/figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/docs/figures/teaser.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/__init__.py -------------------------------------------------------------------------------- /models/csrc/binding.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | 4 | std::vector ray_aabb_intersect( 5 | const torch::Tensor rays_o, 6 | const torch::Tensor rays_d, 7 | const torch::Tensor centers, 8 | const torch::Tensor half_sizes, 9 | const int max_hits 10 | ){ 11 | CHECK_INPUT(rays_o); 12 | CHECK_INPUT(rays_d); 13 | CHECK_INPUT(centers); 14 | CHECK_INPUT(half_sizes); 15 | return ray_aabb_intersect_cu(rays_o, rays_d, centers, half_sizes, max_hits); 16 | } 17 | 18 | 19 | std::vector ray_sphere_intersect( 20 | const torch::Tensor rays_o, 21 | const torch::Tensor rays_d, 22 | const torch::Tensor centers, 23 | const torch::Tensor radii, 24 | const int max_hits 25 | ){ 26 | CHECK_INPUT(rays_o); 27 | CHECK_INPUT(rays_d); 28 | CHECK_INPUT(centers); 29 | CHECK_INPUT(radii); 30 | return ray_sphere_intersect_cu(rays_o, rays_d, centers, radii, max_hits); 31 | } 32 | 33 | 34 | void packbits( 35 | torch::Tensor density_grid, 36 | const float density_threshold, 37 | torch::Tensor density_bitfield 38 | ){ 39 | CHECK_INPUT(density_grid); 40 | CHECK_INPUT(density_bitfield); 41 | 42 | return packbits_cu(density_grid, density_threshold, density_bitfield); 43 | } 44 | 45 | 46 | torch::Tensor morton3D(const torch::Tensor coords){ 47 | CHECK_INPUT(coords); 48 | 49 | return morton3D_cu(coords); 50 | } 51 | 52 | 53 | torch::Tensor morton3D_invert(const torch::Tensor indices){ 54 | CHECK_INPUT(indices); 55 | 56 | return morton3D_invert_cu(indices); 57 | } 58 | 59 | 60 | std::vector raymarching_train( 61 | const torch::Tensor rays_o, 62 | const torch::Tensor rays_d, 63 | const torch::Tensor hits_t, 64 | const torch::Tensor density_bitfield, 65 | const int cascades, 66 | const float scale, 67 | const float exp_step_factor, 68 | const torch::Tensor noise, 69 | const int grid_size, 70 | const int max_samples 71 | ){ 72 | CHECK_INPUT(rays_o); 73 | CHECK_INPUT(rays_d); 74 | CHECK_INPUT(hits_t); 75 | CHECK_INPUT(density_bitfield); 76 | CHECK_INPUT(noise); 77 | 78 | return raymarching_train_cu( 79 | rays_o, rays_d, hits_t, density_bitfield, cascades, 80 | scale, exp_step_factor, noise, grid_size, max_samples); 81 | } 82 | 83 | 84 | std::vector raymarching_test( 85 | const torch::Tensor rays_o, 86 | const torch::Tensor rays_d, 87 | torch::Tensor hits_t, 88 | const torch::Tensor alive_indices, 89 | const torch::Tensor density_bitfield, 90 | const int cascades, 91 | const float scale, 92 | const float exp_step_factor, 93 | const int grid_size, 94 | const int max_samples, 95 | const int N_samples 96 | ){ 97 | CHECK_INPUT(rays_o); 98 | CHECK_INPUT(rays_d); 99 | CHECK_INPUT(hits_t); 100 | CHECK_INPUT(alive_indices); 101 | CHECK_INPUT(density_bitfield); 102 | 103 | return raymarching_test_cu( 104 | rays_o, rays_d, hits_t, alive_indices, density_bitfield, cascades, 105 | scale, exp_step_factor, grid_size, max_samples, N_samples); 106 | } 107 | 108 | 109 | std::vector composite_train_fw( 110 | const torch::Tensor sigmas, 111 | const torch::Tensor rgbs, 112 | const torch::Tensor deltas, 113 | const torch::Tensor ts, 114 | const torch::Tensor rays_a, 115 | const float opacity_threshold 116 | ){ 117 | CHECK_INPUT(sigmas); 118 | CHECK_INPUT(rgbs); 119 | CHECK_INPUT(deltas); 120 | CHECK_INPUT(ts); 121 | CHECK_INPUT(rays_a); 122 | 123 | return composite_train_fw_cu( 124 | sigmas, rgbs, deltas, ts, 125 | rays_a, opacity_threshold); 126 | } 127 | 128 | 129 | std::vector composite_train_bw( 130 | const torch::Tensor dL_dopacity, 131 | const torch::Tensor dL_ddepth, 132 | const torch::Tensor dL_drgb, 133 | const torch::Tensor dL_dws, 134 | const torch::Tensor sigmas, 135 | const torch::Tensor rgbs, 136 | const torch::Tensor ws, 137 | const torch::Tensor deltas, 138 | const torch::Tensor ts, 139 | const torch::Tensor rays_a, 140 | const torch::Tensor opacity, 141 | const torch::Tensor depth, 142 | const torch::Tensor rgb, 143 | const float opacity_threshold 144 | ){ 145 | CHECK_INPUT(dL_dopacity); 146 | CHECK_INPUT(dL_ddepth); 147 | CHECK_INPUT(dL_drgb); 148 | CHECK_INPUT(dL_dws); 149 | CHECK_INPUT(sigmas); 150 | CHECK_INPUT(rgbs); 151 | CHECK_INPUT(ws); 152 | CHECK_INPUT(deltas); 153 | CHECK_INPUT(ts); 154 | CHECK_INPUT(rays_a); 155 | CHECK_INPUT(opacity); 156 | CHECK_INPUT(depth); 157 | CHECK_INPUT(rgb); 158 | 159 | return composite_train_bw_cu( 160 | dL_dopacity, dL_ddepth, dL_drgb, dL_dws, 161 | sigmas, rgbs, ws, deltas, ts, rays_a, 162 | opacity, depth, rgb, opacity_threshold); 163 | } 164 | 165 | 166 | void composite_test_fw( 167 | const torch::Tensor sigmas, 168 | const torch::Tensor rgbs, 169 | const torch::Tensor deltas, 170 | const torch::Tensor ts, 171 | const torch::Tensor hits_t, 172 | const torch::Tensor alive_indices, 173 | const float T_threshold, 174 | const torch::Tensor N_eff_samples, 175 | torch::Tensor opacity, 176 | torch::Tensor depth, 177 | torch::Tensor rgb 178 | ){ 179 | CHECK_INPUT(sigmas); 180 | CHECK_INPUT(rgbs); 181 | CHECK_INPUT(deltas); 182 | CHECK_INPUT(ts); 183 | CHECK_INPUT(hits_t); 184 | CHECK_INPUT(alive_indices); 185 | CHECK_INPUT(N_eff_samples); 186 | CHECK_INPUT(opacity); 187 | CHECK_INPUT(depth); 188 | CHECK_INPUT(rgb); 189 | 190 | composite_test_fw_cu( 191 | sigmas, rgbs, deltas, ts, hits_t, alive_indices, 192 | T_threshold, N_eff_samples, 193 | opacity, depth, rgb); 194 | } 195 | 196 | 197 | std::vector distortion_loss_fw( 198 | const torch::Tensor ws, 199 | const torch::Tensor deltas, 200 | const torch::Tensor ts, 201 | const torch::Tensor rays_a 202 | ){ 203 | CHECK_INPUT(ws); 204 | CHECK_INPUT(deltas); 205 | CHECK_INPUT(ts); 206 | CHECK_INPUT(rays_a); 207 | 208 | return distortion_loss_fw_cu(ws, deltas, ts, rays_a); 209 | } 210 | 211 | 212 | torch::Tensor distortion_loss_bw( 213 | const torch::Tensor dL_dloss, 214 | const torch::Tensor ws_inclusive_scan, 215 | const torch::Tensor wts_inclusive_scan, 216 | const torch::Tensor ws, 217 | const torch::Tensor deltas, 218 | const torch::Tensor ts, 219 | const torch::Tensor rays_a 220 | ){ 221 | CHECK_INPUT(dL_dloss); 222 | CHECK_INPUT(ws_inclusive_scan); 223 | CHECK_INPUT(wts_inclusive_scan); 224 | CHECK_INPUT(ws); 225 | CHECK_INPUT(deltas); 226 | CHECK_INPUT(ts); 227 | CHECK_INPUT(rays_a); 228 | 229 | return distortion_loss_bw_cu(dL_dloss, ws_inclusive_scan, wts_inclusive_scan, 230 | ws, deltas, ts, rays_a); 231 | } 232 | 233 | 234 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 235 | m.def("ray_aabb_intersect", &ray_aabb_intersect); 236 | m.def("ray_sphere_intersect", &ray_sphere_intersect); 237 | 238 | m.def("morton3D", &morton3D); 239 | m.def("morton3D_invert", &morton3D_invert); 240 | m.def("packbits", &packbits); 241 | 242 | m.def("raymarching_train", &raymarching_train); 243 | m.def("raymarching_test", &raymarching_test); 244 | m.def("composite_train_fw", &composite_train_fw); 245 | m.def("composite_train_bw", &composite_train_bw); 246 | m.def("composite_test_fw", &composite_test_fw); 247 | 248 | m.def("distortion_loss_fw", &distortion_loss_fw); 249 | m.def("distortion_loss_bw", &distortion_loss_bw); 250 | 251 | } -------------------------------------------------------------------------------- /models/csrc/include/helper_math.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | * 3 | * Redistribution and use in source and binary forms, with or without 4 | * modification, are permitted provided that the following conditions 5 | * are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of NVIDIA CORPORATION nor the names of its 12 | * contributors may be used to endorse or promote products derived 13 | * from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | /* 29 | * This file implements common mathematical operations on vector types 30 | * (float3, float4 etc.) since these are not provided as standard by CUDA. 31 | * 32 | * The syntax is modeled on the Cg standard library. 33 | * 34 | * This is part of the Helper library includes 35 | * 36 | * Thanks to Linh Hah for additions and fixes. 37 | */ 38 | 39 | #ifndef HELPER_MATH_H 40 | #define HELPER_MATH_H 41 | 42 | #include "cuda_runtime.h" 43 | 44 | typedef unsigned int uint; 45 | typedef unsigned short ushort; 46 | 47 | #ifndef EXIT_WAIVED 48 | #define EXIT_WAIVED 2 49 | #endif 50 | 51 | #ifndef __CUDACC__ 52 | #include 53 | 54 | //////////////////////////////////////////////////////////////////////////////// 55 | // host implementations of CUDA functions 56 | //////////////////////////////////////////////////////////////////////////////// 57 | 58 | inline float fminf(float a, float b) 59 | { 60 | return a < b ? a : b; 61 | } 62 | 63 | inline float fmaxf(float a, float b) 64 | { 65 | return a > b ? a : b; 66 | } 67 | 68 | inline int max(int a, int b) 69 | { 70 | return a > b ? a : b; 71 | } 72 | 73 | inline int min(int a, int b) 74 | { 75 | return a < b ? a : b; 76 | } 77 | 78 | inline float rsqrtf(float x) 79 | { 80 | return 1.0f / sqrtf(x); 81 | } 82 | #endif 83 | 84 | //////////////////////////////////////////////////////////////////////////////// 85 | // constructors 86 | //////////////////////////////////////////////////////////////////////////////// 87 | 88 | inline __host__ __device__ float2 make_float2(float s) 89 | { 90 | return make_float2(s, s); 91 | } 92 | inline __host__ __device__ float2 make_float2(float3 a) 93 | { 94 | return make_float2(a.x, a.y); 95 | } 96 | inline __host__ __device__ float3 make_float3(float s) 97 | { 98 | return make_float3(s, s, s); 99 | } 100 | inline __host__ __device__ float3 make_float3(float2 a) 101 | { 102 | return make_float3(a.x, a.y, 0.0f); 103 | } 104 | inline __host__ __device__ float3 make_float3(float2 a, float s) 105 | { 106 | return make_float3(a.x, a.y, s); 107 | } 108 | 109 | //////////////////////////////////////////////////////////////////////////////// 110 | // negate 111 | //////////////////////////////////////////////////////////////////////////////// 112 | 113 | inline __host__ __device__ float3 operator-(float3 &a) 114 | { 115 | return make_float3(-a.x, -a.y, -a.z); 116 | } 117 | 118 | //////////////////////////////////////////////////////////////////////////////// 119 | // addition 120 | //////////////////////////////////////////////////////////////////////////////// 121 | 122 | inline __host__ __device__ float3 operator+(float3 a, float3 b) 123 | { 124 | return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); 125 | } 126 | inline __host__ __device__ void operator+=(float3 &a, float3 b) 127 | { 128 | a.x += b.x; 129 | a.y += b.y; 130 | a.z += b.z; 131 | } 132 | inline __host__ __device__ float3 operator+(float3 a, float b) 133 | { 134 | return make_float3(a.x + b, a.y + b, a.z + b); 135 | } 136 | inline __host__ __device__ void operator+=(float3 &a, float b) 137 | { 138 | a.x += b; 139 | a.y += b; 140 | a.z += b; 141 | } 142 | inline __host__ __device__ float3 operator+(float b, float3 a) 143 | { 144 | return make_float3(a.x + b, a.y + b, a.z + b); 145 | } 146 | 147 | //////////////////////////////////////////////////////////////////////////////// 148 | // subtract 149 | //////////////////////////////////////////////////////////////////////////////// 150 | 151 | inline __host__ __device__ float3 operator-(float3 a, float3 b) 152 | { 153 | return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); 154 | } 155 | inline __host__ __device__ void operator-=(float3 &a, float3 b) 156 | { 157 | a.x -= b.x; 158 | a.y -= b.y; 159 | a.z -= b.z; 160 | } 161 | inline __host__ __device__ float3 operator-(float3 a, float b) 162 | { 163 | return make_float3(a.x - b, a.y - b, a.z - b); 164 | } 165 | inline __host__ __device__ float3 operator-(float b, float3 a) 166 | { 167 | return make_float3(b - a.x, b - a.y, b - a.z); 168 | } 169 | inline __host__ __device__ void operator-=(float3 &a, float b) 170 | { 171 | a.x -= b; 172 | a.y -= b; 173 | a.z -= b; 174 | } 175 | 176 | //////////////////////////////////////////////////////////////////////////////// 177 | // multiply 178 | //////////////////////////////////////////////////////////////////////////////// 179 | 180 | inline __host__ __device__ float3 operator*(float3 a, float3 b) 181 | { 182 | return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); 183 | } 184 | inline __host__ __device__ void operator*=(float3 &a, float3 b) 185 | { 186 | a.x *= b.x; 187 | a.y *= b.y; 188 | a.z *= b.z; 189 | } 190 | inline __host__ __device__ float3 operator*(float3 a, float b) 191 | { 192 | return make_float3(a.x * b, a.y * b, a.z * b); 193 | } 194 | inline __host__ __device__ float3 operator*(float b, float3 a) 195 | { 196 | return make_float3(b * a.x, b * a.y, b * a.z); 197 | } 198 | inline __host__ __device__ void operator*=(float3 &a, float b) 199 | { 200 | a.x *= b; 201 | a.y *= b; 202 | a.z *= b; 203 | } 204 | 205 | //////////////////////////////////////////////////////////////////////////////// 206 | // divide 207 | //////////////////////////////////////////////////////////////////////////////// 208 | 209 | inline __host__ __device__ float2 operator/(float2 a, float2 b) 210 | { 211 | return make_float2(a.x / b.x, a.y / b.y); 212 | } 213 | inline __host__ __device__ void operator/=(float2 &a, float2 b) 214 | { 215 | a.x /= b.x; 216 | a.y /= b.y; 217 | } 218 | inline __host__ __device__ float2 operator/(float2 a, float b) 219 | { 220 | return make_float2(a.x / b, a.y / b); 221 | } 222 | inline __host__ __device__ void operator/=(float2 &a, float b) 223 | { 224 | a.x /= b; 225 | a.y /= b; 226 | } 227 | inline __host__ __device__ float2 operator/(float b, float2 a) 228 | { 229 | return make_float2(b / a.x, b / a.y); 230 | } 231 | 232 | inline __host__ __device__ float3 operator/(float3 a, float3 b) 233 | { 234 | return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); 235 | } 236 | inline __host__ __device__ void operator/=(float3 &a, float3 b) 237 | { 238 | a.x /= b.x; 239 | a.y /= b.y; 240 | a.z /= b.z; 241 | } 242 | inline __host__ __device__ float3 operator/(float3 a, float b) 243 | { 244 | return make_float3(a.x / b, a.y / b, a.z / b); 245 | } 246 | inline __host__ __device__ void operator/=(float3 &a, float b) 247 | { 248 | a.x /= b; 249 | a.y /= b; 250 | a.z /= b; 251 | } 252 | inline __host__ __device__ float3 operator/(float b, float3 a) 253 | { 254 | return make_float3(b / a.x, b / a.y, b / a.z); 255 | } 256 | 257 | //////////////////////////////////////////////////////////////////////////////// 258 | // min 259 | //////////////////////////////////////////////////////////////////////////////// 260 | 261 | inline __host__ __device__ float3 fminf(float3 a, float3 b) 262 | { 263 | return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z)); 264 | } 265 | 266 | //////////////////////////////////////////////////////////////////////////////// 267 | // max 268 | //////////////////////////////////////////////////////////////////////////////// 269 | 270 | inline __host__ __device__ float3 fmaxf(float3 a, float3 b) 271 | { 272 | return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z)); 273 | } 274 | 275 | //////////////////////////////////////////////////////////////////////////////// 276 | // clamp 277 | // - clamp the value v to be in the range [a, b] 278 | //////////////////////////////////////////////////////////////////////////////// 279 | 280 | inline __device__ __host__ float clamp(float f, float a, float b) 281 | { 282 | return fmaxf(a, fminf(f, b)); 283 | } 284 | inline __device__ __host__ int clamp(int f, int a, int b) 285 | { 286 | return max(a, min(f, b)); 287 | } 288 | inline __device__ __host__ uint clamp(uint f, uint a, uint b) 289 | { 290 | return max(a, min(f, b)); 291 | } 292 | 293 | inline __device__ __host__ float3 clamp(float3 v, float a, float b) 294 | { 295 | return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); 296 | } 297 | inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b) 298 | { 299 | return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); 300 | } 301 | 302 | //////////////////////////////////////////////////////////////////////////////// 303 | // dot product 304 | //////////////////////////////////////////////////////////////////////////////// 305 | 306 | inline __host__ __device__ float dot(float3 a, float3 b) 307 | { 308 | return a.x * b.x + a.y * b.y + a.z * b.z; 309 | } 310 | 311 | //////////////////////////////////////////////////////////////////////////////// 312 | // length 313 | //////////////////////////////////////////////////////////////////////////////// 314 | 315 | inline __host__ __device__ float length(float3 v) 316 | { 317 | return sqrtf(dot(v, v)); 318 | } 319 | 320 | //////////////////////////////////////////////////////////////////////////////// 321 | // normalize 322 | //////////////////////////////////////////////////////////////////////////////// 323 | 324 | inline __host__ __device__ float3 normalize(float3 v) 325 | { 326 | float invLen = rsqrtf(dot(v, v)); 327 | return v * invLen; 328 | } 329 | 330 | //////////////////////////////////////////////////////////////////////////////// 331 | // reflect 332 | // - returns reflection of incident ray I around surface normal N 333 | // - N should be normalized, reflected vector's length is equal to length of I 334 | //////////////////////////////////////////////////////////////////////////////// 335 | 336 | inline __host__ __device__ float3 reflect(float3 i, float3 n) 337 | { 338 | return i - 2.0f * n * dot(n,i); 339 | } 340 | 341 | //////////////////////////////////////////////////////////////////////////////// 342 | // cross product 343 | //////////////////////////////////////////////////////////////////////////////// 344 | 345 | inline __host__ __device__ float3 cross(float3 a, float3 b) 346 | { 347 | return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); 348 | } 349 | 350 | //////////////////////////////////////////////////////////////////////////////// 351 | // smoothstep 352 | // - returns 0 if x < a 353 | // - returns 1 if x > b 354 | // - otherwise returns smooth interpolation between 0 and 1 based on x 355 | //////////////////////////////////////////////////////////////////////////////// 356 | 357 | inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x) 358 | { 359 | float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f); 360 | return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y))); 361 | } 362 | 363 | #endif 364 | -------------------------------------------------------------------------------- /models/csrc/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 5 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 6 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 7 | 8 | 9 | std::vector ray_aabb_intersect_cu( 10 | const torch::Tensor rays_o, 11 | const torch::Tensor rays_d, 12 | const torch::Tensor centers, 13 | const torch::Tensor half_sizes, 14 | const int max_hits 15 | ); 16 | 17 | 18 | std::vector ray_sphere_intersect_cu( 19 | const torch::Tensor rays_o, 20 | const torch::Tensor rays_d, 21 | const torch::Tensor centers, 22 | const torch::Tensor radii, 23 | const int max_hits 24 | ); 25 | 26 | 27 | void packbits_cu( 28 | torch::Tensor density_grid, 29 | const float density_threshold, 30 | torch::Tensor density_bitfield 31 | ); 32 | 33 | 34 | torch::Tensor morton3D_cu(const torch::Tensor coords); 35 | torch::Tensor morton3D_invert_cu(const torch::Tensor indices); 36 | 37 | 38 | std::vector raymarching_train_cu( 39 | const torch::Tensor rays_o, 40 | const torch::Tensor rays_d, 41 | const torch::Tensor hits_t, 42 | const torch::Tensor density_bitfield, 43 | const int cascades, 44 | const float scale, 45 | const float exp_step_factor, 46 | const torch::Tensor noise, 47 | const int grid_size, 48 | const int max_samples 49 | ); 50 | 51 | 52 | std::vector raymarching_test_cu( 53 | const torch::Tensor rays_o, 54 | const torch::Tensor rays_d, 55 | torch::Tensor hits_t, 56 | const torch::Tensor alive_indices, 57 | const torch::Tensor density_bitfield, 58 | const int cascades, 59 | const float scale, 60 | const float exp_step_factor, 61 | const int grid_size, 62 | const int max_samples, 63 | const int N_samples 64 | ); 65 | 66 | 67 | std::vector composite_train_fw_cu( 68 | const torch::Tensor sigmas, 69 | const torch::Tensor rgbs, 70 | const torch::Tensor deltas, 71 | const torch::Tensor ts, 72 | const torch::Tensor rays_a, 73 | const float T_threshold 74 | ); 75 | 76 | 77 | std::vector composite_train_bw_cu( 78 | const torch::Tensor dL_dopacity, 79 | const torch::Tensor dL_ddepth, 80 | const torch::Tensor dL_drgb, 81 | const torch::Tensor dL_dws, 82 | const torch::Tensor sigmas, 83 | const torch::Tensor rgbs, 84 | const torch::Tensor ws, 85 | const torch::Tensor deltas, 86 | const torch::Tensor ts, 87 | const torch::Tensor rays_a, 88 | const torch::Tensor opacity, 89 | const torch::Tensor depth, 90 | const torch::Tensor rgb, 91 | const float T_threshold 92 | ); 93 | 94 | 95 | void composite_test_fw_cu( 96 | const torch::Tensor sigmas, 97 | const torch::Tensor rgbs, 98 | const torch::Tensor deltas, 99 | const torch::Tensor ts, 100 | const torch::Tensor hits_t, 101 | const torch::Tensor alive_indices, 102 | const float T_threshold, 103 | const torch::Tensor N_eff_samples, 104 | torch::Tensor opacity, 105 | torch::Tensor depth, 106 | torch::Tensor rgb 107 | ); 108 | 109 | 110 | std::vector distortion_loss_fw_cu( 111 | const torch::Tensor ws, 112 | const torch::Tensor deltas, 113 | const torch::Tensor ts, 114 | const torch::Tensor rays_a 115 | ); 116 | 117 | 118 | torch::Tensor distortion_loss_bw_cu( 119 | const torch::Tensor dL_dloss, 120 | const torch::Tensor ws_inclusive_scan, 121 | const torch::Tensor wts_inclusive_scan, 122 | const torch::Tensor ws, 123 | const torch::Tensor deltas, 124 | const torch::Tensor ts, 125 | const torch::Tensor rays_a 126 | ); -------------------------------------------------------------------------------- /models/csrc/intersection.cu: -------------------------------------------------------------------------------- 1 | #include "helper_math.h" 2 | #include "utils.h" 3 | 4 | 5 | __device__ __forceinline__ float2 _ray_aabb_intersect( 6 | const float3 ray_o, 7 | const float3 inv_d, 8 | const float3 center, 9 | const float3 half_size 10 | ){ 11 | 12 | const float3 t_min = (center-half_size-ray_o)*inv_d; 13 | const float3 t_max = (center+half_size-ray_o)*inv_d; 14 | 15 | const float3 _t1 = fminf(t_min, t_max); 16 | const float3 _t2 = fmaxf(t_min, t_max); 17 | const float t1 = fmaxf(fmaxf(_t1.x, _t1.y), _t1.z); 18 | const float t2 = fminf(fminf(_t2.x, _t2.y), _t2.z); 19 | 20 | if (t1 > t2) return make_float2(-1.0f); // no intersection 21 | return make_float2(t1, t2); 22 | } 23 | 24 | 25 | __global__ void ray_aabb_intersect_kernel( 26 | const torch::PackedTensorAccessor32 rays_o, 27 | const torch::PackedTensorAccessor32 rays_d, 28 | const torch::PackedTensorAccessor32 centers, 29 | const torch::PackedTensorAccessor32 half_sizes, 30 | const int max_hits, 31 | int* __restrict__ hit_cnt, 32 | torch::PackedTensorAccessor32 hits_t, 33 | torch::PackedTensorAccessor64 hits_voxel_idx 34 | ){ 35 | const int r = blockIdx.x * blockDim.x + threadIdx.x; 36 | const int v = blockIdx.y * blockDim.y + threadIdx.y; 37 | 38 | if (v>=centers.size(0) || r>=rays_o.size(0)) return; 39 | 40 | const float3 ray_o = make_float3(rays_o[r][0], rays_o[r][1], rays_o[r][2]); 41 | const float3 ray_d = make_float3(rays_d[r][0], rays_d[r][1], rays_d[r][2]); 42 | const float3 inv_d = 1.0f/ray_d; 43 | 44 | const float3 center = make_float3(centers[v][0], centers[v][1], centers[v][2]); 45 | const float3 half_size = make_float3(half_sizes[v][0], half_sizes[v][1], half_sizes[v][2]); 46 | const float2 t1t2 = _ray_aabb_intersect(ray_o, inv_d, center, half_size); 47 | 48 | if (t1t2.y > 0){ // if ray hits the voxel 49 | const int cnt = atomicAdd(&hit_cnt[r], 1); 50 | if (cnt < max_hits){ 51 | hits_t[r][cnt][0] = fmaxf(t1t2.x, 0.0f); 52 | hits_t[r][cnt][1] = t1t2.y; 53 | hits_voxel_idx[r][cnt] = v; 54 | } 55 | } 56 | } 57 | 58 | 59 | std::vector ray_aabb_intersect_cu( 60 | const torch::Tensor rays_o, 61 | const torch::Tensor rays_d, 62 | const torch::Tensor centers, 63 | const torch::Tensor half_sizes, 64 | const int max_hits 65 | ){ 66 | 67 | const int N_rays = rays_o.size(0), N_voxels = centers.size(0); 68 | auto hits_t = torch::zeros({N_rays, max_hits, 2}, rays_o.options())-1; 69 | auto hits_voxel_idx = 70 | torch::zeros({N_rays, max_hits}, 71 | torch::dtype(torch::kLong).device(rays_o.device()))-1; 72 | auto hit_cnt = 73 | torch::zeros({N_rays}, 74 | torch::dtype(torch::kInt32).device(rays_o.device())); 75 | 76 | const dim3 threads(256, 1); 77 | const dim3 blocks((N_rays+threads.x-1)/threads.x, 78 | (N_voxels+threads.y-1)/threads.y); 79 | 80 | AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "ray_aabb_intersect_cu", 81 | ([&] { 82 | ray_aabb_intersect_kernel<<>>( 83 | rays_o.packed_accessor32(), 84 | rays_d.packed_accessor32(), 85 | centers.packed_accessor32(), 86 | half_sizes.packed_accessor32(), 87 | max_hits, 88 | hit_cnt.data_ptr(), 89 | hits_t.packed_accessor32(), 90 | hits_voxel_idx.packed_accessor64() 91 | ); 92 | })); 93 | 94 | // sort intersections from near to far based on t1 95 | auto hits_order = std::get<1>(torch::sort(hits_t.index({"...", 0}))); 96 | hits_voxel_idx = torch::gather(hits_voxel_idx, 1, hits_order); 97 | hits_t = torch::gather(hits_t, 1, hits_order.unsqueeze(-1).tile({1, 1, 2})); 98 | 99 | return {hit_cnt, hits_t, hits_voxel_idx}; 100 | } 101 | 102 | 103 | __device__ __forceinline__ float2 _ray_sphere_intersect( 104 | const float3 ray_o, 105 | const float3 ray_d, 106 | const float3 center, 107 | const float radius 108 | ){ 109 | const float3 co = ray_o-center; 110 | 111 | const float a = dot(ray_d, ray_d); 112 | const float half_b = dot(ray_d, co); 113 | const float c = dot(co, co)-radius*radius; 114 | 115 | const float discriminant = half_b*half_b-a*c; 116 | 117 | if (discriminant < 0) return make_float2(-1.0f); // no intersection 118 | 119 | const float disc_sqrt = sqrtf(discriminant); 120 | return make_float2(-half_b-disc_sqrt, -half_b+disc_sqrt)/a; 121 | } 122 | 123 | 124 | __global__ void ray_sphere_intersect_kernel( 125 | const torch::PackedTensorAccessor32 rays_o, 126 | const torch::PackedTensorAccessor32 rays_d, 127 | const torch::PackedTensorAccessor32 centers, 128 | const torch::PackedTensorAccessor32 radii, 129 | const int max_hits, 130 | int* __restrict__ hit_cnt, 131 | torch::PackedTensorAccessor32 hits_t, 132 | torch::PackedTensorAccessor64 hits_sphere_idx 133 | ){ 134 | const int r = blockIdx.x * blockDim.x + threadIdx.x; 135 | const int s = blockIdx.y * blockDim.y + threadIdx.y; 136 | 137 | if (s>=centers.size(0) || r>=rays_o.size(0)) return; 138 | 139 | const float3 ray_o = make_float3(rays_o[r][0], rays_o[r][1], rays_o[r][2]); 140 | const float3 ray_d = make_float3(rays_d[r][0], rays_d[r][1], rays_d[r][2]); 141 | const float3 center = make_float3(centers[s][0], centers[s][1], centers[s][2]); 142 | 143 | const float2 t1t2 = _ray_sphere_intersect(ray_o, ray_d, center, radii[s]); 144 | 145 | if (t1t2.y > 0){ // if ray hits the sphere 146 | const int cnt = atomicAdd(&hit_cnt[r], 1); 147 | if (cnt < max_hits){ 148 | hits_t[r][cnt][0] = fmaxf(t1t2.x, 0.0f); 149 | hits_t[r][cnt][1] = t1t2.y; 150 | hits_sphere_idx[r][cnt] = s; 151 | } 152 | } 153 | } 154 | 155 | 156 | std::vector ray_sphere_intersect_cu( 157 | const torch::Tensor rays_o, 158 | const torch::Tensor rays_d, 159 | const torch::Tensor centers, 160 | const torch::Tensor radii, 161 | const int max_hits 162 | ){ 163 | 164 | const int N_rays = rays_o.size(0), N_spheres = centers.size(0); 165 | auto hits_t = torch::zeros({N_rays, max_hits, 2}, rays_o.options())-1; 166 | auto hits_sphere_idx = 167 | torch::zeros({N_rays, max_hits}, 168 | torch::dtype(torch::kLong).device(rays_o.device()))-1; 169 | auto hit_cnt = 170 | torch::zeros({N_rays}, 171 | torch::dtype(torch::kInt32).device(rays_o.device())); 172 | 173 | const dim3 threads(256, 1); 174 | const dim3 blocks((N_rays+threads.x-1)/threads.x, 175 | (N_spheres+threads.y-1)/threads.y); 176 | 177 | AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "ray_sphere_intersect_cu", 178 | ([&] { 179 | ray_sphere_intersect_kernel<<>>( 180 | rays_o.packed_accessor32(), 181 | rays_d.packed_accessor32(), 182 | centers.packed_accessor32(), 183 | radii.packed_accessor32(), 184 | max_hits, 185 | hit_cnt.data_ptr(), 186 | hits_t.packed_accessor32(), 187 | hits_sphere_idx.packed_accessor64() 188 | ); 189 | })); 190 | 191 | // sort intersections from near to far based on t1 192 | auto hits_order = std::get<1>(torch::sort(hits_t.index({"...", 0}))); 193 | hits_sphere_idx = torch::gather(hits_sphere_idx, 1, hits_order); 194 | hits_t = torch::gather(hits_t, 1, hits_order.unsqueeze(-1).tile({1, 1, 2})); 195 | 196 | return {hit_cnt, hits_t, hits_sphere_idx}; 197 | } -------------------------------------------------------------------------------- /models/csrc/losses.cu: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | // for details of the formulae, please see https://arxiv.org/pdf/2206.05085.pdf 8 | 9 | template 10 | __global__ void prefix_sums_kernel( 11 | const scalar_t* __restrict__ ws, 12 | const scalar_t* __restrict__ wts, 13 | const torch::PackedTensorAccessor64 rays_a, 14 | scalar_t* __restrict__ ws_inclusive_scan, 15 | scalar_t* __restrict__ ws_exclusive_scan, 16 | scalar_t* __restrict__ wts_inclusive_scan, 17 | scalar_t* __restrict__ wts_exclusive_scan 18 | ){ 19 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 20 | if (n >= rays_a.size(0)) return; 21 | 22 | const int start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 23 | 24 | // compute prefix sum of ws and ws*ts 25 | // [a0, a1, a2, a3, ...] -> [a0, a0+a1, a0+a1+a2, a0+a1+a2+a3, ...] 26 | thrust::inclusive_scan(thrust::device, 27 | ws+start_idx, 28 | ws+start_idx+N_samples, 29 | ws_inclusive_scan+start_idx); 30 | thrust::inclusive_scan(thrust::device, 31 | wts+start_idx, 32 | wts+start_idx+N_samples, 33 | wts_inclusive_scan+start_idx); 34 | // [a0, a1, a2, a3, ...] -> [0, a0, a0+a1, a0+a1+a2, ...] 35 | thrust::exclusive_scan(thrust::device, 36 | ws+start_idx, 37 | ws+start_idx+N_samples, 38 | ws_exclusive_scan+start_idx); 39 | thrust::exclusive_scan(thrust::device, 40 | wts+start_idx, 41 | wts+start_idx+N_samples, 42 | wts_exclusive_scan+start_idx); 43 | } 44 | 45 | 46 | template 47 | __global__ void distortion_loss_fw_kernel( 48 | const scalar_t* __restrict__ _loss, 49 | const torch::PackedTensorAccessor64 rays_a, 50 | torch::PackedTensorAccessor loss 51 | ){ 52 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 53 | if (n >= rays_a.size(0)) return; 54 | 55 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 56 | 57 | loss[ray_idx] = thrust::reduce(thrust::device, 58 | _loss+start_idx, 59 | _loss+start_idx+N_samples, 60 | (scalar_t)0); 61 | } 62 | 63 | 64 | std::vector distortion_loss_fw_cu( 65 | const torch::Tensor ws, 66 | const torch::Tensor deltas, 67 | const torch::Tensor ts, 68 | const torch::Tensor rays_a 69 | ){ 70 | const int N_rays = rays_a.size(0), N = ws.size(0); 71 | 72 | auto wts = ws * ts; 73 | 74 | auto ws_inclusive_scan = torch::zeros({N}, ws.options()); 75 | auto ws_exclusive_scan = torch::zeros({N}, ws.options()); 76 | auto wts_inclusive_scan = torch::zeros({N}, ws.options()); 77 | auto wts_exclusive_scan = torch::zeros({N}, ws.options()); 78 | 79 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 80 | 81 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_fw_cu_prefix_sums", 82 | ([&] { 83 | prefix_sums_kernel<<>>( 84 | ws.data_ptr(), 85 | wts.data_ptr(), 86 | rays_a.packed_accessor64(), 87 | ws_inclusive_scan.data_ptr(), 88 | ws_exclusive_scan.data_ptr(), 89 | wts_inclusive_scan.data_ptr(), 90 | wts_exclusive_scan.data_ptr() 91 | ); 92 | })); 93 | 94 | auto _loss = 2*(wts_inclusive_scan*ws_exclusive_scan- 95 | ws_inclusive_scan*wts_exclusive_scan) + 1.0f/3*ws*ws*deltas; 96 | 97 | auto loss = torch::zeros({N_rays}, ws.options()); 98 | 99 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_fw_cu", 100 | ([&] { 101 | distortion_loss_fw_kernel<<>>( 102 | _loss.data_ptr(), 103 | rays_a.packed_accessor64(), 104 | loss.packed_accessor() 105 | ); 106 | })); 107 | 108 | return {loss, ws_inclusive_scan, wts_inclusive_scan}; 109 | } 110 | 111 | 112 | template 113 | __global__ void distortion_loss_bw_kernel( 114 | const torch::PackedTensorAccessor dL_dloss, 115 | const torch::PackedTensorAccessor ws_inclusive_scan, 116 | const torch::PackedTensorAccessor wts_inclusive_scan, 117 | const torch::PackedTensorAccessor ws, 118 | const torch::PackedTensorAccessor deltas, 119 | const torch::PackedTensorAccessor ts, 120 | const torch::PackedTensorAccessor64 rays_a, 121 | torch::PackedTensorAccessor dL_dws 122 | ){ 123 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 124 | if (n >= rays_a.size(0)) return; 125 | 126 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 127 | const int end_idx = start_idx+N_samples-1; 128 | 129 | const scalar_t ws_sum = ws_inclusive_scan[end_idx]; 130 | const scalar_t wts_sum = wts_inclusive_scan[end_idx]; 131 | // fill in dL_dws from start_idx to end_idx 132 | for (int s=start_idx; s<=end_idx; s++){ 133 | dL_dws[s] = dL_dloss[ray_idx] * 2 * ( 134 | (s==start_idx? 135 | (scalar_t)0: 136 | (ts[s]*ws_inclusive_scan[s-1]-wts_inclusive_scan[s-1]) 137 | ) + 138 | (wts_sum-wts_inclusive_scan[s]-ts[s]*(ws_sum-ws_inclusive_scan[s])) 139 | ); 140 | dL_dws[s] += dL_dloss[ray_idx] * (scalar_t)2/3*ws[s]*deltas[s]; 141 | } 142 | } 143 | 144 | 145 | torch::Tensor distortion_loss_bw_cu( 146 | const torch::Tensor dL_dloss, 147 | const torch::Tensor ws_inclusive_scan, 148 | const torch::Tensor wts_inclusive_scan, 149 | const torch::Tensor ws, 150 | const torch::Tensor deltas, 151 | const torch::Tensor ts, 152 | const torch::Tensor rays_a 153 | ){ 154 | const int N_rays = rays_a.size(0), N = ws.size(0); 155 | 156 | auto dL_dws = torch::zeros({N}, dL_dloss.options()); 157 | 158 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 159 | 160 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_bw_cu", 161 | ([&] { 162 | distortion_loss_bw_kernel<<>>( 163 | dL_dloss.packed_accessor(), 164 | ws_inclusive_scan.packed_accessor(), 165 | wts_inclusive_scan.packed_accessor(), 166 | ws.packed_accessor(), 167 | deltas.packed_accessor(), 168 | ts.packed_accessor(), 169 | rays_a.packed_accessor64(), 170 | dL_dws.packed_accessor() 171 | ); 172 | })); 173 | 174 | return dL_dws; 175 | } -------------------------------------------------------------------------------- /models/csrc/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 5 | 6 | 7 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 8 | include_dirs = [osp.join(ROOT_DIR, "include")] 9 | # "helper_math.h" is copied from https://github.com/NVIDIA/cuda-samples/blob/master/Common/helper_math.h 10 | 11 | sources = glob.glob('*.cpp')+glob.glob('*.cu') 12 | 13 | 14 | setup( 15 | name='vren', 16 | version='2.0', 17 | author='kwea123', 18 | author_email='kwea123@gmail.com', 19 | description='cuda volume rendering library', 20 | long_description='cuda volume rendering library', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name='vren', 24 | sources=sources, 25 | include_dirs=include_dirs, 26 | extra_compile_args={'cxx': ['-O2'], 27 | 'nvcc': ['-O2']} 28 | ) 29 | ], 30 | cmdclass={ 31 | 'build_ext': BuildExtension 32 | } 33 | ) -------------------------------------------------------------------------------- /models/csrc/volumerendering.cu: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | #include 3 | #include 4 | 5 | 6 | template 7 | __global__ void composite_train_fw_kernel( 8 | const torch::PackedTensorAccessor sigmas, 9 | const torch::PackedTensorAccessor rgbs, 10 | const torch::PackedTensorAccessor deltas, 11 | const torch::PackedTensorAccessor ts, 12 | const torch::PackedTensorAccessor64 rays_a, 13 | const scalar_t T_threshold, 14 | torch::PackedTensorAccessor64 total_samples, 15 | torch::PackedTensorAccessor opacity, 16 | torch::PackedTensorAccessor depth, 17 | torch::PackedTensorAccessor rgb, 18 | torch::PackedTensorAccessor ws 19 | ){ 20 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 21 | if (n >= opacity.size(0)) return; 22 | 23 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 24 | 25 | // front to back compositing 26 | int samples = 0; scalar_t T = 1.0f; 27 | 28 | while (samples < N_samples) { 29 | const int s = start_idx + samples; 30 | const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]); 31 | const scalar_t w = a * T; // weight of the sample point 32 | 33 | rgb[ray_idx][0] += w*rgbs[s][0]; 34 | rgb[ray_idx][1] += w*rgbs[s][1]; 35 | rgb[ray_idx][2] += w*rgbs[s][2]; 36 | depth[ray_idx] += w*ts[s]; 37 | opacity[ray_idx] += w; 38 | ws[s] = w; 39 | T *= 1.0f-a; 40 | 41 | if (T <= T_threshold) break; // ray has enough opacity 42 | samples++; 43 | } 44 | total_samples[ray_idx] = samples; 45 | } 46 | 47 | 48 | std::vector composite_train_fw_cu( 49 | const torch::Tensor sigmas, 50 | const torch::Tensor rgbs, 51 | const torch::Tensor deltas, 52 | const torch::Tensor ts, 53 | const torch::Tensor rays_a, 54 | const float T_threshold 55 | ){ 56 | const int N_rays = rays_a.size(0), N = sigmas.size(0); 57 | 58 | auto opacity = torch::zeros({N_rays}, sigmas.options()); 59 | auto depth = torch::zeros({N_rays}, sigmas.options()); 60 | auto rgb = torch::zeros({N_rays, 3}, sigmas.options()); 61 | auto ws = torch::zeros({N}, sigmas.options()); 62 | auto total_samples = torch::zeros({N_rays}, torch::dtype(torch::kLong).device(sigmas.device())); 63 | 64 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 65 | 66 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_train_fw_cu", 67 | ([&] { 68 | composite_train_fw_kernel<<>>( 69 | sigmas.packed_accessor(), 70 | rgbs.packed_accessor(), 71 | deltas.packed_accessor(), 72 | ts.packed_accessor(), 73 | rays_a.packed_accessor64(), 74 | T_threshold, 75 | total_samples.packed_accessor64(), 76 | opacity.packed_accessor(), 77 | depth.packed_accessor(), 78 | rgb.packed_accessor(), 79 | ws.packed_accessor() 80 | ); 81 | })); 82 | 83 | return {total_samples, opacity, depth, rgb, ws}; 84 | } 85 | 86 | 87 | template 88 | __global__ void composite_train_bw_kernel( 89 | const torch::PackedTensorAccessor dL_dopacity, 90 | const torch::PackedTensorAccessor dL_ddepth, 91 | const torch::PackedTensorAccessor dL_drgb, 92 | const torch::PackedTensorAccessor dL_dws, 93 | scalar_t* __restrict__ dL_dws_times_ws, 94 | const torch::PackedTensorAccessor sigmas, 95 | const torch::PackedTensorAccessor rgbs, 96 | const torch::PackedTensorAccessor deltas, 97 | const torch::PackedTensorAccessor ts, 98 | const torch::PackedTensorAccessor64 rays_a, 99 | const torch::PackedTensorAccessor opacity, 100 | const torch::PackedTensorAccessor depth, 101 | const torch::PackedTensorAccessor rgb, 102 | const scalar_t T_threshold, 103 | torch::PackedTensorAccessor dL_dsigmas, 104 | torch::PackedTensorAccessor dL_drgbs 105 | ){ 106 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 107 | if (n >= opacity.size(0)) return; 108 | 109 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 110 | 111 | // front to back compositing 112 | int samples = 0; 113 | scalar_t R = rgb[ray_idx][0], G = rgb[ray_idx][1], B = rgb[ray_idx][2]; 114 | scalar_t O = opacity[ray_idx], D = depth[ray_idx]; 115 | scalar_t T = 1.0f, r = 0.0f, g = 0.0f, b = 0.0f, d = 0.0f; 116 | 117 | // compute prefix sum of dL_dws * ws 118 | // [a0, a1, a2, a3, ...] -> [a0, a0+a1, a0+a1+a2, a0+a1+a2+a3, ...] 119 | thrust::inclusive_scan(thrust::device, 120 | dL_dws_times_ws+start_idx, 121 | dL_dws_times_ws+start_idx+N_samples, 122 | dL_dws_times_ws+start_idx); 123 | scalar_t dL_dws_times_ws_sum = dL_dws_times_ws[start_idx+N_samples-1]; 124 | 125 | while (samples < N_samples) { 126 | const int s = start_idx + samples; 127 | const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]); 128 | const scalar_t w = a * T; 129 | 130 | r += w*rgbs[s][0]; g += w*rgbs[s][1]; b += w*rgbs[s][2]; 131 | d += w*ts[s]; 132 | T *= 1.0f-a; 133 | 134 | // compute gradients by math... 135 | dL_drgbs[s][0] = dL_drgb[ray_idx][0]*w; 136 | dL_drgbs[s][1] = dL_drgb[ray_idx][1]*w; 137 | dL_drgbs[s][2] = dL_drgb[ray_idx][2]*w; 138 | 139 | dL_dsigmas[s] = deltas[s] * ( 140 | dL_drgb[ray_idx][0]*(rgbs[s][0]*T-(R-r)) + 141 | dL_drgb[ray_idx][1]*(rgbs[s][1]*T-(G-g)) + 142 | dL_drgb[ray_idx][2]*(rgbs[s][2]*T-(B-b)) + // gradients from rgb 143 | dL_dopacity[ray_idx]*(1-O) + // gradient from opacity 144 | dL_ddepth[ray_idx]*(ts[s]*T-(D-d)) + // gradient from depth 145 | T*dL_dws[s]-(dL_dws_times_ws_sum-dL_dws_times_ws[s]) // gradient from ws 146 | ); 147 | 148 | if (T <= T_threshold) break; // ray has enough opacity 149 | samples++; 150 | } 151 | } 152 | 153 | 154 | std::vector composite_train_bw_cu( 155 | const torch::Tensor dL_dopacity, 156 | const torch::Tensor dL_ddepth, 157 | const torch::Tensor dL_drgb, 158 | const torch::Tensor dL_dws, 159 | const torch::Tensor sigmas, 160 | const torch::Tensor rgbs, 161 | const torch::Tensor ws, 162 | const torch::Tensor deltas, 163 | const torch::Tensor ts, 164 | const torch::Tensor rays_a, 165 | const torch::Tensor opacity, 166 | const torch::Tensor depth, 167 | const torch::Tensor rgb, 168 | const float T_threshold 169 | ){ 170 | const int N = sigmas.size(0), N_rays = rays_a.size(0); 171 | 172 | auto dL_dsigmas = torch::zeros({N}, sigmas.options()); 173 | auto dL_drgbs = torch::zeros({N, 3}, sigmas.options()); 174 | 175 | auto dL_dws_times_ws = dL_dws * ws; // auxiliary input 176 | 177 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 178 | 179 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_train_bw_cu", 180 | ([&] { 181 | composite_train_bw_kernel<<>>( 182 | dL_dopacity.packed_accessor(), 183 | dL_ddepth.packed_accessor(), 184 | dL_drgb.packed_accessor(), 185 | dL_dws.packed_accessor(), 186 | dL_dws_times_ws.data_ptr(), 187 | sigmas.packed_accessor(), 188 | rgbs.packed_accessor(), 189 | deltas.packed_accessor(), 190 | ts.packed_accessor(), 191 | rays_a.packed_accessor64(), 192 | opacity.packed_accessor(), 193 | depth.packed_accessor(), 194 | rgb.packed_accessor(), 195 | T_threshold, 196 | dL_dsigmas.packed_accessor(), 197 | dL_drgbs.packed_accessor() 198 | ); 199 | })); 200 | 201 | return {dL_dsigmas, dL_drgbs}; 202 | } 203 | 204 | 205 | template 206 | __global__ void composite_test_fw_kernel( 207 | const torch::PackedTensorAccessor sigmas, 208 | const torch::PackedTensorAccessor rgbs, 209 | const torch::PackedTensorAccessor deltas, 210 | const torch::PackedTensorAccessor ts, 211 | const torch::PackedTensorAccessor hits_t, 212 | torch::PackedTensorAccessor64 alive_indices, 213 | const scalar_t T_threshold, 214 | const torch::PackedTensorAccessor32 N_eff_samples, 215 | torch::PackedTensorAccessor opacity, 216 | torch::PackedTensorAccessor depth, 217 | torch::PackedTensorAccessor rgb 218 | ){ 219 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 220 | if (n >= alive_indices.size(0)) return; 221 | 222 | if (N_eff_samples[n]==0){ // no hit 223 | alive_indices[n] = -1; 224 | return; 225 | } 226 | 227 | const size_t r = alive_indices[n]; // ray index 228 | 229 | // front to back compositing 230 | int s = 0; scalar_t T = 1-opacity[r]; 231 | 232 | while (s < N_eff_samples[n]) { 233 | const scalar_t a = 1.0f - __expf(-sigmas[n][s]*deltas[n][s]); 234 | const scalar_t w = a * T; 235 | 236 | rgb[r][0] += w*rgbs[n][s][0]; 237 | rgb[r][1] += w*rgbs[n][s][1]; 238 | rgb[r][2] += w*rgbs[n][s][2]; 239 | depth[r] += w*ts[n][s]; 240 | opacity[r] += w; 241 | T *= 1.0f-a; 242 | 243 | if (T <= T_threshold){ // ray has enough opacity 244 | alive_indices[n] = -1; 245 | break; 246 | } 247 | s++; 248 | } 249 | } 250 | 251 | 252 | void composite_test_fw_cu( 253 | const torch::Tensor sigmas, 254 | const torch::Tensor rgbs, 255 | const torch::Tensor deltas, 256 | const torch::Tensor ts, 257 | const torch::Tensor hits_t, 258 | torch::Tensor alive_indices, 259 | const float T_threshold, 260 | const torch::Tensor N_eff_samples, 261 | torch::Tensor opacity, 262 | torch::Tensor depth, 263 | torch::Tensor rgb 264 | ){ 265 | const int N_rays = alive_indices.size(0); 266 | 267 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 268 | 269 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_test_fw_cu", 270 | ([&] { 271 | composite_test_fw_kernel<<>>( 272 | sigmas.packed_accessor(), 273 | rgbs.packed_accessor(), 274 | deltas.packed_accessor(), 275 | ts.packed_accessor(), 276 | hits_t.packed_accessor(), 277 | alive_indices.packed_accessor64(), 278 | T_threshold, 279 | N_eff_samples.packed_accessor32(), 280 | opacity.packed_accessor(), 281 | depth.packed_accessor(), 282 | rgb.packed_accessor() 283 | ); 284 | })); 285 | } -------------------------------------------------------------------------------- /models/loss/nerf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/loss/nerf/__init__.py -------------------------------------------------------------------------------- /models/loss/nerf/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import vren 4 | 5 | 6 | class DistortionLoss(torch.autograd.Function): 7 | """ 8 | Distortion loss proposed in Mip-NeRF 360 (https://arxiv.org/pdf/2111.12077.pdf) 9 | Implementation is based on DVGO-v2 (https://arxiv.org/pdf/2206.05085.pdf) 10 | 11 | Inputs: 12 | ws: (N) sample point weights 13 | deltas: (N) considered as intervals 14 | ts: (N) considered as midpoints 15 | rays_a: (N_rays, 3) ray_idx, start_idx, N_samples 16 | meaning each entry corresponds to the @ray_idx th ray, 17 | whose samples are [start_idx:start_idx+N_samples] 18 | 19 | Outputs: 20 | loss: (N_rays) 21 | """ 22 | @staticmethod 23 | def forward(ctx, ws, deltas, ts, rays_a): 24 | loss, ws_inclusive_scan, wts_inclusive_scan = \ 25 | vren.distortion_loss_fw(ws, deltas, ts, rays_a) 26 | ctx.save_for_backward(ws_inclusive_scan, wts_inclusive_scan, 27 | ws, deltas, ts, rays_a) 28 | return loss 29 | 30 | @staticmethod 31 | def backward(ctx, dL_dloss): 32 | (ws_inclusive_scan, wts_inclusive_scan, 33 | ws, deltas, ts, rays_a) = ctx.saved_tensors 34 | dL_dws = vren.distortion_loss_bw(dL_dloss, ws_inclusive_scan, 35 | wts_inclusive_scan, 36 | ws, deltas, ts, rays_a) 37 | return dL_dws, None, None, None 38 | 39 | 40 | class NeRFLoss(nn.Module): 41 | def __init__(self, lambda_opacity=1e-3, lambda_distortion=1e-3): 42 | super().__init__() 43 | 44 | self.lambda_opacity = lambda_opacity 45 | self.lambda_distortion = lambda_distortion 46 | 47 | def forward(self, results, target, **kwargs): 48 | d = {} 49 | d['rgb'] = (results['rgb']-target['rgb'])**2 50 | 51 | o = results['opacity']+1e-10 52 | # encourage opacity to be either 0 or 1 to avoid floater 53 | d['opacity'] = self.lambda_opacity*(-o*torch.log(o)) 54 | 55 | if self.lambda_distortion > 0: 56 | d['distortion'] = self.lambda_distortion * \ 57 | DistortionLoss.apply(results['ws'], results['deltas'], 58 | results['ts'], results['rays_a']) 59 | 60 | return d 61 | -------------------------------------------------------------------------------- /models/networks/FFB_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import tinycudann as tcnn 5 | import math 6 | 7 | from models.networks.Sine import Sine, sine_init, first_layer_sine_init 8 | 9 | 10 | class FFB_encoder(nn.Module): 11 | def __init__(self, encoding_config, network_config, n_input_dims, bound=1.0, has_out=True): 12 | super().__init__() 13 | 14 | self.bound = bound 15 | 16 | ### The encoder part 17 | sin_dims = network_config["dims"] 18 | sin_dims = [n_input_dims] + sin_dims 19 | self.num_sin_layers = len(sin_dims) 20 | 21 | feat_dim = encoding_config["feat_dim"] 22 | base_resolution = encoding_config["base_resolution"] 23 | per_level_scale = encoding_config["per_level_scale"] 24 | 25 | assert self.num_sin_layers > 3, "The layer number (SIREN branch) should be greater than 3." 26 | grid_level = int(self.num_sin_layers - 2) 27 | self.grid_encoder = tcnn.Encoding( 28 | n_input_dims=n_input_dims, 29 | encoding_config={ 30 | "otype": "HashGrid", 31 | "n_levels": grid_level, 32 | "n_features_per_level": feat_dim, 33 | "log2_hashmap_size": 19, 34 | "base_resolution": base_resolution, 35 | "per_level_scale": per_level_scale, 36 | }, 37 | ) 38 | self.grid_level = grid_level 39 | print(f"Grid encoder levels: {grid_level}") 40 | 41 | self.feat_dim = feat_dim 42 | 43 | ### Create the ffn to map low-dim grid feats to map high-dim SIREN feats 44 | base_sigma = encoding_config["base_sigma"] 45 | exp_sigma = encoding_config["exp_sigma"] 46 | 47 | ffn_list = [] 48 | for i in range(grid_level): 49 | ffn = torch.randn((feat_dim, sin_dims[2 + i]), requires_grad=True) * base_sigma * exp_sigma ** i 50 | 51 | ffn_list.append(ffn) 52 | 53 | self.ffn = nn.Parameter(torch.stack(ffn_list, dim=0)) 54 | 55 | 56 | ### The low-frequency MLP part 57 | for layer in range(0, self.num_sin_layers - 1): 58 | setattr(self, "sin_lin" + str(layer), nn.Linear(sin_dims[layer], sin_dims[layer + 1])) 59 | 60 | self.sin_w0 = network_config["w0"] 61 | self.sin_activation = Sine(w0=self.sin_w0) 62 | self.init_siren() 63 | 64 | ### The output layers 65 | self.has_out = has_out 66 | if has_out: 67 | size_factor = network_config["size_factor"] 68 | self.out_dim = sin_dims[-1] * size_factor 69 | 70 | for layer in range(0, grid_level): 71 | setattr(self, "out_lin" + str(layer), nn.Linear(sin_dims[layer + 1], self.out_dim)) 72 | 73 | self.sin_w0_high = network_config["w1"] 74 | self.init_siren_out() 75 | self.out_activation = Sine(w0=self.sin_w0_high) 76 | else: 77 | self.out_dim = sin_dims[-1] * grid_level 78 | 79 | 80 | ### Initialize the parameters of SIREN branch 81 | def init_siren(self): 82 | for layer in range(0, self.num_sin_layers - 1): 83 | lin = getattr(self, "sin_lin" + str(layer)) 84 | 85 | if layer == 0: 86 | first_layer_sine_init(lin) 87 | else: 88 | sine_init(lin, w0=self.sin_w0) 89 | 90 | 91 | def init_siren_out(self): 92 | for layer in range(0, self.grid_level): 93 | lin = getattr(self, "out_lin" + str(layer)) 94 | 95 | sine_init(lin, w0=self.sin_w0_high) 96 | 97 | 98 | def forward(self, in_pos): 99 | """ 100 | in_pos: [N, 3], in [-bound, bound] 101 | 102 | in_pos (for grid features) should always be located in [0.0, 1.0] 103 | x (for SIREN branch) should always be located in [-1.0, 1.0] 104 | """ 105 | 106 | x = in_pos / self.bound # to [-1, 1] 107 | in_pos = (in_pos + self.bound) / (2 * self.bound) # to [0, 1] 108 | 109 | grid_x = self.grid_encoder(in_pos) 110 | grid_x = grid_x.view(-1, self.grid_level, self.feat_dim) 111 | grid_x = grid_x.permute(1, 0, 2) 112 | 113 | embedding_list = [] 114 | for i in range(self.grid_level): 115 | grid_output = torch.matmul(grid_x[i], self.ffn[i]) 116 | grid_output = torch.sin(2 * math.pi * grid_output) 117 | embedding_list.append(grid_output) 118 | 119 | if self.has_out: 120 | x_out = torch.zeros(x.shape[0], self.out_dim, device=in_pos.device) 121 | else: 122 | feat_list = [] 123 | 124 | ### Grid encoding 125 | for layer in range(0, self.num_sin_layers - 1): 126 | sin_lin = getattr(self, "sin_lin" + str(layer)) 127 | x = sin_lin(x) 128 | x = self.sin_activation(x) 129 | 130 | if layer > 0: 131 | x = embedding_list[layer-1] + x 132 | 133 | if self.has_out: 134 | out_lin = getattr(self, "out_lin" + str(layer-1)) 135 | x_high = out_lin(x) 136 | x_high = self.out_activation(x_high) 137 | 138 | x_out = x_out + x_high 139 | else: 140 | feat_list.append(x) 141 | 142 | if self.has_out: 143 | x = x_out 144 | else: 145 | x = feat_list 146 | 147 | return x -------------------------------------------------------------------------------- /models/networks/Sine.py: -------------------------------------------------------------------------------- 1 | """ 2 | These codes are adapted from SIREN (https://github.com/vsitzmann/siren) 3 | """ 4 | 5 | 6 | import torch 7 | from torch import nn 8 | import numpy as np 9 | 10 | 11 | class Sine(nn.Module): 12 | def __init__(self, w0): 13 | super().__init__() 14 | 15 | self.w0 = w0 16 | 17 | def forward(self, input): 18 | return torch.sin(input * self.w0) 19 | 20 | 21 | 22 | def sine_init(m, w0, num_input=None): 23 | with torch.no_grad(): 24 | if hasattr(m, 'weight'): 25 | if num_input is None: 26 | num_input = m.weight.size(-1) 27 | m.weight.uniform_(-np.sqrt(6 / num_input) / w0, np.sqrt(6 / num_input) / w0) 28 | 29 | 30 | def first_layer_sine_init(m): 31 | with torch.no_grad(): 32 | if hasattr(m, 'weight'): 33 | num_input = m.weight.size(-1) 34 | m.weight.uniform_(-1.0 / num_input, 1.0 / num_input) 35 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/networks/__init__.py -------------------------------------------------------------------------------- /models/networks/img/NFFB_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models.networks.FFB_encoder import FFB_encoder 5 | 6 | 7 | class NFFB(nn.Module): 8 | def __init__(self, config, out_dims=3): 9 | super().__init__() 10 | 11 | self.xyz_encoder = FFB_encoder(n_input_dims=2, encoding_config=config["encoding"], 12 | network_config=config["SIREN"], has_out=False) 13 | 14 | ### Initializing backbone part, to merge multi-scale grid features 15 | backbone_dims = config["Backbone"]["dims"] 16 | grid_feat_len = self.xyz_encoder.out_dim 17 | backbone_dims = [grid_feat_len] + backbone_dims + [out_dims] 18 | self.num_backbone_layers = len(backbone_dims) 19 | 20 | for layer in range(0, self.num_backbone_layers - 1): 21 | out_dim = backbone_dims[layer + 1] 22 | setattr(self, "backbone_lin" + str(layer), nn.Linear(backbone_dims[layer], out_dim)) 23 | 24 | self.relu_activation = nn.ReLU(inplace=True) 25 | 26 | 27 | @torch.no_grad() 28 | # optimizer utils 29 | def get_params(self, LR_schedulers): 30 | params = [ 31 | {'params': self.parameters(), 'lr': LR_schedulers[0]["initial"]} 32 | ] 33 | 34 | return params 35 | 36 | 37 | def forward(self, in_pos): 38 | """ 39 | Inputs: 40 | x: (N, 2) xy in [-scale, scale] 41 | Outputs: 42 | out: (N, 1 or 3), the RGB values 43 | """ 44 | x = (in_pos - 0.5) * 2.0 45 | 46 | grid_x = self.xyz_encoder(x) 47 | out_feat = torch.cat(grid_x, dim=1) 48 | 49 | 50 | ### Backbone transformation 51 | for layer in range(0, self.num_backbone_layers - 1): 52 | backbone_lin = getattr(self, "backbone_lin" + str(layer)) 53 | out_feat = backbone_lin(out_feat) 54 | 55 | if layer < self.num_backbone_layers - 2: 56 | out_feat = self.relu_activation(out_feat) 57 | 58 | out_feat = out_feat.clamp(-1.0, 1.0) 59 | 60 | return out_feat -------------------------------------------------------------------------------- /models/networks/img/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/networks/img/__init__.py -------------------------------------------------------------------------------- /models/networks/nerf/NFFB_nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import tinycudann as tcnn 4 | import vren 5 | from einops import rearrange 6 | from .custom_functions import TruncExp 7 | import numpy as np 8 | 9 | from .rendering import NEAR_DISTANCE 10 | 11 | from models.networks.FFB_encoder import FFB_encoder 12 | from models.networks.Sine import sine_init, Sine 13 | 14 | 15 | class NFFB(nn.Module): 16 | def __init__(self, config, scale, rgb_act='Sigmoid'): 17 | super().__init__() 18 | 19 | self.rgb_act = rgb_act 20 | 21 | # scene bounding box 22 | self.scale = scale 23 | self.register_buffer('center', torch.zeros(1, 3)) 24 | self.register_buffer('xyz_min', -torch.ones(1, 3)*scale) 25 | self.register_buffer('xyz_max', torch.ones(1, 3)*scale) 26 | self.register_buffer('half_size', (self.xyz_max-self.xyz_min)/2) 27 | 28 | # each density grid covers [-2^(k-1), 2^(k-1)]^3 for k in [0, C-1] 29 | self.cascades = max(1+int(np.ceil(np.log2(2*scale))), 1) 30 | self.grid_size = 128 ### This property is used to speed up training process 31 | self.register_buffer('density_bitfield', 32 | torch.zeros(self.cascades*self.grid_size**3//8, dtype=torch.uint8)) 33 | 34 | 35 | self.xyz_encoder = FFB_encoder(n_input_dims=3, encoding_config=config["encoding"], 36 | network_config=config["SIREN"], bound=self.scale) 37 | 38 | ## sigma network 39 | self.num_layers = num_layers = 1 40 | hidden_dim = 64 41 | geo_feat_dim = 15 42 | 43 | sigma_net = [] 44 | for l in range(num_layers): 45 | if l == 0: 46 | in_dim = self.xyz_encoder.out_dim 47 | else: 48 | in_dim = hidden_dim 49 | 50 | if l == num_layers - 1: 51 | out_dim = 1 + geo_feat_dim # 1 sigma + 15 SH features for color 52 | else: 53 | out_dim = hidden_dim 54 | 55 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 56 | self.sigma_net = nn.ModuleList(sigma_net) 57 | 58 | self.sin_w0 = config["SIREN"]["w1"] 59 | self.sin_activation = Sine(w0=self.sin_w0) 60 | self.init_siren() 61 | 62 | # ### sigma network 63 | # self.sigma_net = \ 64 | # tcnn.Network( 65 | # n_input_dims=self.xyz_encoder.out_dim, n_output_dims=16, 66 | # network_config={ 67 | # "otype": "FullyFusedMLP", 68 | # "activation": "ReLU", 69 | # "output_activation": "None", 70 | # "n_neurons": 64, 71 | # "n_hidden_layers": 1, 72 | # } 73 | # ) 74 | 75 | # self.dir_encoder = \ 76 | # tcnn.Encoding( 77 | # n_input_dims=3, 78 | # encoding_config={ 79 | # "otype": "SphericalHarmonics", 80 | # "degree": 4, 81 | # }, 82 | # ) 83 | 84 | self.dir_encoder = \ 85 | tcnn.Encoding( 86 | n_input_dims=3, 87 | encoding_config={ 88 | "otype": "Frequency", 89 | "n_frequencies": 5 90 | }, 91 | ) 92 | 93 | self.rgb_net = \ 94 | tcnn.Network( 95 | n_input_dims=46, n_output_dims=3, 96 | network_config={ 97 | "otype": "FullyFusedMLP", 98 | "activation": "ReLU", 99 | "output_activation": self.rgb_act, 100 | "n_neurons": 64, 101 | "n_hidden_layers": 2, 102 | } 103 | ) 104 | 105 | if self.rgb_act == 'None': # rgb_net output is log-radiance 106 | for i in range(3): # independent tonemappers for r,g,b 107 | tonemapper_net = \ 108 | tcnn.Network( 109 | n_input_dims=1, n_output_dims=1, 110 | network_config={ 111 | "otype": "FullyFusedMLP", 112 | "activation": "ReLU", 113 | "output_activation": "Sigmoid", 114 | "n_neurons": 64, 115 | "n_hidden_layers": 1, 116 | } 117 | ) 118 | setattr(self, f'tonemapper_net_{i}', tonemapper_net) 119 | 120 | ### Initialize the sine-activated parameters 121 | def init_siren(self): 122 | ### Initialize the sigma network 123 | for l in range(self.num_layers): 124 | lin = self.sigma_net[l] 125 | sine_init(lin, w0=self.sin_w0) 126 | 127 | ### TODO - Transform the input coordinates into right range when feeding it into xyz_encoder() 128 | def density(self, x, return_feat=False): 129 | """ 130 | Inputs: 131 | x: (N, 3) xyz in [-scale, scale] 132 | return_feat: whether to return intermediate feature 133 | 134 | Outputs: 135 | sigmas: (N) 136 | """ 137 | # x = (x-self.xyz_min)/(self.xyz_max-self.xyz_min) 138 | h = self.xyz_encoder(x) 139 | # h = self.sigma_net(h) 140 | # 141 | for l in range(self.num_layers): 142 | h = self.sigma_net[l](h) 143 | if l != self.num_layers - 1: 144 | # h = F.relu(h, inplace=True) 145 | h = self.sin_activation(h) 146 | # h = self.sin_activation(h * self.sin_w0) 147 | 148 | sigmas = TruncExp.apply(h[:, 0]) 149 | if return_feat: return sigmas, h 150 | return sigmas 151 | 152 | def log_radiance_to_rgb(self, log_radiances, **kwargs): 153 | """ 154 | Convert log-radiance to rgb as the setting in HDR-NeRF. 155 | Called only when self.rgb_act == 'None' (with exposure) 156 | 157 | Inputs: 158 | log_radiances: (N, 3) 159 | 160 | Outputs: 161 | rgbs: (N, 3) 162 | """ 163 | if 'exposure' in kwargs: 164 | log_exposure = torch.log(kwargs['exposure']) 165 | else: # unit exposure by default 166 | log_exposure = 0 167 | 168 | out = [] 169 | for i in range(3): 170 | inp = log_radiances[:, i:i+1]+log_exposure 171 | out += [getattr(self, f'tonemapper_net_{i}')(inp)] 172 | rgbs = torch.cat(out, 1) 173 | return rgbs 174 | 175 | def forward(self, x, d, **kwargs): 176 | """ 177 | Inputs: 178 | x: (N, 3) xyz in [-scale, scale] 179 | d: (N, 3) directions 180 | 181 | Outputs: 182 | sigmas: (N) 183 | rgbs: (N, 3) 184 | """ 185 | sigmas, h = self.density(x, return_feat=True) 186 | d = d/torch.norm(d, dim=1, keepdim=True) 187 | d = self.dir_encoder((d+1)/2) 188 | rgbs = self.rgb_net(torch.cat([d, h], 1)) 189 | 190 | if self.rgb_act == 'None': # rgbs is log-radiance 191 | if kwargs.get('output_radiance', False): # output HDR map 192 | rgbs = TruncExp.apply(rgbs) 193 | else: # convert to LDR using tonemapper networks 194 | rgbs = self.log_radiance_to_rgb(rgbs, **kwargs) 195 | 196 | return sigmas, rgbs 197 | 198 | @torch.no_grad() 199 | def get_all_cells(self): 200 | """ 201 | Get all cells from the density grid. 202 | 203 | Outputs: 204 | cells: list (of length self.cascades) of indices and coords 205 | selected at each cascade 206 | """ 207 | indices = vren.morton3D(self.grid_coords).long() 208 | cells = [(indices, self.grid_coords)] * self.cascades 209 | 210 | return cells 211 | 212 | @torch.no_grad() 213 | def sample_uniform_and_occupied_cells(self, M, density_threshold): 214 | """ 215 | Sample both M uniform and occupied cells (per cascade) 216 | occupied cells are sample from cells with density > @density_threshold 217 | 218 | Outputs: 219 | cells: list (of length self.cascades) of indices and coords 220 | selected at each cascade 221 | """ 222 | cells = [] 223 | for c in range(self.cascades): 224 | # uniform cells 225 | coords1 = torch.randint(self.grid_size, (M, 3), dtype=torch.int32, 226 | device=self.density_grid.device) 227 | indices1 = vren.morton3D(coords1).long() 228 | # occupied cells 229 | indices2 = torch.nonzero(self.density_grid[c]>density_threshold)[:, 0] 230 | if len(indices2) > 0: 231 | ### Randomly pick M occupied cells 232 | rand_idx = torch.randint(len(indices2), (M,), device=self.density_grid.device) 233 | indices2 = indices2[rand_idx] 234 | coords2 = vren.morton3D_invert(indices2.int()) 235 | # concatenate 236 | cells += [(torch.cat([indices1, indices2]), torch.cat([coords1, coords2]))] 237 | 238 | return cells 239 | 240 | @torch.no_grad() 241 | def mark_invisible_cells(self, K, poses, img_wh, chunk=64**3): 242 | """ 243 | mark the cells that aren't covered by the cameras with density -1 244 | only executed once before training starts 245 | 246 | Inputs: 247 | K: (3, 3) camera intrinsics 248 | poses: (N, 3, 4) camera to world poses 249 | img_wh: image width and height 250 | chunk: the chunk size to split the cells (to avoid OOM) 251 | """ 252 | N_cams = poses.shape[0] 253 | self.count_grid = torch.zeros_like(self.density_grid) 254 | w2c_R = rearrange(poses[:, :3, :3], 'n a b -> n b a') # (N_cams, 3, 3) 255 | w2c_T = -w2c_R@poses[:, :3, 3:] # (N_cams, 3, 1) 256 | cells = self.get_all_cells() 257 | for c in range(self.cascades): 258 | indices, coords = cells[c] 259 | for i in range(0, len(indices), chunk): 260 | xyzs = coords[i:i+chunk]/(self.grid_size-1)*2-1 ### [-1, 1] 261 | s = min(2**(c-1), self.scale) 262 | half_grid_size = s/self.grid_size 263 | xyzs_w = (xyzs*(s-half_grid_size)).T # (3, chunk) ### The coordinates in world frame 264 | xyzs_c = w2c_R @ xyzs_w + w2c_T # (N_cams, 3, chunk) ### The coordinates in camera frame 265 | uvd = K @ xyzs_c # (N_cams, 3, chunk) 266 | uv = uvd[:, :2]/uvd[:, 2:] # (N_cams, 2, chunk) ### The coordinates in image frame 267 | in_image = (uvd[:, 2]>=0)& \ 268 | (uv[:, 0]>=0)&(uv[:, 0]=0)&(uv[:, 1]=NEAR_DISTANCE)&in_image # (N_cams, chunk) 271 | # if the cell is visible by at least one camera 272 | self.count_grid[c, indices[i:i+chunk]] = \ 273 | count = covered_by_cam.sum(0)/N_cams 274 | 275 | too_near_to_cam = (uvd[:, 2]0)&(~too_near_to_any_cam) 280 | self.density_grid[c, indices[i:i+chunk]] = \ 281 | torch.where(valid_mask, 0., -1.) 282 | 283 | @torch.no_grad() 284 | def update_density_grid(self, density_threshold, warmup=False, decay=0.95, erode=False): 285 | density_grid_tmp = torch.zeros_like(self.density_grid) 286 | if warmup: # during the first steps 287 | cells = self.get_all_cells() 288 | else: 289 | cells = self.sample_uniform_and_occupied_cells(self.grid_size**3//4, density_threshold) 290 | 291 | # infer and then update sigmas, and store at the density_grid_tmp 292 | for c in range(self.cascades): 293 | indices, coords = cells[c] 294 | s = min(2**(c-1), self.scale) 295 | half_grid_size = s/self.grid_size 296 | xyzs_w = (coords/(self.grid_size-1)*2-1)*(s-half_grid_size) 297 | # pick random position in the cell by adding noise in [-hgs, hgs] 298 | xyzs_w += (torch.rand_like(xyzs_w)*2-1) * half_grid_size 299 | density_grid_tmp[c, indices] = self.density(xyzs_w) 300 | 301 | if erode: 302 | # My own logic. decay more the cells that are visible to few cameras 303 | decay = torch.clamp(decay**(1/self.count_grid), 0.1, 0.95) 304 | self.density_grid = \ 305 | torch.where(self.density_grid < 0, self.density_grid, 306 | torch.maximum(self.density_grid*decay, density_grid_tmp)) 307 | 308 | mean_density = self.density_grid[self.density_grid>0].mean().item() 309 | 310 | ### Seems that, this line of code turn the density grids into a 8-bit integer array to save space 311 | vren.packbits(self.density_grid, min(mean_density, density_threshold), self.density_bitfield) 312 | 313 | @torch.no_grad() 314 | # optimizer utils 315 | def get_params(self, LR_schedulers): 316 | params = [ 317 | {'params': self.xyz_encoder.parameters(), 'lr': LR_schedulers[0]["initial"]}, 318 | {'params': self.sigma_net.parameters(), 'lr': LR_schedulers[1]["initial"]}, 319 | {'params': self.dir_encoder.parameters(), 'lr': LR_schedulers[2]["initial"]}, 320 | {'params': self.rgb_net.parameters(), 'lr': LR_schedulers[3]["initial"]}, 321 | ] 322 | 323 | return params -------------------------------------------------------------------------------- /models/networks/nerf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/networks/nerf/__init__.py -------------------------------------------------------------------------------- /models/networks/nerf/custom_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import vren 3 | from torch.cuda.amp import custom_fwd, custom_bwd 4 | from torch_scatter import segment_csr 5 | from einops import rearrange 6 | 7 | 8 | ### Compute the intersection information for rays and AABB 9 | class RayAABBIntersector(torch.autograd.Function): 10 | """ 11 | Computes the intersections of rays and axis-aligned voxels. 12 | 13 | Inputs: 14 | rays_o: (N_rays, 3) ray origins 15 | rays_d: (N_rays, 3) ray directions 16 | centers: (N_voxels, 3) voxel centers 17 | half_sizes: (N_voxels, 3) voxel half sizes 18 | max_hits: maximum number of intersected voxels to keep for one ray 19 | (for a cubic scene, this is at most 3*N_voxels^(1/3)-2) 20 | 21 | Outputs: 22 | hits_cnt: (N_rays) number of hits for each ray 23 | (followings are from near to far) 24 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit) 25 | hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit) 26 | """ 27 | @staticmethod 28 | @custom_fwd(cast_inputs=torch.float32) 29 | def forward(ctx, rays_o, rays_d, center, half_size, max_hits): 30 | return vren.ray_aabb_intersect(rays_o, rays_d, center, half_size, max_hits) 31 | 32 | 33 | ### Compute the intersection information between rays and a set of spheres 34 | class RaySphereIntersector(torch.autograd.Function): 35 | """ 36 | Computes the intersections of rays and spheres. 37 | 38 | Inputs: 39 | rays_o: (N_rays, 3) ray origins 40 | rays_d: (N_rays, 3) ray directions 41 | centers: (N_spheres, 3) sphere centers 42 | radii: (N_spheres, 3) radii 43 | max_hits: maximum number of intersected spheres to keep for one ray 44 | 45 | Outputs: 46 | hits_cnt: (N_rays) number of hits for each ray 47 | (followings are from near to far) 48 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit) 49 | hits_sphere_idx: (N_rays, max_hits) hit sphere indices (-1 if no hit) 50 | """ 51 | @staticmethod 52 | @custom_fwd(cast_inputs=torch.float32) 53 | def forward(ctx, rays_o, rays_d, center, radii, max_hits): 54 | return vren.ray_sphere_intersect(rays_o, rays_d, center, radii, max_hits) 55 | 56 | 57 | class RayMarcher(torch.autograd.Function): 58 | """ 59 | March the rays to get sample point positions and directions. 60 | 61 | Inputs: 62 | rays_o: (N_rays, 3) ray origins 63 | rays_d: (N_rays, 3) normalized ray directions 64 | hits_t: (N_rays, 2) near and far bounds from aabb intersection 65 | density_bitfield: (C*G**3//8) 66 | cascades: int 67 | scale: float 68 | exp_step_factor: the exponential factor to scale the steps 69 | grid_size: int 70 | max_samples: int 71 | 72 | Outputs: 73 | rays_a: (N_rays) ray_idx, start_idx, N_samples 74 | xyzs: (N, 3) sample positions 75 | dirs: (N, 3) sample view directions 76 | deltas: (N) dt for integration 77 | ts: (N) sample ts 78 | """ 79 | @staticmethod 80 | @custom_fwd(cast_inputs=torch.float32) 81 | def forward(ctx, rays_o, rays_d, hits_t, 82 | density_bitfield, cascades, scale, exp_step_factor, 83 | grid_size, max_samples): 84 | # noise to perturb the first sample of each ray 85 | noise = torch.rand_like(rays_o[:, 0]) 86 | 87 | rays_a, xyzs, dirs, deltas, ts, counter = \ 88 | vren.raymarching_train( 89 | rays_o, rays_d, hits_t, 90 | density_bitfield, cascades, scale, 91 | exp_step_factor, noise, grid_size, max_samples) 92 | 93 | total_samples = counter[0] # total samples for all rays 94 | # remove redundant output 95 | xyzs = xyzs[:total_samples] 96 | dirs = dirs[:total_samples] 97 | deltas = deltas[:total_samples] 98 | ts = ts[:total_samples] 99 | 100 | ctx.save_for_backward(rays_a, ts) 101 | 102 | return rays_a, xyzs, dirs, deltas, ts, total_samples 103 | 104 | @staticmethod 105 | @custom_bwd 106 | def backward(ctx, dL_drays_a, dL_dxyzs, dL_ddirs, 107 | dL_ddeltas, dL_dts, dL_dtotal_samples): 108 | rays_a, ts = ctx.saved_tensors 109 | segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1]+rays_a[-1:, 2]]) 110 | dL_drays_o = segment_csr(dL_dxyzs, segments) 111 | dL_drays_d = \ 112 | segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments) 113 | 114 | return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None 115 | 116 | 117 | ### Compute the information for RGB, depth and opacity 118 | class VolumeRenderer(torch.autograd.Function): 119 | """ 120 | Volume rendering with different number of samples per ray 121 | Used in training only 122 | 123 | Inputs: 124 | sigmas: (N) 125 | rgbs: (N, 3) 126 | deltas: (N) 127 | ts: (N) 128 | rays_a: (N_rays, 3) ray_idx, start_idx, N_samples 129 | meaning each entry corresponds to the @ray_idx th ray, 130 | whose samples are [start_idx:start_idx+N_samples] 131 | T_threshold: float, stop the ray if the transmittance is below it 132 | 133 | Outputs: 134 | total_samples: int, total effective samples 135 | opacity: (N_rays) 136 | depth: (N_rays) 137 | rgb: (N_rays, 3) 138 | ws: (N) sample point weights 139 | """ 140 | @staticmethod 141 | @custom_fwd(cast_inputs=torch.float32) 142 | def forward(ctx, sigmas, rgbs, deltas, ts, rays_a, T_threshold): 143 | total_samples, opacity, depth, rgb, ws = \ 144 | vren.composite_train_fw(sigmas, rgbs, deltas, ts, 145 | rays_a, T_threshold) 146 | ctx.save_for_backward(sigmas, rgbs, deltas, ts, rays_a, 147 | opacity, depth, rgb, ws) 148 | ctx.T_threshold = T_threshold 149 | return total_samples.sum(), opacity, depth, rgb, ws 150 | 151 | @staticmethod 152 | @custom_bwd 153 | def backward(ctx, dL_dtotal_samples, dL_dopacity, dL_ddepth, dL_drgb, dL_dws): 154 | sigmas, rgbs, deltas, ts, rays_a, \ 155 | opacity, depth, rgb, ws = ctx.saved_tensors 156 | dL_dsigmas, dL_drgbs = \ 157 | vren.composite_train_bw(dL_dopacity, dL_ddepth, dL_drgb, dL_dws, 158 | sigmas, rgbs, ws, deltas, ts, 159 | rays_a, 160 | opacity, depth, rgb, 161 | ctx.T_threshold) 162 | return dL_dsigmas, dL_drgbs, None, None, None, None 163 | 164 | 165 | class TruncExp(torch.autograd.Function): 166 | @staticmethod 167 | @custom_fwd(cast_inputs=torch.float32) 168 | def forward(ctx, x): 169 | ctx.save_for_backward(x) 170 | return torch.exp(x) 171 | 172 | @staticmethod 173 | @custom_bwd 174 | def backward(ctx, dL_dout): 175 | x = ctx.saved_tensors[0] 176 | return dL_dout * torch.exp(x.clamp(-15, 15)) 177 | -------------------------------------------------------------------------------- /models/networks/nerf/rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .custom_functions import \ 3 | RayAABBIntersector, RayMarcher, VolumeRenderer 4 | from einops import rearrange 5 | import vren 6 | 7 | MAX_SAMPLES = 1024 8 | NEAR_DISTANCE = 0.01 9 | 10 | 11 | @torch.cuda.amp.autocast() 12 | def render(model, rays_o, rays_d, **kwargs): 13 | """ 14 | Render rays by 15 | 1. Compute the intersection of the rays with the scene bounding box 16 | 2. Follow the process in @render_func (different for train/test) 17 | 18 | Inputs: 19 | model: NGP 20 | rays_o: (N_rays, 3) ray origins 21 | rays_d: (N_rays, 3) ray directions 22 | 23 | Outputs: 24 | result: dictionary containing final rgb and depth 25 | """ 26 | rays_o = rays_o.contiguous(); rays_d = rays_d.contiguous() 27 | _, hits_t, _ = \ 28 | RayAABBIntersector.apply(rays_o, rays_d, model.center, model.half_size, 1) 29 | hits_t[(hits_t[:, 0, 0]>=0)&(hits_t[:, 0, 0] (n1 n2) c') 91 | dirs = rearrange(dirs, 'n1 n2 c -> (n1 n2) c') 92 | valid_mask = ~torch.all(dirs==0, dim=1) 93 | if valid_mask.sum()==0: break 94 | 95 | sigmas = torch.zeros(len(xyzs), device=device) 96 | rgbs = torch.zeros(len(xyzs), 3, device=device) 97 | sigmas[valid_mask], _rgbs = model(xyzs[valid_mask], dirs[valid_mask], **kwargs) 98 | rgbs[valid_mask] = _rgbs.float() 99 | sigmas = rearrange(sigmas, '(n1 n2) -> n1 n2', n2=N_samples) 100 | rgbs = rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=N_samples) 101 | 102 | vren.composite_test_fw( 103 | sigmas, rgbs, deltas, ts, 104 | hits_t[:, 0], alive_indices, kwargs.get('T_threshold', 1e-4), 105 | N_eff_samples, opacity, depth, rgb) 106 | alive_indices = alive_indices[alive_indices>=0] # remove converged rays 107 | 108 | results['opacity'] = opacity 109 | results['depth'] = depth 110 | results['rgb'] = rgb 111 | results['total_samples'] = total_samples # total samples for all rays 112 | 113 | if exp_step_factor==0: # synthetic 114 | rgb_bg = torch.ones(3, device=device) 115 | else: # real 116 | rgb_bg = torch.zeros(3, device=device) 117 | results['rgb'] += rgb_bg*rearrange(1-opacity, 'n -> n 1') 118 | 119 | return results 120 | 121 | 122 | ### Given the ray information, render RGB images for training stages 123 | def __render_rays_train(model, rays_o, rays_d, hits_t, **kwargs): 124 | """ 125 | Render rays by 126 | 1. March the rays along their directions, querying @density_bitfield 127 | to skip empty space, and get the effective sample points (where 128 | there is object) 129 | 2. Infer the NN at these positions and view directions to get properties 130 | (currently sigmas and rgbs) 131 | 3. Use volume rendering to combine the result (front to back compositing 132 | and early stop the ray if its transmittance is below a threshold) 133 | """ 134 | exp_step_factor = kwargs.get('exp_step_factor', 0.) 135 | results = {} 136 | 137 | (rays_a, xyzs, dirs, 138 | results['deltas'], results['ts'], results['rm_samples']) = \ 139 | RayMarcher.apply( 140 | rays_o, rays_d, hits_t[:, 0], model.density_bitfield, 141 | model.cascades, model.scale, 142 | exp_step_factor, model.grid_size, MAX_SAMPLES) 143 | 144 | for k, v in kwargs.items(): # supply additional inputs, repeated per ray 145 | if isinstance(v, torch.Tensor): 146 | kwargs[k] = torch.repeat_interleave(v[rays_a[:, 0]], rays_a[:, 2], 0) 147 | sigmas, rgbs = model(xyzs, dirs, **kwargs) 148 | 149 | (results['vr_samples'], results['opacity'], 150 | results['depth'], results['rgb'], results['ws']) = \ 151 | VolumeRenderer.apply(sigmas, rgbs.contiguous(), results['deltas'], results['ts'], 152 | rays_a, kwargs.get('T_threshold', 1e-4)) 153 | results['rays_a'] = rays_a 154 | 155 | if exp_step_factor==0: # synthetic 156 | rgb_bg = torch.ones(3, device=rays_o.device) 157 | else: # real 158 | if kwargs.get('random_bg', False): 159 | rgb_bg = torch.rand(3, device=rays_o.device) 160 | else: 161 | rgb_bg = torch.zeros(3, device=rays_o.device) 162 | results['rgb'] = results['rgb'] + \ 163 | rgb_bg*rearrange(1-results['opacity'], 'n -> n 1') 164 | 165 | return results 166 | -------------------------------------------------------------------------------- /models/networks/sdf/NFFB_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models.networks.FFB_encoder import FFB_encoder 5 | from models.networks.Sine import sine_init 6 | 7 | 8 | class NFFB(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | 12 | self.xyz_encoder = FFB_encoder(n_input_dims=3, encoding_config=config["encoding"], 13 | network_config=config["SIREN"], has_out=False) 14 | enc_out_dim = self.xyz_encoder.out_dim 15 | 16 | self.out_lin = nn.Linear(enc_out_dim, 1) 17 | 18 | self.init_output(config["SIREN"]["dims"][-1]) 19 | 20 | 21 | def init_output(self, layer_size): 22 | sine_init(self.out_lin, self.xyz_encoder.sin_w0, layer_size) 23 | 24 | 25 | def forward(self, x): 26 | """ 27 | Inputs: 28 | x: (N, 3) xyz in [-scale, scale] 29 | Outputs: 30 | out: (N), the final sdf value 31 | """ 32 | out = self.xyz_encoder(x) 33 | 34 | out_feat = torch.cat(out, dim=1) 35 | out_feat = self.out_lin(out_feat) 36 | out = out_feat / self.xyz_encoder.grid_level 37 | 38 | return out 39 | 40 | 41 | @torch.no_grad() 42 | # optimizer utils 43 | def get_params(self, LR_schedulers): 44 | params = [ 45 | {'params': self.parameters(), 'lr': LR_schedulers[0]["initial"]} 46 | ] 47 | 48 | return params -------------------------------------------------------------------------------- /models/networks/sdf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/NFFB/59aec650e02f2401293e17d292bbbb73408beac7/models/networks/sdf/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.4.1 2 | kornia==0.6.5 3 | pytorch-lightning==1.7.7 4 | matplotlib==3.5.2 5 | opencv-python==4.6.0.66 6 | lpips 7 | imageio 8 | imageio-ffmpeg 9 | jupyter 10 | scipy 11 | pymcubes 12 | trimesh 13 | dearpygui -------------------------------------------------------------------------------- /scripts/img/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | These codes are adapted from tiny-cuda-nn (https://github.com/NVlabs/tiny-cuda-nn) 5 | """ 6 | 7 | import imageio 8 | import numpy as np 9 | import os 10 | import struct 11 | 12 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13 | 14 | def mse2psnr(x): 15 | return -10.*np.log(x)/np.log(10.) 16 | 17 | def write_image_imageio(img_file, img, quality): 18 | img = (np.clip(img, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) 19 | kwargs = {} 20 | if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]: 21 | if img.ndim >= 3 and img.shape[2] > 3: 22 | img = img[:,:,:3] 23 | kwargs["quality"] = quality 24 | kwargs["subsampling"] = 0 25 | imageio.imwrite(img_file, img, **kwargs) 26 | 27 | def read_image_imageio(img_file): 28 | img = imageio.imread(img_file) 29 | img = np.asarray(img).astype(np.float32) 30 | if len(img.shape) == 2: 31 | img = img[:,:,np.newaxis] 32 | return img / 255.0 33 | 34 | ### Do the exp and division operations to expand the expressivity of valid rgb values 35 | def srgb_to_linear(img): 36 | limit = 0.04045 37 | return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92) 38 | 39 | def linear_to_srgb(img): 40 | limit = 0.0031308 41 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) 42 | 43 | def read_image(file): 44 | if os.path.splitext(file)[1] == ".bin": 45 | with open(file, "rb") as f: 46 | bytes = f.read() 47 | h, w = struct.unpack("ii", bytes[:8]) 48 | img = np.frombuffer(bytes, dtype=np.float16, count=h*w*4, offset=8).astype(np.float32).reshape([h, w, 4]) 49 | else: 50 | img = read_image_imageio(file) 51 | if img.shape[2] == 4: 52 | img[...,0:3] = srgb_to_linear(img[...,0:3]) 53 | # Premultiply alpha 54 | img[...,0:3] *= img[...,3:4] 55 | else: 56 | img = srgb_to_linear(img) 57 | return img 58 | 59 | def write_image(file, img, quality=95): 60 | if os.path.splitext(file)[1] == ".bin": 61 | if img.shape[2] < 4: 62 | img = np.dstack((img, np.ones([img.shape[0], img.shape[1], 4 - img.shape[2]]))) 63 | with open(file, "wb") as f: 64 | f.write(struct.pack("ii", img.shape[0], img.shape[1])) 65 | f.write(img.astype(np.float16).tobytes()) 66 | else: 67 | if img.shape[2] == 4: 68 | img = np.copy(img) 69 | # Unmultiply alpha 70 | img[...,0:3] = np.divide(img[...,0:3], img[...,3:4], out=np.zeros_like(img[...,0:3]), where=img[...,3:4] != 0) 71 | img[...,0:3] = linear_to_srgb(img[...,0:3]) 72 | else: 73 | img = linear_to_srgb(img) 74 | write_image_imageio(file, img, quality) 75 | 76 | def trim(error, skip=0.000001): 77 | error = np.sort(error.flatten()) 78 | size = error.size 79 | skip = int(skip * size) 80 | return error[skip:size-skip].mean() 81 | 82 | def luminance(a): 83 | a = np.maximum(0, a)**0.4545454545 84 | return 0.2126 * a[:,:,0] + 0.7152 * a[:,:,1] + 0.0722 * a[:,:,2] 85 | 86 | def L1(img, ref): 87 | return np.abs(img - ref) 88 | 89 | def APE(img, ref): 90 | return L1(img, ref) / (1e-2 + ref) 91 | 92 | def SAPE(img, ref): 93 | return L1(img, ref) / (1e-2 + (ref + img) / 2.) 94 | 95 | def L2(img, ref): 96 | return (img - ref)**2 97 | 98 | def RSE(img, ref): 99 | return L2(img, ref) / (1e-2 + ref**2) 100 | 101 | def rgb_mean(img): 102 | return np.mean(img, axis=2) 103 | 104 | def compute_error_img(metric, img, ref): 105 | img[np.logical_not(np.isfinite(img))] = 0 106 | img = np.maximum(img, 0.) 107 | if metric == "MAE": 108 | return L1(img, ref) 109 | elif metric == "MAPE": 110 | return APE(img, ref) 111 | elif metric == "SMAPE": 112 | return SAPE(img, ref) 113 | elif metric == "MSE": 114 | return L2(img, ref) 115 | elif metric == "MScE": 116 | return L2(np.clip(img, 0.0, 1.0), np.clip(ref, 0.0, 1.0)) 117 | elif metric == "MRSE": 118 | return RSE(img, ref) 119 | elif metric == "MtRSE": 120 | return trim(RSE(img, ref)) 121 | elif metric == "MRScE": 122 | return RSE(np.clip(img, 0, 100), np.clip(ref, 0, 100)) 123 | 124 | raise ValueError(f"Unknown metric: {metric}.") 125 | 126 | 127 | def compute_error(metric, img, ref): 128 | metric_map = compute_error_img(metric, img, ref) 129 | metric_map[np.logical_not(np.isfinite(metric_map))] = 0 130 | if len(metric_map.shape) == 3: 131 | metric_map = np.mean(metric_map, axis=2) 132 | mean = np.mean(metric_map) 133 | return mean 134 | -------------------------------------------------------------------------------- /scripts/img/opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_opts(): 5 | parser = argparse.ArgumentParser(description="Parsing parameters for 3D occupancy.") 6 | 7 | # config file 8 | parser.add_argument("--config", type=str, required=True, 9 | default="config/img/config.json", 10 | help="network configuration") 11 | 12 | # data file 13 | parser.add_argument("--input_path", type=str, required=True) 14 | parser.add_argument("--output_dir", type=str, default="experiments", 15 | help="output directory") 16 | 17 | # training options 18 | parser.add_argument('--batch_size', type=int, default=2**18, 19 | help='number of points in a batch') 20 | parser.add_argument('--num_epochs', type=int, default=50, 21 | help='number of training epochs') 22 | parser.add_argument('--seed', type=int, default=42, 23 | help='random seed for training') 24 | 25 | # validation options 26 | parser.add_argument('--val_only', action='store_true', default=False, 27 | help='run only validation (need to provide ckpt_path)') 28 | parser.add_argument('--no_save_test', action='store_true', default=False, 29 | help='whether to perform marching cubes for input shapes') 30 | 31 | # misc 32 | parser.add_argument('--ckpt_path', type=str, default=None, 33 | help='pretrained checkpoint to load') 34 | parser.add_argument('--clamp_distance', type=float, default=1.0, 35 | help='the value range for sdfs') 36 | 37 | 38 | args = parser.parse_args() 39 | return args -------------------------------------------------------------------------------- /scripts/img/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | These codes are adapted from tiny-cuda-nn (https://github.com/NVlabs/tiny-cuda-nn) 5 | """ 6 | 7 | import imageio 8 | import numpy as np 9 | import os 10 | import struct 11 | 12 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13 | 14 | 15 | def write_image_imageio(img_file, img, quality): 16 | img = (np.clip(img, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) 17 | kwargs = {} 18 | if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]: 19 | if img.ndim >= 3 and img.shape[2] > 3: 20 | img = img[:,:,:3] 21 | kwargs["quality"] = quality 22 | kwargs["subsampling"] = 0 23 | imageio.imwrite(img_file, img, **kwargs) 24 | 25 | def read_image_imageio(img_file): 26 | img = imageio.imread(img_file) 27 | img = np.asarray(img).astype(np.float32) 28 | if len(img.shape) == 2: 29 | img = img[:,:,np.newaxis] 30 | return img / 255.0 31 | 32 | ### Do the exp and division operations to expand the expressivity of valid rgb values 33 | def srgb_to_linear(img): 34 | limit = 0.04045 35 | return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92) 36 | 37 | def linear_to_srgb(img): 38 | limit = 0.0031308 39 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) 40 | 41 | def read_image(file): 42 | if os.path.splitext(file)[1] == ".bin": 43 | with open(file, "rb") as f: 44 | bytes = f.read() 45 | h, w = struct.unpack("ii", bytes[:8]) 46 | img = np.frombuffer(bytes, dtype=np.float16, count=h*w*4, offset=8).astype(np.float32).reshape([h, w, 4]) 47 | else: 48 | img = read_image_imageio(file) 49 | if img.shape[2] == 4: 50 | img[...,0:3] = srgb_to_linear(img[...,0:3]) 51 | # Premultiply alpha 52 | img[...,0:3] *= img[...,3:4] 53 | else: 54 | img = srgb_to_linear(img) 55 | return img 56 | 57 | def write_image(file, img, quality=95): 58 | if os.path.splitext(file)[1] == ".bin": 59 | if img.shape[2] < 4: 60 | img = np.dstack((img, np.ones([img.shape[0], img.shape[1], 4 - img.shape[2]]))) 61 | with open(file, "wb") as f: 62 | f.write(struct.pack("ii", img.shape[0], img.shape[1])) 63 | f.write(img.astype(np.float16).tobytes()) 64 | else: 65 | if img.shape[2] == 4: 66 | img = np.copy(img) 67 | # Unmultiply alpha 68 | img[...,0:3] = np.divide(img[...,0:3], img[...,3:4], out=np.zeros_like(img[...,0:3]), where=img[...,3:4] != 0) 69 | img[...,0:3] = linear_to_srgb(img[...,0:3]) 70 | else: 71 | img = linear_to_srgb(img) 72 | write_image_imageio(file, img, quality) -------------------------------------------------------------------------------- /scripts/nvs/opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_opts(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # dataset parameters 7 | parser.add_argument('--root_dir', type=str, required=True, 8 | help='root directory of dataset') 9 | parser.add_argument('--dataset_name', type=str, default='nsvf', 10 | choices=['nerf', 'nsvf', 'colmap', 'nerfpp', 'rtmv'], 11 | help='which dataset to train/test') 12 | parser.add_argument('--split', type=str, default='train', 13 | choices=['train', 'trainval', 'trainvaltest'], 14 | help='use which split to train') 15 | parser.add_argument('--downsample', type=float, default=1.0, 16 | help='downsample factor (<=1.0) for the images') 17 | 18 | # model parameters 19 | parser.add_argument('--scale', type=float, default=0.5, 20 | help='scene scale (whole scene must lie in [-scale, scale]^3') 21 | parser.add_argument('--use_exposure', action='store_true', default=False, 22 | help='whether to train in HDR-NeRF setting') 23 | parser.add_argument('--config', nargs="?", type=str, default="config/nerf/config.json") 24 | 25 | # loss parameters 26 | parser.add_argument('--distortion_loss_w', type=float, default=0, 27 | help='''weight of distortion loss (see losses.py), 28 | 0 to disable (default), to enable, 29 | a good value is 1e-3 for real scene and 1e-2 for synthetic scene 30 | ''') 31 | 32 | # training options 33 | parser.add_argument('--batch_size', type=int, default=4096, 34 | help='number of rays in a batch') 35 | parser.add_argument('--ray_sampling_strategy', type=str, default='all_images', 36 | choices=['all_images', 'same_image'], 37 | help=''' 38 | all_images: uniformly from all pixels of ALL images 39 | same_image: uniformly from all pixels of a SAME image 40 | ''') 41 | parser.add_argument('--num_epochs', type=int, default=30, 42 | help='number of training epochs') 43 | parser.add_argument('--num_gpus', type=int, default=1, 44 | help='number of gpus') 45 | parser.add_argument('--lr', type=float, default=1e-2, 46 | help='learning rate') 47 | parser.add_argument('--seed', type=int, default=42, 48 | help='random seed for training') 49 | # experimental training options 50 | parser.add_argument('--optimize_ext', action='store_true', default=False, 51 | help='whether to optimize extrinsics') 52 | parser.add_argument('--random_bg', action='store_true', default=False, 53 | help='''whether to train with random bg color (real scene only) 54 | to avoid objects with black color to be predicted as transparent 55 | ''') 56 | 57 | # validation options 58 | parser.add_argument('--eval_lpips', action='store_true', default=False, 59 | help='evaluate lpips metric (consumes more VRAM)') 60 | parser.add_argument('--val_only', action='store_true', default=False, 61 | help='run only validation (need to provide ckpt_path)') 62 | parser.add_argument('--no_save_test', action='store_true', default=False, 63 | help='whether to save test image and video') 64 | 65 | # misc 66 | parser.add_argument('--exp_name', type=str, default='exp', 67 | help='experiment name') 68 | parser.add_argument('--ckpt_path', type=str, default=None, 69 | help='pretrained checkpoint to load (including optimizers, etc)') 70 | parser.add_argument('--weight_path', type=str, default=None, 71 | help='pretrained checkpoint to load (excluding optimizers, etc)') 72 | 73 | return parser.parse_args() 74 | -------------------------------------------------------------------------------- /scripts/nvs/prepare_rtmv.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import glob 3 | import sys 4 | from tqdm import tqdm 5 | import os 6 | import numpy as np 7 | sys.path.append('datasets') 8 | from color_utils import linear_to_srgb 9 | 10 | import warnings; warnings.filterwarnings("ignore") 11 | 12 | 13 | if __name__ == '__main__': 14 | # convert hdr images to ldr by applying linear_to_srgb and clamping tone-mapping 15 | # and save into images/ folder to accelerate reading 16 | root_dir = sys.argv[1] 17 | envs = sorted(os.listdir(root_dir)) 18 | print('Generating ldr images from hdr images ...') 19 | for env in tqdm(envs): 20 | for scene in tqdm(sorted(os.listdir(os.path.join(root_dir, env)))): 21 | os.makedirs(os.path.join(root_dir, env, scene, 'images'), exist_ok=True) 22 | for i, img_p in enumerate(tqdm(sorted(glob.glob(os.path.join(root_dir, env, scene, '*[0-9].exr'))))): 23 | img = imageio.imread(img_p) # hdr 24 | img[..., :3] = linear_to_srgb(img[..., :3]) 25 | img = (255*img).astype(np.uint8) 26 | imageio.imsave(os.path.join(root_dir, env, scene, f'images/{i:05d}.png'), img) -------------------------------------------------------------------------------- /scripts/sdf/opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_opts(): 5 | parser = argparse.ArgumentParser(description="Parsing parameters for 3D occupancy.") 6 | 7 | # config file 8 | parser.add_argument("--config", type=str, required=True, 9 | default="config/sdf/config.json", 10 | help="network configuration") 11 | 12 | # data file 13 | parser.add_argument("--input_path", type=str, required=True) 14 | parser.add_argument("--output_dir", type=str, default="experiments", 15 | help="output directory") 16 | 17 | # training options 18 | parser.add_argument('--batch_size', type=int, default=49152, 19 | help='number of points in a batch') 20 | parser.add_argument('--num_epochs', type=int, default=50, 21 | help='number of training epochs') 22 | parser.add_argument('--seed', type=int, default=42, 23 | help='random seed for training') 24 | 25 | # validation options 26 | parser.add_argument('--val_only', action='store_true', default=False, 27 | help='run only validation (need to provide ckpt_path)') 28 | parser.add_argument('--no_save_test', action='store_true', default=False, 29 | help='whether to perform marching cubes for input shapes') 30 | 31 | # misc 32 | parser.add_argument('--ckpt_path', type=str, default=None, 33 | help='pretrained checkpoint to load') 34 | parser.add_argument('--clamp_distance', type=float, default=0.1, 35 | help='the value range for sdfs') 36 | 37 | 38 | args = parser.parse_args() 39 | return args -------------------------------------------------------------------------------- /scripts/sdf/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import trimesh 3 | import mcubes 4 | 5 | 6 | def create_mesh(model, mesh_out_path, grid_res): 7 | # Prepare directory 8 | num_samples = grid_res ** 3 9 | 10 | sdf_values = torch.zeros(num_samples, 1) 11 | 12 | bound_min = torch.FloatTensor([-1.0, -1.0, -1.0]) 13 | bound_max = torch.FloatTensor([1.0, 1.0, 1.0]) 14 | 15 | X = torch.linspace(bound_min[0], bound_max[0], grid_res) 16 | Y = torch.linspace(bound_min[1], bound_max[1], grid_res) 17 | Z = torch.linspace(bound_min[2], bound_max[2], grid_res) 18 | 19 | xx, yy, zz = torch.meshgrid(X, Y, Z, indexing='ij') 20 | inputs = torch.concat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda() # [N, 3] 21 | 22 | head = 0 23 | max_batch = int(2 ** 18) 24 | 25 | while head < num_samples: 26 | sample_subset = inputs[head : min(head + max_batch, num_samples), :] 27 | 28 | sdf_values[head : min(head + max_batch, num_samples), 0] = ( 29 | model(sample_subset).squeeze(1).detach().cpu() 30 | ) 31 | head += max_batch 32 | 33 | sdf_values = sdf_values.reshape(grid_res, grid_res, grid_res) 34 | 35 | numpy_3d_sdf_tensor = sdf_values.data.cpu().numpy() 36 | 37 | verts, faces = mcubes.marching_cubes(numpy_3d_sdf_tensor, 0.0) 38 | 39 | vertices = verts / (grid_res - 1.0) * 2.0 - 1.0 40 | 41 | print(f'\nSaving mesh to {mesh_out_path}...', end="") 42 | 43 | mesh = trimesh.Trimesh(vertices, faces, process=False) # important, process=True leads to seg fault... 44 | mesh.export(mesh_out_path) 45 | 46 | print(f"==> Finished saving mesh.") 47 | -------------------------------------------------------------------------------- /train_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from scripts.img.opt import get_opts 4 | 5 | # data 6 | from torch.utils.data import DataLoader 7 | from datasets.img.imager import ImageDataset 8 | from scripts.img.common import read_image 9 | 10 | # models 11 | import commentjson as json 12 | from models.networks.img.NFFB_2d import NFFB 13 | 14 | # optimizer, losses 15 | from apex.optimizers import FusedAdam 16 | from torch.optim.lr_scheduler import StepLR 17 | 18 | 19 | # pytorch-lightning 20 | from pytorch_lightning import LightningModule, Trainer 21 | from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint 22 | from pytorch_lightning.loggers import TensorBoardLogger 23 | 24 | 25 | from utils import load_ckpt, seed_everything, process_batch_in_chunks 26 | 27 | # output 28 | import time 29 | from scripts.img.utils import write_image 30 | 31 | 32 | import warnings; warnings.filterwarnings("ignore") 33 | 34 | 35 | class ImageSystem(LightningModule): 36 | def __init__(self, hparams): 37 | super().__init__() 38 | self.save_hyperparameters(hparams) 39 | 40 | self.time = str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) 41 | 42 | exp_dir = os.path.join(self.hparams.output_dir, self.time) 43 | if not os.path.isdir(exp_dir): 44 | os.makedirs(exp_dir) 45 | 46 | ### Load the configuration file 47 | with open(self.hparams.config) as config_file: 48 | self.config = json.load(config_file) 49 | 50 | ### Save the configuration file 51 | path = f"{exp_dir}/config.json" 52 | with open(path, 'w') as f: 53 | json.dump(self.config, f, indent=4, separators=(", ", ": "), sort_keys=True) 54 | 55 | self.img_data = torch.from_numpy(read_image(self.hparams.input_path)).float() 56 | 57 | 58 | def setup(self, stage): 59 | self.model = NFFB(self.config["network"], out_dims=self.img_data.shape[2]) 60 | 61 | ema_decay = 0.95 62 | ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged: \ 63 | ema_decay * averaged_model_parameter + (1-ema_decay) * model_parameter 64 | self.ema_model = torch.optim.swa_utils.AveragedModel(self.model, avg_fn=ema_avg) 65 | 66 | 67 | self.train_dataset = ImageDataset(data=self.img_data, 68 | size=1000, 69 | num_samples=self.hparams.batch_size, 70 | split='train') 71 | 72 | self.test_dataset = ImageDataset(data=self.img_data, 73 | size=1, 74 | num_samples=self.hparams.batch_size, 75 | split='test') 76 | 77 | 78 | def forward(self, batch): 79 | b_pos = batch["points"] 80 | 81 | pred = self.model(b_pos) 82 | 83 | return pred 84 | 85 | 86 | def on_fit_start(self): 87 | seed_everything(self.hparams.seed) 88 | 89 | 90 | def configure_optimizers(self): 91 | load_ckpt(self.model, self.hparams.ckpt_path) 92 | 93 | opts = [] 94 | net_params = self.model.get_params(self.config["training"]["LR_scheduler"]) 95 | self.net_opt = FusedAdam(net_params, betas=(0.9, 0.99), eps=1e-15) 96 | opts += [self.net_opt] 97 | 98 | lr_interval = self.config["training"]["LR_scheduler"][0]["interval"] 99 | lr_factor = self.config["training"]["LR_scheduler"][0]["factor"] 100 | 101 | if self.config["training"]["LR_scheduler"][0]["type"] == "Step": 102 | net_sch = StepLR(self.net_opt, step_size=lr_interval, gamma=lr_factor) 103 | else: 104 | net_sch = None 105 | 106 | return opts, [net_sch] 107 | 108 | 109 | def train_dataloader(self): 110 | return DataLoader(self.train_dataset, 111 | num_workers=16, 112 | persistent_workers=True, 113 | batch_size=None, 114 | pin_memory=True) 115 | 116 | 117 | def val_dataloader(self): 118 | return DataLoader(self.test_dataset, 119 | num_workers=8, 120 | batch_size=None, 121 | pin_memory=True) 122 | 123 | 124 | def predict_dataloader(self): 125 | return DataLoader(self.test_dataset, 126 | num_workers=8, 127 | batch_size=None, 128 | pin_memory=True) 129 | 130 | 131 | def training_step(self, batch, batch_nb, *args): 132 | results = self(batch) 133 | 134 | b_occ = batch['rgbs'].to(results.dtype) 135 | 136 | batch_loss = (results - b_occ)**2 / (b_occ.detach()**2 + 1e-2) 137 | loss = batch_loss.mean() 138 | 139 | self.log('lr/network', self.net_opt.param_groups[0]['lr'], True) 140 | self.log('train/loss', loss) 141 | 142 | return loss 143 | 144 | 145 | def training_epoch_end(self, training_step_outputs): 146 | for name, cur_para in self.model.named_parameters(): 147 | if len(cur_para) == 0: 148 | print(f"The len of parameter {name} is 0 at epoch {self.current_epoch}.") 149 | continue 150 | 151 | if cur_para is not None and cur_para.requires_grad and cur_para.grad is not None: 152 | para_norm = torch.norm(cur_para.grad.detach(), 2) 153 | self.log('Grad/%s_norm' % name, para_norm) 154 | 155 | 156 | def on_before_zero_grad(self, optimizer): 157 | if self.ema_model is not None: 158 | self.ema_model.update_parameters(self.model) 159 | 160 | 161 | def backward(self, loss, optimizer, optimizer_idx): 162 | # do a custom way of backward to retain graph 163 | loss.backward(retain_graph=True) 164 | 165 | 166 | def on_train_start(self): 167 | gt_img = self.img_data.reshape(self.img_data.shape).float().clamp(0.0, 1.0) 168 | gt_img = gt_img.cpu().numpy() 169 | 170 | img_path = f'{self.hparams.output_dir}/{self.time}/reference.jpg' 171 | write_image(img_path, gt_img) 172 | print(f"\nWriting '{img_path}'... ", end="") 173 | 174 | 175 | model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) 176 | self.log("misc/model_size", model_size) 177 | print(f"\nThe model size: {model_size}") 178 | 179 | 180 | def on_train_end(self): 181 | # The final validation will use the ema model, as it replaces our normal model 182 | if self.ema_model is not None: 183 | print("Replacing the standard model with the EMA model for last validation run") 184 | self.model = self.ema_model 185 | 186 | 187 | def on_validation_start(self): 188 | torch.cuda.empty_cache() 189 | 190 | if not self.hparams.no_save_test: 191 | self.val_dir = f'{self.hparams.output_dir}/{self.time}/validation/' 192 | os.makedirs(self.val_dir, exist_ok=True) 193 | 194 | 195 | def validation_step(self, batch, batch_nb): 196 | img_size = self.img_data.shape[0] * self.img_data.shape[1] 197 | 198 | pred_img = process_batch_in_chunks(batch["points"], self.ema_model, max_chunk_size=2**18) 199 | pred_img = pred_img[:img_size, :].reshape(self.img_data.shape).float().clamp(0.0, 1.0) 200 | 201 | pred_img = pred_img.cpu().numpy() 202 | 203 | if not self.hparams.no_save_test: 204 | img_path = f"{self.val_dir}/{self.current_epoch}.jpg" 205 | write_image(img_path, pred_img) 206 | 207 | 208 | def predict_step(self, batch, batch_idx): 209 | img_size = self.img_data.shape[0] * self.img_data.shape[1] 210 | 211 | pred_img = process_batch_in_chunks(batch["points"], self.ema_model, max_chunk_size=2**18) 212 | pred_img = pred_img[:img_size, :].reshape(self.img_data.shape).float().clamp(0.0, 1.0) 213 | pred_img = pred_img.cpu().numpy() 214 | 215 | img_path = f"{self.val_dir}/result.jpg" 216 | write_image(img_path, pred_img) 217 | 218 | 219 | def get_progress_bar_dict(self): 220 | # don't show the version number 221 | items = super().get_progress_bar_dict() 222 | items.pop("v_num", None) 223 | return items 224 | 225 | 226 | if __name__ == '__main__': 227 | hparams = get_opts() 228 | if hparams.val_only and (not hparams.ckpt_path): 229 | raise ValueError('You need to provide a @ckpt_path for validation!') 230 | system = ImageSystem(hparams) 231 | 232 | ckpt_cb = ModelCheckpoint(dirpath=f'{hparams.output_dir}/{system.time}/ckpts/', 233 | filename='{epoch:d}', 234 | save_weights_only=True, 235 | every_n_epochs=hparams.num_epochs, 236 | save_on_train_epoch_end=True, 237 | save_top_k=-1) 238 | 239 | callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)] 240 | 241 | logger = TensorBoardLogger(save_dir=f"{hparams.output_dir}/{system.time}/logs/", 242 | name="", 243 | default_hp_metric=False) 244 | 245 | trainer = Trainer(max_epochs=hparams.num_epochs, 246 | check_val_every_n_epoch=5, 247 | callbacks=callbacks, 248 | logger=logger, 249 | enable_model_summary=False, 250 | accelerator='gpu', 251 | gradient_clip_val=1.0, 252 | strategy=None, 253 | num_sanity_val_steps=-1 if hparams.val_only else 0, 254 | precision=16) 255 | 256 | if hparams.val_only: 257 | trainer.predict(system, ckpt_path=hparams.ckpt_path) 258 | system.output_metrics(logger) 259 | else: 260 | trainer.fit(system, ckpt_path=hparams.ckpt_path) 261 | trainer.predict() 262 | system.output_metrics(logger) -------------------------------------------------------------------------------- /train_nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from scripts.nvs.opt import get_opts 4 | import os 5 | import glob 6 | import imageio 7 | import numpy as np 8 | import cv2 9 | from einops import rearrange 10 | 11 | # data 12 | from torch.utils.data import DataLoader 13 | from datasets import dataset_dict 14 | from datasets.nerf.ray_utils import axisangle_to_R, get_rays 15 | 16 | # models 17 | import commentjson as json 18 | from kornia.utils.grid import create_meshgrid3d 19 | from models.networks.nerf.NFFB_nerf import NFFB 20 | from models.networks.nerf.rendering import render, MAX_SAMPLES 21 | 22 | # optimizer, losses 23 | from apex.optimizers import FusedAdam 24 | from torch.optim.lr_scheduler import CosineAnnealingLR 25 | from models.loss.nerf.losses import NeRFLoss 26 | 27 | # metrics 28 | from torchmetrics import ( 29 | PeakSignalNoiseRatio, 30 | StructuralSimilarityIndexMeasure 31 | ) 32 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 33 | 34 | # pytorch-lightning 35 | from pytorch_lightning.plugins import DDPPlugin 36 | from pytorch_lightning import LightningModule, Trainer 37 | from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint 38 | from pytorch_lightning.loggers import TensorBoardLogger 39 | from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available 40 | 41 | from utils import slim_ckpt, load_ckpt, seed_everything 42 | 43 | # output 44 | import time 45 | 46 | import warnings; warnings.filterwarnings("ignore") 47 | 48 | 49 | def depth2img(depth): 50 | depth = (depth-depth.min())/(depth.max()-depth.min()) 51 | depth_img = cv2.applyColorMap((depth*255).astype(np.uint8), cv2.COLORMAP_TURBO) 52 | 53 | return depth_img 54 | 55 | class NeRFSystem(LightningModule): 56 | def __init__(self, hparams): 57 | super().__init__() 58 | self.save_hyperparameters(hparams) 59 | 60 | self.warmup_steps = 256 61 | self.update_interval = 16 62 | 63 | self.time = str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) 64 | 65 | exp_dir = os.path.join(f"experiments/", self.time) 66 | if not os.path.isdir(exp_dir): 67 | os.makedirs(exp_dir) 68 | 69 | with open(self.hparams.config) as config_file: 70 | self.net_config = json.load(config_file) 71 | 72 | ### Save the configuration file 73 | path = f"{exp_dir}/config.json" 74 | with open(path, 'w') as f: 75 | json.dump(self.net_config, f, indent=4, separators=(", ", ": "), sort_keys=True) 76 | 77 | self.loss = NeRFLoss(lambda_distortion=self.hparams.distortion_loss_w) 78 | self.train_psnr = PeakSignalNoiseRatio(data_range=1) 79 | self.val_psnr = PeakSignalNoiseRatio(data_range=1) 80 | self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1) 81 | if self.hparams.eval_lpips: 82 | self.val_lpips = LearnedPerceptualImagePatchSimilarity('vgg') 83 | ### Do not train the network parameters which are used to compute metrics 84 | for p in self.val_lpips.net.parameters(): 85 | p.requires_grad = False 86 | 87 | rgb_act = 'None' if self.hparams.use_exposure else 'Sigmoid' 88 | self.model = NFFB(self.net_config["network"], scale=self.hparams.scale, rgb_act=rgb_act) 89 | G = self.model.grid_size 90 | self.model.register_buffer('density_grid', torch.zeros(self.model.cascades, G**3)) 91 | self.model.register_buffer('grid_coords', 92 | create_meshgrid3d(G, G, G, False, dtype=torch.int32).reshape(-1, 3)) 93 | 94 | def forward(self, batch, split): 95 | if split=='train': 96 | poses = self.poses[batch['img_idxs']] 97 | directions = self.directions[batch['pix_idxs']] 98 | else: 99 | poses = batch['pose'] 100 | directions = self.directions 101 | 102 | if self.hparams.optimize_ext: 103 | dR = axisangle_to_R(self.dR[batch['img_idxs']]) 104 | poses[..., :3] = dR @ poses[..., :3] # Do the rotation for poses 105 | poses[..., 3] += self.dT[batch['img_idxs']] # Do the translation for poses 106 | 107 | rays_o, rays_d = get_rays(directions, poses) 108 | 109 | kwargs = {'test_time': split!='train', 110 | 'random_bg': self.hparams.random_bg} 111 | if self.hparams.scale > 0.5: 112 | kwargs['exp_step_factor'] = 1 / 256 113 | if self.hparams.use_exposure: 114 | kwargs['exposure'] = batch['exposure'] 115 | 116 | return render(self.model, rays_o, rays_d, **kwargs) 117 | 118 | ### Setup the dataset for training and testing 119 | def setup(self, stage): 120 | dataset = dataset_dict[self.hparams.dataset_name] 121 | kwargs = {'root_dir': self.hparams.root_dir, 122 | 'downsample': self.hparams.downsample} 123 | self.train_dataset = dataset(split=self.hparams.split, **kwargs) 124 | self.train_dataset.batch_size = self.hparams.batch_size 125 | self.train_dataset.ray_sampling_strategy = self.hparams.ray_sampling_strategy 126 | 127 | self.test_dataset = dataset(split='test', **kwargs) 128 | 129 | def on_fit_start(self): 130 | seed_everything(self.hparams.seed) 131 | 132 | def configure_optimizers(self): 133 | # define additional parameters 134 | self.register_buffer('directions', self.train_dataset.directions.to(self.device)) 135 | self.register_buffer('poses', self.train_dataset.poses.to(self.device)) 136 | 137 | if self.hparams.optimize_ext: 138 | N = len(self.train_dataset.poses) 139 | self.register_parameter('dR', 140 | nn.Parameter(torch.zeros(N, 3, device=self.device))) 141 | self.register_parameter('dT', 142 | nn.Parameter(torch.zeros(N, 3, device=self.device))) 143 | 144 | load_ckpt(self.model, self.hparams.weight_path) 145 | 146 | ### Exclude the parameters of camera extrinsics 147 | net_params = [] 148 | for n, p in self.named_parameters(): 149 | if n not in ['dR', 'dT']: net_params += [p] 150 | 151 | opts = [] 152 | net_params = self.model.get_params(self.net_config["training"]["LearningRateSchedule"]) 153 | self.net_opt = FusedAdam(net_params, betas=(0.9, 0.99), eps=1e-15) 154 | opts += [self.net_opt] 155 | if self.hparams.optimize_ext: 156 | opts += [FusedAdam([self.dR, self.dT], 1e-6)] # learning rate is hard-coded 157 | net_sch = CosineAnnealingLR(self.net_opt, 158 | self.hparams.num_epochs, 159 | self.net_config["training"]["lr_threshold"]) 160 | 161 | return opts, [net_sch] 162 | 163 | def train_dataloader(self): 164 | return DataLoader(self.train_dataset, 165 | num_workers=16, 166 | persistent_workers=True, 167 | batch_size=None, 168 | pin_memory=True) 169 | 170 | def val_dataloader(self): 171 | return DataLoader(self.test_dataset, 172 | num_workers=8, 173 | batch_size=None, 174 | pin_memory=True) 175 | 176 | def on_train_start(self): 177 | self.model.mark_invisible_cells(self.train_dataset.K.to(self.device), 178 | self.poses, 179 | self.train_dataset.img_wh) 180 | 181 | model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) 182 | self.log("misc/model_size", model_size) 183 | print(f"\nThe model size: {model_size}") 184 | 185 | 186 | def training_step(self, batch, batch_nb, *args): 187 | if self.global_step % self.update_interval == 0: 188 | self.model.update_density_grid(0.01*MAX_SAMPLES/3**0.5, 189 | warmup=self.global_step 1 c h w', h=h) 248 | rgb_gt = rearrange(rgb_gt, '(h w) c -> 1 c h w', h=h) 249 | self.val_ssim(rgb_pred, rgb_gt) 250 | logs['ssim'] = self.val_ssim.compute() 251 | self.val_ssim.reset() 252 | if self.hparams.eval_lpips: 253 | self.val_lpips(torch.clip(rgb_pred * 2 - 1, -1, 1), 254 | torch.clip(rgb_gt * 2 - 1, -1, 1)) 255 | logs['lpips'] = self.val_lpips.compute() 256 | self.val_lpips.reset() 257 | 258 | if not self.hparams.no_save_test: # save test image to disk 259 | idx = batch['img_idxs'] 260 | rgb_pred = rearrange(results['rgb'].cpu().numpy(), '(h w) c -> h w c', h=h) 261 | rgb_pred = (rgb_pred*255).astype(np.uint8) 262 | depth = depth2img(rearrange(results['depth'].cpu().numpy(), '(h w) -> h w', h=h)) 263 | imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}.png'), rgb_pred) 264 | imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}_d.png'), depth) 265 | 266 | return logs 267 | 268 | def validation_epoch_end(self, outputs): 269 | psnrs = torch.stack([x['psnr'] for x in outputs]) 270 | mean_psnr = all_gather_ddp_if_available(psnrs).mean() 271 | self.log('test/psnr', mean_psnr, True) 272 | 273 | ssims = torch.stack([x['ssim'] for x in outputs]) 274 | mean_ssim = all_gather_ddp_if_available(ssims).mean() 275 | self.log('test/ssim', mean_ssim) 276 | 277 | if self.hparams.eval_lpips: 278 | lpipss = torch.stack([x['lpips'] for x in outputs]) 279 | mean_lpips = all_gather_ddp_if_available(lpipss).mean() 280 | self.log('test/lpips_vgg', mean_lpips) 281 | 282 | def get_progress_bar_dict(self): 283 | # don't show the version number 284 | items = super().get_progress_bar_dict() 285 | items.pop("v_num", None) 286 | return items 287 | 288 | 289 | if __name__ == '__main__': 290 | hparams = get_opts() 291 | if hparams.val_only and (not hparams.ckpt_path): 292 | raise ValueError('You need to provide a @ckpt_path for validation!') 293 | system = NeRFSystem(hparams) 294 | 295 | ckpt_cb = ModelCheckpoint(dirpath=f'experiments/{system.time}/ckpts/', 296 | filename='{epoch:d}', 297 | save_weights_only=True, 298 | every_n_epochs=hparams.num_epochs, 299 | save_on_train_epoch_end=True, 300 | save_top_k=-1) 301 | 302 | callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)] 303 | 304 | logger = TensorBoardLogger(save_dir=f"experiments/{system.time}/logs/", 305 | name="", 306 | default_hp_metric=False) 307 | 308 | trainer = Trainer(max_epochs=hparams.num_epochs, 309 | check_val_every_n_epoch=hparams.num_epochs, 310 | callbacks=callbacks, 311 | logger=logger, 312 | enable_model_summary=False, 313 | accelerator='gpu', 314 | devices=hparams.num_gpus, 315 | strategy=DDPPlugin(find_unused_parameters=True) 316 | if hparams.num_gpus>1 else None, 317 | num_sanity_val_steps=-1 if hparams.val_only else 0, 318 | precision=16) 319 | 320 | trainer.fit(system, ckpt_path=hparams.ckpt_path) 321 | 322 | if not hparams.val_only: # save slimmed ckpt for the last epoch 323 | ckpt_ = \ 324 | slim_ckpt(f'experiments/{system.time}/ckpts/epoch={hparams.num_epochs-1}.ckpt', 325 | save_poses=hparams.optimize_ext) 326 | torch.save(ckpt_, f'experiments/{system.time}/ckpts/epoch={hparams.num_epochs-1}_slim.ckpt') 327 | 328 | if (not hparams.no_save_test) and \ 329 | hparams.dataset_name=='nsvf' and \ 330 | 'Synthetic' in hparams.root_dir: # save video 331 | imgs = sorted(glob.glob(os.path.join(system.val_dir, '*.png'))) 332 | imageio.mimsave(os.path.join(system.val_dir, 'rgb.mp4'), 333 | [imageio.imread(img) for img in imgs[::2]], 334 | fps=30, macro_block_size=1) 335 | imageio.mimsave(os.path.join(system.val_dir, 'depth.mp4'), 336 | [imageio.imread(img) for img in imgs[1::2]], 337 | fps=30, macro_block_size=1) -------------------------------------------------------------------------------- /train_sdf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from scripts.sdf.opt import get_opts 4 | 5 | # data 6 | from torch.utils.data import DataLoader 7 | from datasets.sdf.sampler import SDFDataset 8 | 9 | # models 10 | import commentjson as json 11 | from models.networks.sdf.NFFB_3d import NFFB 12 | 13 | # optimizer, losses 14 | from apex.optimizers import FusedAdam 15 | from torch.optim.lr_scheduler import StepLR 16 | 17 | # pytorch-lightning 18 | from pytorch_lightning import LightningModule, Trainer 19 | from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint 20 | from pytorch_lightning.loggers import TensorBoardLogger 21 | 22 | from utils import load_ckpt, seed_everything 23 | 24 | # output 25 | import time 26 | from scripts.sdf.utils import create_mesh 27 | 28 | 29 | import warnings; warnings.filterwarnings("ignore") 30 | 31 | 32 | class SDFSystem(LightningModule): 33 | def __init__(self, hparams): 34 | super().__init__() 35 | self.save_hyperparameters(hparams) 36 | 37 | self.time = str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) 38 | 39 | exp_dir = os.path.join(self.hparams.output_dir, self.time) 40 | if not os.path.isdir(exp_dir): 41 | os.makedirs(exp_dir) 42 | 43 | ### Load the configuration file 44 | with open(self.hparams.config) as config_file: 45 | self.config = json.load(config_file) 46 | 47 | ### Save the configuration file 48 | path = f"{exp_dir}/config.json" 49 | with open(path, 'w') as f: 50 | json.dump(self.config, f, indent=4, separators=(", ", ": "), sort_keys=True) 51 | 52 | self.model = NFFB(self.config["network"]) 53 | 54 | ema_decay = 0.95 55 | ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged: \ 56 | ema_decay * averaged_model_parameter + (1-ema_decay) * model_parameter 57 | self.ema_model = torch.optim.swa_utils.AveragedModel(self.model, avg_fn=ema_avg) 58 | 59 | 60 | def setup(self, stage): 61 | self.train_dataset = SDFDataset(path=self.hparams.input_path, 62 | size=1000, 63 | num_samples=self.hparams.batch_size, 64 | clip_sdf=self.hparams.clamp_distance) 65 | 66 | self.test_dataset = SDFDataset(path=self.hparams.input_path, 67 | size=1, 68 | num_samples=self.hparams.batch_size, 69 | clip_sdf=self.hparams.clamp_distance) 70 | 71 | 72 | def forward(self, batch): 73 | b_pos = batch["points"] 74 | 75 | pred = self.model(b_pos) 76 | 77 | if self.hparams.clamp_distance > 0.0: 78 | pred = torch.clamp(pred, -self.hparams.clamp_distance, self.hparams.clamp_distance) 79 | 80 | return pred 81 | 82 | 83 | def on_fit_start(self): 84 | seed_everything(self.hparams.seed) 85 | 86 | 87 | def configure_optimizers(self): 88 | load_ckpt(self.model, self.hparams.ckpt_path) 89 | 90 | opts = [] 91 | net_params = self.model.get_params(self.config["training"]["LR_scheduler"]) 92 | self.net_opt = FusedAdam(net_params, betas=(0.9, 0.99), eps=1e-15) 93 | opts += [self.net_opt] 94 | 95 | lr_interval = self.config["training"]["LR_scheduler"][0]["interval"] 96 | lr_factor = self.config["training"]["LR_scheduler"][0]["factor"] 97 | 98 | if self.config["training"]["LR_scheduler"][0]["type"] == "Step": 99 | net_sch = StepLR(self.net_opt, step_size=lr_interval, gamma=lr_factor) 100 | else: 101 | net_sch = None 102 | 103 | return opts, [net_sch] 104 | 105 | def train_dataloader(self): 106 | return DataLoader(self.train_dataset, 107 | num_workers=16, 108 | persistent_workers=True, 109 | batch_size=None, 110 | pin_memory=True) 111 | 112 | def val_dataloader(self): 113 | return DataLoader(self.test_dataset, 114 | num_workers=8, 115 | batch_size=None, 116 | pin_memory=True) 117 | 118 | def predict_dataloader(self): 119 | return DataLoader(self.test_dataset, 120 | num_workers=8, 121 | batch_size=None, 122 | pin_memory=True) 123 | 124 | def training_step(self, batch, batch_nb, *args): 125 | results = self(batch) 126 | 127 | b_occ = batch['sdfs'].to(results.dtype) 128 | if self.hparams.clamp_distance > 0.0: 129 | b_occ = torch.clamp(b_occ, -self.hparams.clamp_distance, self.hparams.clamp_distance) 130 | 131 | batch_loss = (results - b_occ)**2 / (b_occ.detach()**2 + 1e-4) 132 | loss = batch_loss.mean() 133 | 134 | self.log('lr/network', self.net_opt.param_groups[0]['lr'], True) 135 | self.log('train/loss', loss) 136 | 137 | return loss 138 | 139 | def training_epoch_end(self, training_step_outputs): 140 | for name, cur_para in self.model.named_parameters(): 141 | if len(cur_para) == 0: 142 | print(f"The len of parameter {name} is 0 at epoch {self.current_epoch}.") 143 | continue 144 | 145 | if cur_para is not None and cur_para.requires_grad: 146 | para_norm = torch.norm(cur_para.grad.detach(), 2) 147 | self.log('Grad/%s_norm' % name, para_norm) 148 | 149 | def on_before_zero_grad(self, optimizer): 150 | if self.ema_model is not None: 151 | self.ema_model.update_parameters(self.model) 152 | 153 | def backward(self, loss, optimizer, optimizer_idx): 154 | # to retain graph 155 | loss.backward(retain_graph=True) 156 | 157 | def on_train_start(self): 158 | model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) 159 | self.log("misc/model_size", model_size) 160 | print(f"\nThe model size: {model_size}") 161 | 162 | def on_train_end(self): 163 | # The final validation will use the ema model, as it replaces our normal model 164 | if self.ema_model is not None: 165 | print("Replacing the standard model with the EMA model for last validation run") 166 | self.model = self.ema_model 167 | 168 | def on_validation_start(self): 169 | torch.cuda.empty_cache() 170 | 171 | if not self.hparams.no_save_test: 172 | self.val_dir = f'{self.hparams.output_dir}/{self.time}/validation/' 173 | os.makedirs(self.val_dir, exist_ok=True) 174 | 175 | def validation_step(self, batch, batch_nb): 176 | if not self.hparams.no_save_test: 177 | res = 256 178 | mesh_path = os.path.join(self.val_dir, f'val_{self.current_epoch}_{res}.ply') 179 | 180 | create_mesh(self.ema_model, mesh_path, res) 181 | 182 | 183 | def on_predict_start(self): 184 | torch.cuda.empty_cache() 185 | 186 | if not self.hparams.no_save_test: 187 | self.pred_dir = f'{self.hparams.output_dir}/{self.time}/results/' 188 | os.makedirs(self.pred_dir, exist_ok=True) 189 | 190 | def predict_step(self, batch, batch_nb): 191 | if not self.hparams.no_save_test: 192 | res = 1024 193 | mesh_path = os.path.join(self.pred_dir, f'output_{res}.ply') 194 | 195 | create_mesh(self.model, mesh_path, res) 196 | 197 | 198 | def get_progress_bar_dict(self): 199 | # don't show the version number 200 | items = super().get_progress_bar_dict() 201 | items.pop("v_num", None) 202 | return items 203 | 204 | 205 | if __name__ == '__main__': 206 | hparams = get_opts() 207 | if hparams.val_only and (not hparams.ckpt_path): 208 | raise ValueError('You need to provide a @ckpt_path for validation!') 209 | system = SDFSystem(hparams) 210 | 211 | ckpt_cb = ModelCheckpoint(dirpath=f'{hparams.output_dir}/{system.time}/ckpts/', 212 | filename='{epoch:d}', 213 | save_weights_only=True, 214 | every_n_epochs=hparams.num_epochs, 215 | save_on_train_epoch_end=True, 216 | save_top_k=-1) 217 | 218 | callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)] 219 | 220 | logger = TensorBoardLogger(save_dir=f"{hparams.output_dir}/{system.time}/logs/", 221 | name="", 222 | default_hp_metric=False) 223 | 224 | trainer = Trainer(max_epochs=hparams.num_epochs, 225 | check_val_every_n_epoch=5, 226 | callbacks=callbacks, 227 | logger=logger, 228 | enable_model_summary=False, 229 | accelerator='gpu', 230 | gradient_clip_val=1.0, 231 | strategy=None, 232 | num_sanity_val_steps=-1 if hparams.val_only else 0, 233 | precision=16) 234 | 235 | if hparams.val_only: 236 | trainer.predict(system, ckpt_path=hparams.ckpt_path) 237 | else: 238 | trainer.fit(system, ckpt_path=hparams.ckpt_path) 239 | 240 | if (not hparams.no_save_test): # save mesh 241 | trainer.predict(system) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import random 5 | 6 | import torch 7 | import pytorch_lightning 8 | 9 | 10 | def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]): 11 | checkpoint = torch.load(ckpt_path, map_location='cpu') 12 | checkpoint_ = {} 13 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint 14 | checkpoint = checkpoint['state_dict'] 15 | for k, v in checkpoint.items(): 16 | if not k.startswith(model_name): 17 | continue 18 | k = k[len(model_name)+1:] 19 | for prefix in prefixes_to_ignore: 20 | if k.startswith(prefix): 21 | break 22 | else: 23 | checkpoint_[k] = v 24 | return checkpoint_ 25 | 26 | 27 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): 28 | if not ckpt_path: return 29 | model_dict = model.state_dict() 30 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore) 31 | model_dict.update(checkpoint_) 32 | model.load_state_dict(model_dict) 33 | 34 | 35 | def slim_ckpt(ckpt_path, save_poses=False): 36 | ckpt = torch.load(ckpt_path, map_location='cpu') 37 | # pop unused parameters 38 | keys_to_pop = ['directions', 'model.density_grid', 'model.grid_coords'] 39 | if not save_poses: keys_to_pop += ['poses'] 40 | for k in ckpt['state_dict']: 41 | if k.startswith('val_lpips'): 42 | keys_to_pop += [k] 43 | for k in keys_to_pop: 44 | ckpt['state_dict'].pop(k, None) 45 | return ckpt['state_dict'] 46 | 47 | 48 | 49 | def seed_everything(seed): 50 | random.seed(seed) 51 | os.environ['PYTHONHASHSEED'] = str(seed) 52 | np.random.seed(seed) 53 | torch.manual_seed(seed) 54 | torch.cuda.manual_seed(seed) 55 | pytorch_lightning.seed_everything(seed, workers=True) 56 | #torch.backends.cudnn.deterministic = True 57 | #torch.backends.cudnn.benchmark = True 58 | 59 | 60 | 61 | def process_batch_in_chunks(in_ccords, model, max_chunk_size=1024): 62 | chunk_outs = [] 63 | 64 | coord_chunks = torch.split(in_ccords, max_chunk_size) 65 | for chunk_batched_in in coord_chunks: 66 | tmp_img = model(chunk_batched_in) 67 | chunk_outs.append(tmp_img.detach()) 68 | 69 | batched_out = torch.cat(chunk_outs, dim=0) 70 | 71 | return batched_out --------------------------------------------------------------------------------